-
Notifications
You must be signed in to change notification settings - Fork 58
MoEGEMM as an extension of GroupGEMM #520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The current Group GEMM in cutlass uses a one-size-fits-all approach - the same tiling scheme is used to compute all output tiles corresponding to any group.
There won't be any divergence within a subgroup, despite conditionals being used at run-time to select the appropriate collectives. This hypothesis has not been implemented in this PR yet. Thanks! |
using TileShape = Shape<_16, _256, _32>; | ||
/* | ||
using TiledMma = | ||
TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, | ||
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>; | ||
*/ | ||
|
||
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 | ||
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>, | ||
Layout<Shape<_1, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tileshapes are for small M
per expert in this specific example, simply because I was testing those shapes at the time.
Replace this tiling with the usual (256, 256, 32) workgroup tile, and it performs better than cutlass grouped GEMM for num_experts=16, M_per_expert=256, N=16384, K=5120 used in LLaMA 4.
Change B
to Rowwise, and the performance difference between this one (higher throughput) & cutlass group gemm is even larger.
I have more changes locally, though. Will push to this branch & the vLLM repo once complete.
Thanks!
Background
When multiple GEMMs are to be computed, each with its own canonical
A
,B
,C
,D
matrices, GroupGEMM is useful for ensuring high GPU utilization & preventing launch overhead that'd otherwise occur for multiple GEMM kernel launches. In cutlass, the vanilla GroupGEMM uses a persistent kernel approach - the number of workgroups launched are equal to the number of Xe cores, and they loop through until they have work, (in this case, work, is the mainloop to compute one of the output tiles of any one of the GEMMs we try to compute with the GroupGEMM API).For Mixture of Experts used in Deep Learning models such as LLMs, the MoE GEMM use-case is something like this - each
expert
(corresponding to agroup
) has an associatedweight
sizedN * K
, which essentially a column-majorB
matrix. All theB
matrices are contiguous w.r.t. each other, i.e. their total size isnum_groups * N * K
.N, K
are compile-time constants.M
for each group is variable. AllA
matrices are also contiguous w.r.t. each other. Each set of tokens routed to an expert makes up theA
matrix for that group.MoEGEMM
seems to be a natural candidate for leveraging GroupGEMM.The problem
The cutlass GroupGEMM API is generic in that it requires pointers of
A
,B
,C
,D
tensors pertaining to each group.For launching the kernel, the CPU needs to provide a array of these GPU pointers (that array is also on the GPU).
However, for practical use-cases such as Mixture of Experts (each GroupGEMM
group
corresponds to oneMoEexpert
), such lists can't be conveniently pre-computed in advance (it's indeed possible to create it at the beginning of the kernel, and then synchronize across all workgroups, but that code can't be a part of generic Group GEMM).Solution proposed in this PR
Provide only the base
A
,B
,C
,D
pointers, and also passN
,K
, so that the canonicalA
,B
,C
,D
matrices' pointers for each group can be computed on-the-fly (a prefix sum algorithm to compute a cumulative sum ofM
might help but based on our experimentation, it doesn't seem to make much difference, as smallM
case is memory-bound, anyway).To have minimal changes from the existing code, pass lists sized one instead of lists with size equal to the number of groups, as otherwise happens in the default case.
The PR adds a new kernel & a tile scheduler for MoEGEMM, while reusing existing MMA & epilogue collectives (but with modified code for
A
,B
,C
,D
pointer computation).We could instead add a template parameter to make these changes in the existing kernels and also use
if constexpr
to separate it from the default GroupGEMM. While the current implementation in this PR introduces duplication, the alternative would make the code messier.Performance
With small
M
dimension for eachGEMM problem
, the performance is worse than that of largeM
dimension due to lower arithmetic intensity in the former case, but it's better than launching a separate kernel for each GEMM problem.Caveat
The example just portrays one way to use the API.
Also, it has mostly been copy-pasted from an existing example, so it can be revised further.