Skip to content

Commit 3414a0c

Browse files
committed
s8s32 UT with new mma and copy atoms
1 parent ffb0d54 commit 3414a0c

File tree

5 files changed

+205
-9
lines changed

5 files changed

+205
-9
lines changed

test/unit/gemm/device/CMakeLists.txt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,27 @@
3030
if(CUTLASS_ENABLE_SYCL)
3131
if(SYCL_INTEL_TARGET)
3232
cutlass_test_unit_add_executable(
33-
cutlass_test_unit_gemm_device_tensorop_xe
33+
cutlass_test_unit_gemm_device_tensorop_xe_legacy
3434
xe_gemm_bf16_bf16_bf16_tensor_op_bf16.cpp
3535
xe_gemm_fp16_fp16_fp16_tensor_op_fp16.cpp
3636
xe_gemm_bf16_bf16_bf16_tensor_op_fp32.cpp
3737
xe_gemm_bf16_bf16_fp32_tensor_op_bf16.cpp
3838
xe_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp
3939
xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp
4040
xe_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp
41-
xe_gemm_s8_s8_s32_tensor_op_s32.cpp
41+
xe_gemm_s8_s8_s32_tensor_op_s32_legacy.cpp
4242
xe_gemm_tf32_tf32_fp32_tensor_op_fp32.cpp
4343
xe_gemm_f8_f8_fp32_tensor_op_fp32.cpp
4444
xe_gemm_fp16_s8_fp32_tensor_op_fp32.cpp
4545
gemm_universal_bf16n_bf16t_f32n_tensor_op_f32_xe.cpp
4646
)
4747

48+
# TODO :- Port remaining legacy tests after enabling new atoms
49+
cutlass_test_unit_add_executable(
50+
cutlass_test_unit_gemm_device_tensorop_xe
51+
xe_gemm_s8_s8_s32_tensor_op_s32.cpp
52+
)
53+
4854
cutlass_test_unit_add_executable(
4955
cutlass_test_unit_gemm_device_tensorop_cooperative_xe
5056
xe_gemm_bf16_bf16_fp32_tensor_op_fp32_cooperative.cpp
@@ -93,7 +99,7 @@ if(CUTLASS_ENABLE_SYCL)
9399
cutlass_test_unit_gemm_device_mixed_input_tensorop_xe
94100
cutlass_test_unit_gemm_device_tensorop_xe_group_gemm
95101
cutlass_test_unit_gemm_device_mixed_dtype_tensorop_xe_group_gemm
96-
cutlass_test_unit_gemm_device_tensorop_xe
102+
cutlass_test_unit_gemm_device_tensorop_xe_legacy
97103
)
98104

