torch forward if 타입별 비교

게으른 개미개발자·2022년 10월 5일
0

model_conversion

목록 보기
5/13

지난번에 pyTorch모델에서 forward함수를 변환하는 데 있어서, If 구문에 들어가는 타입에 따라서 torch.jit.script가 정상적으로 변환이 되는지, 안 되는지를 알 수 있었다.

그렇다면 어떤 구문에서 변환이 제대로 안 되는지, 정보를 알려주는 것도 하나의 방법이 될 수 있다.

우선, 크게 두 가지 실험을 해 볼 예정이다.

  1. torch.jit.script 변환 코드에서 .graph.code 를 확인해보면서, IF 구문이 Tensor 혹은 Datatype에 따라서 어떤 식으로 변환되는지 확인해보기
  2. torch.jit.tracetorch.jit.script 두 가지 버전을 비교해볼 때, 마찬가지로 어떤 식으로 변환되는지 .graph.code 확인해보고 변환이 제대로 안되는 부분을 캐치해낼 수 있는지 확인해 볼 예정이다.
  3. 1번 방식을 활용한다면 결국 텍스트이기때문에, 정규식을 활용하여 parsing을 할 것이다.
  4. 만약에 Onnx처럼 graph를 자료구조 형태로 볼 수 있다면, 다른 모듈들도 응용해서 사용이 가능하기 때문에, 활용해볼만한 가치가 있다.

먼저 1번의 경우를 테스트해보겠다.

torch.jit.script.code 활용하여 IF문 파악

#IF문이 tensor 값 관련
def forward(self,x):
        if x[0][0][0][0] % 2 == 0:
            x = self.custom_layer(x)
        x = self.custom_layer(x)
        x = self.layer1(x)    
        x = self.layer2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.condtion_func(x)
        return x
#IF문이 boolean 혹은 다른 값 관련
def forward(self,x):
        if self.branch_testing==True:
            x = self.custom_layer1(x)
        else:
            x = self.custom_layer2(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        if self.condition_testing == True:
            x = self.condition_func1(x)
        else:
            x = self.condition_func2(x)
        return x
#IF문이 boolean 혹은 다른 값 관련
def forward(self,
x: Tensor) -> Tensor:
branch_testing = self.branch_testing
if torch.eq(branch_testing, True):
custom_layer1 = self.custom_layer1
x0 = (custom_layer1).forward(x, )
else:
custom_layer2 = self.custom_layer2
x0 = (custom_layer2).forward(x, )
layer1 = self.layer1
x1 = (layer1).forward(x0, )
layer2 = self.layer2
x2 = (layer2).forward(x1, )
flatten = self.flatten
x3 = (flatten).forward(x2, )
fc1 = self.fc1
x4 = (fc1).forward(x3, )
condition_testing = self.condition_testing
if torch.eq(condition_testing, True):
condition_func1 = self.condition_func1
x5 = (condition_func1).forward(x4, )
else:
condition_func2 = self.condition_func2
x5 = (condition_func2).forward(x4, )
return x5
#IF문이 tensor 값 관련
def forward(self,
x: Tensor) -> Tensor:
_0 = torch.select(torch.select(x, 0, 0), 0, 0)
_1 = torch.select(torch.select(_0, 0, 0), 0, 0)
_2 = bool(torch.eq(torch.remainder(_1, 2), 0))
if _2:
custom_layer = self.custom_layer
x0 = (custom_layer).forward(x, )
else:
x0 = x
custom_layer0 = self.custom_layer
x1 = (custom_layer0).forward(x0, )
layer1 = self.layer1
x2 = (layer1).forward(x1, )
layer2 = self.layer2
x3 = (layer2).forward(x2, )
flatten = self.flatten
x4 = (flatten).forward(x3, )
fc1 = self.fc1
x5 = (fc1).forward(x4, )
condtion_func = self.condtion_func
return (condtion_func).forward(x5, )

script_cell.code 를 사용하면 위와 같이 forward 함수가 torch.jit.script 로 저장했을 때 내용을 확인할 수 있다.

def forward(self,
x: Tensor) -> Tensor:

첫 번째로 확인할 수 있는 정보는 forward 함수에서 사용하는 파라미터 정보이다.

다양한 종류의 파이토치 모델들이 존재하는데, 클래스별로 모델이 모듈화되어 있으며, 다양한 위치에서 파라미터를 선언한다. forward 함수 안에, 파라미터를 정의한 경우에, 해당 구문을 통해서 forward 함수안에 정의된 값을 찾을 수 있다.

커스텀 모델에는 x라는 1개의 input Tensor만을 갖고 있기 때문에 위와 같이 나온 것을 확인할 수 있었다.

def forward(self,x: torch.Tensor,y: torch.Tensor,is_bool: bool,is_str: str,
    is_List:list,is_float:float,is_integer:int,is_tuple:tuple  ):
    return x

위의 forward 함수와 같이 파라미터를 여러개 넣어주고, 데이터 타입을 선언해주면 아래와 같은 결과를 얻기도 한다.

하지만 자료형을 지정해주지 않으면, 호출자함수인 forward 함수에서 파라미터는 torch.Tensor 라고 자료형을 판단하게 된다.

만약에 __init__ 부분에서 파라미터들을 선언해주고, forward 함수에서 self 를 활용하여 해당 변수를 사용하게 되는 경우에 다음과 같이 모든 조건문이 출력되게 된다.

def forward(self,x):
        if self.is_bool == True:
            return 

        if self.is_float < 0:
            return

        if self.is_integer > sum(self.is_List):
            return

결국 1번 방법으로 처리할 수 있는 방법은

  1. forward 함수안에 파라미터가 어떤 데이터타입인지 확인해본다.
  2. 인스턴스 변수인 경우에, self 포함된 구절과 if 포함된 부분을 보고 어떤 데이터타입인지 유추하여 처리한다.

torch.jit.trace IF문 파악

input = torch.randn(1,3,256,256)

condition_testing = True
branch_testing = True
custom_model = CustomModel(condition_testing,branch_testing)

trace_v = torch.jit.trace(custom_model,input)
print(trace_v.code)

custom_model.eval()
with torch.no_grad():
    torch.jit.save(trace_v,os.getcwd()+'/if_trace.pt')

IF분기가 포함되어 있는 위의 모델을 동일하게 trace 버전을 사용해서 code로 출력해보았다.

예상대로 condition에 해당되는 분기만을 택하여, export 되는 것을 확인할 수 있었다.

torch.load는 어떻게 torch.jit.script와 torch.jit.trace를 parsing할까

load_script = torch.jit.load(os.getcwd()+'/if_tensor.pt')
print(type(load_script))

inference_input = torch.randn(1,3,256,256)
print(load_script(inference_input))

torch.jit.load를 활용하여 pt파일을 load하게되면 RecursiveScriptModule 로 불러오게 되고, 일반적인 pytorch 모델을 inference하기 위해 불러오는 모델과 동일한 타입으로 모델을 불러오게 된다.

https://pytorch.org/docs/stable/jit.html#migrating-to-pytorch-1-2-recursive-scripting-api

https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html#torch.jit.ScriptModule

https://pytorch.org/docs/stable/jit_builtin_functions.html

→ torchscript 지원되는 operation 정보

profile
특 : 미친듯한 게으름과 부지런한 생각이 공존하는 사람

0개의 댓글