PyTorch Model Summary

Sunwoo Pi·2023년 8월 7일
0

What I Learned Today

목록 보기
3/3
post-thumbnail

torchinfo

PyTorch Model을 summarize해주는 많은 Library들이 존재하지만 torchinfo 하나만 있으면 다른 모든 것들을 대부분 대체 가능하기에 torchinfo를 사용하는 것을 적극 추천한다.

torchinfo는 TensorFlow의 model.summary()와 유사하게 print(your_pytorch_model)가 제공하는 정보를 보완하여 PyTorch Model을 summarize해준다. 이는 네트워크를 debugging할 때 큰 도움이 된다.

python -m pip install torchinfo # torchinfo Library 설치
from torchvision import models
from torchinfo import summary

model = models.efficientnet_v2_s(weights='DEFAULT')

summary(model, input_size=(16, 3, 224, 224), col_width=20, depth=5, row_settings=["depth", "var_names"], col_names=["input_size", "kernel_size", "output_size", "params_percent"])
# 밑의 출력 정보 예시의 경우 'depth=1' 설정

summary()의 parameter들은 마음대로 설정 가능하지만 나는 위의 설정에서 depth 인자값만 바꾸어 가며 주로 사용한다.

  • depth : Model Layer를 어느 정도의 깊이까지 summary를 할 것인지에 대한 parameter
  • row_settings : 각 Layerd에서 어느 정보를 확인할 것인지에 대한 parameter
    • "depth" : Layer의 깊이 정보
    • "var_names" : 해당 Layer의 변수명 (변수명을 통해 해당 Layer에 직접 접근하여 모델 개발 용이)

torchinfo 및 parameter의 자세한 정보는 https://github.com/TylerYep/torchinfo 에서 확인 가능하다.

profile
어려운 게 제일 싫어😝

1개의 댓글

comment-user-thumbnail
2023년 8월 7일

글 잘 봤습니다.

답글 달기