Skip to content

Commit 16151c1

Browse files
committed
Add support for pk_int4_t using load_int4_tile
1 parent 60c4a77 commit 16151c1

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ struct BlockUniversalGemmAsBsCr
9494
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
9595
using CDataType = remove_cvref_t<typename Traits::CDataType>;
9696

97+
using ATypeToUse =
98+
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
99+
using BTypeToUse =
100+
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
101+
97102
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
98103

99104
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -195,8 +200,8 @@ struct BlockUniversalGemmAsBsCr
195200
static constexpr auto BLdsTileDistr =
196201
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
197202

198-
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
199-
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
203+
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
204+
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
200205

201206
ALdsTile a_warp_tile_;
202207
BLdsTile b_warp_tile_;
@@ -221,8 +226,8 @@ struct BlockUniversalGemmAsBsCr
221226
"The ADataType and BDataType as defined in "
222227
"traits should be the same as correspoinding block window data type!");
223228

224-
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
225-
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
229+
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
230+
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
226231
// hot loop:
227232
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
228233
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -270,8 +275,8 @@ struct BlockUniversalGemmAsBsCr
270275
static constexpr auto BLdsTileDistr =
271276
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
272277

273-
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
274-
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
278+
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
279+
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
275280

276281
ALdsTile a_warp_tile_;
277282
BLdsTile b_warp_tile_;
@@ -285,8 +290,8 @@ struct BlockUniversalGemmAsBsCr
285290
bool_constant<ALoadTranspose> = {},
286291
bool_constant<BLoadTranspose> = {})
287292
{
288-
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
289-
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
293+
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
294+
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
290295
}
291296

292297
// C += A * B
@@ -359,8 +364,8 @@ struct BlockUniversalGemmAsBsCr
359364
static constexpr auto BLdsTileDistr =
360365
make_static_tile_distribution(MakeBBlockDistributionEncode());
361366

362-
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
363-
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
367+
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
368+
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
364369

365370
ALdsTile a_warp_tile_;
366371
BLdsTile b_warp_tile_;
@@ -414,8 +419,8 @@ struct BlockUniversalGemmAsBsCr
414419
auto b_lds_gemm_window = make_tile_window(
415420
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
416421

417-
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
418-
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
422+
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
423+
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
419424
}
420425

421426
// C += A * B

include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
722722
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
723723
: WGAttrNumAccessEnum::Invalid;
724724

725-
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
726-
typename Problem::BDataType,
725+
using ADataType = remove_cvref_t<typename Problem::ADataType>;
726+
using BDataType = remove_cvref_t<typename Problem::BDataType>;
727+
using ATypeToUse =
728+
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
729+
using BTypeToUse =
730+
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
731+
732+
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
733+
BTypeToUse,
727734
typename Problem::CDataType,
728735
WarpTile::at(I0),
729736
WarpTile::at(I1),
@@ -733,8 +740,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
733740
Problem::UseStructuredSparsity,
734741
wg_attr_num_access>;
735742

736-
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
737-
typename Problem::BDataType,
743+
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
744+
BTypeToUse,
738745
typename Problem::CDataType,
739746
BlockWarps,
740747
WarpGemm>;

0 commit comments

Comments
 (0)