[PRO] 합승 택시 요금

천호영·2022년 7월 16일
0

알고리즘

목록 보기
36/100
post-thumbnail

초반에 제출한 풀이. set을 재귀로 전달할때 .copy()를 안해줘서 삽질을 오래했다. 어찌저찌해서 정확도는 통과하였으나 효율성은 통과하지 못했는데 그래프 최단경로를 DFS로 하게되면 O(N!)O(N!)이라는 말도 안되는 시간이 걸리기 때문이다.

import sys
sys.setrecursionlimit(10**7)

a_dest = None
b_dest = None
adj_node_list = None
adj_fair_matrix = None

def dfs(now_node: int, moving_persons: set, fair_sum: int, visited: list):
    global a_dest, b_dest
    global adj_node_list, adj_fair_matrix
    
    if now_node == a_dest and 'A' in moving_persons:
        moving_persons.remove('A')
    elif now_node == b_dest and 'B' in moving_persons:
        moving_persons.remove('B')
    
    if not moving_persons: # A,B 모두 도달 완료
        return fair_sum
    
    now_min_fair_sum = sys.maxsize
    if len(moving_persons) == 1: # A 혹은 B만 이동중
        for next_node in adj_node_list[now_node]:
            if not visited[next_node]:
                visited[next_node] = True
                next_fair_sum = fair_sum + adj_fair_matrix[now_node][next_node]
                fair = dfs(next_node, moving_persons.copy(), next_fair_sum, visited)
                now_min_fair_sum = min(now_min_fair_sum, fair)
                visited[next_node] = False
    else: # A,B 같이 이동중
        for next_node in adj_node_list[now_node]: # 합승 O
            if not visited[next_node]:
                visited[next_node] = True
                next_fair_sum = fair_sum + adj_fair_matrix[now_node][next_node]
                fair = dfs(next_node, moving_persons.copy(), next_fair_sum, visited)
                now_min_fair_sum = min(now_min_fair_sum, fair)
                visited[next_node] = False
        
        for a_next_node in adj_node_list[now_node]:
            for b_next_node in adj_node_list[now_node]:
                if a_next_node != b_next_node: # 합승 X
                    if not visited[a_next_node] and not visited[b_next_node]:
                        a_next_fair_sum = adj_fair_matrix[now_node][a_next_node]
                        b_next_fair_sum = adj_fair_matrix[now_node][b_next_node]
                        fair = fair_sum

                        visited[a_next_node] = True
                        fair += dfs(a_next_node, {"A"}, a_next_fair_sum, visited)
                        visited[a_next_node] = False

                        visited[b_next_node] = True
                        fair += dfs(b_next_node, {"B"}, b_next_fair_sum, visited)
                        visited[b_next_node] = False
                        
                        now_min_fair_sum = min(now_min_fair_sum, fair)
                
    return now_min_fair_sum
        

def solution(n, s, a, b, fares):
    global a_dest, b_dest, global_min_fair_sum
    global adj_node_list, adj_fair_matrix
    a_dest, b_dest = a, b
    
    adj_node_list = [[] for _ in range(n+1)] # adj_node_list[i]: i번 노트에 연결된 노드번호들
    adj_fair_matrix = [[0]*(n+1) for _ in range(n+1)]
    for fare in fares:
        c,d,f = fare
        adj_node_list[c].append(d)
        adj_node_list[d].append(c)
        adj_fair_matrix[c][d] = f
        adj_fair_matrix[d][c] = f
    
    visited = [False] * (n+1)
    visited[s] = True
    return dfs(s, {'A','B'}, 0, visited)

그래프에서 쓰이는 알고리즘들을 이용해야 하며, 이 문제에서는 플로이드-워셜 알고리즘을 통해 모든 정점에서 모든 정점까지의 최단 거리를 구해 놓아야 한다.

def solution(n, s, a, b, fares):
    matrix = [[float('inf')]*(n+1) for _ in range(n+1)]
    for i in range(1,n+1):
        matrix[i][i]=0
        
    for fare in fares:
        c,d,f = fare
        matrix[c][d] = f
        matrix[d][c] = f
    
    for m in range(1, n+1):
        for f in range(1, n+1):
            for t in range(1, n+1):
                matrix[f][t] = min(matrix[f][t], matrix[f][m]+matrix[m][t])
    
    answer = float('inf')
    for c in range(1, n+1):
        answer = min(answer, matrix[s][c] + matrix[c][a] + matrix[c][b])
        
    return answer

다익스트라로도 풀이가 가능하다.

profile
성장!

0개의 댓글