Skip to content

Conversation

sanchitintel
Copy link

@sanchitintel sanchitintel commented Sep 18, 2025

Background

When multiple GEMMs are to be computed, each with its own canonical A, B, C, D matrices, GroupGEMM is useful for ensuring high GPU utilization & preventing launch overhead that'd otherwise occur for multiple GEMM kernel launches. In cutlass, the vanilla GroupGEMM uses a persistent kernel approach - the number of workgroups launched are equal to the number of Xe cores, and they loop through until they have work, (in this case, work, is the mainloop to compute one of the output tiles of any one of the GEMMs we try to compute with the GroupGEMM API).

For Mixture of Experts used in Deep Learning models such as LLMs, the MoE GEMM use-case is something like this - each expert (corresponding to a group) has an associated weight sized N * K, which essentially a column-major B matrix. All the B matrices are contiguous w.r.t. each other, i.e. their total size is num_groups * N * K. N, K are compile-time constants. M for each group is variable. All A matrices are also contiguous w.r.t. each other. Each set of tokens routed to an expert makes up the A matrix for that group.

MoEGEMM seems to be a natural candidate for leveraging GroupGEMM.

The problem

The cutlass GroupGEMM API is generic in that it requires pointers of A, B, C,D tensors pertaining to each group.
For launching the kernel, the CPU needs to provide a array of these GPU pointers (that array is also on the GPU).

However, for practical use-cases such as Mixture of Experts (each GroupGEMM group corresponds to oneMoE expert), such lists can't be conveniently pre-computed in advance (it's indeed possible to create it at the beginning of the kernel, and then synchronize across all workgroups, but that code can't be a part of generic Group GEMM).

Solution proposed in this PR

Provide only the base A, B, C, D pointers, and also pass N, K, so that the canonical A, B, C, D matrices' pointers for each group can be computed on-the-fly (a prefix sum algorithm to compute a cumulative sum of M might help but based on our experimentation, it doesn't seem to make much difference, as small M case is memory-bound, anyway).

To have minimal changes from the existing code, pass lists sized one instead of lists with size equal to the number of groups, as otherwise happens in the default case.
The PR adds a new kernel & a tile scheduler for MoEGEMM, while reusing existing MMA & epilogue collectives (but with modified code for A, B, C, D pointer computation).

We could instead add a template parameter to make these changes in the existing kernels and also use if constexpr to separate it from the default GroupGEMM. While the current implementation in this PR introduces duplication, the alternative would make the code messier.

Performance

With small M dimension for each GEMM problem, the performance is worse than that of large M dimension due to lower arithmetic intensity in the former case, but it's better than launching a separate kernel for each GEMM problem.

Caveat

The example just portrays one way to use the API.
Also, it has mostly been copy-pasted from an existing example, so it can be revised further.

@sanchitintel sanchitintel changed the title MoEGEMM based on cutlass GroupGEMM MoEGEMM as an extension of GroupGEMM Sep 19, 2025
@sanchitintel
Copy link
Author

sanchitintel commented Sep 19, 2025

The current Group GEMM in cutlass uses a one-size-fits-all approach - the same tiling scheme is used to compute all output tiles corresponding to any group.
Having looked at sample M input shapes for MoEGEMM, which are quite divergent, I think better performance can be achieved by:

  1. Pre-compiling MMA & epilogue collectives for various tile-shapes
  2. For each expert/group, selecting a different tiling scheme based on its M, but using the same tiling scheme to compute all the output blocks of one expert/group. Selection of tiling scheme should be based on some heuristics (e.g. different tiling scheme for M <= 16 and another different one for 16 < M <= 32) , which can also be hardcoded in the code after experimentation. Might even use a different tiling scheme to handle the M dimension tail (since the rasterization is along the N dimension, M dimension's tail can be distinguished, as all the output tiles with tail_M M-dimension would be the amongst the last (N + WG_N - 1)/WG_N output tiles corresponding to that expert).

There won't be any divergence within a subgroup, despite conditionals being used at run-time to select the appropriate collectives.

This hypothesis has not been implemented in this PR yet.

Thanks!

Comment on lines 652 to 661
using TileShape = Shape<_16, _256, _32>;
/*
using TiledMma =
TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>;
*/

using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengzhao-intel

This tileshapes are for small M per expert in this specific example, simply because I was testing those shapes at the time.

Replace this tiling with the usual (256, 256, 32) workgroup tile, and it performs better than cutlass grouped GEMM for num_experts=16, M_per_expert=256, N=16384, K=5120 used in LLaMA 4.

Change B to Rowwise, and the performance difference between this one (higher throughput) & cutlass group gemm is even larger.

I have more changes locally, though. Will push to this branch & the vLLM repo once complete.

Thanks!

@sanchitintel sanchitintel marked this pull request as draft September 29, 2025 17:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants