Skip to content

Commit aecfb09

Browse files
authored
Gemm Universal unit tests for MainloopIntelW8A8 API (#584)
## Unit tests for MainloopIntelW8A8 FP8 GEMM operations on Intel Xe Comprehensive test coverage for Intel GPU FP8 GEMM kernels including: - **LLM workloads**: LLaMA2-7B, Mistral-7B configurations - **Parallelization**: Tensor/model parallel scenarios - **Batch sizes**: Micro-batch (4x) to large batch (32x) - **Matrix shapes**: Small (64²) to large (2048²), tall/wide matrices - **Edge cases**: Large K/N dimensions Tests validate `MainloopIntelW8A8` dispatch policy with FP8→FP32 precision on Intel Xe XMX architecture.
2 parents 4e2f5f8 + fc4aaf5 commit aecfb09

File tree

4 files changed

+223
-1
lines changed

4 files changed

+223
-1
lines changed

test/unit/cute/intel_xe/xe_copy_2d_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,4 +288,4 @@ TEST(PVC_CuTe_Xe, XE_COPY_2D_SKIPPED) {
288288
GTEST_SKIP() << "XE_COPY_2D tests require IGC version 2.18 or higher. skipped";
289289
}
290290

291-
#endif
291+
#endif

test/unit/gemm/device/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ if(CUTLASS_ENABLE_SYCL)
4343
xe_gemm_f8_f8_fp32_tensor_op_fp32.cpp
4444
xe_gemm_fp16_s8_fp32_tensor_op_fp32.cpp
4545
gemm_universal_bf16n_bf16t_f32n_tensor_op_f32_xe.cpp
46+
gemm_universal_fp8_fp8_fp32_tensor_op_f32_xe_models.cpp
4647
)
4748

4849
cutlass_test_unit_add_executable(

test/unit/gemm/device/gemm_testbed_3x.hpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4187,6 +4187,48 @@ bool TestXe(
41874187
} // m
41884188
return passed;
41894189
}
4190+
4191+
template <typename Gemm, template <class T> class ActivationFunctor =
4192+
cutlass::epilogue::thread::Identity>
4193+
bool TestXe(
4194+
int m, int n, int k, int l,
4195+
double alpha = 1.0,
4196+
double beta = cute::is_same_v<typename Gemm::GemmKernel::ElementC, void> ? 0.0 : 1.0,
4197+
CheckEquality check_relative_equality = CheckEquality::RELATIVE) {
4198+
using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar;
4199+
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
4200+
4201+
Testbed3x<Gemm, ActivationFunctor, false,
4202+
typename Gemm::GemmKernel::ElementA,
4203+
typename Gemm::GemmKernel::ElementB,
4204+
typename Gemm::GemmKernel::ElementC,
4205+
typename Gemm::GemmKernel::ElementD> testbed(
4206+
check_relative_equality, ScalarLoc::ON_HOST, VectorScale::DISABLED);
4207+
4208+
bool passed = true;
4209+
ProblemShapeType problem_size{m, n, k, l};
4210+
try {
4211+
passed = testbed.run(problem_size,
4212+
cutlass::from_real<ElementScalar>(alpha),
4213+
cutlass::from_real<ElementScalar>(beta));
4214+
}
4215+
catch (std::exception const& e) {
4216+
EXPECT_TRUE(false) << "TestXe: testbed.run threw an exception: " << e.what();
4217+
throw;
4218+
}
4219+
catch (...) {
4220+
EXPECT_TRUE(false) << "TestXe: testbed.run threw an unknown exception for MNKL = "
4221+
<< m << " " << n << " " << k << " " << l;
4222+
throw;
4223+
}
4224+
4225+
EXPECT_TRUE(passed) << "TestXe: testbed.run failed for MNKL = "
4226+
<< m << " " << n << " " << k << " " << l
4227+
<< ", alpha: " << alpha << ", beta: " << beta;
4228+
4229+
return passed;
4230+
}
4231+
41904232
#endif
41914233

