아주아주아주 중요하다.
나같이 열악한 환경에서 공부 중이라면 더더욱..
큰 모델을 돌리거나 학습이 오래걸린다면 더더더더욱...
먼저 colab을 google drive에 mount하자.
하면 창이 뜰텐데 확인 누르고 연결해주면 된다.
PATH = '/content/gdrive/MyDrive/Colab Notebooks/inceptionv4.pt'
import os.path
epoch_start = 1
if os.path.exists(PATH):
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_start = checkpoint['epoch'] + 1
print("successfully loaded!")
print("epoch saved until here: ", epoch_start-1)
print("train starts from this epoch: Epoch ", epoch_start)
그 다음 path를 지정하자.
os.path
를 통해 먼저 그 경로에 파일이 있는지 확인한다.
있다는 것은 곧 이전에 학습을 했다는 뜻이므로 불러와준다.
ResNet에서는 model, optimizer, epoch_start 만 저장했다.
+) Inception-v4에서 약간 수정했다.
더 필요하다면 저장할 수 있다.
예를 들어 각 단계의 loss라던지.
여튼 성공적으로 load한 지 알기 위해 문구도 출력하고,
어느 epoch까지 학습했는지도 볼 수 있게 만들었다.
print(f"we will start from:{epoch_start}")
# train
for epoch in range(epoch_start, args.n_epochs+1):
model.train()
train_loss = 0
correct, count = 0, 0
for batch_idx, (images, labels) in enumerate(train_loader, start=1):
images, labels = images.to(device), labels.to(device)
output = model(images)
optimizer.zero_grad()
loss = criterion(output, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, preds = torch.max(output, 1) # torch max output is (max, max_index)
count += labels.size(0)
correct += torch.sum(preds == labels)
loss_hist.append(train_loss/count)
accuracy_hist.append(correct/count)
print(f"[*] Epoch: {epoch} \tTrain accuracy: {correct/count} \tTrain Loss: {train_loss/count}")
torch.save({
'epoch' : epoch,
'model_state_dict' : model.state_dict(),
'optimizer_state_dict' : optimizer.state_dict(),
}, PATH)
end_time = time.time()
print(f"Training time : {end_time - start_time}")
여기서 torch.save
파트를 보자.
각 단계의 epoch, model_state_dict(현 model의 parameter들을 저장함), optimizer_state_dict 를 저장한다.
나같은 경우, epoch마다 저장하게 했다.
더 자주 해도 되고, 더 띄엄띄엄해도 된다.
여튼 미리 지정한 PATH에 계속 저장해서 덮어씌우는 방식을 채택했다.
만일 여러 epoch마다의 parameter가 따로 필요하다면 PATH를 바꾸며 저장하면 된다.
ex. model_epoch001.pt
와 같이.
파일 구조나 이런 것들은 다른 블로그에서 자세히 설명해주니 참고하길 바란다.
이렇게 하니... 여태 매번 멈췄다 했다가 하던걸 이어서 할 수 있어 너무 좋다.
여기에 여태 loss, accuracy list도 저장할 수 있게 기능을 추가해보려 한다.
torch.save
, 아주 만족스럽다.