@@ -94,6 +94,11 @@ struct BlockUniversalGemmAsBsCr
9494 using ComputeDataType = remove_cvref_t <typename Traits::ComputeDataType>;
9595 using CDataType = remove_cvref_t <typename Traits::CDataType>;
9696
97+ using ATypeToUse =
98+ std::conditional_t <std::is_same_v<ADataType, pk_int4_t >, BDataType, ADataType>;
99+ using BTypeToUse =
100+ std::conditional_t <std::is_same_v<BDataType, pk_int4_t >, ADataType, BDataType>;
101+
97102 using WarpGemm = remove_cvref_t <typename Traits::WarpGemm>;
98103
99104 static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -195,8 +200,8 @@ struct BlockUniversalGemmAsBsCr
195200 static constexpr auto BLdsTileDistr =
196201 decltype (make_static_tile_distribution(MakeBBlockDistributionEncode())){};
197202
198- using ALdsTile = decltype (make_static_distributed_tensor<ADataType >(ALdsTileDistr));
199- using BLdsTile = decltype (make_static_distributed_tensor<BDataType >(BLdsTileDistr));
203+ using ALdsTile = decltype (make_static_distributed_tensor<ATypeToUse >(ALdsTileDistr));
204+ using BLdsTile = decltype (make_static_distributed_tensor<BTypeToUse >(BLdsTileDistr));
200205
201206 ALdsTile a_warp_tile_;
202207 BLdsTile b_warp_tile_;
@@ -221,8 +226,8 @@ struct BlockUniversalGemmAsBsCr
221226 " The ADataType and BDataType as defined in "
222227 " traits should be the same as correspoinding block window data type!" );
223228
224- load_int4_tile<ADataType, ComputeDataType , UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
225- load_int4_tile<BDataType, ComputeDataType , UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
229+ load_int4_tile<ADataType, ATypeToUse , UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
230+ load_int4_tile<BDataType, BTypeToUse , UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
226231 // hot loop:
227232 static_for<0 , GemmTraits::KIterPerWarp, 1 >{}([&](auto kIter ) {
228233 static_for<0 , MIterPerWarp, 1 >{}([&](auto mIter ) {
@@ -270,8 +275,8 @@ struct BlockUniversalGemmAsBsCr
270275 static constexpr auto BLdsTileDistr =
271276 decltype (make_static_tile_distribution(MakeBBlockDistributionEncode())){};
272277
273- using ALdsTile = decltype (make_static_distributed_tensor<ADataType >(ALdsTileDistr));
274- using BLdsTile = decltype (make_static_distributed_tensor<BDataType >(BLdsTileDistr));
278+ using ALdsTile = decltype (make_static_distributed_tensor<ATypeToUse >(ALdsTileDistr));
279+ using BLdsTile = decltype (make_static_distributed_tensor<BTypeToUse >(BLdsTileDistr));
275280
276281 ALdsTile a_warp_tile_;
277282 BLdsTile b_warp_tile_;
@@ -285,8 +290,8 @@ struct BlockUniversalGemmAsBsCr
285290 bool_constant<ALoadTranspose> = {},
286291 bool_constant<BLoadTranspose> = {})
287292 {
288- load_int4_tile<ADataType, ComputeDataType , UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
289- load_int4_tile<BDataType, ComputeDataType , UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
293+ load_int4_tile<ADataType, ATypeToUse , UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
294+ load_int4_tile<BDataType, BTypeToUse , UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
290295 }
291296
292297 // C += A * B
@@ -359,8 +364,8 @@ struct BlockUniversalGemmAsBsCr
359364 static constexpr auto BLdsTileDistr =
360365 make_static_tile_distribution (MakeBBlockDistributionEncode());
361366
362- using ALdsTile = decltype (make_static_distributed_tensor<ADataType >(ALdsTileDistr));
363- using BLdsTile = decltype (make_static_distributed_tensor<BDataType >(BLdsTileDistr));
367+ using ALdsTile = decltype (make_static_distributed_tensor<ATypeToUse >(ALdsTileDistr));
368+ using BLdsTile = decltype (make_static_distributed_tensor<BTypeToUse >(BLdsTileDistr));
364369
365370 ALdsTile a_warp_tile_;
366371 BLdsTile b_warp_tile_;
@@ -414,8 +419,8 @@ struct BlockUniversalGemmAsBsCr
414419 auto b_lds_gemm_window = make_tile_window (
415420 b_block_window.get_bottom_tensor_view (), b_lds_shape, b_offset, b_lds_load_distr);
416421
417- load_int4_tile<ADataType, ComputeDataType , UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
418- load_int4_tile<BDataType, ComputeDataType , UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
422+ load_int4_tile<ADataType, ATypeToUse , UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
423+ load_int4_tile<BDataType, BTypeToUse , UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
419424 }
420425
421426 // C += A * B
0 commit comments