on_train_batch_start()
가 -1
을 리턴하도록 하면됨run
이 멈춤EarlyStopping
callback은 validation metric 값을 보고 더이상 성능 향상이 없는 경우 중지함
사용법
EarlyStopping
callback을 import함log()
method를 이용해서 log남김monitor
에 원하는 metric을 넣음Trainer
callbacks flag에 EarlyStopping
callback을 넘김from pytorch_lightning.callbacks.early_stopping import EarlyStopping
def validation_step(self):
self.log("val_loss", loss)
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss")])
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
stopping_threshold
: 성능이 일정 threshold를 넘어갔을 경우 즉시 학습을 멈춤. 원하는 성능 이상 나왔을 때 사용할 것divergence_threshold
: 성능이 일정 threshold 밑으로 내려갔을때 즉시 학습을 멈춤. 이 이하로 내려갔을 경우 성능을 복원하는 것이 불가능하다고 생각될 때 사용할 것check_finite
: NaN이나 infinite가 metric에 찍히면 즉시 종료EarlyStopping
class를 수정하여 호출할 것class MyEarlyStopping(EarlyStopping):
def on_validation_end(self, trainer, pl_module):
# override this to disable early stopping at the end of val loop
pass
def on_train_end(self, trainer, pl_module):
# instead, do it at the end of training loop
self._run_early_stopping_check(trainer, pl_module)
EarlyStopping
callback은 매 validation epoch 마다 실행이 되지만, check_val_every_n_epoch
, val_check_interval
같은 옵션에 의해서 counting 되는 것이 달라질 수 있음check_val_every_n_epoch=10
, patience=3
인 경우 최소 40 epoch에서 멈출 수 있음