# 문제
크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.
예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.
AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.
같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.
행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.
# 입력
첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어진다.
둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다. (1 ≤ r, c ≤ 500)
항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.
# 출력
첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같다.
우선 이 문제를 풀기 위해서는 행렬을 곱했을 때 연산 횟수를 구해야 한다.
행렬 1과 행렬 2의 곱셈을 해보자!
크기가 5 X 3인 행렬 1
과 크기가 3 X 2인 행렬 2
를 곱할 때,
연산 횟수는 3을 두번씩 더한 것을 5번 반복하게 된다.
이것을 계산식으로 나타내면,
행렬 1의 행 X 행렬 1의 열 X 행렬 2의 열
이 된다.
이번에는 행렬이 4개일 때 연산 횟수를 구해보자!
(행렬 1 X 행렬 2) X (행렬 3 X 행렬 4)
순서로 곱한다고 했을 때 연산 횟수를 구해보자!
행렬이 1부터 3까지 있을 때는 아래와 같이 두가지 경우가 있다.
이 두가지 경우를 전부 계산했을 때 더 작은 값이 최소 연산 횟수가 된다.
행렬이 1부터 4까지 있을 때는 세가지 경우가 있다.
마찬가지로 각 경우의 연산 횟수를 계산했을 때, 그 중에서 가장 작은 값이 최소 연산 횟수가 된다.
이렇게 행렬이 4개인 경우에는, 행렬 3개를 곱했을 때의 최소 연산횟수를 알아야
4개일 때의 연산 횟수를 구할 수 있는 것을 알 수 있다.
(마찬가지로 행렬이 3개인 경우에는 행렬 2개를 곱했을 때의 최소 연산횟수를 알아야 계산이 가능하다.)
따라서
현재 구해야하는 범위 안에서 생길 수 있는 더 작은 범위에 해당하는 연산을 먼저 해야 한다.
(구한 최소 연산횟수를 dp 테이블에 저장해두고 필요할 때 가져와서 활용한다.)
이제 이 내용을 활용해서 코드로 구현해봅시다~!
dp = [[0]*(N) for _ in range(N)]
👉 저장할 값
dp[시작행렬][끝행렬] = 최소 연산 횟수
👉 예시
문제에서 주어진 예시 (A의 크기가 5×3, B의 크기가 3×2, C의 크기가 2×6)
(AB)C의 연산 횟수는 5×3×2 + 5×2×6 = 30 + 60 = 90번
A(BC)의 연산 횟수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이므로
dp[A][C] = 90이 된다.
위에서 연산 경우의 수를 확인하면서 살펴봤듯이
행렬 4개를 곱한 연산 횟수를 알기 위해서는 👉 3개를 곱한 연산 횟수를 알아야 하고,
3개를 곱한 연산 횟수를 구하기 위해서는 👉 2개를 곱한 연산 횟수를 알고 있어야 한다.
따라서, 간격이 작은 범위부터 연산 횟수를 계산해 나갈 것이다. (간격이 작은 것부터 계산한다.)
# 1. 간격이 1인 범위 먼저 전부 계산
행렬 1 ~ 2, 행렬 2 ~ 3, 행렬 3 ~ 4
# 2. 다음으로 간격이 2인 범위 전부 계산
행렬 1 ~ 3, 행렬 2 ~ 4
# 3. 마지막으로 간격이 3인 범위 계산
행렬 1 ~ 4
간격이 작은 것부터 하나씩 계산하기 위해서 term
변수를 만들어서 '1'부터 '행렬의 개수 -1'까지 늘려가며 간격으로 활용할 것이다.
적용해보면,
처음에는 term
이 1이므로 곱해야 하는 두 행렬은 행렬 satrt
와 행렬 start+1
이 된다.
(start
가 1이면 행렬 1
과 행렬 2
의 연산횟수를 구하면 된다.)
for term in range(1, N):
for start in range(N): # 현재 범위의 첫행렬: start, 끝행렬: start + term
if start + term == N: # 범위를 벗어나면 무시
break
term이 1일 때는 행렬 start
와 행렬 start+1
을 곱한 연산 횟수를 바로 구하면 되지만,
term이 1보다 크면 괄호로 묶어서 연산 순서를 바꿀 수 있으므로 여러 가지 경우가 생긴다.
예를 들어, term이 3인 경우에는 행렬 4개의 곱을 계산해야 하는데,
(시작 행렬: start
, 끝 행렬: start+3
👉 start
, start+1
, start+2
, start+3
)
괄호로 묶었을 때, 각 괄호 안에 들어있는 행렬의 개수가 1개부터 3개까지 될 수 있다.
👉 즉, 괄호 안의 행렬이 최소 1개
에서 최대 term개
가 된다.
괄호 묶음의 모든 경우를 계산하기 위해서
start
부터 start+term
직전까지 증가하는 t
변수를 활용한다.
이 t
변수를 활용해서
괄호로 묶이는 묶음을 기준으로 왼쪽 묶음과 오른쪽 묶음을 나눠 보자.
왼쪽 묶음에 들어가는 행렬의 개수는 최소 1개
부터 최대 term개
까지 하나씩 늘어난다.
dp[start][t]
왼쪽 묶음의 시작 행렬은 start
로 고정된다.
(왼쪽 묶음의 시작 행렬은 항상 동일)
왼쪽 묶음의 끝은 t
가 된다.
(t
는 start
부터 1씩 증가하므로 왼쪽 묶음의 끝은 start
, start+1
, start+2
, start+3
, ..., 계산 중인 범위의 마지막 행렬 - 1
이 된다.)
dp[t+1][start+term]
t
이므로, 오른쪽 묶음의 시작은 t+1
이 된다.start+term
이 된다.두 묶음의 결과를 가지고 두 행렬의 곱셈 연산 횟수를 구할 때와 똑같이 하면 된다.
왼쪽 묶음의 결과
행L
: 왼쪽 묶음 첫 행렬의 행 == arr[start][0]
열L
: 왼쪽 묶음 끝 행렬의 열 == arr[t][1]
오른쪽 묶음의 결과
행R
: 오른쪽 묶음 첫 행렬의 행 == arr[t][1]
열R
: 오른쪽 묶음 끝 행렬의 열 == arr[start+term][1]
행L
X 열L
X 열R
arr[start][0] * arr[t][1] * arr[start+term][1]
1~3을 더한 값이 최소가 되는 값을 dp[start][start+term]
에 저장한다.
dp[start][start+term] = int(1e9) # 지금 계산할 첫행렬과 끝행렬
for t in range(start, start+term):
dp[start][start+term] = min(dp[start][start+term],
# 👇 1 + 2 + 3
dp[start][t]+dp[t+1][start+term] + arr[start][0] * arr[t][1] * arr[start+term][1])
print(dp[0][N-1])
전체 코드
import sys
N = int(input())
arr = [list(map(int, sys.stdin.readline().split())) for _ in range(N)]
dp = [[0]*(N) for _ in range(N)]
for term in range(1, N):
for start in range(N): # 첫행렬 : i, 끝행렬: i+term
if start + term == N: # 범위를 벗어나면 무시
break
dp[start][start+term] = int(1e9) # 지금 계산할 첫행렬과 끝행렬
for t in range(start, start+term):
dp[start][start+term] = min(dp[start][start+term],
# 👇 1 + 2 + 3
dp[start][t]+dp[t+1][start+term] + arr[start][0] * arr[t][1] * arr[start+term][1])
print(dp[0][N-1])