Diffusion Model

고영민·2023년 3월 27일
0

논문 및 개념 리뷰

목록 보기
2/4

1. 개요


Figure 1. Diffusion model 개요

기본적인 diffusion model의 구조는 Figure 1과 같이 forward process와 reverse process로 설명될 수 있다. 먼저 forward process는 원본 이미지에 noise를 섞는 과정으로 보통 Gaussian noise가 사용되며, 위의 그림처럼 하나의 step마다 약간씩 노이즈를 추가해가면 결국 많은 step이 지났을 때 노이즈가 너무 많아져서 마치 Gaussian 분포에서 무작위로 픽셀값을 추출한 것만 같은 잡음 이미지가 된다.

다음으로 reverse process가 diffusion model에서 신경망 등을 통해 학습시키고자 하는 과정이며, 이는 forward process와는 반대로 노이즈가 섞여 있는 이미지가 입력되었을 때 해당 노이즈를 제거하는 과정이다. 따라서 forward process에서 몇 번의 step을 통해 최종적으로 완전한 잡음 이미지를 획득한 것처럼, reverse process에서는 잡음 이미지가 들어왔을 때, 몇 번의 step을 통해 완전히 노이즈를 걷어내어 깨끗한 이미지를 얻는 것이다.

이러한 diffusion model은 새로운 데이터를 생성하는 생성 모델에서 주로 사용되는데, 이것은 reverse process를 활용한 것으로, 학습이 잘 되어 있는 diffusion model에 무작위로 생성된 잡음 이미지를 입력하면, 학습 데이터셋에는 포함되어 있지 않은 잡음 이미지더라도 model이 나름대로 알아서 잡음을 걷어내고, 최종적으로는 어떠한 깨끗한 이미지가 생성된다. 어떤 이미지가 생성될 지는 입력되는 잡음 이미지에 따라 결정되며, 최근에는 conditional diffusion model을 통해 이미지를 어떻게 생성할 지 제어하기 위하여 잡음 이미지 뿐만 아니라 추가적인 정보(텍스트, 타겟 이미지 등)을 함께 입력하여 서비스를 제공하곤 한다.

이번 포스트에서는 diffusion model에 대한 식 유도와 함께, conditional diffusion model의 구조 등을 코드 예시와 함께 학습한다.

2. Diffusion model 수식

2.1. Forward Process

Forward process는 깨끗한 이미지에 몇 번의 step에 걸쳐 노이즈을 추가하는 과정이다. 이때, 처음의 깨끗한 데이터를 x0x_0, step을 tt, 최종적인 잡음 이미지를 xTx_T라고 했을 때, Gaussian noise를 추가한다고 가정하면 다음과 같이 쓸 수 있다.

q(x1:Tx0)=t=1Tq(xtxt1),            q(xtxt1)=N(xt;1βtxt1,βtI)q(\mathbf{x}_{1:T}|\mathbf{x}_0) = \prod^T_{t=1}q(\mathbf{x}_t|\mathbf{x}_{t-1}), \;\;\;\;\;\; q(\mathbf{x}_t|\mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t;\sqrt{1-\beta_t}\mathbf{x}_{t-1}, \beta_t \mathbf{I})

먼저 Markov 과정처럼 어떤 time step에서 노이즈를 추가할 때 바로 이전 step의 이미지에만 영향을 받으며, 이전 step의 값의 일부분만을 Gaussian의 평균값에 반영하여 유지시키고(1βtxt1\sqrt{1-\beta_t}\mathbf{x}_{t-1}), 나머지 부분은 분산을 키워 무작위성을 높인다(βtI\beta_t \mathbf{I}). 이러한 과정을 TT번 반복하면, 왼쪽과 같은 식이 되고, TT가 충분히 크다면, 결국 평균은 0에 가까워지고, 분산은 커져서 완전한 잡음 데이터가 될 것이다.

