필요한 배경지식
구조
- 트랜스포머 모델의 MLP를 sparse Mixture of Experts(MoE)로 바꿈
- 사용한 라우터는 Shazeer et al. (2017)과 동일함
- 다만 Shazeer et al. (2017)과 다르게, 라우터가 한 개의 전문가 네트워크(이후에는 전문가 네트워크를 전문가라 칭함)를 선택함
- 이유 1: 두 개 이상의 전문가를 선택해야, 전문가의 출력을 비교할 수 있기에 의미있는 gradient 신호를 얻을 거라는 추측이 있었음. 하지만 추측과 다르게, 한 개만 선택해도 학습할 수 있음을 확인함
- 이유 2: 한 개의 전문가만 선택하면 학습과 추론시 계산량이 작아짐
균등한 토큰 분배
- Expert Capacity
- 한 번의 gradient update를 위해, 각 전문가가 받는 토큰의 최대 개수. 즉 각 전문가의 batch 크기가 expert capacity를 넘길 수 없음
- expert capacity를 넘기는 수의 토큰이 특정 전문가에 라우팅 되면, capacity 만큼만 해당 전문가가 처리하고 나머지 토큰은 어떠한 전문가도 처리하지 않음
- 다만 residual connection으로 인해, 전문가가 처리하지 않더라도 모델의 출력에는 영향을 미침
- 균등한 토큰 분배가 유일한 사용 이유는 아님
- 저자가 사용한 프레임워크는 정적 계산을 지원하기에, 텐서 shape를 지정해야 함. 그런데 batch마다 각 전문가에 라우팅 되는 토큰 수가 바뀌기에 최대 expert capacity만큼의 토큰을 수용할 수 있도록 텐서 shape를 미리 지정함
- 균등한 토큰 분배를 강제하기 위한 auxiliary loss를 negative log likelihood loss에 더함
- lossauxiliary=α∗N∗∑i=1Nfi∗Pi
- fi=Tce 여기에서 ce는 전문가 e에 라우팅 된 토큰 수이고 T는 배치의 토큰 수
- fi 사용 이유: 균등하게 라우팅할 때, ∑i=1Nfi2이 최소화된다.
- Pi=T1∑x∈Bp(xi)
- 이는 Shazeer et al. (2017)의 중요도를 T로 나눈 것과 같다
- pi 사용 이유: fi의 ce가 미분 가능하지 않기에 fi2을 사용해서 기울기 신호를 받을 수 없다. 그렇기에 fi의 미분 가능한 '근사값'으로 fi를 곱한다. 이것이 Pi. Pi는 미분 가능하기에 기울기 신호를 받을 수 있다.
안정적인 학습 및 미세 조정을 위한 방법
- 바닐라 트랜스포머에 비해 switch transformer를 안정적으로 학습시키기 어렵다
- 선택적인 precision
- float16과 같이 낮은 precision을 갖는 데이터 타입을 사용하면 불안정하게 학습한다
- 모델의 일부는 float16을 사용하고 나머지에서 float32를 사용하면 안정적으로 학습함
- Dropout
- 미세 조정할 때 작은 데이터 세트를 사용하기에 overfitting은 흔한 문제다
- 특히 동일한 flops를 갖는 dense 모델보다 switch transformer가 많은 파라미터를 갖기에 overfitting이 더 심하게 발생한다.
- 전문가가 아닌 네트워크는 낮은 드랍아웃 비율을 사용하고 전문가는 높은 드랍아웃 비율을 사용할 때 성능이 제일 좋음을 확인함
병렬 처리
논문에 data parallelism, model parallelism, data and model parallelism, expert and data parallelism 그리고 expert, model and data parallelism을 시각적으로 잘 표현한 그림이 있다. 꼭 참고해 보자.
미세 조정 문제
- sparse 모델과 dense 모델 둘 다 파라미터 수가 증가함에 따라 perplexity가 낮아졌다
- 다만 논문에서 테스트한 제일 큰 sparse 모델은, 사전 학습 때의 낮은 perplexity가 upstream 작업에서의 성능 향상으로 이어지지 않음
- 위에서 말했듯이 overfitting 문제가 있기에 공격적으로 드랍아웃 비율을 설정했지만, 미세 조정의 결과가 좋지 않았음
No Token Left Behind
- expert capacity를 초과하여 토큰이 버려지는 것을 막기 위해, 다음 전문가에 라우팅 되도록 실험해 봤다. 성능 향상 없음
탐색 방법
- 탐색이 필요한 이유
- switch transformer는 한 개의 토큰에 대해 한 개의 전문가 네트워크만 선택하기에 counterfactual reasoning이 불가하다.
- 시도한 방법 네 가지
- argmax
- top-1 전문가 네트워크 선택
- perplexity: -1.471
- Sample softmax
- softmax 레이어의 출력을 확률로 해석하여 샘플링. ex) 전문가 네트워크1의 weight=.9이면 90% 확률로 전문가 네트워크1을 선택
- perplexity: -1.570
- Input dropout
- 전문가의 입력에 드랍아웃 적용
- perplexity: -1.480
- Input jitter
- 전문가의 입력에 multiplicative 노이즈를 추가함
- perplexity: -1.468
- 경험적으로 제일 좋은 성능을 보인 input jitter 방식을 사용함