Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][CPU] CPU MLA #14744

Merged
merged 48 commits into from
Mar 25, 2025
Merged

[Kernel][CPU] CPU MLA #14744

merged 48 commits into from
Mar 25, 2025

Conversation

gau-nernst
Copy link
Contributor

@gau-nernst gau-nernst commented Mar 13, 2025

In this PR, I add preliminary support for MLA on CPU. I'm opening this PR to get feedback and comments from the maintainers on the high level design. The MLA kernel itself is not optimized and I plan to optimize it further (either in this PR or leave it in a future PR).

The main changes can be summarized as follows

  • Add concat_and_cache_mla CPU kernel
  • Add mla_decode_kvcache_cpu kernel. This currently does not follow any existing API. See the code for more details. Only supports decoding 1 query token.
  • Add CPUMLABackend: This largely follows TorchSDPABackend for metdata stuff, and re-use MLACommonBackend for MLA-related logic. IPEX's varlen_attention is used for prefilling, and the new custom kernel is used for decoding.
  • Fix various import-related logic so that MLACommon import works on CPU build, and other minor fixes.

Other areas that I'm also looking into (but not yet implemented):

  • Chunked prefill: this requires merge_attn_states, which can be implemented as a CPU kernel, or perhaps torch.compile() can codegen it?

I have tested this code with deepseek-ai/DeepSeek-V2-Lite-Chat and the outputs look coherent, even though the outputs are different from w/o using MLA (VLLM_MLA_DISABLE=1)

@bigPYJ1151 @Isotr0py Do hope to hear your feedback 🙏

Update 1: I have added some optimizations for the kernel. I'm no expert in optimizing CPU code, so any feedback and advice is welcome. Outlines of my approach:

  • Multi-threading is only parallelized across context dimension. The main reason for this is that I believe decode attention is only slow (relative to MLP and linear projections) when context length is long. When context length is short, most of the runtime is spent on MLP and linear projections, hence it's not so important to parallelize across batch and query heads -> simplify the implementation a bit.
  • Convert KV cache from BF16/FP16 to FP32 once a head of time, since BF16/FP16->FP32 is slow on CPU. This can be reused across query heads and between K and V (since V overlap with K in MLA). Some special care is added for AVX512 with BF16 dot product. Thanks to chunking along the context dimension, I can convert KV cache to FP32 per block, hence not consuming too much memory (and hopefully the FP32 KV cache block stays in CPU cache).
Benchmark script
# modified from https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_flashmla.py
import random
import os
import argparse

tcmalloc_path = "/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4"
os.environ["LD_PRELOAD"] = f"{tcmalloc_path}:{os.environ.get('LD_PRELOAD', '')}"

import torch
from torch.utils.benchmark import Timer
from dataclasses import dataclass
import torch.nn.functional as F
from torch import Tensor

import vllm._custom_ops as ops


def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
    x, y = x.double(), y.double()
    cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
    assert cos_diff < 1e-5


def cdiv(a, b):
    return (a + b - 1) // b


def ref_mla(
    out: Tensor,  # (bs, num_heads, v_head_dim)
    query: Tensor,  # (bs, num_heads, head_dim)
    kv_cache: Tensor,  # (num_blocks, block_size, head_dim)
    scale: float,
    block_tables: Tensor,  # (bs, max_num_blocks)
    seq_lens: Tensor,  # (bs,)
):
    bs, num_heads, v_head_dim = out.shape
    head_dim = query.shape[2]

    for i in range(bs):
        # gather and flatten KV-cache
        kv = kv_cache[block_tables[i]]  # (max_num_blocks, block_size, head_dim)
        kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]]  # (1, seq_len, head_dim)
        v = kv[:, :, :v_head_dim]

        q = query[i].view(num_heads, 1, head_dim)
        o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
        out[i] = o.view(num_heads, v_head_dim)

    return out


@dataclass
class ProblemShape:
    bs: int = 1
    seq_len: int = 256
    num_heads: int = 16
    head_dim: int = 576
    v_head_dim: int = 512
    block_size: int = 16


