[PyTorch] Dataset, DataLoader 직접 사용하기

Jaeyeon Kim·2023년 3월 21일
1

Deep-Learning

목록 보기
2/4

학습을 시키다보면, 데이터를 어떻게 먹이는가 는 중요한 요소이다.
데이터셋의 크기가 크다보니 모든 데이터를 메모리에 올리고 학습을 진행하면
Out of Memory라는 무시무시한 문구를 보게 되기 때문이다.

이를 위해서 데이터셋을 쪼개서 학습을 진행하는 배치 학습을 진행하고,
이 때 데이터를 나눠서 가져올 수 있게끔 도와주는 DataLoader 클래스가 있다.

내가 가진 데이터셋은 ['id', 'text', 'summary'] 의 컬럼으로 이루어진 글자 데이터셋이다.
이를 어떤 식으로 DataLoader를 통해 배치를 불러오는지 알아보자.

CustomDataset 클래스 생성

데이터셋을 csv나 dataframe으로 가지고 있는 것보다는
내가 운용하기 좋게 바꿔놓는 것이 좋다.
이럴 때 쓰는 것이 Dataset 클래스를 상속 받아 CustomDataset 클래스를 만드는 것이다.

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
	def __init__(self, text, summary):
    	self.text = text
        self.summary = self.summary
        
    def __len__(self):
    	return len(self.text)
    
    def __getitem__(self, index):
    	text = self.text[index]
        summary = self.summary[index]
        item = {"text": text, "summary": summary}
        return item

CustomDataset 불러오기

import pandas as pd 

DATA_PATH = {path}
train_df = pd.read_csv(data_path)
MyDataset = CustomDataset(train_df['text'], train_df['summary'])

DataLoader로 불러오기

이렇게 데이터셋을 지정해주면, 데이터 로더에서 불러올 수 있다.

from torch.utils.data import DataLoader

MyDataLoader = DataLoader(MyDataset, batch_size=4, shuffle=True)

batch_size를 조정해주며, 한 번의 iteration에서 불러올 데이터의 개수를 수정할 수 있다.

next(iter(MyDataLoader))를 통해 데이터가 잘 나오는지 확인할 수 있다.

학습 시 사용하기

이를 학습 시킬 때,

...
for batch in MyDataLoader:
	text = batch['text']
    summary = batch['summary']
...

처럼 사용하여 한 배치씩 꺼내서 쓸 수 있다.

profile
낭만과 열정으로 뭉친 개발자 🔥

0개의 댓글