아래 글은 "Introduction to Algorithm" 책과 가장 가까운 두 점 기하학적 접근 이론 및 증명,
Closest Pair 이론 글을 보고 정리하였습니다.
무척이나 플래티넘 2 문제. 가장 가까운 두 점. 플래 2에 지레 겁을 먹고 바로 솔루션을 볼까 싶었지만, 그래도 한번 문제를 보고 풀어보기로 했다. 2차원 좌표 상에서 가장 가까운 두 점의 거리. 단순하게 풀자면 모든 점들의 쌍의 거리를 모두 계산하고 최솟값을 찾으면 된다. 시간복잡도는 O(N^2). 그러나 데이터의 개수가 최대 10만개 이므로, 당연히 터진다. 어떻게든 연산 횟수를 줄일 방법이 뭐가 있을까.
그러다 분할정복이 생각났고, 좌표 평면을 반으로 가르면 연산 횟수가 줄지않을까 생각했다. 점을 x좌표에 대해 정렬하고 중간값을 기준으로 왼쪽면과 오른쪽면으로 나눈다(분할). 왼쪽면에서의 가장 가까운 두 점은 왼쪽에게 맡기고(정복), 오른쪽면에서의 가장 가까운 두 점은 오른쪽에게 맡긴다(정복).
def solution(left, right):
if left == right:
return float('inf')
elif left+1 == right:
return getDistance(points[left], points[right])
mid = (left+right)//2
lv = solution(left, mid)
rv = solution(mid, right)
점쌍은 아래 3가지 경우가 존재한다.
따라서 문제가 되는 것은 3번의 경우이다. 왼쪽점 집합과 오른쪽점 집합을 각각 곱하면 O((N/2)(N/2)) = O(N^2/4) = O(N^2)이다. "그래도 기존 연산횟수보다 4분의 1 정도로 줄었으니 통과하지 않을까?"라는 안일한 생각에 철퇴를 내리듯 당연히 터졌다.
import sys
input = sys.stdin.readline
n = int(input())
points = [list(map(int, input().split())) for _ in range(n)]
points.sort(key=lambda x: x[0])
def getDistance(p1, p2):
return (p1[0]-p2[0])**2 + (p1[1]-p2[1])**2
def solution(left, right):
if left == right:
return float('inf')
elif left+1 == right:
return getDistance(points[left], points[right])
mid = (left+right)//2
lv = solution(left, mid)
rv = solution(mid, right)
result = min(lv, rv)
for i in range(left, mid+1):
for j in range(mid+1, right+1):
temp = (points[i][0]-points[j][0])**2 + (points[i][1]-points[j][1])**2
result = min(result, temp)
return result
print(solution(0, n-1))
가장 가까운 두 점 기하학적 접근 이론 및 증명
Closest Pair 이론
결국 알고리즘 책과 솔루션을 보았다. 나의 풀이와 솔루션은 분할정복하는 곳까지는 똑같지만, 3번의 점쌍을 처리해주는 부분에서 완전히 다르다. 우선 여기서부터는 완전한 기하학의 영역이다.
분할정복으로 양면의 거리의 최솟값을 구했다. 이 거리의 최솟값을 d라 하자. 이는 한면에 존재하는 두 점의 거리는 최소 d라는 것이다. 이제 필요없는 점을 배제할 것이다. 양면을 분할했던 라인 L에서 x좌표상 d보다 멀리 떨어진 점은 필요없다. 어차피 d보다 멀면 볼 필요도 없다. 남은 점들은 target_points라는 새 리스트에 담는다. 그리고 target_points를 y좌표를 기준으로 오름차순 정렬한다.
minv = min(lv, rv)
target_points = []
for i in range(left, right+1):
if (points[mid][0] - points[i][0])**2 < minv:
target_points.append(points[i])
target_points.sort(key=lambda x: x[1])
이는 y좌표 상에서도 한 점에 대해 d보다 멀리 떨어진 점들을 연산에서 배제하기 위함이다. 아래 코드는 한 점에 대해 y좌표 상 d 이내의 점들과의 거리를 재서 최솟 거리를 계산하는 코드이다.
t = len(target_points)
for i in range(t-1):
for j in range(i+1, t):
if (target_points[i][1] - target_points[j][1])**2 < minv:
minv = min(minv, getDistance(target_points[i], target_points[j]))
else:
break
그 비교 횟수는 절대 7번을 넘지않는다고 한다.(증명)
import sys
input = sys.stdin.readline
n = int(input())
points = [list(map(int, input().split())) for _ in range(n)]
points.sort(key=lambda x: x[0])
def getDistance(p1, p2):
return (p1[0]-p2[0])**2 + (p1[1]-p2[1])**2
def solution(left, right):
if left == right:
return float('inf')
elif left+1 == right:
return getDistance(points[left], points[right])
mid = (left+right)//2
lv = solution(left, mid)
rv = solution(mid, right)
minv = min(lv, rv)
target_points = []
for i in range(left, right+1):
if (points[mid][0] - points[i][0])**2 < minv:
target_points.append(points[i])
target_points.sort(key=lambda x: x[1])
t = len(target_points)
for i in range(t-1):
for j in range(i+1, t):
if (target_points[i][1] - target_points[j][1])**2 < minv:
minv = min(minv, getDistance(target_points[i], target_points[j]))
else:
break
return minv
print(solution(0, n-1))