Skip to content

Commit 1d74759

Browse files
committed
Replace CK_TILE_PIPELINE macros with a common enum
This change replaces pipeline macros like CK_TILE_PIPELINE_COMPUTE_V3, CK_TILE_PIPELINE_MEMORY, etc in the CK Tile examples with a common enum called GemmPipeline to reduce code duplication.
1 parent 8f1274d commit 1d74759

File tree

13 files changed

+222
-240
lines changed

13 files changed

+222
-240
lines changed

example/ck_tile/03_gemm/gemm_basic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
6868
else if(data_type == "pk_int4_t")
6969
{
7070
// TODO: Add support for bhalf_t ADataType
71-
if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
71+
if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
7272
{
7373
return run_gemm_example_prec_type<GemmConfig,
7474
Invoker,

example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
962962
else if(data_type == "pk_int4_t")
963963
{
964964
// TODO: Add support for bhalf_t ADataType
965-
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
965+
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
966966
{
967967
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
968968
ck_tile::half_t,

example/ck_tile/03_gemm/gemm_utils.hpp

Lines changed: 42 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,6 @@
1212
#include "ck_tile/ops/gemm.hpp"
1313
#include "ck_tile/utility/json_dump.hpp"
1414

15-
#define CK_TILE_PIPELINE_COMPUTE_V3 1
16-
#define CK_TILE_PIPELINE_MEMORY 2
17-
#define CK_TILE_PIPELINE_COMPUTE_V4 3
18-
#define CK_TILE_PIPELINE_COMPUTE_V5 4
19-
#define CK_TILE_PIPELINE_COMPUTE_V6 5
20-
#define CK_TILE_PIPELINE_PRESHUFFLE_V1 6
21-
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 7
22-
2315
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
2416
constexpr ck_tile::index_t get_k_warp_tile()
2517
{
@@ -70,7 +62,7 @@ struct GemmConfigBase
7062
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
7163
static constexpr ck_tile::index_t TileParitionerM01 = 4;
7264
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
73-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
65+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
7466
static constexpr ck_tile::index_t NumWaveGroups = 1;
7567
static constexpr bool Preshuffle = false;
7668
static constexpr bool TiledMMAPermuteN = false;
@@ -92,9 +84,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
9284
static constexpr ck_tile::index_t N_Warp_Tile = 32;
9385
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
9486

95-
static constexpr bool DoubleSmemBuffer = false;
96-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
97-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
87+
static constexpr bool DoubleSmemBuffer = false;
88+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
89+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
9890
};
9991

10092
template <typename PrecType>
@@ -112,8 +104,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
112104
static constexpr ck_tile::index_t N_Warp_Tile = 32;
113105
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
114106

115-
static constexpr bool DoubleSmemBuffer = false;
116-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
107+
static constexpr bool DoubleSmemBuffer = false;
108+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
117109
};
118110

119111
template <typename PrecType>
@@ -132,8 +124,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
132124
static constexpr ck_tile::index_t N_Warp_Tile = 16;
133125
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
134126

135-
static constexpr bool DoubleSmemBuffer = false;
136-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
127+
static constexpr bool DoubleSmemBuffer = false;
128+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
137129
};
138130

139131
template <typename PrecType>
@@ -151,8 +143,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
151143
static constexpr ck_tile::index_t N_Warp_Tile = 32;
152144
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
153145

154-
static constexpr bool DoubleSmemBuffer = false;
155-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
146+
static constexpr bool DoubleSmemBuffer = false;
147+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
156148
};
157149

158150
template <typename PrecType>
@@ -170,8 +162,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
170162
static constexpr ck_tile::index_t N_Warp_Tile = 16;
171163
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
172164

173-
static constexpr bool DoubleSmemBuffer = false;
174-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
165+
static constexpr bool DoubleSmemBuffer = false;
166+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
175167

176168
static constexpr int kBlockPerCu = 2;
177169
};
@@ -191,8 +183,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
191183
static constexpr ck_tile::index_t N_Warp_Tile = 16;
192184
static constexpr ck_tile::index_t K_Warp_Tile = 16;
193185

