[Java] 느리게 갱신되는 세그먼트 트리 (Segment Tree with Lazy Propagation)

서정범·2023년 3월 29일
0

느리게 갱신되는 세그먼트 트리

배우기에 앞서...

먼저 해당 방식의 필요성에 대해서 알아보자.

다음 문제를 확인해 봅시다.

문제

크기가 N인 정수 배열 A가 있고, 여기서 다음과 같은 연산을 최대 M번 수행해야 하는 문제가 있습니다.
1. 구간 l,r(l<=r)l, r(l <= r)이 주어졌을 때, A[l]+A[l+1]+...+A[r1]+A[r]A[l] + A[l + 1] + ... + A[r - 1] + A[r]을 구해서 출력하기
2. ii번째 수부터 jj번째 수에 vv를 더하기

세그먼트 트리에서 문제에서 2번 연산은 ii번째 수에 vv를 더하는 것이었습니다. 이번 문제의 2번 연산은 구간에 수를 더하는 것입니다. 세그먼트 트리에서 수를 변경하는 연산을 ii번째 수부터 jj번째 수까지 하나씩 하는 방식을 이용할 수 있습니다. 1번 연산은 구간의 합을 구하는 연산이라 O(logN)O(logN)입니다. 2번 연산은 수 하나를 변경하는 연산을 ji+1j - i + 1번 해야하니 O(NlogN)O(NlogN)입니다. 세그먼트 트리를 사용하지 않고 배열에서 1번, 2번 연산을 수행하는 경우 시간 복잡도 O(N)O(N)인데, 오히려 시간이 더 걸립니다.

느리게 갱신되는 세그먼트 트리를 사용하면 구간 변경을 효율적으로 수행할 수 있습니다.

lazy

나중에 변경해야 하는 값을 lazy[i]에 저장합니다.

이때 lazy의 인덱스는 node 번호라고 생각하면 됩니다.

다음의 예시를 확인해 봅시다.

업데이트 할 때 업데이트 하려는 구간의 안에 노드의 구간이 완전히 포함되는 경우에 주목을 해야 합니다.

먼저, [3,10]에서 3의 경우 노드의 구간 [2,3]에 포함되어 있습니다.

노드의 구간이 [2,3]인데 3만 해당하므로 완전히 포함되는 경우가 아닙니다. 해당 경우에는 일반적인 세그먼트 트리와 동일하게 리프 노드를 수정합니다.

여기서 우리가 주목해야 하는 부분은 파란색 점선이 쳐져있는 노드입니다.

구간 [4,7]의 경우 업데이트 하려는 구간 [3,10]에 완전히 포함되어 있습니다.

이 경우 아래에 있는 노드는 모두 변경하려는 구간 [3,7]에 포함됩니다.

따라서, 이러한 노드의 변경은 나중에 필요할 때 하기로 하고, 그 값을 lazy[i]에 저장해 둡니다.

앞으로 어떤 노드를 방문할 때마다 lazy[i]에 값이 있는지 확인해야 합니다. 만약, lazy[i]값이 0이 아닌 경우에는 노드의 합을 변경하고, 자식 노드에게 lazy[i]값을 전달해야 합니다.

lazy[i]에는 그 노드가 담당하는 구간의 더해야 하는 값이 저장되어 있으니, 합에는 그 구간에 포함된 수의 개수만큼 곱해서 더해야 합니다.

예를 들어, 어떤 노드 node에 구간 [nodeLeft, nodeRight]의 합이 저장되어 있다면, tree[node]에는 lazy[node] * (nodeRight - nodeLeft - 1)을 더해야 합니다.

코드

import java.io.*;

public class Main {
  static void init(long[] arr, long[] tree, int node, int nodeLeft, int nodeRight) {
    if (nodeLeft == nodeRight) {
      tree[node] = arr[nodeLeft];
    } else {
      init(arr, tree, node * 2, nodeLeft, (nodeLeft + nodeRight) / 2);
      init(arr, tree, node * 2 + 1, (nodeLeft + nodeRight) / 2 + 1, nodeRight);
      tree[node] = tree[node * 2] + tree[node * 2 + 1];
    }
  }

  static void update_lazy(long[] tree, long[] lazy, int node, int nodeLeft, int nodeRight) {
    if (lazy[node] != 0) {
      tree[node] += (nodeRight - nodeLeft + 1) * lazy[node];
      if (nodeLeft != nodeRight) {
        lazy[node * 2] += lazy[node];
        lazy[node * 2 + 1] += lazy[node];
      }
      lazy[node] = 0;
    }
  }

