-
Notifications
You must be signed in to change notification settings - Fork 622
[Performance] Improve the inference performance of Eagle3. #4441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,8 +20,7 @@ | |||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| from vllm_ascend.ascend_forward_context import set_ascend_forward_context | ||||||||||||||||||||||||||||||||
| from vllm_ascend.attention.attention_mask import AttentionMaskBuilder | ||||||||||||||||||||||||||||||||
| from vllm_ascend.attention.attention_v1 import (AscendAttentionState, | ||||||||||||||||||||||||||||||||
| AscendMetadata) | ||||||||||||||||||||||||||||||||
| from vllm_ascend.attention.attention_v1 import AscendAttentionState | ||||||||||||||||||||||||||||||||
| from vllm_ascend.attention.utils import AscendCommonAttentionMetadata | ||||||||||||||||||||||||||||||||
| from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
@@ -68,9 +67,6 @@ def __init__(self, | |||||||||||||||||||||||||||||||
| self.hidden_size), | ||||||||||||||||||||||||||||||||
| dtype=self.vllm_config.model_config.dtype, | ||||||||||||||||||||||||||||||||
| device=device) | ||||||||||||||||||||||||||||||||
| self.max_num_tokens = ( | ||||||||||||||||||||||||||||||||
| vllm_config.scheduler_config.max_num_batched_tokens) | ||||||||||||||||||||||||||||||||
| self.token_arange_np = np.arange(self.max_num_tokens) | ||||||||||||||||||||||||||||||||
| # We need +1 here because the arange is used to set query_start_loc, | ||||||||||||||||||||||||||||||||
| # which has one more element than batch_size. | ||||||||||||||||||||||||||||||||
| self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + | ||||||||||||||||||||||||||||||||
|
|
@@ -189,8 +185,10 @@ def generate_token_ids(self, | |||||||||||||||||||||||||||||||
| dtype=torch.int32, | ||||||||||||||||||||||||||||||||
| device=self.device, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| cu_num_tokens, token_indices =\ | ||||||||||||||||||||||||||||||||
| self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens) | ||||||||||||||||||||||||||||||||
| num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) | ||||||||||||||||||||||||||||||||
| cu_num_tokens, token_indices = self._prepare_inputs( | ||||||||||||||||||||||||||||||||
| eagle_attn_metadata.query_start_loc, num_rejected_tokens, | ||||||||||||||||||||||||||||||||
| num_tokens) | ||||||||||||||||||||||||||||||||
| target_token_ids = self.runner.input_ids[token_indices] | ||||||||||||||||||||||||||||||||
| target_positions = positions[token_indices] | ||||||||||||||||||||||||||||||||
| if self.name == SpecDcodeType.EAGLE3: | ||||||||||||||||||||||||||||||||
|
|
@@ -590,88 +588,60 @@ def _propose( | |||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _prepare_inputs( | ||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||
| eagle_attn_metadata: AscendMetadata, | ||||||||||||||||||||||||||||||||
| # [batch_size + 1] | ||||||||||||||||||||||||||||||||
| cu_target_query_lens: torch.Tensor, | ||||||||||||||||||||||||||||||||
| # [batch_size] | ||||||||||||||||||||||||||||||||
| num_rejected_tokens: torch.Tensor, | ||||||||||||||||||||||||||||||||
| num_tokens: int, | ||||||||||||||||||||||||||||||||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| This function is used to prepare the inputs for the spec decode. | ||||||||||||||||||||||||||||||||
| It updates to the common_attn_metadata to account for the rejected | ||||||||||||||||||||||||||||||||
| tokens (and newly sampled tokens). It also returns the token indices | ||||||||||||||||||||||||||||||||
| of the tokens that should be fed to the speculator. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| # E.g. | ||||||||||||||||||||||||||||||||
| # common_attn_metadata.query_start_loc{_cpu}: | ||||||||||||||||||||||||||||||||
| # [0, q1, q1 + q2, q1 + q2 + q3] | ||||||||||||||||||||||||||||||||
| # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] | ||||||||||||||||||||||||||||||||
| # num_rejected_tokens: [n1, n2, n3] | ||||||||||||||||||||||||||||||||
| # This function computes the intermediate values: | ||||||||||||||||||||||||||||||||
| # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] | ||||||||||||||||||||||||||||||||
| # And returns: | ||||||||||||||||||||||||||||||||
| # common_attn_metadata.query_start_loc{_cpu}: | ||||||||||||||||||||||||||||||||
| # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] | ||||||||||||||||||||||||||||||||
| # common_attn_metadata.seq_lens{_cpu}: | ||||||||||||||||||||||||||||||||
| # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] | ||||||||||||||||||||||||||||||||
| # token_indices: [0, 1, ..., q1 - n1 - 1, | ||||||||||||||||||||||||||||||||
| # q1, q1 + 1, ..., q1 + q2 - n2 - 1, | ||||||||||||||||||||||||||||||||
| # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] | ||||||||||||||||||||||||||||||||
| num_rejected_tokens_cpu = num_rejected_tokens.to("cpu") | ||||||||||||||||||||||||||||||||
| cu_target_query_lens = eagle_attn_metadata.query_start_loc | ||||||||||||||||||||||||||||||||
| device = eagle_attn_metadata.query_start_loc.device | ||||||||||||||||||||||||||||||||
| query_start_loc_cpu = cu_target_query_lens.to("cpu") | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] | ||||||||||||||||||||||||||||||||
| new_query_len_per_req = (query_start_loc_cpu[1:] - | ||||||||||||||||||||||||||||||||
| query_start_loc_cpu[:-1]) | ||||||||||||||||||||||||||||||||
| # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] | ||||||||||||||||||||||||||||||||
| new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu | ||||||||||||||||||||||||||||||||
| new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # [q1 - n1, q2 - n2, q3 - n3] -> | ||||||||||||||||||||||||||||||||
| # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] | ||||||||||||||||||||||||||||||||
| new_query_start_loc_cpu = torch.zeros( | ||||||||||||||||||||||||||||||||
| query_start_loc_cpu.shape, | ||||||||||||||||||||||||||||||||
| dtype=torch.int32, | ||||||||||||||||||||||||||||||||
| pin_memory=is_pin_memory_available()) | ||||||||||||||||||||||||||||||||
| new_query_start_loc_np = new_query_start_loc_cpu.numpy() | ||||||||||||||||||||||||||||||||
| np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| total_num_tokens = new_query_start_loc_np[-1] | ||||||||||||||||||||||||||||||||
| # Example assuming num_tokens_per_req_np = [2, 4, 3] | ||||||||||||||||||||||||||||||||
| # this implies that `new_query_start_locs` is: | ||||||||||||||||||||||||||||||||
| # [0, 2, 6, 9] -> | ||||||||||||||||||||||||||||||||
| # [0, 0, 2, 2, 2, 2, 6, 6, 6] | ||||||||||||||||||||||||||||||||
| # _r1_ ____r2____ ___r3__ | ||||||||||||||||||||||||||||||||
| new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], | ||||||||||||||||||||||||||||||||
| new_num_tokens_per_req_np) | ||||||||||||||||||||||||||||||||
| # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> | ||||||||||||||||||||||||||||||||
| # [0, 1, 0, 1, 2, 3, 0, 1, 2] | ||||||||||||||||||||||||||||||||
| # _r1_ ____r2____ ___r3__ | ||||||||||||||||||||||||||||||||
| token_offests = self.token_arange_np[:total_num_tokens] \ | ||||||||||||||||||||||||||||||||
| - new_query_start_locs_expanded | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Expand starting positions to match token pattern | ||||||||||||||||||||||||||||||||
| # [0, q1, q1 + q2] -> | ||||||||||||||||||||||||||||||||
| # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] | ||||||||||||||||||||||||||||||||
| # _r1_ _____r2_______ ___________r3____________ | ||||||||||||||||||||||||||||||||
| old_query_start_locs_expanded = np.repeat( | ||||||||||||||||||||||||||||||||
| query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) | ||||||||||||||||||||||||||||||||
| # Final token indices are: | ||||||||||||||||||||||||||||||||
| # [0, 1, // req 1 | ||||||||||||||||||||||||||||||||
| # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 | ||||||||||||||||||||||||||||||||
| # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 | ||||||||||||||||||||||||||||||||
| token_indices_np = token_offests + old_query_start_locs_expanded | ||||||||||||||||||||||||||||||||
| token_indices = torch.from_numpy(token_indices_np).to( | ||||||||||||||||||||||||||||||||
| device, non_blocking=True) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # need use npu | ||||||||||||||||||||||||||||||||
| # cu_target_query_lens: [0, a, a + b, a + b + c] | ||||||||||||||||||||||||||||||||
| # num_rejected_tokens: [n1, n2, n3] | ||||||||||||||||||||||||||||||||
| # num_tokens_per_req: [a - n1, b - n2, c - n3] | ||||||||||||||||||||||||||||||||
| # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] | ||||||||||||||||||||||||||||||||
| # token_indices: [0, 1, ..., a - n1 - 1, | ||||||||||||||||||||||||||||||||
| # a, a + 1, ..., a + b - n2 - 1, | ||||||||||||||||||||||||||||||||
| # a + b, a + b + 1, ..., a + b + c - n3 - 1] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # [0, a, a + b, a + b + c] -> [a, b, c] | ||||||||||||||||||||||||||||||||
| query_len_per_req = (cu_target_query_lens[1:] - | ||||||||||||||||||||||||||||||||
| cu_target_query_lens[:-1]) | ||||||||||||||||||||||||||||||||
| # [a, b, c] -> [a - n1, b - n2, c - n3] | ||||||||||||||||||||||||||||||||
| num_tokens_per_req = query_len_per_req - num_rejected_tokens | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # [a - n1, b - n2, c - n3] -> | ||||||||||||||||||||||||||||||||
| # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] | ||||||||||||||||||||||||||||||||
| cu_num_tokens = torch.zeros_like(cu_target_query_lens) | ||||||||||||||||||||||||||||||||
| torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| token_indices = torch.empty( | ||||||||||||||||||||||||||||||||
| num_tokens, | ||||||||||||||||||||||||||||||||
| dtype=torch.int32, | ||||||||||||||||||||||||||||||||
| device=cu_target_query_lens.device, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| BLOCK_SIZE = 1024 | ||||||||||||||||||||||||||||||||
| self._prepare_eagle_input_sequential( | ||||||||||||||||||||||||||||||||
| token_indices, | ||||||||||||||||||||||||||||||||
| cu_target_query_lens, | ||||||||||||||||||||||||||||||||
| cu_num_tokens, | ||||||||||||||||||||||||||||||||
| block_size=BLOCK_SIZE, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
Comment on lines
+620
to
+626
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||
| return cu_num_tokens, token_indices | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor, | ||||||||||||||||||||||||||||||||
| cu_query_lens: torch.Tensor, | ||||||||||||||||||||||||||||||||
| cu_num_tokens: torch.Tensor, block_size: int): | ||||||||||||||||||||||||||||||||
| device = cu_query_lens.device | ||||||||||||||||||||||||||||||||
| dtype = out_tensor.dtype | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| offsets = torch.arange(block_size, device=device, dtype=dtype) | ||||||||||||||||||||||||||||||||
| start_pos = cu_num_tokens[:-1] | ||||||||||||||||||||||||||||||||
| end_pos = cu_num_tokens[1:] | ||||||||||||||||||||||||||||||||
| num_tokens = end_pos - start_pos | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1)) | ||||||||||||||||||||||||||||||||
| values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1)) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| mask = (offsets.view(1, -1) < num_tokens.view(-1, 1)) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| global_indices_flat = global_indices[mask] | ||||||||||||||||||||||||||||||||
| values_flat = values[mask] | ||||||||||||||||||||||||||||||||
| out_tensor[global_indices_flat] = values_flat | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result of
sum(num_rejected_tokens)is a 0-dimensional tensor. When subtracted fromnum_scheduled_tokens(anint), the resultnum_tokensis also a 0-dimensional tensor. However, the_prepare_inputsfunction is type-hinted to accept anintfornum_tokens. This type mismatch could lead to unexpected behavior. Please convert the tensor to a Python integer using.item()for type correctness and clarity.