DDPM 이해해보기 - 1

BSH·2024년 1월 18일
0

논문 리뷰

목록 보기
7/8

Personalization 관련 논문을 읽으면서 기본 아키텍쳐와 Diffusion 기본 개념에 대해 무지한게 느껴져 기초가 되는 DDPM을 읽고 이해해보기로 했습니다.

공부한 내용은 하단의 링크에 모아두었고 이해한 내용을 바탕으로 흐름을 정리해보았습니다.

본격적으로 들어가기 전에 사전지식입니다.

  • Variational Inference
  • Markov Chain
  • Bayes rule
  • KL divergence
  • VAE

Diffusion process는 2015 ICML의 unsupervised learning using nonequilibrium thermodynamics 논문이 시초입니다. 연기가 퍼지는 과정을 아주 짧은 시간동안 관찰하면 그 연기는 가우시안 분포를 따라 퍼집니다. 그리고 그 반대의 과정도 가우시안 분포라고 할 수 있습니다. 이미지에 noise를 추가하고 제거하는 것도 이와 동일한 과정으로 보고 diffusion model이 나오게 되었습니다.

Model Architecture

시작하기에 앞서 Model Architecture를 먼저 살펴보겠습니다.

Unet을 사용하며 각 layer에 time step t를 추가적으로 넣어주는 것을 볼 수 있습니다.
timestep t를 저렇게 넣는 이유는 Diffusion 과정에서 각 XtX_{t}에서 Xt1X_{t-1}로 넘어가는 과정마다 모델이 다르기 때문에 추가적인 timestep 정보가 없으면 DDPM의 총 timestep인 1000개나 되는 모델을 직접 만들어야합니다. 이는 사실상 불가능하기 때문에 하나의 모델에 timestep을 넣어주어 t단계의 이미지를 예측하도록 모델을 설계하였습니다.

Forward Process

forward 프로세스는 가우시안(정확히 conditional gaussian)이므로 아래와 같이 표현가능합니다. 식은 N(0,I)\mathcal{N}(0, I)을 맞추기위해 저런 형태를 가지게 되었다고 합니다.

q(XtXt1)=N(Xt;μt1,Σt1)=N(Xt;1βtXt1,βtI)\begin{aligned} q(X_{t}|X_{t-1}) &= \mathcal{N}(X_{t};\mu_{t-1},\,\Sigma_{t-1})\\ &= \mathcal{N}(X_{t};\sqrt{1-\beta_{t}}X_{t-1},\,\beta_{t}I) \end{aligned}

아래식을 보면 timestep이 크다는 가정하에 Var(Xt1)=1,Var(ϵ)=1Var(X_{t-1})=1,\,Var(\epsilon)=1 이므로 분산이 1이 됨을 알 수 있습니다.

Var(Xt)=Var(1βtXt1+βtϵ)=(1βt)Var(Xt1)+βtVar(ϵ)=1βt+βt=1\begin{aligned} Var(X_{t})&=Var(\sqrt{1-\beta_{t}}X_{t-1}+\sqrt{\beta_{t}}\epsilon)\\ &=(1-\beta_{t})Var(X_{t-1})+\beta_{t}Var(\epsilon)\\ &=1-\beta_{t}+\beta_{t}\\ &=1 \end{aligned}

N(Xt;1βtXt1,βtI)\mathcal{N}(X_{t};\sqrt{1-\beta_{t}}X_{t-1},\,\beta_{t}I)에서 backpropagation을 위한 reparameterization trick을 통해 XtX_{t}에서의 아래의 샘플링 식을 구할 수 있습니다.

Xt=1βtXt1+βtϵX_{t}=\sqrt{1-\beta_{t}}X_{t-1}+\sqrt{\beta_{t}}\epsilon

timestep t일때의 이미지를 한번씩 더해가면서 구하는게 아니라 한번만에 구할 수 있습니다.

