C. Parsa's Humongous Tree #722 Div.2

LONGNEW·2021년 7월 6일
0

CP

목록 보기
18/155

https://codeforces.com/contest/1529/problem/C
시간 1초, 메모리 256MB

input :

  • t (1 ≤ t ≤ 250)
  • n (2 ≤ n ≤ 105)
  • li ri (1 ≤ li ≤ ri ≤ 109).
  • u v (1 ≤ u, v ≤ n, u≠v)

output :

  • For each test case print the maximum possible beauty for Parsa's tree.
    Parsa's의 나무에서 얻을 수 있는 가장 최대의 아름다움을 출력하시오.

조건 :

  • To make Parsa's tree look even more majestic, Nima wants to assign a number av (lv ≤ av ≤ rv) to each vertex v such that the beauty of Parsa's tree is maximized.
    Parsa's의 나무를 훨씬 거대하게 만들기 위해 Nima는 av를 각 정점 v에 배당할 것이다.

  • He defines the beauty of the tree as the sum of |au − av| over all edges (u,v) of the tree.
    나무의 아름다움은 |au − av|의 합으로 나타낸다.


각 정점에서 우리는 l또는 r 중 1가지를 선택할 수 있다. 그러니까 현재까지 총합 중 더 큰 것을 선택할 수 있다는 것이다.

이 문제를 풀려면 해야 할 것이 2가지 존재한다.
1. 트리의 구조를 파악하기
2. 나무의 아름다움을 계산하기.

DFS

트리의 구조를 파악하기 위해선 1을 루트 노드로 놓은 다음에 모든 트리 까지의 경로를 order배열에 저장하도록 한다.
이 때, 부모를 파악하기 힘드므로 parent배열에 부모를 저장해두도록 한다.

BFS를 수행 할 때 그래프의 엣지가 양방향이므로 다시 되돌아가는 경우를 제거 해야 한다. 이를 parent의 값으로 확인 할 수 있다.

DP

인제 트리의 모양은 만들어 뒀다. 그렇다면 최대화된 값을 찾기만 하면 되는데 처음에서 끝으로 이동할지, 끝에서 처음으로 이동할 지를 생각해야 한다.
처음에서 끝으로 이동한다면 마지막에 리프 노드들에 존재하는 값들의 총합을 구하고 이 둘을 비교해야 한다. 근데 그렇게 한다면 리프 노드인지 판단도 힘들고 조건이 많이 필요하다.

왜 리프를 확인 해야 하냐? 각 노드들에는 자식이 존재할 수 있고 이 자식들이 가지는 값들이 다르다. 그러니까 왼쪽으로 가버린다면 더 이상 오른쪽의 총 합을 알지 못하기 때문이다.

그러니 아래에서 루트로 올라가는 방식이 계산하기 편리하다. 각 노드를 고른 이후에는 parent의 위치에 총합을 저장하도록 하는 것이다.

점화식은 어떻게 될까?
parent의 위치에서 왼쪽값을 선택 하는 경우. (현재까지 node의 왼쪽에 저장된 값 + node 왼쪽 - parent 왼쪽 의 절대값 | 현재까지 node의 오른쪽에 저장된 값 + node 오른쪽 - parent 왼쪽 의 절대값)을 비교해서 큰 값을 더해준다.
그러니까 parent 입장에서 왼쪽을 선택할 건데 현재까지 무엇이 더 큰 값인지 모르니까 이걸 비교해야 한다.
왼쪽을 했으면 오른쪽도 동일하게 수행해야 한다.

시간 복잡도

간과 한 것이 하나 있었는데 그것이 시간이다.
첫번째로, 입력이 매우 많다. n의 개수가 105인 것부터 그리고 t가 250까지 커질수 있다는 것.
-> 입력을 더 빠르게 받을 수록 좋다.
sys.stdin.buffer.readline()의 코드가 버퍼를 이용해서인지 속도가 더 빠르다.

이 보다 빠르고 싶다면 아래의 코드를 사용하자.

import sys,os,io
input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline

런타임 에러가 발생하지 않으려면 나중에 입력을 받을때 k = input()등의 형태로 사용하면 되는 듯하다.

두번째로, 정수의 계산은 느리다.
-> float의 형태로 입력을 받고 초기화도 0.으로 한다. 정답을 출력할 때 int형으로 형변환 한다.

제대로된 이유는 모르겠지만 차이가 크다.....

import sys,os,io
input = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline

t = int(input())
for i in range(t):
    n = int(input())
    data = [(-1, -1)]

    for j in range(n):
        data.append(tuple(map(float, input().split())))

    graph = [[] for j in range(n + 1)]
    for j in range(n - 1):
        u, v = map(int, input().split())
        graph[u].append(v)
        graph[v].append(u)

    parent = [-1] * (n + 1)
    order = []
    q = [1]
    while q:
        node = q.pop()
        order.append(node)

        for next_node in graph[node]:
            if next_node == parent[node]:
                continue
            parent[next_node] = node
            q.append(next_node)

    order.reverse()
    left = [0.] * (n + 1)
    right = [0.] * (n + 1)
    for node in order:
        prev = parent[node]
        if prev == -1:
            continue

        parent_left, parent_right = data[prev][0], data[prev][1]
        node_left, node_right = data[node][0], data[node][1]

        left[prev] += max(left[node] + abs(parent_left - node_left), right[node] + abs(parent_left - node_right))
        right[prev] += max(left[node] + abs(parent_right - node_left), right[node] + abs(parent_right - node_right))

    print(int(max(left[1], right[1])))

0개의 댓글