Skip to content

Commit 8f1274d

Browse files
AviralGoelAMDkyle-256ThomasNingCopilot
authored
test(grouped_gemm): add unit tests for grouped_gemm bquant with preshuffleB true (#3119)
* add tensorwise quant in grouped gemm * fix example issue * update test cases * format codes * clang format * use GTEST_FAIL * add bquant to grouped_gemm * add tensorwise quant in grouped gemm * fix example issue * update test cases * format codes * clang format * use GTEST_FAIL * fix a bug in test_grouped_gemm_util * skip test when use wmma on grouped_quant kernel * change cmake * fix a bug in test_grouped_gemm_util * skip test when use wmma on grouped_quant kernel * change cmake * tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm * Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp Co-authored-by: Copilot <[email protected]> * Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp Co-authored-by: Copilot <[email protected]> * feat: add bf8 support * chore: remove unnecessary decltype usage * chore: add default quant_mode to function signature as fallback * fix: pass correct runtime pipeline params in grouped_gemm bquant kernel Calculate has_hot_loop, num_loop, and tail_number on device side for each GEMM problem instead of using default values. This fixes incorrect results when different problems in the group have different K dimensions. * chore: set default quant mode in function signature * test: add additional test cases to cover edge case of no hotloop * change code based on comments * WIP: bquant preshuffle b compiles but gives numerical error * feat(grouped_gemm_quant): bquant with preshuffleB support added to grouped_gemm example & kernel * refactor: refactor code after merge commit * chore: remove print statements * test(grouped_gemm): split test cases by quant mode to reduce compilation time and add bquant-preshuffleB mode test cases --------- Co-authored-by: kyle-256 <[email protected]> Co-authored-by: ThomasNing <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent a33d98f commit 8f1274d

14 files changed

+426
-75
lines changed

example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
4949
GemmConfig::kPadN,
5050
GemmConfig::kPadK,
5151
false, // PreshuffleQuant
52-
false, // PreshuffleB
52+
GemmConfig::PreshuffleB, // PreshuffleB
5353
ALayout,
5454
BLayout,
5555
CLayout,
@@ -58,7 +58,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
5858
BQLayout,
5959
GemmConfig::TransposeC,
6060
GemmConfig::DoubleSmemBuffer,
61-
true>;
61+
true>; // Persistence
6262

6363
float ave_time{0};
6464

@@ -86,10 +86,14 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
8686
BDataType,
8787
scheduler>>::type;
8888

89-
using GemmPipeline =
90-
typename std::conditional<QuantMode == ck_tile::QuantType::BQuantGrouped,
91-
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
92-
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
89+
using GemmPipeline = std::conditional_t<
90+
QuantMode == ck_tile::QuantType::RowColQuant ||
91+
QuantMode == ck_tile::QuantType::TensorQuant,
92+
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
93+
std::conditional_t<GemmConfig::PreshuffleB == true,
94+
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
95+
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
96+
9397
using GemmEpilogue = ck_tile::CShuffleEpilogue<
9498
ck_tile::CShuffleEpilogueProblem<ADataType,
9599
BDataType,
@@ -141,5 +145,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
141145

142146
int main(int argc, char* argv[])
143147
{
144-
return !run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
148+
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
149+
return result1;
145150
}

example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
#include "ck_tile/ops/gemm.hpp"
1111
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
1212

13-
#define CK_TILE_PIPELINE_COMPUTE_V3 1
14-
#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2
15-
1613
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
1714
constexpr ck_tile::index_t get_k_warp_tile()
1815
{
@@ -31,6 +28,22 @@ constexpr ck_tile::index_t get_k_warp_tile()
3128
#endif
3229
}
3330

31+
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
32+
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
33+
{
34+
#if defined(CK_GFX950_SUPPORT)
35+
if constexpr(M_Warp_Tile == 32)
36+
return sizeof(PrecType) == 2 ? 16 : 64;
37+
else
38+
return sizeof(PrecType) == 2 ? 32 : 128;
39+
#else
40+
if constexpr(M_Warp_Tile == 32)
41+
return sizeof(PrecType) == 2 ? 16 : 32;
42+
else
43+
return sizeof(PrecType) == 2 ? 32 : 64;
44+
#endif
45+
}
46+
3447
template <typename DataType>
3548
struct GemmTypeConfig;
3649

@@ -67,8 +80,9 @@ struct GemmConfigBase
6780
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
6881
static constexpr ck_tile::index_t TileParitionerM01 = 4;
6982
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
70-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
7183
static constexpr ck_tile::index_t NumWaveGroups = 1;
84+
static constexpr bool DoubleSmemBuffer = false;
85+
static constexpr bool PreshuffleB = false;
7286
};
7387

