1212#include " ck_tile/ops/gemm.hpp"
1313#include " ck_tile/utility/json_dump.hpp"
1414
15- #define CK_TILE_PIPELINE_COMPUTE_V3 1
16- #define CK_TILE_PIPELINE_MEMORY 2
17- #define CK_TILE_PIPELINE_COMPUTE_V4 3
18- #define CK_TILE_PIPELINE_COMPUTE_V5 4
19- #define CK_TILE_PIPELINE_COMPUTE_V6 5
20- #define CK_TILE_PIPELINE_PRESHUFFLE_V1 6
21- #define CK_TILE_PIPELINE_PRESHUFFLE_V2 7
22-
2315template <typename PrecType, ck_tile::index_t M_Warp_Tile>
2416constexpr ck_tile::index_t get_k_warp_tile ()
2517{
@@ -70,7 +62,7 @@ struct GemmConfigBase
7062 static constexpr ck_tile::index_t TileParitionerGroupNum = 8 ;
7163 static constexpr ck_tile::index_t TileParitionerM01 = 4 ;
7264 static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
73- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
65+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
7466 static constexpr ck_tile::index_t NumWaveGroups = 1 ;
7567 static constexpr bool Preshuffle = false ;
7668 static constexpr bool TiledMMAPermuteN = false ;
@@ -92,9 +84,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
9284 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
9385 static constexpr ck_tile::index_t K_Warp_Tile = sizeof (PrecType) == 2 ? 8 : 16 ;
9486
95- static constexpr bool DoubleSmemBuffer = false ;
96- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY ;
97- static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
87+ static constexpr bool DoubleSmemBuffer = false ;
88+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY ;
89+ static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
9890};
9991
10092template <typename PrecType>
@@ -112,8 +104,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
112104 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
113105 static constexpr ck_tile::index_t K_Warp_Tile = sizeof (PrecType) == 2 ? 8 : 16 ;
114106
115- static constexpr bool DoubleSmemBuffer = false ;
116- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY ;
107+ static constexpr bool DoubleSmemBuffer = false ;
108+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY ;
117109};
118110
119111template <typename PrecType>
@@ -132,8 +124,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
132124 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
133125 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
134126
135- static constexpr bool DoubleSmemBuffer = false ;
136- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
127+ static constexpr bool DoubleSmemBuffer = false ;
128+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
137129};
138130
139131template <typename PrecType>
@@ -151,8 +143,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
151143 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
152144 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
153145
154- static constexpr bool DoubleSmemBuffer = false ;
155- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
146+ static constexpr bool DoubleSmemBuffer = false ;
147+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
156148};
157149
158150template <typename PrecType>
@@ -170,8 +162,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
170162 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
171163 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
172164
173- static constexpr bool DoubleSmemBuffer = false ;
174- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
165+ static constexpr bool DoubleSmemBuffer = false ;
166+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
175167
176168 static constexpr int kBlockPerCu = 2 ;
177169};
@@ -191,8 +183,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
191183 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
192184 static constexpr ck_tile::index_t K_Warp_Tile = 16 ;
193185
194- static constexpr bool DoubleSmemBuffer = false ;
195- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3 ;
186+ static constexpr bool DoubleSmemBuffer = false ;
187+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3 ;
196188
197189 static constexpr int kBlockPerCu = 2 ;
198190};
@@ -214,8 +206,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
214206 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
215207 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
216208
217- static constexpr bool DoubleSmemBuffer = true ;
218- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4 ;
209+ static constexpr bool DoubleSmemBuffer = true ;
210+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4 ;
219211};
220212
221213template <typename PrecType>
@@ -233,8 +225,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
233225 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
234226 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
235227
236- static constexpr bool DoubleSmemBuffer = true ;
237- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4 ;
228+ static constexpr bool DoubleSmemBuffer = true ;
229+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4 ;
238230};
239231
240232template <typename PrecType>
@@ -253,7 +245,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
253245 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
254246
255247 static constexpr bool DoubleSmemBuffer = false ;
256- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5 ;
248+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5 ;
257249 static constexpr ck_tile::index_t NumWaveGroups = 2 ;
258250};
259251
@@ -273,7 +265,7 @@ struct GemmConfigComputeV6 : public GemmConfigBase
273265 static constexpr ck_tile::index_t K_Warp_Tile = 16 ;
274266
275267 static constexpr bool DoubleSmemBuffer = false ;
276- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6 ;
268+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6 ;
277269 static constexpr ck_tile::index_t NumWaveGroups = 1 ;
278270};
279271
@@ -292,13 +284,13 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
292284 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
293285 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
294286
295- static constexpr int kBlockPerCu = 1 ;
296- static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
297- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2 ;
298- static constexpr bool Preshuffle = true ;
299- static constexpr bool DoubleSmemBuffer = true ;
300- static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
301- static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0 ;
287+ static constexpr int kBlockPerCu = 1 ;
288+ static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
289+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2 ;
290+ static constexpr bool Preshuffle = true ;
291+ static constexpr bool DoubleSmemBuffer = true ;
292+ static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
293+ static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0 ;
302294};
303295
304296template <typename PrecType>
@@ -316,13 +308,13 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
316308 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
317309 static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
318310
319- static constexpr int kBlockPerCu = 2 ;
320- static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
321- static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2 ;
322- static constexpr bool Preshuffle = true ;
323- static constexpr bool DoubleSmemBuffer = true ;
324- static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
325- static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0 ;
311+ static constexpr int kBlockPerCu = 2 ;
312+ static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
313+ static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2 ;
314+ static constexpr bool Preshuffle = true ;
315+ static constexpr bool DoubleSmemBuffer = true ;
316+ static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
317+ static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0 ;
326318};
327319
328320template <typename PrecType>
@@ -466,11 +458,11 @@ struct DataTypeTraits<ck_tile::int8_t>
466458 static constexpr const char * name = " int8" ;
467459};
468460
469- template <ck_tile::index_t PipelineId>
461+ template <ck_tile::GemmPipeline PipelineId>
470462struct PipelineTypeTraits ;
471463
472464template <>
473- struct PipelineTypeTraits <CK_TILE_PIPELINE_MEMORY >
465+ struct PipelineTypeTraits <ck_tile::GemmPipeline::MEMORY >
474466{
475467 template <typename PipelineProblem>
476468 using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -479,7 +471,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
479471};
480472
481473template <>
482- struct PipelineTypeTraits <CK_TILE_PIPELINE_COMPUTE_V3 >
474+ struct PipelineTypeTraits <ck_tile::GemmPipeline::COMPUTE_V3 >
483475{
484476 template <typename PipelineProblem>
485477 using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -488,7 +480,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
488480};
489481
490482template <>
491- struct PipelineTypeTraits <CK_TILE_PIPELINE_COMPUTE_V4 >
483+ struct PipelineTypeTraits <ck_tile::GemmPipeline::COMPUTE_V4 >
492484{
493485 template <typename PipelineProblem>
494486 using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
@@ -497,7 +489,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
497489};
498490
499491template <>
500- struct PipelineTypeTraits <CK_TILE_PIPELINE_COMPUTE_V5 >
492+ struct PipelineTypeTraits <ck_tile::GemmPipeline::COMPUTE_V5 >
501493{
502494 template <typename PipelineProblem>
503495 using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
@@ -506,7 +498,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
506498};
507499
508500template <>
509- struct PipelineTypeTraits <CK_TILE_PIPELINE_COMPUTE_V6 >
501+ struct PipelineTypeTraits <ck_tile::GemmPipeline::COMPUTE_V6 >
510502{
511503 template <typename PipelineProblem>
512504 using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
@@ -515,7 +507,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V6>
515507};
516508
517509template <>
518- struct PipelineTypeTraits <CK_TILE_PIPELINE_PRESHUFFLE_V1 >
510+ struct PipelineTypeTraits <ck_tile::GemmPipeline::PRESHUFFLE_V1 >
519511{
520512 template <typename PipelineProblem>
521513 using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
@@ -525,7 +517,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V1>
525517};
526518
527519template <>
528- struct PipelineTypeTraits <CK_TILE_PIPELINE_PRESHUFFLE_V2 >
520+ struct PipelineTypeTraits <ck_tile::GemmPipeline::PRESHUFFLE_V2 >
529521{
530522 template <typename PipelineProblem>
531523 using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
0 commit comments