[PYTORCH]모델 저장

신동혁·2022년 12월 12일
0

Pytorch

목록 보기
3/4

공식문서 링크 : https://tutorials.pytorch.kr/beginner/saving_loading_models.html

1.모델 저장하는 법

  1. torch.save
    직렬화된 객체를 디스크에 저장합니다. 이 함수는 Python의 pickle 을 사용하여 직렬화합니다. 이 함수를 사용하여 모든 종류의 객체의 모델, Tensor 및 사전을 저장할 수 있습니다.

  2. torch.load
    pickle을 사용하여 저장된 객체 파일들을 역직렬화하여 메모리에 올립니다. 이 함수는 데이터를 장치에 불러올 때에도 사용됩니다. (장치 간 모델 저장하기 & 불러오기 참고)

  3. torch.nn.Module.load_state_dict
    역직렬화된 state_dict 를 사용하여 모델의 매개변수들을 불러옵니다.

1) state_dict란?

  • state_dict 가 무엇인가요?
    PyTorch에서 torch.nn.Module 모델의 학습 가능한 매개변수(예. 가중치와 편향)들은 모델의 매개변수에 포함되어 있습니다(model.parameters()로 접근합니다). state_dict 는 간단히 말해 각 계층을 매개변수 텐서로 매핑되는 Python 사전(dict) 객체입니다. 이 때, 학습 가능한 매개변수를 갖는 계층(합성곱 계층, 선형 계층 등) 및 등록된 버퍼들(batchnorm의 running_mean)만이 모델의 state_dict 에 항목을 가짐을 유의하시기 바랍니다. 옵티마이저 객체(torch.optim) 또한 옵티마이저의 상태 뿐만 아니라 사용된 하이퍼 매개변수(Hyperparameter) 정보가 포함된 state_dict 를 갖습니다.

state_dict 객체는 Python 사전이기 때문에 쉽게 저장하거나 갱신하거나 바꾸거나 되살릴 수 있으며, PyTorch 모델과 옵티마이저에 엄청난 모듈성(modularity)을 제공합니다.

2. 모델 저장 및 불러오기 실습

참고한 블로그 링크 : https://justkode.kr/deep-learning/pytorch-save

1) 저장

import torch
import torch.nn as nn

# 모델 생성
x_data = torch.Tensor([
    [0, 0],
    [1, 0],
    [1, 1],
    [0, 0],
    [0, 0],
    [0, 1]
])

y_data = torch.LongTensor([
    0,  # etc
    1,  # mammal
    2,  # birds
    0,
    0,
    2
])

class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.w1 = nn.Linear(2, 10)
        self.bias1 = torch.zeros([10])

        self.w2 = nn.Linear(10, 3)
        self.bias2 = torch.zeros([3])
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        y = self.w1(x) + self.bias1
        y = self.relu(y)

        y = self.w2(y) + self.bias2
        return y

model = DNN()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


# epoch을 돌 때마다 전체모델, 모델state_dict, 여러정보를 각각 저장한다.
for epoch in range(10):
    output = model(x_data)

    loss = criterion(output, y_data)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    PATH = '저장할 경로 알아서 설정'
    torch.save(model, PATH + f'{epoch}th_model.pt')  # 전체 모델 저장
    torch.save(model.state_dict(), PATH + f'{epoch}th_model_state_dict.pt')  # 모델 객체의 state_dict 저장
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, PATH + f'{epoch}th_all.tar')

    print("progress:", epoch, "loss=", loss.item())

2) 불러오기

model = torch.load(PATH + 'model.pt')  # 전체 모델을 통째로 불러옴, 클래스 선언 필수
model.load_state_dict(torch.load(PATH + 'model_state_dict.pt'))  # state_dict를 불러 온 후, 모델에 저장

checkpoint = torch.load(PATH + 'all.tar')   # dict 불러오기
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
torch.save(modelA.state_dict(), PATH)  # 저장하기

modelB = TheModelBClass(*args, **kwargs)  # 불러오기
modelB.load_state_dict(torch.load(PATH), strict=False)
  • 다른 모델의 매개변수 사용하기
    모델의 매개변수의 일부만 불러 사용하는 것은 전이학습을 이용할 때 자주 사용합니다. state_dict의 일부만 불러오거나, 적재하려는 모델보다 더 많은 키를 갖고 있는 state_dict를 불러 올때는, load_state_dict() 함수의 파라미터에 strict=False를 입력 해 주면 됩니다.
torch.save(modelA.state_dict(), PATH)  # 저장하기

modelB = TheModelBClass(*args, **kwargs)  # 불러오기
modelB.load_state_dict(torch.load(PATH), strict=False)
profile
개발취준생

0개의 댓글