@torch.inference_mode()
def generate(
self,
prompt_tokens: List[List[int]],
max_gen_len : int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
) -> Tuple[List[List[int]]], Optional[List[List[float]]]:
prompt_tokens (List\[List\[int\]\])
: tokenized된 prompt의 리스트를 의미함. 각 프롬프트는 list of integer를 의미하고 있음.max_gen_len (int)
: 만들어낼 텍스트 시퀀스의 최대치.temperature (float, optional)
: Temperature value for controlling randomness in sampling, Defaults to 0.6
. -> 이게 솔직히 제일 충격적인게, text sequence sampling시에 랜덤성을 주려고 0.6의 temperature를 주는거.. -> 소름..top_p (float, optional)
: Top-p probability threshold for nucleus sampling. Defaults to 0.9
. -> 얜 또 뭐야 ?logprobs (bool, optional)
: Flag indicating whether to compute token log probabilities. Default to False
. logprobability 뽑을꺼냐 말거냐. 뽑으면 결과 위에서 Optional[List[List[float]]] 에 들어갈듯.echo (bool, optional)
: Flag indicating whether to include prompt tokens in the generated output. Defaults to False
. -> prompt token을 generated output에 포함시킬꺼냐 말꺼냐. 당연히 False해야지 안나옴Tuple[List[List[int]]], Optional[List[List[float]]]
: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. params = self.model.params
bsz = len(prompt_tokens) # batch size
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens) # 가장 작은 prompt seq len
max_prompt_len = max(len(t) for t in prompt_tokens) # 가장 긴 prompt seq len
assert max_prompt_len <= params.max_seq_len # prompt는 max 입력 시퀀스보다 작아야됨
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) # 생성이랑 prompt 합친게 max 입력 시퀀스보다는 작아야됨.
pad_id = self.tokenizer.pad_id # pad index
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") 일단 pad index로 batch사이즈만큼의 시퀀스들을 다 채움
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") 입력 prompt sequence까지는 패드가 아니라 실제 token 값이 채워짐.
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float) # 모든 배치, 토큰에 대한 logprob값들.
prev_pos = 0 # 뭐야 이건
eos_reached = torch.tensor([False] * bsz, device="cuda") # [False]값을 배치사이즈만큼 만들어
input_text_mask = tokens != pad_id # pad_id랑 token이랑 다른부분들은 True, 같은 부분들은 False가 되겠지.
if min_prompt_len == total_len: # 가장작은 프롬프트가, 맥스 프롬프트 + gen prompt랑 같다 ? 아래 확인
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1,2),
target=tokens,
reduction="none",
ignore_index=pad_id,
) # 거의 없는 경우같음.
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) # 이전토큰 첫번째부터, 지금토큰 전까지를 넣어서 얻은 logits.
if temperature > 0:
probs = torch.softmax(logits[:, -1]/ temperature, dim=-1) # temperature 여기서 사용됨, probs의 랜덤성 부여
next_token = sample_top_p(probs, top_p) # sample_top_p 확인, 샘플링해서 뽑음.
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token # 현재 부분에만 next token에 예측값을 집어넣는다.
if logprobs:
token_logprobs[:, prev_pos+1, cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1,2),
target=tokens[:, prev_pos+1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
prev_pos = cur_pos
if all(eos_reached):
break
if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
for stop_token in self.tokenizer.stop_tokens:
try:
eos_idx = tok.index(stop_token)
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None
except ValueError:
pass
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)
max_gen_len
: 텍스트 시퀀스를 생성할때 가장 긴 length
min_prompt_len
: 입력받은 prompt token들 중에서 가장 작은 길이의 prompt
max_prompt_len
: 입력받은 prompt token들 중에서 가장 큰 길이의 prompt
assert max_prompt_len <= params.max_seq_len
: 여기서 params.max_seq_len은 모델이 입력으로 받을 수 있는 최대길이 시퀀스를 의미 (args에서 2048)
total_len = min(params.max_seq_len(2048), max_gen_len + max_prompt_len)
- 더 작을 수 있음. total len 자체는. max_gen_len + max_prompt_len
일 가능성이 높겠지.
The condition if min_prompt_len == total_len: is checking for a special edge case where the smallest input prompt length exactly matches the total sequence length (total_len). This usually means that the model's input sequence is fully occupied by the shortest prompt, leaving no room for token generation, potentially indicating that the generation process should not proceed or needs special handling.
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
input_text_mask
:
True
) and which are padding or non-relevant positions (likely marked with False
).cur_pos
:
next_token
:
cur_pos
.stop_tokens
:
~input_text_mask[:, cur_pos]
:
~
operator is a bitwise NOT operation that inverts the boolean values in input_text_mask[:, cur_pos]
.input_text_mask[:, cur_pos]
selects the mask values at the current position across all sequences (if dealing with batch processing). The ~
operator then inverts these values, so True
becomes False
and False
becomes True
.torch.isin(next_token, stop_tokens)
:
next_token
exists within the stop_tokens
tensor.True
if the corresponding next_token
is in stop_tokens
and False
otherwise.eos_reached
:
False
for all sequences.|=
:
|=
operator is a bitwise OR combined with assignment. It updates eos_reached
in place by performing a bitwise OR operation between eos_reached
and the result of the expression on the right-hand side.The code snippet can be broken down into the following steps:
Identify positions where input text is not masked:
~input_text_mask[:, cur_pos]
: This identifies positions where the input text is not masked (i.e., where the mask was False
, indicating a real token).Check if the next token is a stop token:
torch.isin(next_token, stop_tokens)
: This checks if the next_token
at the current position is one of the specified stop tokens.Combine the two conditions:
(~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
: This creates a boolean tensor where positions are True
if:Update eos_reached
:
eos_reached |= (...)
: The result is OR-ed with the existing eos_reached
tensor. This updates eos_reached
to mark positions where either:eos_reached
was already True
), oreos_reached
tensor, which tracks whether the end of a sequence (EOS) has been reached for each sequence in a batch.next_token
is a stop token.eos_reached
tensor using a bitwise OR, ensuring that once EOS is detected, it remains flagged as such.This is likely part of a larger loop or function that processes a sequence of tokens, checking at each step if an EOS condition has been met for any of the sequences in a batch.