세그먼트 트리는 대규모 Data Set에서 구간 합, 최대/최소 값을 빠르게 구하기 위해 사용하는 자료구조이다. 세그먼트 트리는 아래의 3단계로 구성된다.
① Tree 배열 초기화
2 ^ k ≥ N
을 만족하는 k의 최소 값에 대해 2 ^ k * 2
를 계산한 수치가 된다.2 ^ k ≥ N
: N 개의 모든 데이터를 담기 위해 필요한 리프 노드의 최소 개수2 ^ k * 2
: 트리 배열의 사이즈가 리프 노드 개수의 2배가 되는 전 이진 트리의 속성을 이용한 것이다. (배열의 인덱스는 1부터 시작)② 리프 노드에 데이터 할당
2 ^ k
이므로, 2 ^ k
번부터 차례대로 데이터를 할당한다.③ 나머지 노드에 값 할당
2i
, 오른쪽 자식은 2i + 1
인 것을 이용하여 구간 합, 최대 값, 최소 값 중 알아내고자 하는 값을 계산한다.위에 나타난 트리 배열을 트리 형식으로 나타내면 아래와 같다.
① Query Range를 트리 배열에 알맞게 변경
2 ^ k - 1
만큼 평행 이동하면 된다.② Query 값 구하기
start = (start + 1) // 2
end = (end - 1) // 2
기존 배열이 변경될 경우, 세그먼트 트리를 그에 맞게 업데이트 해야 할 것이다. 세그먼트 트리의 데이터를 업데이트 하는 방법을, 구간 합 트리와 최대/최소 값 트리에서 각각 알아보자.
① 구간 합
② 최대/최소 값
일반 구간 합 배열과 달리, 세그먼트 트리는 데이터의 업데이트를 효율적으로 처리할 수 있다. 따라서 데이터의 변경이 없는 상황에서 구간 합을 구하는 문제는 구간 합 배열로 풀이하고, 데이터의 변경이 잦은 상황에서 구간 합을 구하는 문제는 세그먼트 트리로 풀이해야 한다. (구간 합 배열에 대해서는 아래의 포스팅을 참조하라.
>> 구간 합 배열
데이터의 변경이 잦은 상황에서 구간 합을 구하는 문제이므로, 전형적인 세그먼트 트리 문제이다. (만약 이 문제를 기존 배열에서 직접 구간 합을 구하는 방식으로 구현하거나, 구간 합 배열로 구현할 경우 시간 초과가 발생할 것이다.) 지금까지 배운 내용을 바탕으로 위 문제를 풀어보자.
import sys
input = sys.stdin.readline
N, M, K = map(int, input().split())
length = N
height = 0 # k 값
while length > 0: # 2 ^ k ≥ N을 만족하는 k의 최소 값을 구하는 과정
length //= 2
height += 1
treeSize = 2 ** (height + 1) # 트리 배열의 사이즈(= 2 ^ k * 2)
leafStartIdx = 2 ** height # 리프 노드의 시작 인덱스
tree = [0] * (treeSize)
# 데이터를 리프 노드에 저장
for i in range(leafStartIdx, leafStartIdx + N): # 기존 배열의 원소 수가 리프 노드 수보다 작거나 같으므로, treeSize를 끝 범위로 설정하면 안됨.
tree[i] = int(input())
# 나머지 노드에 값 할당
def setTree(i):
while i > 1: # 루트 노드까지만 구하면 됨.
tree[i // 2] += tree[i] # 자식 노드의 인덱스가 2i, 2i + 1인 것을 이용하여 구간 합 계산
i -= 1
# 데이터 업데이트 함수
def updateValue(index, value):
diff = value - tree[index] # 기존 데이터와 새로운 데이터 간의 차이
while index > 0: # 루트 노드에 이르기까지 모든 부모(조상) 노드의 값에 diff만큼을 더함
tree[index] = tree[index] + diff
index //= 2
# 구간 합 계산 함수
def getIntervalSum(start, end):
intervalSum = 0
while start <= end:
if start % 2 == 1: # start의 인덱스가 홀수이면, 해당 노드의 값을 intervalSum에 포함
intervalSum += tree[start]
if end % 2 == 0: # end의 인덱스가 짝수이면, 해당 노드의 값을 intervalSum에 포함
intervalSum += tree[end]
# 부모 노드로 이동
start = (start + 1) // 2
end = (end - 1) // 2
return intervalSum
setTree(treeSize - 1) # 높은 인덱스에서 낮은 인덱스 순서로 노드에 값을 할당
for i in range(M + K):
query, start, end = map(int, input().split())
if query == 1:
updateValue(start + leafStartIdx -1, end) # 기존 배열의 range를 세그먼트 트리의 range로 변경하하기 위해 2 ^ k - 1만큼 평행 이동
elif query == 2:
start += leafStartIdx -1
end += leafStartIdx -1
print(getIntervalSum(start, end))
트리 자료구조는 인덱스를 계산하는 부분에서 실수하기 쉽기 때문에, 이 부분을 주의 깊게 연습해보기 바란다.
구간에서의 최소 값을 구하는 문제 역시 세그먼트 트리를 활용해 풀이해야 하는 문제이다. (그래도 먼저는 일반 배열을 이용해 최소 값을 구할 때의 시간 복잡도를 계산해보기 바란다. 이 문제에서는 시간 초과가 날 것이 분명하므로, 바로 세그먼트 트리를 사용하기로 한다.)
다만, 데이터 업데이트가 발생하지 않기 때문에 직전 문제에서 구현하였던 updateValue 메서드는 생략할 수 있다.
import sys
input = sys.stdin.readline
N, M = map(int, input().split())
length = N
height = 0 # k 값
while length > 0: # 2 ^ k ≥ N을 만족하는 k의 최소 값을 구하는 과정
length //= 2
height += 1
treeSize = 2 ** (height + 1) # 트리 배열의 사이즈(= 2 ^ k * 2)
leafStartIdx = 2 ** height # 리프 노드의 시작 인덱스
tree = [sys.maxsize] * (treeSize) # 구간 합을 구할 때와 달리 최소 값을 구할 때는, 매우 큰 값으로 초기화해야 함.
# 데이터를 리프 노드에 저장
for i in range(leafStartIdx, leafStartIdx + N): # 기존 배열의 원소 수가 리프 노드 수보다 작거나 같으므로, treeSize를 끝 범위로 설정하면 안됨.
tree[i] = int(input())
# 나머지 노드에 값 할당
def setTree(i):
while i > 1: # 루트 노드까지만 구하면 됨.
tree[i // 2] = min(tree[i // 2], tree[i]) # 자식 노드의 인덱스가 2i, 2i + 1인 것을 이용하여 구간 합 계산
i -= 1
# 최소 값 계산 함수
def getMin(start, end):
Min = sys.maxsize # 구간 합을 구할 때와 달리, 매우 큰 값으로 초기화
while start <= end:
if start % 2 == 1: # start의 인덱스가 홀수이면, 해당 노드의 값이 최소 값인지 검사
Min = min(Min, tree[start])
if end % 2 == 0: # end의 인덱스가 짝수이면, 해당 노드의 값이 최소 값인지 검사
Min = min(Min, tree[end])
# 부모 노드로 이동
start = (start + 1) // 2
end = (end - 1) // 2
return Min
setTree(treeSize - 1) # 높은 인덱스에서 낮은 인덱스 순서로 노드에 값을 할당
for i in range(M):
start, end = map(int, input().split())
start += leafStartIdx -1
end += leafStartIdx -1
print(getMin(start, end))
구간 합을 구할 때와 최소 값을 구할 때, 어떤 부분이 달라지는지를 유심히 살펴보기 바란다.
구간 합이 아닌 구간 곱이라는 점만 빼면, 1번 문제와 유사해보이기에 세그먼트 트리를 이용하여 풀이하기로한다. 다만, 한 가지 주의해야 할 것은 구간 합 세그먼트 트리에서 데이터를 업데이트 할 때에는 변경된 자식의 옆 자식은 고려하지 않아도 되었지만, 구간 곱 세그먼트 트리에서는 변경된 자식의 옆 자식까지 고려하여 부모 노드의 값을 다시 계산해야 한다는 것이다.
예를 들어, 왼쪽 자식이 2, 오른쪽 자식이 3, 부모 노드가 6이었다고 하자. 이 때, 왼쪽 자식이 4로 변경된다면, 6 // 2 * 4
로 부모 노드를 업데이트 하면 되었을 것이다. 만약 이 로직의 문제가 없다면, 오른쪽 자식은 전혀 고려하지 않아도 되었을 것이다.
그러나 위 로직은 부모 노드의 값이 0인 경우에 대해, 제대로 동작하지 않는다. 0을 0으로 나눌 때, ZeroDivisionError가 발생하면서, 프로그램이 종료되어 버릴 것이다.
따라서, 변경된 노드와 그 옆 노드의 값을 곱하여 새롭게 계산된 값으로, 부모 노드의 값을 업데이트해야 한다.
import sys
input = sys.stdin.readline
N, M, K = map(int, input().split())
length = N
height = 0 # k 값
MOD = 1000000007
while length > 0: # 2 ^ k ≥ N을 만족하는 k의 최소 값을 구하는 과정
length //= 2
height += 1
treeSize = 2 ** (height + 1) # 트리 배열의 사이즈(= 2 ^ k * 2)
leafStartIdx = 2 ** height # 리프 노드의 시작 인덱스
tree = [1] * (treeSize) # 구간 곱을 구할 때에는 1로 초기화해야 함.
# 데이터를 리프 노드에 저장
for i in range(leafStartIdx, leafStartIdx + N): # 기존 배열의 원소 수가 리프 노드 수보다 작거나 같으므로, treeSize를 끝 범위로 설정하면 안됨.
tree[i] = int(input())
# 나머지 노드에 값 할당
def setTree(i):
while i > 1: # 루트 노드까지만 구하면 됨.
tree[i // 2] = (tree[i // 2] * tree[i]) % MOD # 자식 노드의 인덱스가 2i, 2i + 1인 것을 이용하여 구간 합 계산
i -= 1
# 데이터 업데이트 함수
def updateValue(index, value):
tree[index] = value
index //= 2
while index > 0: # 루트 노드에 이르기까지 모든 부모(조상) 노드의 값을 새로 구함(변경된 노드의 옆 노드까지 고려)
tree[index] = (tree[index * 2] * tree[index * 2 + 1]) % MOD
index //= 2
def getIntervalMul(start, end):
intervalMul = 1
while start <= end:
if start % 2 == 1: # start의 인덱스가 홀수이면, 해당 노드의 값을 intervalMul에 포함
intervalMul = (intervalMul * tree[start]) % MOD
if end % 2 == 0: # end의 인덱스가 짝수이면, 해당 노드의 값을 intervalMul에 포함
intervalMul = (intervalMul * tree[end]) % MOD
# 부모 노드로 이동
start = (start + 1) // 2
end = (end - 1) // 2
return intervalMul
setTree(treeSize - 1) # 높은 인덱스에서 낮은 인덱스 순서로 노드에 값을 할당
for i in range(M + K):
query, start, end = map(int, input().split())
if query == 1:
updateValue(start + leafStartIdx -1, end) # 기존 배열의 range를 세그먼트 트리의 range로 변경하하기 위해 2 ^ k - 1만큼 평행 이동
elif query == 2:
start += leafStartIdx -1
end += leafStartIdx -1
print(getIntervalMul(start, end))
중요한 내용은 아니지만 참고할만한 내용으로, 문제에서 "1,000,000,007로 나눈 나머지를 출력하라"라는 말의 의미에 대해서 간단히 알아보기로 하자. 위 문제와 같이 매우 큰 수에 대한 곱셈을 계산하는 문제는 정수 Overflow가 발생하기 쉽다. 물론, 파이썬에서는 매우 큰 숫자도 모두 허용하기 때문에 큰 문제가 안 되지만, 코딩 테스트 문제는 파이썬이 아닌 다른 언어로도 풀이할 수 있어야 하기 때문에, 주로 구해진 값의 나머지를 출력하도록 문제를 출제한다.
그러면 어떤 수에 대해 나눈 나머지를 출력해야 할까? 주로 사용되는 숫자는 1000000007(10억 7) 또는 1000000009(10억 9)이다. 이 숫자를 사용하는 이유는 int 값의 범위와 관련이 깊다. C/C++ 기준 int의 범위는 -2147483648 ~ 2147483647
로 약 -20억부터 +20억까지이다. 즉, int 범위의 절반 정도의 연산만 허용하겠다는 의미인 것이다. (10억과 가장 가까운 두 소수를 선택한 것이다.)
int 범위의 절반만 사용하는 이유는 int 범위에 근접한 값을 사용할 경우, 허용된 값끼리의 덧셈만으로도 Overflow가 발생할 수 있으며, int 범위에 비해 턱 없이 작은 값을 사용하면 효율성이 떨어지기 때문에, 적절하게 int 범위의 절반 정도만 사용하기로 한 것이다.
하지만, 여전히 허용된 범위 내에서의 곱셈에서 정수 Overflow가 발생할 가능성이 존재한다. 이 문제를 해결하기 위해 아래와 같은 연산자 분배법칙을 사용할 수 있다.
(A + B) % M = ((A % M) + (B % M)) % M
(A - B) % M = ((A % M) - (B % M)) % M
(A * B) % M = ((A % M) * (B % M)) % M
(A * B) % M = ((A % M) * (B % M)) % M
는 A * B에서 Overflow가 나는 것을 방지하는 유용한 항등식이다. 다만, 파이썬에서는 정수 Overflow가 거의 발생하지 않기 때문에 직접적으로 사용할 일은 없을 것이다. 사실 출력되는 결과에만 Modulus 연산을 적용해도, 문제 풀이에 전혀 지장이 없는 경우가 대부분이다.
그러나 곱셈 결과에 대한 Modulus를 출력하는 문제에서는, 자칫 곱셈의 값이 너무 방대해져 시간 초과가 발생할 수 있다. 위의 구간 곱 구하기 문제가 바로 이러한 문제이다. 구간 곱의 결과를 Modulus 없이 세그먼트 트리에 저장하다보면, 곱셈의 결과가 매우 방대해지면서 결국 시간 초과가 발생한다. 따라서, 이러한 경우에는 곱셈 결과에 바로 Modulus를 적용하여 풀이해야 한다.