[pytorch] Image Classification : ๐Ÿถ (Using ViT)

๊ฐ•์ฝฉ์ฝฉยท2022๋…„ 3์›” 25์ผ
2

pytorch

๋ชฉ๋ก ๋ณด๊ธฐ
3/7
post-thumbnail

์ง€๋‚œ ์‹œ๊ฐ„์— ์ด์–ด, Dataset / DataLoader class๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๊ฐ•์•„์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ์ž‘์„ฑํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค! ๐Ÿ˜Ž

โœจ ์™œ ๊ฐ•์•„์ง€๋ƒ๊ตฌ์š”? ๐Ÿถ ๊ฐ•์•„์ง€๋Š” ๊ท€์—ฌ์šฐ๋‹ˆ๊นŒ์š” :)

๐Ÿ˜‰ ๊ทธ๋Ÿผ, LET'S DIGGIN' !

Data ์ค€๋น„

โœ” Data๋Š” Kaggle์—์„œ Stanford Dog Dataset์„ ์ค€๋น„ํ–ˆ์Šต๋‹ˆ๋‹ค.
โœ” https://www.kaggle.com/datasets/jessicali9530/stanford-dogs-dataset

  • 800MB ์ •๋„ ๋‚˜๊ฐ€๊ณ , ๋ผ๋ฒจ๋„ 120๊ฐœ ์ •๋„๋กœ ์ ๋‹นํ•ด ๋ณด์ด๋”๋ผ๊ตฌ์š” :)

โœ” train : val = 0.85 : 0.15 split ์ˆ˜ํ–‰ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

import os
import shutil

root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
org_image_path = os.path.join(root_path, "archive/images/Images/")
labels = os.listdir(org_image_path)

labels_cnt_list = []

for l in labels:
    labels_cnt_list.append(len(os.listdir(os.path.join(org_image_path, l))))

os.makedirs(os.path.join(root_path, 'data/'), exist_ok=True)
os.makedirs(os.path.join(root_path, 'data/train/'), exist_ok=True)
os.makedirs(os.path.join(root_path, 'data/val/'), exist_ok=True)

for l in labels:
    os.makedirs(os.path.join(root_path, 'data/train/', l), exist_ok=True)
    os.makedirs(os.path.join(root_path, 'data/val/', l), exist_ok=True)

train_img_list = []
train_label_list = []
val_img_list = []
val_label_list = []

for idx, l in enumerate(labels):
    num_train = int(labels_cnt_list[idx] * 0.85)
    tmp_image_name_list = os.listdir(os.path.join(org_image_path, l))
    for cnt, fname in enumerate(tmp_image_name_list):
        if cnt <= num_train:
        	dst_path = os.path.join(root_path, 'data/train/', l, fname)
            shutil.copy(os.path.join(org_image_path, l, fname), dst_path)
            train_img_list.append(dst_path)
            train_label_list.append(idx)
        else:
        	dst_path = os.path.join(root_path, 'data/val/', l, fname)
            shutil.copy(os.path.join(org_image_path, l, fname), dst_path)
            val_img_list.append(dst_path)
            val_label_list.append(idx)

ViT?

โœ” Vit : Vision Transformer!
โœ” Computer Vision Task์—์„œ ํ•ญ์ƒ ๋น ์ง€์ง€ ์•Š๋˜ CNN ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ œ์™ธํ•˜๊ณ , ์˜ค์ง Self-attention๋งŒ์„ ์‚ฌ์šฉํ•˜์—ฌ์„œ๋„ ์ถฉ๋ถ„ํžˆ CV Task๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Œ์„ ๋ณด์—ฌ์ค€ ๋…ผ๋ฌธ์ž…๋‹ˆ๋‹ค :)
โœ” pytorch์— ๊ตฌํ˜„์ฒด๊ฐ€ ์ค€๋น„๋˜์–ด ์žˆ์–ด, ๊ทธ๋Œ€๋กœ ํ™œ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.
โœจ timm (PyTorch Image Models) ํŒจํ‚ค์ง€๋ฅผ ์‚ฌ์šฉํ•  ์˜ˆ์ •์ž…๋‹ˆ๋‹ค!
๐Ÿ˜Ž timm ํŒจํ‚ค์ง€๋Š” ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์ด๋Š” Computer Vision D/L ์•Œ๊ณ ๋ฆฌ์ฆ˜๋“ค์„ ๋ฏธ๋ฆฌ pytorch๋กœ ๊ตฌํ˜„ํ•ด ๋†“์€ ํŒจํ‚ค์ง€๋กœ์„œ, ๊ฐ„๋‹จํ•˜๊ณ  ๋น ๋ฅด๊ฒŒ ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋„์™€์ค๋‹ˆ๋‹ค :)

