From 7579fa697819dce4dbc21fb176a2f9e09b9a400c Mon Sep 17 00:00:00 2001 From: Emily Martins Date: Mon, 27 Oct 2025 17:26:04 +0000 Subject: [PATCH] Replace CK_TILE_PIPELINE macros with a common enum This change replaces pipeline macros like CK_TILE_PIPELINE_COMPUTE_V3, CK_TILE_PIPELINE_MEMORY, etc in the CK Tile examples with a common enum called GemmPipeline to reduce code duplication. --- example/ck_tile/03_gemm/gemm_basic.cpp | 2 +- .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 2 +- example/ck_tile/03_gemm/gemm_utils.hpp | 89 +++++++++---------- example/ck_tile/03_gemm/universal_gemm.cpp | 6 +- .../ck_tile/16_batched_gemm/batched_gemm.hpp | 36 ++++---- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 67 +++++++------- .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 46 +++++----- .../19_gemm_multi_d/gemm_multi_d_fp16.hpp | 37 ++++---- .../20_grouped_convolution/conv_configs.hpp | 57 ++++++------ .../22_gemm_multi_abd/gemm_multi_abd_fp16.hpp | 41 ++++----- include/ck_tile/ops/gemm.hpp | 1 + .../ops/gemm/pipeline/gemm_pipelines.hpp | 21 +++++ .../gemm/test_gemm_pipeline_smoke_util.hpp | 53 +++++------ 13 files changed, 220 insertions(+), 238 deletions(-) create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index f92f6ef87a..3c26661c84 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -68,7 +68,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, ck_tile::half_t, diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index dbed40800e..6d833fbd7a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -12,13 +12,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_COMPUTE_V6 5 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -69,7 +62,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool TiledMMAPermuteN = false; @@ -91,9 +84,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -111,8 +104,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template @@ -131,8 +124,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -150,8 +143,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -169,8 +162,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -190,8 +183,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -213,8 +206,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -232,8 +225,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -252,7 +245,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaveGroups = 2; }; @@ -272,7 +265,7 @@ struct GemmConfigComputeV6 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6; static constexpr ck_tile::index_t NumWaveGroups = 1; }; @@ -291,13 +284,13 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template @@ -315,13 +308,13 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 2; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template @@ -465,11 +458,11 @@ struct DataTypeTraits static constexpr const char* name = "int8"; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -478,7 +471,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -487,7 +480,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -496,7 +489,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; @@ -505,7 +498,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; @@ -514,7 +507,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index f9a7263a5f..a8a7288a3d 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -58,7 +58,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) else if(data_type == "fp16i4") { // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, Invoker, @@ -73,7 +73,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) } else if(data_type == "fp8i4") { - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, Invoker, @@ -88,7 +88,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) } else if(data_type == "bf8i4") { - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, Invoker, diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 33da0bf0a5..c0935a0e46 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -11,10 +11,6 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - struct GemmConfigMemory { // Memory friendly for Interwave scheduler @@ -30,9 +26,9 @@ struct GemmConfigMemory static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 @@ -50,9 +46,9 @@ struct GemmConfigV3 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 @@ -71,9 +67,9 @@ struct GemmConfigV4 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma @@ -91,16 +87,16 @@ struct GemmConfigV3_Wmma static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -109,7 +105,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -118,7 +114,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 57d3f224d8..049957cbfd 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -11,11 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -87,7 +82,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool Persistent = true; @@ -109,8 +104,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 1; }; @@ -132,8 +127,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; @@ -155,8 +150,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; @@ -178,12 +173,12 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr bool kPadK = true; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool Persistent = true; - static constexpr bool DoubleSmemBuffer = true; + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = true; }; template @@ -201,12 +196,12 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 2; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr bool kPadK = true; + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr bool kPadK = true; }; template @@ -226,8 +221,8 @@ struct GemmConfigComputeV4_Wmma : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; @@ -249,18 +244,18 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase static constexpr bool kPadK = true; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -269,7 +264,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -278,7 +273,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -287,7 +282,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index 12d70eecb6..81c0b654e2 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -11,10 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -44,8 +40,8 @@ struct GemmConfigBase static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet static constexpr bool Persistent = false; // currently persistent == true is not supported yet static constexpr bool DoubleSmemBuffer = @@ -67,10 +63,10 @@ struct GemmConfigMemory : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr bool Persistent = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool Persistent = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 : public GemmConfigBase @@ -88,10 +84,10 @@ struct GemmConfigV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool Persistent = true; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 : public GemmConfigBase { @@ -109,10 +105,10 @@ struct GemmConfigV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool Persistent = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma : public GemmConfigBase @@ -130,16 +126,16 @@ struct GemmConfigV3_Wmma : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -148,7 +144,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -157,7 +153,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp index a7ae227627..8a621cd4be 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp @@ -7,12 +7,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using D0DataType = ck_tile::half_t; @@ -36,9 +33,9 @@ struct GemmConfigMemory static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 @@ -56,9 +53,9 @@ struct GemmConfigV3 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 @@ -77,9 +74,9 @@ struct GemmConfigV4 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma @@ -97,16 +94,16 @@ struct GemmConfigV3_Wmma static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -115,7 +112,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -124,7 +121,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index 1be6080383..c688215280 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -12,11 +12,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 - struct ConvConfigBase { static constexpr bool kPadM = true; @@ -37,7 +32,7 @@ struct ConvConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool TiledMMAPermuteN = false; @@ -61,9 +56,9 @@ struct ConvConfigMemoryInterwave : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -81,8 +76,8 @@ struct ConvConfigMemoryIntrawave : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template @@ -101,8 +96,8 @@ struct ConvConfigComputeV3 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -120,8 +115,8 @@ struct ConvConfigComputeV3_1 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -139,8 +134,8 @@ struct ConvConfigComputeV3_2 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -160,8 +155,8 @@ struct ConvConfigComputeV3_WMMA : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -183,8 +178,8 @@ struct ConvConfigComputeV4 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -202,8 +197,8 @@ struct ConvConfigComputeV4_1 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -222,7 +217,7 @@ struct ConvConfigComputeV5 : public ConvConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; @@ -245,8 +240,8 @@ struct ConvConfigComputeV3_merged_groups : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumGroupsToMerge = 2; }; @@ -294,11 +289,11 @@ struct DataTypeTraits static constexpr const char* name = "bf16"; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -307,7 +302,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -316,7 +311,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -325,7 +320,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp index 35bc232eca..76a2635e5f 100644 --- a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp @@ -7,16 +7,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 -#endif - using A0DataType = ck_tile::half_t; using A1DataType = ck_tile::half_t; @@ -49,9 +42,9 @@ struct GemmConfigMemory static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 @@ -69,9 +62,9 @@ struct GemmConfigV3 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 @@ -90,9 +83,9 @@ struct GemmConfigV4 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma @@ -110,16 +103,16 @@ struct GemmConfigV3_Wmma static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -128,7 +121,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -137,7 +130,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 33be18948b..ec2d2488c8 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -55,6 +55,7 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp new file mode 100644 index 0000000000..9b948626f6 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp @@ -0,0 +1,21 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +enum struct GemmPipeline +{ + COMPUTE_ASYNC, + COMPUTE_V3, + COMPUTE_V4, + COMPUTE_V5, + COMPUTE_V6, + MEMORY, + BASIC_V1, + BASIC_V2, + PRESHUFFLE_V2 +}; + +} // namespace ck_tile diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp index 0820be5b30..1f9033cab9 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -10,11 +10,6 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 - class ArgumentsNotSupportedException : public std::logic_error { public: @@ -56,7 +51,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; }; @@ -76,9 +71,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -96,8 +91,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template @@ -116,8 +111,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -135,8 +130,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -154,8 +149,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -177,8 +172,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -196,8 +191,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -216,7 +211,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; @@ -235,8 +230,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -401,11 +396,11 @@ struct DataTypeTraits static constexpr const char* name = "int8"; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -414,7 +409,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -423,7 +418,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -432,7 +427,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5;