[KeraTorch 개발일지] 2. 자동 미분과 역전파

yeho.dev·2023년 7월 26일
0

KeraTorch 개발일지

목록 보기
2/2
post-thumbnail

🔄 자동 미분 (auto-grad)

KeraTorch는 자동 미분을 지원한다. 자동 미분이란 함수형 프로그래밍 언어와 같이 어떤 값을 연산할 때 연산 결과뿐만 아니라 연산 과정을 추적하여 자동으로 미분을 해주는 것을 말한다. 이는 딥러닝의 오차 역전파 (backpropagation)에서 필수적이다.

본 프로젝트에서 자동 미분을 위해 구현한 GradArray 클래스와 Grad 클래스에 대해 자세히 살펴보기 전에 역전파 알고리즘이 왜 필요하고 어떤 방식으로 작동하는지 간단히 알아보자.

⏪ 역전파 (backpropagation)

1. 모델 및 오차 정의

필자는 딥러닝을 가능하게 만든 가장 중요한 알고리즘은 역전파라고 생각한다. 오차 역전파는 Geoffrey Hinton이 크게 기여한 알고리즘으로, 신경망의 오차가 계산되었을 때 그 값에 각 파라미터가 기여하는 정도를 미분으로 계산하는 방식을 말한다.

예를 들어 아래와 같은 수식으로 연산되는 퍼셉트론 (perceptron)을 생각해보자.

z(x)=wTx+b=i=1nwixi+by^(x)=σ(z(x))z(\mathbf x)=\mathbf w^T \mathbf x+b=\sum_{i=1}^{n}{w_ix_i}+b \\ \hat y(\mathbf x)=\sigma(z(\mathbf x))

위 모델은 입력 xRn\mathbf x \in \R^n에 대해 스칼라 출력 y^R\hat y \in \R을 만든다. 모델에 비선형성을 주기 위해 비선형 활성화함수 σ\sigma를 적용한다.

우리는 이제 모델을 데이터셋 {(x1,y1),(xN,yN)}\left \{(\mathbf x_1, y_1),\cdots (\mathbf x_N,y_N)\right\}에 fitting 하고자 한다. 이때, 모델이 만든 출력 y^\hat y와 실제 정답 yy의 차이의 제곱을 오차로 정의할 수 있을 것이다. 여러 개의 데이터가 학습에 사용되므로 전체 데이터에 대한 오차는 각 데이터에 대한 오차의 평균이다.

L=1Ni=1N(y^iyi)2wherey^i=y^(xi)L=\frac{1}{N} \sum_{i=1}^{N}{\left( \hat y_i-y_i \right)^2} \quad \text{where} \quad \hat y_i=\hat y(\mathbf x_i)

nn은 입력의 차원 (n=dimxn=\dim \mathbf x), NN은 batch의 크기이다.

2. 경사하강법과 오차 역전파

우리는 모델을 LL을 줄이는 방향으로 업데이트 하기 위해 경사하강법 (gradient descent) 알고리즘을 이용한다. 경사하강법은 아래의 수식을 이용하여 파라미터를 업데이트한다.

wwαLw\mathbf w \leftarrow \mathbf w-\alpha \frac{\partial L}{\partial \mathbf w}

위 식에서 오차 LL에 대한 파라미터 w\mathbf w에 대한 편미분은 w\mathbf w에 대해 LL이 증가하는 정도를 나타낸다. 즉, LLRn\R^n에 그리는 면의 w\mathbf w 축에 대한 기울기이다. 경사하강법은 이 방향의 음의 방향으로 파라미터를 조절한다.

이때 역전파 알고리즘이 이용된다. 연쇄 법칙 (chain rule)을 적용한 아래 식을 살펴보자. 간단하게 표현하기 위해 두 출력은 vector 형태로 작성하겠다.

