모든 정점 쌍 사이의 최단거리를 구하는 알고리즘
아래와 같은 테이블처럼 모든 정점 쌍 사이의 최단거리를 구해주는 알고리즘이다.
1 | 2 | 3 | |
---|---|---|---|
1 | 0 | 4 | 1 |
2 | 4 | 0 | 5 |
3 | 1 | 5 | 0 |
무방향 그래프, 방향 그래프에 상관없이 사용할 수 있는 알고리즘이다.
간선값이 음수여도 알고리즘은 잘 동작하지만 음수인 사이클이 있으면 문제가 생긴다.
구현이 쉬운편에 속한다. 주어진 간선들을 우선 최단거리 테이블에 채우고 빈 값은 무한으로 채워준다.
그런 다음 하나의 정점(A)에서 다른 정점(B)으로 갈 때의 값과 정점A에서 또 다른 정점(C)을 경유해서 정점B로 가는 값을 비교해 더 작은 값으로 갱신해준다.
이런식으로 경유정점을 계속 바꾸면서 최단거리를 업데이트 해준다.
경유정점에 대한 순회 V 그리고 특정 정점에서 다른 정점으로의 이동 조회관련 순회 V2 으로 시간 복잡도는 O(V3)이다.
boj-11404
import sys
sys.stdin = open('', 'r')
input = sys.stdin.readline
N = int(input())
M = int(input())
INF = int(10e9)
costs = [[INF] * (N+1) for _ in range(N+1)]
for i in range(1, N+1):
costs[i][i] = 0
for _ in range(M):
start, arrive, cost = map(int, input().split())
costs[start][arrive] = min(costs[start][arrive], cost )
for k in range(1, N+1):
for i in range(1, N+1):
for j in range(1, N+1):
costs[i][j] = min(costs[i][j] , costs[i][k] + costs[k][j])
for i in range(1, N+1):
for j in range(1, N+1):
if costs[i][j] == INF: print(0, end= ' ')
else: print(costs[i][j], end=' ')
print()
컴퓨터는 1초에 3~5억번 정도 연산을 한다고 생각하면 편하다.
정점이 1000개인 이 문제는 1000**3인 10억개의 연산을 해야하기 때문에 불가능할 것 처럼 보이지만 플로이드 알고리즘은 단순 사칙연산이 주를 이루기 때문에 정점 1000개까지는 플로이드 알고리즘으로 풀어볼만 하다.
시간이 간당간당할 때 줄이는 방법 → 연산보다 대입이 느리기 때문에 min을 통해서 매번 대입이 일어나게 하는 것 보다 if문으로 검사해서 꼭 필요할 때만 대입이 일어나게 하면 더 빨라진다.
(dp문제처럼 대입이 빈번하게 일어난다면 비슷한 방법으로 시간을 줄일 수 있다.)
실제로 시간의 1/3 정도가 절약됐다.
순회를 돌면서 최단거리테이블과 next테이블을 같이 채워준다.
next테이블의 초기값 설정은 경로가 있다면 도착 노드를 next테이블에 채워주면 된다.
최단거리 테이블이 갱신이 된다면 next테이블의 값을 next[출발노드][경유노드]의 값으로 업데이트 해준다.
import sys
input = sys.stdin.readline
N = int(input())
M = int(input())
INF = int(10e9)
costs = [[INF] * (N+1) for _ in range(N+1)]
nexts = [[0] * (N+1) for _ in range(N+1)]
for i in range(1, N+1):
costs[i][i] = 0
for _ in range(M):
start, arrive, cost = map(int, input().split())
costs[start][arrive] = min(costs[start][arrive], cost)
## 초기화
nexts[start][arrive] = arrive
for k in range(1, N+1):
for i in range(1, N+1):
for j in range(1, N+1):
### 이 부분에서 nexts 값을 갱신해준다.
if costs[i][k] + costs[k][j] < costs[i][j]:
costs[i][j] = costs[i][k] + costs[k][j]
nexts[i][j] = nexts[i][k]
for i in range(1, N+1):
for j in range(1, N+1):
if costs[i][j] == INF: print(0, end= ' ')
else: print(costs[i][j], end=' ')
print()
### 갯수와 경로를 순서대로 프린트해야하고
for i in range(1, N+1):
for j in range(1, N+1):
if costs[i][j] == 0 or costs[i][j] == INF:
print(0)
continue
stack = []
start = i
### nexts테이블의 값이 j와 같아진다는 건 최종경로라는 의미로 최종경로에 닿을 때 까지 stack에 노드값을 넣어준다.
while start != j:
stack.append(start)
start = nexts[start][j]
stack.append(j)
print(len(stack), end=' ')
for node in stack:
print(node, end= ' ')
print()