Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Roadmap] CPU Performance Optimization for SGLang and Flashinfer 24'Q4 #1

Open
19 of 36 tasks
mingfeima opened this issue Dec 16, 2024 · 6 comments
Open
19 of 36 tasks

Comments

@mingfeima
Copy link
Owner

mingfeima commented Dec 16, 2024

🚀 The feature, motivation and pitch

The target of this project is to optimize the performance of SGLang on Intel Xeon Scalable Processors, feature targets SPR(4th gen), EMR(5th gen), GNR(6th gen) with Intel® Advanced Matrix Extensions support.

  • optimize flashinfer backend on CPU device (SGLang currently supports triton and flashinfer backends).
  • target at providing good efficiency, aka perf per dollar.
  • upstream optimization to main branch with release.
  • focus on avx512 and amx, provide fallbacks for other ISA.

For the current stage, focus customer request first and then gradually increase model coverage:

  • DeepSeek - MHA
  • DeepSeekV2 - MLA

1. Flashinfer Kernel Optimizations on CPU

layernorm

  • rmsnorm
  • fused_add_rmsnorm
  • gemma_rmsnorm
  • gemma_fused_add_rmsnorm

activations

  • silu_and_mul
  • gelu_and_mul
  • gelu_tanh_and_mul

sampling

  • min_p_sampling_from_probs
  • top_k_renorm_prob
  • top_k_top_p_sampling_from_probs
  • top_p_renorm_prob

attention

  • BatchDecodeWithPagedKVCacheWrapper
  • BatchPrefillWithPagedKVCacheWrapper
  • BatchPrefillWithRaggedKVCacheWrapper

[NOTES]: DeepSeekV2 will choose to use triton backend with MLA structure, it won't go to flashinfer, defined in python/sglang/srt/model_executor/model_runner.py

2. SGlang CPU device enabling

benchmarking

  • fixed input benchmarking - use 1K prompt and 128 output tokens.
  • shared gpt benchmark - default, serving mode
  • shared gpt benchmark - default, offline mode (TBD, this should be focus later on)

3. DeepSeekV2 Optimization

