Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 50 additions & 80 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType

Expand Down Expand Up @@ -68,9 +67,6 @@ def __init__(self,
self.hidden_size),
dtype=self.vllm_config.model_config.dtype,
device=device)
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
Expand Down Expand Up @@ -189,8 +185,10 @@ def generate_token_ids(self,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, token_indices =\
self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The result of sum(num_rejected_tokens) is a 0-dimensional tensor. When subtracted from num_scheduled_tokens (an int), the result num_tokens is also a 0-dimensional tensor. However, the _prepare_inputs function is type-hinted to accept an int for num_tokens. This type mismatch could lead to unexpected behavior. Please convert the tensor to a Python integer using .item() for type correctness and clarity.

Suggested change
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
num_tokens = num_scheduled_tokens - torch.sum(num_rejected_tokens).item()

cu_num_tokens, token_indices = self._prepare_inputs(
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
num_tokens)
target_token_ids = self.runner.input_ids[token_indices]
target_positions = positions[token_indices]
if self.name == SpecDcodeType.EAGLE3:
Expand Down Expand Up @@ -590,88 +588,60 @@ def _propose(

def _prepare_inputs(
self,
eagle_attn_metadata: AscendMetadata,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for the spec decode.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens_cpu = num_rejected_tokens.to("cpu")
cu_target_query_lens = eagle_attn_metadata.query_start_loc
device = eagle_attn_metadata.query_start_loc.device
query_start_loc_cpu = cu_target_query_lens.to("cpu")

# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(
query_start_loc_cpu.shape,
dtype=torch.int32,
pin_memory=is_pin_memory_available())
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

total_num_tokens = new_query_start_loc_np[-1]
# Example assuming num_tokens_per_req_np = [2, 4, 3]
# this implies that `new_query_start_locs` is:
# [0, 2, 6, 9] ->
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
# _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = self.token_arange_np[:total_num_tokens] \
- new_query_start_locs_expanded

# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True)

# need use npu
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]

# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens

# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])

token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_target_query_lens.device,
)
BLOCK_SIZE = 1024
self._prepare_eagle_input_sequential(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
Comment on lines +620 to +626
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The BLOCK_SIZE is hardcoded to 1024. The _prepare_eagle_input_sequential method uses this block_size to create an offsets tensor, assuming that the number of tokens per request will not exceed this value. However, the number of tokens per request can be up to max_num_batched_tokens (default 2560), which is larger than 1024. If a request has more than 1024 tokens, this will lead to incorrect indexing and corrupt output. This is a critical bug. To fix this, block_size should be determined dynamically based on the maximum number of tokens per request in the current batch.

Suggested change
BLOCK_SIZE = 1024
self._prepare_eagle_input_sequential(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
if num_tokens > 0:
block_size = int(torch.max(num_tokens_per_req).item())
self._prepare_eagle_input_sequential(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=block_size,
)

return cu_num_tokens, token_indices

def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor,
cu_query_lens: torch.Tensor,
cu_num_tokens: torch.Tensor, block_size: int):
device = cu_query_lens.device
dtype = out_tensor.dtype

offsets = torch.arange(block_size, device=device, dtype=dtype)
start_pos = cu_num_tokens[:-1]
end_pos = cu_num_tokens[1:]
num_tokens = end_pos - start_pos

global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1))
values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1))

mask = (offsets.view(1, -1) < num_tokens.view(-1, 1))

global_indices_flat = global_indices[mask]
values_flat = values[mask]
out_tensor[global_indices_flat] = values_flat
Loading