Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ struct BaseGemmPipelineAgBgCrCompAsync

CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop == 1)
{
return TailNumber::One;
}
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
Expand Down Expand Up @@ -65,6 +69,11 @@ struct BaseGemmPipelineAgBgCrCompAsync
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
else
{
return (run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::One>{}));
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
Expand Down Expand Up @@ -485,7 +494,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
}
else
else if(TailNum == TailNumber::Two)
// 2 block gemms remaining
{
{
Expand All @@ -500,6 +509,12 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
}
else if(TailNum == TailNumber::One)
{
block_sync_lds();
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
return c_block_tile;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ struct BaseGemmPipelineAgBgCrCompV4

CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop == 1)
{
return TailNumber::One;
}
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
Expand Down Expand Up @@ -67,6 +71,11 @@ struct BaseGemmPipelineAgBgCrCompV4
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
else
{
return (run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::One>{}));
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
Expand Down Expand Up @@ -621,7 +630,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
__builtin_amdgcn_sched_barrier(0);
}
}
else
else if(TailNum == TailNumber::Two)
{
// 2
{
Expand All @@ -641,6 +650,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
__builtin_amdgcn_sched_barrier(0);
}
}
else if(TailNum == TailNumber::One)
{
block_sync_lds();
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
return c_block_tile;
}
};
Expand Down