BOJ 2104 부분배열 고르기

LONGNEW·2022년 2월 5일
1

BOJ

목록 보기
312/333

https://www.acmicpc.net/problem/2104
시간 2초, 메모리 128MB

input :

  • N(1 ≤ N ≤ 100,000)
  • A[1], …, A[N] (0 <= 원소 <= 1,000,000)

output :

  • 첫째 줄에 최대 점수를 출력

조건 :

  • 어떤 i, j(1 ≤ i ≤ j ≤ N)에 대한 점수는, (A[i] + … + A[j]) × min{A[i], …, A[j]}가 된다.

세그먼트 트리에 대한 연습을 위해 이 문제를 풀어보기로 하였다.

N이 10만이어도 성립이 가능하게 하려면 어떻게 해야하나?
누적합에 대해서는 따로 DP에 저장을 하는 방식을 사용하고, 최솟값을 찾는것에는 세그 트리를 써야 하나?
근데 그것보다 이것도 가정이 반복문 2번이 수행되어야 하기 때문에 문제가 된다.

세그 트리를 이용한 여러 풀이가 있었지만 [BOJ] 2104 | 쥐니모 !! 글을 보고 공부하였다.
세그 트리의 본질을 사용해 거기에 추가적인 정보를 저장하는 방식을 사용했다.

그리고 추가적으로 세그 트리를 활용해서 답을 구해야 한다. 이 때 사용한 방식이 분할 정복으로 퀵 정렬을 구현할 때 사용한 방식이다.

처음에는 1 ~ n 구간에서 점수를 구한다. 이 점수를 계산할 때 우리는 사용하는 idx가 있다.
이 idx를 5라고 하면 그 다음 확인할 구간은
1 ~ 4, 6 ~ n이 되는 것이다.

이러한 과정을 모든 자연수에 대해서 수행하면 답을 구할 수 있게 된다.

다음 풀이

  1. i ~ j까지의 합
  2. 최솟값을 곱하라
  3. 1 <= i <= j <= N

기본적인 반복문은 쓸 수 없음을 알 수 있다. 이 떄는 합을 계산하는 방식을 떠올리자. dp를 우선적으로 떠올리면 된다.
추가적으로 생각할 것은 구간을 잡는 방식이다.
주요 아이디어 : 특정 idx를 i ~ j까지의 최솟값으로 만들려고 한다. 이 때 idx 기준 왼쪽에서 자기 보다 작은 놈이 나오기 전까지의 idx, 오른쪽에 대해서를 구해야 한다.
결국 모든 idx에 대해 수행하게 되는 것이고 이 문제에선 이를 분할 정복으로 생각한 것이다.
기본적으로 세그 트리 => 구간을 나눔 + 추가 정보 가장 작은 수의 인덱스

초반에 너무 어려운 코드를 보고 있어서 이해가 너무 어려웠다.
정말 딱 하나 더 응용된 코드가 있어 다행이란 생각이 들었고 그나마 다행이다.

테스트 케이스

만약 이 문제의 추가적인 테스트 케이스를 찾는 다면 NEERC 2005 아카이브에서 Tests and Jury Solutions을 찾길 바란다.

두번째 링크에 다운을 받는 링크를 걸어 뒀으니 이를 사용해도 좋다.

import sys
sys.setrecursionlimit(100000)

def init(idx, left, right):
    if left == right:
        tree[idx] = (data[left], left)
        return tree[idx]

    mid = (left + right) // 2
    lsum, lidx = init(idx * 2, left, mid)
    rsum, ridx = init(idx * 2 + 1, mid + 1, right)

    tree[idx] = (lsum + rsum, lidx if data[lidx] < data[ridx] else ridx)
    return tree[idx]

# s : search, t : target
def query(idx, sl, sr, tl, tr):
    if tr < sl or sr < tl:
        return (0, 0)
    if tl <= sl and sr <= tr:
        return tree[idx]

    mid = (sl + sr) // 2
    lsum, lidx = query(idx * 2, sl, mid, tl, tr)
    rsum, ridx = query(idx * 2 + 1, mid + 1, sr, tl, tr)

    ret = (lsum + rsum, lidx if data[lidx] < data[ridx] else ridx)
    return ret

def pick(left, right):
    if left == right:
        return data[left] * data[right]

    min_sum, idx = query(1, 1, n, left, right)
    ret = min_sum * data[idx]
    if left <= idx - 1:
        ret = max(ret, pick(left, idx - 1))
    if idx + 1 <= right:
        ret = max(ret, pick(idx + 1, right))
    return ret

n = int(sys.stdin.readline())
tree = [(0, 0)] * (100001 * 4)

data = [0] * 100001
data[0] = float("inf")
for idx, item in enumerate(list(map(int, sys.stdin.readline().split()))):
    data[idx + 1] = item

init(1, 1, n)
print(pick(1, n))
#include "iostream"
#define ll long long

using std::cout; using std::cin; using std::ios;
using std::make_pair; using std::pair; using std::max;

int n;
ll data[100001] = {0x7fffffff};
pair<ll, int> tree[300003];

void fastio(){
    ios::sync_with_stdio(false);
    cout.tie(nullptr);
    cin.tie(nullptr);
}

pair<ll, int> init(int idx, int l, int r){
    if (l == r)
        return tree[idx] = make_pair(data[l], l);

    int mid = (l + r) / 2;
    pair<ll, int> left = init(idx * 2, l, mid);
    pair<ll, int> right = init(idx * 2 + 1, mid + 1, r);
    return tree[idx] = make_pair(left.first + right.first, data[left.second] > data[right.second] ? right.second : left.second);
}

pair<ll, int> query(int idx, int sl, int sr, int tl, int tr){
    if (tr < sl || sr < tl)
        return make_pair(0, 0);
    if (tl <= sl && sr <= tr)
        return tree[idx];

    int mid = (sl + sr) / 2;
    pair<ll, int> left = query(idx * 2, sl, mid, tl, tr);
    pair<ll, int> right = query(idx * 2 + 1, mid + 1, sr, tl, tr);
    return make_pair(left.first + right.first, data[left.second] > data[right.second] ? right.second : left.second);
}

ll pick(int l, int r){
    if (l == r)
        return data[l] * data[r];

    pair<ll, int> res = query(1, 1, n, l, r);
    ll ret = res.first * data[res.second];

    if (l <= res.second - 1)
        ret = max(ret, pick(l, res.second - 1));
    if (res.second + 1 <= r)
        ret = max(ret, pick(res.second + 1, r));

    return ret;
}

int main(){
    fastio();
    cin >> n;
    for (int i = 1; i < n + 1; ++i) {
        cin >> data[i];
    }

    init(1, 1, n);
    cout << pick(1, n);

    return 0;
}

0개의 댓글