@@ -336,7 +336,107 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
336336
337337// The layout example repeat_count=8, systolic_depth=8,
338338// execution_size=16 and operands_per_chan=2 for warp size 32.
339- // DPASInst layout of C operand:
339+ // For A operand:
340+ // systolic depth = 8
341+ // <----------------------------------------------------->
342+ // opsPerChan=2
343+ // <--------->
344+ // t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 ^
345+ // t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
346+ // t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
347+ // t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 | repeat count <= 8
348+ // t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 |
349+ // t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
350+ // t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
351+ // t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 v
352+ // In this case, the LinearLayout bases are:
353+ // Register: {{0,1}, {4,0}}
354+ // Lane: {{0,2}, {0,4}, {0,8}, {1,0}, {2,0}}
355+ std::vector<std::vector<int32_t >> DPASRegBasesA (int opsPerChannel,
356+ int repeatCount,
357+ int threadsPerWarp,
358+ int systolicDepth) {
359+ int rowPerWarp = threadsPerWarp / systolicDepth;
360+ int warpRepeats = repeatCount / rowPerWarp;
361+ std::vector<std::vector<int32_t >> regBases;
362+
363+ for (int opc = 1 ; opc < opsPerChannel; opc *= 2 ) {
364+ regBases.push_back ({0 , opc});
365+ }
366+
367+ for (int warp = 1 ; warp < warpRepeats; warp *= 2 ) {
368+ regBases.push_back ({warp * rowPerWarp, 0 });
369+ }
370+
371+ return regBases;
372+ }
373+
374+ std::vector<std::vector<int32_t >>
375+ DPASLaneBasesA (int opsPerChannel, int threadsPerWarp, int systolicDepth) {
376+ std::vector<std::vector<int32_t >> laneBases;
377+
378+ for (int tid = 1 ; tid < systolicDepth; tid *= 2 ) {
379+ laneBases.push_back ({0 , opsPerChannel * tid});
380+ }
381+ for (int tid = systolicDepth; tid < threadsPerWarp; tid *= 2 ) {
382+ laneBases.push_back ({tid / systolicDepth, 0 });
383+ }
384+
385+ return laneBases;
386+ }
387+
388+ // For B operand:
389+ // execution size = 16
390+ // <-------------------------------------------------->
391+ // t0 t1 t2 t3 ~ t12 t13 t14 t15 ^ ^
392+ // . . . . . . . . . | opsPerChan=2 |
393+ // t0 t1 t2 t3 ~ t12 t13 t14 t15 v |
394+ // t16 t17 t18 t19 ~ t28 t29 t30 t31 |
395+ // . . . . . . . . . |
396+ // t16 t17 t18 t19 ~ t28 t29 t30 t31 | systolic depth = 8
397+ // t0 t1 t2 t3 ~ t12 t13 t14 t15 |
398+ // . . . . . . . . . |
399+ // t0 t1 t2 t3 ~ t12 t13 t14 t15 |
400+ // t16 t17 t18 t19 ~ t28 t29 t30 t31 |
401+ // . . . . . . . . . |
402+ // t16 t17 t18 t19 ~ t28 t29 t30 t31 v
403+ // In this case, the LinearLayout bases are:
404+ // Register: {{1,0}, {4,0}, {8,0}}
405+ // Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {2,0}}
406+ std::vector<std::vector<int32_t >> DPASRegBasesB (int opsPerChannel,
407+ int executionSize,
408+ int threadsPerWarp,
409+ int systolicDepth) {
410+ int rowsPerWarp = threadsPerWarp / executionSize;
411+ int warpRepeats = systolicDepth / rowsPerWarp;
412+ std::vector<std::vector<int32_t >> regBases;
413+
414+ for (int opc = 1 ; opc < opsPerChannel; opc *= 2 ) {
415+ regBases.push_back ({opc, 0 });
416+ }
417+ for (int rid = rowsPerWarp; rid < systolicDepth; rid *= 2 ) {
418+ regBases.push_back ({rid * opsPerChannel, 0 });
419+ }
420+
421+ return regBases;
422+ }
423+
424+ std::vector<std::vector<int32_t >>
425+ DPASLaneBasesB (int opsPerChannel, int threadsPerWarp, int executionSize) {
426+ std::vector<std::vector<int32_t >> laneBases;
427+
428+ for (int tid = 1 ; tid < executionSize; tid *= 2 ) {
429+ laneBases.push_back ({0 , tid});
430+ }
431+ int rowsPerWarp = threadsPerWarp / executionSize;
432+ for (int row = 1 ; row < rowsPerWarp; row *= 2 ) {
433+ laneBases.push_back ({row * opsPerChannel, 0 });
434+ }
435+
436+ return laneBases;
437+ }
438+
439+ // For C operand:
340440// execution size = 16
341441// <---------------------------------->
342442// t0 t1 t2 t3 ~ t12 t13 t14 t15 ^
@@ -348,15 +448,13 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
348448// In this case, the LinearLayout bases are:
349449// Register: {{2,0}, {4,0}}
350450// Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {1,0}}
351- // Currently, LinearLayout is not supported for DotOperandEncoding
352- // so only Operand C conversion is implemented.
353451std::vector<std::vector<int32_t >>
354452DPASRegBasesC (int repeatCount, int executionSize, int threadsPerWarp) {
355453 int rowsPerWarp = threadsPerWarp / executionSize;
356454
357455 std::vector<std::vector<int32_t >> regBases;
358456
359- for (int rid = rowsPerWarp; rid < repeatCount; rid = rid * 2 ) {
457+ for (int rid = rowsPerWarp; rid < repeatCount; rid *= 2 ) {
360458 regBases.push_back ({rid, 0 });
361459 }
362460
@@ -365,25 +463,24 @@ DPASRegBasesC(int repeatCount, int executionSize, int threadsPerWarp) {
365463
366464std::vector<std::vector<int32_t >>
367465DPASLaneBasesC (int repeatCount, int executionSize, int threadsPerWarp) {
368-
369466 std::vector<std::vector<int32_t >> laneBases;
370467
371- for (int tid = 1 ; tid < executionSize; tid = tid * 2 ) {
468+ for (int tid = 1 ; tid < executionSize; tid *= 2 ) {
372469 laneBases.push_back ({0 , tid});
373470 }
374471 int rowsPerWarp = threadsPerWarp / executionSize;
375- for (int row = 1 ; row < rowsPerWarp; row = row * 2 ) {
472+ for (int row = 1 ; row < rowsPerWarp; row *= 2 ) {
376473 laneBases.push_back ({row, 0 });
377474 }
378475
379476 return laneBases;
380477}
381478
382- std::optional<LinearLayout> DPAStoLinearLayout (ArrayRef< int64_t > shape,
383- Attribute layout) {
384-
479+ std::optional<LinearLayout>
480+ DPAStoLinearLayout (ArrayRef< int64_t > shape, Attribute layout, unsigned opIdx ) {
481+ assert (opIdx < 3 && opIdx >= 0 );
385482 auto dpas = dyn_cast<DpasEncodingAttr>(layout);
386- assert (dpas && " Must be DPAS Operand C layout" );
483+ assert (dpas && " Must be DPAS layout" );
387484
388485 int rank = shape.size ();
389486 assert (rank == dpas.getWarpsPerCTA ().size ());
@@ -397,33 +494,80 @@ std::optional<LinearLayout> DPAStoLinearLayout(ArrayRef<int64_t> shape,
397494
398495 const SmallVector<unsigned > warpsPerCTA = dpas.getWarpsPerCTA ();
399496 int threadsPerWarp = triton::gpu::getWarpSize (dpas);
497+ unsigned opsPerChannel = dpas.getOpsPerChannel ();
400498 auto repCluster = dpas.getRepCluster ();
401- SmallVector<int64_t > numReps = dpas.getDPASRepetitions (shape, 2 );
499+ SmallVector<int64_t > numReps = dpas.getDPASRepetitions (shape, opIdx );
402500
403501 auto tileLayout = LinearLayout::empty ();
502+ int systolicDepth = dpas.getSystolicDepth ();
404503 int repeatCount = dpas.getRepeatCount ();
405504 int executionSize = dpas.getExecutionSize ();
505+ unsigned KDim = 0 ;
506+ unsigned nonKDim = 0 ;
507+ if (opIdx == 0 ) { // Operand A
508+ auto regBasesA = DPASRegBasesA (opsPerChannel, repeatCount, threadsPerWarp,
509+ systolicDepth);
510+ auto laneBasesA =
511+ DPASLaneBasesA (opsPerChannel, threadsPerWarp, systolicDepth);
512+ tileLayout = LinearLayout ({{kRegister , regBasesA}, {kLane , laneBasesA}},
513+ outDimNames);
514+ // A only repeats by repCluster[0]
515+ tileLayout *=
516+ LinearLayout::identity1D (repCluster[0 ], kRegister , outDimNames[0 ]);
517+ nonKDim = 0 ;
518+ KDim = 1 ;
519+ } else if (opIdx == 1 ) { // Operand B
520+ auto regBasesB = DPASRegBasesB (opsPerChannel, executionSize, threadsPerWarp,
521+ systolicDepth);
522+ auto laneBasesB =
523+ DPASLaneBasesB (opsPerChannel, threadsPerWarp, executionSize);
524+ tileLayout = LinearLayout ({{kRegister , regBasesB}, {kLane , laneBasesB}},
525+ outDimNames);
526+ // B only repeats by repCluster[1]
527+ tileLayout *=
528+ LinearLayout::identity1D (repCluster[1 ], kRegister , outDimNames[1 ]);
529+ nonKDim = 1 ;
530+ KDim = 0 ;
531+ } else { // opIdx=2 -> Operand C
532+ auto regBasesC = DPASRegBasesC (repeatCount, executionSize, threadsPerWarp);
533+ auto laneBasesC =
534+ DPASLaneBasesC (repeatCount, executionSize, threadsPerWarp);
535+ tileLayout = LinearLayout ({{kRegister , regBasesC}, {kLane , laneBasesC}},
536+ outDimNames);
537+ // The per-inst layout is repeated at each repCluster.
538+ // Hence, multiply with the identity layouts starting from the
539+ // least significant dimension.
540+ tileLayout *=
541+ LinearLayout::identity1D (repCluster[1 ], kRegister , outDimNames[1 ]);
542+ tileLayout *=
543+ LinearLayout::identity1D (repCluster[0 ], kRegister , outDimNames[0 ]);
544+ nonKDim = 0 ;
545+ KDim = 1 ;
546+ }
406547
407- auto regBases = DPASRegBasesC (repeatCount, executionSize, threadsPerWarp);
408- auto laneBases = DPASLaneBasesC (repeatCount, executionSize, threadsPerWarp);
409- tileLayout =
410- LinearLayout ({{kRegister , regBases}, {kLane , laneBases}}, outDimNames);
411-
412- // The per-inst layout is repeated at each repCluster.
413- // Hence, multiply with the identity layouts starting from the
414- // least significant dimension.
415- tileLayout *=
416- LinearLayout::identity1D (repCluster[1 ], kRegister , outDimNames[1 ]);
548+ // Operand A/B repeats through the K-dimension first then repeats
549+ // through non-K dimension.
417550 tileLayout *=
418- LinearLayout::identity1D (repCluster[0 ], kRegister , outDimNames[0 ]);
419-
420- // Then, it is repeated by DPASRepetitions to form per-Warp layout.
421- tileLayout *= LinearLayout::identity1D (numReps[1 ], kRegister , outDimNames[1 ]);
422- tileLayout *= LinearLayout::identity1D (numReps[0 ], kRegister , outDimNames[0 ]);
423-
424- // Finally, per-warp layout is repeated among the warps in the CTA.
425- LinearLayout warpLayout =
426- identityND (S (" warp" ), dpas.getWarpsPerCTA (), {0 , 1 }, outDimNames);
551+ LinearLayout::identity1D (numReps[KDim], kRegister , outDimNames[KDim]);
552+ tileLayout *= LinearLayout::identity1D (numReps[nonKDim], kRegister ,
553+ outDimNames[nonKDim]);
554+
555+ // For Operand C, warps split the tensor identically.
556+ // For Operand A and B, warps in the K-dimension share the same data.
557+ // In these cases, the warp hops for K-dimensions are zero.
558+ LinearLayout warpLayout = LinearLayout::empty ();
559+ StringAttr kWarp = S (" warp" );
560+ if (opIdx == 0 ) {
561+ warpLayout =
562+ LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
563+ warpLayout *= LinearLayout::zeros1D (warpsPerCTA[1 ], kWarp , outDimNames[1 ]);
564+ } else if (opIdx == 1 ) {
565+ warpLayout = LinearLayout::zeros1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
566+ warpLayout *=
567+ LinearLayout::identity1D (warpsPerCTA[1 ], kWarp , outDimNames[1 ]);
568+ } else { /* opIdx == 2 */
569+ warpLayout = identityND (kWarp , warpsPerCTA, {0 , 1 }, outDimNames);
570+ }
427571 LinearLayout ctaLayout = tileLayout * warpLayout;
428572
429573 return combineCtaCgaWithShape (ctaLayout, CTALayoutAttr::getDefault (ctx, rank),
0 commit comments