https://codeforces.com/contest/1529/problem/C
시간 1초, 메모리 256MB
input :
output :
조건 :
To make Parsa's tree look even more majestic, Nima wants to assign a number av (lv ≤ av ≤ rv) to each vertex v such that the beauty of Parsa's tree is maximized.
Parsa's의 나무를 훨씬 거대하게 만들기 위해 Nima는 av를 각 정점 v에 배당할 것이다.
He defines the beauty of the tree as the sum of |au − av| over all edges (u,v) of the tree.
나무의 아름다움은 |au − av|의 합으로 나타낸다.
각 정점에서 우리는 l또는 r 중 1가지를 선택할 수 있다. 그러니까 현재까지 총합 중 더 큰 것을 선택할 수 있다는 것이다.
이 문제를 풀려면 해야 할 것이 2가지 존재한다.
1. 트리의 구조를 파악하기
2. 나무의 아름다움을 계산하기.
트리의 구조를 파악하기 위해선 1을 루트 노드로 놓은 다음에 모든 트리 까지의 경로를 order
배열에 저장하도록 한다.
이 때, 부모를 파악하기 힘드므로 parent
배열에 부모를 저장해두도록 한다.
BFS를 수행 할 때 그래프의 엣지가 양방향이므로 다시 되돌아가는 경우를 제거 해야 한다. 이를 parent의 값으로 확인 할 수 있다.
인제 트리의 모양은 만들어 뒀다. 그렇다면 최대화된 값을 찾기만 하면 되는데 처음에서 끝으로 이동할지, 끝에서 처음으로 이동할 지를 생각해야 한다.
처음에서 끝으로 이동한다면 마지막에 리프 노드들에 존재하는 값들의 총합을 구하고 이 둘을 비교해야 한다. 근데 그렇게 한다면 리프 노드인지 판단도 힘들고 조건이 많이 필요하다.
왜 리프를 확인 해야 하냐? 각 노드들에는 자식이 존재할 수 있고 이 자식들이 가지는 값들이 다르다. 그러니까 왼쪽으로 가버린다면 더 이상 오른쪽의 총 합을 알지 못하기 때문이다.
그러니 아래에서 루트로 올라가는 방식이 계산하기 편리하다. 각 노드를 고른 이후에는 parent의 위치에 총합을 저장하도록 하는 것이다.
점화식은 어떻게 될까?
parent의 위치에서 왼쪽값을 선택 하는 경우. (현재까지 node의 왼쪽에 저장된 값 + node 왼쪽 - parent 왼쪽 의 절대값 | 현재까지 node의 오른쪽에 저장된 값 + node 오른쪽 - parent 왼쪽 의 절대값)을 비교해서 큰 값을 더해준다
.
그러니까 parent 입장에서 왼쪽을 선택할 건데 현재까지 무엇이 더 큰 값인지 모르니까 이걸 비교해야 한다.
왼쪽을 했으면 오른쪽도 동일하게 수행해야 한다.
간과 한 것이 하나 있었는데 그것이 시간이다.
첫번째로, 입력이 매우 많다. n의 개수가 105인 것부터 그리고 t가 250까지 커질수 있다는 것.
-> 입력을 더 빠르게 받을 수록 좋다.
sys.stdin.buffer.readline()
의 코드가 버퍼를 이용해서인지 속도가 더 빠르다.
이 보다 빠르고 싶다면 아래의 코드를 사용하자.
import sys,os,io
input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
런타임 에러가 발생하지 않으려면 나중에 입력을 받을때 k = input()
등의 형태로 사용하면 되는 듯하다.
두번째로, 정수의 계산은 느리다.
-> float의 형태로 입력을 받고 초기화도 0.
으로 한다. 정답을 출력할 때 int형으로 형변환 한다.
제대로된 이유는 모르겠지만 차이가 크다.....
import sys,os,io
input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
t = int(input())
for i in range(t):
n = int(input())
data = [(-1, -1)]
for j in range(n):
data.append(tuple(map(float, input().split())))
graph = [[] for j in range(n + 1)]
for j in range(n - 1):
u, v = map(int, input().split())
graph[u].append(v)
graph[v].append(u)
parent = [-1] * (n + 1)
order = []
q = [1]
while q:
node = q.pop()
order.append(node)
for next_node in graph[node]:
if next_node == parent[node]:
continue
parent[next_node] = node
q.append(next_node)
order.reverse()
left = [0.] * (n + 1)
right = [0.] * (n + 1)
for node in order:
prev = parent[node]
if prev == -1:
continue
parent_left, parent_right = data[prev][0], data[prev][1]
node_left, node_right = data[node][0], data[node][1]
left[prev] += max(left[node] + abs(parent_left - node_left), right[node] + abs(parent_left - node_right))
right[prev] += max(left[node] + abs(parent_right - node_left), right[node] + abs(parent_right - node_right))
print(int(max(left[1], right[1])))