(Python) 백준 2096

Lee Yechan·2024년 1월 15일
0

알고리즘 문제 풀이

목록 보기
40/60
post-thumbnail

백준 2096

시간 제한메모리 제한제출정답맞힌 사람정답 비율
1 초4 MB (하단참조)39859152701195436.745%

문제

N줄에 0 이상 9 이하의 숫자가 세 개씩 적혀 있다. 내려가기 게임을 하고 있는데, 이 게임은 첫 줄에서 시작해서 마지막 줄에서 끝나게 되는 놀이이다.

먼저 처음에 적혀 있는 세 개의 숫자 중에서 하나를 골라서 시작하게 된다. 그리고 다음 줄로 내려가는데, 다음 줄로 내려갈 때에는 다음과 같은 제약 조건이 있다. 바로 아래의 수로 넘어가거나, 아니면 바로 아래의 수와 붙어 있는 수로만 이동할 수 있다는 것이다. 이 제약 조건을 그림으로 나타내어 보면 다음과 같다.

explanation on constraint of the game

별표는 현재 위치이고, 그 아랫 줄의 파란 동그라미는 원룡이가 다음 줄로 내려갈 수 있는 위치이며, 빨간 가위표는 원룡이가 내려갈 수 없는 위치가 된다. 숫자표가 주어져 있을 때, 얻을 수 있는 최대 점수, 최소 점수를 구하는 프로그램을 작성하시오. 점수는 원룡이가 위치한 곳의 수의 합이다.

입력

첫째 줄에 N(1 ≤ N ≤ 100,000)이 주어진다. 다음 N개의 줄에는 숫자가 세 개씩 주어진다. 숫자는 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 중의 하나가 된다.

출력

첫째 줄에 얻을 수 있는 최대 점수와 최소 점수를 띄어서 출력한다.


답안

import sys

height = int(sys.stdin.readline())
table = list(map(int, sys.stdin.readline().split()))
current_max, current_min = [0, 0, 0], [0, 0, 0]
previous_max, previous_min = [0, 0, 0], [0, 0, 0]
for i in range(3):
    previous_max[i], previous_min[i] = table[i], table[i]
    current_max[i], current_min[i] = table[i], table[i]
temp_max, temp_min = [], []

for _ in range(1, height):
    table = list(map(int, sys.stdin.readline().split()))
    for i in range(3):
        if i == 0:
            temp_max, temp_min = previous_max[:2], previous_min[:2]
        elif i == 1:
            temp_max, temp_min = previous_max[:], previous_min[:]
        else:
            temp_max, temp_min = previous_max[1:], previous_min[1:]
        current_max[i] = max(temp_max) + table[i]
        current_min[i] = min(temp_min) + table[i]
    previous_max, previous_min = current_max[:], current_min[:]

print(max(current_max), min(current_min))

풀이

맨 처음에는 구현의 편의를 위해, table에 모든 입력을 받아둔 후 코드를 실행하도록 로직을 짰다.

int 값이 하나 당 4byte를 차지한다고 했을 때, 10만개 행 3개 열 4byte == 120만 바이트 == 1.2MB 정도는 입력을 받을 수 있을 것이라고 생각했기 때문이다.

아래는 처음 짠 코드이다. 테이블이 2.4MB정도 메모리를 사용하게 되는데, 이렇게 했더니 메모리 초과로 오답 판정을 받았다.

import sys

height = int(sys.stdin.readline())
table = []
for _ in range(height):
    table.append(list(map(int, sys.stdin.readline().split())))
for i in range(3):
    table[0][i] = [table[0][i], table[0][i]]

for i in range(1, height):
    for j in range(3):
        previous_max, previous_min = [], []
        if j == 0:
            previous = table[i-1][:2]
        elif j == 1:
            previous = table[i-1][:]
        else:
            previous = table[i-1][1:]
        table[i][j] = [max(map(lambda x: x[0], previous)) + table[i][j],
                       min(map(lambda x: x[1], previous)) + table[i][j]]

print(max(map(lambda x: x[0], table[-1])), min(map(lambda x: x[1], table[-1])))

백준 기준, 입력을 받고 로직을 실행하기 전부터 메모리 초과가 된 것인데, python이 돌아가기 위한 기본 메모리 사용량이 생각했던 것보다 많았던 것이다.

