Skip to content

Commit a46b725

Browse files
mkumar16-amdillsilinThomasNingAviralGoelAMD
authored
Added Support for tile_grouped_gemm_preshuffle example (#2993)
* Added Support for tile_grouped_gemm_preshuffle example * Resolved PR comments + Added unit tests for preshuffle with persistent * Fixed CMake Build config error * Fix clang error that caused CI to fail * Fix clang formatting * Fix clang issue * Fix errors causing test cases to fail * Fix grouped_gemm_preshuffle unit test failure * Resolve PR comments * Cleaned code + removed unnecassary changes * Update test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp Co-authored-by: Aviral Goel <[email protected]> * Fix clang formatting * Made changes to improve code readability --------- Co-authored-by: Illia Silin <[email protected]> Co-authored-by: Thomas Ning <[email protected]> Co-authored-by: Aviral Goel <[email protected]>
1 parent 6c2ca12 commit a46b725

File tree

5 files changed

+278
-33
lines changed

5 files changed

+278
-33
lines changed

example/ck_tile/17_grouped_gemm/grouped_gemm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
182182
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
183183
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
184184
static constexpr bool Preshuffle = true;
185+
static constexpr bool Persistent = true;
185186
static constexpr bool DoubleSmemBuffer = true;
186187
};
187188

example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,113 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
167167
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
168168
}
169169

170+
template <typename GemmConfig,
171+
typename ALayout,
172+
typename BLayout,
173+
typename CLayout,
174+
typename ADataType,
175+
typename BDataType,
176+
typename AccDataType,
177+
typename CDataType>
178+
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
179+
const ck_tile::index_t num_groups,
180+
void* kargs_ptr,
181+
bool splitk)
182+
{
183+
using GemmShape = ck_tile::TileGemmShape<
184+
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
185+
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
186+
ck_tile::
187+
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
188+
using TilePartitioner =
189+
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
190+
GemmConfig::TileParitionerGroupNum,
191+
GemmConfig::TileParitionerM01>;
192+
193+
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
194+
GemmConfig::kPadN,
195+
GemmConfig::kPadK,
196+
GemmConfig::DoubleSmemBuffer,
197+
ALayout,
198+
BLayout,
199+
CLayout,
200+
GemmConfig::TransposeC,
201+
GemmConfig::UseStructuredSparsity,
202+
GemmConfig::Persistent,
203+
GemmConfig::NumWaveGroups,
204+
GemmConfig::Preshuffle>;
205+
206+
float ave_time{0};
207+
208+
const auto Run = [&](const auto memory_operation_) {
209+
constexpr auto scheduler = GemmConfig::Scheduler;
210+
constexpr auto memory_operation = memory_operation_.value;
211+
212+
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
213+
BDataType,
214+
AccDataType,
215+
GemmShape,
216+
GemmUniversalTraits,
217+
scheduler>;
218+
219+
using GemmPipeline = typename PipelineTypeTraits<
220+
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
221+
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
222+
ADataType,
223+
BDataType,
224+
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
225+
AccDataType,
226+
CDataType,
227+
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
228+
CLayout,
229+
ck_tile::element_wise::PassThrough,
230+
TilePartitioner::MPerBlock,
231+
TilePartitioner::NPerBlock,
232+
GemmConfig::M_Warp,
233+
GemmConfig::N_Warp,
234+
GemmConfig::M_Warp_Tile,
235+
GemmConfig::N_Warp_Tile,
236+
GemmConfig::K_Warp_Tile,
237+
UniversalGemmProblem::TransposeC,
238+
memory_operation>>;
239+
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
240+
const dim3 blocks = Kernel::BlockSize();
241+
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
242+
243+
if(s.log_level_ > 0)
244+
{
245+
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
246+
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
247+
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
248+
}
249+
250+
ave_time =
251+
ck_tile::launch_kernel(s,
252+
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
253+
Kernel{},
254+
grids,
255+
blocks,
256+
0,
257+
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
258+
num_groups));
259+
260+
return ave_time;
261+
};
262+
263+
if(splitk)
264+
{
265+
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
266+
ck_tile::memory_operation_enum::atomic_add>{});
267+
}
268+
else
269+
{
270+
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
271+
ck_tile::memory_operation_enum::set>{});
272+
}
273+
274+
return ave_time;
275+
}
276+
170277
#include "run_grouped_gemm_example.inc"
171278

172279
template <typename GemmConfig, typename PrecType>

