이전까지 Model을 만들어내는 것 까지 완료를 했다.
그런데 아직 모델은 좋은 생성모델이 아닌 상태. (막무가내 출력)
어떻게 좋아지게 만들까..
import torch
torch.manual_seed(1337)
B,T,C = 4, 8, 2 # Batch, Time, Channel
x = torch.randn(B,T,C)
x.shape
Toy example로 self-attention의 트릭을 살펴보자.
먼저 B, T, C 의 데이터를 만들어본다.
이전 데이터셋을 예시로 들면 T size의 sequence를 가지는 character의 집합이 Batch크기만큼 있는것이고, 각각의 T sequence 위치에서는 C만큼 정보가 담겨있다.
현재는 이 token들 사이에 아무런 communication이 일어나고 있지 않고 있다. 이러한 부분을 우리는 이어주고 싶다. 왜냐면 우리가 일반적으로 다음 단어를 예측하는데 있어서 이전 정보들을 잘 aggregate하는 정보를 활용(context : 문맥을 활용한다)해서 예측을 하는데, 이러한 inductive bias가 반영되었다고 볼 수 있다.
방법은 여러가지가 있는데, 만약에 예를들어서 5번째 위치한 토큰같은 경우는 이전 토큰은 볼 수 있지만, 6, 7, 8번째의 token과는 연결이 되어서는 안된다. 우리는 어떤 단어 혹은 알파벳을 예측할 때 이전의 정보만을 사용할 수 있기 때문이다.
따라서 5번째 token은 4,3,2,1 번째 토큰들만 연결함으로써 6번째 token에 대해 예측할 수 있다. 이를 연결할 수 있는 가장 간단한 방법은 그냥 이전 토큰들에 대한 C 정보들을 average하는 것이다. 그러면 어떤 history에 대한 summary를 담고 있는 Vector를 얻어낼 것이고(C) 이를 통해 다음 예측을 진행할 수 있다.
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
for t in range(T):
xprev = x[b,:t+1] # (t,C)
xbow[b,t] = torch.mean(xprev,axis=0)
x bag of word. average of the prev tokens.
x[0]
tensor([[ 0.1808, -0.0700],
[-0.3596, -0.9152],
[ 0.6258, 0.0255],
[ 0.9545, 0.0643],
[ 0.3612, 1.1679],
[-1.3499, -0.5102],
[ 0.2360, -0.2398],
[-0.9211, 1.5433]])
xbow[0]
tensor([[ 0.1808, -0.0700],
[-0.0894, -0.4926],
[ 0.1490, -0.3199],
[ 0.3504, -0.2238],
[ 0.3525, 0.0545],
[ 0.0688, -0.0396],
[ 0.0927, -0.0682],
[-0.0341, 0.1332]])
이전토큰들의 모든 average를 의미.
근데 이런식으로 for문이 두개하는건 sequence length가 길어질수록 엄청 비효율적임.
이를 matrix multiplication으로 해결
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)
a=
tensor([[1.0000, 0.0000, 0.0000],
[0.5000, 0.5000, 0.0000],
[0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
[6., 4.],
[6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
[4.0000, 5.5000],
[4.6667, 5.3333]])
이렇게 간단한 matrix-multiplication 을 활용해서 for문을 두번 돌리는 방법없이, 매우 쉽게 계산할 수 있다.
# version 2: using matrix multiply for a weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2)
True
softmax를 사용해서 할 수도 있음.
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)
True
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf, -inf, -inf],
[0., 0., 0., 0., -inf, -inf, -inf, -inf],
[0., 0., 0., 0., 0., -inf, -inf, -inf],
[0., 0., 0., 0., 0., 0., -inf, -inf],
[0., 0., 0., 0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0., 0., 0., 0.]])
같게 된다.
그런데 위 예시들은 모든 이전 토큰들에 대해서 동일한 가중치로 average된 embedding vector를 사용하고 있다.
이렇게 되면 어떤 단어가 다음 단어를 예측하는데 중요한지, 혹은 연관이 있는지 고려하지 않고 그냥 단순 평균으로 계산하는 것이기 때문에 예측하는데 있어서 정확한 예측이 어려워진다.
이를 모델이 학습하도록 만들어보면 아래와 같은 self-attention을 만들어낼 수 있다. 즉 weighted sum (이전 token들 중에 어떤 것들을 집중해서 볼지를 결정해줌)
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)
# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
v = value(x)
out = wei @ v
#out = wei @ x
out.shape
torch.Size([4, 8, 16])
그러면 이전 토큰들 중에서 필요한 정보들만 가져오는 거니까 효과적인 모델이 될 것이다.(weighted aggregation -> affinity)
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)
# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
v = value(x)
out = wei @ v
#out = wei @ x
out.shape
torch.Size([4, 8, 16])
wei[0]
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
[0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
[0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
[0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
grad_fn=<SelectBackward0>)
Notes:
tril
, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.wei
by 1/sqrt(head_size). This makes it so when input Q, K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much.