def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p # probs sum에서 해당 prob값을 뺏을 때, 나머지. 한마디로 높은애들 중에서 자기자신 뺀 건데, 그 값이 0.95(threshold)보다 크면,(자기자신이 차지하는 비율이 작다는 얘기인듯.) masking한다는거.(True)
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
LLM의 최종 output (next token)은 결국 sampling을 통해서 결정된다.
이 때 LLaMA3.1에서 사용되는 method는 sample_top_p이다.
해당 함수는 probs
와 p
를 입력으로 받는데 각각은 다음을 의미한다.
probs (torch.Tensor)
: Linear Layer를 통과한 logits값들 주에서 마지막 위치(바로 다음을 예측)에 해당되는 token의 softmax probability. 이 때 temperature가 반영되어 계산된다(llama generate.py).
probs = torch.softmax(logits[:, -1]/ temperature, dim=-1)
그리고 p
같은 경우는 threshold를 의미하는데 상위 logit에 해당되는 token들 중에서 하나를 선택하겠다는 의미로 받아들일 수 있다.
-> 여기서 만약에 temperature가 높은(1.0에 가까운) 경우, logit값들은 12만개의 골고루 뿌려진 형태의 distribution을 형성할 것이다. 반면에 temperature가 낮은(0.0에 가까운 경우) 하나의 logit값이 매우 크게 반영이 되겠고, 샘플링을 진행할 때 가장 높은 토큰만을 선택하게 될 것임.
In the context of the sample_top_p
function you've provided, "cumulative probability mass" refers to the sum of probabilities for a sequence of tokens when sorted in descending order of their individual probabilities.
Probability Distribution (probs
):
Sorting the Probabilities:
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sort
), and probs_idx
holds the corresponding indices of the tokens after sorting.Calculating Cumulative Probability Mass:
probs_sum = torch.cumsum(probs_sort, dim=-1)
torch.cumsum
computes the cumulative sum of the sorted probabilities. This means each entry in probs_sum
contains the sum of all previous probabilities up to that point in the sorted list.For example, if probs_sort = [0.4, 0.3, 0.2, 0.1]
, then probs_sum
would be [0.4, 0.7, 0.9, 1.0]
.
Top-p Sampling:
p
.mask = probs_sum - probs_sort > p
identifies tokens where the cumulative probability exceeds p
after subtracting the current token's probability. These tokens will be excluded from sampling.Renormalization and Sampling:
next_token = torch.multinomial(probs_sort, num_samples=1)
samples from the renormalized distribution.p
, top-p sampling involves including tokens until this cumulative mass exceeds p
. The intuition is that you only sample from the most probable tokens that together cover at least p
percent of the total probability distribution, making the sampling process focus on the most likely tokens while still allowing for some diversity.In summary, cumulative probability mass in this situation refers to the sum of probabilities up to a certain point in the sorted list, and it is used to decide which tokens are eligible for sampling in the top-p (nucleus) sampling method.
probs_sort = [0.4/0.9, 0.3/0.9, 0.2/0.9, 0.0]
≈ [0.444, 0.333, 0.222, 0.0]
```