Model compression방법으로 knowledge distillation를 설명하도록 하겠습니다.
Knowledge distillation은 teacher network와 student network의 ensemble을 기반으로 한 방법이라 설명할 수 있습니다. 여기서 knowledge distillation이란 teacher network(capacity가 큰 모델)를 학습한 이후, 이 모델로부터 student network(teacher network보다 capacity가 작은 모델)에게 knowledge를 transfer함을 의미합니다.
우선 Knowledge Distillation이 어떻게 동작하는지 순차적으로 설명하도록 하겠습니다.
첫번째, distillation loss에 대해서 설명하도록 하겠습니다.
👉 input → Teacher model → output(logits) → softmax(temperature scaling) : soft labels
input → Student model → output(logits) → softmax(temperature scaling) : soft predictions
본 연구에서는 기존에 많이 다뤄지는 softmax function에 temperature scaling을 적용하였습니다. temperature scaling에 관하여 간략하게 설명하도록 하겠습니다. 수식은 아래과 같습니다.
Temperature scaling =
좀 더 쉬운 이해를 돕기위해, 임의의 logit값으로 softmax/softmax with temperature scaling 각각에 대해 시각화를 진행 해보았습니다. 왼쪽 그림을 보시면 정확한 수치를 파악하기 힘들 정도로 값이 작은 그래프가 존재합니다.
그러나 후자의 경우, 값들이 좀더 smooth해짐을 알 수 있습니다. 학습에 더 잘 반영할 수 있도록 (정보를 좀 더 잘 전달하기 위해서) temperature scaling을 적용하였습니다.
Temperature scaling은 모든 class에 대해 single scalar parameter 를 logit vector 에 나눠주는 방법입니다.
Distillation loss의 경우 teacher model로 학습했을때의 output과 student model로 학습했을때의 output의 확률분포를 최소화 하는 방향으로 학습합니다. 수식은 다음과 같습니다.
는 softmax, /는 각각 teacher/student model의 output logits, 는 temperature를 의미합니다.
두번째, student loss에 대해서 설명하도록 하겠습니다.
👉 input → Student model → output(logits) → softmax(t=1) : hard predictions
hard label(ground truth : one-hot)
Student loss는 hard predictions, hard label간의 loss를 최소화 하는 방향으로 학습을 진행합니다. 수식은 다음과 같습니다.
는 Student model의 output logits을 의미하고 는 ground truth를 의미합니다.
최종적으로 distillation loss, student loss를 더하여 total loss값을 최소화 하는 방향으로 학습을 진행합니다. total loss는 다음과 같습니다.
: distillation loss / student loss에 가하는 weight hyperparameter