From eaf96492cad4c4a2d407d064595ed7195891283e Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 00:27:11 -0700 Subject: [PATCH 01/12] add gemm+rs --- examples/13_gemm_reduce_scatter/benchmark.py | 0 .../gemm_reduce_scatter.py | 302 ++++++++++++++++++ .../13_gemm_reduce_scatter/matmul_wrapper.py | 253 +++++++++++++++ 3 files changed, 555 insertions(+) create mode 100644 examples/13_gemm_reduce_scatter/benchmark.py create mode 100644 examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py create mode 100644 examples/13_gemm_reduce_scatter/matmul_wrapper.py diff --git a/examples/13_gemm_reduce_scatter/benchmark.py b/examples/13_gemm_reduce_scatter/benchmark.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py new file mode 100644 index 00000000..90e0f841 --- /dev/null +++ b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py @@ -0,0 +1,302 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from examples.common.utils import read_realtime + +import sys +import os + +import iris + + +@triton.jit +def tile_id_to_index_range( + tile_id, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + + tile_in_group = tile_id % num_pid_in_group + pid_m = first_pid_m + (tile_in_group % group_size_m) + pid_n = tile_in_group // group_size_m + + rm_start = pid_m * BLOCK_SIZE_M + rn_start = pid_n * BLOCK_SIZE_N + + max_m = M - 1 + max_n = N - 1 + + rm = rm_start + tl.arange(0, BLOCK_SIZE_M) + rn = rn_start + tl.arange(0, BLOCK_SIZE_N) + + rm = tl.minimum(rm, max_m) + rn = tl.minimum(rn, max_n) + + return rm, rn, rm_start, rn_start + + +@triton.jit +def offset_for_tile(local_tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M_local, N_local): + rm, rn, rm_start, rn_start = tile_id_to_index_range( + local_tile_id, M_local, N_local, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + c_mask = (rm[:, None] < M_local) & (rn[None, :] < N_local) + return rm, rn, c_mask, rm_start, rn_start + + +@triton.jit +def extract_submask_and_offset( + rm, + rn, + mask, + rm_start, + rn_start, + start_row, + start_col, + SUB_BLOCK_SIZE_M: tl.constexpr, + SUB_BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + stride_cm_local: tl.constexpr, + stride_cn_local: tl.constexpr, +): + sub_rm = tl.arange(0, SUB_BLOCK_SIZE_M) + start_row + sub_rn = tl.arange(0, SUB_BLOCK_SIZE_N) + start_col + + sub_rm_2d = sub_rm[:, None] + sub_rn_2d = sub_rn[None, :] + + sub_mask = (sub_rm_2d < BLOCK_SIZE_M) & (sub_rn_2d < BLOCK_SIZE_N) + + sub_offset = ((rm_start + sub_rm_2d) * stride_cm_local) + ((rn_start + sub_rn_2d) * stride_cn_local) + + return sub_mask, sub_offset + + +@triton.jit +def compute_output_partition(cur_rank, world_size, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + """ + 计算当前rank负责的输出分区 + ReduceScatter: 每个rank只负责最终结果的一部分 + """ + # 按行分区(也可以按列或其他方式) + rows_per_rank = tl.cdiv(M, world_size) + start_row = cur_rank * rows_per_rank + end_row = min((cur_rank + 1) * rows_per_rank, M) + + return start_row, end_row + + +@triton.jit +def persistent_gemm_reduce_scatter( + A, + B, + C, + c_local, # 修改:本地输出缓冲区,不是全局的 + bias_ptr, + P, + locks, + tile_completed, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_cm_local, # 修改:本地输出的stride + stride_cn_local, + stride_bias, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + STREAMK_TILES: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + NOTIFY_REMOTES: tl.constexpr = False, + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + # 计算当前rank负责的输出分区 + start_row, end_row = compute_output_partition(cur_rank, world_size, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N) + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + for tile_id in range(pid, total_tiles, NUM_SMS): + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + a = tl.load(tl.multiple_of(A_BASE, (1, 16))) + b = tl.load(tl.multiple_of(B_BASE, (16, 1))) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b) + + # 存储中间结果 + c = acc.to(C.type.element_ty) + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, c, c_mask) + + # 信号通知其他rank + for remote in range(world_size): + if remote != cur_rank: + iris.atomic_add( + tile_completed + tile_id, + 1, + cur_rank, + remote, + heap_bases, + sem="release", + scope="sys", + ) + + # 等待所有rank完成这个tile + result = 0 + while result < (world_size - 1): + compare = world_size - 1 + value = 0 + result = iris.atomic_cas( + tile_completed + tile_id, + compare, + value, + cur_rank, + cur_rank, + heap_bases, + sem="acquire", + scope="sys", + ) + + # ReduceScatter关键修改:只收集和存储属于本rank分区的数据 + rm, rn, mask, rm_start, rn_start = offset_for_tile(tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M, N) + + num_sub_tiles_m = tl.cdiv(BLOCK_SIZE_M, BLOCK_SIZE_M) + num_sub_tiles_n = tl.cdiv(BLOCK_SIZE_N, BLOCK_SIZE_N) + total_sub_tiles = num_sub_tiles_m * num_sub_tiles_n + + for sub_tile_idx in range(0, total_sub_tiles): + start_row_sub = (sub_tile_idx // num_sub_tiles_n) * BLOCK_SIZE_M + start_col_sub = (sub_tile_idx % num_sub_tiles_n) * BLOCK_SIZE_N + + sub_mask, sub_offset = extract_submask_and_offset( + rm, + rn, + mask, + rm_start, + rn_start, + start_row_sub, + start_col_sub, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + stride_cm, + stride_cn, + ) + + # 关键修改:检查这个sub-tile是否属于当前rank的负责区域 + global_row_start = rm_start + start_row_sub + + # 只处理属于当前rank分区的数据 + if global_row_start < end_row and (global_row_start + BLOCK_SIZE_M) > start_row: + # 计算在sub-tile内的有效行范围 + tile_start_row = max(0, start_row - global_row_start) + tile_end_row = min(BLOCK_SIZE_M, end_row - global_row_start) + + # 计算在本地缓冲区中的起始行 + local_start_row = max(global_row_start, start_row) - start_row + + # 创建行掩码 + row_indices = tl.arange(0, BLOCK_SIZE_M) + row_mask = (row_indices >= tile_start_row) & (row_indices < tile_end_row) + write_mask = row_mask[:, None] & (tl.arange(0, BLOCK_SIZE_N)[None, :] < BLOCK_SIZE_N) + + # 归约所有rank的数据 + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for remote_rank in range(world_size): + remote_data = iris.load(C + sub_offset, cur_rank, remote_rank, heap_bases, mask=sub_mask) + acc += remote_data + + # 计算本地偏移 + local_offset = (local_start_row * stride_cm_local + + (rn_start + start_col_sub) * stride_cn_local) + + tl.store(c_local + local_offset, acc, mask=write_mask, cache_modifier=".wt") + + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) \ No newline at end of file diff --git a/examples/13_gemm_reduce_scatter/matmul_wrapper.py b/examples/13_gemm_reduce_scatter/matmul_wrapper.py new file mode 100644 index 00000000..aaefcaab --- /dev/null +++ b/examples/13_gemm_reduce_scatter/matmul_wrapper.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import random +import sys +import os + +from gemm__reduce_scatter import persistent_gemm_reduce_scatter + +from examples.common.utils import is_triton_interpret_set + +gemm_kernel = persistent_gemm_reduce_scatter + + +class matmul_reduce_scatter(torch.autograd.Function): + _debug = True + + @staticmethod + def set_debug(debug: bool): + matmul_reduce_scatter._debug = debug + + @staticmethod + def _call( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + c_local: torch.Tensor, # 修改:本地输出而不是全局输出 + bias: torch.Tensor, + P: torch.Tensor, + locks: torch.Tensor, + tile_completed: torch.Tensor, + rank: int, + world_size: int, + total_programs_streamk: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + two_tiles: bool, + num_stages: int, + num_warps: int, + waves_per_eu: int, + mfmaInstrSize: int, + kpack: int, + heap_bases_ptr: torch.Tensor = None, + cu_count: int = 304, + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # 关键修改:计算每个rank负责的输出分区大小 + rows_per_rank = (M + world_size - 1) // world_size + local_M = rows_per_rank + if rank == world_size - 1: # 最后一个rank处理剩余的行 + local_M = M - rank * rows_per_rank + + # 验证本地输出缓冲区大小是否正确 + assert c_local.shape[0] == local_M, f"c_local shape mismatch: expected {local_M}, got {c_local.shape[0]}" + assert c_local.shape[1] == N, f"c_local shape mismatch: expected {N}, got {c_local.shape[1]}" + + num_xcds = 1 + if cu_count == 304: + num_xcds = 8 + + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + total_tiles = total_blocks_M * total_blocks_N + even_k = K % BLK_K == 0 + + if total_programs_streamk > 0: + total_tiles_streamk = total_tiles % total_programs_streamk + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + else: + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + if matmul_reduce_scatter._debug: + print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") + print(f"Rank {rank}/{world_size} responsible for {local_M} rows") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + + use_bias = False + stride_bias = bias.stride(0) if use_bias else 0 + + # 关键修改:使用ReduceScatter内核 + grids = total_programs_streamk + kk = gemm_kernel[(grids,)]( + a, + b, + c, + c_local, # 修改:传递本地输出缓冲区 + bias, + P, + locks, + tile_completed, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + c_local.stride(0), # 修改:本地输出的stride + c_local.stride(1), + stride_bias, + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + NUM_SMS=total_programs_streamk, + STREAMK_TILES=total_tiles_streamk, + NUM_XCDS=num_xcds, + BIAS=use_bias, + EVEN_K=even_k, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + heap_bases=heap_bases_ptr, + cur_rank=rank, + world_size=world_size, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp_ptr=mm_begin_timestamp, + mm_end_timestamp_ptr=mm_end_timestamp, + ) + + if matmul_reduce_scatter._debug and not is_triton_interpret_set(): + matmul_reduce_scatter.streamk_registers = kk.n_regs + matmul_reduce_scatter.streamk_spills = kk.n_spills + print(f"{kk.n_regs} registers used, {kk.n_spills} spills") + + return c_local # 修改:返回本地输出 + + @staticmethod + def forward( + ctx, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + c_local: torch.Tensor, # 修改:本地输出缓冲区 + bias: torch.Tensor, + P: torch.Tensor, + locks: torch.Tensor, + tile_completed: torch.Tensor, + rank: int, + world_size: int, + grid: int, + BLK_M=128, + BLK_N=128, + BLK_K=32, + gsize_m=1, + two_tiles=True, + num_stages=3, + num_warps=4, + waves_per_eu=2, + mfmaInstrSize=16, + kpack=1, + heap_bases_ptr: torch.Tensor = None, + cu_count: int = 304, + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + result = matmul_reduce_scatter._call( + a=a, + b=b, + c=c, + c_local=c_local, # 修改:传递本地输出 + bias=bias, + P=P, + locks=locks, + tile_completed=tile_completed, + rank=rank, + world_size=world_size, + total_programs_streamk=grid, + BLK_M=BLK_M, + BLK_N=BLK_N, + BLK_K=BLK_K, + gsize_m=gsize_m, + two_tiles=two_tiles, + num_warps=num_warps, + num_stages=num_stages, + waves_per_eu=waves_per_eu, + mfmaInstrSize=mfmaInstrSize, + kpack=kpack, + heap_bases_ptr=heap_bases_ptr, + cu_count=cu_count, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp=mm_begin_timestamp, + mm_end_timestamp=mm_end_timestamp, + ) + return result # 修改:返回本地输出结果 + + +# # 使用示例函数 +# def gemm_reduce_scatter_example(): +# # 初始化参数 +# M, N, K = 1024, 1024, 1024 +# world_size = 4 +# rank = 0 # 在实际中需要根据当前rank设置 + +# # 计算每个rank负责的行数 +# rows_per_rank = (M + world_size - 1) // world_size +# local_M = rows_per_rank +# if rank == world_size - 1: +# local_M = M - rank * rows_per_rank + +# # 创建输入张量 +# a = torch.randn(M, K, device='cuda') +# b = torch.randn(K, N, device='cuda') + +# # 创建中间缓冲区和输出缓冲区 +# c = torch.zeros(M, N, device='cuda') # 中间结果缓冲区 +# c_local = torch.zeros(local_M, N, device='cuda') # 本地输出缓冲区 + +# # 创建同步所需的张量 +# total_tiles = (M + 127) // 128 * (N + 127) // 128 +# tile_completed = torch.zeros(total_tiles, dtype=torch.int32, device='cuda') +# locks = torch.zeros(1024, dtype=torch.int32, device='cuda') # 锁缓冲区 +# P = torch.zeros(1, dtype=torch.int32, device='cuda') # 占位符 + +# # 调用ReduceScatter GEMM +# result = matmul_reduce_scatter.apply( +# a, b, c, c_local, None, P, locks, tile_completed, +# rank, world_size, 256, # grid size = 256 +# 128, 128, 32 # BLK_M, BLK_N, BLK_K +# ) + +# return result + + +# if __name__ == "__main__": +# # 测试代码 +# result = gemm_reduce_scatter_example() +# print(f"ReduceScatter result shape: {result.shape}") \ No newline at end of file From 10deb648433236d8f4407614de28ab034e944896 Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 00:28:37 -0700 Subject: [PATCH 02/12] add gemm+rs --- .../13_gemm_reduce_scatter/matmul_wrapper.py | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/examples/13_gemm_reduce_scatter/matmul_wrapper.py b/examples/13_gemm_reduce_scatter/matmul_wrapper.py index aaefcaab..2ab909ed 100644 --- a/examples/13_gemm_reduce_scatter/matmul_wrapper.py +++ b/examples/13_gemm_reduce_scatter/matmul_wrapper.py @@ -210,44 +210,44 @@ def forward( return result # 修改:返回本地输出结果 -# # 使用示例函数 -# def gemm_reduce_scatter_example(): -# # 初始化参数 -# M, N, K = 1024, 1024, 1024 -# world_size = 4 -# rank = 0 # 在实际中需要根据当前rank设置 +# 使用示例函数 +def gemm_reduce_scatter_example(): + # 初始化参数 + M, N, K = 1024, 1024, 1024 + world_size = 4 + rank = 0 # 在实际中需要根据当前rank设置 -# # 计算每个rank负责的行数 -# rows_per_rank = (M + world_size - 1) // world_size -# local_M = rows_per_rank -# if rank == world_size - 1: -# local_M = M - rank * rows_per_rank + # 计算每个rank负责的行数 + rows_per_rank = (M + world_size - 1) // world_size + local_M = rows_per_rank + if rank == world_size - 1: + local_M = M - rank * rows_per_rank -# # 创建输入张量 -# a = torch.randn(M, K, device='cuda') -# b = torch.randn(K, N, device='cuda') + # 创建输入张量 + a = torch.randn(M, K, device='cuda') + b = torch.randn(K, N, device='cuda') -# # 创建中间缓冲区和输出缓冲区 -# c = torch.zeros(M, N, device='cuda') # 中间结果缓冲区 -# c_local = torch.zeros(local_M, N, device='cuda') # 本地输出缓冲区 + # 创建中间缓冲区和输出缓冲区 + c = torch.zeros(M, N, device='cuda') # 中间结果缓冲区 + c_local = torch.zeros(local_M, N, device='cuda') # 本地输出缓冲区 -# # 创建同步所需的张量 -# total_tiles = (M + 127) // 128 * (N + 127) // 128 -# tile_completed = torch.zeros(total_tiles, dtype=torch.int32, device='cuda') -# locks = torch.zeros(1024, dtype=torch.int32, device='cuda') # 锁缓冲区 -# P = torch.zeros(1, dtype=torch.int32, device='cuda') # 占位符 + # 创建同步所需的张量 + total_tiles = (M + 127) // 128 * (N + 127) // 128 + tile_completed = torch.zeros(total_tiles, dtype=torch.int32, device='cuda') + locks = torch.zeros(1024, dtype=torch.int32, device='cuda') # 锁缓冲区 + P = torch.zeros(1, dtype=torch.int32, device='cuda') # 占位符 -# # 调用ReduceScatter GEMM -# result = matmul_reduce_scatter.apply( -# a, b, c, c_local, None, P, locks, tile_completed, -# rank, world_size, 256, # grid size = 256 -# 128, 128, 32 # BLK_M, BLK_N, BLK_K -# ) + # 调用ReduceScatter GEMM + result = matmul_reduce_scatter.apply( + a, b, c, c_local, None, P, locks, tile_completed, + rank, world_size, 256, # grid size = 256 + 128, 128, 32 # BLK_M, BLK_N, BLK_K + ) -# return result + return result -# if __name__ == "__main__": -# # 测试代码 -# result = gemm_reduce_scatter_example() -# print(f"ReduceScatter result shape: {result.shape}") \ No newline at end of file +if __name__ == "__main__": + # 测试代码 + result = gemm_reduce_scatter_example() + print(f"ReduceScatter result shape: {result.shape}") \ No newline at end of file From 2b85cb53dbdea76c30eb758dcdb338be8d1c8d2f Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 01:19:56 -0700 Subject: [PATCH 03/12] add benchmark and validation --- examples/13_gemm_reduce_scatter/benchmark.py | 302 ++++++++++++++++++ .../13_gemm_reduce_scatter/matmul_wrapper.py | 64 ++-- examples/common/validation.py | 14 + 3 files changed, 348 insertions(+), 32 deletions(-) diff --git a/examples/13_gemm_reduce_scatter/benchmark.py b/examples/13_gemm_reduce_scatter/benchmark.py index e69de29b..817da76a 100644 --- a/examples/13_gemm_reduce_scatter/benchmark.py +++ b/examples/13_gemm_reduce_scatter/benchmark.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import sys +import os +import argparse +import json + +from examples.common.utils import ( + JSONWriter, + Timestamps, + is_triton_interpret_set, +) + +import iris + +from matmul_wrapper import matmul_reduce_scatter +from examples.common.validation import validate_gemm, validate_gemm_reduce_scatter + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse matrix dimensions and configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A") + parser.add_argument("-n", type=int, default=4608, help="Number of columns in matrix B") + parser.add_argument("-k", type=int, default=36864, help="Common dimension between matrices A and B") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + # For All Scatter, use: 256x64x64 + # For One Shot, use: 256x256x64 + parser.add_argument("--BLK_M", type=int, default=256, help="Block size M") + parser.add_argument("--BLK_N", type=int, default=64, help="Block size N") + parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") + + # Best to try 1, 6 or 8 + parser.add_argument("--gsize_m", type=int, default=6, help="Grid size M") + parser.add_argument("--two_tiles", type=str, default="True", help="Use two tiles") + parser.add_argument("--num_stages", type=int, default=1, help="Number of stages") + parser.add_argument("--num_warps", type=int, default=8, help="Number of warps") + parser.add_argument("--waves_per_eu", type=int, default=0, help="Waves per execution unit") + parser.add_argument("--mfmaInstrSize", type=int, default=16, help="MFMA instruction size") + parser.add_argument("--kpack", type=int, default=2, help="K packing size") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + + parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") + parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("-r", "--num_ranks", type=int, default=4, help="Number of ranks/processes") + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + + # GEMM + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "int8": + datatype = torch.int8 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + assert args["n"] % world_size == 0, f"N ({args['n']}) must be divisible by world size ({world_size})." + assert args["k"] % world_size == 0, f"K ({args['k']}) must be divisible by world size ({world_size})." + + A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype) + B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T + C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) + + args["M"] = args["m"] + args["N"] = args["n"] + args["K"] = args["k"] + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + # Splitting + rows_per_gpu = args["k"] // world_size + args["k"] = rows_per_gpu + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + for key, value in args.items(): + json_writer.add_field(key, value) + + # global_C = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) + global_C = None + local_C = shmem.zeros((args["m"] // world_size, args["n"]), device="cuda", dtype=A.dtype) + + total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) + total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + if args["gemm_sms"] >= args["total_sms"]: + print(f"Invalid number of GEMM SMs. {args['gemm_sms']} >= {args['total_sms']}") + exit(1) + + tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + + locks = shmem.zeros((args["gemm_sms"],), device="cuda", dtype=torch.int32) + + P = shmem.zeros( + (args["gemm_sms"], args["BLK_M"] * args["BLK_N"]), + device="cuda", + dtype=torch.float32, + ) + bias = None + + gemm_stream = torch.cuda.Stream() + + json_writer.add_field("gemm_sms", args["gemm_sms"]) + + kernel_timing = { + "gemm": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + } + } + + # Timestamps + timestamps = Timestamps(num_tiles=total_tiles) + + def preamble(): + shmem.barrier() + tile_completed.zero_() + shmem.barrier() + + def run_experiment(): + nonlocal local_C + nonlocal global_C + nonlocal kernel_timing + + shmem.barrier() + + if args["trace_tiles"]: + timestamps.reset() + shmem.barrier() + + torch.cuda.nvtx.range_push("GEMM + Communication") + with torch.cuda.stream(gemm_stream): + kernel_timing["gemm"]["start_event"].record() + local_C = matmul_reduce_scatter.apply( + local_A, + local_B, + local_C, + local_C, + bias, + P, + locks, + tile_completed, + rank, + world_size, + args["gemm_sms"], + args["BLK_M"], + args["BLK_N"], + args["BLK_K"], + args["gsize_m"], + args["two_tiles"], + args["num_stages"], + args["num_warps"], + args["waves_per_eu"], + args["mfmaInstrSize"], + args["kpack"], + shmem.get_heap_bases(), + cu_count, + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["gemm"]["end_event"].record() + kernel_timing["gemm"]["experiments"] += 1 + + torch.cuda.nvtx.range_pop() + shmem.barrier() + + for k in ["gemm"]: + ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) + kernel_timing[k]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + # Warmup + run_experiment() + + shmem.barrier() + preamble() + shmem.barrier() + + for k in ["gemm"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + if not is_triton_interpret_set(): + gemm_registers = matmul_reduce_scatter.streamk_registers + gemm_spills = matmul_reduce_scatter.streamk_spills + + json_writer.add_field("gemm_registers", gemm_registers) + json_writer.add_field("gemm_spills", gemm_spills) + + if args["validate"]: + shmem.info("Validating...") + + matmul_reduce_scatter.set_debug(False) + # Validate global result + success = validate_gemm_reduce_scatter(A, B, local_C, rank, world_size, shmem, atol=2) + passed_str = "passed" if success else "failed" + shmem.info(f"Final C validation {passed_str}.") + + # Wait for all to finish validation + shmem.barrier() + json_writer.add_field("success", success) + shmem.info("Validation completed") + + if args["benchmark"]: + shmem.info("Benchmarking...") + perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) + triton_tflops = perf(triton_ms) + shmem.info(f"tile matmul + reduce_scatter (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") + + json_writer.add_field("triton_tflops", triton_tflops) + json_writer.add_field("triton_ms", triton_ms) + + for k in ["gemm"]: + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + if args["trace_tiles"] and rank == 0: + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 + filename = f"gemm_reduce_scatter_tiles_trace_rank{rank}.json" + timestamps.to_json(filename, gpu_freq) + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/13_gemm_reduce_scatter/matmul_wrapper.py b/examples/13_gemm_reduce_scatter/matmul_wrapper.py index 2ab909ed..5cbc6ac5 100644 --- a/examples/13_gemm_reduce_scatter/matmul_wrapper.py +++ b/examples/13_gemm_reduce_scatter/matmul_wrapper.py @@ -211,43 +211,43 @@ def forward( # 使用示例函数 -def gemm_reduce_scatter_example(): - # 初始化参数 - M, N, K = 1024, 1024, 1024 - world_size = 4 - rank = 0 # 在实际中需要根据当前rank设置 +# def gemm_reduce_scatter_example(): +# # 初始化参数 +# M, N, K = 1024, 1024, 1024 +# world_size = 4 +# rank = 0 # 在实际中需要根据当前rank设置 - # 计算每个rank负责的行数 - rows_per_rank = (M + world_size - 1) // world_size - local_M = rows_per_rank - if rank == world_size - 1: - local_M = M - rank * rows_per_rank +# # 计算每个rank负责的行数 +# rows_per_rank = (M + world_size - 1) // world_size +# local_M = rows_per_rank +# if rank == world_size - 1: +# local_M = M - rank * rows_per_rank - # 创建输入张量 - a = torch.randn(M, K, device='cuda') - b = torch.randn(K, N, device='cuda') +# # 创建输入张量 +# a = torch.randn(M, K, device='cuda') +# b = torch.randn(K, N, device='cuda') - # 创建中间缓冲区和输出缓冲区 - c = torch.zeros(M, N, device='cuda') # 中间结果缓冲区 - c_local = torch.zeros(local_M, N, device='cuda') # 本地输出缓冲区 +# # 创建中间缓冲区和输出缓冲区 +# c = torch.zeros(M, N, device='cuda') # 中间结果缓冲区 +# c_local = torch.zeros(local_M, N, device='cuda') # 本地输出缓冲区 - # 创建同步所需的张量 - total_tiles = (M + 127) // 128 * (N + 127) // 128 - tile_completed = torch.zeros(total_tiles, dtype=torch.int32, device='cuda') - locks = torch.zeros(1024, dtype=torch.int32, device='cuda') # 锁缓冲区 - P = torch.zeros(1, dtype=torch.int32, device='cuda') # 占位符 +# # 创建同步所需的张量 +# total_tiles = (M + 127) // 128 * (N + 127) // 128 +# tile_completed = torch.zeros(total_tiles, dtype=torch.int32, device='cuda') +# locks = torch.zeros(1024, dtype=torch.int32, device='cuda') # 锁缓冲区 +# P = torch.zeros(1, dtype=torch.int32, device='cuda') # 占位符 - # 调用ReduceScatter GEMM - result = matmul_reduce_scatter.apply( - a, b, c, c_local, None, P, locks, tile_completed, - rank, world_size, 256, # grid size = 256 - 128, 128, 32 # BLK_M, BLK_N, BLK_K - ) +# # 调用ReduceScatter GEMM +# result = matmul_reduce_scatter.apply( +# a, b, c, c_local, None, P, locks, tile_completed, +# rank, world_size, 256, # grid size = 256 +# 128, 128, 32 # BLK_M, BLK_N, BLK_K +# ) - return result +# return result -if __name__ == "__main__": - # 测试代码 - result = gemm_reduce_scatter_example() - print(f"ReduceScatter result shape: {result.shape}") \ No newline at end of file +# if __name__ == "__main__": +# # 测试代码 +# result = gemm_reduce_scatter_example() +# print(f"ReduceScatter result shape: {result.shape}") \ No newline at end of file diff --git a/examples/common/validation.py b/examples/common/validation.py index dfe513aa..6fbadd0f 100644 --- a/examples/common/validation.py +++ b/examples/common/validation.py @@ -21,3 +21,17 @@ def validate_gemm(A, B, C, shmem, atol=1): return False return True + +def validate_gemm_reduce_scatter(A, B, local_C, rank, world_size, shmem, atol=1): + # 计算完整的groundtruth + full_result = torch.mm(A, B) + + # 计算当前rank应该负责的部分 + rows_per_gpu = A.shape[0] // world_size + start_row = rank * rows_per_gpu + end_row = start_row + local_C.shape[0] + + expected_local = full_result[start_row:end_row, :] + + # 验证本地结果 + return torch.allclose(local_C, expected_local, atol=atol) From a7d51d236c1cfec7e26db7fc4c96ada42ec39e54 Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 01:32:02 -0700 Subject: [PATCH 04/12] adjust acc buffer to full size --- examples/13_gemm_reduce_scatter/benchmark.py | 16 ++++++++++------ .../13_gemm_reduce_scatter/matmul_wrapper.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/13_gemm_reduce_scatter/benchmark.py b/examples/13_gemm_reduce_scatter/benchmark.py index 817da76a..56f3c057 100644 --- a/examples/13_gemm_reduce_scatter/benchmark.py +++ b/examples/13_gemm_reduce_scatter/benchmark.py @@ -98,7 +98,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): print("Unknown datatype.") exit(1) - assert args["n"] % world_size == 0, f"N ({args['n']}) must be divisible by world size ({world_size})." + assert args["m"] % world_size == 0, f"M ({args['m']}) must be divisible by world size ({world_size})." assert args["k"] % world_size == 0, f"K ({args['k']}) must be divisible by world size ({world_size})." A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype) @@ -125,8 +125,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # global_C = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) global_C = None - local_C = shmem.zeros((args["m"] // world_size, args["n"]), device="cuda", dtype=A.dtype) - + # local_C = shmem.zeros((args["m"] // world_size, args["n"]), device="cuda", dtype=A.dtype) + # 中间计算缓冲区(全尺寸) + compute_buffer = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) + # 本地输出缓冲区(分区尺寸) + local_output = shmem.zeros((args["m"] // world_size, args["n"]), device="cuda", dtype=A.dtype) + total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N @@ -181,11 +185,11 @@ def run_experiment(): torch.cuda.nvtx.range_push("GEMM + Communication") with torch.cuda.stream(gemm_stream): kernel_timing["gemm"]["start_event"].record() - local_C = matmul_reduce_scatter.apply( + local_output = matmul_reduce_scatter.apply( local_A, local_B, - local_C, - local_C, + compute_buffer, + local_output, bias, P, locks, diff --git a/examples/13_gemm_reduce_scatter/matmul_wrapper.py b/examples/13_gemm_reduce_scatter/matmul_wrapper.py index 5cbc6ac5..2da8d40a 100644 --- a/examples/13_gemm_reduce_scatter/matmul_wrapper.py +++ b/examples/13_gemm_reduce_scatter/matmul_wrapper.py @@ -7,7 +7,7 @@ import sys import os -from gemm__reduce_scatter import persistent_gemm_reduce_scatter +from gemm_reduce_scatter import persistent_gemm_reduce_scatter from examples.common.utils import is_triton_interpret_set From ec623a8bd598cc6e1e0f7250a816161461a5f7b3 Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 09:32:42 +0000 Subject: [PATCH 05/12] functionality pass --- examples/13_gemm_reduce_scatter/benchmark.py | 4 +- .../gemm_reduce_scatter.py | 48 ++++++++++++------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/examples/13_gemm_reduce_scatter/benchmark.py b/examples/13_gemm_reduce_scatter/benchmark.py index 56f3c057..220d32de 100644 --- a/examples/13_gemm_reduce_scatter/benchmark.py +++ b/examples/13_gemm_reduce_scatter/benchmark.py @@ -172,8 +172,8 @@ def preamble(): shmem.barrier() def run_experiment(): - nonlocal local_C - nonlocal global_C + nonlocal local_output + nonlocal compute_buffer nonlocal kernel_timing shmem.barrier() diff --git a/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py index 90e0f841..55e2c576 100644 --- a/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py +++ b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py @@ -209,6 +209,7 @@ def persistent_gemm_reduce_scatter( rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_mask = (rm[:, None] < M) & (rn[None, :] < N) + # C和C_都是full size buffer C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn tl.store(C_, c, c_mask) @@ -271,32 +272,47 @@ def persistent_gemm_reduce_scatter( # 关键修改:检查这个sub-tile是否属于当前rank的负责区域 global_row_start = rm_start + start_row_sub - # 只处理属于当前rank分区的数据 + # 在存储部分使用block指针 if global_row_start < end_row and (global_row_start + BLOCK_SIZE_M) > start_row: - # 计算在sub-tile内的有效行范围 tile_start_row = max(0, start_row - global_row_start) tile_end_row = min(BLOCK_SIZE_M, end_row - global_row_start) - - # 计算在本地缓冲区中的起始行 local_start_row = max(global_row_start, start_row) - start_row - # 创建行掩码 - row_indices = tl.arange(0, BLOCK_SIZE_M) - row_mask = (row_indices >= tile_start_row) & (row_indices < tile_end_row) - write_mask = row_mask[:, None] & (tl.arange(0, BLOCK_SIZE_N)[None, :] < BLOCK_SIZE_N) - - # 归约所有rank的数据 + # 归约数据 acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for remote_rank in range(world_size): remote_data = iris.load(C + sub_offset, cur_rank, remote_rank, heap_bases, mask=sub_mask) acc += remote_data - # 计算本地偏移 - local_offset = (local_start_row * stride_cm_local + - (rn_start + start_col_sub) * stride_cn_local) + # 创建block指针 + row_idx = tl.arange(0, BLOCK_SIZE_M) + col_idx = tl.arange(0, BLOCK_SIZE_N) + + # 计算每个元素的本地偏移 + local_offsets = (local_start_row + row_idx[:, None]) * stride_cm_local + \ + (rn_start + start_col_sub + col_idx[None, :]) * stride_cn_local + + # 创建block指针 + local_ptr_block = c_local + local_offsets - tl.store(c_local + local_offset, acc, mask=write_mask, cache_modifier=".wt") - + # 创建写入掩码 + valid_mask = (row_idx[:, None] >= tile_start_row) & \ + (row_idx[:, None] < tile_end_row) & \ + (col_idx[None, :] < BLOCK_SIZE_N) & \ + sub_mask + + # 存储block数据 + tl.store(local_ptr_block, acc, mask=valid_mask, cache_modifier=".wt") # for remote_rank in range(world_size): + # # C为full size buffer + # remote_data = iris.load(C + sub_offset, cur_rank, remote_rank, heap_bases, mask=sub_mask) + # acc += remote_data + + # # 计算本地偏移 + # local_offset = (local_start_row * stride_cm_local + + # (rn_start + start_col_sub) * stride_cn_local) + + # tl.store(c_local + local_offset, acc, mask=write_mask, cache_modifier=".wt") + if COLLECT_TIMESTAMPS: timestamp = read_realtime() - tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) \ No newline at end of file + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) From 79415e65924f121bd3715f9ef27a954d727ac45f Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 09:43:23 +0000 Subject: [PATCH 06/12] adapt benchmark --- examples/13_gemm_reduce_scatter/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/13_gemm_reduce_scatter/benchmark.py b/examples/13_gemm_reduce_scatter/benchmark.py index 220d32de..d7031da9 100644 --- a/examples/13_gemm_reduce_scatter/benchmark.py +++ b/examples/13_gemm_reduce_scatter/benchmark.py @@ -249,7 +249,7 @@ def run_experiment(): matmul_reduce_scatter.set_debug(False) # Validate global result - success = validate_gemm_reduce_scatter(A, B, local_C, rank, world_size, shmem, atol=2) + success = validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2) passed_str = "passed" if success else "failed" shmem.info(f"Final C validation {passed_str}.") From e2e79b0a0d758f6f61cbc13ad14bc71291eb47d9 Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 02:59:48 -0700 Subject: [PATCH 07/12] clean --- .../gemm_reduce_scatter.py | 32 +--------- .../13_gemm_reduce_scatter/matmul_wrapper.py | 62 +++---------------- examples/common/validation.py | 3 - 3 files changed, 10 insertions(+), 87 deletions(-) diff --git a/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py index 55e2c576..75f067fa 100644 --- a/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py +++ b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py @@ -87,11 +87,6 @@ def extract_submask_and_offset( @triton.jit def compute_output_partition(cur_rank, world_size, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): - """ - 计算当前rank负责的输出分区 - ReduceScatter: 每个rank只负责最终结果的一部分 - """ - # 按行分区(也可以按列或其他方式) rows_per_rank = tl.cdiv(M, world_size) start_row = cur_rank * rows_per_rank end_row = min((cur_rank + 1) * rows_per_rank, M) @@ -104,7 +99,7 @@ def persistent_gemm_reduce_scatter( A, B, C, - c_local, # 修改:本地输出缓冲区,不是全局的 + c_local, bias_ptr, P, locks, @@ -118,7 +113,7 @@ def persistent_gemm_reduce_scatter( stride_bn, stride_cm, stride_cn, - stride_cm_local, # 修改:本地输出的stride + stride_cm_local, stride_cn_local, stride_bias, BLOCK_SIZE_M: tl.constexpr, @@ -146,7 +141,6 @@ def persistent_gemm_reduce_scatter( num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n - # 计算当前rank负责的输出分区 start_row, end_row = compute_output_partition(cur_rank, world_size, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N) tl.assume(stride_am > 0) @@ -202,18 +196,15 @@ def persistent_gemm_reduce_scatter( b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) acc += tl.dot(a, b) - # 存储中间结果 c = acc.to(C.type.element_ty) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_mask = (rm[:, None] < M) & (rn[None, :] < N) - # C和C_都是full size buffer C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn tl.store(C_, c, c_mask) - # 信号通知其他rank for remote in range(world_size): if remote != cur_rank: iris.atomic_add( @@ -226,7 +217,6 @@ def persistent_gemm_reduce_scatter( scope="sys", ) - # 等待所有rank完成这个tile result = 0 while result < (world_size - 1): compare = world_size - 1 @@ -242,7 +232,6 @@ def persistent_gemm_reduce_scatter( scope="sys", ) - # ReduceScatter关键修改:只收集和存储属于本rank分区的数据 rm, rn, mask, rm_start, rn_start = offset_for_tile(tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M, N) num_sub_tiles_m = tl.cdiv(BLOCK_SIZE_M, BLOCK_SIZE_M) @@ -269,49 +258,32 @@ def persistent_gemm_reduce_scatter( stride_cn, ) - # 关键修改:检查这个sub-tile是否属于当前rank的负责区域 global_row_start = rm_start + start_row_sub - # 在存储部分使用block指针 if global_row_start < end_row and (global_row_start + BLOCK_SIZE_M) > start_row: tile_start_row = max(0, start_row - global_row_start) tile_end_row = min(BLOCK_SIZE_M, end_row - global_row_start) local_start_row = max(global_row_start, start_row) - start_row - # 归约数据 acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for remote_rank in range(world_size): remote_data = iris.load(C + sub_offset, cur_rank, remote_rank, heap_bases, mask=sub_mask) acc += remote_data - # 创建block指针 row_idx = tl.arange(0, BLOCK_SIZE_M) col_idx = tl.arange(0, BLOCK_SIZE_N) - # 计算每个元素的本地偏移 local_offsets = (local_start_row + row_idx[:, None]) * stride_cm_local + \ (rn_start + start_col_sub + col_idx[None, :]) * stride_cn_local - # 创建block指针 local_ptr_block = c_local + local_offsets - # 创建写入掩码 valid_mask = (row_idx[:, None] >= tile_start_row) & \ (row_idx[:, None] < tile_end_row) & \ (col_idx[None, :] < BLOCK_SIZE_N) & \ sub_mask - # 存储block数据 tl.store(local_ptr_block, acc, mask=valid_mask, cache_modifier=".wt") # for remote_rank in range(world_size): - # # C为full size buffer - # remote_data = iris.load(C + sub_offset, cur_rank, remote_rank, heap_bases, mask=sub_mask) - # acc += remote_data - - # # 计算本地偏移 - # local_offset = (local_start_row * stride_cm_local + - # (rn_start + start_col_sub) * stride_cn_local) - - # tl.store(c_local + local_offset, acc, mask=write_mask, cache_modifier=".wt") if COLLECT_TIMESTAMPS: timestamp = read_realtime() diff --git a/examples/13_gemm_reduce_scatter/matmul_wrapper.py b/examples/13_gemm_reduce_scatter/matmul_wrapper.py index 2da8d40a..78503d7d 100644 --- a/examples/13_gemm_reduce_scatter/matmul_wrapper.py +++ b/examples/13_gemm_reduce_scatter/matmul_wrapper.py @@ -26,7 +26,7 @@ def _call( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, - c_local: torch.Tensor, # 修改:本地输出而不是全局输出 + c_local: torch.Tensor, bias: torch.Tensor, P: torch.Tensor, locks: torch.Tensor, @@ -54,13 +54,11 @@ def _call( M, K = a.shape _, N = b.shape - # 关键修改:计算每个rank负责的输出分区大小 rows_per_rank = (M + world_size - 1) // world_size local_M = rows_per_rank - if rank == world_size - 1: # 最后一个rank处理剩余的行 + if rank == world_size - 1: local_M = M - rank * rows_per_rank - # 验证本地输出缓冲区大小是否正确 assert c_local.shape[0] == local_M, f"c_local shape mismatch: expected {local_M}, got {c_local.shape[0]}" assert c_local.shape[1] == N, f"c_local shape mismatch: expected {N}, got {c_local.shape[1]}" @@ -97,13 +95,12 @@ def _call( use_bias = False stride_bias = bias.stride(0) if use_bias else 0 - # 关键修改:使用ReduceScatter内核 grids = total_programs_streamk kk = gemm_kernel[(grids,)]( a, b, c, - c_local, # 修改:传递本地输出缓冲区 + c_local, bias, P, locks, @@ -117,7 +114,7 @@ def _call( b.stride(1), c.stride(0), c.stride(1), - c_local.stride(0), # 修改:本地输出的stride + c_local.stride(0), c_local.stride(1), stride_bias, BLOCK_SIZE_M=BLK_M, @@ -147,7 +144,7 @@ def _call( matmul_reduce_scatter.streamk_spills = kk.n_spills print(f"{kk.n_regs} registers used, {kk.n_spills} spills") - return c_local # 修改:返回本地输出 + return c_local @staticmethod def forward( @@ -155,7 +152,7 @@ def forward( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, - c_local: torch.Tensor, # 修改:本地输出缓冲区 + c_local: torch.Tensor, bias: torch.Tensor, P: torch.Tensor, locks: torch.Tensor, @@ -183,7 +180,7 @@ def forward( a=a, b=b, c=c, - c_local=c_local, # 修改:传递本地输出 + c_local=c_local, bias=bias, P=P, locks=locks, @@ -207,47 +204,4 @@ def forward( mm_begin_timestamp=mm_begin_timestamp, mm_end_timestamp=mm_end_timestamp, ) - return result # 修改:返回本地输出结果 - - -# 使用示例函数 -# def gemm_reduce_scatter_example(): -# # 初始化参数 -# M, N, K = 1024, 1024, 1024 -# world_size = 4 -# rank = 0 # 在实际中需要根据当前rank设置 - -# # 计算每个rank负责的行数 -# rows_per_rank = (M + world_size - 1) // world_size -# local_M = rows_per_rank -# if rank == world_size - 1: -# local_M = M - rank * rows_per_rank - -# # 创建输入张量 -# a = torch.randn(M, K, device='cuda') -# b = torch.randn(K, N, device='cuda') - -# # 创建中间缓冲区和输出缓冲区 -# c = torch.zeros(M, N, device='cuda') # 中间结果缓冲区 -# c_local = torch.zeros(local_M, N, device='cuda') # 本地输出缓冲区 - -# # 创建同步所需的张量 -# total_tiles = (M + 127) // 128 * (N + 127) // 128 -# tile_completed = torch.zeros(total_tiles, dtype=torch.int32, device='cuda') -# locks = torch.zeros(1024, dtype=torch.int32, device='cuda') # 锁缓冲区 -# P = torch.zeros(1, dtype=torch.int32, device='cuda') # 占位符 - -# # 调用ReduceScatter GEMM -# result = matmul_reduce_scatter.apply( -# a, b, c, c_local, None, P, locks, tile_completed, -# rank, world_size, 256, # grid size = 256 -# 128, 128, 32 # BLK_M, BLK_N, BLK_K -# ) - -# return result - - -# if __name__ == "__main__": -# # 测试代码 -# result = gemm_reduce_scatter_example() -# print(f"ReduceScatter result shape: {result.shape}") \ No newline at end of file + return result diff --git a/examples/common/validation.py b/examples/common/validation.py index 6fbadd0f..d48cd1a5 100644 --- a/examples/common/validation.py +++ b/examples/common/validation.py @@ -23,15 +23,12 @@ def validate_gemm(A, B, C, shmem, atol=1): return True def validate_gemm_reduce_scatter(A, B, local_C, rank, world_size, shmem, atol=1): - # 计算完整的groundtruth full_result = torch.mm(A, B) - # 计算当前rank应该负责的部分 rows_per_gpu = A.shape[0] // world_size start_row = rank * rows_per_gpu end_row = start_row + local_C.shape[0] expected_local = full_result[start_row:end_row, :] - # 验证本地结果 return torch.allclose(local_C, expected_local, atol=atol) From b005c25d16db387c329758ce1ab462f4e5224f8e Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 03:02:47 -0700 Subject: [PATCH 08/12] clean --- examples/13_gemm_reduce_scatter/benchmark.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/13_gemm_reduce_scatter/benchmark.py b/examples/13_gemm_reduce_scatter/benchmark.py index d7031da9..f9682409 100644 --- a/examples/13_gemm_reduce_scatter/benchmark.py +++ b/examples/13_gemm_reduce_scatter/benchmark.py @@ -123,12 +123,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): for key, value in args.items(): json_writer.add_field(key, value) - # global_C = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) global_C = None - # local_C = shmem.zeros((args["m"] // world_size, args["n"]), device="cuda", dtype=A.dtype) - # 中间计算缓冲区(全尺寸) compute_buffer = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) - # 本地输出缓冲区(分区尺寸) local_output = shmem.zeros((args["m"] // world_size, args["n"]), device="cuda", dtype=A.dtype) total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) From 837126043126507b91e60c53185012cc8c4e6a62 Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 07:28:46 -0700 Subject: [PATCH 09/12] add test of gemm rs --- .../test_gemm_reduce_scatter_bench.py | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 tests/examples/test_gemm_reduce_scatter_bench.py diff --git a/tests/examples/test_gemm_reduce_scatter_bench.py b/tests/examples/test_gemm_reduce_scatter_bench.py new file mode 100644 index 00000000..bb4bcb53 --- /dev/null +++ b/tests/examples/test_gemm_reduce_scatter_bench.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch +import triton +import triton.language as tl +import numpy as np +import iris + +import importlib.util +from pathlib import Path +from examples.common.utils import ( + Timestamps, +) +current_dir = Path(__file__).parent +import sys +sys.path.append(str(current_dir / "../../examples/13_gemm_reduce_scatter/")) +sys.path.append(str(current_dir / "../../")) +# Import the matmul wrapper +matmul_path = (current_dir / "../../examples/13_gemm_reduce_scatter/matmul_wrapper.py").resolve() +matmul_spec = importlib.util.spec_from_file_location("matmul_wrapper", matmul_path) +matmul_module = importlib.util.module_from_spec(matmul_spec) +matmul_spec.loader.exec_module(matmul_module) + +# Import the validation function +validation_path = (current_dir / "../../examples/common/validation.py").resolve() +validation_spec = importlib.util.spec_from_file_location("validation", validation_path) +validation_module = importlib.util.module_from_spec(validation_spec) +validation_spec.loader.exec_module(validation_module) + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "m, n, k", + [ + (64, 64, 64), # Very small for quick testing + (128, 128, 128), # Small + (256, 256, 256), # Medium + ], +) +@pytest.mark.parametrize( + "BLK_M, BLK_N, BLK_K", + [ + (32, 32, 16), # Small blocks + (64, 64, 32), # Medium blocks + ], +) +def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K): + """Worker function for PyTorch distributed execution.""" + heap_size = 1 << 30 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + + # GEMM + datatype = dtype + + assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})." + assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})." + + A = shmem.randn(m, k, device="cuda", dtype=datatype) + B = shmem.randn(n, k, device="cuda", dtype=datatype).T + C = shmem.zeros((m, n), device="cuda", dtype=A.dtype) + + M = m + N = n + K = k + + # Splitting + rows_per_gpu = k // world_size + k = rows_per_gpu + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype) + local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype) + + total_blocks_M = triton.cdiv(m, BLK_M) + total_blocks_N = triton.cdiv(n, BLK_N) + total_tiles = total_blocks_M * total_blocks_N + + tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + + locks = shmem.zeros((288,), device="cuda", dtype=torch.int32) + P = shmem.zeros( + (288, BLK_M * BLK_N), + device="cuda", + dtype=torch.float32, + ) + bias = None + gemm_stream = torch.cuda.Stream() + timestamps = Timestamps(num_tiles=total_tiles) + + def preamble(): + shmem.barrier() + tile_completed.zero_() + shmem.barrier() + + def run_experiment(): + nonlocal local_output + nonlocal compute_buffer + + shmem.barrier() + + torch.cuda.nvtx.range_push("GEMM + Communication") + with torch.cuda.stream(gemm_stream): + local_output = matmul_module.matmul_reduce_scatter.apply( + local_A, + local_B, + compute_buffer, + local_output, + bias, + P, + locks, + tile_completed, + rank, + world_size, + 288, + BLK_M, + BLK_N, + BLK_K, + 6, + True, + 1, + 8, + 0, + 16, + 2, + shmem.get_heap_bases(), + cu_count, + False, + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + torch.cuda.nvtx.range_pop() + shmem.barrier() + + # Synchronize across all GPUs + shmem.barrier() + run_experiment() + shmem.barrier() + preamble() + shmem.barrier() + + shmem.info("Validating...") + + matmul_reduce_scatter.set_debug(False) + # Validate global result + success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2) + assert success, ( + f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}" + ) From 19a920815184bd2bde3c46062f736c4bee254923 Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 07:33:52 -0700 Subject: [PATCH 10/12] remove redundant words --- examples/13_gemm_reduce_scatter/benchmark.py | 4 +--- tests/examples/test_gemm_reduce_scatter_bench.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/13_gemm_reduce_scatter/benchmark.py b/examples/13_gemm_reduce_scatter/benchmark.py index f9682409..a70dd351 100644 --- a/examples/13_gemm_reduce_scatter/benchmark.py +++ b/examples/13_gemm_reduce_scatter/benchmark.py @@ -52,8 +52,7 @@ def parse_args(): default="log.json", help="Output file", ) - # For All Scatter, use: 256x64x64 - # For One Shot, use: 256x256x64 + parser.add_argument("--BLK_M", type=int, default=256, help="Block size M") parser.add_argument("--BLK_N", type=int, default=64, help="Block size N") parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") @@ -123,7 +122,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): for key, value in args.items(): json_writer.add_field(key, value) - global_C = None compute_buffer = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) local_output = shmem.zeros((args["m"] // world_size, args["n"]), device="cuda", dtype=A.dtype) diff --git a/tests/examples/test_gemm_reduce_scatter_bench.py b/tests/examples/test_gemm_reduce_scatter_bench.py index bb4bcb53..25ccf852 100644 --- a/tests/examples/test_gemm_reduce_scatter_bench.py +++ b/tests/examples/test_gemm_reduce_scatter_bench.py @@ -155,7 +155,7 @@ def run_experiment(): shmem.info("Validating...") - matmul_reduce_scatter.set_debug(False) + matmul_module.matmul_reduce_scatter.set_debug(False) # Validate global result success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2) assert success, ( From 5f4efcec3194a2b21a0ab4315e704fa3e2f38523 Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Wed, 17 Sep 2025 20:19:57 -0700 Subject: [PATCH 11/12] correct ut --- tests/examples/test_gemm_reduce_scatter_bench.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/examples/test_gemm_reduce_scatter_bench.py b/tests/examples/test_gemm_reduce_scatter_bench.py index 25ccf852..ac0ad4d0 100644 --- a/tests/examples/test_gemm_reduce_scatter_bench.py +++ b/tests/examples/test_gemm_reduce_scatter_bench.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +import sys import pytest import torch import triton @@ -15,7 +16,6 @@ Timestamps, ) current_dir = Path(__file__).parent -import sys sys.path.append(str(current_dir / "../../examples/13_gemm_reduce_scatter/")) sys.path.append(str(current_dir / "../../")) # Import the matmul wrapper @@ -41,15 +41,15 @@ @pytest.mark.parametrize( "m, n, k", [ - (64, 64, 64), # Very small for quick testing - (128, 128, 128), # Small - (256, 256, 256), # Medium + (512, 512, 512), # Very small for quick testing + (1024, 1024, 1024), # Small + (2048, 2048, 2048), # Medium ], ) @pytest.mark.parametrize( "BLK_M, BLK_N, BLK_K", [ - (32, 32, 16), # Small blocks + (32, 32, 32), # Small blocks (64, 64, 32), # Medium blocks ], ) From 18ea56276e6481f9ae3ce827d6407295e2a232a0 Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Fri, 19 Sep 2025 01:44:11 -0700 Subject: [PATCH 12/12] disable RS when rank=1 --- .../gemm_reduce_scatter.py | 155 +++++++++--------- 1 file changed, 79 insertions(+), 76 deletions(-) diff --git a/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py index 75f067fa..cde5c715 100644 --- a/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py +++ b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py @@ -202,88 +202,91 @@ def persistent_gemm_reduce_scatter( rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_mask = (rm[:, None] < M) & (rn[None, :] < N) - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C_, c, c_mask) - - for remote in range(world_size): - if remote != cur_rank: - iris.atomic_add( + if world_size == 1: + C_ = c_local + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, c, c_mask) + else: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, c, c_mask) + for remote in range(world_size): + if remote != cur_rank: + iris.atomic_add( + tile_completed + tile_id, + 1, + cur_rank, + remote, + heap_bases, + sem="release", + scope="sys", + ) + + result = 0 + while result < (world_size - 1): + compare = world_size - 1 + value = 0 + result = iris.atomic_cas( tile_completed + tile_id, - 1, + compare, + value, + cur_rank, cur_rank, - remote, heap_bases, - sem="release", + sem="acquire", scope="sys", ) - result = 0 - while result < (world_size - 1): - compare = world_size - 1 - value = 0 - result = iris.atomic_cas( - tile_completed + tile_id, - compare, - value, - cur_rank, - cur_rank, - heap_bases, - sem="acquire", - scope="sys", - ) - - rm, rn, mask, rm_start, rn_start = offset_for_tile(tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M, N) - - num_sub_tiles_m = tl.cdiv(BLOCK_SIZE_M, BLOCK_SIZE_M) - num_sub_tiles_n = tl.cdiv(BLOCK_SIZE_N, BLOCK_SIZE_N) - total_sub_tiles = num_sub_tiles_m * num_sub_tiles_n - - for sub_tile_idx in range(0, total_sub_tiles): - start_row_sub = (sub_tile_idx // num_sub_tiles_n) * BLOCK_SIZE_M - start_col_sub = (sub_tile_idx % num_sub_tiles_n) * BLOCK_SIZE_N - - sub_mask, sub_offset = extract_submask_and_offset( - rm, - rn, - mask, - rm_start, - rn_start, - start_row_sub, - start_col_sub, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - stride_cm, - stride_cn, - ) - - global_row_start = rm_start + start_row_sub - - if global_row_start < end_row and (global_row_start + BLOCK_SIZE_M) > start_row: - tile_start_row = max(0, start_row - global_row_start) - tile_end_row = min(BLOCK_SIZE_M, end_row - global_row_start) - local_start_row = max(global_row_start, start_row) - start_row - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for remote_rank in range(world_size): - remote_data = iris.load(C + sub_offset, cur_rank, remote_rank, heap_bases, mask=sub_mask) - acc += remote_data - - row_idx = tl.arange(0, BLOCK_SIZE_M) - col_idx = tl.arange(0, BLOCK_SIZE_N) - - local_offsets = (local_start_row + row_idx[:, None]) * stride_cm_local + \ - (rn_start + start_col_sub + col_idx[None, :]) * stride_cn_local - - local_ptr_block = c_local + local_offsets - - valid_mask = (row_idx[:, None] >= tile_start_row) & \ - (row_idx[:, None] < tile_end_row) & \ - (col_idx[None, :] < BLOCK_SIZE_N) & \ - sub_mask - - tl.store(local_ptr_block, acc, mask=valid_mask, cache_modifier=".wt") # for remote_rank in range(world_size): + rm, rn, mask, rm_start, rn_start = offset_for_tile(tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M, N) + + num_sub_tiles_m = tl.cdiv(BLOCK_SIZE_M, BLOCK_SIZE_M) + num_sub_tiles_n = tl.cdiv(BLOCK_SIZE_N, BLOCK_SIZE_N) + total_sub_tiles = num_sub_tiles_m * num_sub_tiles_n + + for sub_tile_idx in range(0, total_sub_tiles): + start_row_sub = (sub_tile_idx // num_sub_tiles_n) * BLOCK_SIZE_M + start_col_sub = (sub_tile_idx % num_sub_tiles_n) * BLOCK_SIZE_N + + sub_mask, sub_offset = extract_submask_and_offset( + rm, + rn, + mask, + rm_start, + rn_start, + start_row_sub, + start_col_sub, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + stride_cm, + stride_cn, + ) + + global_row_start = rm_start + start_row_sub + + if global_row_start < end_row and (global_row_start + BLOCK_SIZE_M) > start_row: + tile_start_row = max(0, start_row - global_row_start) + tile_end_row = min(BLOCK_SIZE_M, end_row - global_row_start) + local_start_row = max(global_row_start, start_row) - start_row + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for remote_rank in range(world_size): + remote_data = iris.load(C + sub_offset, cur_rank, remote_rank, heap_bases, mask=sub_mask) + acc += remote_data + + row_idx = tl.arange(0, BLOCK_SIZE_M) + col_idx = tl.arange(0, BLOCK_SIZE_N) + + local_offsets = (local_start_row + row_idx[:, None]) * stride_cm_local + \ + (rn_start + start_col_sub + col_idx[None, :]) * stride_cn_local + + local_ptr_block = c_local + local_offsets + + valid_mask = (row_idx[:, None] >= tile_start_row) & \ + (row_idx[:, None] < tile_end_row) & \ + (col_idx[None, :] < BLOCK_SIZE_N) & \ + sub_mask + + tl.store(local_ptr_block, acc, mask=valid_mask, cache_modifier=".wt") if COLLECT_TIMESTAMPS: timestamp = read_realtime()