[ BOJ / Python ] 1504번 특정한 최단 경로

황승환·2022년 2월 11일
0

Python

목록 보기
167/498


이번 문제는 다익스트라 알고리즘을 통해 해결하였다. 그래프는 양방향 인접 리스트로 구현하였고, 1부터 v1까지의 최단거리+v1부터 v2까지의 최단거리+v2부터 n까지의 최단거리와 1부터 v2까지의 최단거리+v2부터 v1까지의 최단거리+v1부터 n까지의 최단거리 중 더 작은 값을 결과로 출력하는 방식으로 접근하였다. 이 접근 방식을 그대로 사용하게 되면 6-1, 즉 5번의 다익스트라 알고리즘이 실행되게 된다. 이는 아무리 생각해도 비효율적이었지만 우선은 구현해보자는 생각으로 작성하였다. 당연하게도 시간초과가 계속해서 발생하였다. 이런 저런 방법을 생각해보다가 다른 사람의 코드를 보았다. 나와 다익스트라 알고리즘의 구성은 거의 유사했지만 다익스트라 함수의 반환 값이 리스트라는 점이 달랐다.

리스트 전체를 반환하여 이를 따로 저장하고, 다익스트라의 인자로는 시작 위치만 주어지게 된다. 이렇게 하면 시작 위치부터 n 사이의 모든 노드의 거리가 계산되어 리스트에 저장되므로 단 3번의 다익스트라 함수 호출로 정답을 구할 수 있다.

path=Dijkstra(1)
path1=Dijkstra(v1)
path2=Dijkstra(v2)
result=min(path[v1]+path1[v2]+path2[n], path[v2]+path2[v1]+path1[n])

여기서 path는 1을 시작점으로, path1은 v1을 시작점으로, path2는 v2를 시작점으로 하는 리스트를 저장하게 된다. 그리고 path[v1]은 1부터 v1까지의 최소거리, path1[v2]는 v1부터 v2까지의 최소거리, path2[n]은 v2부터 n까지의 최소거리가 되므로 처음에 접근하고자 했던 방식과 같다.

이렇게 함수 호출의 횟수를 리스트 반환을 통해 5번에서 3번으로 줄여 해결하였다. 이렇게 수정한 코드에서도 시간초과가 발생하여 input을 sys.stdin.readline으로 변경해주었더니 성공하였다.

  • input을 sys.stdin.readline으로 선언한다.
  • n, e를 입력받는다.
  • graph를 2차원 리스트로 선언한다.
  • e번 반복하는 for문을 돌린다.
    -> a, b, c를 입력받는다.
    -> graph[a](b, c)를 넣는다.
    -> graph[b](a, c)를 넣는다.
  • v1, v2를 입력받는다.
  • INF 변수에 sys.maxsize를 저장한다.
  • Dijkstra함수를 start를 인자로 갖도록 선언한다.
    -> 시작점에서 각 노드까지의 거리를 저장할 리스트 dist를 INF n+1개로 채운다.
    -> dist[start]를 0으로 갱신한다.
    -> 다익스트라에 사용할 큐 q를 최소힙으로 선언하고 (0, start)를 넣어준다.
    -> q가 존재하는 동안 반복하는 while문을 돌린다.
    --> q에서 distance, cur을 추출한다.
    --> 만약 distance가 dist[cur]보다 클 경우 다음 반복으로 넘어간다.
    --> graph[cur]을 순회하는 nxt, dst에 대한 for문을 돌린다.
    ---> cost에 distance+dst를 저장한다.
    ---> 만약 dist[nxt]가 cost보다 클 경우,
    ----> dist[nxt]를 cost로 갱신한다.
    ----> q에 (cost, nxt)를 넣는다.
    -> dist를 반환한다.
  • path에 Dijkstra(1)의 반환 리스트를 저장한다.
  • path1에 Dijkstra(v1)의 반환 리스트를 저장한다.
  • path2에 Dijkstra(v2)의 반환 리스트를 저장한다.
  • 결과를 저장할 변수 result에 path[v1]+path1[v2]+path2[n]path[v2]+path2[v1]+path1[n] 중 더 작은 값을 저장한다.
  • 만약 result가 INF보다 작을 경우,
    -> result를 출력한다.
  • 그 외에는 -1을 출력한다.

Code

import heapq
import sys
input=sys.stdin.readline
n, e=map(int, input().split())
graph=[[] for _ in range(n+1)]
for _ in range(e):
    a, b, c=map(int, input().split())
    graph[a].append((b, c))
    graph[b].append((a, c))
v1, v2=map(int, input().split())
INF=sys.maxsize
def Dijkstra(start):
    dist=[INF for _ in range(n+1)]
    dist[start]=0
    q=[]
    heapq.heappush(q, (0, start))
    while q:
        distance, cur=heapq.heappop(q)
        if distance>dist[cur]:
            continue
        for nxt, dst in graph[cur]:
            cost=distance+dst
            if cost<dist[nxt]:
                dist[nxt]=cost
                heapq.heappush(q, (cost, nxt))
    return dist
path=Dijkstra(1)
path1=Dijkstra(v1)
path2=Dijkstra(v2)
result=min(path[v1]+path1[v2]+path2[n], path[v2]+path2[v1]+path1[n])
if result<INF:
    print(result)
else:
    print(-1)

profile
꾸준함을 꿈꾸는 SW 전공 학부생의 개발 일기

0개의 댓글