[BOJ 13907] - 세금 (다익스트라, Python)

보양쿠·2022년 9월 8일
0

BOJ

목록 보기
17/252

예전에 꽤나 재밌고 특이한 다익스트라 문제를 접했었다. 언젠가 풀이를 적어야지 하고 북마크를 해놓았다가 이제 풀이를 적어볼까 한다.

BOJ 13907 - 세금 링크
(2022.09.08 기준 P4)
(치팅하면 세금 더 냄)

문제

세금이 오르면 그 만큼 모든 도로의 통행료가 오르고 S에서 출발하여 D로 도착할 때
세금이 오르기 전과 세금이 오를 때마다의 최소 통행료

알고리즘

각 도로마다 통행료가 있고(가중치가 있는 간선) 출발에서 도착까지의 최소 비용을 구해야 하므로 다익스트라

풀이

모든 간선의 가중치가 차례대로 증가할 때마다 최소 비용을 구해야 한다.
세금 인상 횟수 K는 최대 30,000이므로 당연히 다익스트라를 30,000번 돌리는 것은 불가능.

이 문제는 다익스트라를 갱신할 때마다, 지나온 도시 수도 같이 관리를 해줘야 한다.
그러면 나중에 세금이 인상이 되어도 그 도시 수만큼만 세금을 곱해 더해주면 비교가 가능하기 때문.

다익스트라를 시작하기 전, 시작 도시에서 각 도시마다의 최소 비용을 저장할 distance 배열을 만들어준다. 보통 일차원 배열을 만들어 주는데, 이 문제에선 지나온 도시 수도 같이 저장해야 하기 때문에 이차원 배열로 만들어 준다.
시작 도시는 1번부터이기 때문에 행은 (N + 1) 만큼, 열은 N만큼 만들어 주자.
왜냐면, 지나온 도시 수가 총 도시 수를 넘어가게 되면 모든 도시를 다 돌았다는 뜻이므로 더이상 통행료가 줄어들 수 없기 때문에 필요가 없어진다. 그래서 지나온 도시 수는 N - 1만큼만 봐주면 된다. (처음엔 지나온 도시 수가 0이므로)

그리고 시작 도시가 S이므로 distance[S][0]을 0으로 저장해주고 힙에 [0, 0, S]를 넣어주자.
각 통행료, 지나온 도시 수, 현재 도시이다.

while queue:
    w, visited, here = heapq.heappop(queue) # 통행료, 지나온 도시 수, 현재 도시
    # 최적화
    # visited 도시 수만큼 지나서 한 도시에 도착했을 때
    # 통행료가 visited 도시 수까지의 통행료들보다 크면 스킵해줘야 한다.
    flag = False
    for i in range(visited + 1):
        if distance[here][i] < w:
            flag = True
            break
    # 위에서의 스킵이나 지나온 도시 수가 N - 1만큼 되었다면 스킵
    # 지나온 도시 수가 N - 1이면 모든 도시를 돌았으므로 더이상 통행료를 작게 할 수 없다.
    if visited == N - 1 or flag:
        continue

원래 다익스트라에선 if distance[here] > w 와 같은 최적화를 하는데
이 문제는 이렇게 하자. 코드를 보면 이해가 갈 것이다.

    # 이어진 도시들을 체크
    # 이 때, 도시를 하나 지나는 것이므로 visited + 1을 해주자
    for there, ww in graph[here]:
        if distance[there][visited + 1] > distance[here][visited] + ww:
            distance[there][visited + 1] = distance[here][visited] + ww
            heapq.heappush(queue, [distance[there][visited + 1], visited + 1, there])

그 다음에 이어진 도시들로 하여금 거리를 갱신할 때 코드다.
특별한 것은 없고, 지나온 도시 수에 1을 더해주는 것만 잊지 말자.

이제 세금을 올리기 전 통행료를 구하면 되는데, 이는 그냥 다익스트라 답과 같다.
지나온 도시 수 차례대로 검사하여 답이 갱신해주자.
이 때, 갱신될 때마다 그 때의 도시 수를 따로 저장해두자.
왜냐면, 나중에 세금 올려서 검사할 때마다, 저장한 도시 수보다 지나온 도시 수가 커지면 통행료는 무조건 증가하기 때문에 검사할 필요가 없어지기 때문에 최적화를 위하여 저장하는 것이다.