194-
static constexpr bool DoubleSmemBuffer = false;
195-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
186+
static constexpr bool DoubleSmemBuffer = false;
187+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
196188

197189
static constexpr int kBlockPerCu = 2;
198190
};
@@ -214,8 +206,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
214206
static constexpr ck_tile::index_t N_Warp_Tile = 32;
215207
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
216208

217-
static constexpr bool DoubleSmemBuffer = true;
218-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
209+
static constexpr bool DoubleSmemBuffer = true;
210+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
219211
};
220212

221213
template <typename PrecType>
@@ -233,8 +225,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
233225
static constexpr ck_tile::index_t N_Warp_Tile = 32;
234226
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
235227

236-
static constexpr bool DoubleSmemBuffer = true;
237-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
228+
static constexpr bool DoubleSmemBuffer = true;
229+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
238230
};
239231

240232
template <typename PrecType>
@@ -253,7 +245,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
253245
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
254246

255247
static constexpr bool DoubleSmemBuffer = false;
256-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
248+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
257249
static constexpr ck_tile::index_t NumWaveGroups = 2;
258250
};
259251

@@ -273,7 +265,7 @@ struct GemmConfigComputeV6 : public GemmConfigBase
273265
static constexpr ck_tile::index_t K_Warp_Tile = 16;
274266

275267
static constexpr bool DoubleSmemBuffer = false;
276-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6;
268+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6;
277269
static constexpr ck_tile::index_t NumWaveGroups = 1;
278270
};
279271

@@ -292,13 +284,13 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
292284
static constexpr ck_tile::index_t N_Warp_Tile = 16;
293285
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
294286

295-
static constexpr int kBlockPerCu = 1;
296-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
297-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
298-
static constexpr bool Preshuffle = true;
299-
static constexpr bool DoubleSmemBuffer = true;
300-
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
301-
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
287+
static constexpr int kBlockPerCu = 1;
288+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
289+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
290+
static constexpr bool Preshuffle = true;
291+
static constexpr bool DoubleSmemBuffer = true;
292+
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
293+
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
302294
};
303295

304296
template <typename PrecType>
@@ -316,13 +308,13 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
316308
static constexpr ck_tile::index_t N_Warp_Tile = 16;
317309
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
318310

319-
static constexpr int kBlockPerCu = 2;
320-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
321-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
322-
static constexpr bool Preshuffle = true;
323-
static constexpr bool DoubleSmemBuffer = true;
324-
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
325-
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
311+
static constexpr int kBlockPerCu = 2;
312+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
313+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
314+
static constexpr bool Preshuffle = true;
315+
static constexpr bool DoubleSmemBuffer = true;
316+
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
317+
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
326318
};
327319

328320
template <typename PrecType>
@@ -466,11 +458,11 @@ struct DataTypeTraits<ck_tile::int8_t>
466458
static constexpr const char* name = "int8";
467459
};
468460

469-
template <ck_tile::index_t PipelineId>
461+
template <ck_tile::GemmPipeline PipelineId>
470462
struct PipelineTypeTraits;
471463

472464
template <>
473-
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
465+
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
474466
{
475467
template <typename PipelineProblem>
476468
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -479,7 +471,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
479471
};
480472

481473
template <>
482-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
474+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
483475
{
484476
template <typename PipelineProblem>
485477
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -488,7 +480,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
488480
};
489481

490482
template <>
491-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
483+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
492484
{
493485
template <typename PipelineProblem>
494486
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
@@ -497,7 +489,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
497489
};
498490

499491
template <>
500-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
492+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V5>
501493
{
502494
template <typename PipelineProblem>
503495
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
@@ -506,7 +498,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
506498
};
507499

508500
template <>
509-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V6>
501+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
510502
{
511503
template <typename PipelineProblem>
512504
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
@@ -515,7 +507,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V6>
515507
};
516508

517509
template <>
518-
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V1>
510+
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V1>
519511
{
520512
template <typename PipelineProblem>
521513
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
@@ -525,7 +517,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V1>
525517
};
526518

527519
template <>
528-
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
520+
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
529521
{
530522
template <typename PipelineProblem>
531523
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;

example/ck_tile/03_gemm/universal_gemm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
5858
else if(data_type == "fp16i4")
5959
{
6060
// TODO: Add support for bhalf_t ADataType
61-
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
61+
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
6262
{
6363
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
6464
Invoker,
@@ -73,7 +73,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
7373
}
7474
else if(data_type == "fp8i4")
7575
{
76-
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
76+
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
7777
{
7878
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
7979
Invoker,
@@ -88,7 +88,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
8888
}
8989
else if(data_type == "bf8i4")
9090
{
91-
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
91+
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
9292
{
9393
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
9494
Invoker,

example/ck_tile/16_batched_gemm/batched_gemm.hpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
1212
#include "ck_tile/utility/json_dump.hpp"
1313

14-
#define CK_TILE_PIPELINE_COMPUTE_V3 1
15-
#define CK_TILE_PIPELINE_MEMORY 2
16-
#define CK_TILE_PIPELINE_COMPUTE_V4 3
17-
1814
struct GemmConfigMemory
1915
{
2016
// Memory friendly for Interwave scheduler
@@ -30,9 +26,9 @@ struct GemmConfigMemory
3026
static constexpr ck_tile::index_t N_Warp_Tile = 32;
3127
static constexpr ck_tile::index_t K_Warp_Tile = 8;
3228

33-
static constexpr bool DoubleSmemBuffer = false;
34-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
35-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
29+
static constexpr bool DoubleSmemBuffer = false;
30+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
31+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
3632
};
3733

3834
struct GemmConfigV3
@@ -50,9 +46,9 @@ struct GemmConfigV3
5046
static constexpr ck_tile::index_t N_Warp_Tile = 32;
5147
static constexpr ck_tile::index_t K_Warp_Tile = 16;
5248

53-
static constexpr bool DoubleSmemBuffer = false;
54-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
55-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
49+
static constexpr bool DoubleSmemBuffer = false;
50+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
51+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
5652
};
5753

5854
struct GemmConfigV4
@@ -71,9 +67,9 @@ struct GemmConfigV4
7167
static constexpr ck_tile::index_t N_Warp_Tile = 32;
7268
static constexpr ck_tile::index_t K_Warp_Tile = 16;
7369

74-
static constexpr bool DoubleSmemBuffer = true;
75-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
76-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
70+
static constexpr bool DoubleSmemBuffer = true;
71+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
72+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
7773
};
7874

7975
struct GemmConfigV3_Wmma
@@ -91,16 +87,16 @@ struct GemmConfigV3_Wmma
9187
static constexpr ck_tile::index_t N_Warp_Tile = 16;
9288
static constexpr ck_tile::index_t K_Warp_Tile = 16;
9389

94-
static constexpr bool DoubleSmemBuffer = false;
95-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
96-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
90+
static constexpr bool DoubleSmemBuffer = false;
91+
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
92+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
9793
};
9894

99-
template <ck_tile::index_t PipelineId>
95+
template <ck_tile::GemmPipeline PipelineId>
10096
struct PipelineTypeTraits;
10197

10298
template <>
103-
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
99+
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
104100
{
105101
template <typename PipelineProblem>
106102
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -109,7 +105,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
109105
};
110106

111107
template <>
112-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
108+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
113109
{
114110
template <typename PipelineProblem>
115111
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -118,7 +114,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
118114
};
119115

120116
template <>
121-
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
117+
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
122118
{
123119
template <typename PipelineProblem>
124120
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;

0 commit comments

Comments
 (0)