BJ_4792 레드블루 스패닝 트리

이영구·2022년 2월 8일
0

Algorithm

목록 보기
2/19

Disjoint set 문제를 공부하는 중이다. 제법 simple한 방법이라고 생각하는 응용 범위가 넓고 정형화 되어 있지 않아 제법 어려운 문제들이 많이 등장한다.
disjoint_set으로 문제를 계속 풀어가다가 발견한 이문제, 뭔가 disjoint_set으로만 풀기는 어렵다고 생각하면서도 열심히 풀었는데, 역시나 안되서 살펴보니 Kruskal's algorithm을 사용해야 하는 문제였다.

Kruskal algorithm은 내가 algorithm에 대해 좀더 몰랐던 시절에 익혔을 때는 정말 어렵게 느껴졌었는데, disjoint set을 알고나니 나름 쉬워보였다.
알고리즘 풀이를 참고하고 나름 열심히 풀어서 문제를 제출했는데, 실패.. googling 한 결과들은 거의 C로 작성되어 있긴 했지만, algorithm 흐름은 제대로 파악했는데 33%에서 왜 자꾸 시간 초과를 띄우는지 정말로 알수가 없었다.

Kruskal algorithm을 위한 sorting 과정이 혹시라도 문제가 될까 처음부터 sorting할 필요없이 red, blue간선으로 나누어 진행을 했는데도, 여전히 시간초과..
무엇이 문제일까 곰곰히 고민하며 이것 저것 찾아보다가 단서를 발견했다. 내가 무심코 사용한 list.pop(0).. 0 인덱스면 O(1)이라고 생각한 것은 어떤 발상일까.. 결과적으로 pop(0)는 O(n) operation이었고, 다음부터는 pop(0)를 쓰지 말자고 다짐해 본다.

그런데, 정말 신기하지 않아? 똑같은 알고리즘에서 pop(0) 대신 pop()을 사용한 것으로 속도가 이다지 달라진 다는 것은?
(more than 2 seconds)

# dis-joint set, spanning tree notion catch

import sys

DEBUG = False

def log(message):
    if DEBUG:
        print(message)

def find(i):
    if parent[i] != i:
        parent[i] = find(parent[i])
    return parent[i]


def union(irep,jrep):
    if size[irep] < size[jrep]:
        parent[irep] = jrep
        size[jrep] += size[irep]
        size[irep] = 0
    else:
        parent[jrep] = irep
        size[irep] += size[jrep]
        size[jrep] = 0

def Kruscal(blue_graph, red_graph):
    tree_edges = 0 # 간선 개수
    mst_cost = 0 # 가중치 합
    graph = blue_graph+red_graph
    
    for _ in range(len(graph)):        
        color, i, j= graph.pop()
        irep, jrep = find(i), find(j)
        if irep != jrep:
            union(irep,jrep)
            # mst.append((i,j))
            if color == 'B':
                mst_cost += 1
        tree_edges += 1
    
    return mst_cost


while True:
    n, m, k = map(int, sys.stdin.readline().split())
    if n == 0:
        break
    parent = [i for i in range(n+1)]
    size = [1] * (n+1)
    
    blue_graph = []
    red_graph = []
    for _ in range(m):
        color, i, j = sys.stdin.readline().split()
        i, j = int(i), int(j)
        if color == 'B':
            blue_graph.append((color, i,j))
       else:
            red_graph.append((color, i,j))

    log(("min_graph: ", blue_graph))
    log(("max_graph: ", red_graph))

    min_value = Kruscal(blue_graph, red_graph)

    parent = [i for i in range(n+1)]
    size = [1] * (n+1)
    max_value  = Kruscal(red_graph, blue_graph)

    log(("min_value: ", min_value))
    log(("max_value: ", max_value))


    if min_value <= k <= max_value:
        sys.stdout.write("1\n")
    else:
        sys.stdout.write("0\n")
profile
In to the code!

0개의 댓글