diff --git a/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp b/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp index 251a4d1f10..b7bb1d2753 100644 --- a/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp +++ b/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp @@ -156,7 +156,7 @@ struct ExampleRunner { using ElementAcc = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; - using ElementC = typename Gemm::ElementC; + using ElementC = typename CollectiveEpilogue::ElementOutput; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; @@ -375,7 +375,7 @@ int main(int argc, const char** argv) // aside from the (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for more // complex epilogue examples. using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + ElementOutput, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch // policy/architecture) and defines the epilogue arguments. diff --git a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp index bdda0536d2..1dbe943c7d 100644 --- a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp +++ b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp @@ -190,7 +190,6 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; @@ -199,7 +198,8 @@ struct ExampleRunner { using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementOutput = typename CollectiveEpilogue::ElementOutput; - using ElementAccumulator = ElementOutput; + using ElementAccumulator = ElementAccumulator; + using ElementC = typename CollectiveEpilogue::ElementOutput; using StrideA = typename Gemm::GemmKernel::InternalStrideA; using StrideB = typename Gemm::GemmKernel::InternalStrideB; @@ -585,7 +585,7 @@ int main(int argc, const char** argv) using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + ElementOutput, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index e8b1709aad..0c2b2fd51a 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -90,7 +90,7 @@ class CollectiveEpilogue< using DispatchPolicy = IntelXeXMX16Group; using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementC_; + using ElementC = typename FusionCallbacks::ElementSource; using ElementAccumulator = ElementC_; using StrideC = StrideC_; using InternalStrideC = cute::remove_pointer_t; @@ -115,7 +115,7 @@ class CollectiveEpilogue< static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; static_assert(cute::is_same_v>, + fusion::LinearCombination>, "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -372,6 +372,7 @@ class CollectiveEpilogue< Tensor tCgD = thread_xe_store_d.partition_D(gD); Tensor trC = make_tensor(Shape>{}); + auto trC_frag = recast>(trC); Tensor trD_compute = make_tensor(Shape>{}); // Because Sm90 uses shared memory, they are not tied to using the same accumulator values @@ -421,6 +422,8 @@ class CollectiveEpilogue< constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + constexpr bool is_same_dtype_accum_and_output = std::is_same_v; + auto synchronize = [&] () {}; CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { @@ -428,8 +431,15 @@ class CollectiveEpilogue< for (int epi_m = 0; epi_m < FragsM; epi_m++) { if (is_C_load_needed) { - //cordinates for C and D are the same + if constexpr (is_same_dtype_accum_and_output) { + //cordinates for C and D are the same copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC); + } else { + Tensor trC_ori = make_tensor(Shape>{}); + copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC_ori); + auto trC_ori_frag = recast>(trC_ori); + *(trC_frag.data()) = cutlass::NumericArrayConverter{}(*(trC_ori_frag.data())); + } } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); @@ -438,7 +448,13 @@ class CollectiveEpilogue< CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { - trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + if constexpr (is_same_dtype_accum_and_output) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } else { + // align dtypes firstly + auto tmp = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + trD_compute_frag(epi_v) = cutlass::NumericArrayConverter{}(tmp); + } } cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 003f5de776..779a32b84f 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -89,7 +89,7 @@ class CollectiveEpilogue< using DispatchPolicy = IntelXeXMX16; using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementC_; + using ElementC = typename FusionCallbacks::ElementSource;; using ElementAccumulator = ElementC_; using StrideC = StrideC_; using ElementD = ElementD_; @@ -350,6 +350,7 @@ class CollectiveEpilogue< Tensor tCgD = thread_xe_store_d.partition_D(gD); Tensor trC = make_tensor(Shape>{}); + auto trC_frag = recast>(trC); Tensor trD_compute = make_tensor(Shape>{}); // Because Sm90 uses shared memory, they are not tied to using the same accumulator values @@ -398,7 +399,9 @@ class CollectiveEpilogue< FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); - + + constexpr bool is_same_dtype_accum_and_output = std::is_same_v; + auto synchronize = [&] () {}; CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { @@ -407,7 +410,14 @@ class CollectiveEpilogue< cst_callbacks.begin_loop(epi_m, epi_n); if (is_C_load_needed) { - copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); + if constexpr (is_same_dtype_accum_and_output) { + copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); + } else { + Tensor trC_ori = make_tensor(Shape>{}); + copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC_ori); + auto trC_ori_frag = recast>(trC_ori); + *(trC_frag.data()) = cutlass::NumericArrayConverter{}(*(trC_ori_frag.data())); + } } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); @@ -416,7 +426,13 @@ class CollectiveEpilogue< CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { - trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + if constexpr (is_same_dtype_accum_and_output) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } else { + // align dtypes firstly + auto tmp = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + trD_compute_frag(epi_v) = cutlass::NumericArrayConverter{}(tmp); + } } cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 5173d77000..f78f7b2862 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -343,7 +343,7 @@ template < class ElementOutput_, class ElementCompute_, class ElementAux, - class ElementSource, + class ElementSource_, class ElementScalar, int AlignmentAux, FloatRoundStyle RoundStyle, @@ -355,28 +355,29 @@ struct FusionCallbacks< epilogue::IntelXeXMX16, fusion::LinCombDeEltAct< GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, - ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + ElementAux, ElementSource_, ElementScalar, AlignmentAux, RoundStyle >, CtaTileShapeMNK, EpilogueTile, CopyOpG2R > : XeLinCombDeEltAct< cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, ElementOutput_, - ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle + ElementCompute_, ElementAux, ElementSource_, ElementScalar, RoundStyle > { using ElementOutput = ElementOutput_; using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; using Impl = XeLinCombDeEltAct< cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, ElementOutput, - ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle + ElementCompute, ElementAux, ElementSource_, ElementScalar, RoundStyle >; using Operation = fusion::LinCombDeEltAct< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + ElementAux, ElementSource_, ElementScalar, AlignmentAux, RoundStyle >; struct Arguments { diff --git a/test/unit/gemm/device/default_gemm_group_configuration.hpp b/test/unit/gemm/device/default_gemm_group_configuration.hpp index 33003e2863..ee98a07ddf 100644 --- a/test/unit/gemm/device/default_gemm_group_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_group_configuration.hpp @@ -87,7 +87,7 @@ struct DefaultGemmGroupConfiguration< using TiledMma = typename CollectiveMainloop::TiledMma; - using EpilogueOp = epilogue::fusion::LinearCombination; + using EpilogueOp = epilogue::fusion::LinearCombination; using FusionCallBacks = epilogue::fusion::FusionCallbacks< epilogue::IntelXeXMX16Group, @@ -101,7 +101,7 @@ struct DefaultGemmGroupConfiguration< TileShape, Shape<_1, _1, _1>, epilogue::collective::EpilogueTileAuto, float, float, - float, LayoutC, 1, + ElementOutput, LayoutC, 1, ElementOutput, LayoutC, 1, epilogue::IntelXeXMX16Group, EpilogueOp