from numpy import argmax
import torchvision.transforms as transforms
import torch
import os
from tqdm import tqdm
"""1. aug 2. train loop 3. val loop 4. save model 5. eval"""
def data_augmentation():
""" data augmentation 함수"""
data_transform = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.4),
transforms.RandomVerticalFlip(p=0.2),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.2, 0.2, 0.2])
]),
'test': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.2, 0.2, 0.2])
])
}
return data_transform
"""train loop"""
def train(num_epoch, model, train_loader, test_loader, criterion, optimizer,
save_dir, val_every, device):
print("String... train !!! ")
best_loss = 9999
for epoch in range(num_epoch):
for i, (imgs, labels) in enumerate(train_loader):
print(i)
imgs, labels = imgs.to(device), labels.to(device)
output = model(imgs)
loss = criterion(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, argmax = torch.max(output, 1)
acc = (labels == argmax).float().mean()
print("Epoch [{}/{}], Step [{}/{}], Loss : {:.4f}, Acc : {:.2f}%".format(
epoch + 1, num_epoch, i +
1, len(train_loader), loss.item(), acc.item() * 100
))
if (epoch + 1) % val_every == 0:
avg_loss = validation(
epoch + 1, model, test_loader, criterion, device)
if avg_loss < best_loss:
print("Best prediction at epoch : {} ".format(epoch + 1))
print("Save model in", save_dir)
best_loss = avg_loss
save_model(model, save_dir)
save_model(model, save_dir, file_name="last.pt")
def validation(epoch, model, test_loader, criterion, device):
print("Start validation # {}".format(epoch))
model.eval()
with torch.no_grad():
total = 0
correct = 0
total_loss = 0
cnt = 0
for i, (imgs, labels) in enumerate(test_loader):
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
loss = criterion(outputs, labels)
total += imgs.size(0)
_, argmax = torch.max(outputs, 1)
correct += (labels == argmax).sum().item()
total_loss += loss
cnt += 1
avg_loss = total_loss / cnt
print("Validation # {} Acc : {:.2f}% Average Loss : {:.4f}%".format(
epoch, correct / total * 100, avg_loss
))
model.train()
return avg_loss
def save_model(model, save_dir, file_name="best.pt"):
output_path = os.path.join(save_dir, file_name)
torch.save(model.state_dict(), output_path)
def eval(model, test_loader, device):
print("Starting evaluation")
model.eval()
total = 0
correct = 0
with torch.no_grad():
for i, (imgs, labels) in tqdm(enumerate(test_loader)):
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
# 점수가 가장 높은 클래스 선택
_, argmax = torch.max(outputs, 1)
total += imgs.size(0)
correct += (labels == argmax).sum().item()
print("Test acc for image : {} ACC : {:.2f}".format(
total, correct / total * 100))
print("End test.. ")
util_file.py을 고쳐서 메인에서 돌아가게끔 하려는데, 어렵다.
점심도 안 먹고 지금까지 하고있는데..
생각대로 코딩을 하는 게 어렵다 :(