|
| 1 | +/*************************************************************************************************** |
| 2 | + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights |
| 3 | + *reserved. SPDX-License-Identifier: BSD-3-Clause |
| 4 | + * |
| 5 | + * Redistribution and use in source and binary forms, with or without |
| 6 | + * modification, are permitted provided that the following conditions are met: |
| 7 | + * |
| 8 | + * 1. Redistributions of source code must retain the above copyright notice, |
| 9 | + *this list of conditions and the following disclaimer. |
| 10 | + * |
| 11 | + * 2. Redistributions in binary form must reproduce the above copyright notice, |
| 12 | + * this list of conditions and the following disclaimer in the documentation |
| 13 | + * and/or other materials provided with the distribution. |
| 14 | + * |
| 15 | + * 3. Neither the name of the copyright holder nor the names of its |
| 16 | + * contributors may be used to endorse or promote products derived from |
| 17 | + * this software without specific prior written permission. |
| 18 | + * |
| 19 | + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 20 | + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 21 | + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| 22 | + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE |
| 23 | + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| 24 | + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| 25 | + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| 26 | + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| 27 | + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| 28 | + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
| 29 | + *POSSIBILITY OF SUCH DAMAGE. |
| 30 | + * |
| 31 | + **************************************************************************************************/ |
| 32 | + |
| 33 | +/*! \file |
| 34 | + \brief |
| 35 | + Default kernel-level GEMM definitions combine threadblock-scoped matrix |
| 36 | + multiply-add with the appropriate threadblock-scoped epilogue. |
| 37 | +
|
| 38 | + Note, CUTLASS epilogues universally target row-major outputs. Column-major |
| 39 | + outputs are accommodated by exchanging A and B operands and assuming |
| 40 | + transposed layouts. Partial specializations here choose |
| 41 | + 'device::GemmTransposed' to implement this functionality. |
| 42 | +
|
| 43 | +*/ |
| 44 | + |
| 45 | +#pragma once |
| 46 | + |
| 47 | +#include "cutlass/cutlass.h" |
| 48 | + |
| 49 | +#include "cutlass/complex.h" |
| 50 | +#include "cutlass/layout/matrix.h" |
| 51 | +#include "cutlass/numeric_types.h" |
| 52 | + |
| 53 | +#include "gemm_universal_k.h" |
| 54 | +#include "cutlass/gemm/kernel/gemm_universal_streamk.h" |
| 55 | +#include "cutlass/gemm/kernel/default_gemm.h" |
| 56 | +#include "cutlass/gemm/kernel/default_gemm_complex.h" |
| 57 | + |
| 58 | +#include "cutlass/layout/permute.h" |
| 59 | + |
| 60 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
| 61 | + |
| 62 | +namespace cutlass { |
| 63 | +namespace gemm { |
| 64 | +namespace kernel { |
| 65 | + |
| 66 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
| 67 | + |
| 68 | +template < |
| 69 | + /// Element type for A matrix operand |
| 70 | + typename ElementA_, |
| 71 | + /// Layout type for A matrix operand |
| 72 | + typename LayoutA_, |
| 73 | + /// Complex elementwise transformation on A operand |
| 74 | + ComplexTransform TransformA, |
| 75 | + /// Access granularity of A matrix in units of elements |
| 76 | + int kAlignmentA, |
| 77 | + /// Element type for B matrix operand |
| 78 | + typename ElementB_, |
| 79 | + /// Layout type for B matrix operand |
| 80 | + typename LayoutB_, |
| 81 | + /// Complex elementwise transformation on B operand |
| 82 | + ComplexTransform TransformB, |
| 83 | + /// Access granularity of B matrix in units of elements |
| 84 | + int kAlignmentB, |
| 85 | + /// Element type for C and D matrix operands |
| 86 | + typename ElementC_, |
| 87 | + /// Layout type for C and D matrix operands |
| 88 | + typename LayoutC_, |
| 89 | + /// Element type for internal accumulation |
| 90 | + typename ElementAccumulator, |
| 91 | + /// Operator class tag |
| 92 | + typename OperatorClass, |
| 93 | + /// Tag indicating architecture to tune for |
| 94 | + typename ArchTag, |
| 95 | + /// Threadblock-level tile size (concept: GemmShape) |
| 96 | + typename ThreadblockShape, |
| 97 | + /// Warp-level tile size (concept: GemmShape) |
| 98 | + typename WarpShape, |
| 99 | + /// Instruction tile size (concept: GemmShape) |
| 100 | + typename InstructionShape, |
| 101 | + /// Epilogue output operator |
| 102 | + typename EpilogueOutputOp, |
| 103 | + /// Threadblock-level swizzling operator |
| 104 | + typename ThreadblockSwizzle, |
| 105 | + /// Number of stages used in the pipelined mainloop |
| 106 | + int Stages, |
| 107 | + /// Operation performed by GEMM |
| 108 | + typename Operator, |
| 109 | + /// Use zfill or predicate for out-of-bound cp.async |
| 110 | + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, |
| 111 | + /// Gather operand A by using an index array |
| 112 | + bool GatherA = false, |
| 113 | + /// Gather operand B by using an index array |
| 114 | + bool GatherB = false, |
| 115 | + /// Scatter result D by using an index array |
| 116 | + bool ScatterD = false, |
| 117 | + /// Permute result D |
| 118 | + typename PermuteDLayout = layout::NoPermute, |
| 119 | + /// Permute operand A |
| 120 | + typename PermuteALayout_ = layout::NoPermute, |
| 121 | + /// Permute operand B |
| 122 | + typename PermuteBLayout_ = layout::NoPermute, |
| 123 | + /// |
| 124 | + typename Enable = void> |
| 125 | +struct DefaultGemmUniversal; |
| 126 | + |
| 127 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
| 128 | +// |
| 129 | +// Real-valued GEMM kernels |
| 130 | +// |
| 131 | + |
| 132 | +template < |
| 133 | + /// Element type for A matrix operand |
| 134 | + typename ElementA, |
| 135 | + /// Layout type for A matrix operand |
| 136 | + typename LayoutA, |
| 137 | + /// Access granularity of A matrix in units of elements |
| 138 | + int kAlignmentA, |
| 139 | + /// Element type for B matrix operand |
| 140 | + typename ElementB, |
| 141 | + /// Layout type for B matrix operand |
| 142 | + typename LayoutB, |
| 143 | + /// Access granularity of B matrix in units of elements |
| 144 | + int kAlignmentB, |
| 145 | + /// Element type for C and D matrix operands |
| 146 | + typename ElementC, |
| 147 | + /// Layout type for C and D matrix operands |
| 148 | + typename LayoutC, |
| 149 | + /// Element type for internal accumulation |
| 150 | + typename ElementAccumulator, |
| 151 | + /// Operator class tag |
| 152 | + typename OperatorClass, |
| 153 | + /// Tag indicating architecture to tune for |
| 154 | + typename ArchTag, |
| 155 | + /// Threadblock-level tile size (concept: GemmShape) |
| 156 | + typename ThreadblockShape, |
| 157 | + /// Warp-level tile size (concept: GemmShape) |
| 158 | + typename WarpShape, |
| 159 | + /// Warp-level tile size (concept: GemmShape) |
| 160 | + typename InstructionShape, |
| 161 | + /// Epilogue output operator |
| 162 | + typename EpilogueOutputOp, |
| 163 | + /// Threadblock-level swizzling operator |
| 164 | + typename ThreadblockSwizzle, |
| 165 | + /// Number of stages used in the pipelined mainloop |
| 166 | + int Stages, |
| 167 | + /// Operation performed by GEMM |
| 168 | + typename Operator, |
| 169 | + /// Use zfill or predicate for out-of-bound cp.async |
| 170 | + SharedMemoryClearOption SharedMemoryClear, |
| 171 | + /// Gather operand A by using an index array |
| 172 | + bool GatherA, |
| 173 | + /// Gather operand B by using an index array |
| 174 | + bool GatherB, |
| 175 | + /// Scatter result D by using an index array |
| 176 | + bool ScatterD, |
| 177 | + /// Permute result D |
| 178 | + typename PermuteDLayout, |
| 179 | + /// Permute operand A |
| 180 | + typename PermuteALayout, |
| 181 | + /// Permute operand B |
| 182 | + typename PermuteBLayout> |
| 183 | +struct DefaultGemmUniversal< |
| 184 | + ElementA, LayoutA, |
| 185 | + ComplexTransform::kNone, // transform A |
| 186 | + kAlignmentA, ElementB, LayoutB, |
| 187 | + ComplexTransform::kNone, // transform B |
| 188 | + kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, |
| 189 | + ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, |
| 190 | + ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, GatherA, GatherB, |
| 191 | + ScatterD, PermuteDLayout, PermuteALayout, PermuteBLayout, |
| 192 | + typename platform::enable_if< |
| 193 | + !cutlass::is_complex<ElementAccumulator>::value>::type> { |
| 194 | + using DefaultGemmKernel = typename kernel::DefaultGemm< |
| 195 | + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, |
| 196 | + LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, |
| 197 | + WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, |
| 198 | + true, Operator, SharedMemoryClear, GatherA, GatherB, ScatterD, |
| 199 | + PermuteDLayout, PermuteALayout, PermuteBLayout>::GemmKernel; |
| 200 | + |
| 201 | + /// Universal kernel without StreamkFeature member type |
| 202 | + template <class SwizzleT, class Enable = void> |
| 203 | + class SelectBase |
| 204 | + : public kernel::GemmUniversal<typename DefaultGemmKernel::Mma, |
| 205 | + typename DefaultGemmKernel::Epilogue, |
| 206 | + SwizzleT> {}; |
| 207 | + |
| 208 | + /// Universal kernel with StreamkFeature member type |
| 209 | + template <class SwizzleT> |
| 210 | + class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature> |
| 211 | + : public kernel::GemmUniversalStreamk< |
| 212 | + typename DefaultGemmKernel::Mma, |
| 213 | + typename DefaultGemmKernel::Epilogue, SwizzleT> {}; |
| 214 | + |
| 215 | + /// Select kernel by ThreadblockSwizzle's support for StreamkFeature |
| 216 | + using GemmKernel = SelectBase<ThreadblockSwizzle>; |
| 217 | +}; |
| 218 | + |
| 219 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
| 220 | + |
| 221 | +// |
| 222 | +// Complex-valued GEMM kernels |
| 223 | +// |
| 224 | + |
| 225 | +template < |
| 226 | + /// Element type for A matrix operand |
| 227 | + typename ElementA, |
| 228 | + /// Layout type for A matrix operand |
| 229 | + typename LayoutA, |
| 230 | + /// Complex elementwise transformation on A operand |
| 231 | + ComplexTransform TransformA, |
| 232 | + /// Access granularity of A matrix in units of elements |
| 233 | + int kAlignmentA, |
| 234 | + /// Element type for B matrix operand |
| 235 | + typename ElementB, |
| 236 | + /// Layout type for B matrix operand |
| 237 | + typename LayoutB, |
| 238 | + /// Complex elementwise transformation on B operand |
| 239 | + ComplexTransform TransformB, |
| 240 | + /// Access granularity of B matrix in units of elements |
| 241 | + int kAlignmentB, |
| 242 | + /// Element type for C and D matrix operands |
| 243 | + typename ElementC, |
| 244 | + /// Layout type for C and D matrix operands |
| 245 | + typename LayoutC, |
| 246 | + /// Element type for internal accumulation |
| 247 | + typename ElementAccumulator, |
| 248 | + /// Operator class tag |
| 249 | + typename OperatorClass, |
| 250 | + /// Tag indicating architecture to tune for |
| 251 | + typename ArchTag, |
| 252 | + /// Threadblock-level tile size (concept: GemmShape) |
| 253 | + typename ThreadblockShape, |
| 254 | + /// Warp-level tile size (concept: GemmShape) |
| 255 | + typename WarpShape, |
| 256 | + /// Warp-level tile size (concept: GemmShape) |
| 257 | + typename InstructionShape, |
| 258 | + /// Epilogue output operator |
| 259 | + typename EpilogueOutputOp, |
| 260 | + /// Threadblock-level swizzling operator |
| 261 | + typename ThreadblockSwizzle, |
| 262 | + /// Number of stages used in the pipelined mainloop |
| 263 | + int Stages, |
| 264 | + /// Operation performed by GEMM |
| 265 | + typename Operator, |
| 266 | + /// Use zfill or predicate for out-of-bound cp.async |
| 267 | + SharedMemoryClearOption SharedMemoryClear> |
| 268 | +struct DefaultGemmUniversal< |
| 269 | + ElementA, LayoutA, TransformA, kAlignmentA, ElementB, LayoutB, TransformB, |
| 270 | + kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, |
| 271 | + ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, |
| 272 | + ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, false, false, |
| 273 | + false, layout::NoPermute, layout::NoPermute, layout::NoPermute, |
| 274 | + typename platform::enable_if< |
| 275 | + cutlass::is_complex<ElementAccumulator>::value>::type> { |
| 276 | + using DefaultGemmKernel = typename kernel::DefaultGemmComplex< |
| 277 | + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, |
| 278 | + ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, |
| 279 | + InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, |
| 280 | + TransformA, TransformB, Operator, false>::GemmKernel; |
| 281 | + |
| 282 | + /// Universal kernel without StreamkFeature member type |
| 283 | + template <class SwizzleT, class Enable = void> |
| 284 | + class SelectBase |
| 285 | + : public kernel::GemmUniversal<typename DefaultGemmKernel::Mma, |
| 286 | + typename DefaultGemmKernel::Epilogue, |
| 287 | + SwizzleT> {}; |
| 288 | + |
| 289 | + /// Universal kernel with StreamkFeature member type |
| 290 | + template <class SwizzleT> |
| 291 | + class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature> |
| 292 | + : public kernel::GemmUniversalStreamk< |
| 293 | + typename DefaultGemmKernel::Mma, |
| 294 | + typename DefaultGemmKernel::Epilogue, SwizzleT> {}; |
| 295 | + |
| 296 | + /// Select kernel by ThreadblockSwizzle's support for StreamkFeature |
| 297 | + using GemmKernel = SelectBase<ThreadblockSwizzle>; |
| 298 | +}; |
| 299 | + |
| 300 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
| 301 | + |
| 302 | +} // namespace kernel |
| 303 | +} // namespace gemm |
| 304 | +} // namespace cutlass |
| 305 | + |
| 306 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
0 commit comments