이후 reverse process에서는 신경망을 통해 노이즈를 걷어내는 과정을 모델링하게되며, 이러한 신경망을 학습시키기 위해서 다양한 이미지에 위와 같은 forward process를 적용하여 time step 별로 노이즈가 섞이기 전 이미지와 섞인 후의 이미지를 수집하고 이를 학습데이터로 활용한다. 다만 성공적인 학습을 위해서는 이러한 데이터가 다량 필요하고, training loop를 한번 돌 때마다 이미지에 forward process를 처음부터 적용하여 데이터를 생성하면 시간이 많이 소요된다(만약 매우 큰 time step에서의 노이즈 섞인 이미지가 필요할 경우 forward process를 t=0부터 t=M까지 수행해야하기 때문). 하지만 수식을 약간 정리하면 임의의 time step에서의 노이즈 섞인 이미지를 바로 sampling할 수 있으며, 이는 학습 속도 개선에 큰 도움을 준다.

αt=1βt\alpha_t = 1- \beta_t, αˉt=i=1tαi\bar{\alpha}_t = \prod^t_{i=1}\alpha_i라고 했을 때, 위와 같이 정의한 forward process에 대해서 임의의 time step의 데이터 xt\mathbf{x}_t에 대해서 다음과 같이 쓸 수 있다.

xt=αtxt1+1αtϵt1                                                  ϵ  is  Gaussian  noiseN(0,I)=αt(αt1xt2+1αt1ϵt2)+1αtϵt1=αtαt1xt2+αtαtαt1ϵt2+1αtϵt1=αtαt1xt2+1αtαt1ϵˉt2                        ϵˉ  means  merging  two  Gaussian==αˉtx0+1αˉtϵ\begin{aligned} \mathbf{x}_t & = \sqrt{\alpha_t}\mathbf{x}_{t-1} + \sqrt{1-\alpha_t}\epsilon_{t-1} \;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\; \rightarrow \;\; \epsilon \; is \; Gaussian \; noise \sim \mathcal{N}(0,\mathbf{I}) \\ & = \sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}\mathbf{x}_{t-2} + \sqrt{1-\alpha_{t-1}}\epsilon_{t-2}) + \sqrt{1-\alpha_t}\epsilon_{t-1} \\ & = \sqrt{\alpha_t \alpha_{t-1}}\mathbf{x}_{t-2} + \sqrt{\alpha_t-\alpha_t\alpha_{t-1}}\epsilon_{t-2} + \sqrt{1-\alpha_t}\epsilon_{t-1} \\ & = \sqrt{\alpha_t \alpha_{t-1}}\mathbf{x}_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\bar{\epsilon}_{t-2} \;\;\;\;\;\;\;\;\;\; \rightarrow \;\; \bar{\epsilon} \; means \; merging \; two \; Gaussian \\ & = \cdots \\ & = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\epsilon \end{aligned}

1번째 줄은 forward process의 정의식을 그대로 따른 것이며, 3번째 줄에서 4번째 줄로 넘어갈 때는 두 Gaussian distribution의 합을 생각하면 된다(두 Gaussian의 합이 만드는 분포의 평균과 분산은 각 분포의 평균의 합, 분산의 합으로 나타난다). 위의 식을 정리하면 다음과 같고, 임의의 time step에서의 노이즈 섞인 이미지를 만들기 위하여 굳이 forward process 전체를 수행하지 않아도 아래의 식을 통해서 바로 sampling 할 수 있음을 의미한다.

q(xtx0)=N(xt;αˉtx0,1αˉtI)q(\mathbf{x}_t|\mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t}\mathbf{x}_0, \sqrt{1-\bar{\alpha}_t} \mathbf{I})

2.2. Reverse Process

