From 1010dee18905b21d297827ccb80e4ea364e879be Mon Sep 17 00:00:00 2001 From: ThomasNing Date: Sun, 2 Nov 2025 03:42:42 +0000 Subject: [PATCH] fix the compv4 and async pipeline when tile handler is 1 --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 17 ++++++++++++++++- .../pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp | 17 ++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 1d2a3e180b..91da3cd27b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -25,6 +25,10 @@ struct BaseGemmPipelineAgBgCrCompAsync CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { + if(num_loop == 1) + { + return TailNumber::One; + } if(num_loop % PrefetchStages == 1) { return TailNumber::Three; @@ -65,6 +69,11 @@ struct BaseGemmPipelineAgBgCrCompAsync return run_func(bool_constant{}, integral_constant{}); } + else + { + return (run_func(bool_constant{}, + integral_constant{})); + } } // If execution reaches here, it's an invalid tail_number because it wasn't handled above. #if defined(__HIP_DEVICE_COMPILE__) @@ -485,7 +494,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}, integral_constant{}); } + else + { + return (run_func(bool_constant{}, + integral_constant{})); + } } // If execution reaches here, it's an invalid tail_number because it wasn't handled above. #if defined(__HIP_DEVICE_COMPILE__) @@ -621,7 +630,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 __builtin_amdgcn_sched_barrier(0); } } - else + else if(TailNum == TailNumber::Two) { // 2 { @@ -641,6 +650,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 __builtin_amdgcn_sched_barrier(0); } } + else if(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + __builtin_amdgcn_sched_barrier(0); + } return c_block_tile; } };