{y^[y^1y^N]y[y1yN]\begin{cases} \hat \mathbf y \coloneqq \begin{bmatrix}\hat y_1 & \cdots & \hat y_N\end{bmatrix}^\top\\ \mathbf y \coloneqq \begin{bmatrix} y_1 & \cdots & y_N\end{bmatrix}^\top \end{cases}
Lw=Ly^y^w=Ly^y^zzw\frac{\partial L}{\partial \mathbf w}=\frac{\partial L}{\partial \hat \mathbf y}\frac{\partial \hat \mathbf y}{\partial \mathbf w}=\frac{\partial L}{\partial \hat \mathbf y}\frac{\partial \hat \mathbf y}{\partial z}\frac{\partial z}{\partial \mathbf w}

LL에 대한 wiw_i의 편미분을 직접 계산하기 복잡하기 때문에 미분의 연쇄 법칙을 이용하여 다른 편미분의 곱으로 계산한다. 우선 첫 번째 항을 계산해보자.

L=1Ni=1N(y^iyi)2=1Ni=1N(y^i22y^iyi+yi2)=1N(y^y^+yy2y^y)\begin{aligned} L&=\frac{1}{N} \sum_{i=1}^{N}{\left( \hat y_i-y_i \right)^2}\\ &=\frac{1}{N}{\sum_{i=1}^{N}{\left( \hat y_i^2-2\hat y_iy_i+y_i^2 \right)}}\\ &=\frac{1}{N}\left( \hat \mathbf y^\top \hat \mathbf y + \mathbf y^\top \mathbf y - 2\hat \mathbf y^\top \mathbf y\right) \end{aligned}
Ly^=1NLy^(y^y^+yy2y^y)=2N(y^y)\begin{aligned} \frac{\partial L}{\partial \hat \mathbf y}&= \frac{1}{N}\frac{\partial L}{\partial \hat \mathbf y}\left( \hat \mathbf y^\top \hat \mathbf y + \mathbf y^\top \mathbf y - 2\hat \mathbf y^\top \mathbf y\right)\\ &=\frac 2 N \left( \hat \mathbf y-\mathbf y \right)^\top \end{aligned}

두 번째 항 또한 계산해보자.

y^z=[y^1zy^Nz]=[σ(z(x1))zσ(z(xN))z]=[σ(z(x1))σ(z(xN))]\begin{aligned} \frac{\partial \hat \mathbf y}{\partial z}&=\begin{bmatrix} \frac{\partial \hat y_1}{\partial z} & \cdots & \frac{\partial \hat y_N}{\partial z} \end{bmatrix}^\top\\ &=\begin{bmatrix} \frac{\partial \sigma(z(\mathbf x_1))}{\partial z} & \cdots & \frac{\partial \sigma(z(\mathbf x_N))}{\partial z} \end{bmatrix}^\top\\ &=\begin{bmatrix} {\sigma'(z(\mathbf x_1))} & \cdots & {\sigma'(z(\mathbf x_N))} \end{bmatrix}^\top \end{aligned}

세 번째 항은 아래와 같이 계산된다. 간단한 선형 결합이므로 미분 또한 상수로 표현된다.

zw=x\frac{\partial z}{\partial \mathbf w}=\mathbf x^\top

최종적으로 파라미터 업데이트를 위해 계산해야 했던 미분식은 아래와 같이 표현된다.

Lw=2N(y^y)[σ(z(x1))σ(z(xN))]x\frac{\partial L}{\partial \mathbf w}=\frac 2 N \left( \hat \mathbf y-\mathbf y \right)^\top \begin{bmatrix} \sigma'(z(\mathbf x_1))\\ \vdots\\ \sigma'(z(\mathbf x_N)) \end{bmatrix} \mathbf x^\top

이로써 위 미분값을 이용해 파라미터 w\mathbf w를 업데이트 할 수 있게 되었다.

❗️ 역전파 알고리즘의 중요성

하나의 층으로 구성된 퍼셉트론의 예시에서는 오차 역전파가 강력한 도구로 보이지 않지만, 위와 같은 역전파 매커니즘이 깊은 네트워크에도 동일하게 적용된다는 점에서 딥러닝에 역전파 알고리즘은 필수적이다.

🏠 GradArray 클래스

GradArray 클래스는 PyTorch의 torch.Tensor와 유사한 역할을 하는 클래스로, 실제 array 값과 값이 연산되기 위해 사용된 연산자 정보를 저장한다.

GradArray가 지니는 중요한 속성을 몇 가지 살펴보겠다.

  • GradArray._array: numpy.ndarray 타입으로, 실제 array 값을 저장한다.

  • GradArray._grad: numpy.ndarray 타입으로, 역전파가 실행되었을 때 연산된 미분 값을 저장한다.

  • GradArray._grad_op: Grad 타입으로, 해당 GradArray가 생성되기 위해 수행된 연산 정보를 저장한다. 연산되지 않고 값만 할당된 GradArray의 경우에는 None으로 설정된다.

위 3가지 속성으로 GradArray는 역전파를 지원한다. 자세한 역전파 과정과 연산은 다음 포스트에서 살펴보겠다.

⛰️ Grad 클래스

역전파를 지원하기 위해 연산 정보를 저장하는 추상 클래스이다. 연산의 정보를 저장하기 위해 연산에 사용된 입력 GradArray를 저장하며, 역전파 과정에서 미분값이 입력 GradArray로 전달된다.

Grad 추상 클래스는 자체로는 아무런 연산 정보를 갖지 않기 때문에 Grad.backward()는 추장 메소드로 선언되어 있다. Grad 추상 클래스를 상속한 AddGrad, TransposeGrad, MatMulGrad, PowerGrad 등의 클래스는 backward()를 적절히 정의함으로써 역전파 알고리즘에서 호출될 수 있도록 구현한다.

🔥 KeraTorch는 OOP를 적극적으로 수용한다!

✈️ 결론 및 다음 포스트

이번 포스트에서는 자동 미분 및 오차 역전파와 GradArray, Grad 클래스의 대략적인 구조와 역할에 알아봤다. 다음 포스트에서는 KeraTorch의 핵심이라고 볼 수 있는 GradArrayGrad 클래스의 작동 방식과 연산을 세부적으로 알아보겠다.

2개의 댓글

comment-user-thumbnail
2023년 7월 26일

공감하며 읽었습니다. 좋은 글 감사드립니다.

1개의 답글