Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity 정리

Plato·2023년 12월 27일
0

딥러닝

목록 보기
9/21

필요한 배경지식

구조

  • 트랜스포머 모델의 MLP를 sparse Mixture of Experts(MoE)로 바꿈
    • 사용한 라우터는 Shazeer et al. (2017)과 동일함
      • 다만 Shazeer et al. (2017)과 다르게, 라우터가 한 개의 전문가 네트워크(이후에는 전문가 네트워크를 전문가라 칭함)를 선택함
        • 이유 1: 두 개 이상의 전문가를 선택해야, 전문가의 출력을 비교할 수 있기에 의미있는 gradient 신호를 얻을 거라는 추측이 있었음. 하지만 추측과 다르게, 한 개만 선택해도 학습할 수 있음을 확인함
        • 이유 2: 한 개의 전문가만 선택하면 학습과 추론시 계산량이 작아짐

균등한 토큰 분배

  • Expert Capacity
    • 한 번의 gradient update를 위해, 각 전문가가 받는 토큰의 최대 개수. 즉 각 전문가의 batch 크기가 expert capacity를 넘길 수 없음
    • expert capacity를 넘기는 수의 토큰이 특정 전문가에 라우팅 되면, capacity 만큼만 해당 전문가가 처리하고 나머지 토큰은 어떠한 전문가도 처리하지 않음
      • 다만 residual connection으로 인해, 전문가가 처리하지 않더라도 모델의 출력에는 영향을 미침
    • 균등한 토큰 분배가 유일한 사용 이유는 아님
      • 저자가 사용한 프레임워크는 정적 계산을 지원하기에, 텐서 shape를 지정해야 함. 그런데 batch마다 각 전문가에 라우팅 되는 토큰 수가 바뀌기에 최대 expert capacity만큼의 토큰을 수용할 수 있도록 텐서 shape를 미리 지정함
  • 균등한 토큰 분배를 강제하기 위한 auxiliary loss를 negative log likelihood loss에 더함
    • lossauxiliary=αNi=1NfiPiloss_{auxiliary}=α * N * \sum_{i=1}^{N} f_i * P_i
      • fi=ceTf_i=\frac{c_e}{T} 여기에서 cec_e는 전문가 ee에 라우팅 된 토큰 수이고 TT는 배치의 토큰 수
        • fif_i 사용 이유: 균등하게 라우팅할 때, i=1Nfi2\sum_{i=1}^{N}f_i^2이 최소화된다.
      • Pi=1TxBp(xi)P_i=\frac{1}{T}\sum_{x \in \mathbf{B}} p(x_i)
        • 이는 Shazeer et al. (2017)의 중요도를 TT로 나눈 것과 같다
        • pip_i 사용 이유: fif_icec_e가 미분 가능하지 않기에 fi2f_i^2을 사용해서 기울기 신호를 받을 수 없다. 그렇기에 fif_i의 미분 가능한 '근사값'으로 fif_i를 곱한다. 이것이 PiP_i. PiP_i는 미분 가능하기에 기울기 신호를 받을 수 있다.

안정적인 학습 및 미세 조정을 위한 방법

  • 바닐라 트랜스포머에 비해 switch transformer를 안정적으로 학습시키기 어렵다
  • 선택적인 precision
    • float16과 같이 낮은 precision을 갖는 데이터 타입을 사용하면 불안정하게 학습한다
      • 모델의 일부는 float16을 사용하고 나머지에서 float32를 사용하면 안정적으로 학습함
  • Dropout
    • 미세 조정할 때 작은 데이터 세트를 사용하기에 overfitting은 흔한 문제다
      • 특히 동일한 flops를 갖는 dense 모델보다 switch transformer가 많은 파라미터를 갖기에 overfitting이 더 심하게 발생한다.
      • 전문가가 아닌 네트워크는 낮은 드랍아웃 비율을 사용하고 전문가는 높은 드랍아웃 비율을 사용할 때 성능이 제일 좋음을 확인함

병렬 처리

논문에 data parallelism, model parallelism, data and model parallelism, expert and data parallelism 그리고 expert, model and data parallelism을 시각적으로 잘 표현한 그림이 있다. 꼭 참고해 보자.

미세 조정 문제

  • sparse 모델과 dense 모델 둘 다 파라미터 수가 증가함에 따라 perplexity가 낮아졌다
  • 다만 논문에서 테스트한 제일 큰 sparse 모델은, 사전 학습 때의 낮은 perplexity가 upstream 작업에서의 성능 향상으로 이어지지 않음
    • 위에서 말했듯이 overfitting 문제가 있기에 공격적으로 드랍아웃 비율을 설정했지만, 미세 조정의 결과가 좋지 않았음

No Token Left Behind

  • expert capacity를 초과하여 토큰이 버려지는 것을 막기 위해, 다음 전문가에 라우팅 되도록 실험해 봤다. 성능 향상 없음

탐색 방법

  • 탐색이 필요한 이유
    • switch transformer는 한 개의 토큰에 대해 한 개의 전문가 네트워크만 선택하기에 counterfactual reasoning이 불가하다.
  • 시도한 방법 네 가지
    • argmax
      • top-1 전문가 네트워크 선택
      • perplexity: -1.471
    • Sample softmax
      • softmax 레이어의 출력을 확률로 해석하여 샘플링. ex) 전문가 네트워크1의 weight=.9이면 90% 확률로 전문가 네트워크1을 선택
      • perplexity: -1.570
    • Input dropout
      • 전문가의 입력에 드랍아웃 적용
      • perplexity: -1.480
    • Input jitter
      • 전문가의 입력에 multiplicative 노이즈를 추가함
      • perplexity: -1.468
  • 경험적으로 제일 좋은 성능을 보인 input jitter 방식을 사용함

0개의 댓글