STL10 Dataset 다운로드
import torch
import os
import torchvision
import torchvision.transforms as transforms
import collections
from tqdm import tqdm
data_path = "./data"
transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize((224, 224)),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
if not os.path.exists(data_path):
os.mkdir(data_path)
train_set = torchvision.datasets.STL10(root=data_path, split="train", download=True, transform=transform)
test_set = torchvision.datasets.STL10(root=data_path, split="test", download=True, transform=transform)
Dataset 확인
print(train_set.data.shape)
print(test_set.data.shape)
print(f"image tensor: {train_set[3][0]}")
print(f"image label: {train_set[3][1]}")
y_train = [y for _, y in train_set]
y_test = [y for _, y in test_set]
counter_train = collections.Counter(y_train)
counter_test = collections.Counter(y_test)
print(counter_train)
print(counter_test)
DataLoader 생성
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
shuffle=True, num_workers=0, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
shuffle=False, num_workers=0, drop_last=True)
DataLoader 확인
print(f"TrainLoader Type: {type(train_loader)}")
print(f"TrainLoader Length: {len(train_loader)}")
print(f"TestLoader Length: {len(test_loader)}")
data_iter = iter(train_loader)
images, labels = data_iter.next()
print(f"TrainLoader Image Size: {images.size()}")
for i, (images, labels) in enumerate(tqdm(train_loader)):
print(f"TrainLoader Index: {i + 1} | TrainLoader Images Size: {images.size()} | TrainLoader labels: {labels}")