BOJ 17161 | 편집 거리 (Hard)

전승민·2023년 4월 30일
1

BOJ 기록

목록 보기
36/68

LCS 5를 풀었기 때문에 16MB라는 메모리 제한에서 익숙한 히르쉬버그의 향기가 느껴지는 문제였다.

히르쉬버그를 이미 공부했기 때문에 이 문제는 처음부터 구현할 방법이 바로 떠올랐다.
사소한 실수 한 번 이후에 바로 AC를 맞은 의외로 쉬웠던 문제였다.

이 문제 유형은 간단한 편집 거리 역추적이라서 골드 III 정도로 기본 문제가 있을만 한데 의외로 가장 쉬운 문제도 메모리 최적화 기법이 필요해서 플레티넘 V에 위치하고 있다.

이 문제에서는 LCS 5에서 그러했듯 토글링으로 dp를 최적화하고 히르쉬버그로 역추적해서 출력하면 된다.

DPDP는 다음과 같은 점화식으로 나타낼 수 있다.

DP[i][j]={DP[i1][j1] (S[i]==T[j])min(DP[i][j1],DP[i1][j1],DP[i1][j]) elseDP[i][j] = \begin{cases} DP[i-1][j-1]\ (S[i] == T[j]) \\ min(DP[i][j-1], DP[i-1][j-1], DP[i-1][j])\ else \end{cases}

문자열 S,TS, T는 모두 1-based라고 하고, DP[0][j]=jDP[0][j] = j, DP[i][0]=iDP[i][0] = i으로 전부 초기화한 상태로 DPDP 배열을 채워나간다.

DP[S.size()][T.size()]DP[S.size()][T.size()]에서 SSTT의 최종 편집 거리가 나오고, 이를 이용해서 토글링으로 최적화 할 수 있다.

이 표는 일반적으로 편집 거리를 구했을 때 DPDP 배열이다.
다음은 히르쉬버그를 사용하기 위해 두 부분으로 나누어서 편집 거리를 구한 DPDP 배열이다.

토글링을 사용했기 때문에 결국 마지막에는 빨간색 행과 노란색 행만이 남아있다.
SS는 이미 두 부분으로 나눠졌으니 TT만 나누면 되는데, 분할한 두 부분의 편집 거리의 합이 전체 편집 거리의 합이 된다는 것을 생각하면 쉽다.
최소의 편집 거리를 구해야 하므로 Red[i]+Yellow[i+1]Red[i] + Yellow[i+1]가 최소가 되는 ii를 찾으면, 그 부분을 기준으로 분할하면 된다.

여기서는 파란색으로 색칠한 부분의 합이 최소이므로 SNOWSNOWNOWNOW로, BALLBALLRAPRAP으로 분리된다.

왼쪽 위와 오른쪽 아래 중에서 왼쪽 위를 먼저 계산하는 것이 출력 형태를 생각하면 더 유리하므로 왼쪽 위를 보자.

이 부분도 이렇게 둘로 나누어지고, 이 분할은 행이 하나가 될 때까지 일어난다.
NONOSNOSNO를 먼저 계산하자.

드디어 행이 하나인 부분 문자열로 분할되었다. 왼쪽 위를 먼저 계산해야 하므로 NNSNSN을 보자.

행이 하나일 때 DPDP 배열은 DP[0][j]=jDP[0][j] = j, DP[i][0]=iDP[i][0] = i 초기화 후 하얀 칸을 전부 구해주면 된다.
이 때는 토글링할 필요 없이 편하게 구해주면 되고, 이제 백트래킹을 통해 역추적을 할 수 있다.

역추적은 그리디로 현재 칸에서 두 문자가 같을 때는 무조건 COPY가 일어나고 대각선으로 진행, 아닐 때는 1. min 값, 2. 대각선 -> 왼쪽 or 위쪽 우선순위로 진행하면 된다. 두 문자가 다를 때 대각선 이동은 MODIFY, 왼쪽 이동은 ADD, 위쪽 이동은 DELETE다.

아무리 커봐야 2 x N 배열이므로 일반적인 방법으로 구현하면 된다.

항상 왼쪽 위를 먼저 진행하므로 결국 정순으로 최종 결과값이 출력되기 때문에 역추적 후 바로 결과를 출력해도 괜찮다.

마지막 출력 순서는 왼쪽 위 범위부터 오른쪽 아래 범위까지 차례대로이므로 안심하고 바로 출력해도 된다.

여담이지만 내 코드에는 전부 debug << 가 달려 있는데 이건 코드를 정리해서 올리지 않아서 등장하는 것으로 열심히 디버깅해서 결과값을 확인한 내 흔적을 엿볼 수 있다.

코드

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
constexpr bool local = true;
#else
constexpr bool local = false;
#endif

#define FASTIO ios_base::sync_with_stdio;cin.tie(nullptr);cout.tie(nullptr);

#define debug if constexpr (local) std::cout
#define endl '\n'

string S, T, ans;

void ED(string S, string T){
	int nS = S.size();
	int nT = T.size();
	
	S = ' ' + S; T = ' ' + T;
	
	if (nS == 1){
		vector<int> dp[2];
		dp[0].resize(nT+1);
		dp[1].resize(nT+1);
		for (int i = 0; i <= nT; i++) dp[0][i] = i;
		dp[1][0] = 1;
		for (int i = 1; i <= nT; i++){
			if (S[1] == T[i]) dp[1][i] = dp[0][i-1];
			else dp[1][i] = min(dp[1][i-1], min(dp[0][i-1], dp[0][i])) + 1;
		}
		
		int x = nS, y = nT;
		vector<pair<char, char>> rst;
		while (x != 0 || y != 0){
			if (x == 0 && y != 0){ //  ADD ONLY
				rst.push_back(make_pair('a', T[y]));
				y--;
				continue;
				
			}
			else if (x != 0 && y == 0){ // DELETE ONLY
				rst.push_back(make_pair('d', S[x]));
				x--;
				continue;
			}
			
			int mnv = min(dp[x-1][y], min(dp[x-1][y-1], dp[x][y-1]));
			if (dp[x-1][y-1] == mnv){
				if (dp[x][y] == mnv){ // COPY
					rst.push_back(make_pair('c', S[x]));
				}
				else{ // MODIFY
					rst.push_back(make_pair('m', T[y]));
				}
				x--; y--;
			}
			else if (dp[x-1][y] == mnv){ //DELETE
				rst.push_back(make_pair('d', ' '));
				x--;
			}
			else{ // ADD
				rst.push_back(make_pair('a', T[y]));
				y--;
			}
		}
		
		reverse(rst.begin(), rst.end());
		for (auto &i: rst) cout << i.first << ' ' << i.second << endl;
		return;
	}

	vector<int> prev(nT+2), upper(nT+2), lower(nT+2);
	
	int mid = nS/2;
	
	for (int i = 0; i <= nT; i++) prev[i] = i;
	//for (auto &i: prev) debug << i << ' '; debug << endl;
	for (int i = 1; i <= mid; i++){
		upper[0] = i;
		for (int j = 1; j <= nT; j++){
			if ( S[i] == T[j] ) upper[j] = prev[j-1];
			else upper[j] = min(upper[j-1], min(prev[j], prev[j-1])) + 1;
		}
		prev = upper;
	}
	
	for (int i = nT+1; i >= 1; i--) prev[i] = nT+1 - i;
	//for (auto &i: prev) debug << i << ' '; debug << endl;
	for (int i = nS; i >= mid + 1; i--){
		lower[nT+1] = nS-i+1;
		for (int j = nT; j >= 1; j--){
			if (S[i] == T[j]) lower[j] = prev[j+1];
			else lower[j] = min(lower[j+1], min(prev[j], prev[j+1])) + 1;
		}
		prev = lower;
	}
	
	//DEBUG
	//for (auto &i: upper) debug << i << ' '; debug << endl;
	//for (auto &i: lower) debug << i << ' '; debug << endl;
	
	int mnv = 9999999, idx = 0;
	for (int i = 0; i < upper.size()-1; i++){
		if (mnv > upper[i] + lower[i+1]){
			mnv = upper[i] + lower[i+1];
			idx = i;
		}
	}
	string LT = T.substr(1, idx);
	string RT = T.substr(idx+1);
	string US = S.substr(1, mid);
	string DS = S.substr(mid+1);
	//cout << LT << ' ' << RT << ' ' << US << ' ' << DS << endl;
	
	ED(US, LT);
	ED(DS, RT);
}

int main(){
	cin >> S >> T;
	int nS = S.size();
	int nT = T.size();
	
	ED(S, T);
}
profile
알고리즘 공부한거 끄적이는 곳

0개의 댓글