[llama3/llama/generation.py][class Llama] def generate

ma-kjh·2024년 8월 28일
0

LLM

목록 보기
11/14
@torch.inference_mode()
def generate(
	self,
    prompt_tokens: List[List[int]],
    max_gen_len : int,
    temperature: float = 0.6,
    top_p: float = 0.9, 
    logprobs: bool = False,
    echo: bool = False, 
) -> Tuple[List[List[int]]], Optional[List[List[float]]]:
  • text sequences를 만들어내는 함수.
  • prompt를 입력으로 받아서 텍스트를 만들어냄.

Args:

  • prompt_tokens (List\[List\[int\]\]): tokenized된 prompt의 리스트를 의미함. 각 프롬프트는 list of integer를 의미하고 있음.
  • max_gen_len (int): 만들어낼 텍스트 시퀀스의 최대치.
  • temperature (float, optional): Temperature value for controlling randomness in sampling, Defaults to 0.6. -> 이게 솔직히 제일 충격적인게, text sequence sampling시에 랜덤성을 주려고 0.6의 temperature를 주는거.. -> 소름..
  • top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. -> 얜 또 뭐야 ?
  • logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Default to False. logprobability 뽑을꺼냐 말거냐. 뽑으면 결과 위에서 Optional[List[List[float]]] 에 들어갈듯.
  • echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. -> prompt token을 generated output에 포함시킬꺼냐 말꺼냐. 당연히 False해야지 안나옴

Returns:

  • Tuple[List[List[int]]], Optional[List[List[float]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.

Note:

  • This method uses the provided prompts as a basis for generating txt.
  • It employs nucleus sampling to produce text with controlled randomness. -> 이거 파악하기
  • If logprobs is True, token log probabilities are computed for each generated token.
	params = self.model.params
    bsz = len(prompt_tokens) # batch size
    assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
    
    min_prompt_len = min(len(t) for t in prompt_tokens) # 가장 작은 prompt seq len
    max_prompt_len = max(len(t) for t in prompt_tokens) # 가장 긴 prompt seq len
    assert max_prompt_len <= params.max_seq_len # prompt는 max 입력 시퀀스보다 작아야됨
    total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) # 생성이랑 prompt 합친게 max 입력 시퀀스보다는 작아야됨.
    
    pad_id = self.tokenizer.pad_id # pad index
    tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") 일단 pad index로 batch사이즈만큼의 시퀀스들을 다 채움
    for k, t in enumerate(prompt_tokens):
    	tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") 입력 prompt sequence까지는 패드가 아니라 실제 token 값이 채워짐.
    if logprobs:
    	token_logprobs = torch.zeros_like(tokens, dtype=torch.float) # 모든 배치, 토큰에 대한 logprob값들.
    prev_pos = 0 # 뭐야 이건
    eos_reached = torch.tensor([False] * bsz, device="cuda") # [False]값을 배치사이즈만큼 만들어
    input_text_mask = tokens != pad_id # pad_id랑 token이랑 다른부분들은 True, 같은 부분들은 False가 되겠지.
    if min_prompt_len == total_len: # 가장작은 프롬프트가, 맥스 프롬프트 + gen prompt랑 같다 ? 아래 확인
    	logits = self.model.forward(tokens, prev_pos)
        token_logprobs = -F.cross_entropy(
        	input=logits.transpose(1,2),
            target=tokens,
            reduction="none",
            ignore_index=pad_id,
        ) # 거의 없는 경우같음.
    	
    stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
    
    for cur_pos in range(min_prompt_len, total_len):
    	logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) # 이전토큰 첫번째부터, 지금토큰 전까지를 넣어서 얻은 logits.
        if temperature > 0:
        	probs = torch.softmax(logits[:, -1]/ temperature, dim=-1) # temperature 여기서 사용됨, probs의 랜덤성 부여
            next_token = sample_top_p(probs, top_p) # sample_top_p 확인, 샘플링해서 뽑음.
        else:
        	next_token = torch.argmax(logits[:, -1], dim=-1)
            
        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
        	input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
    	)
        tokens[:, cur_pos] = next_token # 현재 부분에만 next token에 예측값을 집어넣는다.
        if logprobs:
        	token_logprobs[:, prev_pos+1, cur_pos + 1] = -F.cross_entropy(
            	input=logits.transpose(1,2),
                target=tokens[:, prev_pos+1 : cur_pos + 1],
                reduction="none",
                ignore_index=pad_id,
            )
        eos_reached |= (~input_text_mask[:, cur_pos]) & (
        	torch.isin(next_token, stop_tokens)
        )
        prev_pos = cur_pos
        if all(eos_reached):
        	break
        
    if logprobs:
    	token_logprobs = token_logprobs.tolist()
    out_tokens, out_logprobs = [], []
    
    for i, toks in enumerate(tokens.tolist()):
    	# cut to max gen len
        start = 0 if echo else len(prompt_tokens[i])
        toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
        probs = None
        if logprobs:
        	probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
        for stop_token in self.tokenizer.stop_tokens:
        	try:
            	eos_idx = tok.index(stop_token)
                toks = toks[:eos_idx]
                probs = probs[:eos_idx] if logprobs else None
            except ValueError:
            	pass
            out_tokens.append(toks)
            out_logprobs.append(probs)
        return (out_tokens, out_logprobs if logprobs else None)
        
    
    
    
    