pip install timm

https://arxiv.org/abs/2010.11929

Dataset code

์ง€๋‚œ ์‹œ๊ฐ„์— ๊ตฌํ˜„ํ•˜์˜€๋˜ Dataset๊ณผ DataLoader class๋ฅผ ํ™œ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๐Ÿ˜‰

Dataset class ์„ ์–ธ

import torch
from PIL import Image

class MyDataset(torch.utils.data.Dataset):
    """
    Attributes
    ----------
    img_list : ๋ฆฌ์ŠคํŠธ
        ์ด๋ฏธ์ง€์˜ ๊ฒฝ๋กœ๋ฅผ ์ €์žฅํ•œ ๋ฆฌ์ŠคํŠธ
    label_list : ๋ฆฌ์ŠคํŠธ
        label์˜ ๊ฒฝ๋กœ๋ฅผ ์ €์žฅํ•œ ๋ฆฌ์ŠคํŠธ
    phase : 'train' or 'val'
        ํ•™์Šต ๋˜๋Š” ํ…Œ์ŠคํŠธ ์—ฌ๋ถ€ ๊ฒฐ์ •
    transform : object
        ์ „์ฒ˜๋ฆฌ ํด๋ž˜์Šค์˜ ์ธ์Šคํ„ด์Šค
    """

    def __init__(self, img_list, label_list, phase, transform):
        self.img_list = img_list
        self.label_list = label_list
        self.phase = phase  # train ๋˜๋Š” val์„ ์ง€์ •
        self.transform = transform  # ์ด๋ฏธ์ง€์˜ ๋ณ€ํ˜•

    def __len__(self):
        '''์ด๋ฏธ์ง€์˜ ๊ฐฏ์ˆ˜๋ฅผ ๋ฐ˜ํ™˜'''
        return len(self.img_list)

    def __getitem__(self, index):
        '''
        ์ „์ฒ˜๋ฆฌํ•œ ์ด๋ฏธ์ง€ ๋ฐ ๋ผ๋ฒจ return
        '''
        # img_path = self.img_list[index]
        # img = Image.open(img_path).convert('RGB')
        img = self.img_list[index]
        
        transformed_img = self.transform(img, self.phase)
        label = self.label_list[index]
        
        return transformed_img, label

Transform class ์„ ์–ธ

from torchvision import models, transforms

class MyTransform():
    """
    Attributes
    ----------
    resize : int
        Transform ์ˆ˜ํ–‰ ํ›„ ๋ณ€๊ฒฝ๋  width / height ๊ฐ’.
    mean : (R, G, B)
        ๊ฐ ์ƒ‰์ƒ ์ฑ„๋„์˜ ํ‰๊ท ๊ฐ’.
    std : (R, G, B)
        ๊ฐ ์ƒ‰์ƒ ์ฑ„๋„์˜ ํ‘œ์ค€ ํŽธ์ฐจ.
    """

    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(
                    (resize, resize), scale=(0.5, 1.0)),  
                transforms.RandomHorizontalFlip(), 
                transforms.ToTensor(),  # ํ…์„œ๋กœ ๋ณ€ํ™˜
                transforms.Normalize(mean, std)  # ํ‘œ์ค€ํ™”
            ]),
            'val': transforms.Compose([
                transforms.Resize((resize, resize)),
                transforms.ToTensor(),  # ํ…์„œ๋กœ ๋ณ€ํ™˜
                transforms.Normalize(mean, std)  # ํ‘œ์ค€ํ™”
            ])
        }

    def __call__(self, img, phase='train'):
        """
        Parameters
        ----------
        phase : 'train' or 'val'
            ์ „์ฒ˜๋ฆฌ ๋ชจ๋“œ๋ฅผ ์ง€์ •.
        """
        return self.data_transform[phase](img)

๐Ÿ˜Ž ์ข‹์Šต๋‹ˆ๋‹ค! ์ด์ œ ์‚ฌ์ „ ์„ ์–ธํ•ด์•ผ ํ•  ํด๋ž˜์Šค๋Š” ๋ชจ๋‘ ์„ ์–ธํ•˜์˜€์Šต๋‹ˆ๋‹ค.
๐Ÿ˜ ๊ทธ๋Ÿฌ๋ฉด, ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

Train Code

Dataset instance ์ƒ์„ฑ

size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

train_dataset = MyDataset(img_list=train_img_list, label_list=train_label_list, phase="train", transform=MyTransform(
    size, mean, std)))

val_dataset = MyDataset(img_list=val_img_list, label_list=val_label_list, phase="val", transform=MyTransform(
    size, mean, std)))

image_datasets = {'train' : train_dataset, 'val' : val_dataset}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

DataLoader instance ์ƒ์„ฑ

