Skip to content

Conversation

sanchitintel
Copy link

@sanchitintel sanchitintel commented Sep 29, 2025

Problem

In a real-use case of GroupGEMM, we may not have any C matrices. When beta is 0, passing a bogus ptr_C argument for xe_array_epilogue, such as ptr_D, if they are the same dtype, still results in some wasteful compute, so for the default Group GEMM example (examples/04_bmg_grouped_gemm.cpp), there's a slowdown of ~0.1 TFLOPs.

Solution

Preferable solution is to use beta=0 & a non-null value for ptr_C, such as ptr_D, e.g. static_cast<const ElementC**>((void*)ptr_D.get()) even when C & D are not the same dtype because C tiles aren't actually used when beta is 0.

However, in this implementation in this PR, if a user still wants to pass nullptr argument for ptr_C, then C should not be used irrespective of whatever the beta value is.

Given that the aforementioned workaround doesn't result in a discernible performance penalty, I'm not sure if this PR makes sense.

@sanchitintel
Copy link
Author

Hi @rolandschulz, in the main branch, for vanilla GEMM, a nullptr value of C matrix is supported, but that's not the case for Grouped GEMM.

This PR fixes that issue.

Thanks!

Copy link

@rolandschulz rolandschulz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to check the API being consistent with Nvidia. Otherwise LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants