어제 torch.jit.script
활용하여 pt 파일을 생성해보려 했지만, 제대로 되지 않았다.
따라서, yolo v7의 소스코드에 있는 export.py
를 분석하여, 어떤 방식으로 pyTorch 모델을 architecture와 parameter가 담긴 pt파일이 생성된다.
어떤 방식으로 이루어지는지 알아보도록 하겠다.
python export.py --weights yolov7-tiny.pt --grid --end2end --simplify \
--topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640 --max-wh 640
위의 명령어를 활용하여 export 시킨다.
model = attempt_load(opt.weights, map_location=device) # load FP32 model
attempt_load() 함수 부분
model = Ensemble()
https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html
모듈 정보를 담기위해 초기화해주는 작업
ckpt = torch.load(w, map_location=map_location)
for w in weights if isinstance(weights, list) else [weights]:
#attempt_download(w)
ckpt = torch.load(w, map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32
ckpt[’model’]
안에 있는 정보를 model 인스턴스에 담아준다.
for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True # pytorch 1.7.0 compatibility
elif type(m) is nn.Upsample:
m.recompute_scale_factor = None # torch 1.11.0 compatibility
elif type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
torch.nn을 상속받은 model을 load하고 node별로 type에 따라서 알맞는 초기화를 진행해줌.
recompute_scale_factor
초기화 https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html_non_persistent_buffers_set
labels = model.names
labels 안에 객체 탐지할 클래스 정보가 담겨있음.
gs = int(max(model.stride)) # grid size (max stride)
model script 방식으로 save할 때, stride 정보가 models/yolo.py
안의 Model 클래스에서 계산할 때, Datatype 및 size 등 다양한 문제가 존재했었는데, 인스턴스 로드했을 때는 정상적으로 잡혀서 값이 들어간 것을 볼 수 있음.
# Input
img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection
input size를 조절해줌.
아직은 모르겠지만 여기서 문제가 있을거라 예상됨…
input tensor를 dummy input으로 (32,3,256,256)으로 맞춰줬었는데, (1,3,640,640)으로 세팅해줌.
이후, 모델을 다시 한 번 update 해준다.
# Update model
for k, m in model.named_modules():
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv): # assign export-friendly activations
if isinstance(m.act, nn.Hardswish):
m.act = Hardswish()
elif isinstance(m.act, nn.SiLU):
m.act = SiLU()
# elif isinstance(m, models.yolo.Detect):
# m.forward = m.forward_export # assign forward (optional)
model.model[-1].export = not opt.grid # set Detect() layer grid export
해당되는 활성화함수가 해당 레이어에 제대로 담겼는지, 그게 아니면 업데이트해주는 과정으로 보인다.
그리고 마지막 레이어인 Detect 레이어에서는 Detect Layer를 export 할 것인지, 아닌지에 대해 선택해준다.
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
ts.save(f)
결국 앞서 초기화하고 설정해준 pt 정보를 f라는 변수에 넣어주고, torch.jit.trace
를 활용하여 ts라는 변수에 담아주게 된다. 이후, ts.save(f)
명령어를 활용하여 저장해주게 된다.
# ONNX export
try:
import onnx
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pt', '.onnx') # filename
model.eval()
기존의 가중치 파일이 담긴 f를 onnx 파일로 변경해주고, model.eval() 모드로 변경해줘서, Dropout,이나 BatchNormalization과 같이 evaluation 과정에서 필요없는 모듈을 onnx에 담지 않도로 해주었다.
class End2End(nn.Module):
'''export onnx or tensorrt model with NMS operation.'''
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
super().__init__()
device = device if device else torch.device('cpu')
assert isinstance(max_wh,(int)) or max_wh is None
self.model = model.to(device)
self.model.model[-1].end2end = True
self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
self.end2end.eval()
def forward(self, x):
x = self.model(x)
x = self.end2end(x)
return x
변환이 되지 않는 NMS의 경우는 위와 같이 커스텀해서 변환해주는 옵션인거 같다.
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=output_names,
dynamic_axes=dynamic_axes)
결국 허무하지만, onnx export 하는 것은 일반적인 방법론과 똑같았다.
결국은 trace인가…
다시 한 번 요약을 하자면
torch.save
를 활용하여 저장해준다.torch.nn.Upsample
recompute scale factor 초기화torch.jit.trace
사용해서 Trace방식으로 모델 Export마지막 레이어인 Detect() 클래스
output names 수정해줘야함
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
onnx.export
모듈 활용하여 Onnx 모델로 Export 해 줌.