LLaMA: Open and Efficient Foundation Language Models 정리
모델 아키텍처
- 원래의 트랜스포머와 크게 다르지 않음
- 차이
- Pre-normalization: 안정적인 학습을 위해, 계층 출력을 normalize 하지 않고 입력을 normalize함.
- RMSnorm: Scale and re-centering invariance를 갖는 Layer norm과 다르게 Scale invariance만 갖는 정규화 방법. 대신 요구 계산량은 layer norm보다 낮다
- RMSnorm(x)=n1Σixi2
- 활성화 함수: SwiGLU(x,W,V,b,c,β)=swishβ(xW+b)⊙(xV+c). 여기에서 ⊙은 element-wise product를 의미함. swish(x,β)=xσ(βx)
- 종종 swish(x,β) 대신 swish(x)=xσ(x)를 사용함. 다만 llama는 β를 생략했는지 모르겠음
- 선형 변환의 inner layer dimension: 4d대신 324d
- Rotary Embedding: 토큰의 상대적 위치를 임베딩함
자원 사용량 최적화
- 메모리 복잡도가 O(n)인 Lazy Attention 사용. n = 토큰수
- Attention(q,K,V)=Σje(q⋅kj)Σie(q⋅ki)∗vi K,V는 행렬
- 이를 이용하면 한 개의 query 벡터에 대한 attention 연산의 메모리 복잡도는 O(1)이 됨
- n 개의 query 벡터에 대해 attention 연산한 결과를 저장하기 때문에 메모리 복잡도는 O(n)이 된다
- 메모리 복잡도가 O(n2)인 원래 attention보다 메모리 효율적임
- gradient checkpointing
- 모든 레이어의 gradient를 저장하지 않고 선형 계층처럼 높은 계산량을 요구하는 레이어의 gradient를 저장
- 가려진 key/query의 score는 계산하지 않음
- 모델을 여러 GPU에 나누는 model parallel 사용
loss 최적화
- AdamW
- β1=.9,β2=.95
- gradient clipping
- Weight Decay
- Warmup Rate
- cosine learning rate
데이터
- OpenAI의 GPT와 다르게 model capacity 대비 학습 데이터 세트의 크기가 큼
- 전체 데이터 세트를 한 번만 학습함(1 epoch)
Take away
- 큰 데이터 세트가 성능 향상을 가져올 수 있다
- Memory Bandwidth를 고려한 모델 최적화가 중요하다
- 통신과 계산이 최대한 겹치도록 하여 병렬 연산의 효율을 높일 수 있다