[Pytorch] Pytorch를 활용한 DCGAN 구현

double-oh·2021년 1월 30일
0

Pytorch를 이용해 DCGAN을 구현해본다.

필요 모듈 import

import random
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
#%matplotlib inline

데이터 로드

# 학습 때 지속해서 랜덤한 값이 등장하지 않게 랜덤 seed를 정함
MANUAL_SEED = 1
random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)

dataset = datasets.ImageFolder(root=DATA_ROOT,
                    transform=transforms.Compose([
                               transforms.Resize(IMAGE_SIZE),
                               transforms.CenterCrop(IMAGE_SIZE),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=N_WORKERS)

device = torch.device("cuda" if (torch.cuda.is_available() and N_GPU > 0) else "cpu")
print("device: ", torch.cuda.is_available())

# 데이터 확인
image = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
image = np.transpose(vutils.make_grid(image[0].to(device)[:64], padding=2, normalize=True).cpu(), (1,2,0))
plt.imshow(image)
plt.show()

가중치 초기화

def init_weight(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

생성자

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.ConvTranspose2d(LATENT_VECTOR_SIZE, GENERATOR_FEATURE_MAP_SIZE * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(GENERATOR_FEATURE_MAP_SIZE * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(GENERATOR_FEATURE_MAP_SIZE * 8, GENERATOR_FEATURE_MAP_SIZE * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GENERATOR_FEATURE_MAP_SIZE * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(GENERATOR_FEATURE_MAP_SIZE * 4, GENERATOR_FEATURE_MAP_SIZE * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GENERATOR_FEATURE_MAP_SIZE * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(GENERATOR_FEATURE_MAP_SIZE * 2, GENERATOR_FEATURE_MAP_SIZE, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GENERATOR_FEATURE_MAP_SIZE),
            nn.ReLU(True),
            nn.ConvTranspose2d(GENERATOR_FEATURE_MAP_SIZE, TRAIN_IMAGE_CHANNEL_SIZE, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, data):
        return self.main(data)

generator = Generator(N_GPU).to(device)

# multi gpu 설정
if (device.type == 'cuda') and (N_GPU > 1):
    generator = nn.DataParallel(generator, list(range(N_GPU)))

generator.apply(init_weight)

구별자

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(TRAIN_IMAGE_CHANNEL_SIZE, DISCRIMINATOR_FEATURE_MAP_SIZE, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(DISCRIMINATOR_FEATURE_MAP_SIZE, DISCRIMINATOR_FEATURE_MAP_SIZE * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 2, DISCRIMINATOR_FEATURE_MAP_SIZE * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 4, DISCRIMINATOR_FEATURE_MAP_SIZE * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, data):
        return self.main(data)

discriminator = Discriminator(N_GPU).to(device)

# multi gpu 설정
if (device.type == 'cuda') and (N_GPU > 1):
    discriminator = nn.DataParallel(discriminator, list(range(N_GPU)))

discriminator.apply(init_weight)

비용함수, 라벨, 옵티마이저 정의

criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, LATENT_VECTOR_SIZE, 1, 1, device=device)

REAL_LABEL = 1.
FAKE_LABEL = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(discriminator.parameters(), lr=LR, betas=(BETA1, 0.999))
optimizerG = optim.Adam(generator.parameters(), lr=LR, betas=(BETA1, 0.999))

학습

for epoch in range(N_EPOCHS):
    for i, data in enumerate(tqdm(dataloader), 0):
        # 실제 이미지에 대해서 구별자 학습
        discriminator.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), REAL_LABEL, dtype=torch.float, device=device)
        output_real = discriminator(real_cpu).view(-1)
        loss_real_d = criterion(output_real, label)
        loss_real_d.backward()
        
        # 가짜 이미지 생성한후, 가짜 이미지를 구별자가 구별하게 학습
        noise = torch.randn(batch_size, LATENT_VECTOR_SIZE, 1, 1, device=device)
        fake = generator(noise).detach()
        label = torch.full((batch_size,), FAKE_LABEL, dtype=torch.float, device=device)
        output_fake = discriminator(fake).view(-1)     
        loss_fake_d = criterion(output_fake, label)
        loss_fake_d.backward()
        optimizerD.step()

        # 생성자가 진짜 같은 이미지를 만들도록 학습
        generator.zero_grad()
        fake = generator(noise).to(device)
        output = discriminator(fake).view(-1)
        label.fill_(REAL_LABEL)
        loss_g = criterion(output, label)
        loss_g.backward()
        optimizerG.step()

    # epoch별 생성된 이미지 확인 
    fake = generator(fixed_noise).cpu()
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.figure(figsize=(8,8))
    image = np.transpose(vutils.make_grid(fake.detach()[:64], padding=2, normalize=True).cpu(), (1,2,0))
    plt.imshow(image)
    plt.show()

학습 중 생성된 이미지 결과


profile
Yes, Code Wins Arguments!!

0개의 댓글