[BOJ 16467] - 병아리의 변신은 무죄 (DP, 수학, C++, Python)

보양쿠·2024년 4월 13일
0

BOJ

목록 보기
243/252

BOJ 16467 - 병아리의 변신은 무죄 링크
(2024.04.13 기준 G1)

문제

00일째엔 병아리가 11마리 있다. 병아리는 매일 혼자서 알을 하나 낳는다. 병아리가 낳은 알들은 KK일 후에 병아리로 태어난다. 병아리는 죽지 않고 쭉 병아리로 산다.

KK, NN이 주어질 때, NN일째의 병아리의 수를 100000007100\,000\,007로 나눈 나머지를 출력

알고리즘

행렬 DP

풀이

ii일째의 병아리의 수는 다음과 같은 점화식을 만족한다. dp(i)=dp(i1)+dp(i(K+1))dp(i) = dp(i-1)+dp(i-(K+1)).
병아리는 죽지 않아서 전날의 병아리 수가 그대로 넘어오며, K+1K+1일 전에 모든 병아리의 알들이 깨어나기 때문이다.

NN이 너무 크기 때문에, 1차원 선형 점화식은 행렬 곱으로 나타낼 수 있음을 이용해서 풀어야 한다. M=K+1M = K+1이라고 한다면, BOJ 17272 - 리그 오브 레전설 (Large) 풀이와 완벽하게 동일하다. 참고하자.

K=0K = 0이면, 병아리의 수는 22배씩 늘어난다. 결국 2N2^N % MOD를 구하는 문제가 된다. 이는 빠른 거듭제곱을 이용해서 구하자.

코드

  • C++
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef vector<vector<ll>> matrix;

const ll MOD = 100'000'007;

// 행렬 곱
matrix mul(matrix &A, matrix &B){
    // 행렬 A와 행렬 B의 곱은 A의 열과 B의 행 수가 같아야 한다.
    // 또한 곱의 결과는 A의 행의 수와 B의 열의 수를 크기로 갖는다.
    int n = A.size(), m = B[0].size(), l = A[0].size();
    matrix res(n, vector<ll>(m));
    for (int i = 0; i < n; i++) for (int j = 0; j < m; j++) for (int k = 0; k < l; k++)
        res[i][j] = (res[i][j] + A[i][k] * B[k][j]) % MOD;
    return res;
}

// 빠른 거듭제곱
matrix fpow(matrix mat, ll k){
    int M = mat.size();
    matrix res(M, vector<ll>(M));
    for (int i = 0; i < M; i++) res[i][i] = 1;
    while (k){
        if (k & 1) res = mul(res, mat);
        mat = mul(mat, mat);
        k >>= 1;
    }
    return res;
}

// x^n
ll fpow_(ll x, ll n){
    ll res = 1;
    while (n){
        if (n & 1) res = res * x % MOD;
        x = x * x % MOD;
        n >>= 1;
    }
    return res;
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    /* dp(i) = dp(i-1) + dp(i-(K+1))

    K = 0일 때 단순하게 2의 N제곱을 출력하면 된다.
    이는 빠른 거듭제곱을 이용하자.

    K > 1, M = K+1일 때
    dp(N) = dp(N-1) + dp(N-M)
    dp(N-1) = dp(N-1)
    ...
    dp(N-M+1) = dp(N-M+1)
    위와 같은 M개의 항을 행렬로 표현하고 정리를 하면
    행렬의 N제곱의 0행 0열 원소임을 알 수 있다. */

    // M은 2 이상 11 이하가 될 수 있다. 이에 따라 행렬을 준비하자.
    matrix mat[12];
    for (int i = 2; i <= 11; i++){
        mat[i].resize(i, vector<ll>(i));
        mat[i][0][0] = mat[i][0][i - 1] = 1;
        for (int j = 1; j < i; j++) mat[i][j][j - 1] = 1;
    }

    int T; cin >> T;
    for (int K, N, M; T; T--){
        cin >> K >> N;
        if (K == 0) cout << fpow_(2, N) << '\n';
        else{
            M = K + 1;
            cout << fpow(mat[M], N)[0][0] << '\n';
        }
    }
}
  • Python
import sys; input = sys.stdin.readline
MOD = 100000007

def mul(A, B):
    # 행렬 A와 행렬 B의 곱은 A의 열과 B의 행 수가 같아야 한다.
    # 또한 곱의 결과는 A의 행의 수와 B의 열의 수를 크기로 갖는다.
    n = len(A); m = len(B[0]); l = len(A[0])
    res = [[0] * m for _ in range(n)]
    for i in range(n):
        for j in range(m):
            for k in range(l):
                res[i][j] += A[i][k] * B[k][j]
                res[i][j] %= MOD
    return res

# 빠른 거듭제곱
def fpow(mat, k):
    res = [[0] * M for _ in range(M)]
    for i in range(M):
        res[i][i] = 1
    while k:
        if k & 1:
            res = mul(res, mat)
        mat = mul(mat, mat)
        k >>= 1
    return res

# x^n
def fpow_(x, n):
    res = 1
    while n:
        if n & 1:
            res = res * x % MOD
        x = x * x % MOD
        n >>= 1
    return res

''' dp(i) = dp(i-1) + dp(i-(K+1))

K = 0일 때 단순하게 2의 N제곱을 출력하면 된다.
이는 빠른 거듭제곱을 이용하자.

K > 1, M = K+1일 때
dp(N) = dp(N-1) + dp(N-M)
dp(N-1) = dp(N-1)
...
dp(N-M+1) = dp(N-M+1)
위와 같은 M개의 항을 행렬로 표현하고 정리를 하면
행렬의 N제곱의 0행 0열 원소임을 알 수 있다. '''

# M은 2 이상 11 이하가 될 수 있다. 이에 따라 행렬을 준비하자.
mat = [[[0] * i for _ in range(i)] for i in range(12)]
for i in range(2, 12):
    mat[i][0][0] = mat[i][0][i - 1] = 1
    for j in range(1, i):
        mat[i][j][j - 1] = 1

for _ in range(int(input())):
    K, N = map(int, input().split())
    if K == 0:
        print(fpow_(2, N))
    else:
        M = K + 1
        print(fpow(mat[M], N)[0][0])
profile
GNU 16 statistics & computer science

0개의 댓글