[ BOJ 1167 ] 트리의 지름(Python)

uoayop·2021년 5월 13일
0

알고리즘 문제

목록 보기
45/103
post-thumbnail

문제

https://www.acmicpc.net/problem/1167


문제 풀이

트리의 지름을 구하는 공식을 알면 빠르게 풀 수 있다. (나는몰랐다)

"트리의 지름을 구하는 공식은 임의의 하나의 노드 A에서 가장 거리가 먼 노드 B를 구하고, 이 노드 B에서 가장 거리가 먼 노드 C를 구하게 되었을 때, B와 C 사이의 거리가 트리의 지름이 된다."

[출처] https://suri78.tistory.com/135

우선 입력을 받고 트리를 인접 그래프로 만들어주었다.
더 깔끔하게 받을 수 있을 것 같은디 일단 풀기 급급했다.

cnt = int(input())
graph = defaultdict(list)

for _ in range(cnt):
    temp = list(map(int,input().rsplit()))
    key = temp[0]
    for i in range(1,len(temp)-1,2):
        graph[key].append((temp[i],temp[i+1]))

그리고 주어진 node에서 가장 먼 노드를 구하는 dfs 함수를 만들어주었다.

def dfs(node,cnt):
    global max_cnt
    global far_node
    # 가장 먼 거리 : max_cnt, 가장 먼 노드 : far_node
    if max_cnt < cnt :
        max_cnt = cnt
        far_node = node

    visited[node] = True
    # node와 연결된 노드 : u / node와 u 사이 거리 : c
    for u,c in graph[node]:
    	# u를 방문 안했을 때 방문해준다.
        if not visited[u]:
            # 이동거리에 c를 더해주고 dfs를 호출한다.
            dfs(u,cnt+c)

임의의 노드 1에서부터 가장 먼 노드를 구해준 뒤, 그 노드에서 가장 먼 노드를 구해주면 된다.


코드

import sys
from collections import defaultdict
input = sys.stdin.readline

cnt = int(input())
graph = defaultdict(list)

for _ in range(cnt):
    temp = list(map(int,input().rsplit()))
    key = temp[0]
    for i in range(1,len(temp)-1,2):
        graph[key].append((temp[i],temp[i+1]))

max_cnt = 0
far_node = 0
def dfs(node,cnt):
    global max_cnt
    global far_node
    if max_cnt < cnt :
        max_cnt = cnt
        far_node = node

    visited[node] = True
    for u,c in graph[node]:
        if not visited[u]:
            dfs(u,cnt+c)

visited = [False] * (cnt+1)
dfs(1,0)
visited = [False] * (cnt+1)
dfs(far_node,0)

print(max_cnt)
profile
slow and steady wins the race 🐢

0개의 댓글