간단한 모델을 만들어서, torch._C.script
에서 torch.jit.script 타입으로 변환되는 것인지, 아니면 torch.onnx
에서 변환되는 것인지 확인해보려고 합니다.
torch.jit.script 공식 문서를 확인해보면, torch._C.script
에서 if문과 같은 분기가 제대로 변환이 되지 않는 것으로 추측됩니다.
기본적으로 trace의 경우 Input 값도 같이 넣어주고, 변환과정을 거치기 때문에, Input 및 분기와 루프에서의 해당 조건에 종속적입니다. 즉 Path(경로)가 정해져있기 때문에, 여러 branch(분기)로 나누어 모델을 static하게 표현하지 못합니다.
따라서, torch.jit.script
방식을 사용하면 trace에서 발생하는 단점을 해결할 수 있다는 장점이 있습니다. 하지만 trace ↔ 일종의 인터프리터 방식(?) (파이썬의 특장점을 살림)과는 다르게 script ↔ 컴파일 (C,++,Java) 방식을 사용하게 되면 내부 구조를 파악하기 쉽지 않습니다.
script방식을 사용하면 IF분기도 모델에 담기기 때문에, 이번 기회에 간단한 모델로 디버깅해보며 분석해보려고합니다.
※ 지난 포스팅의 문제점
model_instance = model()
tmp = model_instance(data)
script_from_model = torch.jit.script(tmp)
여기서 tmp에 모델 인스턴스의 input까지 같이 넣어주게 되면, tmp 자체는 모델의 최종 결과값이 되게 됩니다. 따라서 지난시간에 torch.jit.script 모듈 안에서 자료형이 리스트타입으로 분리되어 처리되었던 이유도 Inference 거쳤던 결과값을 넣어주고, 이것을 torch.jit.script로 변경해주었기 때문입니다…
Custom Model torch.jit.scriptDebugging
<class 'torch.jit._script.RecursiveScriptModule'>
에 담아주어 파이토치 jit.script의RecursiveScriptModule로 반환해줍니다. 위 디버깅 과정을 통해서 torch.jit.script으로 변환되는 과정에서 forward내부 구문을 직접적으로 처리해주는 과정은 보이지 않았습니다. 또한 forward구문을 타고 그 안의 IF문을 처리하는 부분도 찾을 수 없었습니다. script이기 때문에 torch.nn 아래 타입별로 변환만 해주고 내부 구조자체를 보지 않는 것으로 추측됩니다.