Skip to content

Commit 057b7d4

Browse files
authored
fix the compv4 and async pipeline when tile handler is 1 (#3141)
1 parent 2ec57a8 commit 057b7d4

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ struct BaseGemmPipelineAgBgCrCompAsync
2525

2626
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
2727
{
28+
if(num_loop == 1)
29+
{
30+
return TailNumber::One;
31+
}
2832
if(num_loop % PrefetchStages == 1)
2933
{
3034
return TailNumber::Three;
@@ -65,6 +69,11 @@ struct BaseGemmPipelineAgBgCrCompAsync
6569
return run_func(bool_constant<false>{},
6670
integral_constant<TailNumber, TailNumber::Two>{});
6771
}
72+
else
73+
{
74+
return (run_func(bool_constant<false>{},
75+
integral_constant<TailNumber, TailNumber::One>{}));
76+
}
6877
}
6978
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
7079
#if defined(__HIP_DEVICE_COMPILE__)
@@ -485,7 +494,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
485494
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
486495
}
487496
}
488-
else
497+
else if(TailNum == TailNumber::Two)
489498
// 2 block gemms remaining
490499
{
491500
{
@@ -500,6 +509,12 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
500509
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
501510
}
502511
}
512+
else if(TailNum == TailNumber::One)
513+
{
514+
block_sync_lds();
515+
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
516+
__builtin_amdgcn_sched_barrier(0);
517+
}
503518
return c_block_tile;
504519
}
505520
};

include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ struct BaseGemmPipelineAgBgCrCompV4
2727

2828
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
2929
{
30+
if(num_loop == 1)
31+
{
32+
return TailNumber::One;
33+
}
3034
if(num_loop % PrefetchStages == 1)
3135
{
3236
return TailNumber::Three;
@@ -67,6 +71,11 @@ struct BaseGemmPipelineAgBgCrCompV4
6771
return run_func(bool_constant<false>{},
6872
integral_constant<TailNumber, TailNumber::Two>{});
6973
}
74+
else
75+
{
76+
return (run_func(bool_constant<false>{},
77+
integral_constant<TailNumber, TailNumber::One>{}));
78+
}
7079
}
7180
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
7281
#if defined(__HIP_DEVICE_COMPILE__)
@@ -621,7 +630,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
621630
__builtin_amdgcn_sched_barrier(0);
622631
}
623632
}
624-
else
633+
else if(TailNum == TailNumber::Two)
625634
{
626635
// 2
627636
{
@@ -641,6 +650,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
641650
__builtin_amdgcn_sched_barrier(0);
642651
}
643652
}
653+
else if(TailNum == TailNumber::One)
654+
{
655+
block_sync_lds();
656+
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
657+
__builtin_amdgcn_sched_barrier(0);
658+
}
644659
return c_block_tile;
645660
}
646661
};

0 commit comments

Comments
 (0)