Skip to content
Open
11 changes: 7 additions & 4 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@ int main(int argc, const char** argv)
// Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;
using GmemTiledCopyC = XE_LOAD_2D<32, 8, 16>;
using GmemTiledCopyD = XE_STORE_2D<32, 8, 16>;



// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;
Expand All @@ -369,9 +373,8 @@ int main(int argc, const char** argv)

// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
constexpr int PipelineStages = 2;
// For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeL1Staged;

// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
// (D = alpha * (A*B) + beta * C)
Expand All @@ -394,9 +397,9 @@ int main(int argc, const char** argv)
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
FusionCallBacks,
XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
GmemTiledCopyC, // The copy atom used to load matrix C
void, void,
XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
GmemTiledCopyD, // The copy atom used to store matrix D
void, void>;

// GEMM Mainloop - iteration over blocks in K dimension
Expand Down
12 changes: 11 additions & 1 deletion include/cute/atom/copy_traits_xe_2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,12 +1143,22 @@ template <class CopyOp, class TiledMMA, class CTensor>
auto get_block_2d_copy_C(TiledMMA const& tiled_mma, CTensor const& c_tensor)
{
if constexpr (!std::is_void_v<CopyOp>) {
return make_block_2d_copy_C(CopyOp{}, tiled_mma, c_tensor);
return make_block_2d_copy_CD(CopyOp{}, tiled_mma, c_tensor);
} else {
return make_block_2d_copy_C(tiled_mma, c_tensor);
}
}

template <class CopyOp, class TiledMMA, class DTensor>
auto get_block_2d_copy_D(TiledMMA const& tiled_mma, DTensor const& d_tensor)
{
if constexpr (!std::is_void_v<CopyOp>) {
return make_block_2d_copy_CD(CopyOp{}, tiled_mma, d_tensor);
} else {
return make_block_2d_copy_D(tiled_mma, d_tensor);
}
}

//
// Display utilities
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class CollectiveEpilogue {
#include "sm100_epilogue_array_tma_warpspecialized.hpp"
#if defined (SYCL_INTEL_TARGET)
#include "xe_epilogue.hpp"
#include "xe_epilogue_legacy.hpp"
#include "xe_array_epilogue.hpp"
#endif
//
Expand Down
144 changes: 86 additions & 58 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ template <
class CopyOpR2S_
>
class CollectiveEpilogue<
IntelXeXMX16,
IntelXeL1Staged,
CtaTileMNK_,
ElementC_,
StrideC_,
Expand All @@ -86,7 +86,7 @@ class CollectiveEpilogue<
//
// Type Aliases
//
using DispatchPolicy = IntelXeXMX16;
using DispatchPolicy = IntelXeL1Staged;
using CtaTileMNK = CtaTileMNK_;
using FusionCallbacks = FusionCallbacks_;
using ElementC = ElementC_;
Expand All @@ -101,9 +101,9 @@ class CollectiveEpilogue<
using CopyOpR2S = CopyOpR2S_;

using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits<FusionCallbacks>::Operation;
using GmemTiledCopyC = conditional_t<cute::is_void_v<CopyOpG2R>, XE_2D_U32x8x16_LD_N, CopyOpG2R>;
using GmemTiledCopyC = conditional_t<cute::is_void_v<CopyOpG2R>, XE_LOAD_2D<32, 8, 16>, CopyOpG2R>;
using GmemTiledCopyD = cute::conditional_t<not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
CopyOpR2G, XE_2D_U32x8x16_ST_N>;
CopyOpR2G, XE_STORE_2D<32, 8, 16>>;
using ElementOutput = ElementD;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementAccumulator = ElementCompute;
Expand All @@ -118,16 +118,6 @@ class CollectiveEpilogue<
static_assert(std::is_same_v<SmemLayoutAtomC, void>, "Copy operation to shared memory is not supported");
static_assert(std::is_same_v<SmemLayoutAtomD, void>, "Copy operation to shared memory is not supported");

using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;

using Trait_C = Copy_Traits<GmemTiledCopyC, StrideC>;
using val_layout_load_C = decltype(make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})));
using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom<Trait_C, ElementC>{}, Layout<CopyThreadShape>{}, val_layout_load_C{}));

using Trait_D = Copy_Traits<GmemTiledCopyD, StrideD>;
using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})));
using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom<Trait_D, ElementD>{}, Layout<CopyThreadShape>{}, val_layout_store_D{}));

