1212#include " ck_tile/ops/gemm.hpp"
1313#include " ck_tile/utility/json_dump.hpp"
1414
15- #define CK_TILE_PIPELINE_COMPUTE_V3 1
16- #define CK_TILE_PIPELINE_MEMORY 2
17- #define CK_TILE_PIPELINE_COMPUTE_V4 3
18- #define CK_TILE_PIPELINE_COMPUTE_V5 4
19- #define CK_TILE_PIPELINE_COMPUTE_V6 5
20- #define CK_TILE_PIPELINE_PRESHUFFLE_V2 6
21-
2215template <typename PrecType, ck_tile::index_t M_Warp_Tile>
2316constexpr ck_tile::index_t get_k_warp_tile ()
2417{
@@ -69,7 +62,7 @@ struct GemmConfigBase
6962 static constexpr ck_tile::index_t TileParitionerGroupNum = 8 ;
7063 static constexpr ck_tile::index_t TileParitionerM01 = 4 ;
7164 static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
72- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
65+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
7366 static constexpr ck_tile::index_t NumWaveGroups = 1 ;
7467 static constexpr bool Preshuffle = false ;
7568 static constexpr bool TiledMMAPermuteN = false ;
@@ -91,9 +84,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
9184 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
9285 static constexpr ck_tile::index_t K_Warp_Tile = sizeof (PrecType) == 2 ? 8 : 16 ;
9386
94- static constexpr bool DoubleSmemBuffer = false ;
95- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY ;
96- static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
87+ static constexpr bool DoubleSmemBuffer = false ;
88+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY ;
89+ static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
9790};
9891
9992template <typename PrecType>
@@ -111,8 +104,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
111104 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
112105 static constexpr ck_tile::index_t K_Warp_Tile = sizeof (PrecType) == 2 ? 8 : 16 ;
113106
114- static constexpr bool DoubleSmemBuffer = false ;
115- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY ;
107+ static constexpr bool DoubleSmemBuffer = false ;
108+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY ;
116109};
117110
118111template <typename PrecType>
@@ -131,8 +124,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
131124 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
132125 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
133126
134- static constexpr bool DoubleSmemBuffer = false ;
135- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
127+ static constexpr bool DoubleSmemBuffer = false ;
128+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
136129};
137130
138131template <typename PrecType>
@@ -150,8 +143,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
150143 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
151144 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
152145
153- static constexpr bool DoubleSmemBuffer = false ;
154- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
146+ static constexpr bool DoubleSmemBuffer = false ;
147+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
155148};
156149
157150template <typename PrecType>
@@ -169,8 +162,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
169162 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
170163 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
171164
172- static constexpr bool DoubleSmemBuffer = false ;
173- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
165+ static constexpr bool DoubleSmemBuffer = false ;
166+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
174167
175168 static constexpr int kBlockPerCu = 2 ;
176169};
@@ -190,8 +183,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
190183 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
191184 static constexpr ck_tile::index_t K_Warp_Tile = 16 ;
192185
193- static constexpr bool DoubleSmemBuffer = false ;
194- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
186+ static constexpr bool DoubleSmemBuffer = false ;
187+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
195188
196189 static constexpr int kBlockPerCu = 2 ;
197190};
@@ -213,8 +206,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
213206 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
214207 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
215208
216- static constexpr bool DoubleSmemBuffer = true ;
217- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4 ;
209+ static constexpr bool DoubleSmemBuffer = true ;
210+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4 ;
218211};
219212
220213template <typename PrecType>
@@ -232,8 +225,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
232225 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
233226 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
234227
235- static constexpr bool DoubleSmemBuffer = true ;
236- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4 ;
228+ static constexpr bool DoubleSmemBuffer = true ;
229+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4 ;
237230};
238231
239232template <typename PrecType>
@@ -252,7 +245,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
252245 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
253246
254247 static constexpr bool DoubleSmemBuffer = false ;
255- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5 ;
248+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5 ;
256249 static constexpr ck_tile::index_t NumWaveGroups = 2 ;
257250};
258251
@@ -272,7 +265,7 @@ struct GemmConfigComputeV6 : public GemmConfigBase
272265 static constexpr ck_tile::index_t K_Warp_Tile = 16 ;
273266
274267 static constexpr bool DoubleSmemBuffer = false ;
275- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6 ;
268+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6 ;
276269 static constexpr ck_tile::index_t NumWaveGroups = 1 ;
277270};
278271
@@ -291,13 +284,13 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
291284 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
292285 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
293286
294- static constexpr int kBlockPerCu = 1 ;
295- static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
296- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2 ;
297- static constexpr bool Preshuffle = true ;
298- static constexpr bool DoubleSmemBuffer = true ;
299- static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
300- static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0 ;
287+ static constexpr int kBlockPerCu = 1 ;
288+ static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
289+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2 ;
290+ static constexpr bool Preshuffle = true ;
291+ static constexpr bool DoubleSmemBuffer = true ;
292+ static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
293+ static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0 ;
301294};
302295
303296template <typename PrecType>
@@ -315,13 +308,13 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
315308 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
316309 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
317310
318- static constexpr int kBlockPerCu = 2 ;
319- static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
320- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2 ;
321- static constexpr bool Preshuffle = true ;
322- static constexpr bool DoubleSmemBuffer = true ;
323- static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
324- static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0 ;
311+ static constexpr int kBlockPerCu = 2 ;
312+ static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
313+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2 ;
314+ static constexpr bool Preshuffle = true ;
315+ static constexpr bool DoubleSmemBuffer = true ;
316+ static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
317+ static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0 ;
325318};
326319
327320template <typename PrecType>
@@ -465,11 +458,11 @@ struct DataTypeTraits<ck_tile::int8_t>
465458 static constexpr const char * name = " int8" ;
466459};
467460
468- template <ck_tile::index_t PipelineId>
461+ template <ck_tile::GemmPipeline PipelineId>
469462struct PipelineTypeTraits ;
470463
471464template <>
472- struct PipelineTypeTraits <CK_TILE_PIPELINE_MEMORY >
465+ struct PipelineTypeTraits <ck_tile::GemmPipeline::MEMORY >
473466{
474467 template <typename PipelineProblem>
475468 using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -478,7 +471,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
478471};
479472
480473template <>
481- struct PipelineTypeTraits <CK_TILE_PIPELINE_COMPUTE_V3 >
474+ struct PipelineTypeTraits <ck_tile::GemmPipeline::COMPUTE_V3 >
482475{
483476 template <typename PipelineProblem>
484477 using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -487,7 +480,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
487480};
488481
489482template <>
490- struct PipelineTypeTraits <CK_TILE_PIPELINE_COMPUTE_V4 >
483+ struct PipelineTypeTraits <ck_tile::GemmPipeline::COMPUTE_V4 >
491484{
492485 template <typename PipelineProblem>
493486 using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
@@ -496,7 +489,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
496489};
497490
498491template <>
499- struct PipelineTypeTraits <CK_TILE_PIPELINE_COMPUTE_V5 >
492+ struct PipelineTypeTraits <ck_tile::GemmPipeline::COMPUTE_V5 >
500493{
501494 template <typename PipelineProblem>
502495 using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
@@ -505,7 +498,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
505498};
506499
507500template <>
508- struct PipelineTypeTraits <CK_TILE_PIPELINE_COMPUTE_V6 >
501+ struct PipelineTypeTraits <ck_tile::GemmPipeline::COMPUTE_V6 >
509502{
510503 template <typename PipelineProblem>
511504 using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
@@ -514,7 +507,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V6>
514507};
515508
516509template <>
517- struct PipelineTypeTraits <CK_TILE_PIPELINE_PRESHUFFLE_V2 >
510+ struct PipelineTypeTraits <ck_tile::GemmPipeline::PRESHUFFLE_V2 >
518511{
519512 template <typename PipelineProblem>
520513 using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
0 commit comments