BOJ 23238 | Best Student

전승민·2023년 4월 21일
0

BOJ 기록

목록 보기
16/68

2일간 많은 실패와 고민을 하며 풀어낸 문제다.
처음에는 수열과 쿼리 6의 코드를 기반으로 작성했는데, O(QN)O(Q\sqrt{N}) 복잡도의 코드에 각 Add 와 Remove 함수 실행시마다 set에 값을 저장하고 검사하도록 하여 O(QNlogN)O(Q\sqrt{N} logN)이라는 절대 통과될 것 같지 않은 복잡도로 코드를 짰다.

수열과 쿼리 6은 쿼리 (i, j)에서 최빈값의 개수를 출력하는 문제였기 때문에 모든 Add와 Remove 함수가 상수 시간 O(1)O(1)으로 실행되었다.

그러나 이 문제는 최빈값 자체를 출력하기 때문에 Remove 연산이 문제가 된다.

Add 연산만 있었다면 값을 계속 추가해주면서 max 값을 업데이트 해 주면 되는데, Remove 연산이 생기면서 골치가 아파진다.

최빈값이 2개 이상일 경우, 이 경우를 판별하는 건 수열과 쿼리 6에서 작성했던 cnt[]와 table[]로 O(1)O(1)에 가능하지만 그 최빈값들이 무엇인지 알아내기 위해서는 set이나 vector 같이 최빈값들을 저장할 수 있는 공간이 필요하다.

최빈값이 여러개면 vector에 push하고, 범위가 업데이트되어 최빈값이 하나가 된다면 vector을 clear해서 O(1)O(1)을 유지하는 전략도 생각해봤지만 이는 (6,8,9)(6, 8, 9) => push 9 => (9)(9) => pop 9 => (?)(?) 같은 상황이 생기면 막혀버린다.

따라서 위 방법을 사용한다면 무조건 remove 연산이 필요하게 된다.

이러한 자료구조를 사용하게 되면 각 쿼리마다 최소 O(logN)O(logN)의 시간이 추가로 붙어 N=100000N = 100000일 때 대략 16배정도 실행 시간이 늘어난다. 수열과 쿼리 6이 332ms로 AC가 났는데, 정확하진 않겠지만 두 문제의 범위가 같으니 대략 5.3초로 본다면 영락없는 TLE다.

실제로 자비없이 1%에서 TLE를 먹기도 했다.

여기서 나온 아이디어인데, 제곱근 분할을 이용해서 모든 버킷 묶음에 대해 최빈값을 구해놓으면 2912번을 풀 때 사용한 이분탐색으로 최빈값의 개수를 구할 수 있다.

버킷의 크기는 BB, 버킷의 개수는 SQSQ개이고, SQ =NSQ ~= \sqrt{N}으로 본다면 O(NN)O(N\sqrt{N})에 전처리가 가능하다.

쿼리에서는 전처리해 놓은 가운데 버킷에 저장한 최빈값과, 버킷 양 옆으로 붙은 나머지 값들도 전부 후보군에 넣어준다.
만약 [l,r][l, r]에서 l,rl, r이 모두 한 버킷 안에 있다면 따로 처리해주었다.

이 값들은 최대 2×B12×B-1개이고, 모든 값은 이분 탐색으로 개수를 찾아주니 쿼리마다 BlogNBlogN으로 해결이 가능하다.

따라서 시간 복잡도는 O(NN+QBlogN)O(N\sqrt{N}+QBlogN)으로 보이는데, 버킷의 크기를 정하기 위해 잘 따져보면 다음과 같다.

전처리 때 (NB)2×B({{N}\over{B}})^2×B를 실행하고, 쿼리마다 (2B1)×2logN(2B-1)×2logN을 수행하니까 N과 Q가 최대인 100,000이라고 가정하면 그래프를 그려보니 B=79B=79 정도에서 최적이었다.

처음에는 전처리만 N2N^2 아닌가 싶어서 갸웃했는데 잘 계산해보니 충분히 가능해보여서 도전하게 되었다.

코드 ( Sqrt Decomposition )

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
constexpr bool local = true;
#else
constexpr bool local = false;
#endif

#define FASTIO ios_base::sync_with_stdio;cin.tie(nullptr);cout.tie(nullptr);