def test_cpu_mla(args: ProblemShape, perf: bool = False):
    dtype = torch.bfloat16
    torch.set_default_dtype(dtype)
    torch.manual_seed(0)
    random.seed(0)

    bs = args.bs
    head_dim = args.head_dim
    v_head_dim = args.v_head_dim
    scale = head_dim ** (-0.5)

    print(args)
    seq_lens = torch.full((bs,), args.seq_len, dtype=torch.int32)
    seqlen_pad = cdiv(args.seq_len, 256) * 256

    q = torch.randn(bs, args.num_heads, head_dim)
    block_table = torch.arange(bs * seqlen_pad // args.block_size, dtype=torch.int32)
    block_table = block_table.view(bs, seqlen_pad // args.block_size)

    kv_cache = torch.randn(block_table.numel(), args.block_size, head_dim)
    for i in range(bs):
        kv_cache.view(bs, seqlen_pad, head_dim)[i, args.seq_len :] = float("nan")

    out_mla = q.new_zeros(bs, args.num_heads, v_head_dim)
    ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens)

    if perf:
        return

    out_ref = q.new_zeros(bs, args.num_heads, v_head_dim)
    ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)

    torch.testing.assert_close(out_mla, out_ref)
    cal_diff(out_mla, out_ref, "out")

    num_elems = (
        bs * args.seq_len * head_dim  # kv cache
        + bs * args.num_heads * head_dim  # query
        + bs * args.num_heads * v_head_dim  # output
    )
    num_gb = num_elems * dtype.itemsize / 1e9
    print(
        f"Input size: {num_gb * 1e3 : .2f} MB. Make sure this is larger than 2x L3 cache size for accurate benchmark."
    )

    a = torch.randn(num_elems // 2)
    b = torch.randn(num_elems // 2)
    t = (
        Timer(
            "a.copy_(b)",
            globals={**globals(), **locals()},
            num_threads=torch.get_num_threads(),
        )
        .blocked_autorange(min_run_time=1)
        .median
    )
    print(f"Copy: {num_gb / t:.4f} GB/s")

    t = (
        Timer(
            "ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens)",
            globals={**globals(), **locals()},
            num_threads=torch.get_num_threads(),
        )
        .blocked_autorange(min_run_time=1)
        .median
    )
    print(f"CPU MLA: {num_gb / t:.4f} GB/s")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--perf", action="store_true")
    args = parser.parse_args()

    print(f"No. of threads: {torch.get_num_threads()}")
    test_cpu_mla(ProblemShape(bs=1, seq_len=64_000), perf=args.perf)
    if args.perf:
        return
    test_cpu_mla(ProblemShape(bs=1, seq_len=54_321), perf=args.perf)
    test_cpu_mla(ProblemShape(bs=1, seq_len=1243, num_heads=5), perf=args.perf)
    test_cpu_mla(ProblemShape(bs=3, seq_len=1234), perf=args.perf)
    test_cpu_mla(ProblemShape(bs=30, seq_len=2048), perf=args.perf)


if __name__ == "__main__":
    main()

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Mar 15, 2025
Signed-off-by: Thien Tran <[email protected]>
Signed-off-by: Thien Tran <[email protected]>
Copy link

mergify bot commented Mar 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gau-nernst.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 21, 2025
@mergify mergify bot removed the needs-rebase label Mar 21, 2025
Signed-off-by: Thien Tran <[email protected]>
Comment on lines +41 to +42
pytest -v -s tests/kernels/test_cache.py -m cpu_model
pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest -v -s tests/kernels -m cpu_model doesn't work due to Triton imports (there are probably other issues as well). We can have a separate PR to make pytest -v -s tests/kernels -m cpu_model work (as well as improve test coverage for CPU kernels)

@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, we can merge this with the CUDA test of the same op (i.e. select correct device at runtime)

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay! Overall I think this is quite close to mergable, just left a few comments


# for chunked-prefill
if self.chunked_prefill:
prefill_block_tables = make_tensor_with_pad(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we assert here since chunked_prefill is not supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an assert in __init__(). That should be sufficient?

torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of the above this is fine for now but we should see what we need to do to reuse more from the common builder since we may be refactoring that in the future and this may cause issues

@gau-nernst
Copy link
Contributor Author

I have made some changes. Lmk if it addresses ur concerns. Thank you. Hope to get this PR merged soon.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the contribution!

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) March 25, 2025 01:25
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 25, 2025
@LucasWilkinson LucasWilkinson merged commit 4f044b1 into vllm-project:main Mar 25, 2025
59 checks passed
@gau-nernst gau-nernst deleted the cpu_mla branch March 25, 2025 11:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants