[pytorch] Dataset & DataLoader

강콩콩·2022년 3월 13일
0

pytorch

목록 보기
1/7
post-thumbnail

pytorch를 사용하여 DL 모델을 학습할 때, data의 기본 원소가 되는 두 클래스:
Dataset과 DataLoader class를 알아보도록 하겠습니다!
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
공식 홈페이지에도 아주 잘 나와있습니다 :)

😉 데이터 뭉치: Dataset class!

이게 뭔데요?

Dataset stores the samples and their corresponding labels

✔ Dataset class는 모델 학습시 사용할 data와 label을 저장합니다.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

✔ 미리 작성된 torch.utils.data.Dataset class를 상속(서브클래싱)하여 사용합니다.
✔ __init__() 메서드로 인스턴스 생성시 필요한 초기값들을 받습니다. (이미지 리스트, transform 함수 등)
✔ __getitem__() 메서드의 리턴값이 추후 학습에서 사용할 값이 되도록 오버라이딩합니다.

😏 이런 식으로 사용합니다!

import torch
from PIL import Image

class MyDataset(torch.utils.data.Dataset):
    """
    Attributes
    ----------
    img_list : 리스트
        이미지의 경로를 저장한 리스트
    label_list : 리스트
        label의 경로를 저장한 리스트
    phase : 'train' or 'val'
        학습 또는 테스트 여부 결정
    transform : object
        전처리 클래스의 인스턴스
    """

    def __init__(self, img_list, label_list, phase, transform):
        self.img_list = img_list
        self.label_list = label_list
        self.phase = phase  # train 또는 val을 지정
        self.transform = transform  # 이미지의 변형

    def __len__(self):
        '''이미지의 갯수를 반환'''
        return len(self.img_list)

    def __getitem__(self, index):
        '''
        전처리한 이미지 및 라벨 return
        '''
        image_path = self.file_list[index]
        img = Image.open(image_path)
        
        transformed_img = self.transform(img, self.phase)
        label = self.label_list[index]
        
        return transformed_img, label

😁 어때요? 정말 간단하죠!

그리고, transform 인자에 들어갈 인스턴스의 틀이 되는 클래스는 보통 __call__() 메서드를 정의하는 방법으로 많이 사용됩니다.
✔ __call__() : 클래스 인스턴스를 생성 후, () 명령어를 사용하면 실행되는 함수입니다.

from torchvision import models, transforms

class MyTransform():
    """
    Attributes
    ----------
    resize : int
        Transform 수행 후 변경될 width / height 값.
    mean : (R, G, B)
        각 색상 채널의 평균값.
    std : (R, G, B)
        각 색상 채널의 표준 편차.
    """

    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(
                    resize, scale=(0.5, 1.0)),  
                transforms.RandomHorizontalFlip(), 
                transforms.ToTensor(),  # 텐서로 변환
                transforms.Normalize(mean, std)  # 표준화
            ]),
            'val': transforms.Compose([
                transforms.Resize(resize),
                transforms.ToTensor(),  # 텐서로 변환
                transforms.Normalize(mean, std)  # 표준화
            ])
        }

    def __call__(self, img, phase='train'):
        """
        Parameters
        ----------
        phase : 'train' or 'val'
            전처리 모드를 지정.
        """
        return self.data_transform[phase](img)

예시에서는 학습을 위해 RandomResizedCrop() / RandomHorizontalFlip() 만 적용을 하였지만, 더 다양한 Augmentation을 적용할 수도 있습니다. 😁

https://pytorch.org/vision/stable/transforms.html

그래서..!

✔ 실제로 학습을 위한 인스턴스는 아래와 같이 생성됩니다 😎

size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

train_dataset = MyDataset(file_list=train_img_list, phase="train", transform=MyTransform(
    size, mean, std)))

val_dataset = MyDataset(file_list=val_img_list, phase="val", transform=MyTransform(
    size, mean, std)))

✔ train_dataset / val_dataset은 generator로써, 추후 '실제로 반복이 수행될 때' 메모리를 할당하여 작업을 수행합니다.
✔ D/L 모델 학습을 수행할 때, generator를 사용하지 않으면 OOM이 종종 발생하고는 하여, torch에서는 더욱 간단히 쓸 수 있도록 틀을 제공한 것으로 이해되네요 😁

😉 배치 단위로 가져오자 : DataLoader class!

DataLoader wraps an iterable around the Dataset to enable easy access to the samples

✔ DataLoader class는 Dataset class로 정의된 데이터 뭉치를 쉽게 샘플 단위로 가져올 수 있게 합니다.
✔ 즉, 보통 D/L 모델 학습을 위한 mini-batch 학습 단위로 가져오게 해주는 역할을 수행합니다.

🤦‍♀️ 그리고 코드는 넘나 간단.. :)

# DataLoader를 만든다
batch_size = 32

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

# 사전 객체에 정리
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

✔ batch size 단위로 가져오게 되어, train_dataloader의 iteration을 진행하면, [32, 224, 224, 3]의 shape을 가진 Tensor를 얻을 수 있습니다.
✨ 그리고 학습 수행하면 오케이!

👍 추가) Object Detection Task 등, batch에 포함된 데이터의 크기가 각각 다를 때가 있습니다. (각 이미지에 라벨이 몇개씩 있는지 알 수 없으므로) 이 때, DataLoader 인스턴스를 생성할 시 collate_fn 인자를 설정하여 해결할 수 있습니다.

https://pytorch.org/docs/stable/data.html?highlight=collate_fn

마치며

다음 글은 간단히 Fine Tuning을 진행하는 글을 작성하도록 하겠습니다 :>

profile
MLOps, ML Engineer. 데이터에서 시스템으로, 시스템에서 가치로.

0개의 댓글