지난번에 알아보았던, torch.jit.script
로 pt 파일을 내보낼 때, torch.eq
모듈이 지원되지 않는다며, 정상적으로 netron에서 열리지 않는 문제와 Onnx로 변환하였을 때, if,else와 같은 분기가 존재할 경우, Condition에 해당되는 분기만을 포워딩하는 것을 확인할 수 있었다.
따라서 해당 문제를 살펴보려고 한다.
결론부터 말하자면 해당 문제가 발생했던 이유는, 파이토치 모델 클래스를 만들 때, 사용하는 forward()
함수가 문제였다.
클래스에서 인스턴스를 생성할 때, __init__
이라는 생성자 함수가 호출된다. 마찬가지로 인스턴스가 호출될 때, __call__
이라는 호출자 함수가 호출된다.
파이토치에서 forward
함수 역시, torch.nn.module
클래스를 사용하게 되면, 호출 함수로 사용하게 된다.
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
해당 내용은 1178 라인에 존재한다.
forward
함수가 실행되는데, 한 가지 의문스러운 점이 있었다.
https://pytorch.org/docs/master/jit_language_reference.html#id2
def forward(self, x, y, z):
# type: (Optional[int], Optional[int], Optional[int]) -> int
if x is None:
x = 1
x = x + 1
# Refinement for an attribute by assigning it to a local
z = self.z
if y is not None and z is not None:
x = y + z
# Refinement via an `assert`
assert z is not None
x += z
return x
파이토치 레퍼런스에 존재하는 torch.jit.script
의 일부분을 발취해서 가져왔다.
다른 예제들도 그렇고 항상 input 텐서 값들을 기준으로 if 조건을 걸어서 활용하지, boolean값과 같은 다른 상태값으로 분기를 나누지는 않았다.
def forward(self,x):
x = self.custom_layer(x)
x = self.custom_layer2(x)
if self.branch_testing==True:
x = self.custom_layer2(x)
x = self.layer1(x)
else:
x = self.layer1(x)
x = self.layer2(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
return x
내가 직접 커스터마이징했던 포워딩 구문이다. 나는 input값을 기준으로 분기를 나눴던 것이 아니라, 어떠한 condition에 따라, 분기를 나눴었다.
aten::eq.str_list(str[] a, str[] b) -> (bool):
Expected a value of type 'List[str]' for argument 'a' but instead found type 'NoneType'.
eq(float a, Tensor b) -> (Tensor):
Expected a value of type 'float' for argument 'a' but instead found type 'NoneType'.
eq(int a, Tensor b) -> (Tensor):
Expected a value of type 'int' for argument 'a' but instead found type 'NoneType'.
The original call is:
File "forward_if_simple_script.py", line 57
def forward(self,x):
if self.branch_testing==True:
~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
x = self.custom_layer1(x)
else:
혹시 몰라서 boolean input값을 주지 않고, None 값을 넣어줬을 때, 다음과 같은 오류가 발생했다. 이해가 되지 않았던 점은 boolean값으로 비교해주고 싶었는데, Tensor로 비교하려고 하는 것 같았다.
def forward(self,x):
if x[0][0][0][0] % 2 == 0:
x = self.custom_layer(x)
x = self.custom_layer(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.condtion_func(x)
return x
그래서 약간 허접하지만, 텐서값을 기준으로 분기를 설정해보았다.
두둥, 그러고 pyTorch pt파일과 Onnx파일로 export해보았다.
graph(%self : __torch__.CustomModel,
%x.1 : Tensor):
%3 : int = prim::Constant[value=0]() # forward_if_simple_script.py:67:13
%21 : int = prim::Constant[value=2]() # forward_if_simple_script.py:67:27
%14 : Tensor = aten::select(%x.1, %3, %3) # forward_if_simple_script.py:67:11
%16 : Tensor = aten::select(%14, %3, %3) # forward_if_simple_script.py:67:11
%18 : Tensor = aten::select(%16, %3, %3) # forward_if_simple_script.py:67:11
%20 : Tensor = aten::select(%18, %3, %3) # forward_if_simple_script.py:67:11
%22 : Tensor = aten::remainder(%20, %21) # forward_if_simple_script.py:67:11
%23 : Tensor = aten::eq(%22, %3) # forward_if_simple_script.py:67:11
%25 : bool = aten::Bool(%23) # forward_if_simple_script.py:67:11
%x : Tensor = prim::If(%25) # forward_if_simple_script.py:67:8
block0():
%custom_layer.1 : __torch__.CustomLayer = prim::GetAttr[name="custom_layer"](%self)
%x.9 : Tensor = prim::CallMethod[name="forward"](%custom_layer.1, %x.1) # forward_if_simple_script.py:68:16
-> (%x.9)
block1():
-> (%x.1)
%custom_layer : __torch__.CustomLayer = prim::GetAttr[name="custom_layer"](%self)
%x.23 : Tensor = prim::CallMethod[name="forward"](%custom_layer, %x) # forward_if_simple_script.py:69:12
%layer1 : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name="layer1"](%self)
%x.27 : Tensor = prim::CallMethod[name="forward"](%layer1, %x.23) # forward_if_simple_script.py:70:12
%layer2 : __torch__.torch.nn.modules.container.___torch_mangle_1.Sequential = prim::GetAttr[name="layer2"](%self)
%x.31 : Tensor = prim::CallMethod[name="forward"](%layer2, %x.27) # forward_if_simple_script.py:71:12
%flatten : __torch__.torch.nn.modules.flatten.Flatten = prim::GetAttr[name="flatten"](%self)
%x.35 : Tensor = prim::CallMethod[name="forward"](%flatten, %x.31) # forward_if_simple_script.py:72:12
%fc1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc1"](%self)
%x.39 : Tensor = prim::CallMethod[name="forward"](%fc1, %x.35) # forward_if_simple_script.py:73:12
%condtion_func : __torch__.torch.nn.modules.linear.___torch_mangle_2.Linear = prim::GetAttr[name="condtion_func"](%self)
%x.43 : Tensor = prim::CallMethod[name="forward"](%condtion_func, %x.39) # forward_if_simple_script.py:74:12
return (%x.43)
graph를 찍어보았을 때, 정상적으로 일단 분기가 나뉜 것을 확인할 수 있었다.
한 가지 기억해야될 점은 Onnx file을 netron으로 시각화해보았을 때도 마찬가지로 정상적으로 분기가 나뉘었다. torch.jit.script
로 만든 pt파일의 경우, 시각화하였을 때, x라는 Input을 넣었을 때, True 조건에 대해서만 간선이 연결된 시각화가 이루어진다. 즉 else 정보도 갖고는 있지만, forward 그래프를 표기할 때, If에 대해서 Onnx처럼 하나의 노드로 보는 것이 아니라, python이나 pytorch의 기본 메서드라고 생각하는 것에 가까운 것 같다.
위쪽이 pytorch pt 모델이고, 아래쪽이 Onnx까지 export했을 때의 경우이다.
그래프 중간에 If문이 포함되어 있는것을 확인할 수 있었다. 이 if가 얼마나 반가운지,,,
두 개의 branch로 나눠지는 것을 확인할 수 있었다.
결론, 3줄 요약
torch.jit.script
에서는 export되지 않는다. trace와 동일하게 export된다.