PyTorch에서 모델 저장하기

박요셉·2023년 6월 1일
0

DeepLearningETC

목록 보기
1/4

아주아주아주 중요하다.
나같이 열악한 환경에서 공부 중이라면 더더욱..
큰 모델을 돌리거나 학습이 오래걸린다면 더더더더욱...

Code


먼저 colab을 google drive에 mount하자.
하면 창이 뜰텐데 확인 누르고 연결해주면 된다.

PATH = '/content/gdrive/MyDrive/Colab Notebooks/inceptionv4.pt'
import os.path

epoch_start = 1

if os.path.exists(PATH):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_start = checkpoint['epoch'] + 1
    print("successfully loaded!")
    print("epoch saved until here: ", epoch_start-1)
    print("train starts from this epoch: Epoch ", epoch_start)

그 다음 path를 지정하자.
os.path를 통해 먼저 그 경로에 파일이 있는지 확인한다.
있다는 것은 곧 이전에 학습을 했다는 뜻이므로 불러와준다.
ResNet에서는 model, optimizer, epoch_start 만 저장했다.
+) Inception-v4에서 약간 수정했다.

더 필요하다면 저장할 수 있다.
예를 들어 각 단계의 loss라던지.

여튼 성공적으로 load한 지 알기 위해 문구도 출력하고,
어느 epoch까지 학습했는지도 볼 수 있게 만들었다.

print(f"we will start from:{epoch_start}")
# train
for epoch in range(epoch_start, args.n_epochs+1):
    model.train()
    train_loss = 0
    correct, count = 0, 0
    for batch_idx, (images, labels) in enumerate(train_loader, start=1):
        images, labels = images.to(device), labels.to(device)
        output = model(images)
        optimizer.zero_grad()
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, preds = torch.max(output, 1) # torch max output is (max, max_index)
        count += labels.size(0)
        correct += torch.sum(preds == labels)
        
    loss_hist.append(train_loss/count)
    accuracy_hist.append(correct/count)
    print(f"[*] Epoch: {epoch} \tTrain accuracy: {correct/count} \tTrain Loss: {train_loss/count}")
    torch.save({
        'epoch' : epoch,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        }, PATH)

end_time = time.time()

print(f"Training time : {end_time - start_time}")

여기서 torch.save 파트를 보자.
각 단계의 epoch, model_state_dict(현 model의 parameter들을 저장함), optimizer_state_dict 를 저장한다.
나같은 경우, epoch마다 저장하게 했다.
더 자주 해도 되고, 더 띄엄띄엄해도 된다.

여튼 미리 지정한 PATH에 계속 저장해서 덮어씌우는 방식을 채택했다.
만일 여러 epoch마다의 parameter가 따로 필요하다면 PATH를 바꾸며 저장하면 된다.

ex. model_epoch001.pt와 같이.
파일 구조나 이런 것들은 다른 블로그에서 자세히 설명해주니 참고하길 바란다.

이렇게 하니... 여태 매번 멈췄다 했다가 하던걸 이어서 할 수 있어 너무 좋다.

여기에 여태 loss, accuracy list도 저장할 수 있게 기능을 추가해보려 한다.
torch.save, 아주 만족스럽다.

profile
개발 폐관수련중, ML, DL 무림 초보

0개의 댓글