Reverse process는 위의 forward process와는 반대로 입력 데이터의 노이즈를 제거하는 과정이다. 즉, forward process를 표현할 때는 q(xtxt1)q(\mathbf{x}_t|\mathbf{x}_{t-1})을 표현하는 방법을 찾았지만, reverse process에서는 q(xt1xt)q(\mathbf{x}_{t-1}|\mathbf{x}_{t})를 찾아야 한다. 이때, forward process야 입력 데이터에 Gaussian noise를 추가하며 진행되니 해당 정보를 이용하여 각각의 과정을 Gaussian distribution으로 표현할 수 있지만, reverse process는 어떤 분포로 나타나는 지 의문이다. Diffusion model의 논문(Sohl-Dickstein, Jascha, et al. "Deep unsupervised learning using nonequilibrium thermodynamics." International conference on machine learning. PMLR, 2015)에서 reverse process를 정의하고 있으며, (Feller, W. "On the theory of stochastic processes, with particular reference to applications." In Proceedings of the Berkeley Symposium on Mathematical Statistics and Probability. The Regents of the University of California, 1949.)에서 forward process가 Gaussian이나 binomial 이고 β\beta가 작을 경우 reverse process 또한 forward process와 같은 형태의 분포로 나타남을 보였다고 한다. 즉 q(xt1xt)q(\mathbf{x}_{t-1}|\mathbf{x}_{t}) 또한 Gaussian 분포로 표현될 수 있다.

다만, q(xt1xt)q(\mathbf{x}_{t-1}|\mathbf{x}_{t})를 직접 계산하기는 어려워 다음과 같이 reverse process를 모델링하며, reverse process가 Gaussian 분포로 나타난다는 것을 알고 있기 때문에, 신경망을 통해 다음과 같이 Gaussian을 표현하는 평균과 분산을 모델링하도록 한다.

pθ(x0:T)=p(xT)t=1Tpθ(xt1xt),                          pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))p_\theta(\mathbf{x}_{0:T}) = p(\mathbf{x}_T) \prod^T_{t=1}p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t), \;\;\;\;\;\;\;\;\;\;\;\;\; p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_t, t), \Sigma_\theta(\mathbf{x}_t, t))
p(xT)=N(xt1;0,(I))p(\mathbf{x}_T) = \mathcal{N}(\mathbf{x}_{t-1}; 0, \mathbf(I))

2.3. Loss Function

즉, pθp_\thetaqq와 유사해지도록 학습을 시켜야하고 최종적으로는 reverse process를 통해 원래의 데이터 x0\mathbf{x}_0를 복원해내야 한다. 이를 위한 loss function으로 다음과 같이 cross entropy를 사용하면 두 분포가 유사해지도록 학습시킬 수 있다.

Loss=Eq(x0)[logpθ(x0)]Loss = -\mathbb{E}_{q(\mathbf{x}_0)}[\log{p_\theta(\mathbf{x}_0)}]

하지만 VAE에서와 같은 논리로 loss function 안의 pθ(x0)p_\theta(\mathbf{x}_0)는 직접적으로 계산하기가 어려운데, pθ(x0)=pθ(x0:T)dx1:Tp_\theta(\mathbf{x}_0)=\int p_\theta(\mathbf{x}_{0:T}) d\mathbf{x}_{1:T}에서 이전 step에 해당하는 모든 x1:T\mathbf{x}_{1:T}에 대한 적분이 불가하기 때문이다. 따라서 VAE의 해법처럼 다음과 같이 lower bound(negative log를 사용할 경우 upper bound)를 이용하게 되며, 그 유도 방법에는 여러가지가 있을 수 있지만 여기서는 Jensen's inequality를 통한 방법을 소개한다(VAE 포스트도 참고).

LCE=Eq(x0)logpθ(x0)=Eq(x0)log(pθ(x0:T)dx1:T)=Eq(x0)log(q(x1:Tx0)pθ(x0:T)q(x1:Tx0)dx1:T)=Eq(x0)log(Eq(x1:Tx0)pθ(x0:T)q(x1:Tx0))Eq(x0:T)logpθ(x0:T)q(x1:Tx0)=Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]=LVLB\begin{aligned} L_\text{CE} &= - \mathbb{E}_{q(\mathbf{x}_0)} \log p_\theta(\mathbf{x}_0) \\ &= - \mathbb{E}_{q(\mathbf{x}_0)} \log \Big( \int p_\theta(\mathbf{x}_{0:T}) d\mathbf{x}_{1:T} \Big) \\ &= - \mathbb{E}_{q(\mathbf{x}_0)} \log \Big( \int q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})} d\mathbf{x}_{1:T} \Big) \\ &= - \mathbb{E}_{q(\mathbf{x}_0)} \log \Big( \mathbb{E}_{q(\mathbf{x}_{1:T} \vert \mathbf{x}_0)} \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})} \Big) \\ &\leq - \mathbb{E}_{q(\mathbf{x}_{0:T})} \log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})} \\ &= \mathbb{E}_{q(\mathbf{x}_{0:T})}\Big[\log \frac{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}{p_\theta(\mathbf{x}_{0:T})} \Big] = L_\text{VLB} \end{aligned}