7488
template <typename PrecType>
@@ -85,10 +99,26 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
8599
static constexpr ck_tile::index_t M_Warp_Tile = 32;
86100
static constexpr ck_tile::index_t N_Warp_Tile = 32;
87101
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
102+
};
103+
104+
template <typename PrecType>
105+
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
106+
{
107+
static constexpr ck_tile::index_t M_Tile = 128;
108+
static constexpr ck_tile::index_t N_Tile = 128;
109+
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
110+
111+
static constexpr ck_tile::index_t M_Warp = 1;
112+
static constexpr ck_tile::index_t N_Warp = 4;
113+
static constexpr ck_tile::index_t K_Warp = 1;
88114

89-
static constexpr bool DoubleSmemBuffer = false;
115+
static constexpr ck_tile::index_t M_Warp_Tile = 16;
116+
static constexpr ck_tile::index_t N_Warp_Tile = 16;
117+
static constexpr ck_tile::index_t K_Warp_Tile =
118+
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
90119

91-
static constexpr int kBlockPerCu = 1;
120+
static constexpr bool PreshuffleB = true;
121+
static constexpr bool DoubleSmemBuffer = true;
92122
};
93123

94124
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
@@ -118,7 +148,8 @@ auto create_args(int argc, char* argv[])
118148
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
119149
.insert("group_count", "8", "group count.")
120150
.insert("kbatch", "1", "kbatch for SplitK")
121-
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol");
151+
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
152+
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
122153

123154
bool result = arg_parser.parse(argc, argv);
124155
return std::make_tuple(result, arg_parser);

example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
163163
const int repeat = arg_parser.get_int("repeat");
164164
const int warmup = arg_parser.get_int("warmup");
165165
const int kbatch = arg_parser.get_int("kbatch");
166+
const int init_method = arg_parser.get_int("init");
166167
bool validate = arg_parser.get_bool("validate");
167168
const ck_tile::index_t QuantGroupSize = 128;
168169

@@ -203,6 +204,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
203204

204205
for(int i = 0; i < group_count; i++)
205206
{
207+
206208
Ms.push_back(256 + 256 * i);
207209
Ns.push_back(256 + 512 * i);
208210
Ks.push_back(512 + 128 * i);
@@ -280,6 +282,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
280282
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
281283
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
282284
}
285+
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
286+
{
287+
stride_AQs[i] = 0; // No A quantization
288+
stride_BQs[i] =
289+
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
290+
}
283291

284292
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
285293
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
@@ -313,10 +321,20 @@ int run_grouped_gemm_example_with_layouts(int argc,
313321
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
314322
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
315323

316-
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
317-
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
318-
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
319-
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
324+
if(init_method == 2)
325+
{
326+
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_tensors[i]);
327+
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_tensors[i]);
328+
ck_tile::FillUniformDistribution<AQDataType>{1.f, 1.f}(aq_tensors[i]);
329+
ck_tile::FillUniformDistribution<BQDataType>{1.f, 1.f}(bq_tensors[i]);
330+
}
331+
else
332+
{
333+
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
334+
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
335+
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
336+
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
337+
}
320338

