[BOJ 13925] - 수열과 쿼리 13 (세그먼트 트리, C++, Python)

보양쿠·2023년 7월 7일
0

BOJ

목록 보기
152/252

BOJ 13925 - 수열과 쿼리 13 링크
(2023.07.07 기준 P1)

문제

길이가 N인 수열 A가 주어지고 M개의 쿼리가 주어진다. 각 쿼리에 맞게 출력

  • 1 x y v: Ai = (Ai + v) % MOD (x ≤ i ≤ y)
  • 2 x y v: Ai = (Ai × v) % MOD (x ≤ i ≤ y)
  • 3 x y v: Ai = v (x ≤ i ≤ y)
  • 4 x y: (ΣAi) % MOD 출력 (x ≤ i ≤ y)

알고리즘

Lazy propagation

풀이

더하기와 곱하기가 주어진다. 이를 차곡차곡 lazy에 쌓아야 하는데..

일단, lazy는 곱하기 변수, 더하기 변수. 총 2개를 만들자. 기본 초기값은 곱하기 1, 더하기 0이다. 어떤 수도 곱하기 1 더하기 0을 하면 바뀌는 수는 없다. 즉, 항등원이다.

만약 지금 노드 값은 x, 곱하기 lazy는 a, 더하기 lazy는 b라고 생각해보자. 값은 ax+b다.
여기에 곱하기 v를 하면? (ax+b) * v = avx+bv 가 나온다.
만약 곱하기 c, 더하기 d를 하면? (ax + b) *c + d = acx+bc+d 가 나온다.
결국, 곱하기 lazy에는 곱하기만, 더하기 lazy에는 곱한 후 더하면 된다.

자, 이제 쿼리를 처리해보자.
1번 쿼리는 더하기다. 그러면 ax+b + v = (ax+b) * 1 + v 이므로 {1, v}를 뿌려주자. 물론, 앞이 곱하기, 뒤가 더하기다.
2번 쿼리는 곱하기다. 그러면 (ax+b) * v = (ax+b) * v + 0이므로 {v, 0}을 뿌려주자.
3번 쿼리는 변경이다. 모든 수는 0을 곱하면 0이 된다. 그러므로 {0, v}를 뿌려주자.

코드

  • C++
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

const int MAXN = 100000, MAXH = 1 << (int)ceil(log2(MAXN) + 1);
const ll MOD = 1e9 + 7;

int N;
ll A[MAXN];

struct Lazy{
    ll mul = 1, add = 0;

    bool is_default(){
        return mul == 1 && add == 0;
    }

    void make_default(){
        mul = 1; add = 0;
    }

    void calc(ll _mul, ll _add){ // (ax + b)c + d = acx + bc + d
        mul = (mul * _mul) % MOD;
        add = (add * _mul + _add) % MOD;
    }
};

struct ST{
    ll tree[MAXH];
    Lazy lazy[MAXH];

    void init();

    void _pull(int nd){ // child -> parent
        tree[nd] = (tree[nd << 1] + tree[nd << 1 | 1]) % MOD;
    }

    void _push(int nd, int st, int en){ // parent -> child
        if (lazy[nd].is_default()) return;
        tree[nd] = (tree[nd] * lazy[nd].mul + (en - st + 1) * lazy[nd].add) % MOD;
        if (st != en){
            lazy[nd << 1].calc(lazy[nd].mul, lazy[nd].add);
            lazy[nd << 1 | 1].calc(lazy[nd].mul, lazy[nd].add);
        }
        lazy[nd].make_default();
    }

    void _init(int nd, int st, int en){
        if (st == en){
            tree[nd] = A[st];
            return;
        }
        int mid = (st + en) >> 1;
        _init(nd << 1, st, mid);
        _init(nd << 1 | 1, mid + 1, en);
        _pull(nd);
    }

    void _update(int nd, int st, int en, int l, int r, ll mul, ll add){
        _push(nd, st, en);
        if (r < st || en < l) return;
        if (l <= st && en <= r){
            lazy[nd].calc(mul, add);
            _push(nd, st, en);
            return;
        }
        int mid = (st + en) >> 1;
        _update(nd << 1, st, mid, l, r, mul, add);
        _update(nd << 1 | 1, mid + 1, en, l, r, mul, add);
        _pull(nd);
    }

    void update(int l, int r, ll mul, ll add){
        _update(1, 0, N - 1, l, r, mul, add);
    }

