Skip to content

Commit 7579fa6

Browse files
committed
Replace CK_TILE_PIPELINE macros with a common enum
This change replaces pipeline macros like CK_TILE_PIPELINE_COMPUTE_V3, CK_TILE_PIPELINE_MEMORY, etc in the CK Tile examples with a common enum called GemmPipeline to reduce code duplication.
1 parent 73f6378 commit 7579fa6

File tree

13 files changed

+220
-238
lines changed

13 files changed

+220
-238
lines changed

example/ck_tile/03_gemm/gemm_basic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
6868
else if(data_type == "pk_int4_t")
6969
{
7070
// TODO: Add support for bhalf_t ADataType
71-
if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
71+
if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
7272
{
7373
return run_gemm_example_prec_type<GemmConfig,
7474
Invoker,

example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
962962
else if(data_type == "pk_int4_t")
963963
{
964964
// TODO: Add support for bhalf_t ADataType
965-
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
965+
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
966966
{
967967
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
968968
ck_tile::half_t,

example/ck_tile/03_gemm/gemm_utils.hpp

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@
1212
#include "ck_tile/ops/gemm.hpp"
1313
#include "ck_tile/utility/json_dump.hpp"
1414

15-
#define CK_TILE_PIPELINE_COMPUTE_V3 1
16-
#define CK_TILE_PIPELINE_MEMORY 2
17-
#define CK_TILE_PIPELINE_COMPUTE_V4 3
18-
#define CK_TILE_PIPELINE_COMPUTE_V5 4
19-
#define CK_TILE_PIPELINE_COMPUTE_V6 5
20-
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6
21-
2215
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
2316
constexpr ck_tile::index_t get_k_warp_tile()
2417
{
@@ -69,7 +62,7 @@ struct GemmConfigBase
6962
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
7063
static constexpr ck_tile::index_t TileParitionerM01 = 4;
7164
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
72-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
65+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
7366
static constexpr ck_tile::index_t NumWaveGroups = 1;
7467
static constexpr bool Preshuffle = false;
7568
static constexpr bool TiledMMAPermuteN = false;
@@ -91,9 +84,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
9184
static constexpr ck_tile::index_t N_Warp_Tile = 32;
9285
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
9386

94-
static constexpr bool DoubleSmemBuffer = false;
95-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
96-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
87+
static constexpr bool DoubleSmemBuffer = false;
88+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
89+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
9790
};
9891

9992
template <typename PrecType>
@@ -111,8 +104,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
111104
static constexpr ck_tile::index_t N_Warp_Tile = 32;
112105
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
113106

114-
static constexpr bool DoubleSmemBuffer = false;
115-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
107+
static constexpr bool DoubleSmemBuffer = false;
108+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
116109
};
117110

118111
template <typename PrecType>
@@ -131,8 +124,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
131124
static constexpr ck_tile::index_t N_Warp_Tile = 16;
132125
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
133126

134-
static constexpr bool DoubleSmemBuffer = false;
135-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
127+
static constexpr bool DoubleSmemBuffer = false;
128+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
136129
};
137130

138131
template <typename PrecType>
@@ -150,8 +143,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
150143
static constexpr ck_tile::index_t N_Warp_Tile = 32;
151144
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
152145

153-
static constexpr bool DoubleSmemBuffer = false;
154-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
146+
static constexpr bool DoubleSmemBuffer = false;
147+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
155148
};
156149

157150
template <typename PrecType>
@@ -169,8 +162,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
169162
static constexpr ck_tile::index_t N_Warp_Tile = 16;
170163
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
171164

172-
static constexpr bool DoubleSmemBuffer = false;
173-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
165+
static constexpr bool DoubleSmemBuffer = false;
166+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
174167

175168
static constexpr int kBlockPerCu = 2;
176169
};
@@ -190,8 +183,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
190183
static constexpr ck_tile::index_t N_Warp_Tile = 16;
191184
static constexpr ck_tile::index_t K_Warp_Tile = 16;
192185

193-
static constexpr bool DoubleSmemBuffer = false;
194-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
186+
static constexpr bool DoubleSmemBuffer = false;
187+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
195188

196189
static constexpr int kBlockPerCu = 2;
197190
};
@@ -213,8 +206,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
213206
static constexpr ck_tile::index_t N_Warp_Tile = 32;
214207
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
215208

216-
static constexpr bool DoubleSmemBuffer = true;
217-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
209+
static constexpr bool DoubleSmemBuffer = true;
210+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
218211
};
219212

220213
template <typename PrecType>
@@ -232,8 +225,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
232225
static constexpr ck_tile::index_t N_Warp_Tile = 32;
233226
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
234227

235-
static constexpr bool DoubleSmemBuffer = true;
236-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
228+
static constexpr bool DoubleSmemBuffer = true;
229+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
237230
};
238231

239232
template <typename PrecType>
@@ -252,7 +245,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
252245
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
253246

254247
static constexpr bool DoubleSmemBuffer = false;
255-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
248+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
256249
static constexpr ck_tile::index_t NumWaveGroups = 2;
257250
};
258251

@@ -272,7 +265,7 @@ struct GemmConfigComputeV6 : public GemmConfigBase
272265
static constexpr ck_tile::index_t K_Warp_Tile = 16;
273266

274267
static constexpr bool DoubleSmemBuffer = false;
275-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6;
268+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6;
276269
static constexpr ck_tile::index_t NumWaveGroups = 1;
277270
};
278271

@@ -291,13 +284,13 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
291284
static constexpr ck_tile::index_t N_Warp_Tile = 16;
292285
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
293286