example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,13 @@ float invoke_gemm(int n_warmup,
7070
}
7171
else
7272
{
73-
if(GemmConfig::Preshuffle)
74-
{
75-
// not supported yet
76-
throw std::runtime_error(
77-
"Persistent grouped gemm with preshuffle is not supported yet");
78-
}
79-
80-
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
81-
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
82-
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
83-
// commentCode has comments. Press enter to view. the gemm problems known on the host.
84-
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
85-
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
86-
// In this example however, we generate the `kargs` using the known gemm_descs,
87-
// and copy the gemm descriptions to the device memory.
88-
// The contents of the memory pointed to by `kargs_ptr` pointer could be
89-
// written by e.g. another kernel from earlier stage.
73+
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have the gemm
74+
// problems known on the host. Instead, we can just pass the pointer to the kernel and let
75+
// the workgroups figure out which tiles to work on. This is useful when the gemm problems
76+
// are generated dynamically. In this example however, we generate the `kargs` using the
77+
// known gemm_descs, and copy the gemm descriptions to the device memory. The contents of
78+
// the memory pointed to by `kargs_ptr` pointer could be written by e.g. another kernel from
79+
// earlier stage.
9080

9181
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
9282
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();

test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
#include "ck_tile/host.hpp"
99
#include "test_grouped_gemm_preshuffle_util.hpp"
1010

11-
using F16 = ck_tile::half_t;
12-
using F8 = ck_tile::fp8_t;
13-
using F32 = float;
14-
using Row = ck_tile::tensor_layout::gemm::RowMajor;
15-
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
11+
using F16 = ck_tile::half_t;
12+
using F8 = ck_tile::fp8_t;
13+
using F32 = float;
14+
using Row = ck_tile::tensor_layout::gemm::RowMajor;
15+
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
16+
using False = std::false_type;
17+
using True = std::true_type;
1618

1719
// Custom tuple-like structure for kernel configuration
1820
template <typename ALayout_,
@@ -22,6 +24,7 @@ template <typename ALayout_,
2224
typename BDataType_,
2325
typename AccDataType_,
2426
typename CDataType_,
27+
typename Persistent_,
2528
int M_Tile_val_,
2629
int N_Tile_val_,
2730
int K_Tile_val_,
@@ -35,6 +38,7 @@ struct KernelConfig
3538
using BDataType = BDataType_;
3639
using AccDataType = AccDataType_;
3740
using CDataType = CDataType_;
41+
using Persistent = Persistent_;
3842

3943
static constexpr int M_Tile_ = M_Tile_val_;
4044
static constexpr int N_Tile_ = N_Tile_val_;
@@ -44,11 +48,16 @@ struct KernelConfig
4448

4549
// clang-format off
4650
using KernelTypes = ::testing::Types<
47-
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_Tile, N_Tile, K_Tile, BlockPerCu
48-
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 16, 64, 256, 1>,
49-
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>,
50-
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>,
51-
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2>
51+
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu
52+
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
53+
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>,
54+
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
55+
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>,
56+
57+
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
58+
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>,
59+
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
60+
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>
5261
>;
5362
// clang-format on
5463

test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp

Lines changed: 144 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,13 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
3939
using BDataType = typename Tuple::BDataType;
4040
using AccDataType = typename Tuple::AccDataType;
4141
using CDataType = typename Tuple::CDataType;
42-
using PrecType = BDataType;
43-
using DsLayout = ck_tile::tuple<>; // not used
44-
using DsDataType = ck_tile::tuple<>; // not used
42+
43+
using DsLayout = ck_tile::tuple<>; // not used
44+
using DsDataType = ck_tile::tuple<>; // not used
45+
46+
// Get the persistent value from ck_tile::bool_constant
47+
using PersistentType = typename Tuple::Persistent;
48+
static constexpr bool Persistent = PersistentType::value;
4549

4650
static const bool kPadM = false;
4751
static const bool kPadN = false;
@@ -231,6 +235,129 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
231235
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
232236
}
233237