그 후론 세금을 차례대로 올리면서 저장한 도시 수까지만 (도시 수 * 세금)을 더한 값으로 비교해주면서 답을 출력해주면 된다. 이 때도 갱신될 때의 도시 수를 저장하자.

코드

import sys; input = sys.stdin.readline
import heapq
from math import inf

def solve():
    N, M, K = map(int, input().split())
    S, D = map(int, input().split())
    graph = [[] for _ in range(N + 1)]
    for _ in range(M):
        a, b, w = map(int, input().split())
        graph[a].append((b, w))
        graph[b].append((a, w))
    P = [int(input()) for _ in range(K)]

    # 다익스트라 시작
    # 도시에 도착할 때 지나온 도시 수도 관리해줘야 함
    # 도시 수만큼 열을 만들어주면 되는데, 도시 수가 넘어가면 모든 도시를 다 돈 이후이므로 필요가 없다.
    distance = [[inf] * N for _ in range(N + 1)] # 행 - 현재 도시, 열 - 지나온 도시 수
    distance[S][0] = 0
    queue = [[0, 0, S]] # 통행료, 지나온 도시 수, 현재 도시
    while queue:
        w, visited, here = heapq.heappop(queue)
        # 최적화
        # visited 도시 수만큼 지나서 한 도시에 도착했을 때
        # 통행료가 visited 도시 수까지의 통행료들보다 크면 스킵해줘야 한다.
        flag = False
        for i in range(visited + 1):
            if distance[here][i] < w:
                flag = True
                break
        # 위에서의 스킵이나 지나온 도시 수가 N - 1만큼 되었다면 스킵
        # 지나온 도시 수가 N - 1이면 모든 도시를 돌았으므로 더이상 통행료를 작게 할 수 없다.
        if visited == N - 1 or flag:
            continue
        # 이어진 도시들을 체크
        # 이 때, 도시를 하나 지나는 것이므로 visited + 1을 해주자
        for there, ww in graph[here]:
            if distance[there][visited + 1] > distance[here][visited] + ww:
                distance[there][visited + 1] = distance[here][visited] + ww
                heapq.heappush(queue, [distance[there][visited + 1], visited + 1, there])

    # 세금을 올리지 전은 그냥 다익스트라 답을 구해준다.
    # 단, 지나온 도시 수 차례대로 검사하고
    # 답이 갱신될 때마다 그 때의 지나온 도시 수를 저장해주자.
    # 왜냐면, 그 도시 수보다 커지면 세금이 커지면 무조건 통행료가 증가하기 때문에 검사할 필요가 없어진다.
    answer = inf
    for visited in range(N):
        if answer > distance[D][visited]:
            answer = distance[D][visited]
            limit = visited
    print(answer) # 세금 올리기 전 통행료 출력

    # 이제부터 세금 올리면서 최소 통행료를 구하자.
    tax = 0
    for p in P:
        tax += p # 총 세금에 세금 차례대로 더해가면서 구한다.
        answer = inf
        for visited in range(limit + 1): # 검사는 저번에 구해놓은 도시 수 제한까지만
            # 지나온 도시 수만큼 세금을 곱해 더해주면 총 통행료가 나온다.
            # 만약 답이 갱신되면 똑같이 그 때의 지나온 도시 수를 저장하자.
            if answer > distance[D][visited] + tax * visited:
                answer = distance[D][visited] + tax * visited
                limit = visited
        print(answer)

solve()

여담

도시 수 저장 최적화를 위는 안하고 밑은 한 결과다. 최적화가 얼마나 중요한지 보여준다.
그리고 현재 Python3 제출 순위. 내가 1등이다!
27일 전에 340ms 였는데 방금 함수로 만들어서 지역 변수로 입력받아 푸니깐 저렇게 더 줄었다.

역시 파이썬은 최적화가 중요하다.

profile
GNU 16 statistics & computer science

2개의 댓글

comment-user-thumbnail
2023년 7월 24일

설명 잘 봤습니다! 근데 코드를 이해하는데 시간이 좀 많이 걸렸네요ㅠ
limit에 대한 자세한 설명이 있었으면 더 좋았을 것 같습니다! 그리고 지나온 도시 수를 저장하는데 visited라는 이름대신 passed_city_cnt와 같은 이름을 사용하셨다면 처음 본 사람이 코드를 이해하는데 훨씬 시간을 아낄 수 있을 것 같습니다.
그럼에도 이렇게 포스팅 해주셔서 잘 이해할 수 있었습니다. 감사합니다!

1개의 답글