[CH03] 07. 커스텀 데이터셋

SoYeong Gwon·2022년 6월 17일
0

DeepLearning Introduction

목록 보기
4/12
post-thumbnail

본 게시글은 다음 링크(https://wikidocs.net/book/2788)의 wiki docs를 참고하여 작성되었습니다.

1. 커스텀데이터셋

  • torch.utils.data.Dataset을 상속받아 직접 커스텀 데이터셋을 만드는 경우 존재
  • torch.utils.data.Dataset은 파이토치에서 데이터셋을 제공하는 추상 클래스
  • Dataset을 상속받아 다음 메소드를 오버라이드하여 커스텀 데이터셋을 만들어봤음.
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # 데이터셋의 전처리

    def __len__(self):
        # 데이터셋의 길이. 즉, 총 샘플의 수를 반환
    
    def __getitem__(self, idx):
        # 데이터셋에서 특정 1개의 샘플을 가져옴. 
  • 기본적인 뼈대는 다음과 같음.
    • init()
      • 데이터의 전처리 수행
    • len()
      • 데이터셋의 길이 리턴
      • len(dataset)을 했을때, 데이터셋의 크기를 리턴
    • getitem()
      • 데이터셋에서 특정 1개의 샘플을 가져옴.
      • dataset[i]를 했을때 i번째 샘플을 가져오도록 하는 인덱싱

2. 커스텀 데이터셋으로 선형회귀 구현하기

import torch 
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# Dataset 생성
class CustomDataset(Dataset):
    def __init__(self):
        self.x_data=[[73,80,75],
                     [93,88,93],
                     [89,91,90],
                     [96,98,100],
                     [73,66,70]]
        self.y_data=[[152],[185],[180],[196],[142]]

    # 총 데이터의 개수 리턴
    def __len__(self):
        return len(self.x_data)
    
    # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
    def __getitem__(self,idx):
        x = torch.FloatTensor(self.x_data[idx])
        y = torch.FloatTensor(self.y_data[idx])
        return x,y
dataset = CustomDataset()
dataloader = DataLoader(dataset,batch_size=2, shuffle=True) 
model=torch.nn.Linear(3,1)
optimizer=torch.optim.SGD(model.parameters(),lr=1e-5)
nb_epochs=30
for epoch in range(nb_epochs+1):
    for batch_idx,samples in enumerate(dataloader):
        print(batch_idx)
        print(samples)

        x_train,y_train=samples

        prediction=model(x_train)
        cost=F.mse_loss(prediction,y_train)

        #cost로 prediction 개선 
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        print('Epoch {:4d}/{} Batch {}/{} Cost:{:6f}'.format(epoch, nb_epochs,batch_idx,len(dataloader),cost.item()))
...
0
[tensor([[89., 91., 90.],
        [93., 88., 93.]]), tensor([[180.],
        [185.]])]
Epoch    0/30 Batch 0/3 Cost:1.153904
...

3. Insights

Class Magic Method

  • Class에 선언된 method를 보면 __len__을 확인할 수 있음.
  • 해당 함수는 magic method__len__이 선언된 객체의 길이를 구할 수 있음.
  • 동작원리
    1. len(x) 호출
    2. class 객체 x__len__ 존재 여부 확인
    • 있다면: 객체 x이 길이 반환
    • 없다면: TypeError 발생

0개의 댓글