위 처럼 유도된 negative log의 upper bound를 줄이면 자연스럽게 cross entropy 또한 줄어들게 된다. 이제 해당 bound인 LVLBL_\text{VLB}를 좀 더 정리해 보면 다음과 같다((Sohl-Dickstein, Jascha, et al. "Deep unsupervised learning using nonequilibrium thermodynamics." International conference on machine learning. PMLR, 2015)의 loss function).

LVLB=Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]=Eq[logt=1Tq(xtxt1)pθ(xT)t=1Tpθ(xt1xt)]=Eq[logpθ(xT)+t=1Tlogq(xtxt1)pθ(xt1xt)]=Eq[logpθ(xT)+t=2Tlogq(xtxt1)pθ(xt1xt)+logq(x1x0)pθ(x0x1)]=Eq[logpθ(xT)+t=2Tlog(q(xt1xt,x0)pθ(xt1xt)q(xtx0)q(xt1x0))+logq(x1x0)pθ(x0x1)]=Eq[logpθ(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)+t=2Tlogq(xtx0)q(xt1x0)+logq(x1x0)pθ(x0x1)]=Eq[logpθ(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)+logq(xTx0)q(x1x0)+logq(x1x0)pθ(x0x1)]=Eq[logq(xTx0)pθ(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)logpθ(x0x1)]=Eq[DKL(q(xTx0)pθ(xT))LT+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))Lt1logpθ(x0x1)L0]\begin{aligned} L_\text{VLB} &= \mathbb{E}_{q(\mathbf{x}_{0:T})} \Big[ \log\frac{q(\mathbf{x}_{1:T}\vert\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T})} \Big] \\ &= \mathbb{E}_q \Big[ \log\frac{\prod_{t=1}^T q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{ p_\theta(\mathbf{x}_T) \prod_{t=1}^T p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t) } \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=1}^T \log \frac{q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \log\frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \Big( \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)}\cdot \frac{q(\mathbf{x}_t \vert \mathbf{x}_0)}{q(\mathbf{x}_{t-1}\vert\mathbf{x}_0)} \Big) + \log \frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \sum_{t=2}^T \log \frac{q(\mathbf{x}_t \vert \mathbf{x}_0)}{q(\mathbf{x}_{t-1} \vert \mathbf{x}_0)} + \log\frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\ &= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{q(\mathbf{x}_1 \vert \mathbf{x}_0)} + \log \frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big]\\ &= \mathbb{E}_q \Big[ \log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_T)} + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} - \log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1) \Big] \\ &= \mathbb{E}_q [\underbrace{D_\text{KL}(q(\mathbf{x}_T \vert \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_T))}_{L_T} + \sum_{t=2}^T \underbrace{D_\text{KL}(q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t))}_{L_{t-1}} \underbrace{- \log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)}_{L_0} ] \end{aligned}

따라서 최종적인 형태는 다음과 같이 정리할 수 있다.

LVLB=LT+LT1++L0where LT=DKL(q(xTx0)pθ(xT))Lt=DKL(q(xtxt+1,x0)pθ(xtxt+1)) for 1tT1L0=logpθ(x0x1)\begin{aligned} L_\text{VLB} &= L_T + L_{T-1} + \dots + L_0 \\ \text{where } L_T &= D_\text{KL}(q(\mathbf{x}_T \vert \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_T)) \\ L_t &= D_\text{KL}(q(\mathbf{x}_t \vert \mathbf{x}_{t+1}, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_t \vert\mathbf{x}_{t+1})) \text{ for }1 \leq t \leq T-1 \\ L_0 &= - \log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1) \end{aligned}

