Segment Tree - 세그먼트 트리

전윤지·2021년 6월 4일
0

알고리즘

목록 보기
1/1

1. 세그먼트 트리 (Segment Tree)

  • 여러 개의 데이터가 연속적으로 존재 할 때, 특정 범위의 합을 빠르게 구할 수 있음
    => A~B까지의 합을 구하는데 용이하게 사용

[ 선형 탐색, 세그먼트 트리 비교 ]

1) i번째 데이터를 b로 바꾸기
2) a~c의 데이터의 합 구하기
이 두 가지 단계를 반복한다고 하자.

단순 배열을 사용?

  • 수를 바꾸는데 O(1)
  • 수를 더하는데 O(N)
  • M번 수행한다고 하면, O(MN)의 시간 복잡도를 가짐

세그먼트 트리를 사용?

  • 수를 바꾸는 과정 O(logN)
  • 수를 더하는 과정 O(logN)
  • M번 실행한다치면, O(MlogN)의 시간 복잡도를 가짐

2. 세그먼트 트리 알고리즘

1) 배열 크기 결정

  • 세그먼트 트리를 만들기 위해, 배열의 크기를 얼마나 할당할 것인지 정해야 함
  • 보통 4*N으로 할당 (N:데이터의 개수)
  • 위의 예시 그림에선 4*12=48 크기 만큼을 할당

2) 구간 합 트리 생성

  • 최상단 노드 (root node)에는 전체 원소를 더한 값이 들어감

  • 이후에는 위의 사진에 맞춰, 노드의 index 범위의 합을 구한다
  • 두번째, 세번째 노드를 구하려면?
    • 두번째 : 인덱스 0~5번째 원소의 합
    • 세번째 : 인덱스 6~11번째 원소의 합
      즉, 원래 데이터 범위를 반씩 분할하면서, 그 구간의 합들을 저장
  • 네번째, 다섯번째...도 동일하게 구해줌
    • 네번째 : 인덱스 0~2번째 원소의 합
    • 다섯번째 : 인덱스 3~5번째 원소의 합
  • 구간 합 트리는 인덱스 1부터 시작
    => 1부터 시작하면, 2를 곱했을 때 왼쪽 자식 노드를 가리키므로 효과적으로 계산 가능하기 때문!

[ 최종적으로 생성 된 구간 합 트리 ]


3. Python으로 세그먼트 트리 구현하기

1) init() 구현

  • 구간 합 트리 생성
  • 재귀적으로 init() 호출 해 tree 생성

(1) 배열의 크기를 4*N으로 설정 (N:데이터의 개수)
(2) init() 함수가 (node,start,end)를 가지게 설정
(3) (1,0,N)부터 시작
(4) node가 leaf node일 때 (start=end+1)일 때, tree[node]를 설정하고 반환
(5) node가 leaf node가 아닐 때, (2x,start,mid), (2x+1,mid,end)의 함수값 2개를 더해 tree[node]에 저장 후 반환

class SegTree:
    # (1)~(3) 설정
    def __init__(self,N,A):
            self.A=A
            self.tree=[0]*4*N
            self.init(1,0,N)
        
    # (4),(5) 설정
    def init(self,node,left,right):
        if left+1==right:
            self.tree[node]=self.A[left]
        else:
            mid=(left+right)//2
            self.tree[node]=self.init(node*2,left,mid)+self.init(node*2+1,mid,right)
        return self.tree[node]

2) sum (구간 합 구하기)

  • 항상 O(logN)의 시간 복잡도를 가짐

[start,end)의 범위 합을 구할 때, 현재 노드가 (node,left,right)이면 다음과 같은 3가지 경우가 생김

  • [left,right)가 [start,end)에 완전히 포함되는 경우
    => tree[node] 반환
  • [left,right)가 [start,end)에 부분만 포함되는 경우
    => [left,mid), [mid,right)의 범위로 다시 합을 구해 더하고, 그 값을 반환 함
  • [left,right)와 [start,end)가 교집합이 없는 경우
    => 고려해주지 않아도 된다 (0 return)
 def sum(self,node,left,right,start,end):
        if start<=left and right<=end:
            return self.tree[node]
        if right<=start or end<=left:
            return 0
        mid=(left+right)//2
        return self.sum(node*2,left,mid,start,end)+self.sum(node*2+1,mid,right,start,end)

3) update (값 변경)

  • 특정 원소 값 (target)을 수정 할 때에, 해당 원소를 포함하고 있는 모든 구간의 합 node들을 갱신해야 함

update시에 2가지 고려사항이 있음

  • [left,ringt)안에 target번째가 속한 경우
    => tree[node]+=value 실행 후,
    node가 leaf node가 아니라면 [left,mid), [mid,rignt)범위로 다시 함수 실행
  • 아닌 경우
    => 고려하지 않아도 됨
    def update(self,node,left,right,target,value):
        if left<=target<right:
            self.tree[node]+=value
            if left+1==right: return
            
            # leaf node가 아닌경우, 해당 원소를 포함 한 다른 node 찾아서 update함
            mid=(left+right)//2
            self.update(node*2,left,mid,target,value)
            self.update(node*2+1,mid,right,target,value)

0개의 댓글