41924234
template <typename Gemm>
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/***************************************************************************************************
2+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
/*! \file
32+
\brief Tests for device-wide GEMM interface
33+
34+
*/
35+
#include <gtest/gtest.h>
36+
#include "cutlass/cutlass.h"
37+
#include "cutlass/gemm/collective/collective_mma.hpp"
38+
#include "cutlass/gemm/dispatch_policy.hpp"
39+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
40+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
41+
#include "default_gemm_configuration.hpp"
42+
#include "gemm_testbed_3x.hpp"
43+
44+
using namespace cutlass;
45+
46+
namespace {
47+
48+
template<typename LayoutA, typename LayoutB>
49+
struct MainloopIntelW8A8_GemmConfig {
50+
using ElementA = float_e5m2_t;
51+
using ElementB = float_e5m2_t;
52+
using TileShape = Shape<_256, _256, _32>;
53+
constexpr static int PipelineStages = 2;
54+
using Schedule = gemm::KernelXe;
55+
using TiledMma = typename TiledMMAHelper<
56+
MMA_Atom<XE_8x16x16_F32F16F16F32_TT>,
57+
Layout<TileShape>,
58+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>
59+
>::TiledMMA;
60+
using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
61+
using GmemTiledCopyB = XE_2D_U8x32x32_LD_V;
62+
63+
using DispatchPolicy = gemm::MainloopIntelW8A8<PipelineStages, Schedule>;
64+
65+
using CollectiveMainloop = gemm::collective::CollectiveMma<
66+
DispatchPolicy, TileShape,
67+
ElementA, cutlass::gemm::TagToStrideA_t<LayoutA>,
68+
ElementB, cutlass::gemm::TagToStrideB_t<LayoutB>,
69+
TiledMma,
70+
GmemTiledCopyA, void, void, cute::identity, // A
71+
GmemTiledCopyB, void, void, cute::identity // B
72+
>;
73+
74+
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<
75+
float, float
76+
>;
77+
78+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
79+
cutlass::epilogue::IntelXeXMX16,
80+
EpilogueOp,
81+
TileShape,
82+
decltype(tile_shape(TiledMma()))
83+
>;
84+
85+
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
86+
cutlass::epilogue::IntelXeXMX16,
87+
TileShape,
88+
float, cutlass::gemm::TagToStrideC_t<layout::RowMajor>,
89+
float, cutlass::gemm::TagToStrideC_t<layout::RowMajor>,
90+
FusionCallBacks,
91+
XE_2D_U32x8x16_LD_N, void, void,
92+
XE_2D_U32x8x16_ST_N, void, void
93+
>;
94+
95+
using GemmKernel = gemm::kernel::GemmUniversal<
96+
cute::Shape<int, int, int, int>,
97+
CollectiveMainloop,
98+
CollectiveEpilogue
99+
>;
100+
101+
using Gemm = gemm::device::GemmUniversalAdapter<GemmKernel>;
102+
};
103+
104+
TEST(MainloopIntelW8A8_Special, LargeModel_LLaMA2_7B) {
105+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
106+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(4096, 4096, 11008, 1, 1.0, 0.0));
107+
}
108+
109+
TEST(MainloopIntelW8A8_Special, LargeModel_Mistral_7B) {
110+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
111+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(4096, 4096, 14336, 1, 1.0, 0.0));
112+
}
113+
114+
TEST(MainloopIntelW8A8_Special, TensorParallel) {
115+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
116+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(4096, 1024, 4096, 1, 1.0, 0.0));
117+
}
118+
119+
TEST(MainloopIntelW8A8_Special, ModelParallel) {
120+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
121+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1024, 4096, 4096, 1, 1.0, 0.0));
122+
}
123+
124+
TEST(MainloopIntelW8A8_Special, MicroBatch) {
125+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
126+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(128, 128, 8192, 4, 1.0, 0.0));
127+
}
128+
129+
TEST(MainloopIntelW8A8_Special, LargeBatch) {
130+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
131+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(512, 512, 2048, 32, 1.0, 0.0));
132+
}
133+
134+
TEST(MainloopIntelW8A8_Special, SquareSmall) {
135+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
136+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(64, 64, 64, 1, 1.0, 0.0));
137+
}
138+
139+
TEST(MainloopIntelW8A8_Special, SquareMedium) {
140+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
141+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(512, 512, 512, 1, 1.0, 0.0));
142+
}
143+
144+
TEST(MainloopIntelW8A8_Special, SquareLarge) {
145+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
146+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(2048, 2048, 2048, 1, 1.0, 0.0));
147+
}
148+
149+
TEST(MainloopIntelW8A8_Special, TallMatrix) {
150+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
151+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(4096, 512, 4096, 1, 1.0, 0.0));
152+
}
153+
154+
TEST(MainloopIntelW8A8_Special, WideMatrix) {
155+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
156+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(512, 4096, 4096, 1, 1.0, 0.0));
157+
}
158+
159+
TEST(MainloopIntelW8A8_Special, Batch8) {
160+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
161+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(512, 512, 512, 8, 1.0, 0.0));
162+
}
163+
164+
TEST(MainloopIntelW8A8_Special, Batch16Large) {
165+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
166+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1024, 1024, 1024, 16, 1.0, 0.0));
167+
}
168+
169+
TEST(MainloopIntelW8A8_Special, LargeK) {
170+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
171+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(64, 64, 8192, 1, 1.0, 0.0));
172+
}
173+
174+
TEST(MainloopIntelW8A8_Special, LargeN) {
175+
using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
176+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(64, 8192, 64, 1, 1.0, 0.0));
177+
}
178+
179+
} // namespace

0 commit comments

Comments
 (0)