-
Notifications
You must be signed in to change notification settings - Fork 604
Add MagicMTP(block verify) and Triton optimization #4443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Triton-based optimizations for rejection sampling and adds a "block verify" method, also known as MagicMTP. The changes include new Triton kernels for both greedy and random sampling, along with PyTorch fallback implementations. My review has identified a critical issue where a non-existent Triton function is called, which will prevent the code from running. Additionally, I've found a performance inefficiency in one of the new kernels and a use of a bare except clause, which is considered poor practice. I have provided suggestions to address these points.
| other=0 | ||
| ) | ||
| recovered_id = tl.argmax(tmp_target_prob, axis=-1) | ||
| max_p = tl.get_element(tmp_target_prob, (recovered_id,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function tl.get_element does not appear to be a valid function in the triton.language API. This will likely cause a compilation error. To get the value corresponding to the argmax index from a block tensor, you may need to use a different approach. Using tl.reduce with a custom binary operator to find both max value and index simultaneously is a common pattern for this.
| rejected = False | ||
| for pos in range(num_draft_tokens): | ||
| if not rejected: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
| target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) | ||
| tl.store( | ||
| output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, | ||
| target_argmax_id, | ||
| ) | ||
| if draft_token_id != target_argmax_id: | ||
| # Reject. | ||
| rejected = True | ||
|
|
||
| if not rejected: | ||
| # If all tokens are accepted, append the bonus token. | ||
| bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) | ||
| tl.store( | ||
| output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, | ||
| bonus_token_id, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of the loop continues to iterate even after a token has been rejected, which is inefficient. You can return from the function immediately after the first rejection to avoid unnecessary iterations. This simplifies the code by removing the rejected flag and the conditional block after the loop.
for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject and stop processing this request.
return
# If the loop completes, all tokens were accepted.
# Append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)| except: | ||
| TRITON_ASCEND_AVAILABLE = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a bare except: is generally discouraged as it can catch unexpected exceptions like SystemExit or KeyboardInterrupt, making it harder to debug issues. It's better to catch specific exceptions. In this case, ImportError seems more appropriate if you only want to handle cases where triton is not installed.
except ImportError:
TRITON_ASCEND_AVAILABLE = Falsebefa9e5 to
74390bb
Compare
Signed-off-by: chenaoxuan <[email protected]>
| generate_uniform_probs) | ||
| from vllm.v1.spec_decode.metadata import SpecDecodeMetadata | ||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use HAS_TRITON, plz refer to https://github.com/vllm-project/vllm/blob/d9d342d214b8c13f71215318a6d9252cc4a5ca47/vllm/triton_utils/importing.py#L12
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?