[알고리즘] 세그먼트 트리(Segment Tree)

bbolddagu·2023년 5월 8일
0

알고리즘

목록 보기
4/11

📌세그먼트 트리란?

세그먼트 트리(Segment Tree)는 배열 또는 리스트와 같은 자료구조에서 구간 쿼리를 효율적으로 처리하기 위한 자료구조입니다.
주로 배열의 구간 합, 최소값, 최대값 등을 효율적으로 계산하는 데 사용됩니다.

  • 세그먼트 트리는 트리 구조로 표현되며, 각 노드는 배열의 일부 구간에 대한 정보를 저장합니다. 트리의 루트 노드는 전체 배열에 대한 정보를 가지고 있고, 하위 노드로 내려갈수록 구간이 반으로 줄어들어 구체적인 구간에 대한 정보를 저장합니다.

  • 세그먼트 트리는 일반적으로 재귀 함수를 이용해 구현하며, 배열의 크기가 2의 거듭제곱이 아닐 경우 2의 거듭제곱으로 맞춰주기 위해 추가적인 노드를 만들어야 합니다. 또한, 세그먼트 트리를 이용해 구간 합을 구할 때는 구간의 시작점과 끝점을 알고 있어야 합니다.

  • 세그먼트 트리는 배열의 크기에 대해 O(N log N)의 공간 복잡도를 가지며, 구간 쿼리와 업데이트 연산을 O(log N)의 시간 복잡도로 처리할 수 있습니다.

✋주의! tree 크기 할당✋
일반적으로 세그먼트 트리는 완전 이진 트리의 형태를 갖기 때문에, 입력 배열의 크기 n에 대해 트리의 크기는 보통 4n보다 크거나 같게 설정됩니다. 만약 입력 배열의 크기 n이 2의 거듭제곱인 경우에는 세그먼트 트리의 크기는 2n과 같습니다.

🌟세그먼트 트리 구현 (Python)

# 예
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# 실제로는 데이터의 개수 N에 4를 곱한 크기만큼 미리 세그먼트 트리의 공간을 할당한다.
tree = [0] * (len(arr) * 4)


def build_tree(node, start, end, arr, tree):
    if start == end:
        tree[node] = arr[start]
    else:
        mid = (start + end) // 2
        build_tree(2 * node, start, mid, arr, tree)
        build_tree(2 * node + 1, mid + 1, end, arr, tree)
        tree[node] = tree[2 * node] + tree[2 * node + 1]
        

def update_tree(node, start, end, idx, val, tree):
    if start == end:
        tree[node] = val
    else:
        mid = (start + end) // 2
        if start <= idx <= mid:
            update_tree(2 * node, start, mid, idx, val, tree)
        else:
            update_tree(2 * node + 1, mid + 1, end, idx, val, tree)
        tree[node] = tree[2 * node] + tree[2 * node + 1]
        

def query_tree(node, start, end, l, r, tree):
    if r < start or end < l:
        return 0
    if l <= start and end <= r:
        return tree[node]
    mid = (start + end) // 2
    p1 = query_tree(2 * node, start, mid, l, r, tree)
    p2 = query_tree(2 * node + 1, mid + 1, end, l, r, tree)
    return p1 + p2
    

build_tree(1, 0, len(arr) - 1, arr, tree)
print(query_tree(1, 0, len(arr) - 1, 0, 9))  # 0부터 9까지의 구간 합 (1 + 2 + ... + 9 + 10)
print(query_tree(1, 0, len(arr) - 1, 0, 2))  # 0부터 2까지의 구간 합 (1 + 2 + 3)
print(query_tree(1, 0, len(arr) - 1, 6, 7))  # 0부터 2까지의 구간 합 (7 + 8)

# arr[0]을 +4만큼 수정
update_tree(1, 0, len(arr) - 1, 0, 4)
print(query_tree(1, 0, len(arr) - 1, 0, 2))   # 0부터 2까지의 구간 합 ((1 + 4) + 2 + 3)

# arr[9]를 -11만큼 수정
update_tree(1, 0, len(arr) - 1, 9, -11)
print(query_tree(1, 0, len(arr) - 1, 8, 9))   # 8부터 9까지의 구간 합 (9 + (10 - 11))

동작과정

  • update_tree 함수

    • startendidx를 포함하는 구간인지 확인합니다.
    • startend가 동일하다면 해당 노드는 갱신할 인덱스를 표현하고 있습니다. 따라서 해당 노드 값을 val로 갱신합니다.
    • 그렇지 않은 경우 구간을 반으로 나누어 재귀적으로 update_tree 함수를 호출합니다.
      • 만약 idx가 왼쪽 구간 [start, mid]에 속한다면, 왼쪽 자식 노드인 2 * node를 탐색하여 update_tree 함수를 호출합니다.
      • 그렇지 않으면 idx는 오른쪽 구간 [mid + 1, end]에 속하므로, 오른쪽 자식 노드인 2 * node + 1을 탐색하여 update_tree 함수를 호출합니다.
    • 갱신된 값으로 상위 노드의 값을 업데이트합니다.
  • query_tree 함수

    • 현재 노드가 담당하는 구간 [start, end]와 주어진 구간 [l, r]이 아예 겹치지 않는 경우:

      • 현재 노드의 값을 0으로 반환하고 재귀 호출을 종료합니다.
    • 현재 노드가 담당하는 구간 [start, end]가 주어진 구간 [l, r]에 완전히 포함되는 경우:

      • 현재 노드의 값을 반환합니다. (트리 노드에 저장된 구간의 합)
    • 위의 두 경우에 해당하지 않는 경우:

      • 현재 노드의 구간을 [start, end]라고 하면, 구간 [l, r]와 [start, end]이 부분적으로 겹치는 경우입니다.
      • 현재 노드의 구간을 두 개의 하위 구간으로 분할하여 재귀 호출합니다.
      • 재귀 호출을 통해 얻은 두 하위 구간의 합을 더하여 반환합니다.

📒참고

0개의 댓글