@@ -30,6 +30,7 @@ template <index_t BlockSize,
3030 index_t MRepeat,
3131 index_t NRepeat,
3232 index_t KPack,
33+ index_t KInner,
3334 bool TransposeC = false >
3435struct BlockwiseGemmWmmaops_pipeline_base
3536{
@@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
3839 static constexpr auto I2 = Number<2 >{};
3940 static constexpr auto I3 = Number<3 >{};
4041 static constexpr auto I5 = Number<5 >{};
42+ static constexpr auto I6 = Number<6 >{};
4143
4244 using ThisThreadBlock = ThisThreadBlock<BlockSize>;
4345
@@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base
5456 static constexpr index_t B_KRow = 1 ;
5557#endif
5658
57- static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
58- static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
59+ static constexpr auto wmma_gemm = WmmaGemm<ComputeTypeA,
60+ ComputeTypeB,
61+ AccDataType,
62+ MPerWmma,
63+ NPerWmma,
64+ KPack / KInner,
65+ TransposeC>{};
66+
67+ static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner;
68+ static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread);
69+ static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread);
5970
6071 static_assert (KPack % (A_K1 * A_KRow) == 0 , " wrong!" );
6172 static_assert (KPack % (B_K1 * B_KRow) == 0 , " wrong!" );
62-
63- static constexpr auto wmma_gemm =
64- WmmaGemm<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};
65-
6673 static constexpr index_t KRepeat = KPerBlock / KPack;
6774
6875 static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
@@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
191198 const auto wmma_krow = 0 ;
192199#endif
193200
194- // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
195- return make_tuple (0 , 0 , waveId_m, wmma_krow, wmma_a_idx, 0 );
201+ return make_tuple (0 , 0 , 0 , waveId_m, wmma_krow, wmma_a_idx, 0 );
196202 }
197203
198204 __device__ static auto CalculateBThreadOriginDataIndex ()
@@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
209215 const auto wmma_krow = 0 ;
210216#endif
211217
212- // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
213- return make_tuple (0 , 0 , waveId_n, wmma_krow, wmma_b_idx, 0 );
218+ return make_tuple (0 , 0 , 0 , waveId_n, wmma_krow, wmma_b_idx, 0 );
214219 }
215220
216221 template <index_t m0, index_t n0>
@@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
241246 return make_tuple (c_thread_m, c_thread_n);
242247 }
243248
244- using Tuple6 = decltype (CalculateAThreadOriginDataIndex());
249+ using Tuple7 = decltype (CalculateAThreadOriginDataIndex());
245250
246251 /* *
247252 * @brief Constructor for BlockwiseGemmWmmaops_pipeline_base.
@@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base
261266 * repeat dimensions.
262267 */
263268 __host__ __device__
264- BlockwiseGemmWmmaops_pipeline_base (Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
265- Tuple6 b_origin = CalculateBThreadOriginDataIndex())
269+ BlockwiseGemmWmmaops_pipeline_base (Tuple7 a_origin = CalculateAThreadOriginDataIndex(),
270+ Tuple7 b_origin = CalculateBThreadOriginDataIndex())
266271 : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
267272 {
268273 static_assert (AWmmaTileDesc::IsKnownAtCompileTime () &&
@@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
343348 Number<KRepeat>{},
344349 I1,
345350 I1,
351+ I1,
346352 Number<A_K1>{}),
347353 make_tuple (Number<A_K1>{},
348354 Number<KPack / A_KRow>{},
349355 Number<KPack / A_KRow * MRepeat>{},
350356 I0,
351357 I0,
358+ I0,
352359 I1));
353360
354361 static constexpr auto b_thread_desc_ =
@@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
357364 Number<KRepeat>{},
358365 I1,
359366 I1,
367+ I1,
360368 Number<B_K1>{}),
361369 make_tuple (Number<B_K1>{},
362370 Number<KPack / B_KRow>{},
363371 Number<KPack / B_KRow * NRepeat>{},
364372 I0,
365373 I0,
374+ I0,
366375 I1));
367376
368377 // C[M, N, NumRegWmma]
@@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
374383 ComputeTypeA,
375384 decltype (a_block_desc_k0_m0_m1_m2_k1),
376385 decltype (a_thread_desc_),
377- Sequence<KPack / A_K1 / A_KRow, 1 , 1 , 1 , 1 , A_K1>,
378- Sequence<0 , 1 , 2 , 3 , 4 , 5 >,
379- 5 ,
386+ Sequence<KPack / A_K1 / A_KRow, 1 , 1 , 1 , 1 , 1 , A_K1>,
387+ Sequence<0 , 1 , 2 , 3 , 4 , 5 , 6 >,
388+ 6 ,
380389 A_K1,
381390 A_K1>;
382391
@@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
385394 ComputeTypeB,
386395 decltype (b_block_desc_k0_n0_n1_n2_k1),
387396 decltype (b_thread_desc_),
388- Sequence<KPack / B_K1 / B_KRow, 1 , 1 , 1 , 1 , B_K1>,
389- Sequence<0 , 1 , 2 , 3 , 4 , 5 >,
390- 5 ,
397+ Sequence<KPack / B_K1 / B_KRow, 1 , 1 , 1 , 1 , 1 , B_K1>,
398+ Sequence<0 , 1 , 2 , 3 , 4 , 5 , 6 >,
399+ 6 ,
391400 B_K1,
392401 B_K1>;
393402
0 commit comments