-
-
Notifications
You must be signed in to change notification settings - Fork 6.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel][CPU] CPU MLA #14744
[Kernel][CPU] CPU MLA #14744
Conversation
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
pytest -v -s tests/kernels/test_cache.py -m cpu_model | ||
pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytest -v -s tests/kernels -m cpu_model
doesn't work due to Triton imports (there are probably other issues as well). We can have a separate PR to make pytest -v -s tests/kernels -m cpu_model
work (as well as improve test coverage for CPU kernels)
@pytest.mark.cpu_model | ||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") | ||
@torch.inference_mode() | ||
def test_concat_and_cache_mla_cpu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future, we can merge this with the CUDA test of the same op (i.e. select correct device at runtime)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the delay! Overall I think this is quite close to mergable, just left a few comments
|
||
# for chunked-prefill | ||
if self.chunked_prefill: | ||
prefill_block_tables = make_tensor_with_pad( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we assert here since chunked_prefill is not supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have an assert in __init__()
. That should be sufficient?
torch.cumsum(kv_lens_tensor, | ||
dim=0, | ||
dtype=torch.int32, | ||
out=kv_start_loc[1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all of the above this is fine for now but we should see what we need to do to reuse more from the common builder since we may be refactoring that in the future and this may cause issues
Signed-off-by: Thien Tran <[email protected]>
I have made some changes. Lmk if it addresses ur concerns. Thank you. Hope to get this PR merged soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the contribution!
In this PR, I add preliminary support for MLA on CPU. I'm opening this PR to get feedback and comments from the maintainers on the high level design. The MLA kernel itself is not optimized and I plan to optimize it further (either in this PR or leave it in a future PR).
The main changes can be summarized as follows
concat_and_cache_mla
CPU kernelmla_decode_kvcache_cpu
kernel. This currently does not follow any existing API. See the code for more details. Only supports decoding 1 query token.CPUMLABackend
: This largely followsTorchSDPABackend
for metdata stuff, and re-useMLACommonBackend
for MLA-related logic. IPEX'svarlen_attention
is used for prefilling, and the new custom kernel is used for decoding.MLACommon
import works on CPU build, and other minor fixes.Other areas that I'm also looking into (but not yet implemented):
merge_attn_states
, which can be implemented as a CPU kernel, or perhapstorch.compile()
can codegen it?I have tested this code with
deepseek-ai/DeepSeek-V2-Lite-Chat
and the outputs look coherent, even though the outputs are different from w/o using MLA (VLLM_MLA_DISABLE=1
)@bigPYJ1151 @Isotr0py Do hope to hear your feedback 🙏
Update 1: I have added some optimizations for the kernel. I'm no expert in optimizing CPU code, so any feedback and advice is welcome. Outlines of my approach:
Benchmark script