private:
constexpr static bool is_source_supported = not cute::is_void_v<ElementC> && not cute::is_void_v<CopyOpG2R>;
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>;
Expand All @@ -153,6 +143,15 @@ class CollectiveEpilogue<
};
using TensorStorage = typename SharedStorage::TensorStorage;

// Helper to get tensor types
template<class Element, class Stride>
using TensorTypeC = decltype(make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
make_layout(make_shape(int{}, int{}, int{}), Stride{})));

template<class Element, class Stride>
using TensorTypeD = decltype(make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)),
make_layout(make_shape(int{}, int{}, int{}), Stride{})));

// Host side epilogue arguments
struct Arguments {
typename FusionCallbacks::Arguments thread{};
Expand All @@ -165,8 +164,8 @@ class CollectiveEpilogue<
// Device side epilogue params
struct Params {
typename FusionCallbacks::Params thread{};
XE_Copy_C xe_load_c;
XE_Copy_D xe_store_d;
TensorTypeC<ElementC, StrideC> mC;
TensorTypeD<ElementD, StrideD> mD;
};

//
Expand All @@ -182,23 +181,13 @@ class CollectiveEpilogue<
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;

XE_Copy_C xe_load_c = {};
if constexpr (is_source_supported) {
auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC));
xe_load_c = {xe_load_c.with(mC)};
}

XE_Copy_D xe_store_d = {};
if constexpr (is_destination_supported) {
auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD));
xe_store_d = {xe_store_d.with(mD)};
}
auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC));
auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD));

return {
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
xe_load_c,
xe_store_d,
mC,
mD
};
}

Expand Down Expand Up @@ -269,6 +258,37 @@ class CollectiveEpilogue<
return fusion_callbacks.is_producer_load_needed();
}

template<typename Tensor>
CUTLASS_DEVICE auto reshape_with_unit_insertion(Tensor&& tensor) {
using namespace cute;

auto orig_layout = tensor.layout();
auto orig_shape = orig_layout.shape();
auto orig_stride = orig_layout.stride();

auto first_dim = get<0>(orig_shape);
auto outer_part = get<0>(first_dim);
auto inner_part = get<1>(first_dim);

auto first_stride = get<0>(orig_stride);
auto outer_stride = get<0>(first_stride);
auto inner_stride = get<1>(first_stride);

auto target_shape = make_shape(
make_shape(outer_part, _1{}),
get<0>(inner_part),
get<1>(inner_part)
);

auto target_stride = make_stride(
make_stride(outer_stride, _0{}),
get<0>(inner_stride),
get<1>(inner_stride)
);

return make_tensor(tensor.data(), make_layout(target_shape, target_stride));
}

template<
class ProblemShapeMNKL,
class TileShapeMNK,
Expand All @@ -285,7 +305,6 @@ class CollectiveEpilogue<
TiledMma tiled_mma,
int thread_idx) {

(void) tiled_mma;
using namespace cute;

static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
Expand All @@ -296,12 +315,11 @@ class CollectiveEpilogue<
static constexpr auto BLK_M = get<0>(CtaTileMNK{});
static constexpr auto BLK_N = get<1>(CtaTileMNK{});
static constexpr auto BLK_K = get<2>(CtaTileMNK{});
// static_assert(is_same_v<typename TiledMma::ThrLayoutVMNK, int>, "assertation fail");
static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());
static_assert(

static_assert(
BLK_M % ATOM_M == 0 &&
BLK_N % ATOM_N == 0 &&
BLK_K % ATOM_K == 0,
Expand All @@ -315,46 +333,53 @@ class CollectiveEpilogue<
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group

static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;

// Indexing variables
auto [M, N, K, L] = problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
auto m_sg = get_sub_group_id() / ATOM_N;
auto n_sg = get_sub_group_id() % ATOM_N;

auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{});

auto sg_local_m_coord = get_sub_group_id() / ATOM_N;
auto sg_local_n_coord = get_sub_group_id() % ATOM_N;

auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord;
auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord;
auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord);


auto wg_coord = make_coord(m_coord, n_coord, k_coord, l_coord);
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();

/*
* NOTE: Automatic selection of load/store operations using make_block_2d_copy_C/make_block_2d_copy_D
* is currently not supported. The current implementation is restricted to specific load/store
* operations with dimensions 16x8, which are tightly coupled to the MMA atom size requirements.
*
* TODO: Future enhancement will include automatic selection of load/store operations
* in collectiveEpilogue to provide more flexible dimension support.
*/
auto batch_idx = get<3>(wg_coord);
auto copy_c = make_block_2d_copy_CD(GmemTiledCopyC{}, tiled_mma, params.mC(_,_,batch_idx));
auto copy_d = make_block_2d_copy_CD(GmemTiledCopyD{}, tiled_mma, params.mD(_,_,batch_idx));

