Functional 모델링 + ResNet구현

yeoni·2023년 6월 27일
0

Tensorflow

목록 보기
8/15

Functional API

  • tf.keras.Sequential 보다 더 유연하게 모델을 정의할 수 있는 방법
  • 80%정도는 구현 가능

ResNet

ResNet의 핵심은 Skip Connection

Functional API로 ResNet 구현

1) 모델 정의

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

np.random.seed(7777)
tf.random.set_seed(7777)

from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, Flatten, Dense, Add

## Functional API 를 이용해 ResNet 구현
def build_resnet(input_shape):
    inputs = Input(input_shape)

    net = Conv2D(32, kernel_size=3, strides=2,
                 padding='same', activation='relu')(inputs)
    net = MaxPool2D()(net)

    net1 = Conv2D(64, kernel_size=1, padding='same', activation='relu')(net)
    net2 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(net1)
    net3 = Conv2D(64, kernel_size=1, padding='same', activation='relu')(net2)

    net1_1 = Conv2D(64, kernel_size=1, padding='same')(net)
    net = Add()([net1_1, net3])

    net1 = Conv2D(64, kernel_size=1, padding='same', activation='relu')(net)
    net2 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(net1)
    net3 = Conv2D(64, kernel_size=1, padding='same', activation='relu')(net2)

    net = Add()([net, net3])

    net = MaxPool2D()(net)

    net = Flatten()(net)
    net = Dense(10, activation="softmax")(net)

    model = tf.keras.Model(inputs=inputs, outputs=net, name='resnet')

    return model
    
model = build_resnet((32, 32, 3))
model.summary()

2) 학습 로더

class Cifar10DataLoader():
    def __init__(self):
        # data load
        (self.train_x, self.train_y), \
            (self.test_x, self.test_y) = tf.keras.datasets.cifar10.load_data()
        self.input_shape = self.train_x.shape[1:] #[32, 32, 3]

    def scale(self, x):

        return (x / 255.0).astype(np.float32)

    def preprocess_dataset(self, dataset):

        (feature, target) = dataset

        # scaling #
        scaled_x = np.array([self.scale(x) for x in feature])

        # label encoding #
        ohe_y = np.array([tf.keras.utils.to_categorical(
            y, num_classes=10) for y in target])

        return scaled_x, ohe_y.squeeze(1) #(50000, 1, 10) -> (50000, 10)

    def get_train_dataset(self):
        return self.preprocess_dataset((self.train_x, self.train_y))

    def get_test_dataset(self):
        return self.preprocess_dataset((self.test_x, self.test_y))
        
cifar10_loader = Cifar10DataLoader()

# shape, dtype 확인
train_x, train_y = cifar10_loader.get_train_dataset()
test_x, test_y = cifar10_loader.get_test_dataset()

3) 학습 로직

lr = 0.03
opt = tf.keras.optimizers.Adam(lr)
loss = tf.keras.losses.categorical_crossentropy

model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])

4) 학습 & 후처리

hist = model.fit(train_x, train_y, epochs=10, batch_size=128, validation_data=(test_x, test_y))

plt.figure(figsize=(10, 5))
plt.subplot(221)
plt.plot(hist.history['loss'])
plt.title("loss")
plt.subplot(222)
plt.plot(hist.history['accuracy'], 'b-')
plt.title("acc")
plt.subplot(223)
plt.plot(hist.history['val_loss'])
plt.title("val_loss")
plt.subplot(224)
plt.plot(hist.history['val_accuracy'], 'b-')
plt.title("val_accuracy")

plt.tight_layout()
plt.show()


Reference
1) 제로베이스 데이터스쿨 강의자료

profile
데이터 사이언스 / just do it

0개의 댓글