백준 1967번 - 트리의 지름

윤여준·2022년 5월 9일
0

백준 풀이

목록 보기
1/35
post-thumbnail

문제

문제 링크 : https://www.acmicpc.net/problem/1967

풀이 - 1

문제를 읽고 처음 시도한 풀이는 다음과 같다.

from sys import stdin
from collections import deque
n = int(stdin.readline())
l = {}

for i in range(1, n + 1):
    l[i] = []

for i in range(n - 1):
    a, b, c = map(int, stdin.readline().split())
    l[a].append((b,c))

def dfs(x): # x번 노드를 루트 노드로 하는 트리의 지름을 구하는 함수
    child = []

    for i in l[x]:
        child.append(i) # x번 노드의 자식 노드들을 child 리스트에 추가

	# 자식 노드가 없으면 x번 노드를 루트 노드로 하는 트리의 지름이 0이므로 0을 리턴
    if len(child) == 0:
        return 0

    result = [0 for i in range(len(child))]
	
    # x번 노드의 각각의 자식 노드에서 가장 멀리 떨어진 리프 노드까지의 길이를 구함
    j = 0
    for idx, length in child:
        stack = deque()
        stack.append((idx, length))
        while stack:
            idx, length = stack.pop()
            leaf = True

            for i in l[idx]:
                leaf = False
                stack.append((i[0], length + i[1]))

            if leaf == True:
                result[j] = max(result[j], length)
        j+=1
        
    result = sorted(result)
    r = 0
    if len(child) == 1:
        r = result[len(result) - 1]
    else:
        for i in range(1,3):
            r += result[len(result) - i]
    return r


result = 0
for i in range(1, n+1):
    result = max(result, dfs(i))
print(result)

1번 노드부터 n번 노드까지 돌면서 각 노드를 루트 노드로 하는 트리의 지름을 구하고 그 값들의 최댓값을 출력하는 풀이이다.


하지만 메모리를 너무 많이 쓰고 시간이 너무 오래 걸리는 걸 보고 썩 좋은 풀이가 아닌 것 같아서 구글링을 통해 더 좋은 풀이들을 찾아봤고, 그 풀이가 2번 풀이이다.

풀이 - 2

2번 풀이는 https://kyun2da.github.io/2021/05/04/tree's_diameter/ 이 글을 참고했다.

풀이 과정은 다음과 같다.

  1. 루트 노드(1번 노드)에서 가장 멀리 떨어진 노드 n1를 구한다.
  2. n1에서 가장 멀리 떨어진 노드 n2를 구한다.
  3. n1부터 n2까지의 거리가 주어진 트리의 지름이다.

이 풀이의 증명은 https://blog.myungwoo.kr/112 이 글을 참고했다.

import sys
from sys import stdin
sys.setrecursionlimit(10**9)
n = int(stdin.readline())

graph = [[] for i in range(n + 1)]

def dfs(x, w):
    for i in graph[x]:
        a, b = i
        if distance[a] == -1:
            distance[a] = b + w
            dfs(a,w + b)

for i in range(n - 1):
    a, b, c = map(int, stdin.readline().split())
    graph[a].append((b,c))
    graph[b].append((a,c))

# distance 리스트를 통해 방문 여부를 기록함과 동시에 거리(길이)도 기록함
distance = [-1] * (n + 1)
distance[1] = 0
dfs(1,0) # dfs 함수를 이용해 1번 노드로부터 가장 멀리 떨어진 노드를 구함

# 1번 노드로부터 가장 멀리 떨어진 노드의 인덱스를 start로 잡음
start = distance.index(max(distance))
distance = [-1] * (n + 1)
distance[start] = 0
dfs(start,0) # dfs 함수를 이용해 start번 노드로부터 가장 멀리 떨어진 노드를 구함

print(max(distance))

두번째 줄의 sys.setrecursionlimit(10**9)는 재귀의 최대 깊이를 10**9으로 변경하는 코드이다. 백준 채점 서버의 최대 재귀 깊이는 1000으로 되어 있기 때문에 이를 바꿔주지 않으면 오류가 발생한다. https://help.acmicpc.net/judge/rte/RecursionError

profile
Junior Backend Engineer

0개의 댓글