[Algorithm] 바둑이 승차(DFS) (Feat. 지역변수와 전역변수)

myeonji·2022년 2월 23일
0

Algorithm

목록 보기
52/89

철수는 그의 바둑이들을 데리고 시장에 가려고 한다. 그런데 그의 트럭은 C킬로그램 넘게 태울수가 없다. 철수는 C를 넘지 않으면서 그의 바둑이들을 가장 무겁게 태우고 싶다.
N마리의 바둑이와 각 바둑이의 무게 W가 주어지면, 철수가 트럭에 태울 수 있는 가장 무거운 무게를 구하는 프로그램을 작성하세요.

💡 전역변수와 지역변수 개념
💡 범위를 더 줄여서 시간초과 잡기

import sys

def DFS(L, sum):  # L은 인덱스, sum은 부분집합의 합
    if sum > c:
        return
    
    if L == n:
        if sum > result:
            result = sum
        return
    else:
        DFS(L+1, sum+a[L])
        DFS(L+1, sum)


if __name__ == '__main__':
    c, n = map(int, sys.stdin.readline().split())
    a = [0]*n
    result = -2147000000  # 가장 작은 값으로 초기화 (최대값 구하기 위해)
    for i in range(n):
        a[i] = int(sys.stdin.readline())
    DFS(0, 0)
    print(result)

def DFS 안에 result와 시간초과 에러가 난다.

1. result 에러

if sum > result: 바로 아래 코드에서 result = sum 으로 result 라는 지역변수가 생성된다.

result 는 지역변수 라고 해석이 되었는데, if sum > result: 를 먼저 실행하니 result 는 값이 없는 변수인 것이다.

전역변수 result 를 써야 하므로,

def DFS(L, sum):  # L은 인덱스, sum은 부분집합의 합
    global result

    if sum > c:
        return

    if L == n:
        if sum > result:
            result = sum
        return
    else:
        DFS(L+1, sum+a[L])
        DFS(L+1, sum)


if __name__ == '__main__':
    c, n = map(int, sys.stdin.readline().split())
    a = [0]*n
    result = -2147000000  # 가장 작은 값으로 초기화 (최대값 구하기 위해)
    for i in range(n):
        a[i] = int(sys.stdin.readline())
    DFS(0, 0)
    print(result)

global 키워드를 붙인다.

2. 시간초과 에러

조건을 더 줄일 수 있다.
굳이 확인하지 않아도 되는 부분은 검사하지 않는다.

부분집합에 넣거나 안 넣거나 를 기준으로 하지 않고,
판단을 했으면 전부 더하는 tsum 을 하나 만든다.

def DFS(L, sum, tsum):  # L은 인덱스, sum은 부분집합의 합
    global result

    if sum + (total - tsum) < result:
        return

    if sum > c:
        return

    if L == n:
        if sum > result:
            result = sum
        return
    else:
        DFS(L+1, sum+a[L], tsum+a[L])
        DFS(L+1, sum, tsum+a[L])


if __name__ == '__main__':
    c, n = map(int, sys.stdin.readline().split())
    a = [0]*n
    result = -2147000000  # 가장 작은 값으로 초기화 (최대값 구하기 위해)
    for i in range(n):
        a[i] = int(sys.stdin.readline())
    total = sum(a)
    DFS(0, 0, 0)
    print(result)

0개의 댓글