Union-find Algorithm
: 서로소 집합을 구하는 알고리즘
서로소 집합 구조를 표현할 때는 트리 자료구조를 이용한다.
트리는 계층을 갖고 있기 때문에 노드들간의 부모-자식 관계가 있다고 가정한다.
언제 쓰는가?
여러개의 노드가 존재할 때, 특정한 2개의 노드가 서로 같은 그래프에 속하는 지 판별
ex) 크루스칼 알고리즘 사이클 판별
union-find algorithm 의 필수 연산
1. find
: x가 속한 집합의 대표값(루트 노드 값)을 반환한다. x가 어떤 집합에 속해 있는지 찾는 연산
[find 알고리즘]
부모노드 테이블을 parent[x]와 x를 똑같이 초기화해놓았기 때문에, 두 값이 같아야 루트 노드이므로 찾을 때까지 재귀적으로 호출한다.
def find(x):
if parent[x] != x:
return find(parent[x])
return x
2. union
: 2개 원소로 이루어진 집합을 하나의 집합으로 합치기(간선 잇기)
- union 연산
A의 루트 노드 A'과 B의 루트 노드 B'를 찾기 (find)
A'를 B'의 부모 노드로 설정 (A' < B')
- 모든 union 연산을 처리할 때까지 반복
def union(a, b):
a = find(a)
b = find(b)
if a < b:
parent[b] = a
else:
parent[a] = b
알고리즘 실행 과정
- 부모테이블을 초기화한다.
- 모든 원소가 자기 자신을 부모로 가지도록 설정한다.
- union 연산을 통해 입력받은 데이터대로 A, B를 연결한다.
- 이때 find 연산을 통해 A, B의 루트노드를 각각 찾는다.
- A의 부모노드가 더 작다고 했을 때, A의 부모노드를 B의 부모노드로 설정한다. (일반적으로 번호가 작은 원소가 부모 노드가 되도록 구현한다)
- 모든 union 연산이 처리될 때까지 3~5를 반복한다.
- 부모테이블 초기화
자기 자신을 부모로 설정한다.(parent[x] == x)
- 각각의 union 연산을 수행한다. [(4,5) (3,4) (2,3) (1,2)]
=> union(4, 5)
4와 5의 루트노드를 각각 찾는다. 두 수의 루트노드 중 더 큰 루트노드를 더 작은 루트노드로 설정
- union(3,4)
- union(2,3)
- union(1,2)
최악의 경우 5의 루트노드를 찾기 위해 O(V)의 시간이 걸린다.
경로압축을 통해 개선된 find 알고리즘
위 코드를 쓸 경우 계속 부모노드를 타고 올라가야 되기 때문에 트리의 최하위 자식노드의 부모노드를 찾는데 시간이 오래걸릴 수 있다.
개선된 find 알고리즘을 통해 union 연산을 하면 루트노드에 더 빠르게 접근할 수 있다.
기존과 차이점은 return x
가 아닌 return parent[x]
다. 재귀를 돌면서 부모 테이블을 최상위 부모노드로 바꿔준다.

def find_parent(x):
if parent[x] != x:
parent[x] = find_parent(parent[x])
return parent[x]
개선된 알고리즘으로 각각의 union 연산을 다시 수행해보자. [(4,5) (3,4) (2,3) (1,2)]
union-find 소스 코드
def find_parent(x):
if parent[x] != x:
parent[x] = find_parent(parent[x])
return parent[x]
def union(a, b):
a = find_parent(a)
b = find_parent(b)
if a < b:
parent[b] = a
else:
parent[a] = b
v, e = map(int, input().split())
parent = [0] * (v + 1)
for i in range(1, v + 1):
parent[i] = i
for i in range(e):
a, b = map(int, input().split())
union(parent, a, b)
for i in range(1, v+1):
print(find(parent, i), end=' ')
union-find는 무방향 그래프 내에서 사이클을 판별할 때 사용할 수 있다.
for i in range(e):
a, b, cost = edges[i]
if find(a) != find(b):
union(a, b)