즉, 최종적인 loss는 Gaussian들의 Kullback-Leibler Divergence(KLD)로 계산될 수 있으며, 여기서 LTL_T의 경우 해당 식에 포함되어 있는 xT\mathbf{x}_T는 완전한 Gaussain noise이기 때문에 상수로 나타난다. 그리고 LtL_t의 경우 qqpθp_\theta 사이의 차이를 줄이기 위한 loss이므로 regularizaion term으로 생각할 수 있고, L0L_0의 경우 최종적으로 생성되는 데이터의 분포에 대한 것으로 reconstruction term으로 생각할 수 있다.

위 최종식에서 나타나는 term들을 실제로 계산할 수 있어야 하는데, 먼저 q(xTx0)q(\mathbf{x}_T \vert \mathbf{x}_0)pθ(xT)p_\theta(\mathbf{x}_T)의 경우 xT\mathbf{x}_T가 Gaussain noise이기 때문에 쉽게 계산될 수 있고, 나머지 pθp_\theta로 표현되는 식들은 신경망의 output에서 바로 계산이 된다. 그리고 q(xtxt+1,x0)q(\mathbf{x}_t \vert \mathbf{x}_{t+1}, \mathbf{x}_0) term은 다음과 같이 계산될 수 있다.

q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)exp(12((xtαtxt1)2βt+(xt1αˉt1x0)21αˉt1(xtαˉtx0)21αˉt))=exp(12(xt22αtxtxt1+αtxt12βt+xt122αˉt1x0xt1+αˉt1x021αˉt1(xtαˉtx0)21αˉt))=exp(12((αtβt+11αˉt1)xt12(2αtβtxt+2αˉt11αˉt1x0)xt1+C(xt,x0)))\begin{aligned} q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) &= q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0) \frac{ q(\mathbf{x}_{t-1} \vert \mathbf{x}_0) }{ q(\mathbf{x}_t \vert \mathbf{x}_0) } \\ &\propto \exp \Big(-\frac{1}{2} \big(\frac{(\mathbf{x}_t - \sqrt{\alpha_t} \mathbf{x}_{t-1})^2}{\beta_t} + \frac{(\mathbf{x}_{t-1} - \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0)^2}{1-\bar{\alpha}_{t-1}} - \frac{(\mathbf{x}_t - \sqrt{\bar{\alpha}_t} \mathbf{x}_0)^2}{1-\bar{\alpha}_t} \big) \Big) \\ &= \exp \Big(-\frac{1}{2} \big(\frac{\mathbf{x}_t^2 - 2\sqrt{\alpha_t} \mathbf{x}_t \color{blue}{\mathbf{x}_{t-1}} \color{black}{+ \alpha_t} \color{red}{\mathbf{x}_{t-1}^2} }{\beta_t} + \frac{ \color{red}{\mathbf{x}_{t-1}^2} \color{black}{- 2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0} \color{blue}{\mathbf{x}_{t-1}} \color{black}{+ \bar{\alpha}_{t-1} \mathbf{x}_0^2} }{1-\bar{\alpha}_{t-1}} - \frac{(\mathbf{x}_t - \sqrt{\bar{\alpha}_t} \mathbf{x}_0)^2}{1-\bar{\alpha}_t} \big) \Big) \\ &= \exp\Big( -\frac{1}{2} \big( \color{red}{(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}})} \mathbf{x}_{t-1}^2 - \color{blue}{(\frac{2\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{2\sqrt{\bar{\alpha}_{t-1}}}{1 - \bar{\alpha}_{t-1}} \mathbf{x}_0)} \mathbf{x}_{t-1} \color{black}{ + C(\mathbf{x}_t, \mathbf{x}_0) \big) \Big)} \end{aligned}

여기서 다음과 같이 식을 정리할 수 있다. 이때, μ~t(xt,x0)\tilde{\boldsymbol{\mu}}_t (\mathbf{x}_t, \mathbf{x}_0)에 대한 식의 3번째 줄에서 4번째 줄로 넘어갈 때에는 앞의 forward process 파트에서 설명한 임의의 time step에 대한 sampling 식 유도를 활용하여 x0\mathbf{x}_0xt\mathbf{x}_t에 대한 식으로 표현하였다.

