부스트캠프 week1 RNN 추가학습 D2L

Dae Hee Lee·2021년 8월 8일
0

BoostCamp

목록 보기
9/22

BPTT(Backpropagation Through Time)

Dive into Deep Learning 교재의 BPTT부분을 발췌하여 학습하고자 한다.

BPTT는 교재 8단원 7절에 해당한다. 1~6절 앞부분을 미리 읽고 와도 도움이 될 듯하다. 이번 글에서는 RNN에서 시퀀스 모델에 대한 역전파 알고리즘의 디테일에 대해 알아보고 어떻게 수학이 활용되었는지 알아보겠다.

가장 먼저, 시퀀스 모델에서 Gradient가 어떻게 계산되는지 리뷰해보겠다. MLP의 역전파에서도 그랬듯, Gradient의 계산에는 Chain Rule을 사용한다. 이전 글에서 역전파에 대해 다뤘던 적이 있었으니 참고해도 좋다.

BPTT는 RNN의 역전파 방법들 중 하나로, 모델 변수와 파라미터의 의존성을 얻기 위해 한 번에 한 단계씩 RNN의 Computational Graph를 확장시켜야 한다. 그리고나서 ChainRule에 기반하여 Gradient를 계산하고 저장할 수 있도록 역전파를 진행하는 것이다.

Analysis of Gradients in RNNs

단순화된 RNN 모델을 이용해서 시작해보자. 이 모델은 Hidden State가 어떻게 업데이트되는지에 대한 디테일은 어느정도 무시하는 모델이다. 또, 아래 서술된 수식들은 scalars,vectors, matrices를 명시적으로 구별하는 것들은 아니다. 그런 것들은 분석에서 크게 중요하지 않으며 표기를 더욱 혼란스럽게만 할 뿐이다.

hth_{t}는 hidden state, xtx_{t}는 input, oto_{t}는 output at time step tt로 하자. 이와 같이, 은닉층과 출력층의 가중치로써 whw_{h}wow_{o}를 사용하겠다. 결과적으로 time step에 따른 Hidden state와 출력은 다음과 같이 나타낼 수 있다.

ht=f(xt,ht1,wh),ot=g(ht,wo),\begin{aligned} h_{t} &=f\left(x_{t}, h_{t-1}, w_{h}\right), \\ o_{t} &=g\left(h_{t}, w_{o}\right), \end{aligned}

ff,gg는 각각 은닉층과 출력층의 변환을 의미한다.

이 때, 우리는 Desired label값 y에 대해서 모든 T time step의 목적함수 LL을 다음과 같이 작성할 수 있다.

L(x1,,xT,y1,,yT,wh,wo)=1Tt=1Tl(yt,ot)L\left(x_{1}, \ldots, x_{T}, y_{1}, \ldots, y_{T}, w_{h}, w_{o}\right)=\frac{1}{T} \sum_{t=1}^{T} l\left(y_{t}, o_{t}\right)

역전파에서, 특히 LLwhw_h에 대한 gradients를 계산할 때는 조금 더 Trickier하다. 자세히 표현하자면, Chain Rule에 의해
Lwh=1Tt=1Tl(yt,ot)wh=1Tt=1Tl(yt,ot)otg(ht,wo)hthtwh\begin{aligned} \frac{\partial L}{\partial w_{h}} &=\frac{1}{T} \sum_{t=1}^{T} \frac{\partial l\left(y_{t}, o_{t}\right)}{\partial w_{h}} \\ &=\frac{1}{T} \sum_{t=1}^{T} \frac{\partial l\left(y_{t}, o_{t}\right)}{\partial o_{t}} \frac{\partial g\left(h_{t}, w_{o}\right)}{\partial h_{t}} \frac{\partial h_{t}}{\partial w_{h}} \end{aligned}
로 표현되고, 첫 번째와 두 번째 factor들은 쉽게 계산 가능하다. 세 번째 factor인 ht/wh\partial h_{t} / \partial w_{h}whw_hhth_t의 효과를 순환적으로 계산하기에, 조금 까다롭다.

처음 정의에 의해 hth_tht1h_{t-1}whw_h에 의존한다. 또한 ht1h_{t-1}whw_h에 의존하고 있기에 Chain Rule에 의해 다음과 같이 나타낼 수 있다.
htwh=f(xt,ht1,wh)wh+f(xt,ht1,wh)ht1ht1wh\frac{\partial h_{t}}{\partial w_{h}}=\frac{\partial f\left(x_{t}, h_{t-1}, w_{h}\right)}{\partial w_{h}}+\frac{\partial f\left(x_{t}, h_{t-1}, w_{h}\right)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_{h}}

위의 Gradients는 세 개의 시퀀스를 가지고 있는데, {at},{bt},{ct}\left\{a_{t}\right\},\left\{b_{t}\right\},\left\{c_{t}\right\} 다음을 만족한다. a0=0a_{0}=0 and at=bt+ctat1a_{t}=b_{t}+c_{t} a_{t-1} for t=1,2,t=1,2, \ldots for t1t \geq 1, 는 다음과 같이 나타낼 수 있다.

at=bt+i=1t1(j=i+1tcj)bia_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} c_{j}\right) b_{i}

at,bta_{t}, b_{t}, and ctc_{t} 다음과 같이 대체할 수 있다.

at=htwhbt=f(xt,ht1,wh)whct=f(xt,ht1,wh)ht1\begin{aligned} &a_{t}=\frac{\partial h_{t}}{\partial w_{h}} \\ &b_{t}=\frac{\partial f\left(x_{t}, h_{t-1}, w_{h}\right)}{\partial w_{h}} \\ &c_{t}=\frac{\partial f\left(x_{t}, h_{t-1}, w_{h}\right)}{\partial h_{t-1}} \end{aligned}

결과적으로 다음과 같은 결과값을 얻을 수 있다.
htwh=f(xt,ht1,wh)wh+i=1t1(j=i+1tf(xj,hj1,wh)hj1)f(xi,hi1,wh)wh\frac{\partial h_{t}}{\partial w_{h}}=\frac{\partial f\left(x_{t}, h_{t-1}, w_{h}\right)}{\partial w_{h}}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f\left(x_{j}, h_{j-1}, w_{h}\right)}{\partial h_{j-1}}\right) \frac{\partial f\left(x_{i}, h_{i-1}, w_{h}\right)}{\partial w_{h}}

profile
Today is the day

0개의 댓글