Skip to content

Commit 07adfcc

Browse files
author
puppy
committed
test op-cumsum
1 parent d252e36 commit 07adfcc

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

vllm_ascend/sample/sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def _apply_top_k_top_p(
5151
logits.masked_fill_(elements_to_discard, -float("inf"))
5252

5353
if p is not None:
54+
old_dtype = probs_sort.dtype
55+
probs_sort = probs_sort.to(torch.float32)
5456
cumprob = torch.cumsum(probs_sort, dim=-1)
57+
cumprob = cumprob.to(old_dtype)
5558
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
5659
top_p_mask[:, -1] = False # at least one
5760

0 commit comments

Comments
 (0)