크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.
입력
첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤ 5, 1 ≤ B ≤ 100,000,000,000)
둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.
출력
첫째 줄부터 N개의 줄에 걸쳐 행렬 A를 B제곱한 결과를 출력한다.
우선, 행렬을 입력받아 그 둘을 제곱하는 코드는 다음과 같을 것이다. 참고로, 거듭제곱이 아니라 한 번 제곱하는 수식이다.
import sys
input = sys.stdin.readline
N, B = map(int, input().split())
A = [ list(map(int, input().split())) for _ in range(N) ]
arr = [ [0 for _ in range(N)] for _ in range(N) ]
for r in range(N):
for c in range(N):
sum = 0
for i in range(N):
sum += A[r][i] * A[i][c]
arr[r][c] = sum%1000
print(arr)
A[r][i] * A[i][c]
부분은 즉 matrix1의 한 행과 matrix2의 한 열을 곱하는 것이다.
위를 def화 시키면 다음과 같은 것이다. (길이가 같은 다른 두 행렬을 곱한다는 전제다.)
def mul(matrix1, matrix2):
n = len(matrix1)
arr = [ [0 for _ in range(N)] for _ in range(N) ]
for r in range(n):
for c in range(n):
sum = 0
for i in range(n):
sum += matrix1[r][i] * matrix2[i][c]
arr[r][c] = sum%1000
return arr
이제 우리가 만든 이 함수 mul()로 거듭제곱을 할 것이다. 거듭제곱은 분할정복으로 푸는 아이디어를 그대로 사용한다. 다음은 우리가 만들 행렬 거듭제곱 함수 square()의 초기 세팅이다.
def square(a, b):
if b==1:
for x in range(len(a)):
for y in range(len(a)):
a[x][y] %= 1000
return a
tmp = square(a, b//2)
tmp = mul(tmp, tmp)
.
.
.
.
함수는 행렬 a와 정수의 숫자 b를 인자로 받고 있다. if b==1:
의 코드를 보면, 주어진 제곱의 횟수 B가 1일 때는 원소마다 직접 1000으로 나누는 연산을 해준 뒤 바로 연산처리된 행렬 a를 프린트한다.
그 외의 경우들은 모두 변수 tmp를 선언하는데 이는 인자 b를 2분할한 채 호출된 함수 square이다.
이어 더 적어 보자.
.
.
tmp = square(a, b//2)
tmp = mul(tmp, tmp)
if b%2==0:
return tmp
else:
return mul(tmp, a)
result = square(A, B)
for r in result:
print(*r)
if b%2:
즉 B가 홀수인 경우, 인자 b를 2분할 것에 나머지 1회를 더 제곱해줘야 하므로 mul(tmp)에 a를 한 번 더 mul 해준다.
다음의 코드는 최종적으로 정답이다.
import sys
input = sys.stdin.readline
N, B = map(int, input().split())
A = [ list(map(int, input().split())) for _ in range(N) ]
def mul(matrix1, matrix2):
n = len(matrix1)
arr = [ [0 for _ in range(N)] for _ in range(N) ]
for r in range(n):
for c in range(n):
sum = 0
for i in range(n):
sum += matrix1[r][i] * matrix2[i][c]
arr[r][c] = sum%1000
return arr
def square(a, b):
if b==1:
for x in range(len(a)):
for y in range(len(a)):
a[x][y] %= 1000
return a
tmp = square(a, b//2)
tmp = mul(tmp, tmp)
if b%2==0:
return tmp
else:
return mul(tmp, a)
result = square(A, B)
for r in result:
print(*r)