β~t=1/(αtβt+11αˉt1)=1/(αtαˉt+βtβt(1αˉt1))=1αˉt11αˉtβtμ~t(xt,x0)=(αtβtxt+αˉt11αˉt1x0)/(αtβt+11αˉt1)=(αtβtxt+αˉt11αˉt1x0)1αˉt11αˉtβt=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉtx0=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉt1αˉt(xt1αˉtϵt)=1αt(xt1αt1αˉtϵt)\begin{aligned} \tilde{\beta}_t &= 1/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}}) = 1/(\frac{\alpha_t - \bar{\alpha}_t + \beta_t}{\beta_t(1 - \bar{\alpha}_{t-1})}) = \color{green}{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t} \\ \tilde{\boldsymbol{\mu}}_t (\mathbf{x}_t, \mathbf{x}_0) &= (\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1} }}{1 - \bar{\alpha}_{t-1}} \mathbf{x}_0)/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}}) \\ &= (\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1} }}{1 - \bar{\alpha}_{t-1}} \mathbf{x}_0) \color{green}{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t} \\ &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \mathbf{x}_0\\ &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t}\boldsymbol{\epsilon}_t) \\ &= \color{cyan}{\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_t \Big)} \end{aligned}

따라서 최종적으로 q(xt1xt,x0)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)는 다음과 같이 Gaussian으로 표현될 수 있다.

q(xt1xt,x0)=N(xt1;μ~(xt,x0),β~tI)q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \color{blue}{\tilde{\boldsymbol{\mu}}}(\mathbf{x}_t, \mathbf{x}_0), \color{red}{\tilde{\beta}_t} \mathbf{I})

DDPM(Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in neural information processing systems 33 (2020): 6840-6851.) 논문에서는 이러한 Loss function을 정리하여 더욱 사용하기 쉬운 형태로 표현한다.

우선 앞에서도 나왔듯이 reverse process의 각 step 또한 Gaussian으로 표현할 수 있으며, DDPM에서는 Gaussian의 분산 , Σθ(xt,t)\Sigma_\theta(\mathbf{x}_t, t)을 특정 수치의 상수, σt2I\sigma_t^2 I로 설정(reverse process에서는 noise를 제거하여 깨끗한 원본 데이터를 생성하는 것이 목표이기 때문에 무작위성을 의미하는 분산은 크게 중요하지 않음)하고, βt\beta_t 또한 상수(논문에서는 β1=104\beta_1=10^{-4} 에서 βT=0.02\beta_T=0.02까지 선형적으로 증가)로 두어 식을 간소화 한다. 특히 σt2=βt\sigma_t^2=\beta_t로 설정할 때나 σt2=βt~=1αˉt11αˉtβt\sigma_t^2= \tilde{\beta_t} = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}}\beta_t로 둘 때나 실험적으로 성능차이가 별로 없었다고 한다.

위의 loss 식에서 LTL_Tθ\theta로 결정되는 부분이 없고, βt\beta_t 또한 학습되는 파라미터가 아니기 때문에 학습 loss에서 무시 가능하다.

LtL_t는 두 분포의 KLD로 표현되고 있는데 reverse process의 각 step 또한 Gaussian으로 표현할 수 있으므로 아래와 같이 두 Gaussian의 KLD 식으로 간단하게 표현이 가능하다.

Lt=Eq[12Σθ(xt,t)22μ~t(xt,x0)μθ(xt,t)2]L_t = \mathbb{E}_{q} \Big[\frac{1}{2 \| \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t) \|^2_2} \| \color{blue}{\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0)} - \color{green}{\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)} \|^2 \Big]