  static long query(long[] tree, long[] lazy, int node, int nodeLeft, int nodeRight, int left, int right) {
    update_lazy(tree, lazy, node, nodeLeft, nodeRight);
    if (right < nodeLeft || nodeRight < left) {
      return 0;
    }
    if (left <= nodeLeft && nodeRight <= right) {
      return tree[node];
    }
    long lSum = query(tree, lazy, node * 2, nodeLeft, (nodeLeft + nodeRight) / 2 , left, right);
    long rSum = query(tree, lazy, node * 2 + 1, (nodeLeft + nodeRight) / 2 + 1, nodeRight, left, right);
    return lSum + rSum;
  }

  static void update_range(long[] tree, long[] lazy, int node, int nodeLeft, int nodeRight, int left, int right, long diff) {
    update_lazy(tree, lazy, node, nodeLeft, nodeRight);
    if (right < nodeLeft || nodeRight < left) {
      return;
    }
    if (left <= nodeLeft && nodeRight <= right) {
      tree[node] += (nodeRight - nodeLeft + 1) * diff;
      if (nodeLeft != nodeRight) {
        lazy[node * 2] += diff;
        lazy[node * 2 + 1] += diff;
      }
      return;
    }
    update_range(tree, lazy, node * 2, nodeLeft, (nodeLeft + nodeRight) / 2, left, right, diff);
    update_range(tree, lazy, node * 2 + 1, (nodeLeft + nodeRight) / 2 + 1, nodeRight, left, right, diff);
    tree[node] = tree[node * 2] + tree[node * 2 + 1];
  }
  public static void main(String args[]) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
    String[] line = br.readLine().split(" ");
    int n = Integer.parseInt(line[0]);
    int m = Integer.parseInt(line[1]);
    int k = Integer.parseInt(line[2]);
    m += k;
    long[] a = new long[n];
    for (int i=0; i<n; i++) {
      a[i] = Long.parseLong(br.readLine());
    }
    int h = (int)Math.ceil(Math.log(n) / Math.log(2));
    int tree_size = (1 << (h+1));
    long[] tree = new long[tree_size];
    long[] lazy = new long[tree_size];
    init(a, tree, 1, 0, n-1);
    while (m-- > 0) {
      line = br.readLine().split(" ");
      int what = Integer.parseInt(line[0]);
      if (what == 1) {
        int left = Integer.parseInt(line[1]);
        int right = Integer.parseInt(line[2]);
        long diff = Long.parseLong(line[3]);
        update_range(tree, lazy, 1, 0, n-1, left-1, right-1, diff);
      } else {
        int left = Integer.parseInt(line[1]);
        int right = Integer.parseInt(line[2]);
        bw.write(query(tree, lazy, 1, 0, n-1, left-1, right-1)+"\n");
      }
    }
    bw.flush();
  }
}

해당 코드를 순서대로 분석해보자.

기본적으로 init()는 기존의 segment tree와 동일한 방식이다.

update_laze()의 경우 lazy의 값이 0이 아닌경우 해당 노드의 lazy값을 범위의 크기에 맞게 노드의 값에 더해주고 본인의 lazy값을 자식 노드의 lazy값으로 내려주는 것이다. 그 이후 본인 노드의 lazy값은 0으로 바꿔줍니다.

이 함수의 기능은 update함수든 query함수든 해당 노드를 방문할 때 lazy값이 있으면 업데이트를 해주려는 기능의 함수이기 때문에 자식 노드로 내려주고 본인 노드의 lazy는 초기화 시켜주는 것입니다.

그리고 query함수에서는 update_laze()함수를 실행하고 나머지는 기존의 방식과 동일합니다.

update함수의 경우에는 update_laze()함수를 실행하고 그 이후 변경하는 노드에 한해서 본인 노드는 수정을 하고 본인이 리프 노드가 아닐 경우 자식 노드에게 변경하려는 값을 lazy[child Node]에 건네주는 것입니다.

다른 방식

다른 상황을 가정하고 문제를 봐보자.

만약, 구간 내의 값에 x를 더해주는 것이 아니라 x로 교체한다고 했을 경우 어떻게 처리해야 할 것인가?

