[python] 구간합(prefix sum)_백준 문제풀이

이희진·2023년 3월 13일
0

baekjoon

목록 보기
1/1

구간 합 구하기

연속적으로 나열된 n개의 수가 있을 때, 특정 구간의 모든 수를 합한 값을 구하는 문제이다.
M개의 쿼리, N 크기의 구간이 주어질 때, 매번 구간 합을 계산한다면 해당 알고리즘은 O(NM)의 시간 복잡도를 가진다.
이러한 경우에는 M과 N의 범위가 커지면 시간 초과가 발생하여 문제를 해결하기 어려워진다.

접두사 합(Prefix Sum)

그렇다면 여러 번 사용될 만한 정보는 미리 구해 저장해두는건 어떤가? 확인해보면 쿼리는 M개지만 N개의 수는 한 번 주어지고 변경되지 않는다. 그러므로 접두사 합을 사용한다. N개의 수의 위치 각각에 대하여 접두사 합을 미리 구해두면 된다. 해당 알고리즘은 다음과 같다.

  1. N개의 수에 대하여 접두사 합(Prefix Sum)을 계산하여 배열 P에 저장한다.
  2. 이때, index를 헷갈리게 하지 않기 위해 새로운 배열에 0을 추가해둔다.
  3. 매 M개의 쿼리 정보 [L, R]을 확인할 때, 구간 합은 P[R] - P[L-1]이다.

알고리즘을 설계할 때마다 고려해야 하는 점은 여러 번 사용될 만한 정보는 미리 구해서 저장해 놓을수록 유리하다.
prefix_sum 알고리즘은 각각의 합들을 새로운 배열에 저장해뒀다가 나중에 입력에 구간이 들어오면 해당 구간을 index로 바꿔서 출력해주는 간단한 알고리즘이다.

매 입력이 들어올 때 P[right] - P[left-1]을 수행하면 계산시간은 O(1)이 되고,
결과적으로 M개의 데이터와 N개의 입력이 있을 때, 전체 구간 합을 구하는 작업은 O(M+N)의 시간이 보장된다.

백준 11659번 - 구간 합 구하기 4

처음 단순히 파이썬의 sum(arr) 라이브러리를 사용하여 구현했을 때, 예상치 못한 시간 초과가 발생했다.

import sys

N, M = map(int, sys.stdin.readline().split(' '))
arr = list(map(int, sys.stdin.readline().split(' ')))

for i in range(M):
    n, m = map(int, sys.stdin.readline().split(' '))
    print(sum(arr[n-1:m]))

그래서 반복 계산을 줄이기 위해 prefix_sum 방식을 적용해보았고

import sys

N, M = map(int, sys.stdin.readline().split(' '))
arr = list(map(int, sys.stdin.readline().split(' ')))
p = [0]

for i in range(len(arr)):
    p.append(p[i] + arr[i])

for i in range(M):
    n, m = map(int, sys.stdin.readline().split(' '))
    print(p[m]-p[n-1])

시간 초과 없이 해결되었다!

백준 2559번 - 수열

문제 링크

import sys
n, k = map(int, sys.stdin.readline().split(' '))
arr = list(map(int, sys.stdin.readline().split(' ')))
p = [0]
for i in range(len(arr)):
    p.append(p[i] + arr[i])
new_p = []   
for i in range(k, len(arr)+1):
    new_p.append(p[i] - p[i-k])

print(max(new_p))

백준 16139번 - 인간과 컴퓨터 상호작용

참고 - 아스키코드

이 문제는 prefix sum 배열을 2차원 배열로 정의하고, 각각의 소문자 알파벳의 아스키코드에 97에 빼서 0~25 의 인덱스를 나타내도록 하려고 한다.

즉, 문자열의 길이가 x일때 25*x 의 길이를 갖는 배열에 각각의 위치에서 알파벳 개수를 누적합으로 저장해놓는 방법으로 구현하는 것이다.

이 문제를 풀면서 다른 것보다 이차원 배열을 핸들링하는 거에 미숙하다는 것을 많이 느꼈다.

import sys

arr = list(sys.stdin.readline().strip())
p = [[0]*26]
p[0][ord(arr[0])-97] = 1
for i in range(1, len(arr)):
    p.append(p[-1][:])
    index = ord(arr[i]) - 97
    p[i][index] += 1


n = int(sys.stdin.readline())
for i in range(n):
    x, start, end = map(str, sys.stdin.readline().strip().split(' '))
    if start == '0':
        print(p[int(end)][ord(x)-97])
    else:
        print(p[int(end)][ord(x)-97] - p[int(start)-1][ord(x)-97])

