Dataset과 DataLoader

SSW·2022년 9월 15일
0

Pytorch

목록 보기
1/1

STL10 Dataset 다운로드

import torch
import os
import torchvision
import torchvision.transforms as transforms
import collections

from tqdm import tqdm

data_path = "./data"

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((224, 224)),
                                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

# Generate data directory
if not os.path.exists(data_path):
	os.mkdir(data_path)

train_set = torchvision.datasets.STL10(root=data_path, split="train", download=True, transform=transform)
test_set = torchvision.datasets.STL10(root=data_path, split="test", download=True, transform=transform)

Dataset 확인

print(train_set.data.shape)
print(test_set.data.shape)
print(f"image tensor: {train_set[3][0]}")
print(f"image label: {train_set[3][1]}")

y_train = [y for _, y in train_set]
y_test = [y for _, y in test_set]
counter_train = collections.Counter(y_train)
counter_test = collections.Counter(y_test)
print(counter_train)
print(counter_test)

DataLoader 생성

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=0, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                         shuffle=False, num_workers=0, drop_last=True)

DataLoader 확인

print(f"TrainLoader Type: {type(train_loader)}")
print(f"TrainLoader Length: {len(train_loader)}")
print(f"TestLoader Length: {len(test_loader)}")

# train_loader 안의 실제값 확인
data_iter = iter(train_loader)
images, labels = data_iter.next()
print(f"TrainLoader Image Size: {images.size()}")

for i, (images, labels) in enumerate(tqdm(train_loader)):
    print(f"TrainLoader Index: {i + 1} | TrainLoader Images Size: {images.size()} | TrainLoader labels: {labels}")
profile
ssw

0개의 댓글