예전에 책에서 얼핏 본 것 같은 문제라 생각보다 간단하게 해결할 수 있었다!
위 사진과 같이 주어진 히스토그램을 반으로 나눈 후, 아래의 3가지 경우로 나누어서 분할 정복을 이용하여 문제를 해결하면 된다.
- 정답이 되는 직사각형이 왼쪽 부분에만 있다.
- 정답이 되는 직사각형이 오른쪽 부분에만 있다.
- 정답이 되는 직사각형이 왼쪽과 오른쪽에 걸쳐 있다.
번 직사각형부터 번 직사각형까지 고려한 답을 라 하면, 라고 하였을 때, 1번과 2번의 경우는 와 를 통해 구할 수 있다.
3번의 경우 정답이 되는 직사각형은 항상 와 번 직사각형을 포함하게 되므로 여기서부터 탐색을 진행하면 되는데, 이때 현재 포함하고 있는 영역의 좌우를 보고 히스토그램의 높이가 높은 방향으로 뻗어나가면서 탐색하는 범위의 전체 직사각형을 포함할 때까지 보면 된다.
이 방법이 성립하는 이유는, 만약 더 낮은 방향으로 뻗어 나갔을 때에 정답이 되는 직사각형이 만들어진다고 해도 그 낮은 직사각형을 제외하고 반대 방향으로 뻗어 나갔을 때의 넓이가 항상 더 크거나 같기 때문이다.
마지막으로 나는 불필요한 edge condition과 너무 많은 recursion을 줄이기 위해 작은 수의 입력에 대해서는 일일이 모든 경우를 찾는 brute-force 방법을 적용하여 코드를 작성하였다.
// BOJ 6549. 히스토그램에서 가장 큰 직사각형
#include <cstdio>
#include <vector>
#include <algorithm>
long long solve(std::vector<long long> &hist, long long lo, long long hi) {
// Base case
if (lo > hi) return 0;
if (lo == hi) return hist[lo];
// Brute-force approach for small input.
if (hi - lo <= 100) {
long long result {-1};
for (long long size = 1; size <= hi - lo + 1; size++) {
for (long long i = lo; i <= hi; i++) {
long long j = i + size - 1;
if (j > hi) continue;
long long area = size * (*std::min_element(hist.begin() + i, hist.begin() + j + 1));
result = std::max(result, area);
}
}
return result;
}
long long mid {(lo + hi) / 2};
// Case 1: The rectangle with maximum area is on the left half or right half.
long long result = std::max(solve(hist, lo, mid), solve(hist, mid + 1, hi));
// Case 2: It lies both on the left and right half.
// Starting from hist[mid] and hist[mid + 1],
// extend the rectangle to the direction of larger hist value.
long long i {mid}; long long j {mid + 1}; long long h {std::min(hist[i], hist[j])};
while (i >= lo || j <= hi) {
result = std::max(result, (j - i + 1) * h);
if (i == lo && j == hi) break;
if (i == lo) {
j++;
h = std::min(h, hist[j]);
} else if (j == hi) {
i--;
h = std::min(h, hist[i]);
} else if (hist[i - 1] >= hist[j + 1]) {
i--;
h = std::min(h, hist[i]);
} else if (hist[i - 1] < hist[j + 1]) {
j++;
h = std::min(h, hist[j]);
}
}
return result;
}
int main() {
while (1) {
long long n;
std::vector<long long> hist;
// Input
scanf("%d", &n);
if (n == 0) break;
for (int i=0; i<n; i++) {
long long h;
scanf("%lld", &h);
hist.push_back(h);
}
// Output
printf("%lld\n", solve(hist, 0, n - 1));
}
return 0;
}