Generative adversarial networks

Generative adversarial networks, ACM, 2020

GAN은 2014년, Ian Goodfellow의 "Generative Adversarial Network"라는 논문에서 처음 제시됨
적대적인 과정을 통해 생성 모델을 추정하기 위한 새로운 프레임워크를 제안

  • generative model(생성자, G): 학습 데이터의 분포를 capture
  • discriminator model(판별자, D): 학습 데이터인지 생성된 데이터인지 판별
    • 두 명의 플레이어가 min-max game을 하는 것과 같다고 생각하면 됨

Introduction

  • 딥러닝은 discriminative model(고차원, 풍부한 feature의 입력을 class label에 mapping)에서 두드러진 성공을 보여줌
  • backpropagation, dropout 등의 다양한 알고리즘과 각 layer의 gradient를 network 전체에 잘 전달할 수 있게 도와주는 activation function을 기반으로 딥러닝은 우수한 성능을 보여줌

하지만 MLE(Maximum likehood estimation)과 같은 전략에서 나오는 확률론적인 계산의 어려움으로 인해 generative model은 intractable problem이 존재

  • intractable problem: 보통 O(N2)O(N^2)이상의 time complexity를 가지면 intractable(다루기 어렵다)하다고 말함

  • generative model은 샘플링 된 데이터가 어느 분포에서 나왔는지 추정하는 것이 목적
    • MLE는 확률 밀도 함수를 모델링하는 방법 중 하나
  • 결국 실제 데이터(빨간 점)들의 확률 분포(초록 분포)를 알고 있을 때, generative model(검정 분포)로 실제 데이터의 분포를 근사하는 방향으로 학습

  • discriminative model은 실제 데이터와 생성된 데이터를 분류
  • discriminative model은 경찰, generative model은 위조지폐를 만드는 사람에 비유하면 쉬움
    • 이런 게임으로 인한 두 model간의 경쟁은 위조 데이터와 진짜 데이터를 구분이 불가능한 방향으로 학습하게 지도가능

수식설명
xxPdata(x)P_{data}(x)부터 샘플링한 학습 데이터
zz$P_z(z)로부터 샘플링한 데이터(noise)
Pdata(x)P_{data}(x)실제 데이터의 분포
Pz(z)P_z(z)노이즈의 분포
G(z)G(z)generative model: 노이즈 zz로 생성한 데이터
D(x)D(x)discriminator model: 입력 데이터 x가 학습 데이터일 확률 [0,1]

  • Generator:
    if D(G(z))==1:D(G(z)) == 1:
    • generator는 loss가 최소화 되는 방향으로 학습을 진행(데이터를 잘 생성하도록)
    • noise zz로부터 생성한 데이터 G(z)G(z)가 discriminator를 완벽하게 속임
    • (discriminator는 생성된 데이터가 실제 데이터라고 100% 확신)
      • log(1D(G(z)))=log(11)=log(0)=log(1-D(G(z))) = log(1-1) = log(0) = - \infty

  • Discriminator:
    if D(G(z))==0,D(x)==1:D(G(z)) == 0, D(x) == 1:
    • noise zz로부터 생성한 데이터 G(z)G(z)와 실제 데이터 xx를 discriminator가 완벽하게 구별 가능
    • (discriminator는 생성된 데이터가 가짜 데이터라고 100% 확신)
    • log(D(x))+log(1D(G(z)))=log(1)+log(10)=2log(1)log(D(x)) + log(1-D(G(z))) = log(1) + log(1-0) = 2log(1)

Architecture

  • 실제 동작과정은 다음과 같음

    • Generator의 input으로 latent vector zz를 넣어 가짜 이미지를 생성하고, 실제 이미지와 generator로 생성된 가짜 이미지를 discriminator가 비교(판별)하는 구조
    • discriminator를 통해 추출된 결과(loss)를 가지고 generator와 discriminator가 역전파를 통해 network 학습
  • generative model과 discriminative model을 모두 충분히 학습하고 나면 pdata=p(g)p_{data} = p_(g)가 되는 지점에 도달 (global optimality of pg=pdatap_g = p_{data} )

    • discriminative model은 실제 데이터와 가짜 데이터를 구분할 수 없게 됨 (D(x)=12D(x) = \frac{1}{2})

Adversarial network

  • 매 iter에서 discriminator를 최적화 시키고 generator를 업데이트 시키는 것은 불가능하며, 유한한 크기의 데이터 셋 x를 가졌을 대 discriminator가 overfitting 될 수 있음
    • 따라서 k번 만큼 discriminator를 업데이트하고, 한번의 generator를 업데이트하는 방식으로 진행

  1. 학습 초기: generator는 실제 데이터와 큰 차이를 보이기 때문에 discriminator가 높은 확률로 정답을 맞춤(빠르게 학습됨)
    • (위 그림): log(1D(G(z)))log(1-D(G(z)))를 minimize 하는 수식을 log(D(G(Z)))log(D(G(Z)))를 maximize 하는 방식으로 변경

장점

  • Markov chain이 필요하지 않고 역전파만 사용하여 학습할 수 있음
  • 학습 중에 추론이 필요하지 않고 미분 가능한 다양한 함수를 모델에 사용가능

단점

  • pg(x)p_{g}(x)를 명시적으로 표현할 수 없음
  • 학습 중 discriminator가 generator와 잘 동기화되어야 함
  • generator가 충분한 다양성을 갖기 위해 많은 zz 값을 동일한 xx 값으로 축소하는 model collapse 현상이 발생할 수 있음
    • 따라서 discriminator를 업데이트하지 않고 generator를 너무 많이 훈련하는 것은 안됨
profile
python, Artificial intelligence

0개의 댓글

Powered by GraphCDN, the GraphQL CDN