[PyTorch] 모델 load시 Unexpected key(s) in state_dic 문제

안수진·2023년 11월 16일
0

Capstone Design Project

목록 보기
2/2

🙅‍♀️다른 모델을 내 모델로 불러올 때 Unexpected key(s) in state_dic 문제

종합설계 발표를 위해서 팀원들과 모여 기능을 확인하던 중 학습시킨 KoBERT 모델을 load하면서 아래와 같은 에러가 발생했다.

"/root/ML_Server/KoBERTModel/predict.py", line 20, in load_model
    model.load_state_dict(torch.load('KoBERTModel/model/train.pt'))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BERTClassifier:
        Unexpected key(s) in state_dict: "bert.embeddings.position_ids". 

이 에러는 PyTorch모델의 load_state_dict 함수를 호출할 때 발생하는 것으로 보입니다. 이 오류는 모델을 로드할 때 발견된 예상하지 못한 키(Unexpected key)로 인해 발생한다는 GPT의 답변을 얻었다.

기존 코드

def load_model():
    global model
    model = BERTClassifier(bertmodel,  dr_rate=0.4).to(device)

    model.load_state_dict(torch.load('KoBERTModel/model/train.pt'))
    model.eval() 

🙆‍♀️해결방법

위 코드에서 model.load_state_dict()의 인자에 strictFalse로 설정하여 일치하지 않는 키들을 무시하도록 설정하는 것이다.

부분적으로 모델을 불러오거나, 모델의 일부를 불러오는 것은 전이학습 또는 새로운 복잡한 모델을 학습할 때 일반적인 시나리오이다.
학습된 매개변수를 사용하면, 일부만 사용한다 하더라도 학습 과정을 빠르게 시작할 수 있고, 처음부터 시작하는 것보다 훨씬 빠르게 모델이 수렴하도록 도울 것이다.

몇몇 키를 제외하고 state_dict의 일부를 불러오거나, 적재하려는 모델보다 더 많은 키를 갖고 있는 state_dict를 불러올 때에는 load_state_dict() 함수에서 strict 인자를 False로 설정하여 일치하지 않는 키들을 무시하도록 해야 한다.

profile
멋쟁이 개발자 지망생

0개의 댓글