[llama3/llama/generation.py][class Llama] def build 에서 마지막에 model을 불러올 때 Transformer class를 사용해서 불러온다.
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = prams.vocab_size # default가 -1인데..
self.n_layers = params.n_layers # 32 decoder layer를 쌓는다.
self.tok_embeddings = VocabParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
) # 뭐지이건
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params)) # 미리 정의한 TransformerBlock을 32개 쌓아준다.
self.norm = RMSNorm(params.dim, eps=params.norm_eps) # RMSNorm이라는 걸 사용하는데, 나중에 살펴보기.
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x # ColumnParallelLinear (살펴보기)
)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
) # 아마도 positinal embedding 하는 과정 같음.
여기까지가 __init__
하는 부분인데, 사실 어려운 부분은 없고 처음보는 것들이 몇몇 있다.
VocabParallelEmbedding : from fairscale.nn.model_parallel.layers import VocabParallelEmbedding
으로 되어있는데, 나중에 보자.
self.norm = RMSNorm(*)
이 부분에서 처음보는 Norm이 추가되는데 Llama에서 사용되는 거니까 나중에 알아보기.
ColumnParallelLinear : Vocab* 랑 마찬가지.
freqs_cis : 이것도.. 뭐지
그럼 forward보면서 좀 살펴보면
@torch.inference_mode() # 이렇게 하면 훈련은 어떻게해 .. ? 바꿀 수 있는건가. 아님 필요 없나 forward에서는..
def forward(self, tokens: torch.Tensor, start_pos: int): # start_pos ?
_bsz, seqlen = token.shape # 들어오는 입력 token [batch_size, sequence_length] 로 들어올테니, shape를 받는 것으로 보임.
h = self.tok_embeddings(tokens) # 아무래도 그냥 token to embedding으로 보임. [batch_size, sequence_length, emb_dim] 이겠지 ?
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] # 시작 포지션부터, seq len 끝낼때까지만 가져와서 사용하는 것으로 보임.
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) # 일단 (seqlen, seqlen) matrix에서 float `-inf`로 가득채움.
mask = torch.triu(mask, diagonal=1) # diagonal=1 이면 diagonal까지 0으로 만들어버림. 아래부분 다 0
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output
>>> a = torch.randn(3, 3)
>>> a
tensor([[ 0.2309, 0.5207, 2.0049],
[ 0.2072, -1.0680, 0.6602],
[ 0.3480, -0.5211, -0.4573]])
>>> torch.triu(a, diagonal=1)
tensor([[ 0.0000, 0.5207, 2.0049],
[ 0.0000, 0.0000, 0.6602],
[ 0.0000, 0.0000, 0.0000]])
The @torch.inference_mode()
decorator in PyTorch is used to improve the efficiency of inference (i.e., the process of making predictions with a trained model) by disabling certain features that are only necessary during training. Here's what it does:
@torch.inference_mode()
Disables Gradient Calculation:
torch.no_grad()
, torch.inference_mode()
turns off gradient calculation, which is unnecessary during inference and can save memory and computation. This is useful because gradients are only needed when updating model parameters during training, not during inference.Optimizes Memory Usage:
torch.inference_mode()
is more aggressive than torch.no_grad()
in terms of optimizations. It can lead to further memory savings and performance improvements by skipping certain operations that are irrelevant during inference, such as version counter updates on tensors.Read-Only Operations:
@torch.inference_mode()
torch.inference_mode()
helps ensure that the system runs as efficiently as possible.@torch.inference_mode()
def predict(model, inputs):
return model(inputs)
# This will run the model in inference mode, with optimizations applied.
outputs = predict(my_model, my_inputs)
In this example, the predict
function will execute without tracking gradients, with all the additional optimizations provided by torch.inference_mode()
. This is particularly beneficial when deploying models in environments where performance is critical.
@torch.inference_mode()
is a decorator that optimizes the inference process by disabling gradient calculations and applying additional optimizations, making it more efficient than torch.no_grad()
for read-only operations during inference.