    ll _query(int nd, int st, int en, int l, int r){
        _push(nd, st, en);
        if (r < st || en < l) return 0;
        if (l <= st && en <= r) return tree[nd];
        int mid = (st + en) >> 1;
        return (_query(nd << 1, st, mid, l, r) + _query(nd << 1 | 1, mid + 1, en, l, r)) % MOD;
    }

    ll query(int l, int r){
        return _query(1, 0, N - 1, l, r);
    }
}st;

void ST::init(){
    _init(1, 0, N - 1);
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> N;
    for (int i = 0; i < N; i++) cin >> A[i];

    st.init();

    int M, q, x, y, v;
    cin >> M;
    while (M--){
        cin >> q;
        if (q == 1){ // +
            cin >> x >> y >> v;
            st.update(x - 1, y - 1, 1, v);
        }
        else if (q == 2){ // *
            cin >> x >> y >> v;
            st.update(x - 1, y - 1, v, 0);
        }
        else if (q == 3){ // =
            cin >> x >> y >> v;
            st.update(x - 1, y - 1, 0, v);
        }
        else{
            cin >> x >> y;
            cout << st.query(x - 1, y - 1) << '\n';
        }
    }
}
  • Python (PyPy3)
import sys; input = sys.stdin.readline
from math import ceil, log2
MOD = 1000000007

class Lazy:
    def __init__(self):
        self.mul = 1
        self.add = 0

    def is_default(self):
        return self.mul == 1 and self.add == 0

    def make_default(self):
        self.mul = 1
        self.add = 0

    def calc(self, mul, add): # (ax + b)c + d = acx + bc + d
        self.mul = (self.mul * mul) % MOD
        self.add = (self.add * mul + add) % MOD

class ST:
    def __init__(self):
        self.N = N
        self.H = 1 << ceil(log2(self.N) + 1)
        self.tree = [0] * self.H
        self.lazy = [Lazy() for _ in range(self.H)]

        self._init(1, 0, self.N - 1)

    def _pull(self, nd): # child -> parent
        self.tree[nd] = (self.tree[nd << 1] + self.tree[nd << 1 | 1]) % MOD

    def _push(self, nd, st, en): # parent -> child
        if self.lazy[nd].is_default():
            return
        self.tree[nd] = (self.tree[nd] * self.lazy[nd].mul + (en - st + 1) * self.lazy[nd].add) % MOD
        if st != en:
            self.lazy[nd << 1].calc(self.lazy[nd].mul, self.lazy[nd].add)
            self.lazy[nd << 1 | 1].calc(self.lazy[nd].mul, self.lazy[nd].add)
        self.lazy[nd].make_default()

    def _init(self, nd, st, en):
        if st == en:
            self.tree[nd] = A[st]
            return
        mid = (st + en) >> 1
        self._init(nd << 1, st, mid)
        self._init(nd << 1 | 1, mid + 1, en)
        self._pull(nd)

    def _update(self, nd, st, en, l, r, mul, add):
        self._push(nd, st, en)
        if r < st or en < l:
            return
        if l <= st and en <= r:
            self.lazy[nd].calc(mul, add)
            self._push(nd, st, en)
            return
        mid = (st + en) >> 1
        self._update(nd << 1, st, mid, l, r, mul, add)
        self._update(nd << 1 | 1, mid + 1, en, l, r, mul, add)
        self._pull(nd)

    def update(self, l, r, mul, add):
        return self._update(1, 0, self.N - 1, l, r, mul, add)

    def _query(self, nd, st, en, l, r):
        self._push(nd, st, en)
        if r < st or en < l:
            return 0
        if l <= st and en <= r:
            return self.tree[nd]
        mid = (st + en) >> 1
        return (self._query(nd << 1, st, mid, l, r) + self._query(nd << 1 | 1, mid + 1, en, l, r)) % MOD

    def query(self, l, r):
        return self._query(1, 0, self.N - 1, l, r)

N = int(input())
A = list(map(int, input().split()))

st = ST()

for _ in range(int(input())):
    q, *query = map(int, input().split())
    if q == 1: # +
        x = int(query[0]) - 1
        y = int(query[1]) - 1
        v = int(query[2])
        st.update(x, y, 1, v)
    elif q == 2: # *
        x = int(query[0]) - 1
        y = int(query[1]) - 1
        v = int(query[2])
        st.update(x, y, v, 0)
    elif q == 3: # =
        x = int(query[0]) - 1
        y = int(query[1]) - 1
        v = int(query[2])
        st.update(x, y, 0, v)
    else:
        x = int(query[0]) - 1
        y = int(query[1]) - 1
        print(st.query(x, y))
profile
GNU 16 statistics & computer science

0개의 댓글