Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
48d10e7
Add missing copyright statements
SamiAario-AMD Oct 16, 2025
d03bfbe
Use ck_tile::host_tensor_descriptor instead of a custom lambda
SamiAario-AMD Oct 16, 2025
6a20c5d
Refactor use of check_data_type in test classes
SamiAario-AMD Oct 20, 2025
db1bd00
Use TEST_SUITE_NAME with TYPED_TEST_SUITE
SamiAario-AMD Oct 20, 2025
3b037bb
Remove an unused namespace
SamiAario-AMD Oct 22, 2025
1160a78
Make dim3 const
SamiAario-AMD Oct 22, 2025
5e812c5
Add BF8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 13, 2025
6e49e20
Add F8 x BF8 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 22, 2025
7ec1d82
Add BF16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 22, 2025
8a4dd13
Add BF16 x BF16 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 23, 2025
23a56b6
Add BF8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 23, 2025
075c94a
Add F8 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 23, 2025
3adb94d
Add F16 x I4 tests for CompV3 in test_gemm_pipeline_kernel_types.hpp
SamiAario-AMD Oct 23, 2025
ccfbb10
Skip failing tests of F16 x I4 for CompV3 with K == 2 * K_Tile
SamiAario-AMD Oct 23, 2025
aeec81b
Add missing precision type combinations to CompV4 from CompV3
SamiAario-AMD Oct 23, 2025
e36febe
Move the INT8 tests around for consistency with KernelTypesCompV3Wmma
SamiAario-AMD Oct 23, 2025
f754140
Add missing precision type combinations to CompV3Wmma from CompV3
SamiAario-AMD Oct 23, 2025
2d19108
Remove the basic and universal tests and their dependencies
SamiAario-AMD Oct 23, 2025
f2d3e54
On __gfx950__, avoid using transposed loading of A with datatype pk_i…
SamiAario-AMD Oct 31, 2025
d6ddc22
Use ADataType and BDataType instead of ComputeDataType for WarpGemm
SamiAario-AMD Nov 7, 2025
0802bc9
Explicitly set some return types to void
SamiAario-AMD Oct 29, 2025
33e880b
Use more general typenames in InterleavedPKTypeLoader
SamiAario-AMD Oct 29, 2025
42e0b8d
Add load_interleaved_pk_type.hpp to common.hpp
SamiAario-AMD Oct 9, 2025
1831b0a
Use std::is_same_v in load_int4_tile
SamiAario-AMD Oct 29, 2025
e74e7b9
Add handling of LoadTranspose to load_int4_tile
SamiAario-AMD Oct 29, 2025
bf941f2
Factor out common code in several places using load_int4_tile
SamiAario-AMD Oct 30, 2025
5a1af25
Add support for pk_int4_t using load_int4_tile
SamiAario-AMD Nov 7, 2025
b6c8c6d
Fix formatting
SamiAario-AMD Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/ck_tile/core/tensor/tile_window.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ struct tile_window_with_static_distribution
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
Expand Down Expand Up @@ -283,7 +283,7 @@ struct tile_window_with_static_distribution
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
Expand Down Expand Up @@ -431,7 +431,7 @@ struct tile_window_with_static_distribution
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
CK_TILE_DEVICE void async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
Expand Down Expand Up @@ -515,7 +515,7 @@ struct tile_window_with_static_distribution
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
CK_TILE_DEVICE auto async_load_with_offset(index_t offset,
CK_TILE_DEVICE void async_load_with_offset(index_t offset,
LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
Expand Down Expand Up @@ -605,7 +605,7 @@ struct tile_window_with_static_distribution
typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset,
CK_TILE_DEVICE void load_transpose_with_offset(index_t offset,
DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
Expand Down
29 changes: 13 additions & 16 deletions include/ck_tile/ops/common/load_interleaved_pk_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,7 @@

namespace ck_tile {

template <class T>
struct is_pk_int4 : std::false_type
{
};
template <>
struct is_pk_int4<pk_int4_t> : std::true_type
{
};

template <typename ComputeDataType, index_t UnaryOpSize>
template <typename DstDataType, index_t UnaryOpSize>
struct InterleavedPKTypeLoader
{
template <typename WarpWindow, typename WarpTile>
Expand All @@ -30,24 +21,30 @@ struct InterleavedPKTypeLoader
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);

using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
};

template <typename BDataType,
typename ComputeDataType,
template <typename SrcDataType,
typename DstDataType,
index_t UnaryOpSize,
bool LoadTranspose = false,
typename WarpTile,
typename WarpWindow>
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(is_pk_int4<std::remove_cv_t<BDataType>>::value)
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
{
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t");
InterleavedPKTypeLoader<DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
}
else if constexpr(LoadTranspose)
{
InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
dst = load_tile_transpose(src);
}
else
{
Expand Down
94 changes: 23 additions & 71 deletions include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ struct BlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;

using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;

using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;

static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
Expand Down Expand Up @@ -196,8 +200,8 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};

using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));

ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
Expand All @@ -222,22 +226,10 @@ struct BlockUniversalGemmAsBsCr
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");

if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
Expand Down Expand Up @@ -285,8 +277,8 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};

using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));

ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
Expand All @@ -300,30 +292,10 @@ struct BlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
b_warp_tile_ = load_tile_transpose(b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
}

// C += A * B
Expand Down Expand Up @@ -396,8 +368,8 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());

using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));

ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
Expand Down Expand Up @@ -451,30 +423,10 @@ struct BlockUniversalGemmAsBsCr
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);

if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_lds_gemm_window);
}
else
{
load_tile(a_warp_tile_, a_lds_gemm_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
b_warp_tile_ = load_tile_transpose(b_lds_gemm_window);
}
else
{
load_tile(b_warp_tile_, b_lds_gemm_window);
}
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_lds_gemm_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_lds_gemm_window);
}

// C += A * B
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,21 @@ struct GemmPipelineAgBgCrImplBase
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
#if defined(__gfx950__)
static constexpr bool is_a_load_tr = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
static constexpr bool is_b_load_tr = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
// The combination of pk_int4_t and transposed loading causes numerical errors.
// Therefore do not use transposed loading in this case.
static constexpr bool is_a_load_tr = []() {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
}();

static constexpr bool is_b_load_tr = []() {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
}();
#else
static constexpr bool is_a_load_tr = false;
static constexpr bool is_b_load_tr = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,27 @@ template <typename Derived>
struct UniversalGemmBasePolicy
{
#if defined(__gfx950__)
// The combination of pk_int4_t and transposed loading causes numerical errors.
// Therefore do not use transposed loading in this case.
template <typename Problem>
static constexpr bool is_a_load_tr =
std::is_same_v<remove_cvref_t<typename Problem::ALayout>, tensor_layout::gemm::ColumnMajor>;
static constexpr bool is_a_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
tensor_layout::gemm::ColumnMajor>;
}();

template <typename Problem>
static constexpr bool is_b_load_tr =
std::is_same_v<remove_cvref_t<typename Problem::BLayout>, tensor_layout::gemm::RowMajor>;
static constexpr bool is_b_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
tensor_layout::gemm::RowMajor>;
}();
#else
template <typename Problem>
static constexpr bool is_a_load_tr = false;
Expand Down Expand Up @@ -707,8 +722,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;

using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;

using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
Expand All @@ -718,8 +740,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
Problem::UseStructuredSparsity,
wg_attr_num_access>;

using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
Expand Down Expand Up @@ -156,7 +155,6 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>

using Base = BlockGemmAQuantBase<Problem_>;

using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;

static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
Expand Down Expand Up @@ -447,26 +445,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
}

// C += A * B
Expand Down
Loading