The v0.3 sgl releases several optimizations for this model.

  • MLA decoding kernel optimization with weight absorb. Written in triton, performance increase comes from better blocking, currently hack it in flashinfer for CPU and then merge with existing APIs if possible. The optimized kernel handles MHA and MQA/GQA/MLA with different parallel and tiling strategy to reduce memory access for KV cache.
  • fp8 kv_cache - enable bmm_fp8 with flashinfer. The feature is optional but kind of a must on CPU. WO this, it will go to torch.bmm which is not that performant. The current scheme GPU uses is actually fp8 dynamic quant, we may also fuse the quant_A kernel into bmm_fp8. (Feels like CPU cann't do this due to lack of E4M3 fast conversion impl, check later 👀 ).

Additionally optimizations we need:

  • MHA prefilling kernel optimization: (map triton kernel impl? check later 👀 )
  • FusedMoE kernel enabling on CPU.
  • [Nice to have]: introduce brgemm micro kernel or hard code amx for prefilling kernels and MLA decoding kernel
  • [Nice to have]: block plannning, current scheme with kv split from flash-decoding has load imbalance issue when serving with multiple different kv lengths (need input from profiler, check later 👀 ).

4. ⭐ First Token latency reduction

To make Xeon actually useful, try to reduce first token latency as much as possible for long prompt length:

  • ⭐⭐ gemm efficiency: weight prepacking (torch.compile?, dynamic packing?)
  • ⭐⭐ tensor parallel: run large gemm with tensor parallel (TP) for SNC=3

5. Upstreaming

  • and context for cpu to manage dynamic memory allocation and cache immediate results if necessary (in flashinfer).
  • remove dependency from at::vec::Vectorized<> wrapper. Decision needs to be made after screening types of operations.
  • add fallbacks for other ISA aside from avx512 and amx, to let other vender run.
  • introduce onednn brgemm micro kernels in attention calculation.

6. TODO

  • extend optimizations of CPU backend from flashinfer to other LLM engines
  • extend flashinfer cpu from aot to jit mode
  • wrap up distributed GEMM with OSS acceptable approach
  • extend quantization support
@mingfeima mingfeima changed the title [Feature] [Roadmap] CPU Performance Optimization for SGLang and Flashinfer Dec 16, 2024
@mingfeima mingfeima changed the title [Roadmap] CPU Performance Optimization for SGLang and Flashinfer [Roadmap] CPU Performance Optimization for SGLang and Flashinfer 24'Q4 Dec 16, 2024
@mingfeima
Copy link
Owner Author

mingfeima commented Feb 21, 2025

  • DeepSeek V3 - fp8 block gemm

@mingfeima
Copy link
Owner Author

mingfeima commented Mar 24, 2025

Servning mode tuning

  • serving threading runtime debug
  • serving benchmark configs
  • compute_position_torch, clamp_position C++ implementation

TODO: (in priority order)

  • fuse rope and bmm in absorb: rotary_embedding takes 0.284ms * 60 = 17ms; bmm takes 9ms;
  • the shared moe, with mul and add takes 0.2ms * 60 = 12ms
  • enable FlashMLA, the current decode_attn takes 0.13 ms * 60 = 8ms
  • fuse bmm and int8_scaled_mm after GQA in decode, 3.6ms
  • at::zeros takes majority of time in fused_add_rms_norm when allocating and zeroing temp buffer, each run takes 0.77ms and 3.2 ms in total
  • fuse set_kv_buffer with decod attn, the index_put is super slow, ~3ms, move attn_logits into C++, tune config for [BLOCK_H, num_kv_splits]
  • fuse per_row_quant_int8 with int8_scaled_mm.
  • replace torch.empty in python level to remove python overhead, call it from C++.
  • implement all-gather: the whole module takes about 2.2 ms which comprises of narrow and gather, while all-reduce takes 4.8ms for 120 runs.
  • add native kernels for compute_position_torch, check args to see if we need to vectorize this staff
  • remove the python hollows when calling C++ kernels
  • weight_packed_linear for small OC: 256 takes 2.6ms in total and only achieved memory BW of 486 GB/s
  • all-reduce performance spikes every dozens of runs, normal value is 20us but the spike might be 100us even 200us above.
  • refactor the all-reduce code with shared memory in the aten coding style.
  • remove IPEX code snaps from adapt vllm distributed module to sglang sgl-project/sglang#2244

@gau-nernst
Copy link

Hi @mingfeima. At my company, I'm also working on optimizing LLM inference for CPU servers. Can I get involved with your team so that we can join efforts together? Recently I wrote an MLA decode kernel for CPU in vLLM vllm-project/vllm#14744. I was not aware of your existing efforts. It will be interesting to benchmark against your triton version. I'm also new to CPU optimization, hope to learn from everyone.

@mingfeima
Copy link
Owner Author

Great job! Sure, we are open to more CPU contributors :)

for the MLA decoding part, my optimizations are to a) fold H to change gemv to gemm; b) apply avx512-bf16 and amx-bf16; c) use flash-mla algorithm. I have done a) and b) but c) is still on my TODO list. Right now, for the MLA decoding part you referred, each run in DeepSeek R1 will take ~4ms on our machine.

We also optimized IPEX, this will help improve vllm once IPEX is used as attention backend. Right now the overall performance is roughly TPOT 60ms for 1k input and 1k output with DeepSeek R1 671B with int8 dtype on a single node Xeon CPU. The optimization job is still on going!

@gau-nernst
Copy link

Is there a channel we can discuss things in more details? We can create a channel under SGLang slack if currently there isn't one.

@mingfeima
Copy link
Owner Author

We do have a slack channel with intel-sglang collab but i am not sure whether it is appropriate to invite you there, since we may share some non-public information there.

Could you please send a email to [email protected] and identify your proposals? We can discuss how to leverage efforts together _

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants