이 문제를 풀기 위해 유니온 파인드(union-find)
알고리즘이 필요하다.
유니온 파인드 알고리즘이란?
두 노드가 같은 그래프에 있는지 확인하는 알고리즘으로, 두 노드가 같은 그래프에 없다면 상호 배타적 집합(Disjoint Set, 서로소 집합)이라고도 한다.
다음과 같은 파티 상황이 있다고 해 보자.
단순하게 파티를 하나하나 보면서 그 안에 진실을 아는 A가 있는지 확인하면 안 된다. 4번 파티의 D는 나중에 A와 같은 파티에 있기 때문에 진실을 알 수 있기 때문이다. 그러면 검사하기 전에 누가 진실을 아는 사람 그룹에 들어가는지 체크하면 되지 않을까? 이 때 유니온 파인드 알고리즘이 쓰인다.
먼저 find
연산은 트리에서 자식 노드의 부모를 찾는 연산이다. union
연산을 할 때 루트가 작은 쪽으로 합치기 때문에 find
를 먼저 해줘야 하는 것이 포인트다.
다음은 union
연산으로, find
의 결과로 얻은 부모 노드의 값이 작은 쪽으로 tree를 합쳐준다.
def find(parent, x):
# 자기 자신이 root가 아니라면 root를 찾을 때까지 반복
if parent[x] != x:
parent[x] = find(parent, parent[x])
return parent[x]
def union(parent, a, b):
a = find(parent, a)
b = find(parent, b)
# 부모 노드의 크기가 작은 쪽으로 tree를 합친다.
if a < b:
parent[b] = a
else:
parent[a] = b
자 이제 이 문제에 맞게 union
연산을 변형해보자. 특정 파티에 참여하는 사람들을 두 명씩 묶고, 그 중에 진실을 아는 사람이 있다면 그 사람을 기준으로 트리를 합쳐주고, 아니라면 부모 노드가 작은 순서로 합치면 된다.
그래서 내가 제출한 코드는 다음과 같다.
from sys import stdin
input = stdin.readline
def find(parent, x):
if parent[x] != x:
parent[x] = find(parent, parent[x])
return parent[x]
def union(parent, a, b, know_truth):
a = find(parent, a)
b = find(parent, b)
if a in know_truth and b in know_truth:
return
if a in know_truth:
parent[b] = a
elif b in know_truth:
parent[a] = b
else:
if a < b:
parent[b] = a
else:
parent[a] = b
N, M = map(int, input().split())
know_truth = list(map(int, input().split()))[1:]
parties = []
parent = list(range(N + 1))
for _ in range(M):
party_info = list(map(int, input().split()))
party_len = party_info[0]
party = party_info[1:]
for i in range(party_len - 1):
union(parent, party[i], party[i + 1], know_truth)
parties.append(party)
ans = 0
for party in parties:
for i in range(len(party)):
if find(parent, party[i]) in know_truth:
break
else:
ans += 1
print(ans)