[CH 5] 05. 소프트맥스로 MNIST 데이터 분류하기

SoYeong Gwon·2022년 7월 26일
0

DeepLearning Introduction

목록 보기
12/12
post-thumbnail

이번에는 MNIST 데이터에 대해서 이해하고, PYTORCH로 소프트맥스 회귀를 구현하여 분류하는 실습을 진행해보자.

01. MNIST 데이터 이해하기

  • MNIST는 숫자 0부터 9까지의 이미지로 구성된 손글씨 데이터셋
  • 과거 우체국에서 편지의 우편 번호를 인식하기 위해 만들어진 훈련 데이터
  • 총 60,000개의 훈련데이터와 레이블, 총 10,000여개의 테스트 데이터와 레이블로 구성
  • MNIST 문제는 손글씨로 적힌 숫자 이미가 들어오면, 그 이미지가 무슨 숫자인지 맞추는 문제임.
for X,Y in data_loader:
    # 입력 이미지를 [batch_size * 784]의 크기로 reshape
    # 레이블은 원-핫 인코딩
    X = X.view(-1,28*28)
    # view를 통해서 (배치크기 * 784)로 크기 변환 
    # 원래는 (배치크기 * 1 * 28 * 28)

02. 토치비전(torchvision) 소개하기

  • torchvision
    • cv 분야에 유명한 데이터셋, 이미 구현되어있는 유명한 모델들과 이미지 전처리 도구를 포함하는 패키지

03. 분류기 구현을 위한 사전 설정

import library

import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
import random

Set GPU and Seed

USE_CUDA = torch.cuda.is_available() # gpu 사용 가능하면 true or false
device = torch.device("cuda" if USE_CUDA else "cpu")
print("다음 기기로 학습합니다.",device)
# 랜덤 시드 고정
random.seed(777)
torch.manual_seed(777)

if device == 'cuda':
    torch.cuda.manual_seed_all(777)

Hyperparameter

training_epochs = 15
batch_size = 100

04. MNIST 분류기 구현하기

(1) 데이터 불러오기

torchvision을 통해서 MNIST 데이터셋 불러오기

  • root : MNIST 데이터를 다운받을 경로
  • train = True: 훈련 데이터 or 테스트 데이터
  • transform : 파이토치 텐서로 변환
  • download: 해당 경로에 데이터가 없다면 다운
#MNIST dataset
mnist_train = dsets.MNIST(root = 'MNIST_data/', 
                          train = True, 
                          transform = transforms.ToTensor(), 
                          download=True) 
mnist_test = dsets.MNIST(root = 'MNIST_data/',
                         train = False,
                         transform = transforms.ToTensor(),
                         download = True)

(2) 데이터로더 선언

  • dataset : 로드할 대상
  • batch_size : 배치 크기
  • shuffle : 매 에포크마다 미니 배치를 셔플할 것인지의 여부 (bool)
  • drop_last : 마지막 배치를 버릴 것인지 의미
    • 1000개의 데이터, 배치 크기가 128이라고 했을때, 1000을 128로 나누면 총 7개가 나오고, 나머지 104개가 남음
    • 104개로 마지막 배치를 한다고 하면 128개(기존 배치 사이즈)을 충족하지 못함.
    • 이때 남은 104개를 버릴지 말지를 결정하는 hyperparameter가 drop_last = True임 .
    • 이는 다른 미니배치보다 개수가 적은 마지막 배치를 경사하강법에 사용하여 Overfitting 현상을 막는 효과를 낼 수 있음.
# dataset loader
data_loader = DataLoader(dataset = mnist_train,
                         batch_size = batch_size,
                         shuffle = True,
                         drop_last = True)

(3) 모델 설계

  • to() 함수를 사용하여 모델의 매개변수를 지정한 장치의 메모리로 보냄.
  • cpu의 경우 필수는 아니지만, gpu의 경우 to(device)를 꼭 해줘야함.
# MNIST data image of shape 28 * 28 = 784
linear = nn.Linear(784, 10, bias = True).to(device)

(4) 비용함수와 옵티마이저 정의

criterion = nn.CrossEntropyLoss().to(device) # 내부적으로 소프트맥스 함수 포함
optimizer = torch.optim.SGD(linear.parameters(), lr = 0.1)

(5) 학습

for epoch in range(training_epochs): # 15
    avg_cost = 0 
    total_batch = len(data_loader)

    for X, Y in data_loader:
        # 배치 크기가 100이므로 아래의 연산에서 X는 (100,784)의 텐서가 된다. 
        X = X.view(-1,28*28).to(device)
        # 레이블은 원-핫 인코딩이 된 상태가 아니라 0~9의 정수
        Y = Y.to(device)
        # 가설, 비용 선언 
        hypothesis = linear(X)
        cost = criterion(hypothesis,Y)
        
        # 갱신 과정 
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        avg_cost += cost/total_batch

    print('Epoch: ', '%04d' %(epoch + 1), 'cost = ','{:9f}'.format(avg_cost))

print("Learning finished")

(6) 테스트데이터를 사용하여 모델 테스트

with torch.no_grad(): # gradient 계산을 수행하지 않음
    X_test = mnist_test.test_data.view(-1,28*28).float().to(device)
    Y_test = mnist_test.test_labels.to(device)

    prediction = linear(X_test)
    correct_prediction = torch.argmax(prediction,1) == Y_test
    accuracy = correct_prediction.float().mean()

    print('Accuracy:', accuracy.item())

    # MNIST 테스트 데이터에서 무작위로 하나를 뽑아서 예측
    r = random.randint(0,len(mnist_test)-1)
    X_single_data = mnist_test.test_data[r:r+1].view(-1,28*28).float().to(device)
    Y_single_data = mnist_test.test_labels[r:r+1].to(device)

    print('Label: ', Y_single_data.item())
    single_prediction = linear(X_single_data)
    print('Prediction: ', torch.argmax(single_prediction,1).item())

    plt.imshow(mnist_test.test_data[r:r+1].view(28,28),cmap='Greys',interpolation = 'nearest')
    plt.show()

0개의 댓글