238+
private:
239+
template <typename ALayout, typename BLayout, typename CLayout>
240+
void invoke_grouped_gemm_persistent(const std::vector<grouped_gemm_kargs>& gemm_descs,
241+
const ck_tile::stream_config& s,
242+
void* kargs_ptr)
243+
{
244+
using GemmShape =
245+
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
246+
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
247+
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
248+
using TilePartitioner = ck_tile::
249+
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
250+
251+
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
252+
253+
// Enable persistent mode for preshuffle
254+
using GemmUniversalTraits =
255+
ck_tile::TileGemmUniversalTraits</*kPadM*/ true,
256+
/*kPadN*/ true,
257+
/*kPadK*/ true,
258+
DoubleSmemBuffer,
259+
ALayout,
260+
BLayout,
261+
CLayout,
262+
TransposeC,
263+
/*UseStructuredSparsity*/ false,
264+
/*Persistent*/ true, // Enable persistent mode
265+
/*NumWaveGroups*/ 1,
266+
/*Preshuffle*/ true>;
267+
using GemmPipelineProblem =
268+
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
269+
270+
using BaseGemmPipeline =
271+
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>;
272+
273+
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile;
274+
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile;
275+
const ck_tile::index_t num_loop =
276+
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
277+
TileParitionerGroupNum,
278+
TileParitionerM01>::GetLoopNum(K_split);
279+
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
280+
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
281+
282+
float ave_time{0};
283+
284+
const auto Run = [&](const auto has_hot_loop_,
285+
const auto tail_number_,
286+
const auto memory_operation_) {
287+
constexpr bool has_hot_loop_v = has_hot_loop_.value;
288+
constexpr auto tail_number_v = tail_number_.value;
289+
constexpr auto memory_operation = memory_operation_.value;
290+
using UniversalGemmProblem =
291+
ck_tile::UniversalGemmPipelineProblem<ADataType,
292+
BDataType,
293+
AccDataType,
294+
GemmShape,
295+
GemmUniversalTraits,
296+
ck_tile::GemmPipelineScheduler::Default,
297+
has_hot_loop_v,
298+
tail_number_v>;
299+
using GemmPipeline =
300+
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
301+
using GemmEpilogue = ck_tile::CShuffleEpilogue<
302+
ck_tile::CShuffleEpilogueProblem<ADataType,
303+
BDataType,
304+
DsDataType,
305+
AccDataType,
306+
CDataType,
307+
DsLayout,
308+
CLayout,
309+
ck_tile::element_wise::PassThrough,
310+
TilePartitioner::MPerBlock,
311+
TilePartitioner::NPerBlock,
312+
M_Warp,
313+
N_Warp,
314+
M_Warp_Tile,
315+
N_Warp_Tile,
316+
K_Warp_Tile,
317+
UniversalGemmProblem::TransposeC,
318+
memory_operation>>;
319+
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
320+
auto kargs = Kernel::MakeKargs(gemm_descs);
321+
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
322+
const dim3 grids = Kernel::GridSize(gemm_descs);
323+
const dim3 blocks = Kernel::BlockSize();
324+
325+
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
326+
kargs.data(),
327+
get_workspace_size(gemm_descs),
328+
hipMemcpyHostToDevice,
329+
s.stream_id_));
330+
331+
ave_time = ck_tile::launch_kernel(
332+
s,
333+
ck_tile::make_kernel<kBlockPerCu>(
334+
Kernel{},
335+
grids,
336+
blocks,
337+
0,
338+
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
339+
gemm_descs.size()));
340+
return ave_time;
341+
};
342+
343+
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
344+
if(gemm_descs[0].k_batch == 1)
345+
{
346+
Run(has_hot_loop_,
347+
tail_number_,
348+
ck_tile::integral_constant<ck_tile::memory_operation_enum,
349+
ck_tile::memory_operation_enum::set>{});
350+
}
351+
else
352+
{
353+
// EXPECT TO FAIL because splitk is not supported
354+
EXPECT_FALSE(true);
355+
}
356+
};
357+
358+
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
359+
}
360+
234361
public:
235362
void Run(const std::vector<int>& Ms,
236363
const std::vector<int>& Ns,
@@ -350,9 +477,20 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
350477
ck_tile::DeviceMem gemm_workspace;
351478
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
352479

353-
invoke_grouped_gemm<ALayout, BLayout, CLayout>(gemm_descs,
354-
ck_tile::stream_config{nullptr, false, 1},
355-
gemm_workspace.GetDeviceBuffer());
480+
if constexpr(Persistent)
481+
{
482+
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
483+
gemm_descs,
484+
ck_tile::stream_config{nullptr, false, 1},
485+
gemm_workspace.GetDeviceBuffer());
486+
}
487+
else
488+
{
489+
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
490+
gemm_descs,
491+
ck_tile::stream_config{nullptr, false, 1},
492+
gemm_workspace.GetDeviceBuffer());
493+
}
356494

357495
// Copy results back to host for validation
358496
for(int i = 0; i < group_count; i++)

0 commit comments

Comments
 (0)