Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
720 changes: 720 additions & 0 deletions examples/11_bmg_moe_gemm_bf16/11_bmg_moe_gemm_bf16.cpp

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions examples/11_bmg_moe_gemm_bf16/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Intel Corporation. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

cutlass_example_add_executable(
11_bmg_moe_gemm_bf16
11_bmg_moe_gemm_bf16.cpp
TEST_COMMAND_OPTIONS
)

# TODO(codeplay): Remove these once IGC VectorAliasThreshold issue is fixed
set_target_properties( 11_bmg_moe_gemm_bf16 PROPERTIES CXX_COMPILER_LAUNCHER "IGC_VectorAliasBBThreshold=10000" )
set_target_properties( 11_bmg_moe_gemm_bf16 PROPERTIES CXX_LINKER_LAUNCHER "IGC_VectorAliasBBThreshold=10000" )
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ if(CUTLASS_ENABLE_SYCL)
08_bmg_gemm_f8
09_bmg_grouped_gemm_f8
10_bmg_grouped_gemm_mixed_dtype
11_bmg_moe_gemm_bf16
)
add_subdirectory(${EXAMPLE})
endforeach()
Expand Down
10 changes: 6 additions & 4 deletions include/cutlass/epilogue/collective/builders/xe_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ namespace detail {
template <
class ElementD,
class ElementCompute,
class ElementC
class ElementC,
cutlass::FloatRoundStyle RoundStyle_,
bool supportSource_
>
struct FusionOpInfo<cutlass::epilogue::fusion::LinearCombination<
ElementD, ElementCompute, ElementC, ElementCompute
ElementD, ElementCompute, ElementC, ElementCompute, RoundStyle_, supportSource_
>> {
constexpr static bool HasBuilder = true;

Expand All @@ -63,7 +65,7 @@ namespace detail {
class>
using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks<
DispatchPolicy,
cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementCompute>,
cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementCompute, RoundStyle_, supportSource_>,
TileShape_MNK,
EpilogueTile
>;
Expand Down Expand Up @@ -193,7 +195,7 @@ template <

//TODO(Codeplay): Should FusionCallbacks use DispatchPolicy IntelXeGroupEpilogue for group gemm? That does not work.
using FusionCallbacks = typename detail::FusionOpInfo<FusionOpOrCallbacks>::template FusionCallbacks<
IntelXeXMX16, TileShape_MNK, TileShape_MNK, CopyOpG2R>;
std::conditional_t<IsGroup, IntelXeXMX16Group, IntelXeXMX16>, TileShape_MNK, TileShape_MNK, CopyOpG2R>;

using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue<
DispatchPolicy,
Expand Down
45 changes: 42 additions & 3 deletions include/cutlass/epilogue/collective/xe_array_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp"
#include "cutlass/detail/layout.hpp"
#include "../tools/util/include/cutlass/util/packed_stride.hpp"

#include "cute/tensor.hpp"

Expand Down Expand Up @@ -114,8 +115,10 @@ class CollectiveEpilogue<
using ElementScalar = typename FusionCallbacks::ElementScalar;
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;

static_assert(cute::is_same_v<typename FusionCallbacks::Operation,
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle>>,

static_assert(cute::is_any_of_v<typename FusionCallbacks::Operation,
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle, false>,
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle, true>>,
"Only Linear Combination Epilogue is supported for Grouped GEMM at the moment.");

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
Expand All @@ -139,7 +142,7 @@ class CollectiveEpilogue<
Layout<CopyThreadShape>{},
make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))));
private:
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
constexpr static bool is_source_supported = not cute::is_void_v<ElementC> && FusionCallbacks::Operation::IsSourceSupported;
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>;

public:
Expand Down Expand Up @@ -475,6 +478,42 @@ class CollectiveEpilogue<
return cute::make_tuple(mC_mnl, mD_mnl);
}

template <typename ProblemShape_MNKL>
CUTLASS_DEVICE auto
update_tensor_shape_stride(int32_t const &next_group,
ProblemShape_MNKL const &problem_shape_mnkl,
const int32_t *num_rows_per_expert) {
auto [M, N, K, L] = problem_shape_mnkl;
int32_t cumulative_M = 0;
for (int i = 0; i < next_group; i++) {
cumulative_M += num_rows_per_expert[i];
}
M = num_rows_per_expert[next_group];

TensorC mC_mnl;
TensorD mD_mnl;
if constexpr (is_source_supported) {
ElementC const *ptr_C_curr_batch =
reinterpret_cast<ElementC const *>((void*)(params.ptr_C)) +
cumulative_M * N;
mC_mnl = make_tensor(
make_gmem_ptr(ptr_C_curr_batch),
make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride(
InternalStrideC{}, {M, N, 1})));
}

