[Baekjoon] 10830번 행렬 제곱.cpp

세동네·2022년 10월 14일
0
post-thumbnail

[백준] 10830번 행렬 제곱 문제 링크

특정 값의 거듭제곱을 빠르게 연산하는 방식을 활용하는 문제이다.

· 효율적인 거듭제곱

an만큼 거듭제곱하여 result에 저장하는 단계는 다음을 반복한다.

  1. n이 2로 나누어 떨어지지 않는다면 resulta를 곱한다.
  2. a를 제곱한다. 이후, n을 2로 나누고, 나머지는 버린다.

이때 2번 단계에서 a를 제곱할 때마다 a에 그대로 저장한다. 즉, a는 단계를 거칠수록 최초 수의 2배승이 저장된다.

a27 = a16 * a8 * a2 * a1

이고, 위 단계를 거칠수록 아래와 같이 변한다.

  1. a = a, n = 27, result = 1

  2. a = a, n = 27, result = a
    a = a2, n = 13, result = a

  3. a = a2, n = 13, result = a2 * a
    a = a4, n = 6, result = a2 * a

  4. a = a4, n = 6, result = a2 * a
    a = a8, n = 3, result = a2 * a

  5. a = a8, n = 3, result = a8 * a2 * a
    a = a16, n = 1, result = a8 * a2 * a

  6. a = a16, n = 1, result = a16 * a8 * a2 * a
    a = a32, n = 0, result = a16 * a8 * a2 * a ===== Stop =====

n이 2로 나누어 떨어지지 않을 때 resulta를 곱해주는 과정에서 원래는 n을 1 빼주는 연산이 포함되지만, 컴퓨터는 정수형 연산에서 버림 처리를 해주기 때문에 자연스럽게 나눗셈에서 1이 덜어지게 되어 따로 연산을 해주진 않았다.

이 방법을 행렬에 적용하면 된다. 주의할 것은 정수에서 곱셈의 항등원인 1은 행렬에선 단위 행렬로 두어야 한다는 것이다.

#include <iostream>
using namespace std;

int n;
long long b;

long result[5][5] = { 0, };
long matrix[5][5] = { 0, };

void multiply_matrix(long mat1[5][5], long mat2[5][5]) {
	long temp[5][5] = { 0, };
	for (int row1 = 0; row1 < n; row1++) {
		for (int col2 = 0; col2 < n; col2++) {
			for (int row2 = 0, col1 = 0; row2 < n; row2++, col1++) {
				temp[row1][col2] += mat1[row1][col1] * mat2[row2][col2];
				temp[row1][col2] %= 1000;
			}
		}
	}
	for (int row = 0; row < n; row++) {
		for (int col = 0; col < n; col++) {
			mat1[row][col] = temp[row][col];
		}
	}
}

void init() {
	for (int row = 0; row < n; row++) {
		result[row][row] = 1;
		matrix[row][row] = 1;
	}
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(NULL);
	cout.tie(NULL);

	cin >> n >> b;

	init();

	for (int row = 0; row < n; row++) {
		for (int col = 0; col < n; col++) {
			cin >> matrix[row][col];
		}
	}

	while (b > 0) {
		if (b % 2 == 1) {
			multiply_matrix(result, matrix);
		}
		multiply_matrix(matrix, matrix);
		b /= 2;
	}

	for (int row = 0; row < n; row++) {
		for (int col = 0; col < n; col++) {
			cout << result[row][col] << " ";
		}
		cout << endl;
	}
}

0개의 댓글