세그먼트 트리는 구간을 업데이트 해 주면서 여러가지 알고리즘을 구현하기에 편리한 트리 구조이다.
특히 구간 합을 구할 때 많이 사용하며 일반적으로 구간 합을 구할 때 예시를 보게 된다면,
// sum : 2 to 5
int arr[5] = { 1, 2, 3, 4, 5 };
for(int i = 1; i< 5; i++) {
int sum += arr[i];
}
// time complexity : O(n)
이렇게 배열 인덱스에 하나씩 접근하는 방법이 있다. 이렇게 작성하게 된다면, 시간 복잡도가 O(N)이므로, 데이터가 커질수록 느린 속도를 가질 수 있다. 이런 경우 사용하는것이 O(logN) 만큼 걸리는 세그먼트 트리이다.
쿼리 부분을 어떻게 작성하는지에 따라 여러가지 용도로 사용이 된다.
원리는 그림과 같다.
세그먼트 트리를 하기 전, 꼭 알아두어야 할 사항은 아래와 같다.
Root node 기준 (Root node index : 1)
- 왼쪽 노드 : node * 2
- 오른쪽 노드 : node * 2 + 1
- 부모 노드 : node / 2
세그먼트 트리를 생성하는 코드는 다음과 같다.
// seg tree를 위한 struct
typedef struct tree {
ll value;
ll lazy; // lazy propagation
}tree;
int tree_size = 0;
ll init(tree *T,int node,int start,int end) {
if(start == end) return T[node].value = v[start];
else {
ll mid = (start + end) / 2;
return T[node].value = init(T,node *2, start,mid) + init(T,node * 2 + 1, mid +1, end);
}
}
그림을 보게 되면 F의 값이 새로 추가되는 과정이다. 값이 추가되면, 부모 노드를 거쳐서 Root 노드까지 계속 더해주면 된다.
void update(tree *T, int node,int start,int end, int i, int j, ll dif) {
if(T[node].lazy != 0) {
T[node].value += (end-start + 1) * T[node].lazy;
if(start != end) {
T[node * 2].lazy += T[node].lazy;
T[node * 2 + 1].lazy += T[node].lazy;
}
T[node].lazy = 0;
}
if(j < start || i > end) return;
if(i <= start && end <= j) {
T[node].value += (end - start + 1) * dif;
if(start != end) {
T[node * 2].lazy += dif;
T[node * 2 + 1].lazy += dif;
}
return;
}
int mid = (start + end) / 2;
update(T,node * 2, start,mid,i,j,dif);
update(T,node * 2 + 1, mid+1, end,i,j,dif);
T[node].value = T[node * 2].value + T[node * 2 + 1].value;
}
세그먼트 트리의 구간 합 구하는 코드는 아래와 같다.
ll segtree_sum(tree *T,int node, int start, int end, int i, int j) {
if(T[node].lazy != 0) {
T[node].value += (end - start + 1) * T[node].lazy;
if(start != end) {
T[node *2].lazy += T[node].lazy;
T[node * 2 + 1].lazy += T[node].lazy;
}
T[node].lazy = 0;
}
if(i > end || j < start) return 0;
if(i <= start && end <= j) return T[node].value;
ll mid = (start + end) / 2;
return segtree_sum(T,node *2,start,mid,i,j) + segtree_sum(T, node * 2 + 1,mid +1, end,i,j);
}
i와 j는 i(시작 점) 부터 j(끝 점) 까지의 구간이며, 예외처리 후에 재귀를 통하여 계속 세그먼트 트리의 합을 구하여 준다.
풀 코드 (백준 2042)
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <vector>
#define MAX_LENGTH 1000001
#define ll long long int
using namespace std;
typedef struct tree {
ll value;
ll lazy;
}tree;
ll v[MAX_LENGTH];
ll init(tree *T,int node,int start,int end) {
if(start == end) return T[node].value = v[start];
else {
ll mid = (start + end) / 2;
return T[node].value = init(T,node *2, start,mid) + init(T,node * 2 + 1, mid +1, end);
}
}
void update(tree *T, int node,int start,int end, int i, int j, ll dif) {
if(T[node].lazy != 0) {
T[node].value += (end-start + 1) * T[node].lazy;
if(start != end) {
T[node * 2].lazy += T[node].lazy;
T[node * 2 + 1].lazy += T[node].lazy;
}
T[node].lazy = 0;
}
if(j < start || i > end) return;
if(i <= start && end <= j) {
T[node].value += (end - start + 1) * dif;
if(start != end) {
T[node * 2].lazy += dif;
T[node * 2 + 1].lazy += dif;
}
return;
}
int mid = (start + end) / 2;
update(T,node * 2, start,mid,i,j,dif);
update(T,node * 2 + 1, mid+1, end,i,j,dif);
T[node].value = T[node * 2].value + T[node * 2 + 1].value;
}
ll segtree_sum(tree *T,int node, int start, int end, int i, int j) {
if(T[node].lazy != 0) {
T[node].value += (end - start + 1) * T[node].lazy;
if(start != end) {
T[node *2].lazy += T[node].lazy;
T[node * 2 + 1].lazy += T[node].lazy;
}
T[node].lazy = 0;
}
if(i > end || j < start) return 0;
if(i <= start && end <= j) return T[node].value;
ll mid = (start + end) / 2;
return segtree_sum(T,node *2,start,mid,i,j) + segtree_sum(T, node * 2 + 1,mid +1, end,i,j);
}
int main() {
tree *T;
int n,m,k;
scanf("%d %d %d",&n,&m,&k);
for(int i = 1; i<=n; i++) {
scanf("%lld",&v[i]);
}
T = (tree *)malloc(sizeof(tree) * 4 * MAX_LENGTH);
init(T,1,1,n);
ll change_value = 0;
for(int i = 1; i<=m+k; i++) {
int a,b,c;
ll d;
scanf("%d",&a);
if(a == 1) {
scanf("%d %lld",&b,&d);
if(v[b] != d) {
if((d < 0 && v[b] < 0)){
change_value = d - v[b];
}
else {
change_value = d - v[b];
}
v[b] = d;
}
update(T,1,1,n,b,b,change_value);
change_value = 0;
}
else {
scanf("%d %d",&b,&c);
cout<<segtree_sum(T,1,1,n,b,c)<<endl;
}
}
free(T);
return 0;
}
다음에는 이 코드에서 쓰인 lazy propagation에 대하여 작성을 해 볼 예정이다.