diff --git a/examples/00_bmg_gemm/00_bmg_gemm_padded.cpp b/examples/00_bmg_gemm/00_bmg_gemm_padded.cpp index b231825fe7..3e9fa8e6dd 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm_padded.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm_padded.cpp @@ -39,7 +39,7 @@ This example makes use of BMGs subgroup cooperative 2d-block copy operations and DPAS instructions. To support more input shapes using these instructions, rows of the input/output matrices are padded - to a multiple of 16 and each matrix in batch is padded to a multiple of 64, as required by these + to a multiple of 16 and each matrix in batch is padded to a multiple of 64, as required by these instructions. The shapes of the A and B matrix are defined at runtime by `options.m`, `.n` and `.k`, and the @@ -161,14 +161,14 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementD = typename Gemm::ElementD; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -200,7 +200,7 @@ struct ExampleRunner { bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { auto [M, N, K, L] = problem_size; - + // Padded values // The inner dimension is padded. Since this example is all RowMajor, // we require the following: @@ -208,7 +208,7 @@ struct ExampleRunner { int N_C = cute::round_up(N, AlignElemC); int N_D = cute::round_up(N, AlignElemD); int K_A = cute::round_up(K, AlignElemA); - + int AlignmentOuter = AlignmentPtr / AlignmentInner; int M_ACD = cute::round_up(M, AlignmentOuter); int K_B = cute::round_up(K, AlignmentOuter); @@ -383,9 +383,13 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - // The 2D block copy operations used for the A and B matrices - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + // [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects + // appropriate 2D block copy operations for matrices A and B. Alternatively, you can + // explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI, + // or XE_LOAD_2D_TRANSPOSE. + // Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md + using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>; + using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; @@ -393,21 +397,21 @@ int main(int argc, const char** argv) // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional // hardware (sub-groups for Intel BMG) and iterations by each sub-group. // - // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom - // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The - // TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses + // the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with + //float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1). + // The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a // single contiguous chunk of the work-group TileShape. For this configuration, this implies that // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for // performance reasons. - using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B. constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + // For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; // This is the 'default' epilogue operation (Linear Combination) which performs everything in: // (D = alpha * (A*B) + beta * C) @@ -418,22 +422,21 @@ int main(int argc, const char** argv) // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch // policy/architecture) and defines the epilogue arguments. - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any // auxiliary data required using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, // Epilogue tile (void = automatic) ElementAccumulator, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation ElementOutput, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation - FusionCallBacks, - XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C - void, void, - XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D - void, void>; + FusionCallbacks, + void, // The copy atom used to load matrix C (void = automatic) + void>; // The copy atom used to store matrix D (void = automatic) // GEMM Mainloop - iteration over blocks in K dimension using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< diff --git a/examples/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp b/examples/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp index 67e1193e75..98d0704a93 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp @@ -136,13 +136,12 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -348,42 +347,50 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + // [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects + // appropriate 2D block copy operations for matrices A and B. Alternatively, you can + // explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI, + // or XE_LOAD_2D_TRANSPOSE. + // Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md + using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>; + using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - // The Tile of this layout describes how 8x4x1 sub-groups tile the TileShape of <256, 256, 32>. - // This permutation (which can be thought of as a scatter operation on the default tiling) - // ensures that each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations) - // See 0t_mma_atom.md#TiledMMAs for more info. - // Sub-groups are arranged row-major (stride 4,1,0) for performance reasons. - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional + // hardware (sub-groups for Intel BMG) and iterations by each sub-group. + // + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses + // the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with + //float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1). + // The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a + // single contiguous chunk of the work-group TileShape. For this configuration, this implies that + // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See + // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for + // performance reasons. + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + FusionCallbacks, + void, + void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< diff --git a/examples/00_bmg_gemm/legacy/00_bmg_gemm_padded.cpp b/examples/00_bmg_gemm/legacy/00_bmg_gemm_padded.cpp new file mode 100644 index 0000000000..b231825fe7 --- /dev/null +++ b/examples/00_bmg_gemm/legacy/00_bmg_gemm_padded.cpp @@ -0,0 +1,467 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Gemm Example. + + This example constructs and executes a simple CUTLASS GEMM kernel on Intel BMG hardware, and + verifies its correctness with a reference implementation + (cutlass::reference::device::GemmComplex). The example also provides a performance measurement + for the GEMM in TFLOPS. + + This example makes use of BMGs subgroup cooperative 2d-block copy operations and DPAS instructions. + To support more input shapes using these instructions, rows of the input/output matrices are padded + to a multiple of 16 and each matrix in batch is padded to a multiple of 64, as required by these + instructions. + + The shapes of the A and B matrix are defined at runtime by `options.m`, `.n` and `.k`, and the + batch size is defined by `options.l`. The tile shape, which defines how much work is executed by + a single work-group, is defined at compile time by: + ``` + using TileShape = Shape<_256, _256, _32>; + ``` + That is, each work-group processes a tile of M=256, N=256, and iterates over `options.k` in + blocks of K=32. + + Performance of GEMM on BMG is heavily dependent on prefetching the A and B matrices. That is, + executing Intel specific prefetch instructions for future iterations to ensure that the required + blocks of A and B are resident in cache before they are needed. + + To build & run this example (from your build dir): + + $ ninja 00_bmg_gemm + $ ./examples/sycl/00_bmg_gemm/00_bmg_gemm + + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// The alignment requirement in bytes on inner dimmension that will work for both PVC and BMG +constexpr int AlignmentInner = 16; +// The alignment requirement in bytes on outer dimmension that will work for both PVC and BMG +constexpr int AlignmentPtr = 64; + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(20), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementD = typename Gemm::ElementD; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + static constexpr int AlignElemA = AlignmentInner / sizeof(ElementA); + static constexpr int AlignElemB = AlignmentInner / sizeof(ElementB); + static constexpr int AlignElemC = AlignmentInner / sizeof(ElementB); + static constexpr int AlignElemD = AlignmentInner / sizeof(ElementD); + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; // Reference GEMM result for verification + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + // Padded values + // The inner dimension is padded. Since this example is all RowMajor, + // we require the following: + int N_B = cute::round_up(N, AlignElemB); + int N_C = cute::round_up(N, AlignElemC); + int N_D = cute::round_up(N, AlignElemD); + int K_A = cute::round_up(K, AlignElemA); + + int AlignmentOuter = AlignmentPtr / AlignmentInner; + int M_ACD = cute::round_up(M, AlignmentOuter); + int K_B = cute::round_up(K, AlignmentOuter); + + cutlass::TensorRef ref_A(block_A.get(), LayoutA(K_A)); + cutlass::TensorRef ref_B(block_B.get(), LayoutB(N_B)); + cutlass::TensorRef ref_C(block_C.get(), LayoutC(N_C)); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD(N_D)); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M_ACD * K_A, // batch_stride_A + K_B * N_B, // batch_stride_B + M_ACD * N_C, // batch_stride_C + M_ACD * N_D // batch_stride_D + ); + + // CUTLASS on SYCL uses the compatibility library compat for e.g. default in-order queue + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + // Padded values + int N_B = cute::round_up(N, AlignElemB); + int N_C = cute::round_up(N, AlignElemC); + int N_D = cute::round_up(N, AlignElemD); + int K_A = cute::round_up(K, AlignElemA); + + int AlignmentOuter = AlignmentPtr / AlignmentInner; + int M_ACD = cute::round_up(M, AlignmentOuter); + int K_B = cute::round_up(K, AlignmentOuter); + + // Complete the stride by combining static layout info (StrideA) with runtime size info (M,K,L) + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M_ACD, K_A, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N_B, K_B, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M_ACD, N_C, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M_ACD, N_D, L)); + + block_A.reset(M_ACD * K_A * L); + block_B.reset(K_B * N_B * L); + block_C.reset(M_ACD * N_C * L); + block_D.reset(M_ACD * N_D * L); + block_ref_D.reset(M_ACD * N_D * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { + std::cout << "Warning: Invalid problem size: " + << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l + << ".\nThis size is not directly supported by the selected kernel.\n" + << "However, this example applies padding as needed, so it will still run correctly." + << std::endl; + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + compat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + // The 2D block copy operations used for the A and B matrices + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional + // hardware (sub-groups for Intel BMG) and iterations by each sub-group. + // + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom + // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The + // TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a + // single contiguous chunk of the work-group TileShape. For this configuration, this implies that + // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See + // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for + // performance reasons. + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B. + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // This is the 'default' epilogue operation (Linear Combination) which performs everything in: + // (D = alpha * (A*B) + beta * C) + // aside from the (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for more + // complex epilogue examples. + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch + // policy/architecture) and defines the epilogue arguments. + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any + // auxiliary data required + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + ElementOutput, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + FusionCallBacks, + XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C + void, void, + XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D + void, void>; + + // GEMM Mainloop - iteration over blocks in K dimension + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + ElementInputB, + cutlass::gemm::TagToStrideB_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + // Define the whole kernel (mainloop and epilogue) + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Defer global problem shape definition to runtime + CollectiveMainloop, + CollectiveEpilogue + >; + + // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g. + // persistent scratch memory if required. + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/00_bmg_gemm/legacy/00_bmg_gemm_with_sycl_queue.cpp b/examples/00_bmg_gemm/legacy/00_bmg_gemm_with_sycl_queue.cpp new file mode 100644 index 0000000000..67e1193e75 --- /dev/null +++ b/examples/00_bmg_gemm/legacy/00_bmg_gemm_with_sycl_queue.cpp @@ -0,0 +1,414 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Intel BMG Gemm Example with non-default SYCL queue. + This example modifies 00_bmg_gemm to use a non-default queue. The main changes are passing the + queue to gemm_op.initialize and gemm_op.run. Otherwise, changes are made to allocate memory with + the correct queue. + + To build & run this example (from your build dir): + $ ninja 00_bmg_gemm_with_sycl_queue + $ ./examples/sycl/00_bmg_gemm_with_sycl_queue/00_bmg_gemm_with_sycl_queue + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(20), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + struct Memory { + ElementA* block_A; + ElementB* block_B; + ElementC* block_C; + ElementOutput* block_D; + ElementOutput* block_ref_D; + sycl::queue q; + + Memory(sycl::queue q, ProblemShapeType problem_shape_MNKL) : q(q) { + auto [M, N, K, L] = problem_shape_MNKL; + block_A = sycl::malloc_device(static_cast(M) * K * L, q); + block_B = sycl::malloc_device(static_cast(N) * K * L, q); + block_C = sycl::malloc_device(static_cast(M) * N * L, q); + block_D = sycl::malloc_device(static_cast(M) * N * L, q); + block_ref_D = sycl::malloc_device(static_cast(M) * N * L, q); + } + + ~Memory() { + sycl::free(block_A, q); + sycl::free(block_B, q); + sycl::free(block_C, q); + sycl::free(block_D, q); + sycl::free(block_ref_D, q); + } + + // delete other constructors so avoiding leaks is easy + Memory(const Memory&) = delete; + Memory(Memory&&) noexcept = delete; + Memory& operator=(const Memory&) = delete; + Memory& operator=(Memory&&) noexcept = delete; + }; + + // + // Methods + // + + bool verify(Memory& mem, const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(mem.block_A, LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(mem.block_B, LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(mem.block_C, LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(mem.block_ref_D, LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + mem.block_ref_D, mem.block_D, M * N * L); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size, Memory& mem) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + cutlass::initialize_block(mem.block_A, M * K * L, seed + 2023); + cutlass::initialize_block(mem.block_B, N * K * L, seed + 2022); + cutlass::initialize_block(mem.block_C, M * N * L, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + auto q = compat::create_queue(); + Memory mem(q, problem_size); + initialize(problem_size, mem); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {mem.block_A, stride_A, mem.block_B, stride_B}, + {{options.alpha, options.beta}, mem.block_C, stride_C, mem.block_D, stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + if (workspace_size != 0) { + return cutlass::Status::kErrorInternal; + } + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){ + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, nullptr, &q)); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run(&q)); + + q.wait_and_throw(); + + // Verify that the result is correct + bool passed = verify(mem, problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(&q); + } + + q.wait_and_throw(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + // The Tile of this layout describes how 8x4x1 sub-groups tile the TileShape of <256, 256, 32>. + // This permutation (which can be thought of as a scatter operation on the default tiling) + // ensures that each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations) + // See 0t_mma_atom.md#TiledMMAs for more info. + // Sub-groups are arranged row-major (stride 4,1,0) for performance reasons. + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/00_bmg_gemm/legacy/CMakeLists.txt b/examples/00_bmg_gemm/legacy/CMakeLists.txt index 1d40199221..92ff95ffc5 100644 --- a/examples/00_bmg_gemm/legacy/CMakeLists.txt +++ b/examples/00_bmg_gemm/legacy/CMakeLists.txt @@ -39,3 +39,19 @@ cutlass_example_add_executable( TEST_LARGE TEST_SMALL_SHAPE ) + +set(TEST_SMALL_SHAPE_PADDABLE --m=1 --n=1 --k=2 --l=2) +cutlass_example_add_executable( + 00_bmg_gemm_padded_legacy + 00_bmg_gemm_padded.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES + TEST_SMALL_SHAPE_PADDABLE +) + +cutlass_example_add_executable( + 00_bmg_gemm_with_sycl_queue_legacy + 00_bmg_gemm_with_sycl_queue.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) diff --git a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu.cpp b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu.cpp index 0d330b0360..e9e82c59de 100644 --- a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu.cpp +++ b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu.cpp @@ -149,13 +149,12 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -343,38 +342,35 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyA = void; + using GmemTiledCopyB = void; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; // Linear Combination + element-wise GELU epilogue using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + FusionCallbacks, + void, + void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< diff --git a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu.cpp b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu.cpp index 1a21713b34..ebc87331d5 100644 --- a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu.cpp +++ b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu.cpp @@ -149,13 +149,12 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -343,38 +342,35 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyA = void; + using GmemTiledCopyB = void; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; // The Linear Combination with ReLU epilogue using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + FusionCallbacks, + void, + void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< diff --git a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu.cpp b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu.cpp index d4f040ad33..b5625398ec 100644 --- a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu.cpp +++ b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu.cpp @@ -148,13 +148,12 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -342,38 +341,35 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyA = void; + using GmemTiledCopyB = void; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; // The Linear Combination with SiLu epilogue using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + FusionCallbacks, + void, + void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< diff --git a/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_gelu.cpp b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_gelu.cpp new file mode 100644 index 0000000000..0d330b0360 --- /dev/null +++ b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_gelu.cpp @@ -0,0 +1,405 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Gemm with GELU Activation Fn epilogue + + This example constructs and executes a standard GEMM fused with a GELU (Gaussian Error Linear + Unit) activation epilogue. Aside from the epilogue operation, it is identical to 00_bmg_gemm. + + CUTLASS 3.x epilogues are implemented using the Epilogue Visitor Tree design pattern, and + typically combine 'Linear Combination' (i.e. `D = alpha * A*B + beta * C`) with an additional + epilogue operation. + + In this case, the GELU Element-wise activation function is applied: + + // D = GELU(alpha * (A*B) + beta * C) + + To build & run this example (from your build dir): + + $ ninja 05_bmg_gemm_with_epilogue_gelu + $ ./examples/sycl/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu + + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_gelu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + compat::wait(); + + using TensorView = cutlass::TensorView; + for(int batch = 0, offset = 0; batch < L; batch++, offset += M * N) { + cutlass::reference::device::TensorGeLu(TensorView(block_ref_D.get() + offset, LayoutD::packed({M, N}), + cutlass::make_Coord(M, N))); + } + + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(static_cast(M) * K * L); + block_B.reset(static_cast(K) * N * L); + block_C.reset(static_cast(M) * N * L); + block_D.reset(static_cast(M) * N * L); + block_ref_D.reset(static_cast(M) * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + compat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // Linear Combination + element-wise GELU epilogue + using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_relu.cpp b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_relu.cpp new file mode 100644 index 0000000000..1a21713b34 --- /dev/null +++ b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_relu.cpp @@ -0,0 +1,405 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Gemm with ReLU Activation Fn epilogue + + This example constructs and executes a standard GEMM fused with a ReLU (Rectified Linear Unit) + activation epilogue. Aside from the epilogue operation, it is identical to 00_bmg_gemm. + + CUTLASS 3.x epilogues are implemented using the Epilogue Visitor Tree design pattern, and + typically combine 'Linear Combination' (i.e. `D = alpha * A*B + beta * C`) with an additional + epilogue operation. + + In this case, the ReLU Element-wise activation function is applied: + + // D = ReLU(alpha * (A*B) + beta * C) + + To build & run this example (from your build dir): + + $ ninja 05_bmg_gemm_with_epilogue_relu + $ ./examples/sycl/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu + + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + compat::wait(); + + using TensorView = cutlass::TensorView; + for(int batch = 0, offset = 0; batch < L; batch++, offset += M * N) { + cutlass::reference::device::TensorReLu(TensorView(block_ref_D.get() + offset, LayoutD::packed({M, N}), + cutlass::make_Coord(M, N))); + } + + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(static_cast(M) * K * L); + block_B.reset(static_cast(K) * N * L); + block_C.reset(static_cast(M) * N * L); + block_D.reset(static_cast(M) * N * L); + block_ref_D.reset(static_cast(M) * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + compat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // The Linear Combination with ReLU epilogue + using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_silu.cpp b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_silu.cpp new file mode 100644 index 0000000000..d4f040ad33 --- /dev/null +++ b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_silu.cpp @@ -0,0 +1,404 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Intel BMG Gemm with SiLu Activation Fn epilogue + + This example constructs and executes a standard GEMM fused with a SiLu (Sigmoid Linear Unit) + activation epilogue. Aside from the epilogue operation, it is identical to + 05_bmg_gemm_with_epilogue_relu. + + The SiLu Element-wise activation function is applied as: + + // D = SiLu(alpha * (A*B) + beta * C) + + To build & run this example (from your build dir): + + $ ninja 05_bmg_gemm_with_epilogue_silu + $ ./examples/sycl/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu + + Call with `--help` for information about available options +*/ + + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_silu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + compat::wait(); + + using TensorView = cutlass::TensorView; + for(int batch = 0, offset = 0; batch < L; batch++, offset += M * N) { + cutlass::reference::device::TensorSiLu(TensorView(block_ref_D.get() + offset, LayoutD::packed({M, N}), + cutlass::make_Coord(M, N))); + } + + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + compat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // The Linear Combination with SiLu epilogue + using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/05_bmg_gemm_with_epilogues/legacy/CMakeLists.txt b/examples/05_bmg_gemm_with_epilogues/legacy/CMakeLists.txt new file mode 100644 index 0000000000..4172c755dd --- /dev/null +++ b/examples/05_bmg_gemm_with_epilogues/legacy/CMakeLists.txt @@ -0,0 +1,50 @@ +# Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_BATCHES --l=2) + +cutlass_example_add_executable( + 05_bmg_gemm_with_epilogue_gelu_legacy + 05_bmg_gemm_with_epilogue_gelu.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) + +cutlass_example_add_executable( + 05_bmg_gemm_with_epilogue_relu_legacy + 05_bmg_gemm_with_epilogue_relu.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) + +cutlass_example_add_executable( + 05_bmg_gemm_with_epilogue_silu_legacy + 05_bmg_gemm_with_epilogue_silu.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 20e64b32a9..47a5f2d227 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -108,6 +108,7 @@ if(CUTLASS_ENABLE_SYCL) 04_bmg_grouped_gemm 04_bmg_grouped_gemm/legacy 05_bmg_gemm_with_epilogues + 05_bmg_gemm_with_epilogues/legacy 06_bmg_flash_attention 07_bmg_dual_gemm 08_bmg_gemm_f8 diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 54c3e2ab73..7df8a7455d 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -318,6 +318,60 @@ struct FusionCallbacks< using Impl::Impl; }; +template < + // int FragmentSize, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + class CopyOpR2G_, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombSoftmaxRow, + CtaTileShapeMNK, + EpilogueTile +> : XeLinCombSoftmaxRow { + + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Impl = XeLinCombSoftmaxRow::type, ElementCompute, CopyOpR2G_, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombSoftmaxRow; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementOutput* output_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {output_ptr} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + template< class StrideAux, class CopyOpG2R, @@ -522,13 +576,14 @@ struct FusionCallbacks< }, // end ternary op {} // ternary args : multiply_add }; // end ternary op - } + } }; // Ctor inheritance using Impl::Impl; }; + // D = alpha * acc + beta * C + per-column bias template< int StagesC,