-
Notifications
You must be signed in to change notification settings - Fork 12.4k
CUDA: add fused rms norm #14800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: add fused rms norm #14800
Conversation
Sorry, ignore my previous suggestion. I forgot that the kernel modifies the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion, the non-deleted suggestions for changes are how I mean the code to be modified (though my preference is very small).
I've encountered this error while trying it with Qwen/Qwen3-235B-A22B-Instruct-2507 🥲
|
@exxocism do you have a stack trace? Also does the problem go away with setting env variable |
Some quick performance numbers:
|
I get NaN with |
The code seems to work correctly for |
I just ran |
The PPL value can be fixed with |
Thanks @JohannesGaessler for quickly figuring out the bug! Could you please try again for your ppl values, I could replicate the issue and it seems to be fixed now. Also cc @exxocism if you are willing to give this another try. |
@am17an Yes, it works with the env variable
|
@am17an Thanks! it works with new commit. 🎉 |
In my testing it now also works with the new commit. I think the problem with |
Actually it looks like the vulkan fusion operation also does not seem to implement the broadcast, the new test cases are failing for them |
757b81c
to
ed9f84e
Compare
Similar to the Vulkan PR (#14366), perhaps
ggml_vk_can_fuse
andggml_cuda_can_fuse
can live inside ggml instead of their respective backends since they don't have backend specific codeDecent speedup in PP on my RTX 3090