- model을 테스트하는 건 쉽지 않은 일
- Pytorch-Lightning에서는 unittest를 위해 여러가지 debug flag를 제공함
1. fast_dev_run
- 전체 데이터로 테스트 하는게 아니라 소수의 batch로 테스트 할 수 있는 debug flag
n(int)
/ True
로 설정 가능
trainer = Trainer(fast_dev_run=True)
trainer = Trainer(fast_dev_run=7)
fast_dev_run
옵션 사용시 tuner
, checkpoint callbacks
, ealry stopping callbacks
, logger
, logger callbacks
가 비활성화됨
- 1 epoch에 대해서만 실행됨
2. Inspect gradient norms
- 각 weight matrix의 norm을 log 찍어줌
trainer = Trainer(track_grad_norm=2)
3. Log device stats
DeviceStatsMonitor
로 device stats을 log 찍어줌
from pytorch_lightning.callbacks import DeviceStatsMonitor
trainer = Trainer(callbacks=[DeviceStatsMonitor()])
4. Make model overfit on subset of data
- 특정 데이터 subset에 overfitting 시켜, 정상적으로 overfitting 되는지 확인해주는 debug flag
- 만약 overfitting이 정상적으로 안된다면, 전체 데이터셋에서도 작동 안된다는 뜻
trainer = Trainer(overfit_batches=0.01)
trainer = Trainer(overfit_batches=10)
5. Print a summary of your LightningModule
.fit()
함수 호출시, Trainer
가 LightningModule
의 weight summary
를 출력함
- 기본 옵션은 top-level module만 출력하지만, 원한다면
max_depth
옵션을 이용해서 sub-level module도 출력가능
from pytorch_lightning.callbacks import ModelSummary
trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])
example_input_array
옵션을 이용해서 모든 레이어의 input/output size도 출력 가능
Trainer
에서 .fit()
호출시 발생하는 버그는 다음 component를 확인하여 해결가능
ModelSummary
summarize()
6. Shorten epochs
- 전체 데이터셋에서 일부분만 가지고 학습을 하거나, validation 가능하게 해주는 debug flag
trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.01)
trainer = Trainer(limit_train_batches=10, limit_val_batches=5)
7. Set the number of validation sanity steps
- 학습 시작 전에 validation step을 돌려보면서 전체 데이터학습 이후에 validation step에서 crash 나는걸 미리 검증해주는 debug flag
- 기본 옵션은 validation step 2회 실행
trainer = Trainer(num_sanity_val_steps=2)