-
-
Couldn't load subscription status.
- Fork 10.9k
[PERF] Decouple projections from GDN custom op #27512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Vadim Gimpelson <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the Gated Delta Net (GDN) attention mechanism to improve torch.compile compatibility and performance. By decoupling the input/output projections from the core custom operator and introducing a native PyTorch RMSNormGated layer, the changes yield significant decode throughput improvements. The refactoring is well-executed and the code is clear. I have one high-severity suggestion regarding a local import in a performance-critical path, which should be moved to the top level of the module to adhere to best practices and avoid potential overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
CC @heheda12345 |
Signed-off-by: Vadim Gimpelson <[email protected]>
|
@ALL |
Purpose
This PR is refactoring of GDN.
The main goal is to allow wider using of
torch.compile.torch.compile.RMSNormGatedclass that implements torch native gated rmsnorm and use it for GDN.torch.compilecreates a good code forRMSNormGatedeven better than custom triton kernel used before.Functional Test Result
lm_evalBefore
After
Perf Test Result
Server
Prefill
Before: Total Token throughput (tok/s): 104098.78
After: Total Token throughput (tok/s): 105270.70
Speedup: 1.1%
Decode1
Before: Output token throughput (tok/s): 19212.17
After: Output token throughput (tok/s): 22384.37
Speedup: 16.5%
Decode2
Before: Output token throughput (tok/s): 28821.37
After: Output token throughput (tok/s): 30298.90
Speed up: 5.1%
Decode3
Server
(without increasing
--max_cudagraph_capture_size)Before: Output token throughput (tok/s): 16586.93
After: Output token throughput (tok/s): 18953.92
Speed up: 14.3%