diff --git a/examples/07_gemm_all_scatter/gemm_all_scatter.py b/examples/07_gemm_all_scatter/gemm_all_scatter.py index 8c544fa9..4849db9a 100644 --- a/examples/07_gemm_all_scatter/gemm_all_scatter.py +++ b/examples/07_gemm_all_scatter/gemm_all_scatter.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -47,8 +47,7 @@ def persistent_gemm_all_scatter( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -67,12 +66,7 @@ def persistent_gemm_all_scatter( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py b/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py index e692f210..2a234574 100644 --- a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py +++ b/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -22,15 +22,8 @@ def tile_id_to_index_range( ): num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - - tile_in_group = tile_id % num_pid_in_group - pid_m = first_pid_m + (tile_in_group % group_size_m) - pid_n = tile_in_group // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm_start = pid_m * BLOCK_SIZE_M rn_start = pid_n * BLOCK_SIZE_N @@ -132,8 +125,7 @@ def persistent_gemm_all_reduce( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -152,12 +144,7 @@ def persistent_gemm_all_reduce( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py index 915e470b..4ee4425c 100644 --- a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py +++ b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -22,15 +22,8 @@ def tile_id_to_index_range( ): num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - - tile_in_group = tile_id % num_pid_in_group - pid_m = first_pid_m + (tile_in_group % group_size_m) - pid_n = tile_in_group // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm_start = pid_m * BLOCK_SIZE_M rn_start = pid_n * BLOCK_SIZE_N @@ -132,8 +125,7 @@ def persistent_gemm_all_reduce( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -152,12 +144,7 @@ def persistent_gemm_all_reduce( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index ac2d2e35..dfb4b05c 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -49,8 +49,7 @@ def persistent_gemm_all_scatter_wg_specialization( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -74,12 +73,7 @@ def persistent_gemm_all_scatter_wg_specialization( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -149,12 +143,7 @@ def persistent_gemm_all_scatter_wg_specialization( COMM_SMS = NUM_SMS - GEMM_SMS pid = pid - GEMM_SMS for tile_id in range(pid, total_tiles, COMM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) # Begin: See the if segment for explanation: rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M diff --git a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py index a8311943..ac6da432 100644 --- a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py +++ b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -45,8 +45,7 @@ def persistent_gemm( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, GEMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -65,12 +64,7 @@ def persistent_gemm( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -159,19 +153,13 @@ def persistent_all_scatter( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, COMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n for tile_id in range(pid, total_tiles, COMM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py index 42961398..d7f9ec38 100644 --- a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -44,8 +44,7 @@ def persistent_gemm( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, GEMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -64,12 +63,7 @@ def persistent_gemm( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -155,19 +149,13 @@ def persistent_all_scatter( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, COMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n for tile_id in range(pid, total_tiles, COMM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/14_all_gather_gemm/all_gather_gemm_pull.py b/examples/14_all_gather_gemm/all_gather_gemm_pull.py index c710c8a1..fc58b5d8 100644 --- a/examples/14_all_gather_gemm/all_gather_gemm_pull.py +++ b/examples/14_all_gather_gemm/all_gather_gemm_pull.py @@ -4,6 +4,7 @@ import triton import triton.language as tl +from examples.common.utils import chiplet_reorder, program_id_reorder import torch import iris @@ -35,8 +36,7 @@ def persistent_ag_gemm( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -52,12 +52,7 @@ def persistent_ag_gemm( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 for tile_id in range(pid, total_tiles, NUM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N diff --git a/examples/14_all_gather_gemm/all_gather_gemm_push.py b/examples/14_all_gather_gemm/all_gather_gemm_push.py index 7cb4fe4b..65685e31 100644 --- a/examples/14_all_gather_gemm/all_gather_gemm_push.py +++ b/examples/14_all_gather_gemm/all_gather_gemm_push.py @@ -4,6 +4,7 @@ import triton import triton.language as tl +from examples.common.utils import chiplet_reorder, program_id_reorder import iris import torch @@ -99,8 +100,7 @@ def gemm_push_kernel( world_size: tl.constexpr, ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, NUM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -121,12 +121,7 @@ def gemm_push_kernel( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 for tile_id in range(pid, total_tiles, NUM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py index 8cb1dbbf..447f67f4 100644 --- a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py +++ b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from examples.common.utils import read_realtime, chiplet_reorder, program_id_reorder import sys import os @@ -44,8 +44,7 @@ def persistent_gemm( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (GEMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, GEMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -64,12 +63,7 @@ def persistent_gemm( timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -153,19 +147,13 @@ def persistent_all_scatter( ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + pid = chiplet_reorder(pid, NUM_XCDS, COMM_SMS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n for tile_id in range(pid, total_tiles, COMM_SMS): - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m, pid_n = program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) diff --git a/examples/common/utils.py b/examples/common/utils.py index 0e6ea948..bf144711 100644 --- a/examples/common/utils.py +++ b/examples/common/utils.py @@ -167,3 +167,52 @@ def read_realtime(): pack=1, ) return tmp + + +@triton.jit +def chiplet_reorder(pid, NUM_XCDS: tl.constexpr, NUM_SMS: tl.constexpr): + """ + Apply XCD (compute die) space-filling curve reordering to program ID. + + This function reorders program IDs such that you fill an XCD with work + before going to the next XCD, improving locality when multiple compute + dies (chiplets) are present. + + Args: + pid: The original program ID from tl.program_id(0) + NUM_XCDS: Number of compute dies (XCDs) in the system + NUM_SMS: Total number of streaming multiprocessors + + Returns: + Reordered program ID that optimizes for XCD locality + """ + if NUM_XCDS != 1: + return (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + return pid + + +@triton.jit +def program_id_reorder(tile_id, num_pid_m, num_pid_n, GROUP_SIZE_M: tl.constexpr): + """ + Compute 2D tile coordinates (pid_m, pid_n) from linear tile_id using swizzling. + + This function implements a space-filling curve that groups tiles along the M + dimension to improve memory coalescing and cache locality. Tiles are organized + into groups of size GROUP_SIZE_M along the M dimension. + + Args: + tile_id: Linear tile index + num_pid_m: Number of tiles in the M dimension + num_pid_n: Number of tiles in the N dimension + GROUP_SIZE_M: Size of tile groups along M dimension for swizzling + + Returns: + Tuple of (pid_m, pid_n) representing the 2D coordinates of the tile + """ + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n