불균형 데이터 처리: Focal Loss

Mollang·2022년 11월 25일
1

현재 7개의 타겟 클래스를 맞추는 다중 분류 태스크를 Bert 모델로 학습시키고 있다.

Bert모델로 학습할 때 확인한 평균 best score(f1-score)는 56~60 사이인데, 실제 예측 점수는 38점을 기록했다. 이전에 다뤘던 text 데이터도 NLP 태스크인데, 당시엔 f1-score와 실제 예측 스코어간의 차이가 크지 않았다. (평균적으로 0.3 정도의 차이 존재)

큰 점수 격차의 원인은 "데이터 불균형"으로 추측된다. 본인은 데이터 불균형을 해소할 수 있는 방안으로 데이터 증강을 생각했다. 관련 레퍼런스를 더 찾던 와중에 focal loss라는 개념을 발견하였다.

다중 분류 문제에서 타겟 클래스 간 분포가 불균형할 경우, 분포도가 높은 클래스 (더 많은 데이터를 갖고 있는 클래스)에 높은 가중치를 두게 된다. 모델 자체적으로 예측 성능을 높게 평가하므로, accuracy는 높아질 수 있으나, 분포가 낮은 클래스에 대한 예측은 낮게 나타나게 된다.

다중 분류 문제에서 보통 Cross Entropy 손실 함수를 사용한다. 본인도 이번 대회에서 Cross Entropy 손실 함수를 사용했다.
(대부분의 손실함수는 학습 중 발생한 데이터 셋의 예측 오류를 합산한 것이다. 세부적인 가중치나 공식은 다르지만, 손실함수는 전체 데이터에 대한 예측 오류 합산의 개념으로 설계된다. 예를 들어 10000개의 데이터터에서 10개를 배치 데이터 셋으로 예측하고, 이 예측 오류를 실제 데이터와 비교하여 10개 배치 데이터 전체에 대한 예측 오류를 계산한다. 이를 1000번 수행하면서 전체 데이터에 대한 예측 오류를 합산한다.)

Cross Entropy Loss의 경우, 잘 분류한 경우보다 잘못예측한 경우에 패널티를 부여하는 것에 초점을 둔다.

요약

러프하게 요약하자면, 잘 예측한 경우 보상도 없고 패널티도 없다. 그러나 잘 예측하지 못할 경우 패널티가 굉장히 크게 부과된다. (*Foreground와 Background에 대한 내용도 이해에 필요하지만 설명 역량이 부족하여 여기선 건너 뛴다_ 세부 내용은 [참고자료 3] 확인) 여기서 발생하는 문제점을 개선하기 위해 등장한 것이, Focal Loss이다.

Cross Entropy Loss는 확률이 낮은 케이스에 패널티를 주는 역할만 수행하고 확률이 높은 케이스에 어떤 보상도 하지 않지만, Focal Loss는 확률이 높은 케이스에는 확률이 낮은 케이스 보다 Loss를 더 크게 낮추는 보상을 준다.

Keras를 사용하면 RetinaNet의 focal loss를 바로 사용가능하지만, Pytorch를 사용할 경우 안타깝게도 새로 코드를 구현해야한다.
https://deep-learning-study.tistory.com/616

*파이토치 코드는 위 링크 참조

참고 자료1
참고 자료2
참고 자료3
참고 자료4

0개의 댓글