세그먼트 트리를 사용하는 곳과 이유
- 구간 합을 구할 때 자주 사용한다.
- 구간 합을 구하기 위해서는 단순히 for 문을 사용하면 O(N)의 시간 복잡도가 나오지만 세그먼트 트리를 이용하면 O(logN)의 시간 복잡도를 가지기 때문에 더 빠른 처리가 가능하다.
세그먼트 트리

- 배열 A는 인덱스 0부터 시작한다.
- 트리 B는 인덱스 1부터 시작한다. 그 이유는 이렇게 할 경우 왼쪽 자식 노드의 인덱스는 부모 노드의 2배, 오른쪽 자식 노드의 인덱스는 부모 노드의 2배 + 1이 되기 대문이다.
- 노드 위에 쓰여 있는 값은 배열 A를 기준으로 합의 구간이다.
세그먼트 트리 초기화 방법
- 배열 A를 이용하여 세그먼트 트리 배열을 초기화해야한다. 이때 노드 번호가 트리 배열의 인덱스가 되며 다음과 같다.
int init(int start, int end, int node) {
if (start == end) {
tree[node] = A[start];
return tree[node];
}
int mid = (start + end) / 2;
int left = init(start, mid, node * 2);
int right = init(mid + 1, end, node * 2 + 1);
tree[node] = left + right;
return tree[node];
}
- start == end인 경우는 누적합의 시작과 끝이 같다는 의미로 이는 곧 start만을 의미하여 tree[node]에 A[start]를 넣는다.
- mid는 시작과 끝의 중간 값으로 이를 통해서 왼쪽 자식 노드의 끝, 오른쪽 자식 노드의 시작을 표현할 수 있다.
- left는 왼쪽 자식 노드의 값으로 이는 왼쪽 자식 노드 까지의 구간 합이라고 볼 수 있다.
- right는 오른쪽 자식 노드의 값으로 이는 오른쪽 자식 노드까지의 구간 합이라고 볼 수 있다.
- 해당 노드의 값은 left와 right를 더한 값이 된다.
- 재귀적 방식으로 초기화한다.
세그먼트 트리에서 구간합 구하는 방법
int sum(int start, int end, int node, int left, int right) {
if (left > end || right < start) {
return 0;
}
if (left <= start && end <= right) {
return tree[node];
}
int mid = (start + end) / 2;
return sum(start, mid, node * 2, left, right) + sum(mid + 1, end, node * 2 + 1, left, right);
}
세그먼트 트리에서 값을 업데이트 하는 방법
- 단순히 값을 업데이트하고 세그먼트 트리를 사용하면 O(N)의 시간복잡도가 갱신 될 때마다 발생한다.
void update(int start, int end, int node, int index, int value) {
if (index > end || index < start) {
return;
}
if (start == end) {
tree[node] = value;
return;
}
int mid = (start + end) / 2;
if (mid < index) {
update(mid + 1, end, node * 2 + 1, index, value);
}
else {
update(start, mid, node * 2, index, value);
}
tree[node] = tree[node * 2] + tree[node * 2 + 1];
}
- 우선 해당 인덱스 번호에 해당하는 노드를 찾을 때까지 재귀적으로 들어가서 찾는다.
- start == end인 경우 찾은 경우이며 이때 해당 노드 값을 업데이트한다.
- 재귀적으로 빠져 나오면서 부모 노드의 값을 자식 노드들의 합으로 업데이트한다.
- 이렇게 하면 시간 복잡도가 O(logN)이 된다.
세그먼트 트리 예시 코드
#include <iostream>
using namespace std;
int A[10001] = { 0 };
int tree[100000] = { 0 };
int init(int start, int end, int node) {
if (start == end) {
tree[node] = A[start];
return tree[node];
}
int mid = (start + end) / 2;
int left = init(start, mid, node * 2);
int right = init(mid + 1, end, node * 2 + 1);
tree[node] = left + right;
return tree[node];
}
int sum(int start, int end, int node, int left, int right) {
if (left > end || right < start) {
return 0;
}
if (left <= start && end <= right) {
return tree[node];
}
int mid = (start + end) / 2;
return sum(start, mid, node * 2, left, right) + sum(mid + 1, end, node * 2 + 1, left, right);
}
void update(int start, int end, int node, int index, int value) {
if (index > end || index < start) {
return;
}
if (start == end) {
tree[node] = value;
return;
}
int mid = (start + end) / 2;
if (mid < index) {
update(mid + 1, end, node * 2 + 1, index, value);
}
else {
update(start, mid, node * 2, index, value);
}
tree[node] = tree[node * 2] + tree[node * 2 + 1];
}
int main() {
int N;
cin >> N;
for (int i = 0; i < N; i++) {
int tmp;
cin >> tmp;
A[i] = tmp;
}
init(0, N - 1, 1);
cout << sum(0, N - 1, 1, 2, 4) << endl;
update(0, N - 1, 1, 1, 4);
cout << sum(0, N - 1, 1, 1, 4) << endl;
return 0;
}