1+ /* **************************************************************************************************
2+ * Copyright (C) 2025 Intel Corporation, All rights reserved.
3+ * SPDX-License-Identifier: BSD-3-Clause
4+ *
5+ * Redistribution and use in source and binary forms, with or without
6+ * modification, are permitted provided that the following conditions are met:
7+ *
8+ * 1. Redistributions of source code must retain the above copyright notice, this
9+ * list of conditions and the following disclaimer.
10+ *
11+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12+ * this list of conditions and the following disclaimer in the documentation
13+ * and/or other materials provided with the distribution.
14+ *
15+ * 3. Neither the name of the copyright holder nor the names of its
16+ * contributors may be used to endorse or promote products derived from
17+ * this software without specific prior written permission.
18+ *
19+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+ *
30+ **************************************************************************************************/
31+ /* ! \file
32+ \brief Tests for device-wide GEMM interface
33+
34+ */
35+ #include < gtest/gtest.h>
36+ #include " cutlass/cutlass.h"
37+ #include " cutlass/gemm/collective/collective_mma.hpp"
38+ #include " cutlass/gemm/dispatch_policy.hpp"
39+ #include " cutlass/gemm/device/gemm_universal_adapter.h"
40+ #include " cutlass/gemm/kernel/gemm_universal.hpp"
41+ #include " default_gemm_configuration.hpp"
42+ #include " gemm_testbed_3x.hpp"
43+
44+ using namespace cutlass ;
45+
46+ namespace {
47+
48+ template <typename LayoutA, typename LayoutB>
49+ struct MainloopIntelW8A8_GemmConfig {
50+ using ElementA = float_e5m2_t ;
51+ using ElementB = float_e5m2_t ;
52+ using TileShape = Shape<_256, _256, _32>;
53+ constexpr static int PipelineStages = 2 ;
54+ using Schedule = gemm::KernelXe;
55+ using TiledMma = typename TiledMMAHelper<
56+ MMA_Atom<XE_8x16x16_F32F16F16F32_TT>,
57+ Layout<TileShape>,
58+ Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>
59+ >::TiledMMA;
60+ using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
61+ using GmemTiledCopyB = XE_2D_U8x32x32_LD_V;
62+
63+ using DispatchPolicy = gemm::MainloopIntelW8A8<PipelineStages, Schedule>;
64+
65+ using CollectiveMainloop = gemm::collective::CollectiveMma<
66+ DispatchPolicy, TileShape,
67+ ElementA, cutlass::gemm::TagToStrideA_t<LayoutA>,
68+ ElementB, cutlass::gemm::TagToStrideB_t<LayoutB>,
69+ TiledMma,
70+ GmemTiledCopyA, void , void , cute::identity, // A
71+ GmemTiledCopyB, void , void , cute::identity // B
72+ >;
73+
74+ using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<
75+ float , float
76+ >;
77+
78+ using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
79+ cutlass::epilogue::IntelXeXMX16,
80+ EpilogueOp,
81+ TileShape,
82+ decltype (tile_shape(TiledMma()))
83+ >;
84+
85+ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
86+ cutlass::epilogue::IntelXeXMX16,
87+ TileShape,
88+ float , cutlass::gemm::TagToStrideC_t<layout::RowMajor>,
89+ float , cutlass::gemm::TagToStrideC_t<layout::RowMajor>,
90+ FusionCallBacks,
91+ XE_2D_U32x8x16_LD_N, void , void ,
92+ XE_2D_U32x8x16_ST_N, void , void
93+ >;
94+
95+ using GemmKernel = gemm::kernel::GemmUniversal<
96+ cute::Shape<int , int , int , int >,
97+ CollectiveMainloop,
98+ CollectiveEpilogue
99+ >;
100+
101+ using Gemm = gemm::device::GemmUniversalAdapter<GemmKernel>;
102+ };
103+
104+ TEST (MainloopIntelW8A8_Special, LargeModel_LLaMA2_7B) {
105+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
106+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(4096 , 4096 , 11008 , 1 , 1.0 , 0.0 ));
107+ }
108+
109+ TEST (MainloopIntelW8A8_Special, LargeModel_Mistral_7B) {
110+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
111+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(4096 , 4096 , 14336 , 1 , 1.0 , 0.0 ));
112+ }
113+
114+ TEST (MainloopIntelW8A8_Special, TensorParallel) {
115+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
116+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(4096 , 1024 , 4096 , 1 , 1.0 , 0.0 ));
117+ }
118+
119+ TEST (MainloopIntelW8A8_Special, ModelParallel) {
120+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
121+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(1024 , 4096 , 4096 , 1 , 1.0 , 0.0 ));
122+ }
123+
124+ TEST (MainloopIntelW8A8_Special, MicroBatch) {
125+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
126+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(128 , 128 , 8192 , 4 , 1.0 , 0.0 ));
127+ }
128+
129+ TEST (MainloopIntelW8A8_Special, LargeBatch) {
130+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
131+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(512 , 512 , 2048 , 32 , 1.0 , 0.0 ));
132+ }
133+
134+ TEST (MainloopIntelW8A8_Special, SquareSmall) {
135+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
136+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(64 , 64 , 64 , 1 , 1.0 , 0.0 ));
137+ }
138+
139+ TEST (MainloopIntelW8A8_Special, SquareMedium) {
140+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
141+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(512 , 512 , 512 , 1 , 1.0 , 0.0 ));
142+ }
143+
144+ TEST (MainloopIntelW8A8_Special, SquareLarge) {
145+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
146+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(2048 , 2048 , 2048 , 1 , 1.0 , 0.0 ));
147+ }
148+
149+ TEST (MainloopIntelW8A8_Special, TallMatrix) {
150+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
151+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(4096 , 512 , 4096 , 1 , 1.0 , 0.0 ));
152+ }
153+
154+ TEST (MainloopIntelW8A8_Special, WideMatrix) {
155+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
156+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(512 , 4096 , 4096 , 1 , 1.0 , 0.0 ));
157+ }
158+
159+ TEST (MainloopIntelW8A8_Special, Batch8) {
160+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
161+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(512 , 512 , 512 , 8 , 1.0 , 0.0 ));
162+ }
163+
164+ TEST (MainloopIntelW8A8_Special, Batch16Large) {
165+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
166+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(1024 , 1024 , 1024 , 16 , 1.0 , 0.0 ));
167+ }
168+
169+ TEST (MainloopIntelW8A8_Special, LargeK) {
170+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
171+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(64 , 64 , 8192 , 1 , 1.0 , 0.0 ));
172+ }
173+
174+ TEST (MainloopIntelW8A8_Special, LargeN) {
175+ using Gemm = typename MainloopIntelW8A8_GemmConfig<layout::RowMajor, layout::RowMajor>::Gemm;
176+ EXPECT_TRUE (test::gemm::device::TestXe<Gemm>(64 , 8192 , 64 , 1 , 1.0 , 0.0 ));
177+ }
178+
179+ } // namespace
0 commit comments