OX분류하기

InSung-Na·2023년 5월 5일
0

Part 10. Deep Learning

목록 보기
6/7
post-thumbnail

해당 글은 제로베이스데이터스쿨 학습자료를 참고하여 작성되었습니다

OX Classification

1. 개요

  • 딥러닝 기초 이진분류에 대해 학습한다
  • 데이터셋 : OX Images
  • 데이터셋을 얻지 못했다
    • 결과는 학습자료를 보고, 코드만 입력하면서 학습한다

2. 데이터 수집

  • 데이터셋 없음
  • 코드만 입력

3. 데이터 전처리

  • 이미지 불러오기(실행X)
  • train_test별로 크기 조정
  • 이미지 제네레이터

3-1. 이미지 조정

from glob import glob

train_raw_path = "./train_raw/O/*."
train_raw_O_list = glob(train_raw_path)
train_raw_O_list

img_resize

# !pip install scikit-image
from skimage.transform import rescale, resize
from skimage import color
from skimage.io import imread, imsave
import matplotlib.pyplot as plt
import numpy as np

def img_resize(img):
    img = color.rgb2gray(img)
    return resize(img, (28,28))

train_O

from tqdm.notebook import tqdm

def convert_train_O():
    train_raw_O_list = glob(train_raw_path)
    for each in tqdm(train_raw_O_list):
        img = imread(each)
        img_resized = img_resize(img)
        save_name = "./train/O/" + each.split("/")[-1][:-4] + ".png"    # rename
        imsave(save_name, np.round(img_resized*255).astype(int))
        
convert_train_O()

train_X

from tqdm.notebook import tqdm

train_raw_path = "./train_raw/X/*."
def convert_train_X():
    train_raw_X_list = glob(train_raw_path)
    for each in tqdm(train_raw_X_list):
        img = imread(each)
        img_resized = img_resize(img)
        save_name = "./train/X/" + each.split("/")[-1][:-4] + ".png"    # rename
        imsave(save_name, np.round(img_resized*255).astype(int))
        
convert_train_X()

test_O

from tqdm.notebook import tqdm

test_raw_path = ""
def convert_test_O():
    test_raw_O_list = glob(test_raw_path)
    for each in tqdm(test_raw_O_list):
        img = imread(each)
        img_resized = img_resize(img)
        save_name = "./test/O/" + each.split("/")[-1][:-4] + ".png"    # rename
        imsave(save_name, np.round(img_resized*255).astype(int))
        
convert_test_O()

test_X

from tqdm.notebook import tqdm

test_raw_path = ""
def convert_test_X():
    test_raw_X_list = glob(test_raw_path)
    for each in tqdm(test_raw_X_list):
        img = imread(each)
        img_resized = img_resize(img)
        save_name = "./test/X/" + each.split("/")[-1][:-4] + ".png"    # rename
        imsave(save_name, np.round(img_resized*255).astype(int))
        
convert_test_X()

3-2. Image_generator

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
import keras

np.random.seed(13)
train_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory("./train", target_size=(28,28), batch_size=3, class_mode="categorical")

test_datagen = ImageDataGenerator(rescale=1./255)

test_generator = test_datagen.flow_from_directory("./test", target_size=(28,28), batch_size=3, class_mode="categorical")

4. 모델링

model = Sequential()
model.add(Conv2D(32, kernel_size=(3,3), activation="relu", input_shape=(28,28,3)))
model.add(Conv2D(64, kernel_size=(3,3), activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dense(128, activation="relu"))
model.add(Dense(2, activation="softmax"))

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

hist = model.fit_generator(train_generator, steps_per_epoch=15, epochs=50, validation_data=test_generator, validation_steps=5)
plt.figure(figsize=(12,6))
plt.plot(hist.history["loss"], label="loss")
plt.plot(hist.history["val_loss"], label="val_loss")
plt.plot(hist.history["accuracy"], label="accuracy")
plt.plot(hist.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()

5. 모델 평가

scores = model.evalute(test_generator, steps=5)

print("%s: %.2f%%" %(model.metrics_names[1], scores[1]*100))	# 100.0%
model.predict(test_generator)
n = 1

def show_prediction_result(n):
    img = imread(test_generator.filepaths[n])
    pred = model.predict(np.expand_dims(color.gray2rgb(img), axis=0))
    title = "Predict : " + str(np.argmax(pred))
    plt.imshow(img/255., cmap="gray")
    plt.title(title)
    plt.show()
    
show_prediction_result(n)
show_prediction_result(40)

0개의 댓글