if constexpr (is_destination_supported) {
ElementD *ptr_D_curr_batch =
reinterpret_cast<ElementD *>((void*)(params.ptr_D)) +
cumulative_M * N;
mD_mnl = make_tensor(
make_gmem_ptr(ptr_D_curr_batch),
make_layout(make_shape(M, N, L), cutlass::make_cute_packed_stride(
InternalStrideD{}, {M, N, 1})));
}
return cute::make_tuple(mC_mnl, mD_mnl);
}

private:
Params const& params;
FusionCallbacks fusion_callbacks;
Expand Down
6 changes: 4 additions & 2 deletions include/cutlass/epilogue/fusion/operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,16 @@ template<
class ElementCompute_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest,
bool supportSource = true
>
struct LinearCombination
: ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
using ElementSource = ElementSource_;
static constexpr bool IsSourceSupported = true;
static constexpr bool IsSourceSupported = supportSource;
};


// D = activation(alpha * acc + beta * C)
template<
template <class> class ActivationFn_,
Expand Down
14 changes: 8 additions & 6 deletions include/cutlass/epilogue/fusion/xe_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ template <
class ElementScalar_,
FloatRoundStyle RoundStyle_,
class CtaTileShapeMNK_,
class EpilogueTile_
class EpilogueTile_,
bool supportSource_
>
struct FusionCallbacks<
epilogue::IntelXeXMX16,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_, supportSource_>,
CtaTileShapeMNK_,
EpilogueTile_
> : Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput_>::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
Expand All @@ -77,7 +78,7 @@ struct FusionCallbacks<
using ElementCompute = ElementCompute_;
using ElementSource = ElementSource_;
using ElementScalar = ElementScalar_;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource_, ElementScalar, RoundStyle_>;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource_, ElementScalar, RoundStyle_, supportSource_>;

struct Arguments {
ElementScalar alpha = ElementScalar(1);
Expand Down Expand Up @@ -730,11 +731,12 @@ template <
class ElementScalar_,
FloatRoundStyle RoundStyle_,
class CtaTileShapeMNK_,
class EpilogueTile_
class EpilogueTile_,
bool supportSource_
>
struct FusionCallbacks<
epilogue::IntelXeXMX16Group,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_, supportSource_>,
CtaTileShapeMNK_,
EpilogueTile_
> : Sm90LinearCombinationPtrArray<typename cutlass::detail::get_unpacked_element_type<ElementOutput_>::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
Expand All @@ -744,7 +746,7 @@ struct FusionCallbacks<
using ElementCompute = ElementCompute_;
using ElementSource = ElementSource_;
using ElementScalar = ElementScalar_;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle_>;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle_, supportSource_>;

struct Arguments {
ElementScalar alpha = ElementScalar(1);
Expand Down
4 changes: 2 additions & 2 deletions include/cutlass/gemm/collective/builders/xe_mma_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ struct CollectiveBuilder<
cutlass::gemm::collective::StageCountAuto,
KernelScheduleType,
cute::enable_if_t<
cute::is_any_of_v<KernelScheduleType, KernelScheduleAuto, KernelXe, KernelXeCooperative, KernelXePtrArrayCooperative> &&
cute::is_any_of_v<KernelScheduleType, KernelScheduleAuto, KernelXe, KernelXeCooperative, KernelXePtrArrayCooperative, KernelXeMoEGEMM> &&
cute::is_any_of_v<ElementA, bfloat16_t, half_t, cute::float_e5m2_t, cute::float_e4m3_t, cute::int8_t> &&
cute::is_any_of_v<ElementB, bfloat16_t, half_t, cute::float_e5m2_t, cute::float_e4m3_t, cute::int8_t, cute::uint4_t>
>
Expand Down Expand Up @@ -190,7 +190,7 @@ struct CollectiveBuilder<
Layout<TileShape_MNK>,
Layout<Shape<atoms_M, atoms_N, _1>, Stride<atoms_N, _1, _0>>>::TiledMMA;

static constexpr bool IsGroup = cute::is_same_v<KernelScheduleType, KernelXePtrArrayCooperative>;
static constexpr bool IsGroup = cute::is_any_of_v<KernelScheduleType, KernelXePtrArrayCooperative, KernelXeMoEGEMM>;

using KernelSchedule = std::conditional_t<cute::is_same_v<KernelScheduleType, KernelScheduleAuto>, KernelXe, KernelScheduleType>;
static constexpr int PipelineStages = IsGroup ? 2 : 3;
Expand Down
33 changes: 33 additions & 0 deletions include/cutlass/gemm/collective/xe_array_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "../tools/util/include/cutlass/util/packed_stride.hpp"

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -292,6 +293,38 @@ struct CollectiveMma<MainloopIntelXeXMX16Group<Stages, Schedule>, TileShape_, El

return cute::make_tuple(mA, mB);
}

template <typename ProblemShape_MNKL>
CUTLASS_DEVICE auto
update_tensor_shape_stride(Params const &mainloop_params,
int32_t const &next_group,
ProblemShape_MNKL const &problem_shape_mnkl,
const int *num_rows_per_expert) {
int32_t cumulative_M = 0;
for (int i = 0; i < next_group; i++) {
cumulative_M += num_rows_per_expert[i];
}

const int32_t M = num_rows_per_expert[next_group];
const int32_t N = get<1>(problem_shape_mnkl);
const int32_t K = get<2>(problem_shape_mnkl);

ElementA const *ptr_A_curr_batch =
reinterpret_cast<ElementA const *>((void*)(mainloop_params.ptr_A)) +
cumulative_M * K;
ElementB const *ptr_B_curr_batch =
reinterpret_cast<ElementB const *>((void*)(mainloop_params.ptr_B)) +
next_group * K * N;

Tensor mA = make_tensor(
make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K, (int32_t)1),
cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1}));
Tensor mB = make_tensor(
make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K, (int32_t)1),
cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1}));

