Skip to content

Commit 3f2a337

Browse files
authored
Rearchitecture: Xe epilogue (#621)
This PR builds on #573, adding a `CollectiveEpilogue` with support for the new block 2D copy atoms. The existing epilogue implementation was mostly rewritten, as it had many hardcoded assumptions and limitations: * Subgroups own a contiguous tile within the workgroup tile * Subgroup tiles are laid out n-major within the workgroup tile * C/D atoms have the same block size * One copy atom of data is processed at a time * C/D atoms must bring data in the exact same layout as the accumulator The new implementation removes all these restrictions. Its API is also somewhat different, mostly in ways that more closely match the SM90 epilogues: * Configurable EpilogueTile template parameter controls the block size for epilogue computation. * Fusion callbacks receive workgroup-scope tiling information, not subgroup-scope tiling information (because CuTe's TiledMMA is very flexible -- the subgroup "tile" may not be contiguous). * Vectorization for the epilogue compute operations is configurable via the `ComputeVectorLen` constexpr variable. Currently this is set to operate on one MMA atom's worth of accumulator data at a time, but if we want to make it user-configurable like the NV epilogues (where it's a template parameter for the dispatch policy), that's possible. * It receives the TiledMMA as a template parameter rather than an argument to `operator()`. * The S2R/R2S copy operation parameters are omitted (a difference vs. SM90) as they are irrelevant to both the old and new epilogue implementation. The new implementation glues together C/D loads and compute with reorders, so it can support efficient data type and layout conversions outside of the epilogue computation.
1 parent 92785e4 commit 3f2a337

File tree

7 files changed

+221
-282
lines changed

7 files changed

+221
-282
lines changed

examples/00_bmg_gemm/00_bmg_gemm.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,12 @@ struct ExampleRunner {
154154

155155
using ElementA = typename Gemm::ElementA;
156156
using ElementB = typename Gemm::ElementB;
157-
using ElementAcc = typename Gemm::ElementAccumulator;
157+
using ElementAccumulator = typename Gemm::ElementAccumulator;
158158

159159
using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
160160
using ElementC = typename Gemm::ElementC;
161161
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
162162
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
163-
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
164163

165164
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
166165

@@ -343,17 +342,13 @@ int main(int argc, const char** argv)
343342
using LayoutC = cutlass::layout::RowMajor;
344343
using LayoutD = cutlass::layout::RowMajor;
345344

346-
// [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects
347-
// appropriate 2D block copy operations for matrices A and B. Alternatively, you can
348-
// explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI
349-
// (applicable only to matrix B), or XE_LOAD_2D_TRANSPOSE.
345+
// [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects
346+
// appropriate 2D block copy operations for matrices A and B. Alternatively, you can
347+
// explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI,
348+
// or XE_LOAD_2D_TRANSPOSE.
350349
// Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md
351350
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
352351
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;
353-
using GmemTiledCopyC = XE_LOAD_2D<32, 8, 16>;
354-
using GmemTiledCopyD = XE_STORE_2D<32, 8, 16>;
355-
356-
357352

358353
// Workgroup-level tile
359354
using TileShape = Shape<_256, _256, _32>;
@@ -373,6 +368,7 @@ int main(int argc, const char** argv)
373368

374369
// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
375370
constexpr int PipelineStages = 2;
371+
// For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy
376372
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
377373
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;
378374

@@ -385,22 +381,21 @@ int main(int argc, const char** argv)
385381

386382
// FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch
387383
// policy/architecture) and defines the epilogue arguments.
388-
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
384+
using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
389385
decltype(tile_shape(TiledMma()))>;
390386
// GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any
391387
// auxiliary data required
392388
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
393389
EpilogueDispatchPolicy,
394-
TileShape,
390+
TiledMma,
391+
void, // Epilogue tile (void = automatic)
395392
ElementAccumulator,
396393
cutlass::gemm::TagToStrideC_t<LayoutC>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
397394
ElementOutput,
398395
cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
399-
FusionCallBacks,
400-
GmemTiledCopyC, // The copy atom used to load matrix C
401-
void, void,
402-
GmemTiledCopyD, // The copy atom used to store matrix D
403-
void, void>;
396+
FusionCallbacks,
397+
void, // The copy atom used to load matrix C (void = automatic)
398+
void>; // The copy atom used to store matrix D (void = automatic)
404399

405400
// GEMM Mainloop - iteration over blocks in K dimension
406401
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
@@ -417,9 +412,9 @@ int main(int argc, const char** argv)
417412

418413
// Define the whole kernel (mainloop and epilogue)
419414
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
420-
Shape<int, int, int, int>, // Defer global problem shape definition to runtime
421-
CollectiveMainloop,
422-
CollectiveEpilogue
415+
Shape<int, int, int, int>, // Defer global problem shape definition to runtime
416+
CollectiveMainloop,
417+
CollectiveEpilogue
423418
>;
424419

425420
// The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g.

include/cute/tensor_sg.hpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,25 @@ template <class Engine,
105105
class SubgroupTVLayout,
106106
__CUTE_REQUIRES(is_layout<SubgroupTVLayout>::value)>
107107
CUTE_HOST_DEVICE
108-
constexpr auto
109-
make_subgroup_tensor(Tensor<Engine, Layout> const& tensor, SubgroupTVLayout const&)
108+
constexpr decltype(auto)
109+
make_subgroup_tensor(Tensor<Engine,Layout>& tensor, SubgroupTVLayout const& tv_layout)
110+
{
111+
static_assert(is_static_v<SubgroupTVLayout>, "Subgroup TV layout must be static");
112+
static_assert(is_rmem_v<Engine>, "Expected an rmem tensor");
113+
return make_subgroup_tensor(make_tensor(tensor.data(), tensor.layout()), tv_layout);
114+
}
115+
116+
template <class Engine,
117+
class Layout,
118+
class SubgroupTVLayout,
119+
__CUTE_REQUIRES(is_layout<SubgroupTVLayout>::value)>
120+
CUTE_HOST_DEVICE
121+
constexpr decltype(auto)
122+
make_subgroup_tensor(Tensor<Engine,Layout>&& tensor, SubgroupTVLayout const&)
110123
{
111124
static_assert(is_static_v<SubgroupTVLayout>, "Subgroup TV layout must be static");
112125
static_assert(is_rmem_v<Engine>, "Expected an rmem tensor");
113-
return static_cast<SubgroupTensor<Engine,Layout,SubgroupTVLayout> const&>(tensor);
126+
return static_cast<SubgroupTensor<Engine,Layout,SubgroupTVLayout>&&>(tensor);
114127
}
115128

116129
// Create a new owning SubgroupTensor with the given subgroup-level layout.

include/cute/util/type_traits.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,4 +327,10 @@ struct is_any_of {
327327
template <class T, class... Us>
328328
inline constexpr bool is_any_of_v = is_any_of<T, Us...>::value;
329329

330+
//
331+
// replace_void_t
332+
//
333+
template <class T, class ReplacementTypeIfVoid>
334+
using replace_void_t = conditional_t<is_void_v<T>, ReplacementTypeIfVoid, T>;
335+
330336
} // end namespace cute

0 commit comments

Comments
 (0)