99105
add_custom_target(

test/unit/gemm/device/default_gemm_configuration.hpp

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,17 @@ struct DefaultGemmConfigurationToCutlass3Types {
6262
static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists.");
6363
};
6464

65+
template<
66+
class OperatorClass, class ArchTag,
67+
class ElementA, class LayoutA,
68+
class ElementB, class LayoutB,
69+
class ElementC, class LayoutC,
70+
class ElementAccumulator>
71+
struct XeLegacyGemmConfigurationToCutlass3Types {
72+
static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists.");
73+
};
74+
75+
6576
// This type is only intended to demonstrate porting 2.x kernels to 3.0
6677
template<
6778
class OperatorClass, class ArchTag,
@@ -1901,9 +1912,9 @@ struct DefaultGemmConfigurationToCutlass3Types<
19011912

19021913
///////////////////////////////////////////////////////////////////////////////
19031914

1904-
// Intel XE MMA S32S8
1915+
// Intel XE MMA S32S8 Legacy
19051916
template <typename LayoutA, typename LayoutB, typename LayoutC>
1906-
struct DefaultGemmConfigurationToCutlass3Types<
1917+
struct XeLegacyGemmConfigurationToCutlass3Types<
19071918
arch::OpClassTensorOp, arch::IntelXe,
19081919
int8_t, LayoutA,
19091920
int8_t, LayoutB,
@@ -1961,6 +1972,64 @@ struct DefaultGemmConfigurationToCutlass3Types<
19611972

19621973
///////////////////////////////////////////////////////////////////////////////
19631974

1975+
// Intel XE MMA S32S8
1976+
template <typename LayoutA, typename LayoutB, typename LayoutC>
1977+
struct DefaultGemmConfigurationToCutlass3Types<
1978+
arch::OpClassTensorOp, arch::IntelXe,
1979+
int8_t, LayoutA,
1980+
int8_t, LayoutB,
1981+
int32_t, LayoutC,
1982+
int32_t>
1983+
{
1984+
using TileShape = Shape<_256, _256, _32>;
1985+
1986+
using GEMMDispatchPolicy = gemm::MainloopXeL1Staged<3>;
1987+
1988+
using TiledMma =
1989+
typename TiledMMAHelper<
1990+
MMA_Atom<XE_DPAS_TT<8, int32_t, int8_t>>,
1991+
Layout<TileShape>,
1992+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>
1993+
>::TiledMMA;
1994+
1995+
using GmemTiledCopyA = void;
1996+
using GmemTiledCopyB = void;
1997+
1998+
// Mainloop
1999+
using CollectiveMainloop = collective::CollectiveMma<
2000+
GEMMDispatchPolicy, TileShape,
2001+
int8_t, TagToStrideA_t<LayoutA>,
2002+
int8_t, TagToStrideB_t<LayoutB>,
2003+
TiledMma,
2004+
GmemTiledCopyA, void, void, cute::identity, // A
2005+
GmemTiledCopyB, void, void, cute::identity // B
2006+
>;
2007+
2008+
using EpilogueDispatchPolicy = epilogue::IntelXeGeneric;
2009+
using EpilogueOp = epilogue::fusion::LinearCombination<int32_t, int32_t>;
2010+
2011+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
2012+
EpilogueDispatchPolicy,
2013+
EpilogueOp,
2014+
TileShape,
2015+
decltype(tile_shape(TiledMma()))
2016+
>;
2017+
2018+
using GmemTiledCopyC = XE_LOAD_2D<32, 8, 16>;
2019+
using GmemTiledCopyD = XE_STORE_2D<32, 8, 16>;
2020+
2021+
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
2022+
EpilogueDispatchPolicy,
2023+
TileShape,
2024+
int32_t, TagToStrideC_t<LayoutC>,
2025+
int32_t, TagToStrideC_t<LayoutC>,
2026+
FusionCallBacks,
2027+
GmemTiledCopyC, void, void,
2028+
GmemTiledCopyD, void, void>;
2029+
};
2030+
2031+
///////////////////////////////////////////////////////////////////////////////
2032+
19642033
namespace detail {
19652034

19662035
//
@@ -2002,7 +2071,7 @@ struct DefaultGemm_TensorOpXe_OperandB<tfloat32_t, layout::ColumnMajor, 32, Size
20022071

20032072
///////////////////////////////////////////////////////////////////////////////
20042073

2005-
// Intel XE MMA S32S8
2074+
// Intel XE MMA F32TF32
20062075
template <typename LayoutA, typename LayoutB, typename LayoutC>
20072076
struct DefaultGemmConfigurationToCutlass3Types<
20082077
arch::OpClassTensorOp, arch::IntelXe,
@@ -2158,6 +2227,7 @@ struct DefaultGemmConfigurationToCutlass3Types<
21582227
XE_2D_U32x8x16_ST_N, void, void>;
21592228
};
21602229

2230+
// Intel XE MMA FP32FP16
21612231
template <typename LayoutA, typename LayoutB, typename LayoutC>
21622232
struct DefaultGemmConfigurationToCutlass3Types<
21632233
arch::OpClassTensorOp, arch::IntelXe,

test/unit/gemm/device/gemm_testbed_3x.hpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
#include "cute/layout.hpp"
7171
#include "cute/numeric/int.hpp"
7272

73+
#include "cutlass/util/GPU_Clock.hpp"
74+
7375
namespace test {
7476
namespace gemm {
7577
namespace device {
@@ -3041,7 +3043,7 @@ struct TestbedImpl {
30413043

30423044
if (status != cutlass::Status::kSuccess) {
30433045
#if defined(CUTLASS_ENABLE_SYCL)
3044-
std::cerr << "This test is not supported." << "\n";
3046+
std::cerr << "This test is not supported. - gemm_op can_implement failed" << "\n";
30453047
return true;
30463048
#else
30473049
cudaError_t error = cudaGetLastError();
@@ -3069,7 +3071,7 @@ struct TestbedImpl {
30693071
status = gemm_op.initialize(arguments, workspace.get());
30703072
if (status != cutlass::Status::kSuccess) {
30713073
#if defined(CUTLASS_ENABLE_SYCL)
3072-
std::cerr << "This test is not supported." << "\n";
3074+
std::cerr << "This test is not supported. - gemm_op initialize failed" << "\n";
30733075
#else
30743076
cudaError_t error = cudaGetLastError();
30753077
const auto error_str = cudaGetErrorString(error);
@@ -3079,10 +3081,28 @@ struct TestbedImpl {
30793081
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
30803082
CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.run");
30813083
#endif
3084+
GPU_Clock timer;
3085+
if (profiling)
3086+
timer.start();
30823087
status = gemm_op.run();
30833088
#if defined(CUTLASS_ENABLE_SYCL)
30843089
try {
30853090
compat::wait_and_throw();
3091+
if (profiling) {
3092+
double time = timer.seconds();
3093+
auto m = cute::get<0>(problem_size);
3094+
auto n = cute::get<1>(problem_size);
3095+
auto k = cute::get<2>(problem_size);
3096+
auto l = cute::get<3>(problem_size);
3097+
double tops = (2.0 * m * n * k * l) * 1e-12;
3098+
printf(
3099+
"[Perf] M=%d N=%d K=%d L=%d | "
3100+
"-> [%4.3f] Tops/s (%.4f ms)\n",
3101+
m, n, k, l,
3102+
tops / time,
3103+
time * 1000
3104+
);
3105+
}
30863106
} catch (std::exception const &e) {
30873107
ADD_FAILURE() << "Error at Kernel Sync.";
30883108
return false;

test/unit/gemm/device/xe_gemm_s8_s8_s32_tensor_op_s32.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ struct XE_Device_Gemm_s8_s8_s32_tensor_op_s32 {
5858
typename Config::CollectiveMainloop,
5959
typename Config::CollectiveEpilogue>;
6060

61-
using Gemm = gemm::device::GemmUniversalAdapter<GemmKernel>;
61+
struct Gemm : public gemm::device::GemmUniversalAdapter<GemmKernel> {
62+
static constexpr int kAlignmentA = 16;
63+
static constexpr int kAlignmentB = 16;
64+
static constexpr int kAlignmentC = 4;
65+
static constexpr int kAlignmentD = 4;
66+
};
6267
};
6368

6469
TEST(XE_Device_Gemm_s8t_s8t_s32t_tensor_op_s32, 256x256x32) {
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
32+
/*! \file
33+
\brief Tests for Xe s8_s8_s32
34+
*/
35+
36+
37+
#include "cutlass/cutlass.h"
38+
39+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
40+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
41+
#include "default_gemm_configuration.hpp"
42+
43+
#include "gemm_testbed_3x.hpp"
44+
45+
namespace cutlass {
46+
namespace {
47+
template <typename LayoutA, typename LayoutB>
48+
struct XE_Device_Gemm_s8_s8_s32_tensor_op_s32 {
49+
using Config = gemm::device::XeLegacyGemmConfigurationToCutlass3Types<
50+
arch::OpClassTensorOp, arch::IntelXe,
51+
int8_t, LayoutA,
52+
int8_t, LayoutB,
53+
int32_t, layout::RowMajor,
54+
int32_t>;
55+
56+
using GemmKernel = gemm::kernel::GemmUniversal<
57+
cute::Shape<int, int, int, int>,
58+
typename Config::CollectiveMainloop,
59+
typename Config::CollectiveEpilogue>;
60+
61+
using Gemm = gemm::device::GemmUniversalAdapter<GemmKernel>;
62+
};
63+
64+
TEST(XE_Device_Gemm_s8t_s8t_s32t_tensor_op_s32, 256x256x32) {
65+
using LayoutA = layout::RowMajor;
66+
using LayoutB = layout::RowMajor;
67+
using Gemm = XE_Device_Gemm_s8_s8_s32_tensor_op_s32<LayoutA, LayoutB>::Gemm;
68+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
69+
}
70+
71+
// TODO(Codeplay): Test on XE2 because the copy function is not available in the IGC driver for PVC
72+
TEST(XE2_Device_Gemm_s8n_s8t_s32t_tensor_op_s32, 64x128x32) {
73+
using LayoutA = layout::ColumnMajor;
74+
using LayoutB = layout::RowMajor;
75+
using Gemm = XE_Device_Gemm_s8_s8_s32_tensor_op_s32<LayoutA, LayoutB>::Gemm;
76+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
77+
}
78+
79+
TEST(XE_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x32) {
80+
using LayoutA = layout::RowMajor;
81+
using LayoutB = layout::ColumnMajor;
82+
using Gemm = XE_Device_Gemm_s8_s8_s32_tensor_op_s32<LayoutA, LayoutB>::Gemm;
83+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
84+
}
85+
86+
// TODO(Codeplay): Test on XE2 because the copy function is not available in the IGC driver for PVC
87+
TEST(XE2_Device_Gemm_s8n_s8n_s32t_tensor_op_s32, 64x128x32) {
88+
using LayoutA = layout::ColumnMajor;
89+
using LayoutB = layout::ColumnMajor;
90+
using Gemm = XE_Device_Gemm_s8_s8_s32_tensor_op_s32<LayoutA, LayoutB>::Gemm;
91+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
92+
}
93+
94+
}
95+
} // namespace cutlass

0 commit comments

Comments
 (0)