294-
static constexpr int kBlockPerCu = 1;
295-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
296-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
297-
static constexpr bool Preshuffle = true;
298-
static constexpr bool DoubleSmemBuffer = true;
299-
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
300-
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
287+
static constexpr int kBlockPerCu = 1;
288+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
289+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
290+
static constexpr bool Preshuffle = true;
291+
static constexpr bool DoubleSmemBuffer = true;
292+
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
293+
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
301294
};
302295

303296
template <typename PrecType>
@@ -315,13 +308,13 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
315308
static constexpr ck_tile::index_t N_Warp_Tile = 16;
316309
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
317310

318-
static constexpr int kBlockPerCu = 2;
319-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
320-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
321-
static constexpr bool Preshuffle = true;
322-
static constexpr bool DoubleSmemBuffer = true;
323-
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
324-
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
311+
static constexpr int kBlockPerCu = 2;
312+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
313+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
314+
static constexpr bool Preshuffle = true;
315+
static constexpr bool DoubleSmemBuffer = true;
316+
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
317+
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
325318
};
326319

327320
template <typename PrecType>
@@ -465,11 +458,11 @@ struct DataTypeTraits<ck_tile::int8_t>
465458
static constexpr const char* name = "int8";
466459
};
467460

468-
template <ck_tile::index_t PipelineId>
461+
template <ck_tile::GemmPipeline PipelineId>
469462
struct PipelineTypeTraits;
470463

471464
template <>
472-
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
465+
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
473466
{
474467
template <typename PipelineProblem>
475468
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -478,7 +471,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
478471
};
479472

480473
template <>
481-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
474+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
482475
{
483476
template <typename PipelineProblem>
484477
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -487,7 +480,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
487480
};
488481

489482
template <>
490-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
483+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
491484
{
492485
template <typename PipelineProblem>
493486
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
@@ -496,7 +489,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
496489
};
497490

498491
template <>
499-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
492+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V5>
500493
{
501494
template <typename PipelineProblem>
502495
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
@@ -505,7 +498,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
505498
};
506499

507500
template <>
508-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V6>
501+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
509502
{
510503
template <typename PipelineProblem>
511504
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
@@ -514,7 +507,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V6>
514507
};
515508

516509
template <>
517-
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
510+
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
518511
{
519512
template <typename PipelineProblem>
520513
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;

example/ck_tile/03_gemm/universal_gemm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
5858
else if(data_type == "fp16i4")
5959
{
6060
// TODO: Add support for bhalf_t ADataType
61-
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
61+
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
6262
{
6363
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
6464
Invoker,
@@ -73,7 +73,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
7373
}
7474
else if(data_type == "fp8i4")
7575
{
76-
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
76+
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
7777
{
7878
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
7979
Invoker,
@@ -88,7 +88,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
8888
}
8989
else if(data_type == "bf8i4")
9090
{
91-
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
91+
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
9292
{
9393
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
9494
Invoker,

example/ck_tile/16_batched_gemm/batched_gemm.hpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
1212
#include "ck_tile/utility/json_dump.hpp"
1313

14-
#define CK_TILE_PIPELINE_COMPUTE_V3 1
15-
#define CK_TILE_PIPELINE_MEMORY 2
16-
#define CK_TILE_PIPELINE_COMPUTE_V4 3
17-
1814
struct GemmConfigMemory
1915
{
2016
// Memory friendly for Interwave scheduler
@@ -30,9 +26,9 @@ struct GemmConfigMemory
3026
static constexpr ck_tile::index_t N_Warp_Tile = 32;
3127
static constexpr ck_tile::index_t K_Warp_Tile = 8;
3228

33-
static constexpr bool DoubleSmemBuffer = false;
34-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
35-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
29+
static constexpr bool DoubleSmemBuffer = false;
30+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
31+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
3632
};
3733

3834
struct GemmConfigV3
@@ -50,9 +46,9 @@ struct GemmConfigV3
5046
static constexpr ck_tile::index_t N_Warp_Tile = 32;
5147
static constexpr ck_tile::index_t K_Warp_Tile = 16;
5248

53-
static constexpr bool DoubleSmemBuffer = false;
54-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
55-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
49+
static constexpr bool DoubleSmemBuffer = false;
50+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
51+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
5652
};
5753

5854
struct GemmConfigV4
@@ -71,9 +67,9 @@ struct GemmConfigV4
7167
static constexpr ck_tile::index_t N_Warp_Tile = 32;
7268
static constexpr ck_tile::index_t K_Warp_Tile = 16;
7369

74-
static constexpr bool DoubleSmemBuffer = true;
75-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
76-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
70+
static constexpr bool DoubleSmemBuffer = true;
71+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
72+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
7773
};
7874

7975
struct GemmConfigV3_Wmma
@@ -91,16 +87,16 @@ struct GemmConfigV3_Wmma
9187
static constexpr ck_tile::index_t N_Warp_Tile = 16;
9288
static constexpr ck_tile::index_t K_Warp_Tile = 16;
9389

94-
static constexpr bool DoubleSmemBuffer = false;
95-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
96-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
90+
static constexpr bool DoubleSmemBuffer = false;
91+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
92+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
9793
};
9894

99-
template <ck_tile::index_t PipelineId>
95+
template <ck_tile::GemmPipeline PipelineId>
10096
struct PipelineTypeTraits;
10197

10298
template <>
103-
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
99+
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
104100
{
105101
template <typename PipelineProblem>
106102
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -109,7 +105,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
109105
};
110106

111107
template <>
112-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
108+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
113109
{
114110
template <typename PipelineProblem>
115111
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -118,7 +114,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
118114
};
119115

120116
template <>
121-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
117+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
122118
{
123119
template <typename PipelineProblem>
124120
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;

0 commit comments

Comments
 (0)