Pytorch-Lightning Commom Use Cases 02 - Debugging

한건우·2021년 10월 20일
0
  • model을 테스트하는 건 쉽지 않은 일
  • Pytorch-Lightning에서는 unittest를 위해 여러가지 debug flag를 제공함

1. fast_dev_run

  • 전체 데이터로 테스트 하는게 아니라 소수의 batch로 테스트 할 수 있는 debug flag
  • n(int) / True로 설정 가능
    • n(int)
      • n 개 만큼의 batch를 실험
    • True
      • 1 개 만큼의 batch를 실험
# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)

# runs 7 train, val, test batches and program ends
trainer = Trainer(fast_dev_run=7)
  • fast_dev_run 옵션 사용시 tuner, checkpoint callbacks, ealry stopping callbacks, logger, logger callbacks 가 비활성화됨
  • 1 epoch에 대해서만 실행됨

2. Inspect gradient norms

  • 각 weight matrix의 norm을 log 찍어줌
# the 2-norm
trainer = Trainer(track_grad_norm=2)

3. Log device stats

  • DeviceStatsMonitor로 device stats을 log 찍어줌
from pytorch_lightning.callbacks import DeviceStatsMonitor

trainer = Trainer(callbacks=[DeviceStatsMonitor()])

4. Make model overfit on subset of data

  • 특정 데이터 subset에 overfitting 시켜, 정상적으로 overfitting 되는지 확인해주는 debug flag
  • 만약 overfitting이 정상적으로 안된다면, 전체 데이터셋에서도 작동 안된다는 뜻
# use only 1% of training data (and use the same training dataloader (with shuffle off) in val and test)
trainer = Trainer(overfit_batches=0.01)

# similar, but with a fixed 10 batches no matter the size of the dataset
trainer = Trainer(overfit_batches=10)

5. Print a summary of your LightningModule

  • .fit() 함수 호출시, TrainerLightningModuleweight summary를 출력함
  • 기본 옵션은 top-level module만 출력하지만, 원한다면 max_depth 옵션을 이용해서 sub-level module도 출력가능
from pytorch_lightning.callbacks import ModelSummary

trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])
  • example_input_array 옵션을 이용해서 모든 레이어의 input/output size도 출력 가능

  • Trainer에서 .fit() 호출시 발생하는 버그는 다음 component를 확인하여 해결가능
  • ModelSummary
  • summarize()

6. Shorten epochs

  • 전체 데이터셋에서 일부분만 가지고 학습을 하거나, validation 가능하게 해주는 debug flag
# use only 10% of training data and 1% of val data
trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.01)

# use 10 batches of train and 5 batches of val
trainer = Trainer(limit_train_batches=10, limit_val_batches=5)

7. Set the number of validation sanity steps

  • 학습 시작 전에 validation step을 돌려보면서 전체 데이터학습 이후에 validation step에서 crash 나는걸 미리 검증해주는 debug flag
  • 기본 옵션은 validation step 2회 실행
# DEFAULT
trainer = Trainer(num_sanity_val_steps=2)
profile
아마추어 GAN잽이

0개의 댓글