αt=1βt,at~=i=1Tαiq(XtX0)=N(Xt;at~X0,(1at~)I)\alpha_{t}=1-\beta_{t},\,\tilde{a_{t}}=\prod_{i=1}^T\alpha_{i}\\ q(X_{t}|X_{0})=\mathcal{N}(X_{t};\sqrt{\tilde{a_{t}}}X_{0},\,(1-\tilde{a_{t}})I)
xt=1βtxt1+βtϵt1ϵt1N(0,I)=αtxt1+1αtϵt1=αt(αt1xt2+1αt1ϵt2)+1αtϵt1ϵt2N(0,I)=αtαt1xt2+αt(1αt1)ϵt2+1αtϵt1\begin{aligned} x_{t}&=\sqrt{1-\beta_{t}}x_{t-1}+\sqrt{\beta_{t}}\epsilon_{t-1}\quad\epsilon_{t-1} \sim \mathcal{N}(0, I)\\ &=\sqrt{\alpha_{t}}x_{t-1}+\sqrt{1-\alpha_{t}}\epsilon_{t-1}\\ &=\sqrt{\alpha_{t}}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2})+\sqrt{1-\alpha_{t}}\epsilon_{t-1}\quad\epsilon_{t-2} \sim \mathcal{N}(0, I)\\ &=\sqrt{\alpha_{t}\alpha_{t-1}}x_{t-2}+\sqrt{\alpha_{t}(1-\alpha_{t-1})}\epsilon_{t-2}+\sqrt{1-\alpha_{t}}\epsilon_{t-1} \end{aligned}

αt(1αt1)ϵt2+1αtϵt1N(0,(1αtαt1)I)\sqrt{\alpha_{t}(1-\alpha_{t-1})}\epsilon_{t-2}+\sqrt{1-\alpha_{t}}\epsilon_{t-1}\sim \mathcal{N}(0,\,(1-\alpha_{t}\alpha_{t-1})I)이기 때문에 이를 위 식에 적용하면 아래와 같이 증명이 됩니다.

=αtαt1xt2+1αtαt1ϵt2ϵt2N(0,I)=αtαt1αt2xt3+1αtαt1αt2ϵt3ϵt3N(0,I)=...=αt~x0+1αt~ϵ0ϵ0N(0,I)\begin{aligned} &=\sqrt{\alpha_{t}\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t}\alpha_{t-1}}\epsilon'_{t-2}\quad \epsilon'_{t-2} \sim \mathcal{N}(0, I)\\ &=\sqrt{\alpha_{t}\alpha_{t-1}\alpha_{t-2}}x_{t-3}+\sqrt{1-\alpha_{t}\alpha_{t-1}\alpha_{t-2}}\epsilon'_{t-3}\quad \epsilon'_{t-3} \sim \mathcal{N}(0, I)\\ &=...\\ &=\sqrt{\tilde{\alpha_{t}}}x_{0}+\sqrt{1-\tilde{\alpha_{t}}}\epsilon'_{0} \quad \epsilon'_{0} \sim \mathcal{N}(0, I) \end{aligned}

N(Xt;at~X0,(1at~)I)\mathcal{N}(X_{t};\sqrt{\tilde{a_{t}}}X_{0},\,(1-\tilde{a_{t}})I)에서의 샘플링은 아래와 같습니다.

Xt=αt~X0+(1αt~)ϵX_{t}=\sqrt{\tilde{\alpha_{t}}}X_{0}+\sqrt{(1-\tilde{\alpha_{t}})}\epsilon

그럼 이제 여기까지 X0X_{0}를 통해 time step t일 때의 노이즈 이미지인 XtX_{t}를 한번에 구하는 방법을 알게 되었고 forward과정에서 β\beta는 하이퍼파라미터(학습가능하기도 하다)로 설정했기 때문에 learnable variable은 없다는 것을 보았습니다.


Reverse Process

Forward Process의 역과정인 Reverse Process는 Forward 과정과 유사합니다. DDPM의 핵심이기도 한 이 과정은 PθP_{\theta}가 noise를 추가하는 qq를 보고 noise를 걷어내는 과정을 학습하게 됩니다.

Pθ(Xt1Xt)=N(Xt1;μθ(Xt,t),Σθ(Xt,t))P_{\theta}(X_{t-1}|X_{t})=\mathcal{N}(X_{t-1};\mu_{\theta}(X_{t},\,t),\,\Sigma_{\theta}(X_{t},\,t))

DDPM에서 분산은 time step t에 의존하는 변수이므로 아래와 같이 표현됩니다.

Pθ(Xt1Xt)=N(Xt1;μθ(Xt,t),σt2I)P_{\theta}(X_{t-1}|X_{t})=\mathcal{N}(X_{t-1};\mu_{\theta}(X_{t},\,t),\,\sigma_{t}^{2}I)

이제 이 과정에서 학습해야할 건 μθ\mu_{\theta}입니다. 당장 구하는 방법은 알 수가 없으니 Loss Term으로 넘어가 보겠습니다.


DDPM Loss