batch_size = 32

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

# ์‚ฌ์ „ ๊ฐ์ฒด์— ์ •๋ฆฌ
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

ViT model

import timm

num_classes = 120
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)

loss function & Optimizer

import torch.optims as optims
import torch.nn as nn
from torch.optim import lr_scheduler

citerion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

์ž, ์ง„์งœ ํ•™์Šต ์‹œ์ž‘์ž…๋‹ˆ๋‹ค! :)

Train loop ์ž‘์„ฑ


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

epochs = 10

for epoch in range(epochs):
	print("{}/{} epoch running now".format(epoch, epochs - 1))
    
    for phase in ['train', 'val']:
    	if phase == 'train':
        	model.train()
        else:
        	model.eval()
        
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in dataloaders[phase]:
        	inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            # requires_grad = True, when Training
            with torch.set_grad_enabled(phase == 'train'):
            	outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                
                if phase == 'train':
                	# ์†์‹ค ๊ณ„์‚ฐ (๊ฐ Tensor์˜ gradient ๊ณ„์‚ฐ)
                	loss.backward()
                    # ๊ณ„์‚ฐ๋œ ์†์‹ค์„ optimizer์˜ ๊ณ„์‚ฐ์— ๋”ฐ๋ผ weight ์กฐ์ •
                    optimizer.step()
                    
            running_loss += loss.item() * inputs.size(0)
            running_corrects = torch.sum(preds == labels.data)
        
        if phase == 'train':
        	scheduler.step()
            
        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]
        
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
		
        # ๋ชจ๋ธ ์ €์žฅ
        if phase == 'val' and epoch_acc > best_acc :
        	best_acc = epoch_acc
            torch.save('./best_model.pth')

๐Ÿ˜Ž ์ด๋กœ์จ ๊ฐ•์•„์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ํ™œ์šฉํ•˜์—ฌ ViT ๋ชจ๋ธ์˜ Transfer Learning์„ ๊ฐ„๋‹จํžˆ ์‹ค์Šตํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•ด ๋ณด์•˜์Šต๋‹ˆ๋‹ค.
๐Ÿ˜ ํ•˜์ง€๋งŒ ๋…ผ๋ฌธ์— ๋ช…์‹œ๋œ Optimizer๋‚˜, learning rate scheduler๋Š” ๋‹ค๋ฅธ ๋ถ€๋ถ„์ด ์žˆ๊ธฐ์—, ์ด ๋ถ€๋ถ„์€ ์ถ”ํ›„ ๊ฐœ์„  ์‚ฌํ•ญ์ด ๋˜๊ฒ ๊ตฐ์š” :)

Visualization Code

https://tutorials.pytorch.kr/beginner/transfer_learning_tutorial.html

๐Ÿ‘ ์‹œ๊ฐํ™” ์ฝ”๋“œ๋Š” ์œ„์˜ ํŠœํ† ๋ฆฌ์–ผ ํŽ˜์ด์ง€์˜ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

๐Ÿ˜‰ ์•„๋ž˜์™€ ๊ฐ™์ด ์‚ฌ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

visualize_model(model)

๋งˆ์น˜๋ฉฐ

๐ŸŽ‰ WOW! ๊ฐ๋™์ ์ด๊ฒŒ๋„ ๋›ฐ์–ด๋‚œ ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ค€ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์™€์„œ Transfer Learning์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐ ์„ฑ๊ณตํ•˜์˜€์Šต๋‹ˆ๋‹ค!
๐Ÿ˜‹ ๋ฌผ๋ก , ๋…ผ๋ฌธ์˜ ๋ฒค์น˜๋งˆํฌ๋ฅผ ์žฌํ˜„ํ•˜๋ ค๋ฉด ๋…ผ๋ฌธ ๊ทธ๋Œ€๋กœ์˜ Training sceinaro์™€ Optimizer, ๊ทธ๋ฆฌ๊ณ  learning rate scheduling ๋“ฑ์ด ํฌํ•จ๋˜์–ด์•ผ ํ•˜์ง€๋งŒ, ์ผ๋‹จ ํ•ด๋‚ธ๊ฒŒ ์–ด๋””์—์š” :)
๐Ÿ˜ ๋‹ค์Œ ๊ธฐํšŒ์—๋Š” ์ข€๋” Advanced ํ•œ ๊ตฌํ˜„์œผ๋กœ ๋„์ „ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค!

profile
MLOps, ML Engineer. ๋ฐ์ดํ„ฐ์—์„œ ์‹œ์Šคํ…œ์œผ๋กœ, ์‹œ์Šคํ…œ์—์„œ ๊ฐ€์น˜๋กœ.

0๊ฐœ์˜ ๋Œ“๊ธ€