기존에는 lazy값에 추가적으로 계속해서 더해주는 방식으로 처리했지만, 이 상황의 경우 기존의 값을 완전히 바꿔버리는 것이기 때문에 가지고 있던 lazy값에 추가적으로 더해주는 방식이 아니라 대체하는 방식을 사용해야 할 것이다.

코드 자체도 다른 식으로 구현해놨으니 본인에게 맞는 방식을 사용하면 됩니다.

public class Main {
  static int DEFAULT_VALUE = 0;
  static int MAX_VALUE = Integer.MAX_VALUE;  // for min
  static int MIN_VALUE = Integer.MIN_VALUE; // for max

  int merge(int left, int right) {
    return left + right;  // sum
    // return min(left, right);  // min
    // return max(left, right);  // max
  }

  int mergeBlock(int value, int size) {
    return value * size;  // sum
    // return value;  // min
    // return value;  // max
  }

  int N;  // size
  int[] tree; //  Segment Tree
  int[] lazyValue;  //  Lazy Value
  boolean[] lazyExist;


  void init(int arr[], int size) {
    N = size;
    int h = (int)Math.ceil(Math.log(N) / Math.log(2));
    int tree_size = 1 << h;
    tree = new int[tree_size];
    lazyValue = new int[tree_size];
    lazyExist = new boolean[tree_size];

    // 주워진 배열의 범위: 0 ~ N - 1
    initRec(arr, 1, 0, N - 1);
  }

  // inclusive
  int update(int left, int right, int newValue) {
    return updateRec(left, right, newValue, 1,  0, N - 1);
  }

  // inclusive
  int query(int left, int right) {
    return queryRec(left, right, 1, 0, N - 1);
  }


  private int pushDown(int newValue, int node, int nodeLeftRange, int nodeRightRange) {
    if (nodeLeftRange == nodeRightRange)
      return tree[node] = newValue;

    lazyExist[node] = true;
    lazyValue[node] = newValue;
    return tree[node] = mergeBlock(newValue, nodeLeftRange - nodeRightRange + 1);
  }

  private int initRec(int arr[], int node, int nodeLeft, int nodeRight) {
    if (nodeLeft == nodeRight)
      return tree[node] = arr[nodeLeft];

    int mid = (nodeLeft + nodeRight) / 2;

    int leftVal = initRec(arr, node * 2, nodeLeft, mid);
    int rightVal = initRec(arr, node * 2 + 1, mid + 1, nodeRight);
    return tree[node] = merge(leftVal, rightVal);
  }

  private int updateRec(int left, int right, int newValue, int node, int nodeLeft, int nodeRight) {
    if (right < nodeLeft || nodeRight < left)
      return tree[node];

    if (nodeLeft == nodeRight)
      return tree[node] = newValue;

    if (left <= nodeLeft && nodeRight <= right) {
      lazyExist[node] = true;
      lazyValue[node] = newValue;
      return tree[node] = mergeBlock(newValue, nodeRight - nodeLeft + 1);
    }

    int mid = (nodeLeft + nodeRight) / 2;
    if (lazyExist[node]) {
      lazyExist[node] = false;
      pushDown(lazyValue[node], node * 2, nodeLeft, mid);
      pushDown(lazyValue[node], node * 2 + 1, mid + 1, nodeRight);
      lazyValue[node] = DEFAULT_VALUE;
    }

    int leftVal = updateRec(left, right, newValue, node * 2, nodeLeft, mid);
    int rightVal = updateRec(left, right, newValue, node * 2 + 1, mid + 1, nodeRight);
    return tree[node] = merge(leftVal, rightVal);
  }

  private int queryRec(int left, int right, int node, int nodeLeft, int nodeRight) {
    if (right < nodeLeft || nodeRight < left)
      return DEFAULT_VALUE;

    if (left <= nodeLeft && nodeRight <= right)
      return tree[node];

    int mid = (nodeLeft + nodeRight) / 2;

    if (lazyExist[node]) {
      lazyExist[node] = false;
      pushDown(lazyValue[node], node * 2, nodeLeft, mid);
      pushDown(lazyValue[node], node * 2 + 1, mid + 1, nodeRight);
      lazyValue[node] = DEFAULT_VALUE;
    }

    return merge(queryRec(left, right, node * 2, nodeLeft, mid),
            queryRec(left, right, node * 2 + 1, mid + 1, nodeRight));
  }
}

Reference

profile
개발정리블로그

0개의 댓글