CIFAR-10 데이터셋과 nn.Linear를 사용하여 테스트셋에서 Top-1 accuracy 65% 이상 달성해보기
Torchvision은 PyTorch와 함께 사용되는 컴퓨터 비전용 라이브러리이다.
1.0 torch.device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
1.1 torchvision.transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std(0.5, 0.5, 0.5))
])
1.2 torchvision.datasets
torchvision.datasets : Torchvision에서 제공하는 데이터셋을 받아오기
trainset = datasets.CIFAR10(
root='./cifar10',
train=True,
download=True,
transform=transform
)
testset = datasets.CIFAR10(
root='./cifar10',
train=False,
download=True,
transform=transform
)
1.3 torch.utils.data.DataLoader()
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=128,
shuffle=True,
num_workers=0
)