[Transformer] Multi-Head Self-Attention ( MSA )

Seongkeun·2024년 5월 13일
2

ML

목록 보기
3/6
post-thumbnail

QKV 가 무엇인지 궁금하다면 [Transformer] QKV ( Query, Key, Value ) 를 간단하게 살펴보고 오자.

시작하기 앞서...Single Self-Attention

시작하기 앞서 이 포스팅에서 MSA 수식에 중점을 둔다. Single Self-Attention 에서는 DhD_h 가 사용되지 않는다. 왜냐하면 단일 헤드기 때문이다. Single Self-Attention 메커니즘으로만 놓고보면 수식은 아래와 같다.

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
dkd_k 는 key 의 차원 수를 말한다.

MSA 수식

[q,k,v]=zUqkv,UqkvRD×3Dh[q, k, v] = zU_{qkv}, \quad U_{qkv} \in \mathbb{R}^{D \times 3D_h}

A=softmax(qkTDh),ARN×NA = \text{softmax}\left(\frac{q\cdot k^T}{\sqrt{D_h}}\right), \quad A \in \mathbb{R}^{N \times N}

SA(z)=Av,AvRN×DhSA(z) = A\cdot v, \quad A\cdot v \in \mathbb{R}^{N \times D_h}

MSA(z)=[SA1(z);SA2(z);;SAk(z)]Umsa,UmsaRkDh×DMSA(z) = [SA_1(z); SA_2(z); \ldots; SA_k(z)]U_{msa}, \quad U_{msa} \in \mathbb{R}^{k \cdot D_h \times D}

MSA 는 MHA 라고도 불리우는 것 같지만 이 글에서는 MSA 라고 하겠다. 그리고 MSA에 대한 설명을 진행하기 전에, 내 글보다 더 좋은 글을 찾아서 공유하려고한다. Transformer 에 대한 전반적인 설명이 들어가 있는 좋은 글이다.
Transformer 파헤치기 - Multi-Head Attention

어찌 되었든 위의 MSA 수식을 이제부터 천천히 풀어가겠다.

Self-Attention for MSA

[q,k,v]=zUqkv,UqkvRD×3Dh[q, k, v] = zU_{qkv}, \quad U_{qkv} \in \mathbb{R}^{D \times 3D_h}

A=softmax(qkTDh),ARN×NA = \text{softmax}\left(\frac{q\cdot k^T}{\sqrt{D_h}}\right), \quad A \in \mathbb{R}^{N \times N}

SA(z)=Av,AvRN×DhSA(z) = A\cdot v, \quad A\cdot v \in \mathbb{R}^{N \times D_h}

  • Dh=DkD_h= \frac{D}{k} ( 왼쪽의 kk 는 위의 k(key)k(key) 가 아니라 attention head 의 개수를 의미한다)
  • AA 는 Attention Score
  • kTk^TTT는 Transpose 는 전치를 의미하며 행으로 늘어선 배열과 열로 늘어선 배열의 내적을 구하기 위함이다.
    • TMI ) 수학적 표현으로 벡터의 내적은 일(W)를 구하기 위해 사용된다.
  • SASA 는 MSA의 Self-Attention
  • vv 는 Query, Key, Value 에서의 Value(v)

RD×3Dh\mathbb{R}^{D \times 3D_h} 에서 D×3DhD \times 3D_h

DhD_h( Dk\frac{D}{k} ) 는 헤더들이 수행해야하는 총 수행량이라고 보면 되고, 그 앞의 3은 세 군데(각 q,k,v) 전부에 대한 연산을 해야함 암시한다. 마지막으로 DD 는 위에서 언급한 것들이 총 Dimension 에 대해 연산이 수행되어야 함을 의미한다.

q=zwq,wqRD×Dhq = z \cdot w_q, \quad w_q \in \mathbb{R}^{D \times D_h}

k=zwk,wkRD×Dhk = z \cdot w_k, \quad w_k \in \mathbb{R}^{D \times D_h}

v=zwv,wvRD×Dhv = z \cdot w_v, \quad w_v \in \mathbb{R}^{D \times D_h}

