[백준] (실패) 1197 최소 스패닝 트리

Hyun·2025년 5월 1일
0

백준

목록 보기
94/96
post-thumbnail

문제

그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.

최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.

입력

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.

그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.

출력

첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.

예제 입력 1

3 3
1 2 1
2 3 2
1 3 3

예제 출력 1

3

풀이

처음에 최소 스패닝 트리가 뭔지 몰라서 찾아보았고, 크루스칼 알고리즘을 이용한다는 것을 알게된 후 아래와 같은 수행 동작을 가정했었다.

  1. 가중치 오름 차순으로 간선 정보 정렬
  2. 작은 것부터 하나씩 택하는데
    2-1) 집합에 이미 모든 정점이 포함되어 있으면 종료
    2-2) 두 정점이 이미 포함되어 있으면 다음 간선 택함 (2번으로 돌아감)
  3. 집합에 두 정점 정보를 추가 (자동으로 중복 제거)

그래서 다음과 같은 코드를 작성했었다.
잘못된 풀이

import sys
input = sys.stdin.readline

v, e = map(int, input().split())
v_set = set() # {} 이거랑 set() 이거랑 다른가?
temp_arr = []
w_sum = 0
for _ in range(e):
    temp_arr.append(list(map(int,input().split())))
temp_arr.sort(key=lambda x: x[2]) # 람다를 이용한 정렬 숙지하기

for sub_arr in temp_arr:
    f, s, w = sub_arr
    # 모든 정점이 이미 포함되어 있으면 break
    if len(v_set) == v:
        break
    # 두 정점이 이미 포함되어 있으면 continue
    if f in v_set and s in v_set:
        continue
    
    v_set.add(f)
    v_set.add(s)
    w_sum += w

if len(v_set) == 1:
    print(0)
else:
    print(w_sum)

그러나 위와 같이 풀면, 크루스칼 알고리즘 특성 상 기존에 택한 간선과 연결되어 있는 간선을 택하는게 아니라, 단순히 가중치가 작은 간선을 택하기 때문에 선택한 간선들이 서로 연결되어 있지 않은 상태가 될 수 있고, 이때 모든 정점들이 집합에 포함되면 종료되기 때문에 올바른 최소 스패닝 트리가 아니게 된다.

위 코드의 반례(간선들이 연결되어 있지 않아도 종료됨)

5 7
1 2 1
1 3 1
2 3 1
2 4 10
3 4 10
4 5 1
3 5 100

따라서 여러 개의 집합이 생길 수 있고, 결국에는 이 집합들이 하나로 합쳐져야 한다. 그래서 다음과 같은 로직을 생각해볼 수 있다.

  1. 두 정점 a, b 가 다른 집합에 속하면 -> 두 집합을 합친다(간선 선택 O)
  2. 이미 같은 집합에 속하면 -> 사이클이 생기므로 건너뛴다(간선 선택 X)
  3. 각 정점들이 속한 초기 집합은 자신만이 속한 하나의 집합이다.

그러나 위와 같이 여러 개의 집합을 구현하려면 복잡하다. 따라서 정점별로 자신이 속한 집합이 있다고 가정하고, 해당 집합의 대표노드를 집합의 식별자로 사용한다. 대표노드는 1차원 배열을 사용하면 된다(실제로 여러 집합을 만들지 않고, 정점별 대표 노드를 지정하는 배열을 사용함으로써 그와 동일한 효과를 낸다). 그리고 널리 알려진 union/find 함수들을 사용하여 문제를 풀이한다.

올바른 풀이

import sys
input = sys.stdin.readline

sys.setrecursionlimit(10**6)

def find(parent, x):
    # parent[x] == x 인 경우는 자신의 부모가 자신임(대표 노드)
    if parent[x] != x: 
        # 경로 압축: x 의 부모를 재귀적으로 찾아 최상위 대표 노드로 변경
        parent[x] = find(parent, parent[x])
    return parent[x]

def union(parent, a, b):
    a_root = find(parent, a)
    b_root = find(parent, b)
    if a_root != b_root:
        # 두 집합의 대표가 다르면 하나로 합침
        # b_root 집합을 a_root 집합에 연결 (b_root의 대표를 a_root로 지정)
        # 이때 b_root의 하위 노드들은 여전히 parent 배열에 b_root로 기록된 상태
        parent[b_root] = a_root

v, e = map(int, input().split())
edges = []

for _ in range(e):
    a, b, c = map(int, input().split())
    edges.append((c,a,b))
edges.sort()

parent = [i for i in range(v+1)] # 각 정점의 부모 초기화 (자기 자신)
result = 0

for cost, a, b in edges:
    # 두 노드의 대표가 다를 때만 간선 선택 (사이클 방지)
    if find(parent, a) != find(parent, b):
        union(parent, a, b)
        result += cost

print(result)

참고
https://velog.io/@yoopark/baekjoon-1197

profile
better than yesterday

0개의 댓글