pytorch 튜토리얼 [2]

·2023년 7월 10일
0

Pytorch 튜토리얼

목록 보기
2/4

Dataset과 DataLoader

data 샘플 처리하는 코드는 지저분하고 유지보수 어려움
더 나은 가독성과 모듈성 위해 dataset 코드를 모델 학습 코드로부터 분리하는 것이 이상적이다.
pytorch는
torch.utils.data.DataLoadertorch.utils.data.Dataset 의 두 가지 data 기본 요소를 제공하고 미리 준비해준 dataset 뿐만 아니라 가지고 있는 데이터를 사용할 수 있다.
Dataset은 샘플과 정답(label)을 저장하고
DataLoader는 dataset을 샘플에 쉽게 접근할 수 있도록 순회 가능한 객체(iterable)로 감싼다.

pytorch의 도메인 특화 라이브러리들은 미리 준비해둔 다양한 dataset 제공
dataset은 torch.utils.data.Dataset의 하위 클래스로 개별 data를 특정하는 함수가 구현되어 있다. 이러한 dataset은 모델을 만들어보고 성능을 측정하는데 사용할 수 있다.

Dataset 불러오기

TorchVision에서 Fashion-MNIST 예제 확인

Fashion-MNIST는 이미지 dataset으로 60,000개 학습 예제
10,000 테스트 예제로 이루어짐
각 예제는 흑백의 28x28 이미지와 10개의 분류 중 하나인 정답(label)로 구성

다음 매개변수 사용해 FashionMNIST dataset 불러온다.

  • root는 학습/테스트 데이터가 저장되는 경로
  • train은 학습용, 테스트용 dataset 여부 지정
  • download=True는 root에 data 없는 경우 인터넷에서 download함
  • transform과 target_transform은 특징과 변형을 지정
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Dataset 순회하고 시각화

Dataset에 list처럼 직접 접근할 수 있다.
training_data[index], matplotlib 사용해 데이터 시각화

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

File에서 사용자 정의 dataset 만들기

사용자 정의 Dataset 클래스는 반드시 3개의 함수를 구현해야 한다.

__init__
__len__
__getitem__

아래 구현 살펴보면 FashionMNIST 이미지들은 img_dir 디렉토리에 저장되고
정답은 annotations_files csv 파일에 별도로 저장된다.

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label'])
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

init

init 함수는 Dataset 객체가 생성(instantiate)될 때 한번만 실행된다.
여기서 이미지와 주석파일 포함된 디렉토리와 두가지 변형을초기화한다.
labels.csv 파일은 다음과 같다.

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

len

len 함수는 Dataset 샘플 개수 반환

def __len__(self):
	return len(self.img_labels)

getitem

주어진 인덱스 idx에 해당하는 샘플을 dataset에서 불러오고 반환한다.
인덱스 기반으로 디스크에서 이미지 위치 식별하고
read_image 사용해 이미지를 tensor로 변환
self.img_labels의 csv data로부텉 해당하는 label 가져오고
변형함수들 호출한 뒤, tensor 이미지와 라벨을 Python 사전형(dict)으로반환한다.

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    sample = {"image": image, "label": label}
    return sample

DataLoader로 학습용 데이터 준비

Dataset은 dataset의 특징을 가져오고 하나의 샘플에 정답을 지정하는 일을 한번에 한다.
모델 학습할 때, 일반적으로 샘플들을 minibatch로 전달, 매 epoch마다 데이터 다시 섞어 과적합 막고, Python의 multiprocessing 사용해 데이터 검색 속도 높이려고 한다.

DataLoader는 간단한 API로 이러한 복잡한 과정들을 순회 가능한 객체(iterable)다.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=Ture)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

DataLoader를 통해 순회하기 (iterate)

DataLoader에 데이터셋 불러온 뒤 필요에 따라 데이터셋을 순회할 수 있다.
아래 각 순회는(iteration) train_features와 train_labels의 묶음(batch)를 반환한다.
shuffle=True 로 지정했으므로, 모든 배치를 순회한 뒤 데이터가 섞인다.

# 이미지와 정답(label)을 표시합니다.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

Out

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 0

0개의 댓글