[q,k,v]=zUqkv,UqkvRD×3Dh[q, k, v] = z \cdot U_{qkv}, \quad U_{qkv} \in \mathbb{R}^{D \times 3D_h}

softmax(qkTDh)\text{softmax}\left(\frac{q\cdot k^T}{\sqrt{D_h}}\right)

qkTq\cdot k^Tqq, kTk^T 두 벡터의 행렬곱이자 "행렬 벡터의 내적" 이다. 또한, 내적( 행렬곱 )을 사용하는 이유는 벡터 행렬간의 유사도를 계산하기 위함이다. 머신러닝에서는 두 벡터의 내적을 통해 유사도를 계산하는 것은 일반화 되어있다( 코사인 유사도와 같이 언급됨 ).

위 이미지를 보면 Key 가 이미 Transpose( 전치 ) 되어 있는 상태로 행렬곱이 되었다. 이는 내적을 계산하려면 두 벡터의 차원이 일치해야 하기 때문이다.

왜 차원이 일치해야하는지, 행렬곱의 특성에 대해 알고있다면 바로 수긍 가능하다.

여기 A와 B라는 행렬이 있다. A와 B 는 각각 아래와 같은 행과 열을 갖는다.
A=m×nA=m\times{n}
B=p×nB=p\times{n}

에서, A×BA\times{B} 는 성립할 수가 없다. 그래서 우리는 B를 전치해서 아래와 같이 만들어 줘야한다.

A=m×nA=m\times{n}
BT=n×pB^T=n\times{p}

A×BT=m×pA\times{B^T}=m\times{p}

Dh\sqrt{D_h}qkTq\cdot k^T 내적을 나눠주는 이유는 내적의 결과를 안정적으로 만들기 위함이다. 즉, Dh\sqrt{D_h} 는 내적의 결과를 스케일링 해주는 값이다. 이 스케일링을 통해 그라디언트가 과도하게 커지거나 작아지는 것을 방지한다.

softmaxsoftmax 는 위 연산을 거친 점수들을 확률적으로 변환시킨다. 그리고 output 은 'Attention Weights' 라고도 불리운다. 즉, 각 입력에 대한 출력의 가중치를 의미한다. 이 과정을 거치면 Attention Score 의 유사도는 0 ~ 1사이가 된다.

마지막으로 여기 softmax 를 거쳐서 나온 Attention Score 에 Value(v) 를 곱하면 Self-Attention ( SASA ) 이 된다.

Multi-Head Self-Attention ( MSA )

MSA(z)=[SA1(z);SA2(z);;SAk(z)]Umsa,UmsaRkDh×DMSA(z) = [SA_1(z); SA_2(z); \ldots; SA_k(z)]U_{msa}, \quad U_{msa} \in \mathbb{R}^{k \cdot D_h \times D}

kk : 'head' 의 개수
UU : 변환 행렬 ( transform matrix )

MSA 는 Transformer 기반 모든 모델에서 연산량이 가장 많이 요구되는 부분이기도 하다.

MSA 는 'head' 의 개수만큼 독립적인 SA 를 갖고있고, 각각의 head 가 SA 를 병렬로 계산한다. 그러므로, MSA는 여러 'head' 들이 각각 주어진 SA의 Query, Key, Value 행렬을 계산하고, Attention 가중치를 병렬로 계산한다. 그리고, 이 가중치들은 최종적으로 concat 되어서 통합된 인사이트를 제공해준다.

추가로 위 수식 RkDh×D\mathbb{R}^{k \cdot D_h \times D} 을 덧붙여 말하자면, SASA 연산을 수행하면 차원이 줄어든다. 그렇기에 RkDh×D\mathbb{R}^{k \cdot D_h \times D} 에서 k(head 의 개수)k\text{(head 의 개수)} 가 곱해진 것이다. 이로써, SASA 연산으로 줄어든 차원이 복구된다.

위 포스팅은 [논문리뷰] ViT: OpenAI sora DiTs 의 근간 을 이해하는데 도움이 되는 글이다.
이제 논문을 이해하러 가보자.

profile
지혜는 지식에서 비롯된다

0개의 댓글