ImageNet Normalize

DOYOUNG KIM·2023년 4월 14일
0
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

Pytorch 사용시 torchvision.transform 을 이용해서 여러 비전 관련 작업들을 수행한다.
여러 사람들이 다양한 데이터를 사용하고 이를 이용해 모델을 학습 시킨다.
위와 같은 정규화 과정의 평균과 표준편차는 대부분이 사요한다. 그럼 이 정규화 코드는 무엇을 의미할까

이미지의 구간 범위를 0-1로 하면 평균은 0.5로 하는것이 일반적이라고 생각되지만 그렇지 않다.

위 수치는 ImageNet 데이터세트의 평균과 표준편차 통계치이다.

데이터간의 차이가 존재하고 이를 표준화 시켜주는 과정이 중요하다.
전체 데이터 간의 화소의 평균, 표준편차를 일괄적으로 적용하는 과정이 필요하고 우리는 양질의 대량 데이터세트인 ImageNet의 표준편차를 표준처럼 사용하는 것이다.

자신의 데이터세트에 맞춘 값을 사용 가능하지만 일반적으로는 자체 평균과 표준이 존재하는 ImageNet의 표준편차를 사용하는 것이 좋다.

표준편차 구하는 예제 - Mnist

import os
import torch
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset
from tqdm.notebook import tqdm
from time import time

N_CHANNELS = 1

dataset = datasets.MNIST("data", download=True,
                 train=True, transform=transforms.ToTensor())
full_loader = torch.utils.data.DataLoader(dataset, shuffle=False, num_workers=os.cpu_count())

before = time()
mean = torch.zeros(1)
std = torch.zeros(1)
print('==> Computing mean and std..')
for inputs, _labels in tqdm(full_loader):
    for i in range(N_CHANNELS):
        mean[i] += inputs[:,i,:,:].mean()
        std[i] += inputs[:,i,:,:].std()
mean.div_(len(dataset))
std.div_(len(dataset))
print(mean, std)

print("time elapsed: ", time()-before)
profile
매일 1%씩 성장하는 개발 공부 블로그 입니다.

0개의 댓글