Journal: 2021 CVPR
Github: https://github.com/kekmodel/MPL-pytorch
Approach
- Teacher 모델이 unlabeled pool에 대해 inference해서 pseudo labeled 미니배치를 생성한다.
- Student 모델이 teacher가 레이블링한 pseudo labeled 미니배치를 학습한다.
- 학습된 student 모델이 labeled 미니배치를 inference한 성능을 피드백으로 받아서 teacher 모델도 학습된다. Teacher 모델을 학습할 때는 보조적인 loss를 쓰면 성능이 향상되며, 여기서는 labeled dataset에 대해 supervised learning해서 얻는 loss와 unlabeled dataset에 대해 semi-supervised learning을 해서 얻는 UDA loss(Unsupervised Data Augmentation)를 같이 썼다.
- MPL로 학습 후 수렴된 student 모델은 labeled 데이터셋에서 finetuning 한다.
- MPL할 때 teacher와 student는 동시에 학습된다. 동시에 학습시키려면 loss function을 수정하면 된다.
- Teacher와 student 모두 EfficientNet-L2 사용했다. → 이 모델 말고도 여러 모델로 실험 했다.
- Labeled dataset = ImageNet, unlabeled data pool = JFT-200M
Pseudo Labels
MPL과 비슷한 기존 방법: Pseudo Labels (aka Self-training)
- Self-training: Semi-supervised learning의 방법 중 하나로, 레이블링 데이터에서 학습된 모델이 unlabeled data에 대해 inference한 값을 pseudo label로 활용하는 방식이다.
- 모델: 한 쌍의 모델이 있는데, 하나는 teacher 다른 하나는 student다.
- Teacher: unlabeled images로부터 psuedo label을 생성한다. 여기서 생성된 pseudo labeled images + labeled images가 student의 학습 데이터셋이 된다.
- Student: Pseudo labeled images + labeled images를 학습해서 결국에는 teacher보다 높은 성능의 모델이 된다. 이는 많은 pseudo labeled images와 data augmentation 같은 기법을 적용함으로써 달성 가능하다.
- 단점: psuedo label이 부정확하면 student가 부정확하게 레이블링된 데이터를 학습하게 된다. 이러면 결국 student의 성능은 teacher보다 높아질 수 없다. 이런 문제를 “Confirmation Bias (확증 편향)” 문제라고 부른다.
- Confirmation Bias 문제 해결방법: 만들어낸 pseudo labels가 student에 어떤 영향을 미치는지를 관찰하면서 → teacher는 bias를 고치게 된다.
- Teacher가 더 나은 pseudo labels를 만들 수 있도록 → student는 피드백 신호을 준다.
- 피드백 신호: Student가 labeled dataset에 대해서 낸 퍼포먼스를 말한다. 이 피드백 신호는 student가 학습하는 동안 계속 teacher에게 전달된다 (병렬 학습).
- Teacher와 student의 병렬 학습:
- Teacher가 레이블링 한 pseudo labeled 미니배치를 → student가 학습한다.
- Student가 labeled dataset에서 뽑은 미니배치에 대해 내는 퍼포먼스를 → teacher가 피드백 신호로 받아서 학습한다.
Pseudo Labels as an optimization problem
Pseudo Labels
- Student는 unlabeled data에서의 cross-entropy loss를 최소화하기 위해 학습한다.
- Teacher가 unlabeled dataset에서 만든 pseudo labels와 student가 unlabeled dataset에 대해 prediction한 값 사이의 cross-entropy loss를 구한다.
- Unlabeled dataset에 대한 모든 cross-entropy loss의 기댓값을 구한다.
- 기댓값을 최소화하는 student의 파라미터를 찾는다.
- Teacher의 파라미터는 고정되어 변하지 않는다. → MPL과 다른 부분.
- Teacher가 사전에 잘 학습된 모델이라는 가정 하에, 학습 완료된 student의 파라미터는 labeled data에서 낮은 loss를 보일 것이다.
- 학습 완료된 최적의 student 파라미터는 항상 teacher의 파라미터에 의존적일 수밖에 없다 → teacher가 만든 pseudo label로 학습하기 때문이다.
Meta Pseudo Labels
- Labeled data에 대한 student의 퍼포먼스를 가지고 teacher의 파라미터를 최적화해서 → pseudo label도 최적화 되어 student의 퍼포먼스를 높일 수 있다.
Practical approximation
MPL을 구현하기 위해서는 Meta learing의 개념을 몇 개 차용하고, multi-step인 /argmin{\theta{S}}를 \theta_{S}에 대한 one-step gradient 업데이트로 근사화 한다.
Student update & Teacher update
-
MPL의 Student는 Pseudo Labels처럼 식 (1)에 여전히 의존적이지만, 다른 점은 teacher의 파라미터는 더 이상 고정이 아니고 → teacher가 optimization될 때마다 \theta_{T}도 바뀐다.
-
그리고 중요한 점은, 식 (3)의 one-step approximation에서 student의 파라미터 업데이트가 재사용 된다. → 이것은 student와 teacher의 업데이트 간 alternating optimization procedure를 발생시킨다.
-
Student: Unlabeled data 배치 x{u}에 대해 → Teacher의 prediction T(x{u}; theta_{T})을 얻고 → Student의 SGD를 사용해서 식 (1) Student parameter를 최적화한다.
-
Teacher: Labeled data 배치 (x{l}, y{l})에 대해 → Student의 파라미터 업데이트를 재사용하고 → Teacher의 SGD를 사용해서 식 (3) Approximated teacher objective를 최적화한다.
Teacher’s auxiliary losses
- 지금까지 설명된 MPL 메소드만으로도 좋은 성능을 내는데, 만약 teacher에 다른 loss들을 추가하면 더 높은 성능을 낸다.
- 그래서 논문에서는 teacher를 학습할 때 supervised learning loss + semi-supervised learning loss를 같이 사용한다.
- Teacher의 loss 두 가지: Supervised & UDA objectives
- Supervised objective: Teacher를 labeled data에서 학습시킨다.
- Semi-supervised objective: Teacher를 unlabeled data에서 추가적으로 학습시킨다. 이때 UDA objective를 사용하였다.
- 이렇게 학습되는 teacher가 만들어내는 pseudo labels에 대해 student는 unlabeled data만 학습하게 되고, 수렴한 후에는 labeled data에서 finetuning을 해서 정확도를 높일 수 있도록 한다.
- 논문 저자들도 MPL을 학습시키는게 너무 cost가 크다고 느꼈는지 lite version MPL도 같이 제시한다. → Appendix E
Datasets
- 특징은 구글만 독점적으로 사용하고 공개하지 않는 거대 데이터셋인 JFT 데이터셋을 사용하지 않는다는 점이다.
- Labeled dataset
- 4,000 CIFAR-10, 7,300 SVHN, 40 data shards ImageNet → 하이퍼 파라미터 튜닝용
- CIFAR-10 45,000, 65,000 SVHN, 1,230,000 ImageNet →
- Unlabeled dataset: YFCC100M dataset
Results
- EfficentNet-B7 모델로 ImageNet ILSRVC 2012 validation set 에서 86.9% accuracy.
- 2개의 네트워크를 메모리에 올리지 않아도 되게 되었다 → ?
Appendix E
- Teacher 모델 T를 수렴할 때까지 학습시킨다.
- 이 모델로 student의 학습 데이터셋의 모든 타겟 분포를 계산한다. 중요한건 이때까지는 아직 student가 메모리에 로드되지 않는다. 그래서 메모리 공간을 크게 차지하지 않도록 만든다.
- Reduced teacher T’을 MLP처럼 작고 효율적인 모델로 매개변수화 (parameterize)해서 student와 함께 학습하도록 한다.
- Reduced teacher T’은 teacher T가 예측한 분포를 입력으로 받아서 → student가 학습할 수 있도록 보정된 분포를 출력한다.
⇒ Reduced MPL이 잘 작동할 수 밖에 없는 이유: 큰 모델이 T이기 때문에 정확할 것이고, 그래서 reduced 모델인 T’은 identity map(항등 함수?)에 가까울 것이다. Identity map은 MLP로 핸들링 된다. → 항등 함수에 가깝다는 말이 무엇일까? 정확하다는 표현인가?
- Model Architectures
- EfficinetNet-B0 for CIFAR-10 and SVHN, EfficientNet-B7 for ImageNet.
- Teacher model: small 5-layer perceptron
Conclusion
- Main idea: Teacher가 student로부터 피드백을 받아서 더 나은 pseudo label을 생성한다.
- Two main updates
- Teacher가 생성한 pseudo label로 student가 학습한다.
- Student가 주는 피드백을 가지고 teacher가 학습한다.
- Standard low-resource 데이터셋과 large scale 실험을 통해 → 다른 semi-supervised learning 메소드들보다 좋은 퍼포먼스를 냈고 → 이건 student가 teacher에게 주는 피드백의 역할이 중요함을 보여준다.
훌륭한 글이네요. 감사합니다.