From 2faf2e700d5921a50433297000e519a47ed0a705 Mon Sep 17 00:00:00 2001 From: chenaoxuan Date: Tue, 25 Nov 2025 19:26:44 +0800 Subject: [PATCH] Add MagicMTP(block verify) and Triton optimization Signed-off-by: chenaoxuan --- mypy.ini | 3 + vllm_ascend/sample/rejection_sampler.py | 331 +++++++++++++++++++++--- 2 files changed, 292 insertions(+), 42 deletions(-) diff --git a/mypy.ini b/mypy.ini index 7778a6f1bde..a1d9a916822 100644 --- a/mypy.ini +++ b/mypy.ini @@ -27,3 +27,6 @@ ignore_missing_imports = True [mypy-msprobe.*] ignore_missing_imports = True allow_untyped_imports = True + +[mypy-triton.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index a17f534045e..49f046234ee 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -4,12 +4,149 @@ import torch import torch.nn as nn import vllm.v1.sample.rejection_sampler as rs +from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (RejectionSampler, apply_sampling_constraints, generate_uniform_probs) from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +if HAS_TRITON: + + @triton.jit(do_not_specialize=["max_spec_len"]) + def rejection_greedy_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + target_argmax_ptr, # [num_tokens] + bonus_token_ids_ptr, # [batch_size] + is_greedy_ptr, # [batch_size] or None + max_spec_len, + ): + req_idx = tl.program_id(0) + # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, + # re-compilation may happen during runtime when is_greedy_ptr is None. + is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + + req_idx) + if not is_greedy: + # Early exit for non-greedy sampling requests. + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id, + ) + if draft_token_id != target_argmax_id: + # Reject. + rejected = True + + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, + bonus_token_id, + ) + + @triton.jit(do_not_specialize=["max_spec_len"]) + def rejection_random_sample_block_verify_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + NO_DRAFT_PROBS: tl.constexpr, + SUB_BLOCK: tl.constexpr = 1500, + ): + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exit for greedy sampling requests. + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + pi = 1.0 + uniform_prob = 1.0 + last_accepted_token_pos = -1 + + for pos in range(num_draft_tokens): + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + uniform_prob = uniform_prob * tmp_uniform_prob + + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + + pi = min(pi * target_prob / draft_prob, 1.0) + if draft_prob > 0 and pi >= uniform_prob: + last_accepted_token_pos = pos + rejected = False + else: + rejected = True + + if last_accepted_token_pos > -1: + for pos in range(last_accepted_token_pos + 1): + token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + token_id) + + if rejected: + loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK + global_recovered_id = -1 + global_max_p = -1.0 + for loop_i in range(loop): + vocab_start = loop_i * SUB_BLOCK + vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) + tmp_target_prob = tl.load( + target_probs_ptr + + (start_idx + last_accepted_token_pos + 1) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + recovered_id = tl.argmax(tmp_target_prob, axis=-1) + max_p = tl.get_element(tmp_target_prob, (recovered_id, )) + if max_p > global_max_p: + global_max_p = max_p + global_recovered_id = vocab_start + recovered_id + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + last_accepted_token_pos + 1, global_recovered_id) + else: + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) + + PLACEHOLDER_TOKEN_ID = -1 GREEDY_TEMPERATURE = -1 # Maximum number of speculative draft tokens allowed per request in a single @@ -134,6 +271,9 @@ def rejection_sample( assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) + # When num_speculative_tokens>=3, using block verify. + using_block_verify = max_spec_len >= 3 + # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), @@ -149,25 +289,36 @@ def rejection_sample( if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) - if min(num_draft_tokens) == 1 and max( - num_draft_tokens) == 1 and sampling_metadata.all_greedy: - rejection_greedy_sample_spec_len_1_pytorch( - output_token_ids, - draft_token_ids, - target_argmax, - bonus_token_ids, - ) - else: - rejection_greedy_sample_pytorch( + if HAS_TRITON: + rejection_greedy_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, - num_draft_tokens, - max_spec_len, is_greedy, + max_spec_len, ) + else: + if min(num_draft_tokens) == 1 and max( + num_draft_tokens) == 1 and sampling_metadata.all_greedy: + rejection_greedy_sample_spec_len_1_pytorch( + output_token_ids, + draft_token_ids, + target_argmax, + bonus_token_ids, + ) + else: + rejection_greedy_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + num_draft_tokens, + max_spec_len, + is_greedy, + ) if sampling_metadata.all_greedy: return output_token_ids @@ -178,37 +329,68 @@ def rejection_sample( num_draft_tokens, sampling_metadata.generators, device, - ) - - # Sample recovered tokens for each position. - # [num_tokens] - recovered_token_ids = sample_recovered_tokens( - max_spec_len, - num_draft_tokens, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - sampling_metadata, - device, - ) + ).to(torch.float32) + + if not using_block_verify: + # Sample recovered tokens for each position. + # [num_tokens] + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device, + ) - # Rejection sampling for random sampling requests. - rejection_random_sample_pytorch( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - IS_NGRAM=draft_probs is None, - # num_warps=1, - ) + # Rejection sampling for random sampling requests. + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + # num_warps=1, + ) + else: + # MagicMTP: Improving acceptance rate with Block Verify. + if HAS_TRITON: + rejection_random_sample_block_verify_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + NO_DRAFT_PROBS=draft_probs is None, + multibuffer=True, + ) + else: + rejection_random_sample_block_verify_pytorch(output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs + is None) return output_token_ids @@ -504,4 +686,69 @@ def sample_recovered_tokens_pytorch( target_probs[token_idx, draft_token_id] = orig_prob +def rejection_random_sample_block_verify_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM=False, +): + batch_size = output_token_ids.shape[0] + + for req_idx in range(batch_size): + if is_greedy[req_idx]: + continue + + if req_idx == 0: + start_idx = 0 + else: + start_idx = cu_num_draft_tokens[req_idx - 1].item() + end_idx = cu_num_draft_tokens[req_idx].item() + num_draft_tokens = end_idx - start_idx + + rejected = False + pi = 1.0 + uniform_prob = 1.0 + last_accepted_token_pos = -1 + for pos in range(num_draft_tokens): + draft_token_id = draft_token_ids[start_idx + pos].item() + + target_prob = target_probs[start_idx + pos, draft_token_id].item() + uniform_prob = uniform_prob * uniform_probs[start_idx + pos].item() + + if IS_NGRAM: + draft_prob = 1.0 + else: + draft_prob = draft_probs[start_idx + pos, + draft_token_id].item() + + pi = min(pi * target_prob / draft_prob, 1.0) + + if draft_prob > 0 and pi >= uniform_prob: + last_accepted_token_pos = pos + rejected = False + else: + rejected = True + + if last_accepted_token_pos > -1: + for pos in range(last_accepted_token_pos + 1): + draft_token_id = draft_token_ids[start_idx + pos].item() + output_token_ids[req_idx, pos] = draft_token_id + + if rejected: + recovered_token_id = torch.argmax( + target_probs[start_idx + last_accepted_token_pos + 1]).item() + output_token_ids[req_idx, + last_accepted_token_pos + 1] = recovered_token_id + else: + bonus_token_id = bonus_token_ids[req_idx].item() + output_token_ids[req_idx, num_draft_tokens] = bonus_token_id + + rs.expand_batch_to_tokens = expand_batch_to_tokens