Devlog

[백준 1287] 할 수 있다 (Python) 본문

Problem Solving/코딩문제풀기

[백준 1287] 할 수 있다 (Python)

recoma 2022. 9. 7. 02:07
728x90

저격당한 풀이 입니다.

 

1287번: 할 수 있다

곱하기가 연산자 우선순위가 빠르므로 5+(1+2)*3 = 5+3*3 = 5+9 = 14가 된다. 연산자의 우선순위는 다음과 같다. (), */, +- 여기서 *와 /가 연산자 우선순위가 같고, +와 -가 연산자 우선순위가 같다. ()가

www.acmicpc.net

하지만 eval로는 할 수 없다 ㅠㅠ

문제

괄호가 포함되어 있는 사칙연산 계산기를 만드는 교육적인 문제 입니다. 하지만 이 부분을 신경써야 합니다.

길이는 1000자를 넘지 않는다

반대로 생각해보면 숫자가 최대 998자리 까지 될 수 있습니다. Big Integer를 사용해야 합니다. 그러니 Python의 Decimal를 사용하거나, Java의  BigInteger를 사용하는 것을 추천합니다. 여기서는 Python의 Decimal을 사용합니다.

eval()로 날먹?

Python에는 eval()이라는 함수가 있습니다. 이 함수는 문자열로 되어 있는 코드를 그래도 실행해주는 함수 입니다. 예를 들어

>>> string = '1+4/2'
>>> print(eval(string))
3.0

문자열 식을 만들어 놓고 eval()함수에 그 문자열을 집어넣으면 그 식에 대한 정답을 계산해 줍니다. 그렇기 때문에 eval() 함수가 계산기를 대신 할 수도 있고 이걸 정규표현식과 같이 활용해서 문제를 풀을 수 있었습니다.(물론 전 무식하게도 못했습니다.) 실제로도 eval()를 활용한 코드들이 정답처리 되고 있습니다. 그런데 최근에 eval()에 대한 반례가 게시판에 등장하게 됩니다.

'(' * 499 + '1' + ')' * 499

이 식은 1을 가운데 두고 좌/우 괄호가 499개 씩 있는 식입니다. 즉 (((((1))))) 에서 ()가 499개 있다는 의미가 됩니다. eval은 이 식을 정상적으로 처리하지 못합니다. 

>>> s = '(' * 499 + '1' + ')' * 499
>>> eval(s)
SyntaxError: too many nested parentheses

중첩된 괄호가 너무 많으면 eval이 더 이상 계산을 하지 않고 던져 버립니다. 여러번 테스트를 해 봤는데 중첩된 괄호가 한 200개 쯤 부터 이 에러를 뱉는거 같군요. 그러니 eval,정규식 사용 보다는 그냥 빡구현으로 문제를 해결합시다. 당장 eval로 풀었다고 쳐도 저 저격 데이터로 재채점을 하게 되면, 오답 쳐리가 되서 다시 풀어야 합니다. 아마 저게 재채점 되는 날에는 한자릿수 정답률을 구경할 수 있지 않을 까 생각됩니다. 파이썬에선 거의 다 정규식하고 eval로 풀었던데.

그냥 정공법을 쓰자

문자열 + 파싱 문제 답게 반례가 끔찍하게 많습니다. 그렇기에 이 문제는 계산보다 예외처리가 훨씬 어렵습니다.

전처리

Python Decimal를 사용해야 합니다. 그런데 아무리 Decimal도 따로 커스텀을 해주지 않으면 큰 수를 계산할 수 없습니다. 그렇기 때문에 아래와 같이 조정해 줍니다.

from decimal import Decimal, getcontext
getcontext().prec = 2000

최대 2000자 까지 계산할 수 있게 되었습니다.

그리고 계산하기에 앞서 문자열에 있는 식들도 파싱해야 하는데, 피연산자(숫자)가 두자릿수 이상인 점을 고려해야 합니다. 이렇게 해서 숫자 문자열을 찾으면 Decimal형태로 저장합니다. 이렇게 해서 파싱된 데이터들은 배열에다 저장합니다.

s = input()[:-1]
c = {'+', '-', '/', '*', '(', ')'}
arr = []

# parsing
p = -1
for i, char in enumerate(s):
    if char in c:
        if p == -1:
            arr.append(char)
        else:
            arr.append(Decimal(s[p:i]))
            arr.append(char)
            p = -1
    else:
        if p == -1: p = i
if p != -1: arr.append(Decimal(s[p:]))

계산하기

기본적으로 스택을 활용하며 숫자와 사칙연산들이 들어올 때마다 스택에 저장합니다. 사칙연산이 들어올 경우, 우선순위가 낮은 (-,+)는 스택에 집어넣고 (*,/)의 경우 스택에 일단 집어넣다가, 바로 다음에 숫자가 들어오면, 스택에서 연선과 같이 계산할 숫자 총 두개를 꺼내 연산을 수행한 다음, 연산이 끝난 수를 다시 스택에 집어넣습니다. 이렇게 반복해서 끝에 다다르면, 스택에는 +,-만 있는 식이 남게 되는 데, 이때 나머지들을 모조리 계산해 주기만 하면 됩니다.

