[Algorithm] 연쇄 행렬 곱셈

건도리 ·2022년 11월 3일
0

Algorithm

목록 보기
1/5

목차

  1. Chained Matrix Multiplication (연쇄 행렬 곱셈)

Chained Matrix Multiplication

A, B, C … 와 같은 여러 행렬이 있을 때, 행렬들의 곱셈 순서에 우선 순위를 다르게 부여함에 따라 요구되는 연산의 수가 다르다.

예를 들어, 우리에게 A, B, C, D 총 네개의 행렬이 있다고 가정하자.

((AB)C)D = (A(BC))D = (AB)(CD) = A((BC)D) = A(B(CD))

괄호를 사용하여 우선순위를 부여함으로써 행렬의 곱셈 연산 수를 최소화 할 수 있다.

예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.

AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.

같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.

연쇄 행렬 곱셈을 브루트 포스를 이용해 모든 경우를 찾게 되면 시간 복잡도는 2^n이 된다.
따라서 동적 계획법을 이용한 연쇄 행렬 곱셈 알고리즘이 훨씬 효율적이다.

연쇄 행렬 곱셈의 구현 핵심은 부분 수열을 이용하는 것이다.

전체 행렬을 2개의 부분 수열로 분리한다

각 부분 수열에 있어, 최소 비용을 구한 후 합쳐준다

분리할 수 있을 때 까지 부분 수열의 길이를 늘려주면서 이 과정을 반복한다

점화식을 구해보면 다음과 같다

M[i,j] = M[i,k] + M[k+1,j] + Mat[i][0] * mat[k][1] * mat[j][1]

이해가 안간다면 바로 위에 예제를 참고해보자.

(AB)C를 구하는 연산은 M[0][1] + M[2][2] + mat[0][0] * mat[1][0] * mat[2][1] 이다

왜 저렇게 표현하죠?

(AB)를 묶었기 때문에 우리는 행렬을 두 부분 행렬로 나눌 수 있다.
	1. AB
	2. C

따라서 M[0][1] = AB, M[2][2] = C 라고 생각하자. 

그리고 mat[0][0] * mat[1][0] * mat[2][1] 은 5 x 2 x 6 이라고 보면 된다. 
(AB를 연산하게 되면 5x3x2 가 되어 5x2 행렬이 된다. 이를 5x2 행렬과 2x6 행렬로 다시 표현한 것)

코드

import java.util.*;

public class Chained_Matrix_Multiplication {
    public static int[][] mat, M, P;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        System.out.println("How many matrix?");
        int N = sc.nextInt();
        mat = new int[N][2];
        M = new int[N][N];
        P = new int[N][N];
        for (int i = 0; i < N; i++) {
            mat[i][0] = sc.nextInt();
            mat[i][1] = sc.nextInt();
        }

        minmat();

        System.out.print("출력할 순서의 인덱스를 차례대로 입력하시오 (i, j) >> ");
        int i = sc.nextInt();
        int j = sc.nextInt();
        order(i-1,j-1);
    }


    public static void minmat() {
        int i, j, k, diagonal;
        int len = mat.length;

        for (i = 0; i < len; i++) {
            M[i][i] = 0;
        }

        for (diagonal = 0; diagonal < len; diagonal++) {
            for (i = 0; i < len - diagonal; i++) {
                j = i + diagonal;
                for (k = i; k < j; k++) {
                    if (M[i][j] != 0) {
                        if(M[i][j]>M[i][k] + M[k + 1][j] + mat[i][0] * mat[k][1] * mat[j][1]){
                            M[i][j] = M[i][k] + M[k + 1][j] + mat[i][0] * mat[k][1] * mat[j][1];
                            P[i][j] = k+1;
                        }
                    } else {
                        M[i][j] = M[i][k] + M[k + 1][j] + mat[i][0] * mat[k][1] * mat[j][1];
                        P[i][j] = k+1;
                    }
                }
            }
        }

        for (i = 0; i < M.length; i++) {
            for (j = 0; j < M.length; j++) {
                System.out.print(M[i][j] + " ");
            }
            System.out.println();
        }

        for (i = 0; i < M.length; i++) {
            for (j = 0; j < M.length; j++) {
                System.out.print(P[i][j] + " ");
            }
            System.out.println();
        }
    }

    public static void order(int i, int j) {
        if (i == j) System.out.print("A" + (i+1));
        else {
            int k = P[i][j];
            System.out.print("(");
            order(i, k-1);
            order(k, j);
            System.out.print(")");
        }
    }
}
profile
배움이 즐거워요 ! 함께 그 즐거움을 나눴으면 좋겠습니다 :)

0개의 댓글