return cute::make_tuple(mA, mB);
}
};

} // namespace cutlass::gemm::collective
Expand Down
13 changes: 10 additions & 3 deletions include/cutlass/gemm/dispatch_policy.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2025 Intel Corporation
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -140,6 +141,7 @@ struct KernelTmaWarpSpecializedCooperativeMixedInput: KernelTmaWarpSpecializedCo
struct KernelXe { };
struct KernelXeCooperative { };
struct KernelXePtrArrayCooperative { };
struct KernelXeMoEGEMM { };
//////////////////////////////////////////////////////////////////////////////

//
Expand Down Expand Up @@ -1214,9 +1216,14 @@ struct MainloopIntelXeXMX16 {
using ClusterShape = Shape<_1,_1,_1>;
};

template<int Stages_, class KernelScheduler = KernelXePtrArrayCooperative>
struct MainloopIntelXeXMX16Group : MainloopIntelXeXMX16<Stages_, KernelScheduler> {
};
template <int Stages_, class KernelScheduler = KernelXePtrArrayCooperative>
struct MainloopIntelXeXMX16Group
: MainloopIntelXeXMX16<Stages_, KernelScheduler> {};

// partial specialization for KernelXeMoEGEMM
template <int Stages_>
struct MainloopIntelXeXMX16Group<Stages_, KernelXeMoEGEMM>
: MainloopIntelXeXMX16<Stages_, KernelXeMoEGEMM> {};

template<int Stages_, class KernelScheduler = KernelXePtrArrayCooperative>
struct MainloopIntelXeXMX16GroupMixedPrecision : MainloopIntelXeXMX16<Stages_, KernelScheduler> {
Expand Down
1 change: 1 addition & 0 deletions include/cutlass/gemm/kernel/gemm_universal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U

#if defined(SYCL_INTEL_TARGET)
#include "cutlass/gemm/kernel/xe_gemm.hpp"
#include "cutlass/gemm/kernel/xe_moe_gemm.hpp"
#include "cutlass/gemm/kernel/xe_gemm_cooperative.hpp"
#include "cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp"
#endif
Expand Down
1 change: 1 addition & 0 deletions include/cutlass/gemm/kernel/tile_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct StaticPersistentScheduler { };
#if defined (SYCL_INTEL_TARGET)
#include "cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp"
#include "cutlass/gemm/kernel/xe_tile_scheduler_group.hpp"
#include "cutlass/gemm/kernel/xe_tile_scheduler_moe.hpp"
#endif
////////////////////////////////////////////////////////////////////////////////

Expand Down
Loading
Loading