여기서 μ~t(xt,x0)\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0)는 위에서 보았듯이 1αt(xt1αt1αˉtϵt)\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_t \Big)로 표현될 수 있다. DDPM에서는 이 과정에서 reparameterize를 수행하여 식을 간소화하는데, loss를 줄이기 위해서는 μθ(xt,t)\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)μ~t(xt,x0)=1αt(xt1αt1αˉtϵt)\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) = \frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_t \Big)에 가까워져야하고, 해당 식에서 xt,αˉt\mathbf{x}_t, \bar{\alpha}_t 등은 데이터 입력 시에 주어지는 값이다. 따라서 학습으로 추론되어야 하는 값은 노이즈 term인 ϵt\boldsymbol{\epsilon}_t로 결정되며, 네트워크로 직접 μθ(xt,t)\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)를 추론하기보다 ϵt\boldsymbol{\epsilon}_t를 대신 추론한다(직접 μθ(xt,t)\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)을 추론해도 학습이 가능하지만 논문에서는 둘 다 실험을 해보았을 때 reparameterize를 수행하는 것이 더 성능이 좋았다고 한다. 왜 그런지는 자세히 설명하지는 않았는데, 평균값은 약간 편향된 분포가 될 수 있어서 그런수도...?). 따라서 다음과 같이 loss를 쓸 수 있다.

Lt=Eq[12Σθ(xt,t)22μ~t(xt,x0)μθ(xt,t)2]=Ex0,ϵ[12Σθ221αt(xt1αt1αˉtϵt)1αt(xt1αt1αˉtϵθ(xt,t))2]=Ex0,ϵ[(1αt)22αt(1αˉt)Σθ22ϵtϵθ(xt,t)2]=Ex0,ϵ[(1αt)22αt(1αˉt)Σθ22ϵtϵθ(αˉtx0+1αˉtϵt,t)2]\begin{aligned} L_t &= \mathbb{E}_{q} \Big[\frac{1}{2 \| \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t) \|^2_2} \| \color{blue}{\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0)} - \color{green}{\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)} \|^2 \Big] \\ &= \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}} \Big[\frac{1}{2 \|\boldsymbol{\Sigma}_\theta \|^2_2} \| \color{blue}{\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_t \Big)} - \color{green}{\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\boldsymbol{\epsilon}}_\theta(\mathbf{x}_t, t) \Big)} \|^2 \Big] \\ &= \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}} \Big[\frac{ (1 - \alpha_t)^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \| \boldsymbol{\Sigma}_\theta \|^2_2} \|\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\|^2 \Big] \\ &= \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}} \Big[\frac{ (1 - \alpha_t)^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \| \boldsymbol{\Sigma}_\theta \|^2_2} \|\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\boldsymbol{\epsilon}_t, t)\|^2 \Big] \end{aligned}

앞의 weight term까지 제거하면 최종적인 DDPM의 loss funcion을 다음과 같이 구할 수 있다.

Ltsimple=Et[1,T],x0,ϵt[ϵtϵθ(xt,t)2]=Et[1,T],x0,ϵt[ϵtϵθ(αˉtx0+1αˉtϵt,t)2]\begin{aligned} L_t^\text{simple} &= \mathbb{E}_{t \sim [1, T], \mathbf{x}_0, \boldsymbol{\epsilon}_t} \Big[\|\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\|^2 \Big] \\ &= \mathbb{E}_{t \sim [1, T], \mathbf{x}_0, \boldsymbol{\epsilon}_t} \Big[\|\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\boldsymbol{\epsilon}_t, t)\|^2 \Big] \end{aligned}

t=1t=1일 때, 위 loss 식은 L0L_0에 대응되는데, t=1t=1이면 loss 식은 q(x_0|x_1)과 p(x_0|x_1)사이의 KDL를 줄이기 위한 loss가 되고, 이는 최초의 loss에서 L0L_0부분이 Eq[LT+Lt+L0]=Eq[+p(x0x1)]\mathbb{E}_q[L_T+L_t+L_0]=\mathbb{E}_q[\cdots+p(x_0|x_1)]가 두 분포의 KDL을 표현하는 것과 일치한다.

참고자료

  1. https://www.youtube.com/watch?v=uFoGaIVHfoE&t=3289s
  2. https://lilianweng.github.io/posts/2021-07-11-diffusion-models
  3. https://greeksharifa.github.io/generative%20model/2020/07/31/Variational-AutoEncoder/
  4. https://process-mining.tistory.com/182
  5. https://ffighting.net/deep-learning-paper-review/diffusion-model/diffusion-model-basic/
  6. https://xoft.tistory.com/32
  7. https://kimjy99.github.io/%EB%85%BC%EB%AC%B8%EB%A6%AC%EB%B7%B0/sr3/

0개의 댓글