VAE와 DDPM은 결국 NLL을 통해서 구하는데 그 전개 방식도 유사합니다. 그럼 먼저 VAE먼저 유도하면서 DDPM으로 넘어가겠습니다. 기존 VAE에서는 logPθlogP_{\theta}로 시작하는데 DDPM loss까지 가기 위해 Negative Log Likelihood로 시작합니다.

VAE loss

Ezq(zx)[logPθ(x)]=logPθ(x)q(zx)dzDefinitionofE=logPθ(x,z)Pθ(zx)q(zx)dz=(logPθ(x,z)Pθ(zx)q(zx)q(zx))q(zx)dz=logPθ(x,z)q(zx)q(zx)dz+logq(zx)Pθ(zx)q(zx)dz=ELBOKL(q(zx)Pθ(zx))KL0logPθ(x,z)q(zx)q(zx)dz=logPθ(xz)Pθ(z)q(zx)q(zx)dz\begin{aligned} \mathbb{E_{z\sim q(z|x)}}[-logP_{\theta}(x)]&=\int{-logP_{\theta}(x)}\cdot q(z|x)dz\quad\because Definition\,of\, \mathbb{E}\\ &=\int{-log{{P_{\theta}(x, z)}\over P_{\theta}(z|x)}} q(z|x)dz\\ &=\int(-log{{{P_{\theta}(x, z)}\over{P_{\theta}(z|x)}}\cdot{{{q(z|x)}\over{q(z|x)}}}})\cdot q(z|x)dz\\ &=\int{-log{{P_{\theta}(x, z)}\over{q(z|x)}}\cdot q(z|x)}dz+\int{-log{{q(z|x)}\over{P_{\theta}(z|x)}}}\cdot q(z|x)dz\\ &=-ELBO-KL(q(z|x)||P_{\theta}(z|x))\quad\because KL \ge 0\\ &\leq \int{-log{{P_{\theta}(x, z)}\over{q(z|x)}}\cdot q(z|x)}dz\\ &=\int{-log{{{P_{\theta}(x|z)P_{\theta}(z)}\over{q(z|x)}}}\cdot q(z|x)dz} \end{aligned}

여기서 Pθ(z)P_{\theta}(z)q(zx)q(z|x)를 묶어서 전개하면 더 들어가면 VAE loss가 나오게 됩니다.
DDPM은 Pθ(xz)P_{\theta}(x|z)q(zx)q(z|x)를 묶어서 전개합니다.

VAE to DDPM

VAE에서는 latent가 z이지만 DDPM에서는 xTx_{T}이기 때문에 이를 바꾸고 전개하겠습니다.

logPθ(xz)Pθ(z)q(zx)q(zx)dz=logPθ(xz)q(zx)q(zx)dz+log(Pθ(z))q(zx)dz=ExTq(xTx0)[logPθ(x0xT)q(xTx0)]+ExTq(xTx0)[logPθ(xT)]=KL(q(xTx0)Pθ(x0xT))+ExTq(xTx0)[logPθ(xT)]\int{-log{{{P_{\theta}(x|z)P_{\theta}(z)}\over{q(z|x)}}}\cdot q(z|x)dz}\\ =\int{-log{{{P_{\theta}(x|z)}\over{q(z|x)}}}\cdot q(z|x)dz} + \int{-log(P_{\theta}(z))\cdot q(z|x)dz}\\ =\mathbb{E_{x_{T}\sim q(x_{T}|x_{0})}}[-log{{{P_{\theta}(x_{0}|x_{T})}\over{q(x_{T}|x_{0})}}}] + \mathbb{E_{x_{T}\sim q(x_{T}|x_{0})}}[-logP_{\theta}(x_{T})]\\ =-KL(q(x_{T}|x_{0})||P_{\theta}(x_{0}|x_{T})) + \mathbb{E_{x_{T}\sim q(x_{T}|x_{0})}}[-logP_{\theta}(x_{T})]

처음에 언급한 "PθP_{\theta}가 noise를 추가하는 qq를 보고 noise를 걷어내는 과정을 학습하게 된다"라는 말이 가장 마지막의 q(xTx0)q(x_{T}|x_{0})Pθ(x0xT)P_{\theta}(x_{0}|x_{T})의 KL Term을 통해 알 수 있습니다.

길이 너무 길면 읽기 힘드니 여기까지 하고 나머지는 다음 포스팅에서 하도록 하겠습니다.


참고 링크

영상

블로그

논문

profile
컴공생

0개의 댓글