#define debug if constexpr (local) std::cout
#define endl '\n'

struct Query{
	int l, r;
	int idx;
};

int n, m;
vector<int> oarr;
vector<int> arr;
vector<int> zip;
vector<Query> q;

const int SZ = 79; // size of bucket
const int SQ = 111111 / SZ + 15; // bucket quantity

int ans[110001];
int bucketmax[SQ][SQ];
int bucketmaxcnt;
int bmax;
int bucket_cnt[100001];

vector<vector<int>> studentidx;

void arrayclear(int (&a)[100001]){
	for (int &i : a){
		i = 0;
	}
}

void preprocessing(){	
	for (int k = 0; k * SZ < n; k++){
		bucketmaxcnt = 0;
		bmax = 0;
		arrayclear(bucket_cnt);
		for (int i = k*SZ; i <= n; i++){
			if (i == 0) continue;
			
			int t = arr[i];
			bucket_cnt[t]++;
			if (bucket_cnt[t] == bucketmaxcnt){ // bigger value
				bmax = max(bmax, t);
			}
			else if (bucket_cnt[t] > bucketmaxcnt){ // change maxvalue
				bmax = t;
				bucketmaxcnt = bucket_cnt[t];
			}
			
			if (i % SZ == SZ-1 || i == n){ // Succession
				bucketmax[k][i/SZ+1] = bmax;
			}
		}
	}
}

bool _cmp(Query a, Query b){
	int al = a.l / SZ;
	int bl = b.l / SZ;
	
	if (al != bl) return al < bl;
	return a.r < b.r;
}

int findzip(int x){
	return lower_bound(zip.begin(), zip.end(), x) - zip.begin();
}


int main(){
	FASTIO;
	oarr.resize(100001);
	arr.resize(100001);
	studentidx.resize(100001);
	
	#ifdef LOCAL
	ifstream cin("bstudent013.in");
	ofstream cout("result.ans");
	#endif
	
	cin >> n >> m;
	for (int i = 1; i <= n; i++){
		int t; cin >> t;
		oarr[i] = t;
		zip.push_back(t);
	}
	
	sort(zip.begin(), zip.end());
	zip.erase(unique(zip.begin(), zip.end()), zip.end());
	
	for (int i = 1; i <= n; i++){
		arr[i] = findzip(oarr[i]);
		studentidx[arr[i]].push_back(i);
	} // compression
	
	preprocessing();
	
	for (int i = 0; i < m; i++){
		int l, r; cin >> l >> r;
		q.push_back({l, r, i});
	}
	
	sort(q.begin(), q.end(), _cmp);

	for (int i = 0; i < m; i++){
		int l = q[i].l, r = q[i].r;
		int idx = q[i].idx;
		int mxc = 0;
		int mxv = 0;
		
		if (l / SZ == r / SZ) { // same bucket
			for (int i = l; i <= r; i++){
				int t = upper_bound(studentidx[arr[i]].begin(), studentidx[arr[i]].end(), r) - lower_bound(studentidx[arr[i]].begin(), studentidx[arr[i]].end(), l);
				if (t == mxc) {
					mxv = max(mxv, arr[i]); // if same cnt => bigger value
				}
				else if (t > mxc){
					mxv = arr[i];
					mxc = t;
				}
			}
		}
		else{
			vector<int> doubt;
			int lp = (l/SZ) + 1;
			if (l % SZ == 0) lp--;
			
			for (int i = l; i < (lp)*SZ; i++){
				doubt.push_back(arr[i]);
			}
			doubt.push_back( bucketmax[lp][r/SZ] );
			
			for (int i = r; i >= (r/SZ)*SZ; i--){
				doubt.push_back(arr[i]);
			}
			
			for (auto &i : doubt){
				int t = upper_bound(studentidx[i].begin(), studentidx[i].end(), r) - lower_bound(studentidx[i].begin(), studentidx[i].end(), l);
				if (t == mxc) {
					mxv = max(mxv, i); // if same cnt => bigger value
				}
				else if (t > mxc){
					mxv = i;
					mxc = t;
				}
			}
			
		}
		
		ans[idx] = zip[mxv];
	}
	
	for (int i = 0; i < m; i++){
        cout << ans[i] << endl;
    }
	
}
profile
알고리즘 공부한거 끄적이는 곳

0개의 댓글