백준 10986 - 나머지 합

첫번째는 모든 구간 합을 저장해가면서 나머지가 m인지 체크하는 방법이었다.
메모리 공간을 비효율적으로 쓰는 느낌이 들었는데 아니나다를까 메모리 초과가 났다.

for i in range(1, n):
    next_p = []
    for j in p[i-1]:
        v = j+arr[i]
        next_p.append(v)
        if v % m == 0:
            result += 1
    next_p.append(arr[i])
    if arr[i]%m == 0:
        result += 1
    p.append(next_p)

if arr[0]%m == 0:
    print(result + 1)
else:
    print(result)

그래서 메모리를 적게 쓰도록 수정했더니 시간 초과가 났다. 구글링하여 다른 사람들의 풀이를 보았는데 나머지를 기록해놓고, 나머지가 0이거나 나머지가 같은 구간끼리 빼면 된다는 것을 알았다.
수학 머리를 전혀 안 쓰고 풀려고 했구나...

import sys
input = sys.stdin.readline

N,M= map(int, input().split())
num = list(map(int, input().split()))
sum = 0
numRemainder = [0] * M

for i in range(N):
  sum += num[i]
  numRemainder[sum % M] += 1

result = numRemainder[0]
for i in numRemainder:
  result += i*(i-1)//2

print(result)

백준 11660 - 구간 합 구하기 5

이 문제는 2차원 배열의 구간 합을 구하는 문제이다.
배열을 표로 그려보면 알 수 있듯이
각 구간합 p[x][y] 는 p[x][y-1] + p[x-1][y] - p[x-1][y-1] + arr[x-1]이 된다는 걸 체크해야 한다. 처음에는 중복으로 들어가는 p[x-1][y-1] 값을 생각을 못했는데 p를 출력해가면서 값의 위치와 계산 과정을 그려보니 중복이 된다는 걸 알 수 있었다.

import sys
input = sys.stdin.readline

N, M = map(int, input().split(' '))
p = [[0]*(N+1) for _ in range(N+1)]
for n in range(1, N+1):
    arr = list(map(int, input().split(' ')))
    for i in range(1, len(arr)+1):
        p[n][i] = p[n][i-1] + p[n-1][i] - p[n-1][i-1] + arr[i-1]

for m in range(M):
    x1, y1, x2, y2 = map(int, input().split(' '))
    print(p[x2][y2] - p[x2][y1-1] - p[x1-1][y2] + p[x1-1][y1-1])

백준 25682 - 체스판 다시 칠하기 2


이 문제도 다른 사람의 풀이를 참고하여 풀었다.

먼저 무엇을 기준으로 구간합을 계산할 것인가?
1. 정답 체스판을 기준으로 맞으면 0, 틀리면 1이라고 생각해보자.
2. 정답 체스판은 그럼 두가지 케이스가 있다. 첫 칸이 흑인 경우와 백인 경우, 그러나 한 경우만 계산하면 된다. 왜냐 흑인 경우의 변경해야 하는 케이스는 백인 경우 정답인 케이스가 되기 때문에, 흑인 경우 max 값을 칸수에서 빼주면 백인 경우 min 값이 된다는 걸 알 수 있다.
3. 그 다음 누적합에서 k x k 크기의 누적합을 계산하여 min, max를 체크해준다.

import sys
input = sys.stdin.readline
n, m, k = map(int, input().split(' '))
p = [[0]*(m+1) for _ in range(n+1)]

for i in range(1, n+1):
    arr = list(input())
    for j in range(1, m+1):
        if (i+j)%2 == 0 and arr[j-1] == 'B':
            p[i][j] = p[i-1][j] + p[i][j-1] - p[i-1][j-1]
        elif (i+j)%2 == 1 and arr[j-1] == 'W':
            p[i][j] = p[i-1][j] + p[i][j-1] - p[i-1][j-1]
        else:
            p[i][j] = p[i-1][j] + p[i][j-1] - p[i-1][j-1] + 1
    

min_ = float('inf')
max_ = -float('inf')
for r in range(k, n+1):
    for c in range(k, m+1):
        min_ = min(p[r][c] - p[r-k][c] - p[r][c-k] + p[r-k][c-k], min_)
        max_ = max(p[r][c] - p[r-k][c] - p[r][c-k] + p[r-k][c-k], max_)

print(min(min_, max_, k*k-min_, k*k-max_))

조만간 내 힘으로 다시 풀어보자..!

0개의 댓글