[boostcamp-ai-tech][PyTorch] 모델불러오기(Transfer learning)

whatSup CheatSheet·2022년 1월 29일
0

PyTorch

목록 보기
4/5
post-thumbnail

모델 저장 및 불러오기

model.save()

  • 학습의 결과를 저장하기 위한 함수

  • 모델 형태(architecture)와 parameter를 저장

    • 모델 parameter 저장
    ## 모델 저장(ordered dict 형태)
    MODEL_PATH ="saved"  
    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)
    torch.save(model.state_dict(), 
               os.path.join(MODEL_PATH, "model.pt"))  # .pt확장자
    
    ## 모델 로드
    new_model = TheModelClass()  # 동일한 모델일 때
    # 모델 로드
    new_model.load_state_dict(torch.load(os.path.join(
        MODEL_PATH, "model.pt")))
    • 모델 architecture 저장
    # 모델 자체를 저장(모델의 architecture 저장)
    torch.save(model, os.path.join(MODEL_PATH, "model_pickle.pt"))  # python pickle 확장자
    model = torch.load(os.path.join(MODEL_PATH, "model_pickle.pt"))
    model.eval()

checkpoints

  • 학습의 중간결과를 저장하여 최선의 결과를 선택
  • epoch, loss, metric값을 지속적으로 확인 및 저장
## 데이터 저장
...(생략)...
for epoch in range(epochs):
...(생략)...

# 단계마다 저장(EPOCHS 단위)
    torch.save({
        'epoch': e,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': epoch_loss,
        }, f"saved/checkpoint_model_{e}_{epoch_loss/len(dataloader)}_{epoch_acc/len(dataloader)}.pt")
        
    print(f'Epoch {e+0:03}: | Loss: {epoch_loss/len(dataloader):.5f} | Acc: {epoch_acc/len(dataloader):.3f}')

## 데이터 로드
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict']
epoch = checkpoint['epoch']
loss = checkpoint['loss']

pretained model Transfer learning

  • 다른 데이터셋으로 만든 모델을 현재 데이터에 적용하는 것(현재 DL에서 가장 일반적인 학습 기법임)
  • backbone architecture가 잘 학습된 모델에서 일부분만 변경하여 학습을 수행함.
  • pretained model을 활용시 모델의 일부분을 frozen시킴
profile
AI Engineer : Lv 0

0개의 댓글