From 167f1383b60b666a5c9e070f10261f63912bad3d Mon Sep 17 00:00:00 2001 From: liumail1122 Date: Wed, 26 Nov 2025 00:23:17 +0800 Subject: [PATCH] [Performance] Improve the inference performance of Eagle3. vLLM version: v0.11.0 vLLM main: vllm-project/vllm Signed-off-by: liumail202512 --- vllm_ascend/spec_decode/eagle_proposer.py | 130 +++++++++------------- 1 file changed, 50 insertions(+), 80 deletions(-) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 4d076ac117f..01ee382ac64 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -20,8 +20,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import (AscendAttentionState, - AscendMetadata) +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType @@ -68,9 +67,6 @@ def __init__(self, self.hidden_size), dtype=self.vllm_config.model_config.dtype, device=device) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) - self.token_arange_np = np.arange(self.max_num_tokens) # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + @@ -189,8 +185,10 @@ def generate_token_ids(self, dtype=torch.int32, device=self.device, ) - cu_num_tokens, token_indices =\ - self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens) + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) + cu_num_tokens, token_indices = self._prepare_inputs( + eagle_attn_metadata.query_start_loc, num_rejected_tokens, + num_tokens) target_token_ids = self.runner.input_ids[token_indices] target_positions = positions[token_indices] if self.name == SpecDcodeType.EAGLE3: @@ -590,88 +588,60 @@ def _propose( def _prepare_inputs( self, - eagle_attn_metadata: AscendMetadata, + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, # [batch_size] num_rejected_tokens: torch.Tensor, + num_tokens: int, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - This function is used to prepare the inputs for the spec decode. - It updates to the common_attn_metadata to account for the rejected - tokens (and newly sampled tokens). It also returns the token indices - of the tokens that should be fed to the speculator. - """ - # E.g. - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] - # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] - # num_rejected_tokens: [n1, n2, n3] - # This function computes the intermediate values: - # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] - # And returns: - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] - # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - num_rejected_tokens_cpu = num_rejected_tokens.to("cpu") - cu_target_query_lens = eagle_attn_metadata.query_start_loc - device = eagle_attn_metadata.query_start_loc.device - query_start_loc_cpu = cu_target_query_lens.to("cpu") - - # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) - # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] - new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu - new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() - - # [q1 - n1, q2 - n2, q3 - n3] -> - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - new_query_start_loc_cpu = torch.zeros( - query_start_loc_cpu.shape, - dtype=torch.int32, - pin_memory=is_pin_memory_available()) - new_query_start_loc_np = new_query_start_loc_cpu.numpy() - np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) - - total_num_tokens = new_query_start_loc_np[-1] - # Example assuming num_tokens_per_req_np = [2, 4, 3] - # this implies that `new_query_start_locs` is: - # [0, 2, 6, 9] -> - # [0, 0, 2, 2, 2, 2, 6, 6, 6] - # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) - # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> - # [0, 1, 0, 1, 2, 3, 0, 1, 2] - # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded - - # Expand starting positions to match token pattern - # [0, q1, q1 + q2] -> - # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] - # _r1_ _____r2_______ ___________r3____________ - old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) - # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 - token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) - - # need use npu + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] query_len_per_req = (cu_target_query_lens[1:] - cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] num_tokens_per_req = query_len_per_req - num_rejected_tokens # [a - n1, b - n2, c - n3] -> # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] cu_num_tokens = torch.zeros_like(cu_target_query_lens) torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_target_query_lens.device, + ) + BLOCK_SIZE = 1024 + self._prepare_eagle_input_sequential( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) return cu_num_tokens, token_indices + + def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor, + cu_query_lens: torch.Tensor, + cu_num_tokens: torch.Tensor, block_size: int): + device = cu_query_lens.device + dtype = out_tensor.dtype + + offsets = torch.arange(block_size, device=device, dtype=dtype) + start_pos = cu_num_tokens[:-1] + end_pos = cu_num_tokens[1:] + num_tokens = end_pos - start_pos + + global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1)) + values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1)) + + mask = (offsets.view(1, -1) < num_tokens.view(-1, 1)) + + global_indices_flat = global_indices[mask] + values_flat = values[mask] + out_tensor[global_indices_flat] = values_flat \ No newline at end of file