[Python] Nested list time complexity

이영구·2022년 1월 18일
0

Algorithm

목록 보기
1/19

Python으로 코딩을 하다보면 간결한 표현, 또는 인자로 만들기 위해서 Nested list를 종종 사용한다. list를 3개를 만들어서 관리할 바에야, 한개의 리스트에서 tuple로 3개의 인자를 포함하게 해서 인덱스로 접근하면 간편한 일이기 때문이다.

그런데, nested list를 사용해서 정말 한참동안 해맨 경험이 있으니, 그것이 바로 이 문제다. 백준 10216 Disjoin-set을 이용해서 알고리즘을 설계하고는 O(n^2)이 나올 것 같아 시간 초과를 우려했고, 예상대로 시간초과가 발생했다. 그런데, 문제는 .. 다른 풀이 과정이 나의 알고리즘의 흐름을 그대로 따라간다는 것이었다. 그런데.. Python 으로 된 풀이는 없어서 한참 C로만 보다가 도저히 알수가 없어서 Python 풀이를 간단히 찾아냈지만, 역시나 알수가 없어서 한참이나 고민했다.

문제는 Nested List의 경우는 List * N개 보다 훨씬 많은 시간을 소모한다는데 있다. 아래 코드를 돌려보면, nested_list의 경우는 5.42s, 3개의 list는 2.33 초를 소모하는 것을 확인할 수 있다. 반복의 수가 적으면 상과 없겠지만, 그렇지 안으면 큰 차이를 보인다. 이걸 찾는데.. 시간을 얼마나 쓴건지.. 부디 참고하기 바란다.

import time 

it = 1000000
nest_list = [(i-1, i, i+1) for i in range(1, it)]

a_list = [i-1 for i in range(1, it)]
b_list = [i for i in range(1, it)]
c_list = [i+1 for i in range(1, it)]

start = time.time()

for i in range(it-1):
	ans = nest_list[i][0] + nest_list[i][1] + \
    	nest_list[i][2]

print(f"nest list iteration time {time.time() - start}")

start = time.time()
for i in range(it-1):
	ans = a_list[0] + b_list[1] + c_list[2]
    
print(f"normal 3 list iteration time {time.time() - start}")

풀이는 참고로 올려본다.

# O(n^2) 알고리즘으로 풀어낼 수 있지않을까? disjoint-set과 함께

import sys

DEBUG = True

def log(message):
    if DEBUG:
        print(message)


def find(i):
    if parent[i] != i:
        parent[i] = find(parent[i])
    return parent[i]

def union(irep,jrep):
    
    if size[irep] < size[jrep]:
        parent[irep] = jrep
        size[jrep] += size[irep]
        size[irep] = 0 
    else:
        parent[jrep] = irep
        size[irep] += size[jrep]
        size[jrep] = 0

T = int(sys.stdin.readline())
for _ in range(T):
    N = int(sys.stdin.readline())

    parent = [i for i in range(N)]
    size = [1] * (N)
    ans = N

    x_pos, y_pos, radius = [], [], []
    for _ in range(N):
        x, y, r = map(int, sys.stdin.readline().split())
        x_pos.append(x)
        y_pos.append(y)
        radius.append(r)

    for i in range(N-1):
        for j in range(i+1, N):
            x_diff = x_pos[i] - x_pos[j]
            y_diff = y_pos[i] - y_pos[j]
            dist = radius[i] + radius[j]

            if (x_diff**2) + (y_diff**2) <= dist **2:
                irep = find(i); jrep = find(j)
                if irep != jrep:
                    union(irep,jrep)
                    # parent[irep] = jrep
                    ans -= 1
    sys.stdout.write(f"{ans}\n")

            
profile
In to the code!

0개의 댓글