아래는 두 번째로 짠 코드인데, 테이블만 저장한 뒤(1.2MB) 다른 것들은 그때그때 행을 보고 연산하는 식으로 바꿨지만, 이 역시도 메모리 초과 판정을 받았다.

import sys

height = int(sys.stdin.readline())
table = []
for _ in range(height):
    table.append(list(map(int, sys.stdin.readline().split())))
current_max, current_min = [0, 0, 0], [0, 0, 0]
previous_max, previous_min = [0, 0, 0], [0, 0, 0]
for i in range(3):
    previous_max[i], previous_min[i] = table[0][i], table[0][i]
temp_max, temp_min = [], []

for i in range(1, height):
    for j in range(3):
        if j == 0:
            temp_max, temp_min = previous_max[:2], previous_min[:2]
        elif j == 1:
            temp_max, temp_min = previous_max[:], previous_min[:]
        else:
            temp_max, temp_min = previous_max[1:], previous_min[1:]
        current_max[j] = max(temp_max) + table[i][j]
        current_min[j] = min(temp_min) + table[i][j]
    previous_max, previous_min = current_max[:], current_min[:]

print(max(current_max), min(current_min))

테이블 값의 하나의 행에 대하여 연산을 완료했다면, 다시는 그 값이 필요하지 않으므로 이것을 이용해서, 값이 필요하지 않은 경우 메모리에 들고 있지 않도록 코드를 다시 짰고, 이번에는 메모리 초과 없이 ‘정답입니다’ 판정을 받게 되었다.

우선 이 문제는 DP로 쉽게 풀 수 있다.

대각선 왼쪽 위, 바로 위, 대각선 오른쪽 위쪽으로부터 누적된 합의 최댓값과 최솟값을 메모이제이션을 통해 저장해 이 과정을 반복하면 된다.

그런데 메모리를 아끼는 것이 관건이므로, 10만개의 행을 메모리에 모두 들고 있지 않고, 현재 행의 실제 값과 메모이제이션된 값만 들고 있으면 되는 것이다.

하나씩 설명을 하면,

table: 테이블의 실제 값 중, 현재 보고 있는 row 값 ([a, b, c])
previous_max, previous_min: 테이블의 누적된 합 중 최대/최소를 저장한, 이전 row 값
temp_max, temp_min:(0, 1, 2) 값 중 하나를 갖는 j에 대해, 이전 row에서 현재 테이블 값에 더해질 수 있는 값들의 후보
current_max, current_min: 테이블의 누적된 합 중 최대/최소를 저장한, 현재 보고 있는 row 값

맨 처음으로, table에 실제 테이블의 row 값을 덮어쓴다.

temp_max, temp_min에 이전 테이블의 누적합을 저장한 메모이제이션 배열로부터, 현재 테이블 값에 더해질 수 있는 후보를 저장한다.

자신과 거리가 2 이상 떨어진 칸과는 더해질 수 없으므로, 만약 첫번째 열이라면 이전 행의 1, 2열과, 만약 세번째 열이라면 이전 행의 2, 3열이 후보가 된다.

current_max[j] = max(temp_max) + table[i][j]
current_min[j] = min(temp_min) + table[i][j]

와 같이, 그 후보 중 최대/최소 값을 실제 테이블의 값과 더해 메모이제이션된 값을 갱신한다.

previous_max, previous_min = current_max[:], current_min[:]

이때, 리스트를 deep copy하는 것을 잊지 않도록 한다. reference만 넘겨주면 로직이 수행될 수 없다.

테이블의 행의 개수 n에 대해 이것을 n-1회 반복하면 메모이제이션된 리스트에 지금까지 모든 행에 대한 누적합의 최대/최솟값이 저장되게 되므로, 이중 최대/최소인 값을 출력하면 된다.

print(max(current_max), min(current_min))

10만개의 행을 들고 있지 않더라도 문제를 풀 수 있도록 만드는 것이 이 문제의 핵심이었던 것 같다.

이전에 계산한 결과값이 필요 없다는 사실을 잘 캐치해서 필자와 같이 몇번이나 메모리 초과 판정을 받지 않도록 주의하면 좋을 것 같다.

profile
이예찬

0개의 댓글