From b81d63bdc9ee8437993bdf53da7435108b4403a4 Mon Sep 17 00:00:00 2001 From: Sanchit Jain Date: Thu, 18 Sep 2025 00:05:40 -0700 Subject: [PATCH 1/8] Enable MoEGEMM --- .../11_bmg_moe_gemm_bf16.cpp | 746 ++++++++++++++++++ examples/11_bmg_moe_gemm_bf16/CMakeLists.txt | 37 + examples/CMakeLists.txt | 1 + .../collective/builders/xe_builder.inl | 2 +- .../epilogue/collective/xe_array_epilogue.hpp | 37 + .../collective/builders/xe_mma_builder.inl | 4 +- .../cutlass/gemm/collective/xe_array_mma.hpp | 33 + include/cutlass/gemm/dispatch_policy.hpp | 13 +- .../cutlass/gemm/kernel/gemm_universal.hpp | 1 + .../cutlass/gemm/kernel/tile_scheduler.hpp | 1 + include/cutlass/gemm/kernel/xe_moe_gemm.hpp | 360 +++++++++ .../gemm/kernel/xe_tile_scheduler_moe.hpp | 481 +++++++++++ 12 files changed, 1710 insertions(+), 6 deletions(-) create mode 100644 examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp create mode 100644 examples/11_bmg_moe_gemm_bf16/CMakeLists.txt create mode 100644 include/cutlass/gemm/kernel/xe_moe_gemm.hpp create mode 100644 include/cutlass/gemm/kernel/xe_tile_scheduler_moe.hpp diff --git a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp new file mode 100644 index 0000000000..23c0c9a23a --- /dev/null +++ b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp @@ -0,0 +1,746 @@ +/*************************************************************************************************** + * Copyright (c) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Fused MoE API example based on cutlass Group GEMM + + This example demonstrates fusing multiple GEMM operations into one kernel. + + Note that the scalar arguments to e.g. the standard 00_bmg_gemm example, + have been replaced with vector equivalents, as each individual GEMM has its + own inputs and outputs, which needn't be contiguous in memory. For example, + where 00_bmg_gemm receives an `ElementA *` defining Matrix A, grouped gemm + receives a `ElementA **`, i.e. a pointer to pointers, each pointing to a + distinct Matrix A. Likewise, each individual GEMM operation may have its own + alpha and beta factors for linear combination. This example demonstrates two + approaches: the user can provide `options.alpha` and `options.beta`, in which + case they will apply to all GEMMs; otherwise, random values are generated per + GEMM. + + Group GEMM scheduling (cutlass::gemm::GroupScheduler) is more complex than + standard GEMM, because each GEMM may have a unique size, only known at + runtime. Thus, the scheduler will distribute an a priori unknown number of + tiles to each work-group. See + include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp for + implementation. + + Note that for simplicity, this example hard-codes input shapes. + + Verification for this example is a conventional GEMM kernel, executed + iteratively per group. + + To build & run this example (from your build dir): + + $ ninja 11_bmg_fused_moe_bf16 + $ ./examples/sycl/11_bmg_fused_moe_bf16/11_bmg_fused_moe_bf16 + + Note: the code may spill registers once compiled which will result in + sub-optimal performance. This is because of an issue inside Intel Graphics + Compiler (IGC) related to VectorAliasBBThreshold being debugged internally. + To avoid register spills, build the example by setting the environment + variable: $ export IGC_VectorAliasBBThreshold=10000 +*/ +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_array_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "helper.h" +#include "sycl_common.hpp" + +#include + +using namespace cute; +using ProblemShape = + cutlass::gemm::GroupProblemShape>; // per group + +using ElementAccumulator = float; // <- data type of accumulator +using ElementComputeEpilogue = float; // <- data type of epilogue operations +using ElementA = bfloat16_t; // <- data type of elements in input matrix A +using ElementB = bfloat16_t; // <- data type of elements in input matrix B +using ElementOutput = bfloat16_t; // <- data type of elements in output matrix D + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CUTLASS_SYCL_PROFILING_ENABLED + +// Command line options parsing +struct GroupGEMMOptions { + + bool error = false; + bool help = false; + + float alpha = 1.f; + float beta = 0.f; + int iterations; + int m=0, n=0, k=0, groups; + int *num_rows_per_expert = nullptr; + std::vector problem_sizes_host; + + GroupGEMMOptions() + : error(false), help(false), alpha(1.f), beta(0.f), iterations(100) { + } + + void parse(const int num_experts, const int *num_tokens_per_expert_host, + int moe_n, int moe_k, + const int *num_tokens_per_expert_device = nullptr) { + n = moe_n; + k = moe_k; + groups = num_experts; + iterations = 2; + num_rows_per_expert = const_cast(num_tokens_per_expert_device); + assert(groups > 0); + problem_sizes_host.clear(); + problem_sizes_host.reserve(groups); + for (int i = 0; i < groups; i++) { + problem_sizes_host.push_back({num_tokens_per_expert_host[i], n, k}); + } + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, + std::vector + problem_sizes_host) const { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const &problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ExampleRunner { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + // Host-side allocations + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector stride_A_host; + std::vector stride_B_host; + std::vector stride_C_host; + std::vector stride_D_host; + + std::vector alpha_host; + std::vector beta_host; + + // Device-side allocations + cutlass::DeviceAllocation + problem_sizes; + + // This example defines all matrices in a single allocation (e.g. block_A), + // but this is not a requirement. Matrix base pointers are read from device + // allocation (e.g. ptr_A) + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation ptr_ref_D; + + cutlass::DeviceAllocation stride_A; + cutlass::DeviceAllocation stride_B; + cutlass::DeviceAllocation stride_C; + cutlass::DeviceAllocation stride_D; + + // Note, this is an array of pointers to alpha and beta scaling values per + // group + cutlass::DeviceAllocation alpha_device; + cutlass::DeviceAllocation beta_device; + cutlass::DeviceAllocation block_alpha; + cutlass::DeviceAllocation block_beta; + int* cumsum_host; + cutlass::DeviceAllocation cumsum_device; + + uint64_t seed = 0; + + // + // Methods + // + + bool verify(const GroupGEMMOptions &options) { + bool passed = true; + // Verify against individual reference GEMMs + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), + LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), + LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), + LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), + LayoutD::packed({M, N})); + + // + // Compute reference output + // + cutlass::reference::device::GemmComplex( + {M, N, K}, alpha_host.at(i), ref_A, cutlass::ComplexTransform::kNone, + ref_B, cutlass::ComplexTransform::kNone, beta_host.at(i), ref_C, + ref_D, ElementAccumulator(0), + 1, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // Wait for kernel to finish + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or + // not + passed &= cutlass::reference::device::BlockCompareEqual( + block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), + M * N); + if (!passed) + break; + } + return passed; + } + + /// Allocates device-side data + void allocate(const GroupGEMMOptions &options, const ElementA *block_A_ptr, + const ElementA *block_B_ptr, ElementOutput*block_C_ptr, + int block_A_size, int block_B_size, int block_C_size) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + cumsum_device.reset(options.groups + 1); + cumsum_host = (int32_t*)(malloc((options.groups + 1) * sizeof(int32_t))); + cumsum_host[0] = 0; + // Compute total allocation sizes across group + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + cumsum_host[i + 1] += M + cumsum_host[i]; + // Offset into block allocation of each matrix base pointer + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_A_host.push_back( + cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back( + cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back( + cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back( + cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + } + assert(total_elements_A == block_A_size); + // change this assert to static assert because it's known at compile-time + assert(total_elements_B == block_B_size); + assert(total_elements_C == block_C_size); + block_A.reset(const_cast(block_A_ptr), block_A_size); + block_B.reset(const_cast(block_B_ptr), block_B_size); + block_C.reset(total_elements_D); + block_D.reset(block_C_ptr, block_C_size); + block_ref_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + cumsum_device.copy_from_host(cumsum_host); + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize_for_moe_gemm(const GroupGEMMOptions &options) { + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + // Compute offsets, alpha & beta over group on host + + ptr_A_host.at(0) = block_A.get(); + ptr_B_host.at(0) = block_B.get(); + ptr_C_host.at(0) = block_C.get(); + ptr_D_host.at(0) = block_D.get(); + for (int32_t i = 0; i < options.groups; ++i) { + // Fill host vector of alpha & beta with random values if using per-group + // values + alpha_host.push_back( + (options.alpha == FLT_MAX) + ? static_cast((rand() % 5) + 1) + : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) + ? static_cast(rand() % 5) + : options.beta); + // Fill host ptr vectors with offset addresses into device alpha/beta + // blocks + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + // Allocate device memory & copy from host + ptr_A.reset(options.groups); + // Per-group alpha and beta + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + // Per-group alpha and beta ptrs + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + // Per-group alpha and beta values - note these are not directly passed to + // kernel - the pointers (alpha_device/beta_device) are passed instead + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize_for_ref_gemm(const GroupGEMMOptions &options) { + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + // Compute offsets, alpha & beta over group on host + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + // Fill host ptr vectors with offset addresses into device alpha/beta + // blocks + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + // Allocate device memory & copy from host + ptr_A.reset(options.groups); + // Per-group alpha and beta + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + // Per-group alpha and beta ptrs + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + // Per-group alpha and beta values - note these are not directly passed to + // kernel - the pointers (alpha_device/beta_device) are passed instead + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + } + + /// Populates a Gemm::Arguments structure from the given commandline options + typename Gemm::Arguments + args_from_options(const GroupGEMMOptions &options, + const cutlass::KernelHardwareInfo &hw_info, + const int gemm_N, + const int gemm_K) { + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + bool host_problem_shapes_available = false; + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, + // i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ + // between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + using RasterOrderOptions = + typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup< + ProblemShape>::RasterOrderOptions; + + // Per-GEMM problem shape info may only exist on the device. + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), + options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), + stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN}, + options.num_rows_per_expert, + gemm_N, + gemm_K}; + } else { + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), + stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN}, + options.num_rows_per_expert, + gemm_N, + gemm_K}; + } + + return arguments; + } + + cutlass::Status run(const GroupGEMMOptions &options, + const cutlass::KernelHardwareInfo &hw_info, + const ElementA *A_ptr, const ElementB *B_ptr, + ElementOutput *C_ptr, int A_size, int B_size, int D_size, const int gemm_n, const int gemm_k) { + allocate(options, A_ptr, B_ptr, C_ptr, A_size, B_size, D_size); + initialize_for_moe_gemm(options); + + Gemm gemm_op; + auto arguments = args_from_options(options, hw_info, gemm_n, gemm_k); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + syclcompat::wait(); + initialize_for_ref_gemm(options); + // Verify that the result is correct + bool passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + if (!passed) + return cutlass::Status::kErrorInternal; + initialize_for_moe_gemm(options); + syclcompat::wait(); + arguments = args_from_options(options, hw_info, gemm_n, gemm_k); + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.run()); + } + syclcompat::wait(); + + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time) / double(options.iterations); + double gflops = options.gflops(cute_average_time / 1000.0, + options.problem_sizes_host); + + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) + << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Avg runtime : " << cute_average_time << " ms" + << std::endl; + std::cout << " GFLOPS : " << gflops << std::endl; + } + + return cutlass::Status::kSuccess; + } +}; + +void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights, + bfloat16_t *outputs, const int gemm_n, const int gemm_k, + const int *num_rows_per_expert_device, const int num_experts) { + GroupGEMMOptions options; + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a + // given device ID. This information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + int num_tokens_incl_duplicated = 0; + int total_rows_for_each_expert[num_experts]; + cutlass::DeviceAllocation num_rows_per_expert_obj; + num_rows_per_expert_obj.reset( + const_cast(num_rows_per_expert_device), num_experts); + num_rows_per_expert_obj.copy_to_host(total_rows_for_each_expert); + options.parse(num_experts, total_rows_for_each_expert, gemm_n, gemm_k, + num_rows_per_expert_device); + + for (int i = 0; i < num_experts; i++) { + num_tokens_incl_duplicated += total_rows_for_each_expert[i]; + } + size_t A_size = num_tokens_incl_duplicated * gemm_k; + size_t B_size = num_experts * gemm_n * gemm_k; + size_t D_size = num_tokens_incl_duplicated * gemm_n; + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x16x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; + + // Workgroup-level tile + using TileShape = Shape<_16, _256, _32>; +/* + using TiledMma = + TiledMMA, + Layout, Stride<_4, _1, _0>>>; +*/ + + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper, Layout, + Layout, Stride<_8, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + // Dispatch to grouped gemm algorithm + using GEMMDispatchPolicy = + cutlass::gemm::MainloopIntelXeXMX16Group; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + using EpilogueOp = + cutlass::epilogue::fusion::LinearCombination; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape, + Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueTileAuto, + float, float, float, LayoutC, 1, ElementOutput, LayoutC, 1, + EpilogueDispatchPolicy, EpilogueOp>::CollectiveOp; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, TileShape, ElementA, + cutlass::gemm::TagToStrideA_t, ElementB, + cutlass::gemm::TagToStrideB_t, TiledMma, GmemTiledCopyA, void, + void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + runner.run(options, hw_info, activations, weights, outputs, A_size, B_size, + D_size, gemm_n, gemm_k); + num_rows_per_expert_obj.release(); +} + + +int main(int argc, const char **argv) { + constexpr int num_experts = 128; + + int total_rows_for_each_expert[128] = { + 23, 16, 15, 4, 9, 36, 20, 20, 26, 26, 9, 31, 36, 3, 30, 15, 12, 6, + 28, 18, 3, 12, 16, 9, 18, 17, 38, 14, 36, 16, 24, 34, 22, 4, 27, 21, + 16, 39, 30, 19, 6, 35, 23, 29, 1, 11, 29, 13, 6, 25, 27, 26, 19, 8, + 23, 31, 17, 24, 40, 15, 20, 32, 17, 36, 34, 12, 31, 22, 0, 9, 19, 20, + 2, 26, 9, 6, 15, 27, 22, 8, 18, 14, 36, 12, 19, 19, 36, 20, 2, 27, + 23, 36, 29, 14, 4, 28, 5, 1, 36, 5, 31, 36, 26, 32, 6, 21, 32, 39, + 27, 12, 37, 6, 6, 39, 0, 16, 39, 34, 19, 13, 0, 0, 0, 0, 0, 0, 0, 0}; + + int num_tokens_incl_duplicated = 0; + for (int i = 0; i < num_experts; i++) { + num_tokens_incl_duplicated += total_rows_for_each_expert[i]; + } + int n_moe = 3072; + int k_moe = 4096; + + cutlass::DeviceAllocation num_rows_per_expert_device; + cutlass::DeviceAllocation activations_data; + cutlass::DeviceAllocation weights_data; + cutlass::DeviceAllocation output_data; + size_t A_size = num_tokens_incl_duplicated * k_moe; + size_t B_size = num_experts * n_moe * k_moe; + size_t D_size = num_tokens_incl_duplicated * n_moe; + num_rows_per_expert_device.reset(num_experts); + num_rows_per_expert_device.copy_from_host(total_rows_for_each_expert); + activations_data.reset(A_size); + weights_data.reset(B_size); + output_data.reset(D_size); + uint64_t seed = 2023; + initialize_block(activations_data, seed + 2023); + initialize_block(weights_data, seed + 2022); + initialize_block(output_data, seed + 2021); + MoEGEMM(activations_data.get(), weights_data.get(), output_data.get(), n_moe, + k_moe, num_rows_per_expert_device.get(), num_experts); + activations_data.release(); + weights_data.release(); + output_data.release(); + num_rows_per_expert_device.release(); + return 0; +} diff --git a/examples/11_bmg_moe_gemm_bf16/CMakeLists.txt b/examples/11_bmg_moe_gemm_bf16/CMakeLists.txt new file mode 100644 index 0000000000..68cf7d7ff7 --- /dev/null +++ b/examples/11_bmg_moe_gemm_bf16/CMakeLists.txt @@ -0,0 +1,37 @@ +# Copyright (c) Intel Corporation. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 11_bmg_moe_gemm_bf16 + 11_bmg_moe_gemm_bf16.cpp + TEST_COMMAND_OPTIONS +) + +# TODO(codeplay): Remove these once IGC VectorAliasThreshold issue is fixed +set_target_properties( 11_bmg_moe_gemm_bf16 PROPERTIES CXX_COMPILER_LAUNCHER "IGC_VectorAliasBBThreshold=10000" ) +set_target_properties( 11_bmg_moe_gemm_bf16 PROPERTIES CXX_LINKER_LAUNCHER "IGC_VectorAliasBBThreshold=10000" ) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d141f5b7de..9872cb7382 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -111,6 +111,7 @@ if(CUTLASS_ENABLE_SYCL) 08_bmg_gemm_f8 09_bmg_grouped_gemm_f8 10_bmg_grouped_gemm_mixed_dtype + 11_bmg_moe_gemm_bf16 ) add_subdirectory(${EXAMPLE}) endforeach() diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index 809cede6f7..b4b3358076 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -193,7 +193,7 @@ template < //TODO(Codeplay): Should FusionCallbacks use DispatchPolicy IntelXeGroupEpilogue for group gemm? That does not work. using FusionCallbacks = typename detail::FusionOpInfo::template FusionCallbacks< - IntelXeXMX16, TileShape_MNK, TileShape_MNK, CopyOpG2R>; + std::conditional_t, TileShape_MNK, TileShape_MNK, CopyOpG2R>; using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< DispatchPolicy, diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index 9b879bd14d..88e4b2b50f 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -43,6 +43,7 @@ #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" #include "cutlass/detail/layout.hpp" +#include "../tools/util/include/cutlass/util/packed_stride.hpp" #include "cute/tensor.hpp" @@ -475,6 +476,42 @@ class CollectiveEpilogue< return cute::make_tuple(mC_mnl, mD_mnl); } +template + CUTLASS_DEVICE auto + update_tensor_shape_stride(int32_t const &next_group, + ProblemShape_MNKL const &problem_shape_mnkl, + const int32_t *num_rows_per_expert) { + auto [M, N, K, L] = problem_shape_mnkl; + int32_t cumulative_M = 0; + for (int i = 0; i < next_group; i++) { + cumulative_M += num_rows_per_expert[i]; + } + M = num_rows_per_expert[next_group]; + + TensorC mC_mnl; + TensorD mD_mnl; + if constexpr (is_source_supported) { + ElementC const *ptr_C_curr_batch = + reinterpret_cast(params.ptr_C[0]) + + cumulative_M * N; + mC_mnl = make_tensor( + make_gmem_ptr(ptr_C_curr_batch), + make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride( + InternalStrideC{}, {M, N, 1}))); + } + + if constexpr (is_destination_supported) { + ElementD *ptr_D_curr_batch = + reinterpret_cast(params.ptr_D[0]) + + cumulative_M * N; + mD_mnl = make_tensor( + make_gmem_ptr(ptr_D_curr_batch), + make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride( + InternalStrideD{}, {M, N, 1}))); + } + return cute::make_tuple(mC_mnl, mD_mnl); + } + private: Params const& params; FusionCallbacks fusion_callbacks; diff --git a/include/cutlass/gemm/collective/builders/xe_mma_builder.inl b/include/cutlass/gemm/collective/builders/xe_mma_builder.inl index c2ffaa5a5f..9b71d046e7 100644 --- a/include/cutlass/gemm/collective/builders/xe_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/xe_mma_builder.inl @@ -160,7 +160,7 @@ struct CollectiveBuilder< cutlass::gemm::collective::StageCountAuto, KernelScheduleType, cute::enable_if_t< - cute::is_any_of_v && + cute::is_any_of_v && cute::is_any_of_v && cute::is_any_of_v > @@ -190,7 +190,7 @@ struct CollectiveBuilder< Layout, Layout, Stride>>::TiledMMA; - static constexpr bool IsGroup = cute::is_same_v; + static constexpr bool IsGroup = cute::is_any_of_v; using KernelSchedule = std::conditional_t, KernelXe, KernelScheduleType>; static constexpr int PipelineStages = IsGroup ? 2 : 3; diff --git a/include/cutlass/gemm/collective/xe_array_mma.hpp b/include/cutlass/gemm/collective/xe_array_mma.hpp index 102a431008..8421d4915c 100644 --- a/include/cutlass/gemm/collective/xe_array_mma.hpp +++ b/include/cutlass/gemm/collective/xe_array_mma.hpp @@ -37,6 +37,7 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" +#include "../tools/util/include/cutlass/util/packed_stride.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -293,6 +294,38 @@ struct CollectiveMma, TileShape_, El return cute::make_tuple(mA, mB); } + +template + CUTLASS_DEVICE auto + update_tensor_shape_stride(Params const &mainloop_params, + int32_t const &next_group, + ProblemShape_MNKL const &problem_shape_mnkl, + const int *num_rows_per_expert) { + int32_t cumulative_M = 0; + for (int i = 0; i < next_group; i++) { + cumulative_M += num_rows_per_expert[i]; + } + + const int32_t M = num_rows_per_expert[next_group]; + const int32_t N = get<1>(problem_shape_mnkl); + const int32_t K = get<2>(problem_shape_mnkl); + + ElementA const *ptr_A_curr_batch = + reinterpret_cast(mainloop_params.ptr_A[0]) + + cumulative_M * K; + ElementB const *ptr_B_curr_batch = + reinterpret_cast(mainloop_params.ptr_B[0]) + + next_group * K * N; + + Tensor mA = make_tensor( + make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K, (int32_t)1), + cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + Tensor mB = make_tensor( + make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K, (int32_t)1), + cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + return cute::make_tuple(mA, mB); + } }; } // namespace cutlass::gemm::collective diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index b742dfd76f..a4873c5f22 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2025 Intel Corporation * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -140,6 +141,7 @@ struct KernelTmaWarpSpecializedCooperativeMixedInput: KernelTmaWarpSpecializedCo struct KernelXe { }; struct KernelXeCooperative { }; struct KernelXePtrArrayCooperative { }; +struct KernelXeMoEGEMM { }; ////////////////////////////////////////////////////////////////////////////// // @@ -1214,9 +1216,14 @@ struct MainloopIntelXeXMX16 { using ClusterShape = Shape<_1,_1,_1>; }; -template -struct MainloopIntelXeXMX16Group : MainloopIntelXeXMX16 { -}; +template +struct MainloopIntelXeXMX16Group + : MainloopIntelXeXMX16 {}; + +// partial specialization for KernelXeMoEGEMM +template +struct MainloopIntelXeXMX16Group + : MainloopIntelXeXMX16 {}; template struct MainloopIntelXeXMX16GroupMixedPrecision : MainloopIntelXeXMX16 { diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 69137d2114..fd8bfd3cf0 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -79,6 +79,7 @@ struct IsCutlass3ArrayKernel +class GemmUniversal< + ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, + cute::enable_if_t>> { +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or cute::rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::WorkgroupTileShape; + using WorkgroupTileShape = TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_same_v, + "Only Group Scheduler is supported with this code."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::PersistentTileSchedulerXeMoE; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr int SubgroupSize = + CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = + CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; + using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; + + using MainloopTensors = typename CollectiveMainloop::MainloopTensors; + using EpilogueTensors = typename CollectiveEpilogue::EpilogueTensors; + + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static_assert(cute::is_same_v>); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + const int *M_per_group{nullptr}; + int N = 0; + int K = 0; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void *workspace{nullptr}; + const int *M_per_group{nullptr}; + int N = 0; + int K = 0; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + auto problem_shape = args.problem_shape; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t *workspace_ptr = reinterpret_cast(workspace); + + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape, TileShape{}, ClusterShape{}, hw_info, args.scheduler, + workspace_ptr); + + return {args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, workspace_ptr), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, workspace_ptr), + hw_info, + scheduler, + workspace, + args.M_per_group, + args.N, + args.K}; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + + implementable = implementable && (args.mode == GemmUniversalMode::kGrouped || + (args.mode == GemmUniversalMode::kBatched && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3)); + + implementable = implementable && TileScheduler::can_implement(args.scheduler); + + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, -1); + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t *workspace_ptr = reinterpret_cast(workspace); + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, -1); + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + TileScheduler scheduler{params.scheduler}; + scheduler.configure( + const_cast(params.M_per_group), params.N, params.K); + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); + constexpr auto workgroup_shape = WorkgroupTileShape{}; // (BLK_M,BLK_N,BLK_K) + + int thread_idx = int(ThreadIdxX()); + constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + bool did_group_change = true; + int32_t curr_group = -1; + using ProblemShapeMNKL = Shape; + ProblemShapeMNKL problem_shape_MNKL; + MainloopTensors AB_tensors; + EpilogueTensors CD_tensors; + const int32_t N = params.N; + const int32_t K = params.K; + if (work_tile_info.is_valid()) { + curr_group = work_tile_info.L_idx; + problem_shape_MNKL = append<4>(Shape{params.M_per_group[curr_group], N, K}, 1); + } + + while (work_tile_info.is_valid()) { + auto M = get<0>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + Tensor mA_mkl = cute::get_xe_tensor(make_shape(M, K, L)); //(m,k,l) + Tensor mB_nkl = cute::get_xe_tensor(make_shape(N, K, L)); //(n,k,l) + + auto m_coord = work_tile_info.M_idx; + auto n_coord = work_tile_info.N_idx; + + auto gA_mkl = local_tile(mA_mkl, select<0,2>(workgroup_shape), make_coord(m_coord, _, 0)); + auto gB_nkl = local_tile(mB_nkl, select<1,2>(workgroup_shape), make_coord(n_coord, _, 0)); + + CollectiveMainloop collective_mma; + if (did_group_change) { + AB_tensors = collective_mma.update_tensor_shape_stride( + params.mainloop, curr_group, problem_shape_MNKL, + params.M_per_group); + } + auto tile_coord = make_coord(m_coord, n_coord, _, 0); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + int work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, workgroup_shape); + int work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, make_shape(K)), make_shape(K)); + + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(workgroup_shape)); + + // Perform the collective scoped MMA + collective_mma( + accumulators, + gA_mkl, + gB_nkl, + accumulators, + k_tile_iter, work_k_tile_count, + tile_coord, + K, + thread_idx, + params.mainloop, + AB_tensors + ); + + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators, -1, -1); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + + if (did_group_change) { + CD_tensors = epilogue.update_tensor_shape_stride( + curr_group, problem_shape_MNKL, params.M_per_group); + did_group_change = false; + } + + epilogue( + problem_shape_MNKL, + subgroup_shape, + tile_coord, + accumulators, + tiled_mma, + thread_idx, + CD_tensors + ); + } + + // Get next work tile + auto [next_work_tile_info, temp] = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + + did_group_change = curr_group != work_tile_info.L_idx; + + if (did_group_change && work_tile_info.is_valid()) { + curr_group = work_tile_info.L_idx; + problem_shape_MNKL = append<4>(Shape{params.M_per_group[curr_group], N, K}, 1); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/xe_tile_scheduler_moe.hpp b/include/cutlass/gemm/kernel/xe_tile_scheduler_moe.hpp new file mode 100644 index 0000000000..c372705472 --- /dev/null +++ b/include/cutlass/gemm/kernel/xe_tile_scheduler_moe.hpp @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/fast_math.h" +#include "cutlass/gemm_coord.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cute/layout.hpp" +#include "cute/tensor.hpp" + +namespace cutlass::gemm::kernel::detail { + +/////////////////////////////////////////////////////////////////////////////// + +// Persistent Thread Block (TB) scheduler for MoE GEMM +template +class PersistentTileSchedulerXeMoE { + // + // Data members + // + +private: + uint64_t current_work_linear_idx_ = 0; + uint64_t total_grid_size_ = 0; + int32_t* num_rows_per_expert_ = nullptr; + int32_t K_ = 0; + int32_t N_ = 0; + + // Tracking current group, its starting linear idx and total tiles + struct GroupInfo { + int group_idx = 0; + uint64_t start_linear_idx = 0; + uint64_t total_tiles = 0; + } current_group_info_; + +public: + struct WorkTileInfo { + int32_t M_idx = 0; + int32_t N_idx = 0; + int32_t L_idx = 0; + bool is_valid_tile = false; + + CUTLASS_HOST_DEVICE + bool + is_valid() const { + return is_valid_tile; + } + + CUTLASS_HOST_DEVICE + static WorkTileInfo + invalid_work_tile() { + return {-1, -1, -1, false}; + } + + CUTLASS_HOST_DEVICE + bool + is_final_split(uint32_t k_tiles_per_output_tile) const { + return true; + } + + CUTLASS_HOST_DEVICE + int32_t + reduction_subtile_idx() const { + return -1; + } + }; + + using ProblemShape = typename GroupProblemShape::UnderlyingProblemShape; + using Params = PersistentTileSchedulerSm90GroupParams; + using RasterOrder = typename Params::RasterOrder; + using RasterOrderOptions = typename Params::RasterOrderOptions; + + struct Arguments { + int max_swizzle_size = 1; + // Not applying Heuristics for Grouped problems, since largest dimension can change per group + RasterOrderOptions raster_order = RasterOrderOptions::AlongM; + }; + + // Sink scheduler params as a member + Params scheduler_params; + + // + // Methods + // + + CUTLASS_HOST_DEVICE void configure(int32_t* num_rows_per_expert, int32_t N, int32_t K) { + num_rows_per_expert_ = num_rows_per_expert; + N_ = N; + K_ = K; + } + + // Given the inputs, computes the total number of output blocks this problem will compute over + // Note that this is only the logical size of our grid, not the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_tiled_cta_shape_mnl(const KernelHardwareInfo &hw_info, ClusterShape cluster_shape) { + uint32_t total_ctas = 0; + uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here + + total_ctas = hw_info.sm_count; + + return Params::get_tiled_cta_shape_mnl( + to_gemm_coord(cluster_shape), + total_ctas, cta_in_N_dim + ); + } + + template + static Params + to_underlying_arguments( + GroupProblemShape problem_shapes, + TileShape tile_shape, + ClusterShape cluster_shape, + KernelHardwareInfo const& hw_info, + Arguments const& arguments, + [[maybe_unused]] void* workspace=nullptr, + [[maybe_unused]] const uint32_t epilogue_subtile = 1, + [[maybe_unused]] uint32_t ktile_start_alignment_count = 1u + ) { + + // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic + static_assert(cute::is_static::value); + static_assert(cute::is_static::value); + + dim3 problem_blocks = get_tiled_cta_shape_mnl( + hw_info, + cluster_shape); + + Params params; + params.initialize( + problem_blocks, + problem_shapes, + to_gemm_coord(tile_shape), + to_gemm_coord(cluster_shape), + hw_info, + arguments.max_swizzle_size, + arguments.raster_order + ); + + return params; + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + [[maybe_unused]] Params const& params, + GroupProblemShape problem_shapes, + TileShape tile_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info, + Arguments arguments, + bool truncate_by_problem_size=true) { + + dim3 problem_blocks = get_tiled_cta_shape_mnl( + hw_info, + cluster_shape); + + return Params::get_grid_shape( + problem_blocks, + to_gemm_coord(cluster_shape), + hw_info, + arguments.max_swizzle_size, + arguments.raster_order, + /* truncate_by_problem_size = */true + ); + } + + + + static bool + can_implement(Arguments const& args) { + return true; + } + + + + + PersistentTileSchedulerXeMoE() = default; + + CUTLASS_DEVICE explicit PersistentTileSchedulerXeMoE(Params const& params_) : scheduler_params(params_) { + // MSVC requires protecting use of CUDA-specific nonstandard syntax, + // like blockIdx and gridDim, with __CUDA_ARCH__. +#if defined(__CUDA_ARCH__) || defined __SYCL_DEVICE_ONLY__ + if (scheduler_params.raster_order_ == RasterOrder::AlongN) { + current_work_linear_idx_ = uint64_t(BlockIdxX()) + uint64_t(BlockIdxY()) * uint64_t(GridDimX()); + } + else { + current_work_linear_idx_ = uint64_t(BlockIdxX()) * uint64_t(GridDimY()) + uint64_t(BlockIdxY()); + } + + total_grid_size_ = uint64_t(GridDimX()) * uint64_t(GridDimY()) * uint64_t(GridDimZ()); + +#else + CUTLASS_ASSERT(false && "This line should never be reached"); +#endif + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work() { + return get_current_work_for_linear_idx(current_work_linear_idx_); + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work_for_linear_idx(uint64_t linear_idx) { + if (scheduler_params.pre_processed_problem_shapes && linear_idx >= scheduler_params.blocks_across_problem_) { + return WorkTileInfo::invalid_work_tile(); + } + + return get_work_idx_m_and_n(linear_idx, + current_group_info_, + scheduler_params.problem_shapes_, + scheduler_params.cta_shape_, + scheduler_params.cluster_shape_, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.divmod_cta_shape_m_, + scheduler_params.divmod_cta_shape_n_, + scheduler_params.log_swizzle_size_, + scheduler_params.raster_order_); + } + + CUTLASS_DEVICE + void + advance_to_next_work(uint32_t advance_count = 1) { + current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); + } + + // get work_idx_m, work_idx_n from linear_idx while applying swizzle + CUTLASS_DEVICE + WorkTileInfo + get_work_idx_m_and_n( + uint64_t linear_idx, + struct GroupInfo& group_info, + GroupProblemShape &problem_shapes, + GemmCoord cta_shape, + GemmCoord cluster_shape, + FastDivmodU64Pow2 const& divmod_cluster_shape_major, + FastDivmodU64Pow2 const& divmod_cluster_shape_minor, + FastDivmodU64 const& divmod_cta_shape_m, + FastDivmodU64 const& divmod_cta_shape_n, + int32_t log_swizzle_size, + RasterOrder raster_order) { + + bool valid_tile = true; + uint64_t ctas_along_m, ctas_along_n; + int total_problem_groups = problem_shapes.groups(); + ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + divmod_cta_shape_m.divisor - 1); + ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + divmod_cta_shape_n.divisor - 1); + + auto problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + auto problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; + + while (group_info.start_linear_idx + group_info.total_tiles <= linear_idx) { + group_info.group_idx++; + + if (group_info.group_idx >= total_problem_groups) + return WorkTileInfo::invalid_work_tile(); + + group_info.start_linear_idx += group_info.total_tiles; + ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + divmod_cta_shape_m.divisor - 1); + ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + divmod_cta_shape_n.divisor - 1); + + problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; + } + + uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; + uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide(linear_idx - group_info.start_linear_idx); + divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); + + // With static schedulers, we launch grid such that all cluster are linear (1-D) order, i.e., + // there can only be one cluster in the minor dimension. get_grid_shape() in scheduler params + // put cluster_shape.m/n() as the minor dimension based on raster order AlongN/M resp. + // Therefore, the offset of a CTA (inside a cluster) in the minor dimension can be directly be + // inferred by the blockIdx along the minor dimension. + if (raster_order == RasterOrder::AlongN) { + cluster_minor_offset = BlockIdxX(); + } + else { + cluster_minor_offset = BlockIdxY(); + } + + uint64_t cluster_idx_minor, cluster_idx_major; + + uint64_t cluster_idx_minor_div_swizzle, extra, offset; + + offset = cluster_id & ((1 << log_swizzle_size) - 1); + extra = cluster_id >> log_swizzle_size; + + uint64_t curr_group_cluster_blk_major; + if (raster_order == RasterOrder::AlongN) { + curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_n); + } + else { + curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_m); + } + cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major; + cluster_idx_major = extra % curr_group_cluster_blk_major; + + cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; + + auto minor_work_idx = static_cast(cluster_idx_minor * divmod_cluster_shape_minor.divisor + + cluster_minor_offset); + auto major_work_idx = static_cast(cluster_idx_major * divmod_cluster_shape_major.divisor + + cluster_major_offset); + + if (raster_order == RasterOrder::AlongN) { + return {minor_work_idx, major_work_idx, group_info.group_idx, valid_tile}; + } + else { + return {major_work_idx, minor_work_idx, group_info.group_idx, valid_tile}; + } + + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the basic tile scheduler, this is always true. + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const&, Params const&) { + return true; + } + + // Performs the reduction across splits for a given output tile. Since this scheduler does + // not split output tiles, no reduction is needed. + template + CUTLASS_DEVICE + static void + fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {} + + // Returns whether the current WorkTileInfo passed in should continue to be used. Since + // this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo + // passed in should not be used after having been processed. + CUTLASS_DEVICE + static bool + continue_current_work(WorkTileInfo&) { + return false; + } + + // The basic tile scheduler does not require any additional workspace + template + static size_t + get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1, uint32_t = 1) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&, + uint32_t, const uint32_t = 1, uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE + static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape_MNKL problem_shape, TileShape tile_shape) { + // All work units returned by this scheduler cover the entire K iteration + // space of the output tile assigned to the work unit. + return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape))); + } + + CUTLASS_HOST_DEVICE + static uint32_t + get_work_k_tile_start(WorkTileInfo const&) { + // All work units returned by this scheduler start from K tile 0 + return 0u; + } + + CUTLASS_DEVICE + static bool + need_separate_reduction(Params const& params) { + return false; + } + + CUTLASS_DEVICE + bool + is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) { + return false; + } + + CUTLASS_DEVICE + uint32_t + epilgoue_subtile_idx(WorkTileInfo const& work_tile_info, Params const& params) const { + return 0; + } + + template + CUTLASS_DEVICE + void + separate_reduction( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + } + + // Shares the accumulator set with peers in the global workspace + template + CUTLASS_DEVICE + static void + share( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + } + + CUTLASS_DEVICE + static bool + valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { + return true; + } + + CUTLASS_DEVICE + static bool + requires_separate_reduction(Params const& params) { + return false; + } + + // Kernel helper function to get next work tile + CUTLASS_DEVICE + auto + fetch_next_work(WorkTileInfo work_tile_info) { + if (continue_current_work(work_tile_info)) { + return cute::make_tuple(work_tile_info, true); + } + + advance_to_next_work(); + return cute::make_tuple(get_current_work(), true); + } + + // Returns the initial work tile info that will be computed over + template + CUTLASS_DEVICE + WorkTileInfo + initial_work_tile_info(ClusterShape) { + return get_current_work(); + } + +}; + +} // namespace cutlass::gemm::kernel::detail From 03b74ca0d64d2cd29a85dce71c18dbdd517accf4 Mon Sep 17 00:00:00 2001 From: Sanchit Jain Date: Fri, 26 Sep 2025 11:51:15 -0700 Subject: [PATCH 2/8] Prevent D2H transfer of problem sizes --- .../11_bmg_moe_gemm_bf16.cpp | 29 +++---- include/cutlass/gemm/kernel/xe_moe_gemm.hpp | 79 +++++++++++++++---- .../gemm/kernel/xe_tile_scheduler_moe.hpp | 6 +- 3 files changed, 80 insertions(+), 34 deletions(-) diff --git a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp index 23c0c9a23a..3aa279b1aa 100644 --- a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp +++ b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp @@ -45,6 +45,9 @@ case they will apply to all GEMMs; otherwise, random values are generated per GEMM. + While a nullptr can be passed for C, in this example, we don't do that, + because the reference kernel doesn't accept nullptr for C. + Group GEMM scheduling (cutlass::gemm::GroupScheduler) is more complex than standard GEMM, because each GEMM may have a unique size, only known at runtime. Thus, the scheduler will distribute an a priori unknown number of @@ -521,26 +524,25 @@ template struct ExampleRunner { if (host_problem_shapes_available) { arguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGrouped, - {options.groups, problem_sizes.get(), - options.problem_sizes_host.data()}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info, {1, RasterOrderOptions::AlongN}, options.num_rows_per_expert, + options.groups, gemm_N, gemm_K}; } else { arguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGrouped, - {options.groups, problem_sizes.get(), nullptr}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info, {1, RasterOrderOptions::AlongN}, options.num_rows_per_expert, + options.groups, gemm_N, gemm_K}; } @@ -645,11 +647,11 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights, using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x16x32_LD_N; + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; // Workgroup-level tile - using TileShape = Shape<_16, _256, _32>; + using TileShape = Shape<_256, _256, _32>; /* using TiledMma = TiledMMA, @@ -658,7 +660,7 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights, using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 typename TiledMMAHelper, Layout, - Layout, Stride<_8, _1, _0>>>::TiledMMA; + Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; // Dispatch to grouped gemm algorithm @@ -702,16 +704,11 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights, int main(int argc, const char **argv) { - constexpr int num_experts = 128; - - int total_rows_for_each_expert[128] = { - 23, 16, 15, 4, 9, 36, 20, 20, 26, 26, 9, 31, 36, 3, 30, 15, 12, 6, - 28, 18, 3, 12, 16, 9, 18, 17, 38, 14, 36, 16, 24, 34, 22, 4, 27, 21, - 16, 39, 30, 19, 6, 35, 23, 29, 1, 11, 29, 13, 6, 25, 27, 26, 19, 8, - 23, 31, 17, 24, 40, 15, 20, 32, 17, 36, 34, 12, 31, 22, 0, 9, 19, 20, - 2, 26, 9, 6, 15, 27, 22, 8, 18, 14, 36, 12, 19, 19, 36, 20, 2, 27, - 23, 36, 29, 14, 4, 28, 5, 1, 36, 5, 31, 36, 26, 32, 6, 21, 32, 39, - 27, 12, 37, 6, 6, 39, 0, 16, 39, 34, 19, 13, 0, 0, 0, 0, 0, 0, 0, 0}; + const int num_experts = 32; + + int total_rows_for_each_expert[num_experts] = { + 148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, 845, 191, 424, 30, 97, 57, 324, + 62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}; int num_tokens_incl_duplicated = 0; for (int i = 0; i < num_experts; i++) { diff --git a/include/cutlass/gemm/kernel/xe_moe_gemm.hpp b/include/cutlass/gemm/kernel/xe_moe_gemm.hpp index 45eb604ca1..b45e8fcfde 100644 --- a/include/cutlass/gemm/kernel/xe_moe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_moe_gemm.hpp @@ -123,14 +123,14 @@ class GemmUniversal< // Device side arguments struct Arguments { GemmUniversalMode mode{}; - ProblemShape problem_shape{}; MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; TileSchedulerArguments scheduler{}; const int *M_per_group{nullptr}; - int N = 0; - int K = 0; + int num_experts; + int N; + int K; }; // Kernel entry point API @@ -143,8 +143,9 @@ class GemmUniversal< TileSchedulerParams scheduler{}; void *workspace{nullptr}; const int *M_per_group{nullptr}; - int N = 0; - int K = 0; + int num_experts; + int N; + int K; }; // @@ -156,7 +157,8 @@ class GemmUniversal< Params to_underlying_arguments(Arguments const& args, void* workspace) { CUTLASS_TRACE_HOST("to_underlying_arguments():"); - auto problem_shape = args.problem_shape; + auto dummy_problem_shape = cute::Shape{256, args.N, args.K}; + auto dummy_group_problem_shape = ProblemShape{1, &dummy_problem_shape, nullptr}; // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; @@ -174,19 +176,20 @@ class GemmUniversal< uint8_t *workspace_ptr = reinterpret_cast(workspace); TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape, TileShape{}, ClusterShape{}, hw_info, args.scheduler, + dummy_group_problem_shape, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace_ptr); return {args.mode, - problem_shape, + dummy_group_problem_shape, CollectiveMainloop::to_underlying_arguments( - args.problem_shape, args.mainloop, workspace_ptr), + dummy_group_problem_shape, args.mainloop, workspace_ptr), CollectiveEpilogue::to_underlying_arguments( - args.problem_shape, args.epilogue, workspace_ptr), + dummy_group_problem_shape, args.epilogue, workspace_ptr), hw_info, scheduler, workspace, args.M_per_group, + args.num_experts, args.N, args.K}; } @@ -199,9 +202,10 @@ class GemmUniversal< (args.mode == GemmUniversalMode::kBatched && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3)); implementable = implementable && TileScheduler::can_implement(args.scheduler); - - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + auto dummy_problem_shape = cute::Shape{256, args.N, args.K}; + auto dummy_group_problem_shape = ProblemShape{1, &dummy_problem_shape, nullptr}; + implementable &= CollectiveMainloop::can_implement(dummy_group_problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(dummy_group_problem_shape, args.epilogue); return implementable; } @@ -255,8 +259,10 @@ class GemmUniversal< SharedStorage& shared_storage = *reinterpret_cast(smem_buf); TileScheduler scheduler{params.scheduler}; + const int32_t N = params.N; + const int32_t K = params.K; scheduler.configure( - const_cast(params.M_per_group), params.N, params.K); + const_cast(params.M_per_group), params.N, params.K, params.num_experts); auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); constexpr auto workgroup_shape = WorkgroupTileShape{}; // (BLK_M,BLK_N,BLK_K) @@ -268,12 +274,53 @@ class GemmUniversal< ProblemShapeMNKL problem_shape_MNKL; MainloopTensors AB_tensors; EpilogueTensors CD_tensors; - const int32_t N = params.N; - const int32_t K = params.K; + if (work_tile_info.is_valid()) { curr_group = work_tile_info.L_idx; problem_shape_MNKL = append<4>(Shape{params.M_per_group[curr_group], N, K}, 1); } + /* + using LayoutA_tiny = cutlass::layout::RowMajor; + using LayoutB_tiny = cutlass::layout::ColumnMajor; + using LayoutC_tiny = cutlass::layout::RowMajor; + using LayoutD_tiny = cutlass::layout::RowMajor; + + using GmemTiledCopyA_tiny = XE_2D_U16x16x32_LD_N; + using GmemTiledCopyB_tiny = XE_2D_U16x16x16_LD_T; + + // Workgroup-level tile + using TileShape_tiny = Shape<_16, _256, _32>; + + using TiledMma_tiny = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper, Layout, + Layout, Stride<_8, _1, _0>>>::TiledMMA; + + + // Dispatch to grouped gemm algorithm + using GEMMDispatchPolicy_tiny = + cutlass::gemm::MainloopIntelXeXMX16Group<2, + cutlass::gemm::KernelXeMoEGEMM>; + using EpilogueDispatchPolicy_tiny = cutlass::epilogue::IntelXeXMX16Group; + + using EpilogueOp_tiny = + cutlass::epilogue::fusion::LinearCombination; + + using CollectiveEpilogue_tiny = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape_tiny, + Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueTileAuto, + float, float, float, LayoutC_tiny, 1, bfloat16_t, LayoutC_tiny, 1, + EpilogueDispatchPolicy_tiny, EpilogueOp_tiny>::CollectiveOp; + + // Mainloop + using CollectiveMainloop_tiny = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy_tiny, TileShape_tiny, ElementA, + cutlass::gemm::TagToStrideA_t, ElementB, + cutlass::gemm::TagToStrideB_t, TiledMma_tiny, GmemTiledCopyA_tiny, void, + void, cute::identity, // A + GmemTiledCopyB_tiny, void, void, cute::identity // B + >; + */ while (work_tile_info.is_valid()) { auto M = get<0>(problem_shape_MNKL); diff --git a/include/cutlass/gemm/kernel/xe_tile_scheduler_moe.hpp b/include/cutlass/gemm/kernel/xe_tile_scheduler_moe.hpp index c372705472..6fc38b961e 100644 --- a/include/cutlass/gemm/kernel/xe_tile_scheduler_moe.hpp +++ b/include/cutlass/gemm/kernel/xe_tile_scheduler_moe.hpp @@ -54,6 +54,7 @@ class PersistentTileSchedulerXeMoE { int32_t* num_rows_per_expert_ = nullptr; int32_t K_ = 0; int32_t N_ = 0; + int32_t num_experts_ = 0; // Tracking current group, its starting linear idx and total tiles struct GroupInfo { @@ -112,10 +113,11 @@ class PersistentTileSchedulerXeMoE { // Methods // - CUTLASS_HOST_DEVICE void configure(int32_t* num_rows_per_expert, int32_t N, int32_t K) { + CUTLASS_HOST_DEVICE void configure(int32_t* num_rows_per_expert, int32_t N, int32_t K, int32_t num_experts) { num_rows_per_expert_ = num_rows_per_expert; N_ = N; K_ = K; + num_experts_ = num_experts; } // Given the inputs, computes the total number of output blocks this problem will compute over @@ -277,7 +279,7 @@ class PersistentTileSchedulerXeMoE { bool valid_tile = true; uint64_t ctas_along_m, ctas_along_n; - int total_problem_groups = problem_shapes.groups(); + int total_problem_groups = num_experts_; ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + divmod_cta_shape_m.divisor - 1); ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + divmod_cta_shape_n.divisor - 1); From 89126cdbf81dc8252dda34b631364dcd008d6cd7 Mon Sep 17 00:00:00 2001 From: Sanchit Jain Date: Sun, 28 Sep 2025 17:48:46 -0700 Subject: [PATCH 3/8] Remove all D2H & H2D transfers --- .../11_bmg_moe_gemm_bf16.cpp | 38 ++++---- .../collective/builders/xe_builder.inl | 8 +- .../epilogue/collective/xe_array_epilogue.hpp | 21 ++--- .../cutlass/epilogue/fusion/operations.hpp | 6 +- .../cutlass/epilogue/fusion/xe_callbacks.hpp | 14 +-- .../cutlass/gemm/collective/xe_array_mma.hpp | 10 ++- include/cutlass/gemm/kernel/xe_moe_gemm.hpp | 89 +++++++++---------- 7 files changed, 97 insertions(+), 89 deletions(-) diff --git a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp index 3aa279b1aa..7052310774 100644 --- a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp +++ b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp @@ -348,10 +348,10 @@ template struct ExampleRunner { // Assign pointers // - std::vector ptr_A_host(options.groups); - std::vector ptr_B_host(options.groups); - std::vector ptr_C_host(options.groups); - std::vector ptr_D_host(options.groups); + std::vector ptr_A_host(1); + std::vector ptr_B_host(1); + std::vector ptr_C_host(1); + std::vector ptr_D_host(1); std::vector ptr_alpha_host(options.groups); std::vector ptr_beta_host(options.groups); @@ -378,17 +378,17 @@ template struct ExampleRunner { } // Allocate device memory & copy from host - ptr_A.reset(options.groups); + ptr_A.reset(1); // Per-group alpha and beta ptr_A.copy_from_host(ptr_A_host.data()); - ptr_B.reset(options.groups); + ptr_B.reset(1); ptr_B.copy_from_host(ptr_B_host.data()); - ptr_C.reset(options.groups); + ptr_C.reset(1); ptr_C.copy_from_host(ptr_C_host.data()); - ptr_D.reset(options.groups); + ptr_D.reset(1); ptr_D.copy_from_host(ptr_D_host.data()); stride_A.reset(options.groups); @@ -489,7 +489,7 @@ template struct ExampleRunner { const int gemm_N, const int gemm_K) { typename Gemm::Arguments arguments; - decltype(arguments.epilogue.thread) fusion_args; + decltype(arguments.fusion_args) fusion_args; bool host_problem_shapes_available = false; if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { // If both alpha/beta are provided (via cmd line args) and are scalar, @@ -524,9 +524,11 @@ template struct ExampleRunner { if (host_problem_shapes_available) { arguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGrouped, - {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), - stride_D.get()}, + ptr_A.get(), + ptr_B.get(), + nullptr, + ptr_D.get(), + fusion_args, hw_info, {1, RasterOrderOptions::AlongN}, options.num_rows_per_expert, @@ -536,9 +538,11 @@ template struct ExampleRunner { } else { arguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGrouped, - {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), - stride_D.get()}, + ptr_A.get(), + ptr_B.get(), + nullptr, + ptr_D.get(), + fusion_args, hw_info, {1, RasterOrderOptions::AlongN}, options.num_rows_per_expert, @@ -669,8 +673,10 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights, cutlass::gemm::KernelXeMoEGEMM>; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + // ScaledAcc needs to be supported in xe_builder.inl and xe_callbacks.cpp + // This is a workaround using EpilogueOp = - cutlass::epilogue::fusion::LinearCombination; + cutlass::epilogue::fusion::LinearCombination; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index b4b3358076..bdd3e9661f 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -49,10 +49,12 @@ namespace detail { template < class ElementD, class ElementCompute, - class ElementC + class ElementC, + cutlass::FloatRoundStyle RoundStyle_, + bool supportSource_ > struct FusionOpInfo> { constexpr static bool HasBuilder = true; @@ -63,7 +65,7 @@ namespace detail { class> using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< DispatchPolicy, - cutlass::epilogue::fusion::LinearCombination, + cutlass::epilogue::fusion::LinearCombination, TileShape_MNK, EpilogueTile >; diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index 88e4b2b50f..89a7c584fe 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -43,7 +43,6 @@ #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" #include "cutlass/detail/layout.hpp" -#include "../tools/util/include/cutlass/util/packed_stride.hpp" #include "cute/tensor.hpp" @@ -115,8 +114,9 @@ class CollectiveEpilogue< using ElementScalar = typename FusionCallbacks::ElementScalar; static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; - static_assert(cute::is_same_v>, + static_assert(cute::is_any_of_v, + fusion::LinearCombination>, "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -140,7 +140,7 @@ class CollectiveEpilogue< Layout{}, make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})))); private: - constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_source_supported = not cute::is_void_v && FusionCallbacks::Operation::IsSourceSupported; constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; public: @@ -244,7 +244,6 @@ class CollectiveEpilogue< Arguments const& args) { constexpr int copy_alignment_bits = 128; constexpr int batch_alignment_bits = 512; - bool implementable = true; bool fusion_implementable = true; @@ -465,7 +464,7 @@ class CollectiveEpilogue< TensorC mC_mnl; TensorD mD_mnl; if constexpr (is_source_supported) { - ElementC const* ptr_C_curr_batch = reinterpret_cast(params.ptr_C[next_group]); + ElementC const* ptr_C_curr_batch = (params.ptr_C == nullptr) ? nullptr : reinterpret_cast(params.ptr_C[next_group]); mC_mnl = make_tensor(make_gmem_ptr(ptr_C_curr_batch), make_layout(make_shape(M, N, L), params.dC[next_group])); } @@ -494,20 +493,22 @@ template ElementC const *ptr_C_curr_batch = reinterpret_cast(params.ptr_C[0]) + cumulative_M * N; + auto c_stride = InternalStrideC{}; + cute::get<0>(c_stride) = N; mC_mnl = make_tensor( make_gmem_ptr(ptr_C_curr_batch), - make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride( - InternalStrideC{}, {M, N, 1}))); + make_layout(make_shape(M, N, L), c_stride)); } if constexpr (is_destination_supported) { ElementD *ptr_D_curr_batch = reinterpret_cast(params.ptr_D[0]) + cumulative_M * N; + auto d_stride = InternalStrideD{}; + cute::get<0>(d_stride) = N; mD_mnl = make_tensor( make_gmem_ptr(ptr_D_curr_batch), - make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride( - InternalStrideD{}, {M, N, 1}))); + make_layout(make_shape(M, N, L), d_stride)); } return cute::make_tuple(mC_mnl, mD_mnl); } diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index c7c94d18f9..d180430a38 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -111,14 +111,16 @@ template< class ElementCompute_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest, + bool supportSource = true > struct LinearCombination : ScaledAcc { using ElementSource = ElementSource_; - static constexpr bool IsSourceSupported = true; + static constexpr bool IsSourceSupported = supportSource; }; + // D = activation(alpha * acc + beta * C) template< template class ActivationFn_, diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 5173d77000..447555c043 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -63,11 +63,12 @@ template < class ElementScalar_, FloatRoundStyle RoundStyle_, class CtaTileShapeMNK_, - class EpilogueTile_ + class EpilogueTile_, + bool supportSource_ > struct FusionCallbacks< epilogue::IntelXeXMX16, - fusion::LinearCombination, + fusion::LinearCombination, CtaTileShapeMNK_, EpilogueTile_ > : Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { @@ -77,7 +78,7 @@ struct FusionCallbacks< using ElementCompute = ElementCompute_; using ElementSource = ElementSource_; using ElementScalar = ElementScalar_; - using Operation = fusion::LinearCombination; + using Operation = fusion::LinearCombination; struct Arguments { ElementScalar alpha = ElementScalar(1); @@ -730,11 +731,12 @@ template < class ElementScalar_, FloatRoundStyle RoundStyle_, class CtaTileShapeMNK_, - class EpilogueTile_ + class EpilogueTile_, + bool supportSource_ > struct FusionCallbacks< epilogue::IntelXeXMX16Group, - fusion::LinearCombination, + fusion::LinearCombination, CtaTileShapeMNK_, EpilogueTile_ > : Sm90LinearCombinationPtrArray::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { @@ -744,7 +746,7 @@ struct FusionCallbacks< using ElementCompute = ElementCompute_; using ElementSource = ElementSource_; using ElementScalar = ElementScalar_; - using Operation = fusion::LinearCombination; + using Operation = fusion::LinearCombination; struct Arguments { ElementScalar alpha = ElementScalar(1); diff --git a/include/cutlass/gemm/collective/xe_array_mma.hpp b/include/cutlass/gemm/collective/xe_array_mma.hpp index 8421d4915c..a27a0c395f 100644 --- a/include/cutlass/gemm/collective/xe_array_mma.hpp +++ b/include/cutlass/gemm/collective/xe_array_mma.hpp @@ -37,7 +37,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "../tools/util/include/cutlass/util/packed_stride.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -316,13 +315,16 @@ template ElementB const *ptr_B_curr_batch = reinterpret_cast(mainloop_params.ptr_B[0]) + next_group * K * N; - + auto a_stride = InternalStrideA{}; + cute::get<0>(a_stride) = K; Tensor mA = make_tensor( make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K, (int32_t)1), - cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + a_stride); + auto b_stride = InternalStrideB{}; + cute::get<0>(b_stride) = K; Tensor mB = make_tensor( make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K, (int32_t)1), - cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + b_stride); return cute::make_tuple(mA, mB); } diff --git a/include/cutlass/gemm/kernel/xe_moe_gemm.hpp b/include/cutlass/gemm/kernel/xe_moe_gemm.hpp index b45e8fcfde..b6199dcac5 100644 --- a/include/cutlass/gemm/kernel/xe_moe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_moe_gemm.hpp @@ -36,8 +36,10 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "../tools/util/include/cutlass/util/packed_stride.hpp" #include "cute/tensor.hpp" + /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -123,8 +125,11 @@ class GemmUniversal< // Device side arguments struct Arguments { GemmUniversalMode mode{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; + const ElementA** A_ptr; + const ElementB** B_ptr; + const ElementC** C_ptr; + ElementD** D_ptr; + decltype(EpilogueArguments{}.thread) fusion_args; KernelHardwareInfo hw_info{}; TileSchedulerArguments scheduler{}; const int *M_per_group{nullptr}; @@ -182,9 +187,26 @@ class GemmUniversal< return {args.mode, dummy_group_problem_shape, CollectiveMainloop::to_underlying_arguments( - dummy_group_problem_shape, args.mainloop, workspace_ptr), + dummy_group_problem_shape, + MainloopArguments{ + args.A_ptr, + nullptr, + args.B_ptr, + nullptr + }, + workspace_ptr + ), CollectiveEpilogue::to_underlying_arguments( - dummy_group_problem_shape, args.epilogue, workspace_ptr), + dummy_group_problem_shape, + EpilogueArguments{ + args.fusion_args, + args.C_ptr, + nullptr, + args.D_ptr, + nullptr + }, + workspace_ptr + ), hw_info, scheduler, workspace, @@ -204,8 +226,21 @@ class GemmUniversal< implementable = implementable && TileScheduler::can_implement(args.scheduler); auto dummy_problem_shape = cute::Shape{256, args.N, args.K}; auto dummy_group_problem_shape = ProblemShape{1, &dummy_problem_shape, nullptr}; - implementable &= CollectiveMainloop::can_implement(dummy_group_problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(dummy_group_problem_shape, args.epilogue); + implementable &= CollectiveMainloop::can_implement(dummy_group_problem_shape, + MainloopArguments{ + args.A_ptr, + nullptr, + args.B_ptr, + nullptr + }); + implementable &= CollectiveEpilogue::can_implement(dummy_group_problem_shape, + EpilogueArguments{ + args.fusion_args, + args.C_ptr, + nullptr, + args.D_ptr, + nullptr + }); return implementable; } @@ -279,48 +314,6 @@ class GemmUniversal< curr_group = work_tile_info.L_idx; problem_shape_MNKL = append<4>(Shape{params.M_per_group[curr_group], N, K}, 1); } - /* - using LayoutA_tiny = cutlass::layout::RowMajor; - using LayoutB_tiny = cutlass::layout::ColumnMajor; - using LayoutC_tiny = cutlass::layout::RowMajor; - using LayoutD_tiny = cutlass::layout::RowMajor; - - using GmemTiledCopyA_tiny = XE_2D_U16x16x32_LD_N; - using GmemTiledCopyB_tiny = XE_2D_U16x16x16_LD_T; - - // Workgroup-level tile - using TileShape_tiny = Shape<_16, _256, _32>; - - using TiledMma_tiny = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 - typename TiledMMAHelper, Layout, - Layout, Stride<_8, _1, _0>>>::TiledMMA; - - - // Dispatch to grouped gemm algorithm - using GEMMDispatchPolicy_tiny = - cutlass::gemm::MainloopIntelXeXMX16Group<2, - cutlass::gemm::KernelXeMoEGEMM>; - using EpilogueDispatchPolicy_tiny = cutlass::epilogue::IntelXeXMX16Group; - - using EpilogueOp_tiny = - cutlass::epilogue::fusion::LinearCombination; - - using CollectiveEpilogue_tiny = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape_tiny, - Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueTileAuto, - float, float, float, LayoutC_tiny, 1, bfloat16_t, LayoutC_tiny, 1, - EpilogueDispatchPolicy_tiny, EpilogueOp_tiny>::CollectiveOp; - - // Mainloop - using CollectiveMainloop_tiny = cutlass::gemm::collective::CollectiveMma< - GEMMDispatchPolicy_tiny, TileShape_tiny, ElementA, - cutlass::gemm::TagToStrideA_t, ElementB, - cutlass::gemm::TagToStrideB_t, TiledMma_tiny, GmemTiledCopyA_tiny, void, - void, cute::identity, // A - GmemTiledCopyB_tiny, void, void, cute::identity // B - >; - */ while (work_tile_info.is_valid()) { auto M = get<0>(problem_shape_MNKL); From 1008714e442475ecc6533712325e0f7c19ac3c41 Mon Sep 17 00:00:00 2001 From: Sanchit Jain Date: Sun, 28 Sep 2025 23:31:00 -0700 Subject: [PATCH 4/8] Revert one change --- include/cutlass/epilogue/collective/xe_array_epilogue.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index 89a7c584fe..9fa8b18c17 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -464,7 +464,7 @@ class CollectiveEpilogue< TensorC mC_mnl; TensorD mD_mnl; if constexpr (is_source_supported) { - ElementC const* ptr_C_curr_batch = (params.ptr_C == nullptr) ? nullptr : reinterpret_cast(params.ptr_C[next_group]); + ElementC const* ptr_C_curr_batch = reinterpret_cast(params.ptr_C[next_group]); mC_mnl = make_tensor(make_gmem_ptr(ptr_C_curr_batch), make_layout(make_shape(M, N, L), params.dC[next_group])); } From 5b6f880076ccc8d56dafe27ebc58ff927f901418 Mon Sep 17 00:00:00 2001 From: Sanchit Jain Date: Mon, 29 Sep 2025 06:16:41 -0700 Subject: [PATCH 5/8] Bug fix --- .../epilogue/collective/xe_array_epilogue.hpp | 13 +++++++------ include/cutlass/gemm/collective/xe_array_mma.hpp | 10 ++++------ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index 9fa8b18c17..2284d5d0d7 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -43,6 +43,7 @@ #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" #include "cutlass/detail/layout.hpp" +#include "../tools/util/include/cutlass/util/packed_stride.hpp" #include "cute/tensor.hpp" @@ -114,6 +115,7 @@ class CollectiveEpilogue< using ElementScalar = typename FusionCallbacks::ElementScalar; static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; + static_assert(cute::is_any_of_v, fusion::LinearCombination>, @@ -244,6 +246,7 @@ class CollectiveEpilogue< Arguments const& args) { constexpr int copy_alignment_bits = 128; constexpr int batch_alignment_bits = 512; + bool implementable = true; bool fusion_implementable = true; @@ -493,22 +496,20 @@ template ElementC const *ptr_C_curr_batch = reinterpret_cast(params.ptr_C[0]) + cumulative_M * N; - auto c_stride = InternalStrideC{}; - cute::get<0>(c_stride) = N; mC_mnl = make_tensor( make_gmem_ptr(ptr_C_curr_batch), - make_layout(make_shape(M, N, L), c_stride)); + make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride( + InternalStrideC{}, {M, N, 1}))); } if constexpr (is_destination_supported) { ElementD *ptr_D_curr_batch = reinterpret_cast(params.ptr_D[0]) + cumulative_M * N; - auto d_stride = InternalStrideD{}; - cute::get<0>(d_stride) = N; mD_mnl = make_tensor( make_gmem_ptr(ptr_D_curr_batch), - make_layout(make_shape(M, N, L), d_stride)); + make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride( + InternalStrideD{}, {M, N, 1}))); } return cute::make_tuple(mC_mnl, mD_mnl); } diff --git a/include/cutlass/gemm/collective/xe_array_mma.hpp b/include/cutlass/gemm/collective/xe_array_mma.hpp index a27a0c395f..8421d4915c 100644 --- a/include/cutlass/gemm/collective/xe_array_mma.hpp +++ b/include/cutlass/gemm/collective/xe_array_mma.hpp @@ -37,6 +37,7 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" +#include "../tools/util/include/cutlass/util/packed_stride.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -315,16 +316,13 @@ template ElementB const *ptr_B_curr_batch = reinterpret_cast(mainloop_params.ptr_B[0]) + next_group * K * N; - auto a_stride = InternalStrideA{}; - cute::get<0>(a_stride) = K; + Tensor mA = make_tensor( make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K, (int32_t)1), - a_stride); - auto b_stride = InternalStrideB{}; - cute::get<0>(b_stride) = K; + cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); Tensor mB = make_tensor( make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K, (int32_t)1), - b_stride); + cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); return cute::make_tuple(mA, mB); } From dad2193efca1d07f885054af9e21eaeecd4e39b8 Mon Sep 17 00:00:00 2001 From: Sanchit Jain Date: Wed, 1 Oct 2025 13:31:56 -0700 Subject: [PATCH 6/8] Eliminate all H2D & D2H copies --- .../11_bmg_moe_gemm_bf16.cpp | 163 +++++------------- .../epilogue/collective/xe_array_epilogue.hpp | 4 +- .../cutlass/gemm/collective/xe_array_mma.hpp | 4 +- 3 files changed, 49 insertions(+), 122 deletions(-) diff --git a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp index 7052310774..69d146d832 100644 --- a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp +++ b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp @@ -280,8 +280,8 @@ template struct ExampleRunner { return passed; } - /// Allocates device-side data - void allocate(const GroupGEMMOptions &options, const ElementA *block_A_ptr, + /// Allocates device-side data for reference GEMM + void allocate_for_ref_gemm(const GroupGEMMOptions &options, const ElementA *block_A_ptr, const ElementA *block_B_ptr, ElementOutput*block_C_ptr, int block_A_size, int block_B_size, int block_C_size) { int64_t total_elements_A = 0; @@ -338,85 +338,10 @@ template struct ExampleRunner { cumsum_device.copy_from_host(cumsum_host); } - /// Initialize operands to be used in the GEMM and reference GEMM - void initialize_for_moe_gemm(const GroupGEMMOptions &options) { + /// Initialize operands to be used in the reference GEMM + void initialize(const GroupGEMMOptions &options) { - problem_sizes.reset(options.groups); - problem_sizes.copy_from_host(options.problem_sizes_host.data()); - - // - // Assign pointers - // - - std::vector ptr_A_host(1); - std::vector ptr_B_host(1); - std::vector ptr_C_host(1); - std::vector ptr_D_host(1); - std::vector ptr_alpha_host(options.groups); - std::vector ptr_beta_host(options.groups); - - // Compute offsets, alpha & beta over group on host - - ptr_A_host.at(0) = block_A.get(); - ptr_B_host.at(0) = block_B.get(); - ptr_C_host.at(0) = block_C.get(); - ptr_D_host.at(0) = block_D.get(); - for (int32_t i = 0; i < options.groups; ++i) { - // Fill host vector of alpha & beta with random values if using per-group - // values - alpha_host.push_back( - (options.alpha == FLT_MAX) - ? static_cast((rand() % 5) + 1) - : options.alpha); - beta_host.push_back((options.beta == FLT_MAX) - ? static_cast(rand() % 5) - : options.beta); - // Fill host ptr vectors with offset addresses into device alpha/beta - // blocks - ptr_alpha_host.at(i) = block_alpha.get() + i; - ptr_beta_host.at(i) = block_beta.get() + i; - } - - // Allocate device memory & copy from host - ptr_A.reset(1); - // Per-group alpha and beta - ptr_A.copy_from_host(ptr_A_host.data()); - - ptr_B.reset(1); - ptr_B.copy_from_host(ptr_B_host.data()); - - ptr_C.reset(1); - ptr_C.copy_from_host(ptr_C_host.data()); - - ptr_D.reset(1); - ptr_D.copy_from_host(ptr_D_host.data()); - - stride_A.reset(options.groups); - stride_A.copy_from_host(stride_A_host.data()); - - stride_B.reset(options.groups); - stride_B.copy_from_host(stride_B_host.data()); - - stride_C.reset(options.groups); - stride_C.copy_from_host(stride_C_host.data()); - - stride_D.reset(options.groups); - stride_D.copy_from_host(stride_D_host.data()); - - // Per-group alpha and beta ptrs - alpha_device.reset(options.groups); - alpha_device.copy_from_host(ptr_alpha_host.data()); - beta_device.reset(options.groups); - beta_device.copy_from_host(ptr_beta_host.data()); - - // Per-group alpha and beta values - note these are not directly passed to - // kernel - the pointers (alpha_device/beta_device) are passed instead - block_alpha.copy_from_host(alpha_host.data()); - block_beta.copy_from_host(beta_host.data()); - } - - /// Initialize operands to be used in the GEMM and reference GEMM - void initialize_for_ref_gemm(const GroupGEMMOptions &options) { + uint64_t seed = 2020; problem_sizes.reset(options.groups); problem_sizes.copy_from_host(options.problem_sizes_host.data()); @@ -438,8 +363,10 @@ template struct ExampleRunner { ptr_B_host.at(i) = block_B.get() + offset_B.at(i); ptr_C_host.at(i) = block_C.get() + offset_C.at(i); ptr_D_host.at(i) = block_D.get() + offset_D.at(i); - // Fill host ptr vectors with offset addresses into device alpha/beta - // blocks + // Fill host vector of alpha & beta with random values if using per-group values + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + // Fill host ptr vectors with offset addresses into device alpha/beta blocks ptr_alpha_host.at(i) = block_alpha.get() + i; ptr_beta_host.at(i) = block_beta.get() + i; } @@ -475,9 +402,8 @@ template struct ExampleRunner { alpha_device.copy_from_host(ptr_alpha_host.data()); beta_device.reset(options.groups); beta_device.copy_from_host(ptr_beta_host.data()); - - // Per-group alpha and beta values - note these are not directly passed to - // kernel - the pointers (alpha_device/beta_device) are passed instead + // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers + // (alpha_device/beta_device) are passed instead block_alpha.copy_from_host(alpha_host.data()); block_beta.copy_from_host(beta_host.data()); } @@ -486,6 +412,9 @@ template struct ExampleRunner { typename Gemm::Arguments args_from_options(const GroupGEMMOptions &options, const cutlass::KernelHardwareInfo &hw_info, + const ElementA* A_ptr, + const ElementB* B_ptr, + ElementOutput* D_ptr, const int gemm_N, const int gemm_K) { typename Gemm::Arguments arguments; @@ -494,8 +423,8 @@ template struct ExampleRunner { if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { // If both alpha/beta are provided (via cmd line args) and are scalar, // i.e., same alpha/beta applies to all batches. - fusion_args.alpha = options.alpha; - fusion_args.beta = options.beta; + fusion_args.alpha = 1; + fusion_args.beta = 0; fusion_args.alpha_ptr = nullptr; fusion_args.beta_ptr = nullptr; fusion_args.alpha_ptr_array = nullptr; @@ -506,12 +435,12 @@ template struct ExampleRunner { } else { // If pointers to alpha/beta are provided, i.e., alpha/beta can differ // between batches/groups. - fusion_args.alpha = 0; + fusion_args.alpha = 1; fusion_args.beta = 0; fusion_args.alpha_ptr = nullptr; fusion_args.beta_ptr = nullptr; - fusion_args.alpha_ptr_array = alpha_device.get(); - fusion_args.beta_ptr_array = beta_device.get(); + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; // One alpha and beta per each group fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; @@ -523,11 +452,11 @@ template struct ExampleRunner { // Per-GEMM problem shape info may only exist on the device. if (host_problem_shapes_available) { arguments = typename Gemm::Arguments{ - cutlass::gemm::GemmUniversalMode::kGrouped, - ptr_A.get(), - ptr_B.get(), - nullptr, - ptr_D.get(), + cutlass::gemm::GemmUniversalMode::kGrouped, // this just means grouped GEMM + static_cast((void*)A_ptr), + static_cast((void*)B_ptr), + static_cast((void*)D_ptr), // we could also pass nullptr + static_cast((void*)D_ptr), fusion_args, hw_info, {1, RasterOrderOptions::AlongN}, @@ -538,10 +467,10 @@ template struct ExampleRunner { } else { arguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGrouped, - ptr_A.get(), - ptr_B.get(), - nullptr, - ptr_D.get(), + static_cast((void*)A_ptr), + static_cast((void*)B_ptr), + static_cast((void*)D_ptr), + static_cast((void*)D_ptr), fusion_args, hw_info, {1, RasterOrderOptions::AlongN}, @@ -557,12 +486,11 @@ template struct ExampleRunner { cutlass::Status run(const GroupGEMMOptions &options, const cutlass::KernelHardwareInfo &hw_info, const ElementA *A_ptr, const ElementB *B_ptr, - ElementOutput *C_ptr, int A_size, int B_size, int D_size, const int gemm_n, const int gemm_k) { - allocate(options, A_ptr, B_ptr, C_ptr, A_size, B_size, D_size); - initialize_for_moe_gemm(options); + ElementOutput *D_ptr, int A_size, int B_size, int D_size, const int gemm_n, const int gemm_k) { + allocate_for_ref_gemm(options, A_ptr, B_ptr, D_ptr, A_size, B_size, D_size); Gemm gemm_op; - auto arguments = args_from_options(options, hw_info, gemm_n, gemm_k); + auto arguments = args_from_options(options, hw_info, A_ptr, B_ptr, D_ptr, gemm_n, gemm_k); size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); @@ -575,15 +503,14 @@ template struct ExampleRunner { CUTLASS_CHECK(gemm_op.run()); syclcompat::wait(); - initialize_for_ref_gemm(options); + initialize(options); // Verify that the result is correct bool passed = verify(options); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if (!passed) return cutlass::Status::kErrorInternal; - initialize_for_moe_gemm(options); syclcompat::wait(); - arguments = args_from_options(options, hw_info, gemm_n, gemm_k); + arguments = args_from_options(options, hw_info, A_ptr, B_ptr, D_ptr, gemm_n, gemm_k); CUTLASS_CHECK(gemm_op.can_implement(arguments)); CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); @@ -647,20 +574,15 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights, hw_info.device_id); using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; -/* - using TiledMma = - TiledMMA, - Layout, Stride<_4, _1, _0>>>; -*/ using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 typename TiledMMAHelper, Layout, @@ -710,18 +632,23 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights, int main(int argc, const char **argv) { - const int num_experts = 32; + const int num_experts = 16; - int total_rows_for_each_expert[num_experts] = { + /* int total_rows_for_each_expert[num_experts] = { 148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, 845, 191, 424, 30, 97, 57, 324, - 62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}; + 62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}; */ + + int total_rows_for_each_expert[num_experts]; + for (int i = 0; i < num_experts; i++) { + total_rows_for_each_expert[i] = 512; + } int num_tokens_incl_duplicated = 0; for (int i = 0; i < num_experts; i++) { num_tokens_incl_duplicated += total_rows_for_each_expert[i]; } - int n_moe = 3072; - int k_moe = 4096; + int n_moe = 16384; + int k_moe = 5120; cutlass::DeviceAllocation num_rows_per_expert_device; cutlass::DeviceAllocation activations_data; diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index 2284d5d0d7..b745ab02f2 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -494,7 +494,7 @@ template TensorD mD_mnl; if constexpr (is_source_supported) { ElementC const *ptr_C_curr_batch = - reinterpret_cast(params.ptr_C[0]) + + reinterpret_cast((void*)(params.ptr_C)) + cumulative_M * N; mC_mnl = make_tensor( make_gmem_ptr(ptr_C_curr_batch), @@ -504,7 +504,7 @@ template if constexpr (is_destination_supported) { ElementD *ptr_D_curr_batch = - reinterpret_cast(params.ptr_D[0]) + + reinterpret_cast((void*)(params.ptr_D)) + cumulative_M * N; mD_mnl = make_tensor( make_gmem_ptr(ptr_D_curr_batch), diff --git a/include/cutlass/gemm/collective/xe_array_mma.hpp b/include/cutlass/gemm/collective/xe_array_mma.hpp index 8421d4915c..b9d4b939df 100644 --- a/include/cutlass/gemm/collective/xe_array_mma.hpp +++ b/include/cutlass/gemm/collective/xe_array_mma.hpp @@ -311,10 +311,10 @@ template const int32_t K = get<2>(problem_shape_mnkl); ElementA const *ptr_A_curr_batch = - reinterpret_cast(mainloop_params.ptr_A[0]) + + reinterpret_cast((void*)(mainloop_params.ptr_A)) + cumulative_M * K; ElementB const *ptr_B_curr_batch = - reinterpret_cast(mainloop_params.ptr_B[0]) + + reinterpret_cast((void*)(mainloop_params.ptr_B)) + next_group * K * N; Tensor mA = make_tensor( From d29944f2435db2401364f1174b8815395cb2c168 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Tue, 7 Oct 2025 11:23:50 -0700 Subject: [PATCH 7/8] Update 11_bmg_moe_gemm_bf16.cpp --- .../11_bmg_moe_gemm_bf16.cpp | 96 ++++++++++++++----- 1 file changed, 70 insertions(+), 26 deletions(-) diff --git a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp index 69d146d832..7e992074f0 100644 --- a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp +++ b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp @@ -142,21 +142,28 @@ struct GroupGEMMOptions { } /// Compute performance in GFLOP/s - double gflops(double runtime_s, + std::tuple gflops(double runtime_s, std::vector problem_sizes_host) const { // Number of real-valued multiply-adds uint64_t fmas = uint64_t(); + uint64_t bytes_loaded = 0; for (auto const &problem : problem_sizes_host) { - fmas += static_cast(get<0>(problem)) * - static_cast(get<1>(problem)) * - static_cast(get<2>(problem)); + auto M = static_cast(get<0>(problem)); + auto N = static_cast(get<1>(problem)); + auto K = static_cast(get<2>(problem)); + fmas += M * N * K; + bytes_loaded += /* sizeof(cutlass::bfloat16_t) */ 2 * (2 * M * N + N * K + M * K); } // Two flops per multiply-add uint64_t flop = uint64_t(2) * uint64_t(fmas); double gflop = double(flop) / double(1.0e9); - return gflop / runtime_s; + double arithmetic_intensity = double(flop) / double(bytes_loaded); + double peak_mwm_bw = 456.0; + double gflops_attainable = std::min(117 * double(1.0e12), arithmetic_intensity * (peak_mwm_bw * 1024 * 1024 * 1024)); + double projected_time = flop/gflops_attainable; + return std::make_tuple(gflop / runtime_s, double(bytes_loaded) / 1024 / 1024 / 1024 / runtime_s, projected_time * 1000); } }; @@ -455,7 +462,7 @@ template struct ExampleRunner { cutlass::gemm::GemmUniversalMode::kGrouped, // this just means grouped GEMM static_cast((void*)A_ptr), static_cast((void*)B_ptr), - static_cast((void*)D_ptr), // we could also pass nullptr + nullptr,//static_cast((void*)D_ptr), // we could also pass nullptr static_cast((void*)D_ptr), fusion_args, hw_info, @@ -469,7 +476,7 @@ template struct ExampleRunner { cutlass::gemm::GemmUniversalMode::kGrouped, static_cast((void*)A_ptr), static_cast((void*)B_ptr), - static_cast((void*)D_ptr), + nullptr, // static_cast((void*)D_ptr), static_cast((void*)D_ptr), fusion_args, hw_info, @@ -525,7 +532,7 @@ template struct ExampleRunner { float cute_time = timer.seconds() * 1000; double cute_average_time = double(cute_time) / double(options.iterations); - double gflops = options.gflops(cute_average_time / 1000.0, + auto [gflops, mem_bw_util, projected_time] = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host); std::cout << " Problem Sizes, Alpha, Beta " << std::endl; @@ -538,6 +545,7 @@ template struct ExampleRunner { std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; std::cout << " GFLOPS : " << gflops << std::endl; + std::cout << " Memory BW utilization : " << mem_bw_util << " GBPs" << std::endl; } return cutlass::Status::kSuccess; @@ -630,26 +638,23 @@ void MoEGEMM(const bfloat16_t *activations, const bfloat16_t *weights, num_rows_per_expert_obj.release(); } - -int main(int argc, const char **argv) { - const int num_experts = 16; - - /* int total_rows_for_each_expert[num_experts] = { - 148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, 845, 191, 424, 30, 97, 57, 324, - 62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}; */ - - int total_rows_for_each_expert[num_experts]; - for (int i = 0; i < num_experts; i++) { - total_rows_for_each_expert[i] = 512; - } - +void launcher(int* M_per_expert, int N, int K, const int& num_experts) { + int n_moe = N; + int k_moe = K; int num_tokens_incl_duplicated = 0; - for (int i = 0; i < num_experts; i++) { - num_tokens_incl_duplicated += total_rows_for_each_expert[i]; + for(int i=0; i < num_experts; i++) { + num_tokens_incl_duplicated += M_per_expert[i]; } - int n_moe = 16384; - int k_moe = 5120; + float M_occupancy = 0.f; + float actual_num_units = 0.f; + int total_num_M_tiles = 0; + for (int i=0; i < num_experts; i++) { + total_num_M_tiles += (M_per_expert[i] + 63)/64; + actual_num_units += M_per_expert[i]/64.0; + } + M_occupancy = actual_num_units / total_num_M_tiles; + std::cout << "\n\n M-occupancy is " << M_occupancy << std::endl; cutlass::DeviceAllocation num_rows_per_expert_device; cutlass::DeviceAllocation activations_data; cutlass::DeviceAllocation weights_data; @@ -658,7 +663,7 @@ int main(int argc, const char **argv) { size_t B_size = num_experts * n_moe * k_moe; size_t D_size = num_tokens_incl_duplicated * n_moe; num_rows_per_expert_device.reset(num_experts); - num_rows_per_expert_device.copy_from_host(total_rows_for_each_expert); + num_rows_per_expert_device.copy_from_host(M_per_expert); activations_data.reset(A_size); weights_data.reset(B_size); output_data.reset(D_size); @@ -672,5 +677,44 @@ int main(int argc, const char **argv) { weights_data.release(); output_data.release(); num_rows_per_expert_device.release(); +} + + +int main(int argc, const char **argv) { + constexpr int num_experts = 32; + constexpr int num_layers = 24; + + int total_rows_for_each_expert[num_layers][num_experts] = { + {148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, 845, 191, 424, 30, 97, 57, 324, 62, 77, 75, 144, 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}, + {666, 214, 448, 87, 4, 28, 48, 13, 74, 40, 546, 397, 487, 350, 26, 95, 517, 487, 295, 58, 637, 97, 139, 33, 126, 15, 352, 311, 995, 193, 135, 135}, + {1016, 30, 36, 452, 469, 473, 232, 0, 493, 14, 954, 6, 4, 6, 279, 3, 94, 106, 96, 48, 49, 113, 142, 169, 75, 99, 25, 220, 249, 289, 4, 1803}, + {350, 229, 703, 154, 8, 64, 80, 339, 2, 56, 5, 312, 1005, 29, 9, 11, 23, 0, 23, 431, 48, 129, 496, 476, 8, 1234, 7, 130, 34, 58, 41, 1554}, + {39, 10, 6, 2, 110, 1, 894, 8, 53, 0, 275, 6, 506, 421, 700, 178, 0, 530, 1623, 15, 231, 74, 6, 222, 1246, 116, 35, 20, 0, 6, 381, 334}, + {399, 5, 201, 6, 134, 93, 1748, 1, 51, 4, 38, 336, 53, 88, 328, 724, 15, 388, 706, 52, 19, 55, 52, 33, 623, 1, 222, 215, 69, 45, 308, 1036}, + {11, 8, 407, 571, 458, 275, 197, 211, 13, 564, 462, 114, 15, 13, 132, 24, 514, 2, 71, 13, 694, 47, 16, 203, 610, 40, 0, 1587, 66, 23, 196, 491}, + {0, 230, 116, 136, 315, 643, 6, 183, 37, 26, 960, 1, 8, 258, 21, 1602, 213, 198, 6, 196, 455, 557, 47, 282, 493, 18, 101, 11, 616, 45, 268, 0}, + {392, 305, 179, 14, 227, 98, 114, 39, 64, 1456, 465, 0, 18, 372, 0, 0, 189, 257, 25, 290, 486, 0, 12, 1534, 468, 4, 555, 35, 146, 0, 161, 143}, + {4, 107, 20, 125, 236, 898, 0, 0, 375, 2, 125, 0, 0, 1429, 36, 195, 1660, 0, 127, 454, 73, 358, 47, 79, 32, 20, 1465, 0, 0, 6, 109, 66}, + {19, 0, 0, 0, 2, 1638, 75, 135, 392, 2, 1494, 3, 23, 5, 4, 58, 0, 0, 71, 1285, 8, 441, 0, 145, 209, 408, 450, 2, 824, 13, 326, 16}, + {4, 2, 14, 0, 30, 206, 41, 131, 0, 429, 16, 895, 35, 21, 44, 128, 12, 0, 417, 0, 838, 917, 42, 115, 109, 1759, 0, 36, 17, 0, 1790, 0}, + {6, 483, 241, 1327, 17, 11, 480, 9, 880, 58, 4, 0, 61, 30, 16, 176, 9, 309, 26, 0, 0, 1882, 4, 281, 475, 783, 197, 0, 19, 15, 6, 243}, + {370, 1222, 0, 6, 108, 929, 2, 7, 157, 348, 149, 106, 2, 5, 25, 33, 1569, 8, 6, 106, 69, 1298, 0, 2, 529, 520, 0, 421, 0, 25, 26, 0}, + {59, 89, 0, 26, 25, 40, 1873, 141, 527, 371, 262, 62, 16, 0, 127, 234, 1637, 64, 132, 8, 0, 7, 161, 1005, 22, 1, 49, 6, 83, 925, 80, 16}, + {269, 617, 30, 4, 90, 26, 0, 16, 154, 212, 21, 269, 379, 174, 129, 32, 8, 121, 344, 15, 0, 591, 1494, 6, 737, 50, 112, 856, 483, 25, 454, 330}, + {0, 98, 1488, 22, 73, 0, 0, 343, 77, 4, 0, 612, 165, 268, 4, 10, 43, 0, 598, 271, 2, 73, 185, 0, 112, 779, 24, 1626, 0, 0, 0, 1171}, + {0, 0, 0, 189, 266, 1743, 0, 462, 20, 7, 668, 310, 40, 0, 10, 236, 423, 18, 0, 0, 0, 999, 0, 139, 1754, 8, 619, 3, 23, 0, 102, 9}, + {131, 1753, 0, 113, 24, 94, 2, 12, 108, 0, 0, 252, 97, 0, 1319, 233, 93, 1254, 195, 152, 14, 413, 4, 2, 220, 67, 20, 4, 34, 559, 837, 42}, + {55, 76, 0, 8, 0, 3, 1557, 975, 135, 271, 4, 0, 0, 666, 207, 152, 5, 2, 97, 364, 0, 13, 1423, 771, 159, 31, 223, 0, 431, 7, 409, 4}, + {4, 1026, 1799, 166, 694, 753, 0, 16, 0, 240, 1119, 19, 6, 0, 46, 659, 10, 0, 112, 808, 181, 0, 28, 22, 90, 0, 176, 0, 37, 5, 10, 22}, + {44, 0, 4, 153, 299, 1357, 6, 23, 0, 12, 4, 419, 73, 24, 16, 24, 1, 4, 4, 102, 16, 4, 0, 1953, 1850, 0, 908, 4, 0, 13, 708, 23}, + {6, 13, 123, 28, 197, 0, 202, 69, 0, 6, 0, 21, 1434, 1582, 11, 0, 6, 0, 7, 190, 4, 1700, 6, 434, 1886, 0, 14, 28, 8, 30, 25, 18}, + {5, 27, 1442, 18, 0, 6, 0, 73, 6, 781, 0, 1915, 291, 649, 98, 4, 33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9} + }; + + for (int i = 0; i < num_layers; i++) { + launcher(total_rows_for_each_expert[i], 5760, 2880, num_experts); + launcher(total_rows_for_each_expert[i], 2880, 2880, num_experts); + } + return 0; } From 15459826c633fd808b5a235efa2ab344bed63f75 Mon Sep 17 00:00:00 2001 From: Sanchit Jain Date: Tue, 14 Oct 2025 22:11:16 -0700 Subject: [PATCH 8/8] Rebase to use compat API --- examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp index 7e992074f0..17d3297373 100644 --- a/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp +++ b/examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp @@ -274,7 +274,7 @@ template struct ExampleRunner { ); // Wait for kernel to finish - syclcompat::wait(); + compat::wait(); // Check if output from CUTLASS kernel and reference kernel are equal or // not @@ -509,14 +509,14 @@ template struct ExampleRunner { // Run the GEMM CUTLASS_CHECK(gemm_op.run()); - syclcompat::wait(); + compat::wait(); initialize(options); // Verify that the result is correct bool passed = verify(options); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if (!passed) return cutlass::Status::kErrorInternal; - syclcompat::wait(); + compat::wait(); arguments = args_from_options(options, hw_info, A_ptr, B_ptr, D_ptr, gemm_n, gemm_k); CUTLASS_CHECK(gemm_op.can_implement(arguments)); @@ -528,7 +528,7 @@ template struct ExampleRunner { for (int iter = 0; iter < options.iterations; ++iter) { CUTLASS_CHECK(gemm_op.run()); } - syclcompat::wait(); + compat::wait(); float cute_time = timer.seconds() * 1000; double cute_average_time = double(cute_time) / double(options.iterations); @@ -718,3 +718,4 @@ int main(int argc, const char **argv) { return 0; } +