We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d252e36 commit 07adfccCopy full SHA for 07adfcc
vllm_ascend/sample/sampler.py
@@ -51,7 +51,10 @@ def _apply_top_k_top_p(
51
logits.masked_fill_(elements_to_discard, -float("inf"))
52
53
if p is not None:
54
+ old_dtype = probs_sort.dtype
55
+ probs_sort = probs_sort.to(torch.float32)
56
cumprob = torch.cumsum(probs_sort, dim=-1)
57
+ cumprob = cumprob.to(old_dtype)
58
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
59
top_p_mask[:, -1] = False # at least one
60
0 commit comments