여러개의 데이터가 존재할 때, 특정 구간의 합을 구하는데 사용되는 자료구조 이다.
Tree 종류 중 하나이며, 이진트리 형태로 O(log N)의 시간으로 구할 수 있다.
10칸 짜리의 배열 int arr = new int[10] 이 존재한다고 생각해보자.
여기서 만약에 2번째 부터 8번째 까지의 합을 구한다고 생각해보자.
그렇다면 2+3+4+5+6+7+8 = 35로 정답을 구할 수 있다.
이렇게 한다면 데이터의 개수가 N개일때 시간복잡도가 O(N) 이 걸리므로 너무 느리게 된다.
만약 이 상황에서 4번째 숫자를 바꿔라. 이렇게 연산이 된다면 더욱 느리게 된다. 이럴 때 더 좋은 방법으로 사용하는것이 세그먼트 트리이다!!
노드안에 적혀있는 숫자들은 배열의 index번호를 의미한다.
리프노드 = 배열의 값
루트노드 = 배열 전체 index값들의 합
이외의 노드 = 자식노드들의 값들에 대한 연산 결과
<세그먼트 트리를 배열의 각 구간 합으로 채워주기>
start : 배열의 시작 인덱스, end : 배열의 마지막 인덱스
index : 세그먼트 트리의 인덱스 (무조건 1부터 시작)
public int init(int start, int end, int node) {
if(start == end) { /* 리프노드이거나 자식노드들이 구간합이 모두구해졌을 경우 */
return tree[node] = arr[start]; /* 구간합 트리에 넣어준다 */
}
/* 반씩 나눠서 재귀적으로 자식노드들의 구간합을 구해준다 */
int mid = (start+end)/2;
return tree[node] = init(start, mid, node*2) + init(mid+1, end, node*2+1);
}
즉 다음과 같이 구성된다는 뜻이다.
public int sum(int start, int end, int node, int left, int right) {
if(left>end || right < start) {
return 0;
}
if(left <=start && end <=right) {
return tree[node];
}
/* 필요한 구간마다 밑에서부터 구간합을 가지고 올라온다 */
int mid = (start+end)/2;
return sum(start, mid, node*2, left, right) + sum(mid+1, end, node*2+1, left, right);
}
다음은 구간의 합을 구하려할때 사용하는 함수이다.
dif = n번째 인덱스 숫자를 m으로 바꾼다고 한다면,
dif = m - arr[n] 이다. 즉 두 숫자의 차이를 나타낸다.
public void update(int start, int end, int node, int index, int dif) {
if(index < start || index > end) {
return;
}
tree[node] += dif; /* 변경된 값만큼 더해주고 */
if(start == end) {
return;
}
/* 변경된 값이 속해있는 구간의 구간합을 모두 바꿔준다 */
int mid = (start + end)/2;
update(start, mid, node*2, index, dif);
update(mid+1, end, node*2+1, index, dif);
}
https://www.acmicpc.net/problem/2042
어떤 N개의 수가 주어져 있다.
그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다.
만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다.
그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다.
M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다.
그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그
리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.
입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.
첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.
package back_joon.Data_Structures;
import java.util.*;
public class b2042 {
static long[] input,tree;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int M = sc.nextInt();
int K = sc.nextInt();
input = new long[N+1];
tree = new long[N*4];
for(int i=1;i<=N;i++){
input[i] = sc.nextLong();
}
init(1,N,1); // segment tree 생성.
for(int i=0;i<M+K;i++){
int a = sc.nextInt();
int b = sc.nextInt();
if(a == 1){
long c = sc.nextLong();
long diff = c - input[b];
input[b] = c;
update(1,N,1,b,diff);
}else{
int c = sc.nextInt();
System.out.println(sum(1,N,1,b,c));
}
}
}
public static long init(int start,int end,int node){
if(start == end){
return tree[node] = input[start];
}
int mid = (start + end) / 2;
return tree[node] = init(start,mid,(node*2)) + init(mid+1,end,(node*2)+1);
}
public static long sum(int start,int end,int node,int left,int right){
if(left > end || right < start){
return 0;
}
if(left <= start && end <= right){
return tree[node];
}
int mid = (start + end) / 2;
return sum(start,mid,node*2,left,right) + sum(mid+1,end,(node*2)+1,left,right);
}
public static void update(int start,int end,int node,int index,long diff){
if(index < start || index > end){
return;
}
tree[node] += diff;
if(start == end){
return;
}
int mid = (start + end) / 2;
update(start,mid,node*2,index,diff);
update(mid+1,end,node*2+1,index,diff);
}
}
앞에서 만들었던 함수를 모두 그대로 이용하는 문제였다.
세그먼트트리를 이해했다면 쉽게 풀수있는 문제였습니다!