// Represent the full output tensor
Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L));

// Tile the output tensor per WG and select the tile for current WG
Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N)

// Tile the output tensor per SG and select tile for the current SG
Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N)
// Tile the output tensor for the current workgroup
Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N)

auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx);
Tensor tCgC = thread_xe_load_c.partition_S(gD);
auto thread_xe_load_c = copy_c.get_thread_slice(thread_idx);
Tensor tCgC = reshape_with_unit_insertion(thread_xe_load_c.partition_S(gD));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this reshape and why is it needed?

Copy link
Author

@anamikac-intel anamikac-intel Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rolandschulz - The reshape_with_unit_insertion is needed because of differences between legacy and new atoms. In the legacy code, after partitioning with old atoms, we got the tCgC/tCgD layout as ArithTuple(0,0,0) o ((_8,_1),_4,_4):((_1@0,_0),_8@0,_16@1) so we could processes 8 elements across 4×4 iterations. However, with new atoms, I was getting ArithTuple(0,0,0) o ((_8,(_4,_4)),_1,_1):((_1@0,(_8@0,_16@1)),_0,_0) instead.

The legacy code uses trC/trD with layout ptr32b o (_8):(_1), which corresponds to a fragmentSize of 8. This fragment size is calculated as (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize,
which equals 8 given MmaAtomShape(8x16x16) and SubgroupSize of 16. When performing copy operations like copy(params.xe_load_c, tCgC(, epi_m, epi_n), trC) or copy(params.xe_store_d, trD, tCgD(, epi_m, epi_n)),
we process small fragments of 8 elements across FragM x FragN (4x4) iterations.

In my earlier implementation, I was doing direct bulk copying of 128 elements tCgC/tCgD :ArithTuple(0,0,0) o ((_8,(_4,_4)),_1,_1):((_1@0,(_8@0,_16@1)),_0,_0) to trC/trD ptr32b o ((_8,(_4,_4)),_1,_1):((_1,(_8,_32)),_0,_0) without reshaping, but this caused register spills which worsen with larger block size but was working with all load ops. So, the reshape allows me to match the legacy behavior by processing 8 elements at a time, eliminating register spills. I'm unsure if there's another function associated with new atoms that could achieve this layout without reshaping - waiting for Peter to check & revert. However legacy approach is currently restricted to XE_2D_U32x8x16_LD_N/ST_N operations (16×8 dimensions only) due to the code design limitations.

Copy link

@petercad petercad Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying an offline comment here:

I'm fine with merging the changes here as long as the level of functionality is no worse than what was already available, and performance is on par. But yes, I think the legacy implementation needs to be rewritten, for several reasons:

  • Hardcoded accumulator layouts. It assumes a very specific layout of C blocks within the accumulator tile, and does not verify its assumptions.
  • Lack of flexibility in C/D copy operations. Due to the tie-in with the MMA atom size, as Anamika mentioned above, the code breaks if the load/store atom is larger than the MMA atom.
    • This is a performance problem if you are downconverting C (e.g. float -> half/bf16) — i.e. in most cases of interest — because then your store atoms are performing partial cache line accesses, which by default will not be coalesced in L1$.

@anamikac-intel — if you want to merge your changes as-is and address the points above later, let's create a JIRA to track the technical debt. I think we need a JIRA too for the dimensionality assumptions for the CollectiveMMA we discussed on #540 as well.

Copy link

@sanchitintel sanchitintel Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my earlier implementation, I was doing direct bulk copying of 128 elements tCgC/tCgD :ArithTuple(0,0,0) o ((_8,(_4,_4)),_1,_1):((_1@0,(_8@0,_16@1)),_0,_0) to trC/trD ptr32b o ((_8,(_4,_4)),_1,_1):((_1,(_8,_32)),_0,_0) without reshaping, but this caused register spills which worsen with larger block size but was working with all load ops.

When Anamika manually specified copy atoms in her previous implementation, the performance was still worse than the legacy implementation, so it seems this problem (register spills, as described above) is not inherent to make_block_2d_copy_C or make_block_2d_copy_D, but the epilogue implementation instead.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchitintel Yes, that's right. Though as I mentioned above, fundamentally it's a problem with IGC code scheduling and not even the epilogue implementation.

To make code scheduling easier for the compiler, it seems we do need to break up the C/D access into smaller tiles (might as well do one atom at a time). But instead of hard-coding the accumulator access, we should use the regular CuTe machinery to tile C/D/accumulator into smaller tiles.

Copy link

@sanchitintel sanchitintel Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for clarifying, @petercad!

Please correct me if I got it wrong - the observed IGC issues comes into play when we do some specific type of compute (such as epilogue computation) on large tiles, and it isn't the R2G/G2R copies with manually specified copy atoms alone that are problematic for IGC. Tiling the SG-tile further into smaller tiles may help make scheduling easier for IGC.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right. When we manually tile the subgroup tile into smaller tiles, we aren't changing the instructions, just how they are ordered.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the "reshape" here is not a general solution which is only for special case i think.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the code is specifically designed to work only with particular load/store operations (16 x 8 dimensions). Because it's tightly coupled with the MMA atom size, the implementation fails when the load/store atom exceeds the MMA atom size. Even attempts to generalize it for other load/store operations won't work due to this fundamental constraint. After discussing with Peter, we agreed that if performance matches the old atom, we can merge this change. However, we'll address the underlying technical debt in a separate PR #573 (comment)


auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx);
Tensor tCgD = thread_xe_store_d.partition_D(gD);
auto thread_xe_store_d = copy_d.get_thread_slice(thread_idx);
Tensor tCgD = reshape_with_unit_insertion(thread_xe_store_d.partition_D(gD));

