GAN의 학습과정

gogowonji·2022년 3월 3일
0

인공지능

목록 보기
1/1

GAN의 학습과정 - Notion 링크

import torch
import torch.nn as nn

# 알아야 할 것
# x -> Discriminator -> D(x)
# z -> Generator -> G(z) -> Discriminator -> D(G(z))


# Discriminator 구분자 - 진짜를 진짜로 인식하도록, 가짜를 가짜로 인식하도록
D = nn.Sequential(

nn.Linear(784,128),
# 28 * 28 = 784(input = x) -> 128(hidden)

# 활성화함수
nn.ReLU(),
 
nn.Linear(128,1),
# 128(hidden) -> 1(output = D(x))

# 활성화함수 - sigmoid와 유사
nn.Tanh()

)

# Generator 생성자 - 가짜 이미지 생성
G = nn.Sequential(

nn.Linear(100,128),
# 100(input = z) -> 128(hidden)

# 활성화함수
nn.ReLU(),

nn.Linear(128,784),
# 128(hidden) -> 784(output = G(z) = x)

# 활성화함수
nn.Tanh()

)

# LossFunction - Binary Cross Entropy
# LossFunction : 목적 함수가 최소화되도록 학습 수행
# -ylog(h(x)) - (1-y)log(1-h(x))
criterion = nn.BCELoss()

# Optimizer 최적화 함수 - Adam
# lr 학습률 = 0.01

# D
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.01)
# G
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.01)


# 가정
# x be real images of shape (batch_size,784)
# z be random noise of shape (batch_size,100)

# train
while True:

	# D
	loss = criterion(D(x),1) + criterion(D(G(x)),0)
	# D(x) 1에 가까워지도록 (진짜 이미지 학습)
	# D(G(x)) 0에 가까워지도록 (가짜 이미지 학습)
	
	# Backprop
	loss.backward()

	d_optimizer.step()
	
	# G
	loss = criterion(D(G(z)),1)
	# D(G(z)) 1에 가까워지도록 (D를 속이도록)

	# Backprop
	loss.backward()

	g_optimizer.step()

# Alert
# Backprop시,
# G.parameters update O
# D.parameters update X
profile
개발자를 할까 말까

0개의 댓글