Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,12 @@ struct ExampleRunner {

using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementAcc = typename Gemm::ElementAccumulator;
using ElementAccumulator = typename Gemm::ElementAccumulator;

using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
using ElementC = typename Gemm::ElementC;
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;

using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;

Expand Down Expand Up @@ -343,17 +342,13 @@ int main(int argc, const char** argv)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

// [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects
// appropriate 2D block copy operations for matrices A and B. Alternatively, you can
// explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI
// (applicable only to matrix B), or XE_LOAD_2D_TRANSPOSE.
// [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects
// appropriate 2D block copy operations for matrices A and B. Alternatively, you can
// explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI,
// or XE_LOAD_2D_TRANSPOSE.
// Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;
using GmemTiledCopyC = XE_LOAD_2D<32, 8, 16>;
using GmemTiledCopyD = XE_STORE_2D<32, 8, 16>;



// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;
Expand All @@ -373,6 +368,7 @@ int main(int argc, const char** argv)

// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
constexpr int PipelineStages = 2;
// For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;

Expand All @@ -385,22 +381,21 @@ int main(int argc, const char** argv)

// FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch
// policy/architecture) and defines the epilogue arguments.
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
decltype(tile_shape(TiledMma()))>;
// GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any
// auxiliary data required
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
EpilogueDispatchPolicy,
TileShape,
TiledMma,
void, // Epilogue tile (void = automatic)
ElementAccumulator,
cutlass::gemm::TagToStrideC_t<LayoutC>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
FusionCallBacks,
GmemTiledCopyC, // The copy atom used to load matrix C
void, void,
GmemTiledCopyD, // The copy atom used to store matrix D
void, void>;
FusionCallbacks,
void, // The copy atom used to load matrix C (void = automatic)
void>; // The copy atom used to store matrix D (void = automatic)

// GEMM Mainloop - iteration over blocks in K dimension
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
Expand All @@ -417,9 +412,9 @@ int main(int argc, const char** argv)

// Define the whole kernel (mainloop and epilogue)
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Defer global problem shape definition to runtime
CollectiveMainloop,
CollectiveEpilogue
Shape<int, int, int, int>, // Defer global problem shape definition to runtime
CollectiveMainloop,
CollectiveEpilogue
>;

// The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g.
Expand Down
19 changes: 16 additions & 3 deletions include/cute/tensor_sg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,25 @@ template <class Engine,
class SubgroupTVLayout,
__CUTE_REQUIRES(is_layout<SubgroupTVLayout>::value)>
CUTE_HOST_DEVICE
constexpr auto
make_subgroup_tensor(Tensor<Engine, Layout> const& tensor, SubgroupTVLayout const&)
constexpr decltype(auto)
make_subgroup_tensor(Tensor<Engine,Layout>& tensor, SubgroupTVLayout const& tv_layout)
{
static_assert(is_static_v<SubgroupTVLayout>, "Subgroup TV layout must be static");
static_assert(is_rmem_v<Engine>, "Expected an rmem tensor");
return make_subgroup_tensor(make_tensor(tensor.data(), tensor.layout()), tv_layout);
}

template <class Engine,
class Layout,
class SubgroupTVLayout,
__CUTE_REQUIRES(is_layout<SubgroupTVLayout>::value)>
CUTE_HOST_DEVICE
constexpr decltype(auto)
make_subgroup_tensor(Tensor<Engine,Layout>&& tensor, SubgroupTVLayout const&)
{
static_assert(is_static_v<SubgroupTVLayout>, "Subgroup TV layout must be static");
static_assert(is_rmem_v<Engine>, "Expected an rmem tensor");
return static_cast<SubgroupTensor<Engine,Layout,SubgroupTVLayout> const&>(tensor);
return static_cast<SubgroupTensor<Engine,Layout,SubgroupTVLayout>&&>(tensor);
}

// Create a new owning SubgroupTensor with the given subgroup-level layout.
Expand Down
6 changes: 6 additions & 0 deletions include/cute/util/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,10 @@ struct is_any_of {
template <class T, class... Us>
inline constexpr bool is_any_of_v = is_any_of<T, Us...>::value;

//
// replace_void_t
//
template <class T, class ReplacementTypeIfVoid>
using replace_void_t = conditional_t<is_void_v<T>, ReplacementTypeIfVoid, T>;

} // end namespace cute
Loading