321339
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
322340
a_m_k_tensors[i].get_element_space_size_in_bytes()));
@@ -329,8 +347,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
329347
bq_dev_buf.push_back(
330348
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
331349

350+
if constexpr(GemmConfig::PreshuffleB && QuantMode == ck_tile::QuantType::BQuantGrouped)
351+
{
352+
ck_tile::HostTensor<BDataType> b_shuffle_host =
353+
ck_tile::shuffle_b<GemmConfig>(b_k_n_tensors[i]);
354+
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
355+
}
356+
else
357+
{
358+
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
359+
}
360+
332361
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
333-
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
334362
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
335363
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
336364
c_m_n_dev_buf[i]->SetZero();

include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
2020

2121
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
2222

23-
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
23+
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
2424
{
2525
return num_loop > PrefetchStages;
2626
}

include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ struct QuantGemmKernel
483483
const QuantGemmKernelArgs& kargs,
484484
const SplitKBatchOffset& splitk_batch_offset)
485485
{
486+
486487
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
487488
const auto& a_tensor_view = [&]() {
488489
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
@@ -790,6 +791,7 @@ struct QuantGemmKernel
790791
}();
791792
if constexpr(PreshuffleB)
792793
{
794+
793795
return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
794796
}
795797
else
@@ -802,6 +804,7 @@ struct QuantGemmKernel
802804
CK_TILE_DEVICE static auto
803805
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
804806
{
807+
805808
const auto& a_pad_view = views.at(I0);
806809
const auto& aq_pad_view = views.at(I1);
807810
const auto& b_pad_view = views.at(I2);
@@ -867,6 +870,7 @@ struct QuantGemmKernel
867870
const auto& b_block_window = [&]() {
868871
if constexpr(PreshuffleB)
869872
{
873+
870874
return make_tile_window(
871875
b_pad_view,
872876
make_tuple(number<GemmPipeline::flatNPerWarp>{},

include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,88 @@ struct QuantGroupedGemmKernel
317317
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
318318
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
319319

320-
static_assert(GemmPipeline::DoubleSmemBuffer == false,
321-
"DoubleSmemBuffer needs to be false");
322320
// allocate LDS
323321
__shared__ char smem_ptr_0[GetSmemSize()];
324322

325-
RunGemmWithPipelineSelection(
326-
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
323+
// Only for BQuantGrouped DoubleSmemBuffer is supported
324+
if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
325+
kQuantType == QuantType::BQuantGrouped)
326+
{
327+
328+
__shared__ char smem_ptr_1[GetSmemSize()];
329+
RunGemmWithPipelineSelection2LDS(a_ptr,
330+
b_ptr,
331+
aq_ptr,
332+
bq_ptr,
333+
c_ptr,
334+
smem_ptr_0,
335+
smem_ptr_1,
336+
kargs,
337+
splitk_batch_offset,
338+
i_m,
339+
i_n);
340+
}
341+
else
342+
{
343+
344+
RunGemmWithPipelineSelection(a_ptr,
345+
b_ptr,
346+
aq_ptr,
347+
bq_ptr,
348+
c_ptr,
349+
smem_ptr_0,
350+
kargs,
351+
splitk_batch_offset,
352+
i_m,
353+
i_n);
354+
}
355+
}
356+
357+
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
358+
CK_TILE_DEVICE static void
359+
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
360+
const BDataType* b_ptr,
361+
const AQDataType* aq_ptr,
362+
const BQDataType* bq_ptr,
363+
CDataType* c_ptr,
364+
void* smem_ptr_0,
365+
void* smem_ptr_1,
366+
const QuantGroupedGemmKernelArgs& kargs,
367+
const typename Base::SplitKBatchOffset& splitk_batch_offset,
368+
const index_t block_idx_m,
369+
const index_t block_idx_n)
370+
{
371+
static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
372+
// Create Gemm tensor views, pad views and tile windows
373+
const auto& gemm_tensor_views_tuple =
374+
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
375+
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
376+
377+
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
378+
auto gemm_tile_windows =
379+
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
380+
381+
const index_t num_loop = __builtin_amdgcn_readfirstlane(
382+
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
383+
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
384+
385+
// Run GEMM cooperatively by whole workgroup.
386+
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
387+
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
388+
389+
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
390+
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
391+
b_block_window,
392+
bq_block_window,
393+
num_loop,
394+
tail_num,
395+
smem_ptr_0,
396+
smem_ptr_1);
397+
398+
// Run Epilogue Pipeline
399+
auto& c_block_window = gemm_tile_windows.at(Base::I4);
400+
401+
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
327402
}
328403

329404
/**

include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
458458
void* p_smem_ping,
459459
void* p_smem_pong) const
460460
{
461+
461462
return operator()<TailNum>(
462463
a_dram_block_window_tmp,
463464
[](const ADataType& a) { return a; },
@@ -467,5 +468,31 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
467468
p_smem_ping,
468469
p_smem_pong);
469470
}
471+
472+
template <typename ADramBlockWindowTmp,
473+
typename BFlatBlockWindowTmp,
474+
typename BQDramBlockWindowTmp>
475+
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
476+
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
477+
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
478+
index_t num_loop,
479+
TailNumber tail_number,
480+
void* p_smem_ping,
481+
void* p_smem_pong) const
482+
{
483+
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
484+
(void)bool_val; // Suppress unused parameter warning
485+
constexpr auto tail_num = tail_num_.value;
486+
return operator()<tail_num>(
487+
a_dram_block_window_tmp,
488+
[](const ADataType& a) { return a; },
489+
b_flat_dram_block_window_tmp,
490+
bq_dram_block_window_tmp,
491+
num_loop,
492+
p_smem_ping,
493+
p_smem_pong);
494+
};
495+
return Base::TailHandler(RunPipeline, true, tail_number);
496+
}
470497
};
471498
} // namespace ck_tile

test/ck_tile/grouped_gemm_quant/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@ if(CK_USE_OCP_FP8)
44
endif()
55

66
if(GPU_TARGETS MATCHES "gfx94|gfx95")
7-
add_gtest_executable(test_ck_tile_grouped_gemm_quant test_grouped_gemm_quant.cpp)
8-
target_compile_options(test_ck_tile_grouped_gemm_quant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
7+
# Split into three separate test executables for faster parallel compilation
8+
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
9+
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
10+
11+
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
12+
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
13+
14+
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
15+
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
916
endif()
1017

0 commit comments

Comments
 (0)