세그먼트 트리

개발자·2022년 3월 4일
0
post-thumbnail

📌세그먼트 트리 (Segment Tree)

  • 여러개의 데이터가 존재할 때, 특정 구간의 합을 구하는데 사용되는 자료구조 이다.

  • Tree 종류 중 하나이며, 이진트리 형태로 O(log N)의 시간으로 구할 수 있다.




💬세그먼트 트리 사용이유

10칸 짜리의 배열 int arr = new int[10] 이 존재한다고 생각해보자.

여기서 만약에 2번째 부터 8번째 까지의 합을 구한다고 생각해보자.
그렇다면 2+3+4+5+6+7+8 = 35로 정답을 구할 수 있다.
이렇게 한다면 데이터의 개수가 N개일때 시간복잡도가 O(N) 이 걸리므로 너무 느리게 된다.

만약 이 상황에서 4번째 숫자를 바꿔라. 이렇게 연산이 된다면 더욱 느리게 된다. 이럴 때 더 좋은 방법으로 사용하는것이 세그먼트 트리이다!!



💬세그먼트 트리의 구조

  • 노드안에 적혀있는 숫자들은 배열의 index번호를 의미한다.

  • 리프노드 = 배열의 값

  • 루트노드 = 배열 전체 index값들의 합

  • 이외의 노드 = 자식노드들의 값들에 대한 연산 결과



📌세그먼트 트리 구현

💬 Init

<세그먼트 트리를 배열의 각 구간 합으로 채워주기>

start : 배열의 시작 인덱스, end : 배열의 마지막 인덱스
index : 세그먼트 트리의 인덱스 (무조건 1부터 시작)

public int init(int start, int end, int node) {
		if(start == end) { /* 리프노드이거나 자식노드들이 구간합이 모두구해졌을 경우 */
			return tree[node] = arr[start]; /* 구간합 트리에 넣어준다 */
		}
		/* 반씩 나눠서  재귀적으로 자식노드들의 구간합을 구해준다 */
		int mid = (start+end)/2;
		return tree[node] = init(start, mid, node*2) + init(mid+1, end, node*2+1);
	}

즉 다음과 같이 구성된다는 뜻이다.




💬 Sum

	public int sum(int start, int end, int node, int left, int right) {
		if(left>end || right < start) {
			return 0;
		}
		if(left <=start && end <=right) {
			return tree[node];
		}
		/* 필요한 구간마다 밑에서부터 구간합을 가지고 올라온다 */
		int mid = (start+end)/2;
		return sum(start, mid, node*2, left, right) + sum(mid+1, end, node*2+1, left, right);
	}

다음은 구간의 합을 구하려할때 사용하는 함수이다.




💬 Update

dif = n번째 인덱스 숫자를 m으로 바꾼다고 한다면,
dif = m - arr[n] 이다. 즉 두 숫자의 차이를 나타낸다.

	public void update(int start, int end, int node, int index, int dif) {
		if(index < start || index > end) {
			return;
		}
		tree[node] += dif; /* 변경된 값만큼 더해주고 */
		if(start == end) {
			return;
		}
		/* 변경된 값이 속해있는 구간의 구간합을 모두 바꿔준다 */
		int mid = (start + end)/2;
		update(start, mid, node*2, index, dif);
		update(mid+1, end, node*2+1, index, dif);
	}




📌세그먼트 트리 문제 풀이

https://www.acmicpc.net/problem/2042

문제

어떤 N개의 수가 주어져 있다.
그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다.

만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다.

그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.

입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다.
M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다.
그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그
리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.

입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

출력

첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

코드

package back_joon.Data_Structures;
import java.util.*;

public class b2042 {

    static long[] input,tree;

    public static void main(String[] args) {

        Scanner sc = new Scanner(System.in);

        int N = sc.nextInt();
        int M = sc.nextInt();
        int K = sc.nextInt();

        input = new long[N+1];
        tree = new long[N*4];

        for(int i=1;i<=N;i++){
            input[i] = sc.nextLong();
        }

        init(1,N,1); // segment tree 생성.

        for(int i=0;i<M+K;i++){
            int a = sc.nextInt();
            int b = sc.nextInt();

            if(a == 1){
                long c = sc.nextLong();
                long diff = c - input[b];
                input[b] = c;
                update(1,N,1,b,diff);
            }else{
                int c = sc.nextInt();
                System.out.println(sum(1,N,1,b,c));
            }
        }


    }

    public static long init(int start,int end,int node){
        if(start == end){
            return tree[node] = input[start];
        }

        int mid = (start + end) / 2;
        return tree[node] = init(start,mid,(node*2)) + init(mid+1,end,(node*2)+1);
    }

    public static long sum(int start,int end,int node,int left,int right){
        if(left > end || right < start){
            return 0;
        }
        if(left <= start && end <= right){
            return tree[node];
        }
        int mid = (start + end) / 2;
        return sum(start,mid,node*2,left,right) + sum(mid+1,end,(node*2)+1,left,right);
    }

    public static void update(int start,int end,int node,int index,long diff){
        if(index < start || index > end){
            return;
        }
        tree[node] += diff;
        if(start == end){
            return;
        }
        int mid = (start + end) / 2;
        update(start,mid,node*2,index,diff);
        update(mid+1,end,node*2+1,index,diff);
    }
}

앞에서 만들었던 함수를 모두 그대로 이용하는 문제였다.
세그먼트트리를 이해했다면 쉽게 풀수있는 문제였습니다!

0개의 댓글