max_gen_len : 텍스트 시퀀스를 생성할때 가장 긴 length
min_prompt_len : 입력받은 prompt token들 중에서 가장 작은 길이의 prompt
max_prompt_len : 입력받은 prompt token들 중에서 가장 큰 길이의 prompt

  • assert max_prompt_len <= params.max_seq_len : 여기서 params.max_seq_len은 모델이 입력으로 받을 수 있는 최대길이 시퀀스를 의미 (args에서 2048)

  • total_len = min(params.max_seq_len(2048), max_gen_len + max_prompt_len)
    - 더 작을 수 있음. total len 자체는. max_gen_len + max_prompt_len일 가능성이 높겠지.

  • The condition if min_prompt_len == total_len: is checking for a special edge case where the smallest input prompt length exactly matches the total sequence length (total_len). This usually means that the model's input sequence is fully occupied by the shortest prompt, leaving no room for token generation, potentially indicating that the generation process should not proceed or needs special handling.

eos_reached |= (~input_text_mask[:, cur_pos]) & (
                torch.isin(next_token, stop_tokens)
            )

Understanding the Components:

  1. input_text_mask:

    • This is likely a boolean tensor that masks the input text, indicating which positions in the input text are actual tokens (likely marked with True) and which are padding or non-relevant positions (likely marked with False).
  2. cur_pos:

    • This is an index representing the current position in the sequence that the model is processing.
  3. next_token:

    • This is the token predicted by the model at the current position cur_pos.
  4. stop_tokens:

    • This is a tensor containing tokens that signify stopping points, such as an end-of-sequence (EOS) token, period, or other punctuation marks that indicate the end of a sentence or sequence.
  5. ~input_text_mask[:, cur_pos]:

    • The ~ operator is a bitwise NOT operation that inverts the boolean values in input_text_mask[:, cur_pos].
    • input_text_mask[:, cur_pos] selects the mask values at the current position across all sequences (if dealing with batch processing). The ~ operator then inverts these values, so True becomes False and False becomes True.
  6. torch.isin(next_token, stop_tokens):

    • This function checks if next_token exists within the stop_tokens tensor.
    • It returns a boolean tensor where each element is True if the corresponding next_token is in stop_tokens and False otherwise.
  7. eos_reached:

    • This is a boolean tensor that tracks whether the end of a sequence (EOS) has been reached for each sequence in the batch. It likely starts as False for all sequences.
  8. |=:

    • The |= operator is a bitwise OR combined with assignment. It updates eos_reached in place by performing a bitwise OR operation between eos_reached and the result of the expression on the right-hand side.

Putting It All Together:

The code snippet can be broken down into the following steps:

  1. Identify positions where input text is not masked:

    • ~input_text_mask[:, cur_pos]: This identifies positions where the input text is not masked (i.e., where the mask was False, indicating a real token).
  2. Check if the next token is a stop token:

    • torch.isin(next_token, stop_tokens): This checks if the next_token at the current position is one of the specified stop tokens.
  3. Combine the two conditions:

    • (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)): This creates a boolean tensor where positions are True if:
      • The current position is a real token (not masked), and
      • The next token is a stop token.
  4. Update eos_reached:

    • eos_reached |= (...): The result is OR-ed with the existing eos_reached tensor. This updates eos_reached to mark positions where either:
      • EOS had already been reached earlier (eos_reached was already True), or
      • The current position meets the conditions for reaching EOS as identified in the combined conditions.

Summary of What the Code Does:

  • This code snippet is used to update the eos_reached tensor, which tracks whether the end of a sequence (EOS) has been reached for each sequence in a batch.
  • The EOS is considered to be reached if:
    • The current position is an actual token (not masked), and
    • The predicted next_token is a stop token.
  • The result of this condition is combined with the existing eos_reached tensor using a bitwise OR, ensuring that once EOS is detected, it remains flagged as such.

This is likely part of a larger loop or function that processes a sequence of tokens, checking at each step if an EOS condition has been met for any of the sequences in a batch.

profile
거인의 어깨에 올라서서 더 넓은 세상을 바라보라 - 아이작 뉴턴

0개의 댓글