Tensor trC = make_tensor<ElementC>(Shape<Int<FragmentSize>>{});
Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});

// Because Sm90 uses shared memory, they are not tied to using the same accumulator values
// for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be
// sure that we are operating on the same values.
ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx);
ThrCopy thread_g2r = copy_c.get_slice(thread_idx);
auto mn_shape = shape(typename decltype(copy_d)::Tiler_MN{});

// OOB predication for tile quantization "residue"
// Absolute coordinate tensors (dynamic)
Expand All @@ -363,7 +388,7 @@ class CollectiveEpilogue<
Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N)
Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)

Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)
Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout());

// Get the fusion callbacks
// Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles
Expand All @@ -375,7 +400,7 @@ class CollectiveEpilogue<
sg_coord,
tiled_mma,
mn_shape,
params.xe_store_d,
copy_d,
cD,
residue_mn,
tRS_cD,
Expand All @@ -400,7 +425,8 @@ class CollectiveEpilogue<
FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K;
constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{});
static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" );



auto synchronize = [&] () {};
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < FragsN; epi_n++) {
Expand All @@ -409,7 +435,7 @@ class CollectiveEpilogue<
cst_callbacks.begin_loop(epi_m, epi_n);

if (is_C_load_needed) {
copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC);
copy(copy_c, tCgC(_, epi_m, epi_n), trC);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anamikac-intel can you replace "tCgC(, epi_m, epi_n)" with "tCgC((, (epi_m, epi_n)), _, _)" which is what you want for "reshape_with_unit_insertion" although it's not a good fix.

Copy link
Author

@anamikac-intel anamikac-intel Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taozha2 -- this workaround not working.. please see the layout below:

tCgC ArithTuple(0,0,0) o ((_8,(_4,_4)),_1,_1):((_1@0,(_8@0,_16@1)),_0,0)
**tCgC((
, (epi_m, epi_n)), _, _) ArithTuple(0,0,0) o (_1,_1):(_0,0)**
tCgC(
, epi_m, epi_n) ArithTuple(0,0,0) o ((_8,_1)):((_1@0,_0)) (after reshape) --> this is expected
tCgD ArithTuple(0,0,0) o ((_8,(_4,_4)),_1,_1):((_1@0,(_8@0,_16@1)),_0,0)
**tCgD((
, (epi_m, epi_n)), _, _) ArithTuple(0,0,0) o (_1,_1):(_0,0)**
tCgD(
, epi_m, epi_n) ArithTuple(0,0,0) o ((_8,_1)):((_1@0,_0)) (after reshape) --> this is expected
trC ptr32b o (_8):(_1)
trD ptr32b o (_8):(_1)

}

cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed);
Expand All @@ -421,21 +447,23 @@ class CollectiveEpilogue<
trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
}
cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag);

if constexpr (is_destination_supported) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(trD_compute_frag); ++i) {
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, RegisterElementD, FragmentSize>{}(trD_compute_frag(i));
}
copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n));
copy(copy_d, trD, tCgD(_, epi_m, epi_n));
}

cst_callbacks.end_loop(epi_m, epi_n);

}
}

cst_callbacks.end();
}

}

private:
Params const& params;
Expand All @@ -449,4 +477,4 @@ class CollectiveEpilogue<
} // namespace epilogue
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
Loading
Loading