-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement persistent matmul scheduling #3812
Open
jacobhinkle
wants to merge
42
commits into
main
Choose a base branch
from
jh/persistent_kernel_impl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+299
−107
Open
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
851669a
Split Hopper MMA by warp-tile before instruction tile
jacobhinkle 8b42cd6
Use 4 warpgroups, disable smem epilogue
jacobhinkle 7c6d417
Merge branch 'main' into hopper_warptile_split
jacobhinkle 521d5cc
Use warp_tile for tma_m and tma_n
jacobhinkle dce16ad
Two warp tiles per CTA in each dim, increase instr to 64_64_16
jacobhinkle f5e084c
Also split by K
jacobhinkle be705bf
Add ScheduleWithTranslation test (failing)
jacobhinkle 9de3202
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle 41e2b94
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle e010ead
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle 5246fb3
Update to fix compilation
jacobhinkle 1dccf22
Don't do K split. Fix TMA offset
jacobhinkle 496d8d7
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle dfa6ff8
Add options for warp specialization and persistence strategy
jacobhinkle 21c508d
Temporarily revert change to scheduleStMatrixForMmaOutput
jacobhinkle 174deda
Parametrize MLP Benchmark tests to run three configurations
jacobhinkle 7868900
Unguard most matmul node translation tests on Hopper
jacobhinkle 9b5e73c
lintrunner
jacobhinkle 21a2710
Apply suggestions from code review
jacobhinkle f1fff43
Reparametrize and place a big comment explaining
jacobhinkle a3b8fd4
Update python bindings
jacobhinkle bfd65f3
Add more checks for valid configs
jacobhinkle 694e0fe
Set warp specialization as default on hopper
jacobhinkle 794285b
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle 95cf199
Guard MLPBenchmarkTest to Hopper only
jacobhinkle ffa276e
Merge remote-tracking branch 'origin/hopper_warptile_split' into jh/p…
jacobhinkle 86d75de
Merge in from #3642. Add persistent change
jacobhinkle 6d98405
Add BroadcastInputs tests
jacobhinkle 438e1a0
Remove debug prints
jacobhinkle 68c07a0
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle 4d0226c
Fix block parallelization
jacobhinkle 07c93c6
Override params for horizontal fusion tests
jacobhinkle 37a7282
Merge commit '9dc94c0' into jh/persistent_kernel_impl
jacobhinkle 74751b3
Merge commit '3ac19f0' into jh/persistent_kernel_impl
jacobhinkle 2527dc0
Merge commit 'a1baafa' into jh/persistent_kernel_impl
jacobhinkle e4486c8
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle b0359a2
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle 7f161bc
Uncomment correctness checks in tests
jacobhinkle 0cbb3e6
Guard failing MLPBenchmarkTest cases on Ampere
jacobhinkle dec223e
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle 913ba63
Don't do register sharing for OneTilePerCTA
jacobhinkle 4ef0339
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -309,6 +309,10 @@ bool fillDefaultHopperHeuristic( | |
|
||
mparams->tile_sizes = {cta_tile, warp_tile}; | ||
|
||
// Use warp specialization on hopper by default | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought using warp specialization by default was causing some test failures. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not anymore. I think that was before integrating the warp tile split. |
||
mparams->circular_buffering_strategy = | ||
MatmulParams::CircularBufferingStrategy::WarpSpecialized; | ||
|
||
// stages and async mem copy | ||
mparams->circular_buffer_options.smem_circular_buffer_stage = 8; | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rdspring1 I had to add this condition because of this:
Fuser/csrc/device_lower/analysis/circular_buffer.cpp
Lines 158 to 165 in fb2eacd
Commenting out that check leads nvrtc to hang :-D. This means we cannot use register sharing for persistent kernels yet. Do you think this might work if we lift the warp specialization predicate to the top level (i.e. outside of the persistent loop) during lowering?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you merge cta-M, cta-N, cta-K together into an iterdomain, couldn't you circular buffer this loop?
e.g., Implement a general persistent stream-k first, then add boundary conditions so you don't have mma tiles split over multiple SMs. Then, you wouldn't need special lowering for persistent matmul and stream-k matmul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't there a
return
in theLoadWarp
? NVRTC is likely lost with this kernel.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you're right. There is a return in that if, so it's not a valid persistent kernel, but ptxas compiles it without warnings about register sharing. And actually I had that wrong, the kernel compiles but it deadlocks when the dma warps return too early. If I comment out the insertion of the
kir::Return
the kernel succeeds. So maybe we should just not insert the return if we detect an outer non-trivial loop?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel will run without the return but performance tanks. The compiler will not optimize anything with register sharing enabled.