Skip to content

Conversation

@chenaoxuan
Copy link

@chenaoxuan chenaoxuan commented Nov 26, 2025

What this PR does / why we need it?

Does this PR introduce any user-facing change?

How was this patch tested?

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Triton-based optimizations for rejection sampling and adds a "block verify" method, also known as MagicMTP. The changes include new Triton kernels for both greedy and random sampling, along with PyTorch fallback implementations. My review has identified a critical issue where a non-existent Triton function is called, which will prevent the code from running. Additionally, I've found a performance inefficiency in one of the new kernels and a use of a bare except clause, which is considered poor practice. I have provided suggestions to address these points.

other=0
)
recovered_id = tl.argmax(tmp_target_prob, axis=-1)
max_p = tl.get_element(tmp_target_prob, (recovered_id,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function tl.get_element does not appear to be a valid function in the triton.language API. This will likely cause a compilation error. To get the value corresponding to the argmax index from a block tensor, you may need to use a different approach. Using tl.reduce with a custom binary operator to find both max value and index simultaneously is a common pattern for this.

Comment on lines 40 to 61
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True

if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of the loop continues to iterate even after a token has been rejected, which is inefficient. You can return from the function immediately after the first rejection to avoid unnecessary iterations. This simplifies the code by removing the rejected flag and the conditional block after the loop.

        for pos in range(num_draft_tokens):
            draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
            target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
            tl.store(
                output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
                target_argmax_id,
            )
            if draft_token_id != target_argmax_id:
                # Reject and stop processing this request.
                return

        # If the loop completes, all tokens were accepted.
        # Append the bonus token.
        bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
        tl.store(
            output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
            bonus_token_id,
        )

Comment on lines 149 to 150
except:
TRITON_ASCEND_AVAILABLE = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a bare except: is generally discouraged as it can catch unexpected exceptions like SystemExit or KeyboardInterrupt, making it harder to debug issues. It's better to catch specific exceptions. In this case, ImportError seems more appropriate if you only want to handle cases where triton is not installed.

except ImportError:
    TRITON_ASCEND_AVAILABLE = False

@chenaoxuan chenaoxuan force-pushed the magicmtp branch 6 times, most recently from befa9e5 to 74390bb Compare November 26, 2025 03:25
generate_uniform_probs)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata

try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants