Pytorch-Lightning Commom Use Cases 03 - Early Stopping

한건우·2021년 10월 21일
0

Stopping an epoch early

  • epoch을 조기에 종료하게 하고 싶다면, on_train_batch_start()-1을 리턴하도록 하면됨
  • 만약 이걸 반복하게 하고 싶다면, 전체 epoch이 이러도록 세팅하면 전체 run이 멈춤

Early Stopping based on metric using the EarlyStopping Callback

  • EarlyStopping callback은 validation metric 값을 보고 더이상 성능 향상이 없는 경우 중지함

  • 사용법

    • EarlyStopping callback을 import함
    • 원하는 metric을 log() method를 이용해서 log남김
    • callback을 선언하고, monitor에 원하는 metric을 넣음
    • Trainer callbacks flag에 EarlyStopping callback을 넘김
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


def validation_step(self):
    self.log("val_loss", loss)


trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss")])
  • parameter들을 변경하여 원하는 대로 커스텀 가능
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
  • 추가적으로 극단적인 상황일때 학습을 멈추게 해주는 parameter들
    • stopping_threshold : 성능이 일정 threshold를 넘어갔을 경우 즉시 학습을 멈춤. 원하는 성능 이상 나왔을 때 사용할 것
    • divergence_threshold: 성능이 일정 threshold 밑으로 내려갔을때 즉시 학습을 멈춤. 이 이하로 내려갔을 경우 성능을 복원하는 것이 불가능하다고 생각될 때 사용할 것
    • check_finite: NaN이나 infinite가 metric에 찍히면 즉시 종료
  • 학습 도중 다른 요인 때문에 멈춰야 한다면 EarlyStopping class를 수정하여 호출할 것
class MyEarlyStopping(EarlyStopping):
    def on_validation_end(self, trainer, pl_module):
        # override this to disable early stopping at the end of val loop
        pass

    def on_train_end(self, trainer, pl_module):
        # instead, do it at the end of training loop
        self._run_early_stopping_check(trainer, pl_module)
  • EarlyStopping callback은 매 validation epoch 마다 실행이 되지만, check_val_every_n_epoch, val_check_interval 같은 옵션에 의해서 counting 되는 것이 달라질 수 있음
  • 예를들어 check_val_every_n_epoch=10, patience=3인 경우 최소 40 epoch에서 멈출 수 있음
profile
아마추어 GAN잽이

0개의 댓글