- 
                Notifications
    You must be signed in to change notification settings 
- Fork 22
Refactor space-filling curves from examples into common utilities #252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: neoblizz <[email protected]>
Co-authored-by: neoblizz <[email protected]>
        
          
                examples/common/utils.py
              
                Outdated
          
        
      | This function reorders program IDs to improve locality when multiple compute | ||
| dies (XCDs) are present. It ensures that consecutive PIDs are distributed | ||
| across different XCDs before moving to the next set of programs within an XCD. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot Slightly incorrect, it ensures the program IDs are reordered such that you "fill" an XCD with work before going to the next.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed the docstring in commit 2d6f8c4. The description now correctly states that the function "fills an XCD with work before going to the next XCD".
        
          
                examples/common/utils.py
              
                Outdated
          
        
      |  | ||
|  | ||
| @triton.jit | ||
| def apply_xcd_reordering(pid, NUM_XCDS: tl.constexpr, NUM_SMS: tl.constexpr): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot Better name: chiplet_reorder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to chiplet_reorder in commit 2d6f8c4.
        
          
                examples/common/utils.py
              
                Outdated
          
        
      |  | ||
|  | ||
| @triton.jit | ||
| def compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M: tl.constexpr): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot Rename program_id_reorder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to program_id_reorder in commit 2d6f8c4.
…reorder Co-authored-by: neoblizz <[email protected]>
Overview
This PR addresses issue #251 by refactoring repeated space-filling curve patterns from multiple example files into reusable utility functions in
examples/common/utils.py.Problem
The same space-filling curve patterns were duplicated across 9+ example files:
This code duplication made maintenance difficult and increased the risk of inconsistencies.
Solution
Created two new Triton JIT utility functions in
examples/common/utils.py:1.
chiplet_reorder(pid, NUM_XCDS, NUM_SMS)Applies XCD (compute die) space-filling curve reordering to program IDs. This reorders program IDs such that you fill an XCD (chiplet) with work before going to the next XCD, improving locality when multiple compute dies are present.
2.
program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M)Computes 2D tile coordinates
(pid_m, pid_n)from lineartile_idusing swizzling. This implements a space-filling curve that groups tiles along the M dimension to improve memory coalescing and cache locality.Both functions include comprehensive docstrings explaining their purpose, parameters, and return values.
Changes
Updated 9 example files to use the new utility functions:
examples/07_gemm_all_scatter/gemm_all_scatter.pyexamples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.pyexamples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.pyexamples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.pyexamples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.pyexamples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.pyexamples/14_all_gather_gemm/all_gather_gemm_push.pyexamples/14_all_gather_gemm/all_gather_gemm_pull.pyexamples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.pyImpact
Testing
Fixes #251
Original prompt
Fixes #251
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.