괄호가 없으면 이정도 알고리즘만해도 충분히 풀 수 있지만, 여기서는 괄호가 추가됩니다.

괄호가 들어가 있는 계산식을 아래와 같이 트리로 표현할 수 있습니다. 트리의 레벨이 높을 수록 계산의 우선순위가 높으며 문자열 식을 토대로 위와 같은 트리를 만든 다음, 우선순위가 높은순 대로 계산만하면 됩니다.

하지만 이는 계산식이 잘못되지 않았을 경우에만 트리로 풀 수 있으며, 이번 문제는 계산식이 잘못된 경우도 있어, 당장 트리로 만들어서 풀기에는 여러 예외적인 부분들을 처리해야 하기 때문에 구현이 복잡하게 됩니다. 하지만 위의 그림과 비슷한 방식으로 풀 수 있습니다. 재귀를 사용합니다

 

def f(a, s, i, flag=True):
    """
    a -> 사칙연산 식이 들어있는 배열
    s -> 스택
    i -> 현재 배열 a에 대한 인덱스 위치
    flag -> 괄호로 인해 재귀되었으면 True, 아니면 False
    """

")"가 보이거나, 혹은 식의 맨 마지막에 도달할 때 까지 배열 a를 선형탐색합니다. 그중에 "("를 만나면 그자리에서 동일한 함수를 재귀호출합니다. "("를 만날 때마다 계속 호출 하다보면, 중간에 괄호가 더이 상 없는 식이 보이게 됩니다. 계속 진행해서 ")"가 보이면 (+,-)만 남은 식들을 모조리 정리하고 이에 나온 결과값을 스택에 저장해서 재귀를 마칩니다.

즉, "("가 보이면 재귀함수를 호출하고, ")"가 보이면 스택의 요소 갯수를 "("가 시작되었을 때의 요소 갯수가 될 때 까지 연산을 수행한 다음, 이에 나온 결과값을 스택에 저장하고 함수를 끝냅니다.

예외 처리

인접한 부분

박스 안에 있는 식이 포함되어 있을 경우 올바른 식이 아닙니다.

사칙연산 (-,+,*,/)

(+, ++
맨 처음에 사칙연산으로 시작하는 경우: (ex: +3, +4-6, -3, *)

피연산자 (숫자)

)3, )45

괄호 "("

)(, 3(,

괄호 ")"

+), (), ) (식이 ")"로 시작하는 경우)

전체적인 부분

모든 괄호가 항상 매칭되어야 합니다. 어느 한쪽이라도 매칭이 안되어 있으면 올바른 식이 아닙니다.

(3+4, )3+5, 4-3), 4-5(

 

코드 보기(Python)

더보기
import sys
from decimal import Decimal, getcontext
input = sys.stdin.readline
getcontext().prec = 2000

def _exit():
    print('ROCK')
    exit(0)

s = input()[:-1]
c = {'+', '-', '/', '*', '(', ')'}
cal = {
    '+': lambda a, b: a + b,
    '-': lambda a, b: a - b,
    '*': lambda a, b: a * b,
    '/': lambda a, b: a // b,
}
arr = []

# parsing
p = -1
for i, char in enumerate(s):
    if char in c:
        if p == -1:
            arr.append(char)
        else:
            arr.append(Decimal(s[p:i]))
            arr.append(char)
            p = -1
    else:
        if p == -1: p = i
if p != -1: arr.append(Decimal(s[p:]))

def f(a, s, i, flag=True):
    p = len(s) if not flag else len(s) - 1
    while i < len(a):
        e = a[i]
        if e in cal.keys():
            # 사칙연산
            if not s or not isinstance(s[-1], Decimal): _exit()
            s.append(e)
        elif e == '(':
            if s and isinstance(s[-1], Decimal): _exit()
            s.append(e)
            i = f(a, s, i+1)
        elif e == ')':
            if not s or not isinstance(s[-1], Decimal) or not flag: _exit()
            break
        else:
            if s and (a[i-1] == ')' or isinstance(s[-1], Decimal)): _exit()
            if s and s[-1] in ('*', '/'):
                alc, x1, x2 = s.pop(), s.pop(), e
                y = cal[alc](x1, x2)
                s.append(y)
            else:
                s.append(e)
        i += 1
    
    if i == len(a) and flag: _exit()

    y = s.pop()
    if not isinstance(y, Decimal): _exit()
    while len(s) > p:
        alc = s.pop()
        if alc == '(': break
        if alc not in cal.keys(): _exit()
        x = s.pop()
        if not isinstance(x, Decimal): _exit()
        y = cal[alc](x, y)
    s.append(y)
    return i

s = []
f(arr, s, 0, flag=False)
print(s[0])

게시판 반례 아니었으면 영원히 못풀었다 진짜

728x90
반응형