프로그래머스 시소 짝꿍 코드 분석

고봉진·2023년 3월 11일
0

TIL/코딩테스트

목록 보기
12/27
def solution(numbers):
    radix = [0] * 1001
    r1 = [[] for _ in range(1001)]
    r2 = [set() for _ in range(4001)]

    for i in numbers:
        radix[i] += 1
        for j in [i*2, i*3, i*4]:
            r1[i].append(j)
            r2[j].add(i)

    answer = 0
    for n in numbers:
        for j in r1[n]:
            for x in r2[j]:
                if x != n:
                    answer += radix[n]*radix[x]
    answer //= 2
    for n in radix:
        if n > 0:
            answer += (n*~-n)//2

    return answer
    

두 가지 문제

1. 실패

일단 실패부터 해결해야겠다. 어디가 틀린거지?
논리를 설명해보자.

  1. 먼저 각 수가 몇개씩 있는지 센다. 여기서 radix를 사용하고 있으나 dict를 사용하면 메모리 최적화가 될 것 같다. 순서가 중요하지 않기 때문이다.
  2. 각 수에 대해 2, 3, 4를 곱한 값을 그 수 인덱스의 r1 리스트에 저장한다. 중복되는 값이 어차피 없으므로 리스트를 사용했다. (r1리스트를 dict로 바꿔도 될 것 같다.)
  3. 어느 수에 2, 3, 4를 곱한 값에 곱하기 이전 값을 저장한다. 이 때, 중복되는 값이 없어야 순회에 소요되는 시간을 줄일 수 있기에 set를 사용했다. 역시 r2를 dict로 사용할 수 있다.
  4. answer를 0으로 초기화하고, numbers 리스트의 각 숫자들에 대해 아래와 같은 연산을 수행한다.
    1. 각 숫자(n)를 2, 3, 4로 곱한 값들을 사용해, 그 곱한 값들을 자신의 곱한 값들로 갖는 다른 숫자(x)들을 구한다.
    2. n의 개수와 x의 개수를 곱한 값을 answer에 추가한다.
      • 예를 들어 x가 3개 있고, n이 5개가 있는 경우를 생각해보자.
      • ls = [100, 100, 100, 150, 150, 150, 150, 150]를 보면 각 100마다 5개의 150으로 뻗어 나가는 간선이 생긴다.
  5. 양쪽에서 간선이 뻗어 나오기 때문에 answer를 2로 나눈다.
  6. 그리고 각 수의 개수에 대해 서로를 연결하는 간선의 개수를 구해 answer에 더한다. 간선의 개수를 구하는 공식을 사용했다.
    • 공식 : 각 숫자 n에 대해 (n*~-n)//2

논리는 맞는 것 같은데 왜 통과가 안될까?

2. 시간 초과

  1. radix는 단순히 카운터 용도로만 사용한다. dict를 사용해보자. 직접 구하는 방법도 있지만 여기선 collections 모듈의 Counter 클래스를 사용해보자.
  2. r1r2도 같은 모듈의 defaultdict를 사용한다.

아래는 완성된 코드이다.

from collections import Counter, defaultdict

def solution(numbers):
    counter = Counter(numbers)
    r1 = defaultdict(list)
    r2 = defaultdict(set)

    for i in numbers:
        for j in [i*2, i*3, i*4]:
            r1[i].append(j)
            r2[j].add(i)

    answer = 0
    for n in numbers:
        for j in r1[n]:
            for x in r2[j]:
                if x != n:
                    answer += counter[n]*counter[x]

    answer //= 2
    for r in counter.values():
        if r > 1:
            answer += (r*~-r)//2

    return answer
    

여전히 같은 문제에서 통과와 실패, 시간초과가 발생한다. 어떤 경우에서는 순회가 많이 줄어들지 않았나보다. 논리 자체를 수정할 필요가 있는 것 같다. 다시 와서 생각해봐야겠다.

해결

약간의 실수를 바로잡으니 해결되었다.

1. Counterdictionary 활용

from collections import Counter, defaultdict

def solution(numbers):
    counter = Counter(numbers)
    r1 = defaultdict(list)
    r2 = defaultdict(set)

    for i in counter:
        for j in [i*2, i*3, i*4]:
            r1[i].append(j)
            r2[j].add(i)

    answer = 0
    for n in counter:
        for j in r1[n]:
            for x in r2[j]:
                if x != n:
                    answer += counter[n]*counter[x]

    answer //= 2
    for cnt in counter.values():
        answer += (cnt*~-cnt)//2

    return answer

2. radix 활용

def solution(numbers):
    radix = [0] * 1001
    r1 = [set() for _ in range(1001)]
    r2 = [set() for _ in range(4001)]

    for i in numbers:
        radix[i] += 1
        for j in [i*2, i*3, i*4]:
            r1[i].add(j)
            r2[j].add(i)

    answer = 0
    for n in set(numbers):
        for j in r1[n]:
            for x in r2[j]:
                if x != n:
                    answer += radix[n]*radix[x]

    answer //= 2
    for n in radix:
        if n > 0:
            answer += (n*~-n)//2

    return answer
    

다른 풀이

상어고기먹고가자 님의 풀이를 약간 수정해서 올린다. 이 글에서 많은 힌트를 얻었다.

from collections import Counter

def solution(weights):
    answer = 0
    count = Counter(weights)
    for v in count.values():
        if v > 1:
            answer += v * (v-1) / 2

    weights = set(weights)
    check = (3/4, 2/3, 1/2)
    for w in weights:
        for c in check:
            if w*c in weights:
                answer += count[w] * count[w*c]

    return answer
    
profile
이토록 멋진 휴식!

0개의 댓글