query(l, r, num)을 날리면 해당 구간에 num보다 작은 숫자가 몇개 있는지 알려준다. 이제 이걸로 이분탐색을 하면서 답을 구하면 된다. [3, 6, 9, 12]에서 3번째 수가 무엇인지 구하기 위해 쿼리를 날렸다 하자. num이 7, 8, 9인 경우에 자기보다 작은 숫자가 2개 있다고 답할 것이다. 이중에 가장 큰 값인 9가 실제로 배열에 존재하는 숫자이고, 정답이다. 10부터는 자기보다 작은 숫자가 3개 있다고 답할 것이기 때문이다.
#include <bits/stdc++.h>
using namespace std;
constexpr int MAX = 100000;
int N, M;
vector<int> arr(MAX);
vector<vector<int>> tree(4*MAX);
void init(int start, int end, int node) {
if (start == end) {
tree[node].push_back(arr[start]);
return;
}
int mid = (start+end)/2;
init(start, mid, node*2);
init(mid+1, end, node*2+1);
vector<int>& left = tree[node*2];
vector<int>& right = tree[node*2+1];
auto l = left.begin();
auto r = right.begin();
while(l != left.end() || r != right.end()) {
if (l == left.end()) {
tree[node].push_back(*r++);
}
else if (r == right.end()) {
tree[node].push_back(*l++);
}
else {
if (*l < *r) tree[node].push_back(*l++);
else tree[node].push_back(*r++);
}
}
}
int query(int start, int end, int left, int right, int num, int node) {
if (end < left || right < start) {
return 0;
}
else if (left <= start && end <= right) {
int cnt = lower_bound(tree[node].begin(), tree[node].end(), num) - tree[node].begin();
return cnt;
}
return query(start, (start+end)/2, left, right, num, node*2) + query((start+end)/2+1, end, left, right, num, node*2+1);
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> N >> M;
for (int i = 0; i < N; i++) {
cin >> arr[i];
}
init(0, N-1, 1);
for (int i = 0; i < M; i++) {
int a, b, c;
cin >> a >> b >> c;
int l = -1e9;
int r = 1e9;
int mid;
int ans = -1e9;
while (l <= r) {
mid = (l+r)/2;
int q = query(0, N-1, a-1, b-1, mid, 1);
if (q < c) {
ans = max(ans, mid);
l = mid + 1;
}
else r = mid - 1;
}
cout << ans << '\n';
}
return 0;
}
처음에 세그먼트 트리에 저장된 각각의 정렬된 구간에서 여러 번 이분 탐색을 했더니 TLE를 받았다. 쿼리에 배열에 존재하지 않는 숫자를 넣으면 안될 것 같아서 저렇게 풀었는데..
머리를 조금만 굴리면 O(NlogN + Mlog^2N*log2e9)에 풀 수 있다.
맨 처음에 작성한 코드인데 얘는 로그 4제곱이라 10%쯤에서 TLE를 받는다. 시간복잡도를 잘 계산해야겠다.
#include <bits/stdc++.h>
using namespace std;
constexpr int MAX = 100000;
int N, M;
vector<int> arr(MAX);
vector<vector<int>> tree(4*MAX);
void init(int start, int end, int node) {
if (start == end) {
tree[node].push_back(arr[start]);
return;
}
int mid = (start+end)/2;
init(start, mid, node*2);
init(mid+1, end, node*2+1);
vector<int>& left = tree[node*2];
vector<int>& right = tree[node*2+1];
auto l = left.begin();
auto r = right.begin();
while(l != left.end() || r != right.end()) {
if (l == left.end()) {
tree[node].push_back(*r++);
}
else if (r == right.end()) {
tree[node].push_back(*l++);
}
else {
if (*l < *r) tree[node].push_back(*l++);
else tree[node].push_back(*r++);
}
}
}
int query(int start, int end, int left, int right, int num, int node) {
if (end < left || right < start) {
return 0;
}
else if (left <= start && end <= right) {
int cnt = lower_bound(tree[node].begin(), tree[node].end(), num) - tree[node].begin();
return cnt;
}
return query(start, (start+end)/2, left, right, num, node*2) + query((start+end)/2+1, end, left, right, num, node*2+1);
}
int getkth(int left, int right, int sz, int k, int node) {
int lo = 0;
int hi = sz;
while (lo <= hi) {
int mid = (lo+hi)/2;
int q = query(0, N-1, left, right, tree[node][mid], 1);
if (q == k) return tree[node][mid];
else if (q > k) hi = mid-1;
else lo = mid+1;
}
return 0;
}
int solve(int start, int end, int left, int right, int k, int node) {
if (end < left || right < start) return 0;
else if (left <= start && end <= right) {
return getkth(left, right, end-start, k, node);
}
return solve(start, (start+end)/2, left, right, k, node*2) + solve((start+end)/2+1, end, left, right, k , node*2+1);
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> N >> M;
for (int i = 0; i < N; i++) {
cin >> arr[i];
}
init(0, N-1, 1);
for (int i = 0; i < M; i++) {
int a, b, c;
cin >> a >> b >> c;
cout << solve(0, N-1, a-1, b-1, c-1, 1) << '\n';
}
return 0;
}
제 뇌 쿼리 부하도 해결해 주나요?