Pytorch: model save/load

danbibibi·2022년 2월 5일
0

PyTorch 🔥

목록 보기
10/20

모델 저장

import torch

torch.save(model.state_dict(), PATH)

모델 불러오기

GPU에서 save, CPU에서 load

import torch

device = torch.device('cpu')
model = CNN(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

GPU에서 save, GPU에서 load

import torch

device = torch.device('cuda')
model = CNN(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

CPU에서 save, GPU에서 load

import torch

device = torch.device('cuda')
model = CNN(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # 사용할 GPU 장치 번호 선택
model.to(device)  # CUDA Tensor 형 변환
profile
블로그 이전) https://danbibibi.tistory.com

0개의 댓글