diff --git a/examples/13_gemm_reduce_scatter/benchmark.py b/examples/13_gemm_reduce_scatter/benchmark.py new file mode 100644 index 00000000..a70dd351 --- /dev/null +++ b/examples/13_gemm_reduce_scatter/benchmark.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import sys +import os +import argparse +import json + +from examples.common.utils import ( + JSONWriter, + Timestamps, + is_triton_interpret_set, +) + +import iris + +from matmul_wrapper import matmul_reduce_scatter +from examples.common.validation import validate_gemm, validate_gemm_reduce_scatter + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse matrix dimensions and configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A") + parser.add_argument("-n", type=int, default=4608, help="Number of columns in matrix B") + parser.add_argument("-k", type=int, default=36864, help="Common dimension between matrices A and B") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + + parser.add_argument("--BLK_M", type=int, default=256, help="Block size M") + parser.add_argument("--BLK_N", type=int, default=64, help="Block size N") + parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") + + # Best to try 1, 6 or 8 + parser.add_argument("--gsize_m", type=int, default=6, help="Grid size M") + parser.add_argument("--two_tiles", type=str, default="True", help="Use two tiles") + parser.add_argument("--num_stages", type=int, default=1, help="Number of stages") + parser.add_argument("--num_warps", type=int, default=8, help="Number of warps") + parser.add_argument("--waves_per_eu", type=int, default=0, help="Waves per execution unit") + parser.add_argument("--mfmaInstrSize", type=int, default=16, help="MFMA instruction size") + parser.add_argument("--kpack", type=int, default=2, help="K packing size") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + + parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") + parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("-r", "--num_ranks", type=int, default=4, help="Number of ranks/processes") + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + + # GEMM + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "int8": + datatype = torch.int8 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + assert args["m"] % world_size == 0, f"M ({args['m']}) must be divisible by world size ({world_size})." + assert args["k"] % world_size == 0, f"K ({args['k']}) must be divisible by world size ({world_size})." + + A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype) + B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T + C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) + + args["M"] = args["m"] + args["N"] = args["n"] + args["K"] = args["k"] + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + # Splitting + rows_per_gpu = args["k"] // world_size + args["k"] = rows_per_gpu + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + for key, value in args.items(): + json_writer.add_field(key, value) + + compute_buffer = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) + local_output = shmem.zeros((args["m"] // world_size, args["n"]), device="cuda", dtype=A.dtype) + + total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) + total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + if args["gemm_sms"] >= args["total_sms"]: + print(f"Invalid number of GEMM SMs. {args['gemm_sms']} >= {args['total_sms']}") + exit(1) + + tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + + locks = shmem.zeros((args["gemm_sms"],), device="cuda", dtype=torch.int32) + + P = shmem.zeros( + (args["gemm_sms"], args["BLK_M"] * args["BLK_N"]), + device="cuda", + dtype=torch.float32, + ) + bias = None + + gemm_stream = torch.cuda.Stream() + + json_writer.add_field("gemm_sms", args["gemm_sms"]) + + kernel_timing = { + "gemm": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + } + } + + # Timestamps + timestamps = Timestamps(num_tiles=total_tiles) + + def preamble(): + shmem.barrier() + tile_completed.zero_() + shmem.barrier() + + def run_experiment(): + nonlocal local_output + nonlocal compute_buffer + nonlocal kernel_timing + + shmem.barrier() + + if args["trace_tiles"]: + timestamps.reset() + shmem.barrier() + + torch.cuda.nvtx.range_push("GEMM + Communication") + with torch.cuda.stream(gemm_stream): + kernel_timing["gemm"]["start_event"].record() + local_output = matmul_reduce_scatter.apply( + local_A, + local_B, + compute_buffer, + local_output, + bias, + P, + locks, + tile_completed, + rank, + world_size, + args["gemm_sms"], + args["BLK_M"], + args["BLK_N"], + args["BLK_K"], + args["gsize_m"], + args["two_tiles"], + args["num_stages"], + args["num_warps"], + args["waves_per_eu"], + args["mfmaInstrSize"], + args["kpack"], + shmem.get_heap_bases(), + cu_count, + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["gemm"]["end_event"].record() + kernel_timing["gemm"]["experiments"] += 1 + + torch.cuda.nvtx.range_pop() + shmem.barrier() + + for k in ["gemm"]: + ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) + kernel_timing[k]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + # Warmup + run_experiment() + + shmem.barrier() + preamble() + shmem.barrier() + + for k in ["gemm"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + if not is_triton_interpret_set(): + gemm_registers = matmul_reduce_scatter.streamk_registers + gemm_spills = matmul_reduce_scatter.streamk_spills + + json_writer.add_field("gemm_registers", gemm_registers) + json_writer.add_field("gemm_spills", gemm_spills) + + if args["validate"]: + shmem.info("Validating...") + + matmul_reduce_scatter.set_debug(False) + # Validate global result + success = validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2) + passed_str = "passed" if success else "failed" + shmem.info(f"Final C validation {passed_str}.") + + # Wait for all to finish validation + shmem.barrier() + json_writer.add_field("success", success) + shmem.info("Validation completed") + + if args["benchmark"]: + shmem.info("Benchmarking...") + perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) + triton_tflops = perf(triton_ms) + shmem.info(f"tile matmul + reduce_scatter (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") + + json_writer.add_field("triton_tflops", triton_tflops) + json_writer.add_field("triton_ms", triton_ms) + + for k in ["gemm"]: + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + if args["trace_tiles"] and rank == 0: + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 + filename = f"gemm_reduce_scatter_tiles_trace_rank{rank}.json" + timestamps.to_json(filename, gpu_freq) + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py new file mode 100644 index 00000000..cde5c715 --- /dev/null +++ b/examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from examples.common.utils import read_realtime + +import sys +import os + +import iris + + +@triton.jit +def tile_id_to_index_range( + tile_id, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + + tile_in_group = tile_id % num_pid_in_group + pid_m = first_pid_m + (tile_in_group % group_size_m) + pid_n = tile_in_group // group_size_m + + rm_start = pid_m * BLOCK_SIZE_M + rn_start = pid_n * BLOCK_SIZE_N + + max_m = M - 1 + max_n = N - 1 + + rm = rm_start + tl.arange(0, BLOCK_SIZE_M) + rn = rn_start + tl.arange(0, BLOCK_SIZE_N) + + rm = tl.minimum(rm, max_m) + rn = tl.minimum(rn, max_n) + + return rm, rn, rm_start, rn_start + + +@triton.jit +def offset_for_tile(local_tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M_local, N_local): + rm, rn, rm_start, rn_start = tile_id_to_index_range( + local_tile_id, M_local, N_local, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + c_mask = (rm[:, None] < M_local) & (rn[None, :] < N_local) + return rm, rn, c_mask, rm_start, rn_start + + +@triton.jit +def extract_submask_and_offset( + rm, + rn, + mask, + rm_start, + rn_start, + start_row, + start_col, + SUB_BLOCK_SIZE_M: tl.constexpr, + SUB_BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + stride_cm_local: tl.constexpr, + stride_cn_local: tl.constexpr, +): + sub_rm = tl.arange(0, SUB_BLOCK_SIZE_M) + start_row + sub_rn = tl.arange(0, SUB_BLOCK_SIZE_N) + start_col + + sub_rm_2d = sub_rm[:, None] + sub_rn_2d = sub_rn[None, :] + + sub_mask = (sub_rm_2d < BLOCK_SIZE_M) & (sub_rn_2d < BLOCK_SIZE_N) + + sub_offset = ((rm_start + sub_rm_2d) * stride_cm_local) + ((rn_start + sub_rn_2d) * stride_cn_local) + + return sub_mask, sub_offset + + +@triton.jit +def compute_output_partition(cur_rank, world_size, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + rows_per_rank = tl.cdiv(M, world_size) + start_row = cur_rank * rows_per_rank + end_row = min((cur_rank + 1) * rows_per_rank, M) + + return start_row, end_row + + +@triton.jit +def persistent_gemm_reduce_scatter( + A, + B, + C, + c_local, + bias_ptr, + P, + locks, + tile_completed, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_cm_local, + stride_cn_local, + stride_bias, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + STREAMK_TILES: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + NOTIFY_REMOTES: tl.constexpr = False, + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + start_row, end_row = compute_output_partition(cur_rank, world_size, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N) + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + for tile_id in range(pid, total_tiles, NUM_SMS): + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + a = tl.load(tl.multiple_of(A_BASE, (1, 16))) + b = tl.load(tl.multiple_of(B_BASE, (16, 1))) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b) + + c = acc.to(C.type.element_ty) + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + if world_size == 1: + C_ = c_local + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, c, c_mask) + else: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, c, c_mask) + for remote in range(world_size): + if remote != cur_rank: + iris.atomic_add( + tile_completed + tile_id, + 1, + cur_rank, + remote, + heap_bases, + sem="release", + scope="sys", + ) + + result = 0 + while result < (world_size - 1): + compare = world_size - 1 + value = 0 + result = iris.atomic_cas( + tile_completed + tile_id, + compare, + value, + cur_rank, + cur_rank, + heap_bases, + sem="acquire", + scope="sys", + ) + + rm, rn, mask, rm_start, rn_start = offset_for_tile(tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M, N) + + num_sub_tiles_m = tl.cdiv(BLOCK_SIZE_M, BLOCK_SIZE_M) + num_sub_tiles_n = tl.cdiv(BLOCK_SIZE_N, BLOCK_SIZE_N) + total_sub_tiles = num_sub_tiles_m * num_sub_tiles_n + + for sub_tile_idx in range(0, total_sub_tiles): + start_row_sub = (sub_tile_idx // num_sub_tiles_n) * BLOCK_SIZE_M + start_col_sub = (sub_tile_idx % num_sub_tiles_n) * BLOCK_SIZE_N + + sub_mask, sub_offset = extract_submask_and_offset( + rm, + rn, + mask, + rm_start, + rn_start, + start_row_sub, + start_col_sub, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + stride_cm, + stride_cn, + ) + + global_row_start = rm_start + start_row_sub + + if global_row_start < end_row and (global_row_start + BLOCK_SIZE_M) > start_row: + tile_start_row = max(0, start_row - global_row_start) + tile_end_row = min(BLOCK_SIZE_M, end_row - global_row_start) + local_start_row = max(global_row_start, start_row) - start_row + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for remote_rank in range(world_size): + remote_data = iris.load(C + sub_offset, cur_rank, remote_rank, heap_bases, mask=sub_mask) + acc += remote_data + + row_idx = tl.arange(0, BLOCK_SIZE_M) + col_idx = tl.arange(0, BLOCK_SIZE_N) + + local_offsets = (local_start_row + row_idx[:, None]) * stride_cm_local + \ + (rn_start + start_col_sub + col_idx[None, :]) * stride_cn_local + + local_ptr_block = c_local + local_offsets + + valid_mask = (row_idx[:, None] >= tile_start_row) & \ + (row_idx[:, None] < tile_end_row) & \ + (col_idx[None, :] < BLOCK_SIZE_N) & \ + sub_mask + + tl.store(local_ptr_block, acc, mask=valid_mask, cache_modifier=".wt") + + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) diff --git a/examples/13_gemm_reduce_scatter/matmul_wrapper.py b/examples/13_gemm_reduce_scatter/matmul_wrapper.py new file mode 100644 index 00000000..78503d7d --- /dev/null +++ b/examples/13_gemm_reduce_scatter/matmul_wrapper.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import random +import sys +import os + +from gemm_reduce_scatter import persistent_gemm_reduce_scatter + +from examples.common.utils import is_triton_interpret_set + +gemm_kernel = persistent_gemm_reduce_scatter + + +class matmul_reduce_scatter(torch.autograd.Function): + _debug = True + + @staticmethod + def set_debug(debug: bool): + matmul_reduce_scatter._debug = debug + + @staticmethod + def _call( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + c_local: torch.Tensor, + bias: torch.Tensor, + P: torch.Tensor, + locks: torch.Tensor, + tile_completed: torch.Tensor, + rank: int, + world_size: int, + total_programs_streamk: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + two_tiles: bool, + num_stages: int, + num_warps: int, + waves_per_eu: int, + mfmaInstrSize: int, + kpack: int, + heap_bases_ptr: torch.Tensor = None, + cu_count: int = 304, + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + rows_per_rank = (M + world_size - 1) // world_size + local_M = rows_per_rank + if rank == world_size - 1: + local_M = M - rank * rows_per_rank + + assert c_local.shape[0] == local_M, f"c_local shape mismatch: expected {local_M}, got {c_local.shape[0]}" + assert c_local.shape[1] == N, f"c_local shape mismatch: expected {N}, got {c_local.shape[1]}" + + num_xcds = 1 + if cu_count == 304: + num_xcds = 8 + + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + total_tiles = total_blocks_M * total_blocks_N + even_k = K % BLK_K == 0 + + if total_programs_streamk > 0: + total_tiles_streamk = total_tiles % total_programs_streamk + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + else: + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + if matmul_reduce_scatter._debug: + print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") + print(f"Rank {rank}/{world_size} responsible for {local_M} rows") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + + use_bias = False + stride_bias = bias.stride(0) if use_bias else 0 + + grids = total_programs_streamk + kk = gemm_kernel[(grids,)]( + a, + b, + c, + c_local, + bias, + P, + locks, + tile_completed, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + c_local.stride(0), + c_local.stride(1), + stride_bias, + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + NUM_SMS=total_programs_streamk, + STREAMK_TILES=total_tiles_streamk, + NUM_XCDS=num_xcds, + BIAS=use_bias, + EVEN_K=even_k, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + heap_bases=heap_bases_ptr, + cur_rank=rank, + world_size=world_size, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp_ptr=mm_begin_timestamp, + mm_end_timestamp_ptr=mm_end_timestamp, + ) + + if matmul_reduce_scatter._debug and not is_triton_interpret_set(): + matmul_reduce_scatter.streamk_registers = kk.n_regs + matmul_reduce_scatter.streamk_spills = kk.n_spills + print(f"{kk.n_regs} registers used, {kk.n_spills} spills") + + return c_local + + @staticmethod + def forward( + ctx, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + c_local: torch.Tensor, + bias: torch.Tensor, + P: torch.Tensor, + locks: torch.Tensor, + tile_completed: torch.Tensor, + rank: int, + world_size: int, + grid: int, + BLK_M=128, + BLK_N=128, + BLK_K=32, + gsize_m=1, + two_tiles=True, + num_stages=3, + num_warps=4, + waves_per_eu=2, + mfmaInstrSize=16, + kpack=1, + heap_bases_ptr: torch.Tensor = None, + cu_count: int = 304, + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + result = matmul_reduce_scatter._call( + a=a, + b=b, + c=c, + c_local=c_local, + bias=bias, + P=P, + locks=locks, + tile_completed=tile_completed, + rank=rank, + world_size=world_size, + total_programs_streamk=grid, + BLK_M=BLK_M, + BLK_N=BLK_N, + BLK_K=BLK_K, + gsize_m=gsize_m, + two_tiles=two_tiles, + num_warps=num_warps, + num_stages=num_stages, + waves_per_eu=waves_per_eu, + mfmaInstrSize=mfmaInstrSize, + kpack=kpack, + heap_bases_ptr=heap_bases_ptr, + cu_count=cu_count, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp=mm_begin_timestamp, + mm_end_timestamp=mm_end_timestamp, + ) + return result diff --git a/examples/common/validation.py b/examples/common/validation.py index dfe513aa..d48cd1a5 100644 --- a/examples/common/validation.py +++ b/examples/common/validation.py @@ -21,3 +21,14 @@ def validate_gemm(A, B, C, shmem, atol=1): return False return True + +def validate_gemm_reduce_scatter(A, B, local_C, rank, world_size, shmem, atol=1): + full_result = torch.mm(A, B) + + rows_per_gpu = A.shape[0] // world_size + start_row = rank * rows_per_gpu + end_row = start_row + local_C.shape[0] + + expected_local = full_result[start_row:end_row, :] + + return torch.allclose(local_C, expected_local, atol=atol) diff --git a/tests/examples/test_gemm_reduce_scatter_bench.py b/tests/examples/test_gemm_reduce_scatter_bench.py new file mode 100644 index 00000000..ac0ad4d0 --- /dev/null +++ b/tests/examples/test_gemm_reduce_scatter_bench.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import sys +import pytest +import torch +import triton +import triton.language as tl +import numpy as np +import iris + +import importlib.util +from pathlib import Path +from examples.common.utils import ( + Timestamps, +) +current_dir = Path(__file__).parent +sys.path.append(str(current_dir / "../../examples/13_gemm_reduce_scatter/")) +sys.path.append(str(current_dir / "../../")) +# Import the matmul wrapper +matmul_path = (current_dir / "../../examples/13_gemm_reduce_scatter/matmul_wrapper.py").resolve() +matmul_spec = importlib.util.spec_from_file_location("matmul_wrapper", matmul_path) +matmul_module = importlib.util.module_from_spec(matmul_spec) +matmul_spec.loader.exec_module(matmul_module) + +# Import the validation function +validation_path = (current_dir / "../../examples/common/validation.py").resolve() +validation_spec = importlib.util.spec_from_file_location("validation", validation_path) +validation_module = importlib.util.module_from_spec(validation_spec) +validation_spec.loader.exec_module(validation_module) + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "m, n, k", + [ + (512, 512, 512), # Very small for quick testing + (1024, 1024, 1024), # Small + (2048, 2048, 2048), # Medium + ], +) +@pytest.mark.parametrize( + "BLK_M, BLK_N, BLK_K", + [ + (32, 32, 32), # Small blocks + (64, 64, 32), # Medium blocks + ], +) +def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K): + """Worker function for PyTorch distributed execution.""" + heap_size = 1 << 30 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + + # GEMM + datatype = dtype + + assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})." + assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})." + + A = shmem.randn(m, k, device="cuda", dtype=datatype) + B = shmem.randn(n, k, device="cuda", dtype=datatype).T + C = shmem.zeros((m, n), device="cuda", dtype=A.dtype) + + M = m + N = n + K = k + + # Splitting + rows_per_gpu = k // world_size + k = rows_per_gpu + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype) + local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype) + + total_blocks_M = triton.cdiv(m, BLK_M) + total_blocks_N = triton.cdiv(n, BLK_N) + total_tiles = total_blocks_M * total_blocks_N + + tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + + locks = shmem.zeros((288,), device="cuda", dtype=torch.int32) + P = shmem.zeros( + (288, BLK_M * BLK_N), + device="cuda", + dtype=torch.float32, + ) + bias = None + gemm_stream = torch.cuda.Stream() + timestamps = Timestamps(num_tiles=total_tiles) + + def preamble(): + shmem.barrier() + tile_completed.zero_() + shmem.barrier() + + def run_experiment(): + nonlocal local_output + nonlocal compute_buffer + + shmem.barrier() + + torch.cuda.nvtx.range_push("GEMM + Communication") + with torch.cuda.stream(gemm_stream): + local_output = matmul_module.matmul_reduce_scatter.apply( + local_A, + local_B, + compute_buffer, + local_output, + bias, + P, + locks, + tile_completed, + rank, + world_size, + 288, + BLK_M, + BLK_N, + BLK_K, + 6, + True, + 1, + 8, + 0, + 16, + 2, + shmem.get_heap_bases(), + cu_count, + False, + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + torch.cuda.nvtx.range_pop() + shmem.barrier() + + # Synchronize across all GPUs + shmem.barrier() + run_experiment() + shmem.barrier() + preamble() + shmem.barrier() + + shmem.info("Validating...") + + matmul_module.matmul_reduce_scatter.set_debug(False) + # Validate global result + success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2) + assert success, ( + f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}" + )