Skip to content

Commit eb29023

Browse files
committed
Extend AK1 / BK1 support:
- Add support for AK1 != BK1 - Add support for AK1, BK1 > 8 - Introduce KInner template parameter for pipelines when loading multiple tiles with one instruction
1 parent 99f38e4 commit eb29023

24 files changed

+632
-351
lines changed

example/01_gemm/gemm_wmma_fp8_v3.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using CDataType = ck::bhalf_t;
1313
using ComputeTypeA = ck::f8_t;
1414
using ComputeTypeB = ck::f8_t;
1515

16-
using ALayout = Row;
16+
using ALayout = Col;
1717
using BLayout = Col;
1818
using CLayout = Row;
1919

@@ -30,13 +30,13 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
3030
PassThrough, PassThrough, PassThrough, GemmDefault,
3131
128,
3232
128, 64, 64,
33-
8, 8,
33+
16, 16, // AK1, BK1
3434
16, 16,
3535
4, 2,
36+
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
37+
1, 4, 16, 0,
3638
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
37-
2, 8, 8, 0,
38-
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
39-
2, 8, 8, 0,
39+
2, 16, 16, 0,
4040
1, 1, S<1, 32, 1, 4>, 8,
4141
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,
4242
ComputeTypeA, ComputeTypeB>;

include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
2828
index_t MRepeat,
2929
index_t NRepeat,
3030
index_t KPack,
31+
index_t KInner,
3132
bool TransposeC = false>
3233
constexpr auto BlockGemmPipeline_Selector()
3334
{
@@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector()
5253
MRepeat,
5354
NRepeat,
5455
KPack,
56+
KInner,
5557
TransposeC>{};
5658
}
5759
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
@@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector()
7577
MRepeat,
7678
NRepeat,
7779
KPack,
80+
KInner,
7881
TransposeC>{};
7982
}
8083
else

include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ template <index_t BlockSize,
3030
index_t MRepeat,
3131
index_t NRepeat,
3232
index_t KPack,
33+
index_t KInner,
3334
bool TransposeC = false>
3435
struct BlockwiseGemmWmmaops_pipeline_base
3536
{
@@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
3839
static constexpr auto I2 = Number<2>{};
3940
static constexpr auto I3 = Number<3>{};
4041
static constexpr auto I5 = Number<5>{};
42+
static constexpr auto I6 = Number<6>{};
4143

4244
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
4345

@@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base
5456
static constexpr index_t B_KRow = 1;
5557
#endif
5658

57-
static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
58-
static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
59+
static constexpr auto wmma_gemm = WmmaGemm<ComputeTypeA,
60+
ComputeTypeB,
61+
AccDataType,
62+
MPerWmma,
63+
NPerWmma,
64+
KPack / KInner,
65+
TransposeC>{};
66+
67+
static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner;
68+
static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread);
69+
static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread);
5970

6071
static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
6172
static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
62-
63-
static constexpr auto wmma_gemm =
64-
WmmaGemm<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};
65-
6673
static constexpr index_t KRepeat = KPerBlock / KPack;
6774

6875
static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
@@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
191198
const auto wmma_krow = 0;
192199
#endif
193200

194-
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
195-
return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
201+
return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
196202
}
197203

198204
__device__ static auto CalculateBThreadOriginDataIndex()
@@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
209215
const auto wmma_krow = 0;
210216
#endif
211217

212-
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
213-
return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
218+
return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
214219
}
215220

216221
template <index_t m0, index_t n0>
@@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
241246
return make_tuple(c_thread_m, c_thread_n);
242247
}
243248

244-
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
249+
using Tuple7 = decltype(CalculateAThreadOriginDataIndex());
245250

246251
/**
247252
* @brief Constructor for BlockwiseGemmWmmaops_pipeline_base.
@@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base
261266
* repeat dimensions.
262267
*/
263268
__host__ __device__
264-
BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
265-
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
269+
BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(),
270+
Tuple7 b_origin = CalculateBThreadOriginDataIndex())
266271
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
267272
{
268273
static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
@@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
343348
Number<KRepeat>{},
344349
I1,
345350
I1,
351+
I1,
346352
Number<A_K1>{}),
347353
make_tuple(Number<A_K1>{},
348354
Number<KPack / A_KRow>{},
349355
Number<KPack / A_KRow * MRepeat>{},
350356
I0,
351357
I0,
358+
I0,
352359
I1));
353360

354361
static constexpr auto b_thread_desc_ =
@@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
357364
Number<KRepeat>{},
358365
I1,
359366
I1,
367+
I1,
360368
Number<B_K1>{}),
361369
make_tuple(Number<B_K1>{},
362370
Number<KPack / B_KRow>{},
363371
Number<KPack / B_KRow * NRepeat>{},
364372
I0,
365373
I0,
374+
I0,
366375
I1));
367376

368377
// C[M, N, NumRegWmma]
@@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
374383
ComputeTypeA,
375384
decltype(a_block_desc_k0_m0_m1_m2_k1),
376385
decltype(a_thread_desc_),
377-
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
378-
Sequence<0, 1, 2, 3, 4, 5>,
379-
5,
386+
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
387+
Sequence<0, 1, 2, 3, 4, 5, 6>,
388+
6,
380389
A_K1,
381390
A_K1>;
382391

@@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
385394
ComputeTypeB,
386395
decltype(b_block_desc_k0_n0_n1_n2_k1),
387396
decltype(b_thread_desc_),
388-
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
389-
Sequence<0, 1, 2, 3, 4, 5>,
390-
5,
397+
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
398+
Sequence<0, 1, 2, 3, 4, 5, 6>,
399+
6,
391400
B_K1,
392401
B_K1>;
393402

0 commit comments

Comments
 (0)