세그먼트 트리(Segment Tree)는 배열 또는 리스트와 같은 자료구조에서 구간 쿼리를 효율적으로 처리하기 위한 자료구조입니다.
주로 배열의 구간 합, 최소값, 최대값 등을 효율적으로 계산하는 데 사용됩니다.
세그먼트 트리는 트리 구조로 표현되며, 각 노드는 배열의 일부 구간에 대한 정보를 저장합니다. 트리의 루트 노드는 전체 배열에 대한 정보를 가지고 있고, 하위 노드로 내려갈수록 구간이 반으로 줄어들어 구체적인 구간에 대한 정보를 저장합니다.
세그먼트 트리는 일반적으로 재귀 함수를 이용해 구현하며, 배열의 크기가 2의 거듭제곱이 아닐 경우 2의 거듭제곱으로 맞춰주기 위해 추가적인 노드를 만들어야 합니다. 또한, 세그먼트 트리를 이용해 구간 합을 구할 때는 구간의 시작점과 끝점을 알고 있어야 합니다.
✋주의! tree 크기 할당✋
일반적으로 세그먼트 트리는 완전 이진 트리의 형태를 갖기 때문에, 입력 배열의 크기 n에 대해 트리의 크기는 보통 4n보다 크거나 같게 설정됩니다. 만약 입력 배열의 크기 n이 2의 거듭제곱인 경우에는 세그먼트 트리의 크기는 2n과 같습니다.
# 예
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# 실제로는 데이터의 개수 N에 4를 곱한 크기만큼 미리 세그먼트 트리의 공간을 할당한다.
tree = [0] * (len(arr) * 4)
def build_tree(node, start, end, arr, tree):
if start == end:
tree[node] = arr[start]
else:
mid = (start + end) // 2
build_tree(2 * node, start, mid, arr, tree)
build_tree(2 * node + 1, mid + 1, end, arr, tree)
tree[node] = tree[2 * node] + tree[2 * node + 1]
def update_tree(node, start, end, idx, val, tree):
if start == end:
tree[node] = val
else:
mid = (start + end) // 2
if start <= idx <= mid:
update_tree(2 * node, start, mid, idx, val, tree)
else:
update_tree(2 * node + 1, mid + 1, end, idx, val, tree)
tree[node] = tree[2 * node] + tree[2 * node + 1]
def query_tree(node, start, end, l, r, tree):
if r < start or end < l:
return 0
if l <= start and end <= r:
return tree[node]
mid = (start + end) // 2
p1 = query_tree(2 * node, start, mid, l, r, tree)
p2 = query_tree(2 * node + 1, mid + 1, end, l, r, tree)
return p1 + p2
build_tree(1, 0, len(arr) - 1, arr, tree)
print(query_tree(1, 0, len(arr) - 1, 0, 9)) # 0부터 9까지의 구간 합 (1 + 2 + ... + 9 + 10)
print(query_tree(1, 0, len(arr) - 1, 0, 2)) # 0부터 2까지의 구간 합 (1 + 2 + 3)
print(query_tree(1, 0, len(arr) - 1, 6, 7)) # 0부터 2까지의 구간 합 (7 + 8)
# arr[0]을 +4만큼 수정
update_tree(1, 0, len(arr) - 1, 0, 4)
print(query_tree(1, 0, len(arr) - 1, 0, 2)) # 0부터 2까지의 구간 합 ((1 + 4) + 2 + 3)
# arr[9]를 -11만큼 수정
update_tree(1, 0, len(arr) - 1, 9, -11)
print(query_tree(1, 0, len(arr) - 1, 8, 9)) # 8부터 9까지의 구간 합 (9 + (10 - 11))
update_tree
함수
start
와 end
가 idx
를 포함하는 구간인지 확인합니다.start
와 end
가 동일하다면 해당 노드는 갱신할 인덱스를 표현하고 있습니다. 따라서 해당 노드 값을 val
로 갱신합니다.update_tree
함수를 호출합니다.idx
가 왼쪽 구간 [start
, mid
]에 속한다면, 왼쪽 자식 노드인 2 * node
를 탐색하여 update_tree
함수를 호출합니다.idx
는 오른쪽 구간 [mid + 1
, end
]에 속하므로, 오른쪽 자식 노드인 2 * node + 1
을 탐색하여 update_tree
함수를 호출합니다.query_tree
함수
현재 노드가 담당하는 구간 [start
, end
]와 주어진 구간 [l
, r
]이 아예 겹치지 않는 경우:
현재 노드가 담당하는 구간 [start
, end
]가 주어진 구간 [l
, r
]에 완전히 포함되는 경우:
위의 두 경우에 해당하지 않는 경우:
start
, end
]라고 하면, 구간 [l
, r
]와 [start
, end
]이 부분적으로 겹치는 경우입니다.