초반에 제출한 풀이. set을 재귀로 전달할때 .copy()를 안해줘서 삽질을 오래했다. 어찌저찌해서 정확도는 통과하였으나 효율성은 통과하지 못했는데 그래프 최단경로를 DFS로 하게되면 이라는 말도 안되는 시간이 걸리기 때문이다.
import sys
sys.setrecursionlimit(10**7)
a_dest = None
b_dest = None
adj_node_list = None
adj_fair_matrix = None
def dfs(now_node: int, moving_persons: set, fair_sum: int, visited: list):
global a_dest, b_dest
global adj_node_list, adj_fair_matrix
if now_node == a_dest and 'A' in moving_persons:
moving_persons.remove('A')
elif now_node == b_dest and 'B' in moving_persons:
moving_persons.remove('B')
if not moving_persons: # A,B 모두 도달 완료
return fair_sum
now_min_fair_sum = sys.maxsize
if len(moving_persons) == 1: # A 혹은 B만 이동중
for next_node in adj_node_list[now_node]:
if not visited[next_node]:
visited[next_node] = True
next_fair_sum = fair_sum + adj_fair_matrix[now_node][next_node]
fair = dfs(next_node, moving_persons.copy(), next_fair_sum, visited)
now_min_fair_sum = min(now_min_fair_sum, fair)
visited[next_node] = False
else: # A,B 같이 이동중
for next_node in adj_node_list[now_node]: # 합승 O
if not visited[next_node]:
visited[next_node] = True
next_fair_sum = fair_sum + adj_fair_matrix[now_node][next_node]
fair = dfs(next_node, moving_persons.copy(), next_fair_sum, visited)
now_min_fair_sum = min(now_min_fair_sum, fair)
visited[next_node] = False
for a_next_node in adj_node_list[now_node]:
for b_next_node in adj_node_list[now_node]:
if a_next_node != b_next_node: # 합승 X
if not visited[a_next_node] and not visited[b_next_node]:
a_next_fair_sum = adj_fair_matrix[now_node][a_next_node]
b_next_fair_sum = adj_fair_matrix[now_node][b_next_node]
fair = fair_sum
visited[a_next_node] = True
fair += dfs(a_next_node, {"A"}, a_next_fair_sum, visited)
visited[a_next_node] = False
visited[b_next_node] = True
fair += dfs(b_next_node, {"B"}, b_next_fair_sum, visited)
visited[b_next_node] = False
now_min_fair_sum = min(now_min_fair_sum, fair)
return now_min_fair_sum
def solution(n, s, a, b, fares):
global a_dest, b_dest, global_min_fair_sum
global adj_node_list, adj_fair_matrix
a_dest, b_dest = a, b
adj_node_list = [[] for _ in range(n+1)] # adj_node_list[i]: i번 노트에 연결된 노드번호들
adj_fair_matrix = [[0]*(n+1) for _ in range(n+1)]
for fare in fares:
c,d,f = fare
adj_node_list[c].append(d)
adj_node_list[d].append(c)
adj_fair_matrix[c][d] = f
adj_fair_matrix[d][c] = f
visited = [False] * (n+1)
visited[s] = True
return dfs(s, {'A','B'}, 0, visited)
그래프에서 쓰이는 알고리즘들을 이용해야 하며, 이 문제에서는 플로이드-워셜 알고리즘을 통해 모든 정점에서 모든 정점까지의 최단 거리를 구해 놓아야 한다.
def solution(n, s, a, b, fares):
matrix = [[float('inf')]*(n+1) for _ in range(n+1)]
for i in range(1,n+1):
matrix[i][i]=0
for fare in fares:
c,d,f = fare
matrix[c][d] = f
matrix[d][c] = f
for m in range(1, n+1):
for f in range(1, n+1):
for t in range(1, n+1):
matrix[f][t] = min(matrix[f][t], matrix[f][m]+matrix[m][t])
answer = float('inf')
for c in range(1, n+1):
answer = min(answer, matrix[s][c] + matrix[c][a] + matrix[c][b])
return answer
다익스트라로도 풀이가 가능하다.