GAN의 학습과정 - Notion 링크
import torch
import torch.nn as nn
# 알아야 할 것
# x -> Discriminator -> D(x)
# z -> Generator -> G(z) -> Discriminator -> D(G(z))
# Discriminator 구분자 - 진짜를 진짜로 인식하도록, 가짜를 가짜로 인식하도록
D = nn.Sequential(
nn.Linear(784,128),
# 28 * 28 = 784(input = x) -> 128(hidden)
# 활성화함수
nn.ReLU(),
nn.Linear(128,1),
# 128(hidden) -> 1(output = D(x))
# 활성화함수 - sigmoid와 유사
nn.Tanh()
)
# Generator 생성자 - 가짜 이미지 생성
G = nn.Sequential(
nn.Linear(100,128),
# 100(input = z) -> 128(hidden)
# 활성화함수
nn.ReLU(),
nn.Linear(128,784),
# 128(hidden) -> 784(output = G(z) = x)
# 활성화함수
nn.Tanh()
)
# LossFunction - Binary Cross Entropy
# LossFunction : 목적 함수가 최소화되도록 학습 수행
# -ylog(h(x)) - (1-y)log(1-h(x))
criterion = nn.BCELoss()
# Optimizer 최적화 함수 - Adam
# lr 학습률 = 0.01
# D
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.01)
# G
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.01)
# 가정
# x be real images of shape (batch_size,784)
# z be random noise of shape (batch_size,100)
# train
while True:
# D
loss = criterion(D(x),1) + criterion(D(G(x)),0)
# D(x) 1에 가까워지도록 (진짜 이미지 학습)
# D(G(x)) 0에 가까워지도록 (가짜 이미지 학습)
# Backprop
loss.backward()
d_optimizer.step()
# G
loss = criterion(D(G(z)),1)
# D(G(z)) 1에 가까워지도록 (D를 속이도록)
# Backprop
loss.backward()
g_optimizer.step()
# Alert
# Backprop시,
# G.parameters update O
# D.parameters update X