pytorch DataLoader

NYC·2021년 8월 12일
0

딥러닝 이론 정리

목록 보기
4/6

파이토치 DataLoader

튜토리얼을 기반으로 적어보는 사용자 정의 DataLoader
출처 : https://tutorials.pytorch.kr/beginner/basics/data_tutorial.html

기본형

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

__init__

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

__init__ 함수는 Dataset 객체가 생성(instantiate)될 때 한 번만 실행된다. 여기서는 이미지와 주석 파일(annotation_file)이 포함된 디렉토리와 (다음 장에서 자세히 살펴볼) 두가지 변형(transform)을 초기화한다.
주요 목적은 데이터셋의 디렉토리를 가져오는 역할을 하기 위해서 만들어진다.
가지고 있는 데이터의 디렉토리 정보가 담긴 csv,txt,json 등이 없을 경우 직접 함수를 구현해야한다.
camvid 데이터셋을 load 하는 함수를 구현하였다.

    def __init__(self, path, image_set, transforms=None):
        assert image_set in ('train', 'val', 'test'), "image_set is not valid!"
        self.data_dir_path = path
        self.image_set = image_set
        self.transforms = transforms
        self.createIndex()
    def createIndex(self):
        self.img_list = []
        self.segLabel_list = []
        self.exist_list = []
        img_list = os.listdir(os.path.join(self.data_dir_path,self.image_set))
        self.img_list = [os.path.join(self.data_dir_path,self.image_set,file_name) for file_name in img_list]
        #print(self.img_list)
        self.segLabel_list = os.listdir(os.path.join(self.data_dir_path,self.image_set+"annot"))
        self.segLabel_list = [os.path.join(self.data_dir_path,self.image_set+"annot",file_name) for file_name in self.segLabel_list]
        for imgs in self.segLabel_list:
            imgs = cv2.imread(imgs)
            lis_id = [0]*12 # 클래스 개수 
            for num in np.unique(imgs[:,:,0].flatten()):
                lis_id[num] = 1
            self.exist_list.append(lis_id)
  1. 데이터셋의 경로를 가지고 온 후 train, test, valid 를 구분해준다.
  2. def createIndex(self) 함수를 만들어 데이터의 특성에 맞게 데이터를 불러올 수 있도록 해준다. ( __getitem__ )
  • self.img_list는 이미지의 모든 경로를 저장하는 list이다.
  • self.segLabel_list 세그먼트된 이미지 경로를 저장하는 list이다.
  • self.exist_list 세그먼트된 이미지에 어떤 클래스가 있는지 저장하는 list이다.
  • self.exist_list 경우 학습할 때 이미지에 어떤 클래스가 저장되어 있는지가 필요했기 때문에 따로 만들어준 list이다.
  • segLabel_list안에 클래스가 몇 개 있는지 찾고 one_hot_encoding 하였다.

__len__

def __len__(self):
    return len(self.img_labels)

__len__ 함수는 데이터셋의 샘플 개수를 반환한다. 데이터 개수를 표현할 수 있는 list만 있으면 사용할 수 있다.

def __len__(self):
    return len(self.img_list)

__getitem__

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    sample = {"image": image, "label": label}
    return sample

__getitem__ 함수는 주어진 인덱스 idx 에 해당하는 샘플을 데이터셋에서 불러오고 반환합니다. 인덱스를 기반으로, 디스크에서 이미지의 위치를 식별하고, read_image 를 사용하여 이미지를 텐서로 변환하고, self.img_labels 의 csv 데이터로부터 해당하는 정답(label)을 가져오고, (해당하는 경우) 변형(transform) 함수들을 호출한 뒤, 텐서 이미지와 라벨을 Python 사전(dict)형으로 반환합니다.
idx를 통해서 데이터를 가져오는데 sample 딕셔너리를 통해서 한꺼번에 가져온다.
idx를 가져올 때는 batch 개수에 맞춰서 가져오고 옵션에 따라 무작위로 가져갈 수 있다.

def __getitem__(self, idx):
    img = cv2.imread(self.img_list[idx])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    segLabel = cv2.imread(self.segLabel_list[idx])[:, :, 0]
    exist = np.array(self.exist_list[idx])
    sample = {'img': img,
              'segLabel': segLabel,
              'exist': exist,
              'img_name': self.img_list[idx]}
    if self.transforms is not None:
        sample = self.transforms(sample)
    return sample

idx를 통해서 데이터를 가져가기 때문에 __init__ 에서 변수들을 list 형태로 만든것이다.

collate

batch_size =2일 때, [sample1, sample2] 이렇게 list형식으로 input이 들어온다. batch로 묶일 모든 데이터를 잘 묶어주는(collate) 함수이다.

@staticmethod
def collate(batch):
    if isinstance(batch[0]['img'], torch.Tensor):
        img = torch.stack([b['img'] for b in batch])
    else:
        img = [b['img'] for b in batch]
    if batch[0]['segLabel'] is None:
        segLabel = None
        exist = None
    elif isinstance(batch[0]['segLabel'], torch.Tensor):
        segLabel = torch.stack([b['segLabel'] for b in batch])
        exist = torch.stack([b['exist'] for b in batch])
    else:
        segLabel = [b['segLabel'] for b in batch]
        exist = [b['exist'] for b in batch]
    samples = {'img': img,
               'segLabel': segLabel,
               'exist': exist,
               'img_name': [x['img_name'] for x in batch]}
    return samples
profile
Vision_NLP

0개의 댓글