[Pytorch] Early Stopping Sample

es.Seong·2024년 4월 4일
0

Pytorch를 통해 딥러닝 모델을 학습시키면서 사용자가 매 학습마다 가중치 저장 코드를 넣지 않는다는 가정하에, 인위적으로 학습을 중단하게 된다면 학습 모형가 저장되지 않는다. 만약 매 학습마다 가중치 저장 코드를 넣었다면, 최종적으로 저장된 가중치 파일(.pt or .pth)은 최종 Epoch에서 저장된 파일일 것이다.
하지만 학습 중 Overfitting이 발생했다면 저장된 가중치를 불러와서 Test Loader를 Predict했다면 결과는 어떨까?
아주 낮은 평가 척도를 보이며 일반화 성능이 현저히 떨어질 것이다.

학습 중 결과가 진전이 되지 않을 때 이를 중단하고 최적의 학습 결과를 얻기위해 고안된 것이 Early Stopping이다.
Early Stopping을 Pytorch 프레임워크에 맞게 구현된 코드를 살펴보자.

Early Stopping 코드

class EarlyStopping:
    def __init__(self,model, patience=3, delta=0.0, mode='min', verbose=True):
        """
        patience (int): loss or score가 개선된 후 기다리는 기간. default: 3
        delta  (float): 개선시 인정되는 최소 변화 수치. default: 0.0
        mode     (str): 개선시 최소/최대값 기준 선정('min' or 'max'). default: 'min'.
        verbose (bool): 메시지 출력. default: True
        """
        self.early_stop = False
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        
        self.best_score = np.Inf if mode == 'min' else 0
        self.mode = mode
        self.delta = delta
        self.model = model

    def __call__(self, score):

        if self.best_score is None:
            self.best_score = score
            self.counter = 0
        elif self.mode == 'min':
            if score < (self.best_score - self.delta):
                self.counter = 0
                self.best_score = score
                if self.verbose:
                    # 모델 저장
                    torch.save(self.model.state_dict(), f'best_model.pth')
                    print(f'[EarlyStopping] (Update) Best Score: {self.best_score:.5f} & Model saved')
            else:
                self.counter += 1
                if self.verbose:
                    print(f'[EarlyStopping] (Patience) {self.counter}/{self.patience}, ' \
                          f'Best: {self.best_score:.5f}' \
                          f', Current: {score:.5f}, Delta: {np.abs(self.best_score - score):.5f}')
                
        elif self.mode == 'max':
            if score > (self.best_score + self.delta):
                self.counter = 0
                self.best_score = score
                if self.verbose:
                    # 모델 저장
                    torch.save(model.state_dict(), f'best_model.pth')
                    print(f'[EarlyStopping] (Update) Best Score: {self.best_score:.5f} & Model saved')
            else:
                self.counter += 1
                if self.verbose:
                    print(f'[EarlyStopping] (Patience) {self.counter}/{self.patience}, ' \
                          f'Best: {self.best_score:.5f}' \
                          f', Current: {score:.5f}, Delta: {np.abs(self.best_score - score):.5f}')
                
            
        if self.counter >= self.patience:
            if self.verbose:
                print(f'[EarlyStop Triggered] Best Score: {self.best_score:.5f}')
            # Early Stop
            self.early_stop = True
        else:
            # Continue
            self.early_stop = False

클래스의 파라미터는 총 다섯 개이다.
model : 학습에 사용한 모델.
patience (int): loss or score가 개선된 후 기다리는 기간. default: 3
delta (float): 개선시 인정되는 최소 변화 수치. default: 0.0
mode (str): 개선시 최소/최대값 기준 선정('min' or 'max'). default: 'min'.
verbose (bool): 메시지 출력. default: True

예를 들어 patience가 20이라면 loss가 20회이상 개선이 되지 않으면 학습을 종료한다는 의미이다.
그리고 mode의 경우 클래스 내 score 변수가 커져야 하는지, 작아져야하는지 결정하는 파라미터이다.

mode를 'min'으로 사용하는 경우는 Early Stopping을 사용할 평가척도가 학습을 거듭하며 작아지는 값, 대표적으로 Loss 값이 있다.

mode를 'max'으로 사용하는 경우는 Early Stopping을 사용할 평가척도가 학습을 거듭하며 커지는 값, 대표적으로 Accuracy, mAP, PSNR, SSIM 등과 같은 평가척도가 있다.

Early Stopping 사용 예시

pt=30
es = EarlyStopping(model=model,
                   patience=pt, 
                   delta=0, 
                   mode='min', 
                   verbose=True
                  )

def train(model, train_loader, valid_loader, criterion, optimizer, es ,n_epochs=10):
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0.0
        valid_loss = 0.0
        train_correct = 0
        valid_correct = 0
        valid_total = 0
        
        for data, target in train_loader:
			..........
            ..........
                 
         
        # Early Stopping & save best model.pth
        es(valid_loss)
        if es.early_stop:
            print("Early stopping")
            break  

사용 예시를 살펴보자. 우선 es 변수에 클래스를 선언 후 사용할 파라미터를 넣어준다.
그리고 학습 과정에서 생성된 Validation Loss 값을 es 객체의 파라미터로 사용한다.
즉, Early Stopping 기준으로 Validation Loss로 잡겠다는 의미이다. 학습을 하며 Validation Loss가 일정 횟수(해당 예시에서는 30회) 진전이 없다면 학습을 중단할 것이다.


위 이미지는 Early Stopping 코드가 적용된 모형 학습 Log이다.
첫 이미지에서는 Validation Loss가 이전 Epoch과 비교해서 줄어들어서 Update Log가 보인다.
두 번째 이미지에서는 Validation Loss가 학습이 30회 더 진행되었음에도 줄어들지 않아 학습이 조기종료가 잘된 것을 확인할 수 있다.

profile
Graduate student at Pusan National University, majoring in Artificial Intelligence

0개의 댓글