From 94a6fc50e037d0c71fadcf2ff313e1fc73c7dabd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 17 Oct 2025 05:27:16 +0000 Subject: [PATCH 1/4] Initial plan From 45eb42be6f9614103174d9b89001b9c2d581d600 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 17 Oct 2025 05:39:43 +0000 Subject: [PATCH 2/4] Refactor space-filling curves into common utilities Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com> --- .../07_gemm_all_scatter/gemm_all_scatter.py | 12 ++--- .../gemm_atomics_all_reduce.py | 14 ++---- .../gemm_one_shot_all_reduce.py | 14 ++---- .../gemm_all_scatter_wg_specialization.py | 19 ++----- .../gemm_all_scatter_producer_consumer.py | 22 ++------- .../gemm_all_scatter_bulk_synchronous.py | 22 ++------- .../all_gather_gemm_pull.py | 11 ++--- .../all_gather_gemm_push.py | 11 ++--- .../gemm_all_scatter_bulk_synchronous.py | 22 ++------- examples/common/utils.py | 49 +++++++++++++++++++ 10 files changed, 83 insertions(+), 113 deletions(-) diff --git a/examples/07_gemm_all_scatter/gemm_all_scatter.py b/examples/07_gemm_all_scatter/gemm_all_scatter.py index 8c544fa9..c0abe036 100644 --- a/examples/07_gemm_all_scatter/gemm_all_scatter.py +++ b/examples/07_gemm_all_scatter/gemm_all_scatter.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates import sys import os @@ -47,8 +47,7 @@ def persistent_gemm_all_scatter( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -67,12 +66,7 @@ def persistent_gemm_all_scatter( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py b/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py index e692f210..3d241c0a 100644 --- a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py +++ b/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates import sys import os @@ -22,15 +22,8 @@ def tile_id_to_index_range( ): num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - - tile_in_group = tile_id % num_pid_in_group - pid_m = first_pid_m + (tile_in_group % group_size_m) - pid_n = tile_in_group // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm_start = pid_m * BLOCK_SIZE_M rn_start = pid_n * BLOCK_SIZE_N @@ -132,8 +125,7 @@ def persistent_gemm_all_reduce( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n diff --git a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py index 915e470b..f34ff557 100644 --- a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py +++ b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates import sys import os @@ -22,15 +22,8 @@ def tile_id_to_index_range( ): num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - - tile_in_group = tile_id % num_pid_in_group - pid_m = first_pid_m + (tile_in_group % group_size_m) - pid_n = tile_in_group // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm_start = pid_m * BLOCK_SIZE_M rn_start = pid_n * BLOCK_SIZE_N @@ -132,8 +125,7 @@ def persistent_gemm_all_reduce( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index ac2d2e35..940ebc07 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates import sys import os @@ -49,8 +49,7 @@ def persistent_gemm_all_scatter_wg_specialization( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -74,12 +73,7 @@ def persistent_gemm_all_scatter_wg_specialization( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -149,12 +143,7 @@ def persistent_gemm_all_scatter_wg_specialization( COMM_SMS = NUM_SMS - GEMM_SMS pid = pid - GEMM_SMS for tile_id in range(pid, total_tiles, COMM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) # Begin: See the if segment for explanation: rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M diff --git a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py index a8311943..ee46b3b8 100644 --- a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py +++ b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates import sys import os @@ -45,8 +45,7 @@ def persistent_gemm( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, GEMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -65,12 +64,7 @@ def persistent_gemm( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -159,19 +153,13 @@ def persistent_all_scatter( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, COMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n for tile_id in range(pid, total_tiles, COMM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py index 42961398..fb709fdd 100644 --- a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates import sys import os @@ -44,8 +44,7 @@ def persistent_gemm( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, GEMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -64,12 +63,7 @@ def persistent_gemm( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -155,19 +149,13 @@ def persistent_all_scatter( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, COMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n for tile_id in range(pid, total_tiles, COMM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/14_all_gather_gemm/all_gather_gemm_pull.py b/examples/14_all_gather_gemm/all_gather_gemm_pull.py index c710c8a1..a9d8c78d 100644 --- a/examples/14_all_gather_gemm/all_gather_gemm_pull.py +++ b/examples/14_all_gather_gemm/all_gather_gemm_pull.py @@ -4,6 +4,7 @@ import triton import triton.language as tl +from examples.common.utils import apply_xcd_reordering, compute_tile_coordinates import torch import iris @@ -35,8 +36,7 @@ def persistent_ag_gemm( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -52,12 +52,7 @@ def persistent_ag_gemm( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 for tile_id in range(pid, total_tiles, NUM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/14_all_gather_gemm/all_gather_gemm_push.py b/examples/14_all_gather_gemm/all_gather_gemm_push.py index 7cb4fe4b..d4453da4 100644 --- a/examples/14_all_gather_gemm/all_gather_gemm_push.py +++ b/examples/14_all_gather_gemm/all_gather_gemm_push.py @@ -4,6 +4,7 @@ import triton import triton.language as tl +from examples.common.utils import apply_xcd_reordering, compute_tile_coordinates import iris import torch @@ -99,8 +100,7 @@ def gemm_push_kernel( world_size: tl.constexpr, ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -121,12 +121,7 @@ def gemm_push_kernel( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 for tile_id in range(pid, total_tiles, NUM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py index 8cb1dbbf..dd47a0cc 100644 --- a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py +++ b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates import sys import os @@ -44,8 +44,7 @@ def persistent_gemm( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, GEMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -64,12 +63,7 @@ def persistent_gemm( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -153,19 +147,13 @@ def persistent_all_scatter( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = apply_xcd_reordering(pid, NUM_XCDS, COMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n for tile_id in range(pid, total_tiles, COMM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/common/utils.py b/examples/common/utils.py index 0e6ea948..d57d1883 100644 --- a/examples/common/utils.py +++ b/examples/common/utils.py @@ -167,3 +167,52 @@ def read_realtime(): pack=1, ) return tmp + + +@triton.jit +def apply_xcd_reordering(pid, NUM_XCDS: tl.constexpr, NUM_SMS: tl.constexpr): + """ + Apply XCD (compute die) space-filling curve reordering to program ID. + + This function reorders program IDs to improve locality when multiple compute + dies (XCDs) are present. It ensures that consecutive PIDs are distributed + across different XCDs before moving to the next set of programs within an XCD. + + Args: + pid: The original program ID from tl.program_id(0) + NUM_XCDS: Number of compute dies (XCDs) in the system + NUM_SMS: Total number of streaming multiprocessors + + Returns: + Reordered program ID that optimizes for XCD locality + """ + if NUM_XCDS != 1: + return (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + return pid + + +@triton.jit +def compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M: tl.constexpr): + """ + Compute 2D tile coordinates (pid_m, pid_n) from linear tile_id using swizzling. + + This function implements a space-filling curve that groups tiles along the M + dimension to improve memory coalescing and cache locality. Tiles are organized + into groups of size GROUP_SIZE_M along the M dimension. + + Args: + tile_id: Linear tile index + num_pid_m: Number of tiles in the M dimension + num_pid_n: Number of tiles in the N dimension + GROUP_SIZE_M: Size of tile groups along M dimension for swizzling + + Returns: + Tuple of (pid_m, pid_n) representing the 2D coordinates of the tile + """ + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n From 725bbb232139bfdbb7cd5542ee2c5f2c75d193d2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 17 Oct 2025 05:42:15 +0000 Subject: [PATCH 3/4] Complete refactoring of space-filling curves in all examples Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com> --- .../08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py | 7 +------ .../gemm_one_shot_all_reduce.py | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py b/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py index 3d241c0a..485b7ae1 100644 --- a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py +++ b/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py @@ -144,12 +144,7 @@ def persistent_gemm_all_reduce( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py index f34ff557..8b5fe60e 100644 --- a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py +++ b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py @@ -144,12 +144,7 @@ def persistent_gemm_all_reduce( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N From 9b1677b464f3a4361d845784f9109c45a8872d60 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 17 Oct 2025 05:54:39 +0000 Subject: [PATCH 4/4] Rename functions per review feedback: chiplet_reorder and program_id_reorder Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com> --- examples/07_gemm_all_scatter/gemm_all_scatter.py | 6 +++--- .../gemm_atomics_all_reduce.py | 8 ++++---- .../gemm_one_shot_all_reduce.py | 8 ++++---- .../gemm_all_scatter_wg_specialization.py | 8 ++++---- .../gemm_all_scatter_producer_consumer.py | 10 +++++----- .../gemm_all_scatter_bulk_synchronous.py | 10 +++++----- examples/14_all_gather_gemm/all_gather_gemm_pull.py | 6 +++--- examples/14_all_gather_gemm/all_gather_gemm_push.py | 6 +++--- .../gemm_all_scatter_bulk_synchronous.py | 10 +++++----- examples/common/utils.py | 10 +++++----- 10 files changed, 41 insertions(+), 41 deletions(-) diff --git a/examples/07_gemm_all_scatter/gemm_all_scatter.py b/examples/07_gemm_all_scatter/gemm_all_scatter.py index c0abe036..4849db9a 100644 --- a/examples/07_gemm_all_scatter/gemm_all_scatter.py +++ b/examples/07_gemm_all_scatter/gemm_all_scatter.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -47,7 +47,7 @@ def persistent_gemm_all_scatter( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -66,7 +66,7 @@ def persistent_gemm_all_scatter( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py b/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py index 485b7ae1..2a234574 100644 --- a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py +++ b/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -23,7 +23,7 @@ def tile_id_to_index_range( num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm_start = pid_m * BLOCK_SIZE_M rn_start = pid_n * BLOCK_SIZE_N @@ -125,7 +125,7 @@ def persistent_gemm_all_reduce( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -144,7 +144,7 @@ def persistent_gemm_all_reduce( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py index 8b5fe60e..4ee4425c 100644 --- a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py +++ b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -23,7 +23,7 @@ def tile_id_to_index_range( num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm_start = pid_m * BLOCK_SIZE_M rn_start = pid_n * BLOCK_SIZE_N @@ -125,7 +125,7 @@ def persistent_gemm_all_reduce( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -144,7 +144,7 @@ def persistent_gemm_all_reduce( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index 940ebc07..dfb4b05c 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -49,7 +49,7 @@ def persistent_gemm_all_scatter_wg_specialization( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -73,7 +73,7 @@ def persistent_gemm_all_scatter_wg_specialization( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -143,7 +143,7 @@ def persistent_gemm_all_scatter_wg_specialization( COMM_SMS = NUM_SMS - GEMM_SMS pid = pid - GEMM_SMS for tile_id in range(pid, total_tiles, COMM_SMS): - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) # Begin: See the if segment for explanation: rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M diff --git a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py index ee46b3b8..ac6da432 100644 --- a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py +++ b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -45,7 +45,7 @@ def persistent_gemm( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, GEMM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, GEMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -64,7 +64,7 @@ def persistent_gemm( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -153,13 +153,13 @@ def persistent_all_scatter( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, COMM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, COMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n for tile_id in range(pid, total_tiles, COMM_SMS): - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py index fb709fdd..d7f9ec38 100644 --- a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -44,7 +44,7 @@ def persistent_gemm( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, GEMM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, GEMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -63,7 +63,7 @@ def persistent_gemm( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -149,13 +149,13 @@ def persistent_all_scatter( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, COMM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, COMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n for tile_id in range(pid, total_tiles, COMM_SMS): - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/14_all_gather_gemm/all_gather_gemm_pull.py b/examples/14_all_gather_gemm/all_gather_gemm_pull.py index a9d8c78d..fc58b5d8 100644 --- a/examples/14_all_gather_gemm/all_gather_gemm_pull.py +++ b/examples/14_all_gather_gemm/all_gather_gemm_pull.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from examples.common.utils import apply_xcd_reordering, compute_tile_coordinates +from examples.common.utils import chiplet_reorder, program_id_reorder import torch import iris @@ -36,7 +36,7 @@ def persistent_ag_gemm( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -52,7 +52,7 @@ def persistent_ag_gemm( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 for tile_id in range(pid, total_tiles, NUM_SMS): - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/14_all_gather_gemm/all_gather_gemm_push.py b/examples/14_all_gather_gemm/all_gather_gemm_push.py index d4453da4..65685e31 100644 --- a/examples/14_all_gather_gemm/all_gather_gemm_push.py +++ b/examples/14_all_gather_gemm/all_gather_gemm_push.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from examples.common.utils import apply_xcd_reordering, compute_tile_coordinates +from examples.common.utils import chiplet_reorder, program_id_reorder import iris import torch @@ -100,7 +100,7 @@ def gemm_push_kernel( world_size: tl.constexpr, ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, NUM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -121,7 +121,7 @@ def gemm_push_kernel( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 for tile_id in range(pid, total_tiles, NUM_SMS): - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py index dd47a0cc..447f67f4 100644 --- a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py +++ b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime, apply_xcd_reordering, compute_tile_coordinates +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -44,7 +44,7 @@ def persistent_gemm( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, GEMM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, GEMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -63,7 +63,7 @@ def persistent_gemm( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -147,13 +147,13 @@ def persistent_all_scatter( ): pid = tl.program_id(0) - pid = apply_xcd_reordering(pid, NUM_XCDS, COMM_SMS) + pid = chiplet_reorder(pid, NUM_XCDS, COMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n for tile_id in range(pid, total_tiles, COMM_SMS): - pid_m, pid_n = compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/common/utils.py b/examples/common/utils.py index d57d1883..bf144711 100644 --- a/examples/common/utils.py +++ b/examples/common/utils.py @@ -170,13 +170,13 @@ def read_realtime(): @triton.jit -def apply_xcd_reordering(pid, NUM_XCDS: tl.constexpr, NUM_SMS: tl.constexpr): +def chiplet_reorder(pid, NUM_XCDS: tl.constexpr, NUM_SMS: tl.constexpr): """ Apply XCD (compute die) space-filling curve reordering to program ID. - This function reorders program IDs to improve locality when multiple compute - dies (XCDs) are present. It ensures that consecutive PIDs are distributed - across different XCDs before moving to the next set of programs within an XCD. + This function reorders program IDs such that you fill an XCD with work + before going to the next XCD, improving locality when multiple compute + dies (chiplets) are present. Args: pid: The original program ID from tl.program_id(0) @@ -192,7 +192,7 @@ def apply_xcd_reordering(pid, NUM_XCDS: tl.constexpr, NUM_SMS: tl.constexpr): @triton.jit -def compute_tile_coordinates(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M: tl.constexpr): +def program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M: tl.constexpr): """ Compute 2D tile coordinates (pid_m, pid_n) from linear tile_id using swizzling.