생성적 적대 신경망(GANs: Generative Adversarial Networks)

김우빈·2022년 6월 12일
1

JUST BUILD DEEP_LEARNING

목록 보기
9/9
post-thumbnail

생성적 적대 신경망(GANs: Generative Adversarial Networks)이란?

<이미지 출저: https://guimperarnau.com/blog/2017/03/Fantastic-GANs-and-where-to-find-them >

2014년 이안 굿펠로우(Ian Goodfellow)에의해 제시된 생성적 적대 신경망(GANs: Generative Adversarial Networks)모델은 쉽게 설명하면 생성을 하는 모델입니다. 생성적 적대 신경망(GANs: Generative Adversarial Networks)연구의 초기에는 그럴듯하게 랜덤한 이미지를 생성하는것 부터 시작했습니다.
기계가 생성을 한다니.. 생각해보면 굉장히 어려운 개념이라고 생각 하실 수도 있으나, 사실 CNN, RNN같은 신경망의 이론보다 더 쉽습니다.

생성을 하는 모델?

위에서 제가 생성을 하는 모델이라고 말씀 드렸습니다.
그렇다면 과연 생성적 적대 신경망(GANs: Generative Adversarial Networks)은 무엇을 생성할까요?

일반적인 머신러닝이나 딥러닝 모델이 생성하는 것은, 클래스에 대한 예측값이나, 분류 정도가 될것입니다.
생각해보면 어떠한 "형태"를 만들어 내는것은 아닙니다.
생성적 적대 신경망(GANs: Generative Adversarial Networks)은 "데이터의 형태를 만들어 내고자 하는데 목적이 있습니다.
이미지를 예를 들어 설명 드리면, 픽셀의 분포에 따라, 강아지의 코,눈, 귀등의 모양을 인식 할것입니다. 이것은 사실 사진의 전체적인 채도와 색감은 크게 상관이 없습니다.

결론적으로 분포를 만들어 낸다는것은, 단순히 결과값을 도출하는 함수를 만드는것을 넘어서, 실제적인 형태를 갖춘 데이터를 만드는것입니다.
다시 풀어보면, 생성적 적대 신경망(GANs: Generative Adversarial Networks)은 어떠한 분포나 분산 자체를 만들어 내는 신경망 이라고 정의 할 수 있겠습니다.

왜 적대적일까?

'적대적' 이란 단어의 의미는 뭘까요?
사실 '적대적'이라는 단어에 생성적 적대 신경망(GANs: Generative Adversarial Networks)의 핵심 아이디어가 담겨있습니다.
생성적 적대 신경망(GANs: Generative Adversarial Networks)을 설명할때는 지폐위조범과 경찰로 비유를 많이 합니다.

<이미지출처 : https://files.slack.com/files-pri/T25783BPY-F9SHTP6F9/picture2.png?pub_secret=6821873e68 >

위의 이미지 처럼 지폐위조범은 Generator(생성자), 경찰에겐 Discriminator(판별자)의 역할을 부여합니다.
위조범은 경찰의 단속을 피하기위해, 더더욱 정교한 가짜 지폐를 만들어 낼것이고, 경찰은 더 정교한 기법으로 판별해내는 방법을 개발해 냅니다.
이처럼 각각의 역할을 하는 두개의 모델을 통해 진짜같은 가짜를 생산해 내는 능력을 올려주는것이 생성적 적대 신경망(GANs: Generative Adversarial Networks)의 핵심 아이디어이고, 그래서 '적대적'이라는 단어를 사용하는 것 입니다.

tensorflow로 구현하기

이미지생성을 위해 이미지를 다룰필요가 있습니다. 저번 이미지를 다룬 CNN모델의 합성곱 신경망을 함께 이용하는 심층 합성곱 생성적 적대 신경망(DCGAN:Deep Convolution Generative Adversarial Networks)을 사용해, Mnist 데이터를 훈련한 후, 생성하는 모델을 빌드 해보도록 하겠습니다.

필요 모듈을 import 합니다

import glob
import imageio #gif생성
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
import tensorflow as tf
from IPython import display

데이터를 불러온 후,

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

생성과 판별을 위해 데이터를 나눠줍니다.
그다음 CNN모델을 만들때처럼 정규화 작업을 합니다

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 이미지를 [-1, 1]로 정규화합니다.

버퍼와 배치사이즈를 정의하고 훈련데이터를 나누고 섞어줍니다.

BUFFER_SIZE = 60000
BATCH_SIZE = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

generator 모델 생성

#훈련 되지 않은 generator 생성
generator = tf.keras.Sequential()
generator.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
generator.add(layers.BatchNormalization())
generator.add(layers.LeakyReLU())

generator.add(layers.Reshape((7, 7, 256)))
assert generator.output_shape == (None, 7, 7, 256) #  배치사이즈로 None

generator.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert generator.output_shape == (None, 7, 7, 128)
generator.add(layers.BatchNormalization())
generator.add(layers.LeakyReLU())

generator.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert generator.output_shape == (None, 14, 14, 64)
generator.add(layers.BatchNormalization())
generator.add(layers.LeakyReLU())

generator.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert generator.output_shape == (None, 28, 28, 1)

Discriminator 모델 생성

#훈련되지않은 discriminator 생성
discriminator = tf.keras.Sequential()
discriminator.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                  input_shape=[28, 28, 1]))
discriminator.add(layers.LeakyReLU())
discriminator.add(layers.Dropout(0.3))

discriminator.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
discriminator.add(layers.LeakyReLU())
discriminator.add(layers.Dropout(0.3))

discriminator.add(layers.Flatten())
discriminator.add(layers.Dense(1))

훈련 정의

#손실함수 정의
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

#감별자 손실함수
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
#생성자 손실함수
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

#생성자와 감별자의 옵티마이저 정의
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)


#체크포인트 설정
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# 훈련루프 정의

EPOCHS = 60
noise_dim = 100
num_examples_to_generate = 16

#시드고정(GIF로 훈련의 진전도를 시각화하기위해)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

#훈련 함수 정의

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # GIF를 위한 이미지를 바로 생성합니다.
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # 매 15 에포크마다 생성.
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    # print (' 에포크 {} 에서 걸린 시간은 {} 초 입니다'.format(epoch +1, time.time()-start))
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # 마지막 에포크가 끝난 후 생성
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

def generate_and_save_images(model, epoch, test_input):
  # `training`이 False로 맞춰진 것을 주목하세요.
  # 이렇게 하면 (배치정규화를 포함하여) 모든 층들이 추론 모드로 실행됩니다. 
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

훈련

%%time
train(train_dataset, EPOCHS)

다음과 같이 생성됩니다.

gif로 시각화

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

import IPython
if IPython.version_info > (6,2,0,''):
  display.Image(filename=anim_file)
try:
  from google.colab import files
except ImportError:
  pass
else:
  files.download(anim_file)

다음과같은 코드들로 gif파일을 생성해 생성결과를 확인합니다.

마무리

tensorflow 공식 문서를 보며, GAN에 대한 기초적인 지식을 알 수 있는 모델이 있어 소개 해드렸습니다.
앞으로도 재밌게 빌드할 수 있는 모델들을 가지고 와 소개해보도록 하겠습니다.
긴글 읽어 주셔서 감사합니다.

profile
DeepLearning, MLOps

0개의 댓글