Skip to content

Commit 75eee6f

Browse files
authored
DPAS operand A and operand B conversion to LinearLayout (#1746)
This PR adds DPAS -> LinearLayout conversion of operand A and B layouts. Currently, Triton does not use LinearLayout conversion for DotOperand layouts (A and B). I have included operand A/B support in DPAStoLinearLayout function for potential future use. It tested with the unit tests added.
1 parent e542e39 commit 75eee6f

File tree

3 files changed

+254
-32
lines changed

3 files changed

+254
-32
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@
1111

1212
namespace mlir::triton::gpu {
1313

14+
// DPAS operand A: opIdx=0
15+
// DPAS operand B: opIdx=1
16+
// DPAS operand C (default): opIdx=2
17+
// Operand A and B conversion are not used yet
1418
std::optional<LinearLayout> DPAStoLinearLayout(ArrayRef<int64_t> shape,
15-
Attribute layout);
19+
Attribute layout,
20+
unsigned opIdx = 2);
1621

1722
} // namespace mlir::triton::gpu
1823

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 175 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,107 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
336336

337337
// The layout example repeat_count=8, systolic_depth=8,
338338
// execution_size=16 and operands_per_chan=2 for warp size 32.
339-
// DPASInst layout of C operand:
339+
// For A operand:
340+
// systolic depth = 8
341+
//<----------------------------------------------------->
342+
// opsPerChan=2
343+
//<--------->
344+
// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 ^
345+
// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
346+
// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
347+
// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 | repeat count <= 8
348+
// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 |
349+
// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
350+
// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
351+
// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 v
352+
// In this case, the LinearLayout bases are:
353+
// Register: {{0,1}, {4,0}}
354+
// Lane: {{0,2}, {0,4}, {0,8}, {1,0}, {2,0}}
355+
std::vector<std::vector<int32_t>> DPASRegBasesA(int opsPerChannel,
356+
int repeatCount,
357+
int threadsPerWarp,
358+
int systolicDepth) {
359+
int rowPerWarp = threadsPerWarp / systolicDepth;
360+
int warpRepeats = repeatCount / rowPerWarp;
361+
std::vector<std::vector<int32_t>> regBases;
362+
363+
for (int opc = 1; opc < opsPerChannel; opc *= 2) {
364+
regBases.push_back({0, opc});
365+
}
366+
367+
for (int warp = 1; warp < warpRepeats; warp *= 2) {
368+
regBases.push_back({warp * rowPerWarp, 0});
369+
}
370+
371+
return regBases;
372+
}
373+
374+
std::vector<std::vector<int32_t>>
375+
DPASLaneBasesA(int opsPerChannel, int threadsPerWarp, int systolicDepth) {
376+
std::vector<std::vector<int32_t>> laneBases;
377+
378+
for (int tid = 1; tid < systolicDepth; tid *= 2) {
379+
laneBases.push_back({0, opsPerChannel * tid});
380+
}
381+
for (int tid = systolicDepth; tid < threadsPerWarp; tid *= 2) {
382+
laneBases.push_back({tid / systolicDepth, 0});
383+
}
384+
385+
return laneBases;
386+
}
387+
388+
// For B operand:
389+
// execution size = 16
390+
//<-------------------------------------------------->
391+
// t0 t1 t2 t3 ~ t12 t13 t14 t15 ^ ^
392+
//. . . . . . . . . | opsPerChan=2 |
393+
// t0 t1 t2 t3 ~ t12 t13 t14 t15 v |
394+
// t16 t17 t18 t19 ~ t28 t29 t30 t31 |
395+
//. . . . . . . . . |
396+
// t16 t17 t18 t19 ~ t28 t29 t30 t31 | systolic depth = 8
397+
// t0 t1 t2 t3 ~ t12 t13 t14 t15 |
398+
//. . . . . . . . . |
399+
// t0 t1 t2 t3 ~ t12 t13 t14 t15 |
400+
// t16 t17 t18 t19 ~ t28 t29 t30 t31 |
401+
//. . . . . . . . . |
402+
// t16 t17 t18 t19 ~ t28 t29 t30 t31 v
403+
// In this case, the LinearLayout bases are:
404+
// Register: {{1,0}, {4,0}, {8,0}}
405+
// Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {2,0}}
406+
std::vector<std::vector<int32_t>> DPASRegBasesB(int opsPerChannel,
407+
int executionSize,
408+
int threadsPerWarp,
409+
int systolicDepth) {
410+
int rowsPerWarp = threadsPerWarp / executionSize;
411+
int warpRepeats = systolicDepth / rowsPerWarp;
412+
std::vector<std::vector<int32_t>> regBases;
413+
414+
for (int opc = 1; opc < opsPerChannel; opc *= 2) {
415+
regBases.push_back({opc, 0});
416+
}
417+
for (int rid = rowsPerWarp; rid < systolicDepth; rid *= 2) {
418+
regBases.push_back({rid * opsPerChannel, 0});
419+
}
420+
421+
return regBases;
422+
}
423+
424+
std::vector<std::vector<int32_t>>
425+
DPASLaneBasesB(int opsPerChannel, int threadsPerWarp, int executionSize) {
426+
std::vector<std::vector<int32_t>> laneBases;
427+
428+
for (int tid = 1; tid < executionSize; tid *= 2) {
429+
laneBases.push_back({0, tid});
430+
}
431+
int rowsPerWarp = threadsPerWarp / executionSize;
432+
for (int row = 1; row < rowsPerWarp; row *= 2) {
433+
laneBases.push_back({row * opsPerChannel, 0});
434+
}
435+
436+
return laneBases;
437+
}
438+
439+
// For C operand:
340440
// execution size = 16
341441
//<---------------------------------->
342442
// t0 t1 t2 t3 ~ t12 t13 t14 t15 ^
@@ -348,15 +448,13 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
348448
// In this case, the LinearLayout bases are:
349449
// Register: {{2,0}, {4,0}}
350450
// Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {1,0}}
351-
// Currently, LinearLayout is not supported for DotOperandEncoding
352-
// so only Operand C conversion is implemented.
353451
std::vector<std::vector<int32_t>>
354452
DPASRegBasesC(int repeatCount, int executionSize, int threadsPerWarp) {
355453
int rowsPerWarp = threadsPerWarp / executionSize;
356454

357455
std::vector<std::vector<int32_t>> regBases;
358456

359-
for (int rid = rowsPerWarp; rid < repeatCount; rid = rid * 2) {
457+
for (int rid = rowsPerWarp; rid < repeatCount; rid *= 2) {
360458
regBases.push_back({rid, 0});
361459
}
362460

@@ -365,25 +463,24 @@ DPASRegBasesC(int repeatCount, int executionSize, int threadsPerWarp) {
365463

366464
std::vector<std::vector<int32_t>>
367465
DPASLaneBasesC(int repeatCount, int executionSize, int threadsPerWarp) {
368-
369466
std::vector<std::vector<int32_t>> laneBases;
370467

371-
for (int tid = 1; tid < executionSize; tid = tid * 2) {
468+
for (int tid = 1; tid < executionSize; tid *= 2) {
372469
laneBases.push_back({0, tid});
373470
}
374471
int rowsPerWarp = threadsPerWarp / executionSize;
375-
for (int row = 1; row < rowsPerWarp; row = row * 2) {
472+
for (int row = 1; row < rowsPerWarp; row *= 2) {
376473
laneBases.push_back({row, 0});
377474
}
378475

379476
return laneBases;
380477
}
381478

382-
std::optional<LinearLayout> DPAStoLinearLayout(ArrayRef<int64_t> shape,
383-
Attribute layout) {
384-
479+
std::optional<LinearLayout>
480+
DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout, unsigned opIdx) {
481+
assert(opIdx < 3 && opIdx >= 0);
385482
auto dpas = dyn_cast<DpasEncodingAttr>(layout);
386-
assert(dpas && "Must be DPAS Operand C layout");
483+
assert(dpas && "Must be DPAS layout");
387484

388485
int rank = shape.size();
389486
assert(rank == dpas.getWarpsPerCTA().size());
@@ -397,33 +494,80 @@ std::optional<LinearLayout> DPAStoLinearLayout(ArrayRef<int64_t> shape,
397494

398495
const SmallVector<unsigned> warpsPerCTA = dpas.getWarpsPerCTA();
399496
int threadsPerWarp = triton::gpu::getWarpSize(dpas);
497+
unsigned opsPerChannel = dpas.getOpsPerChannel();
400498
auto repCluster = dpas.getRepCluster();
401-
SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, 2);
499+
SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);
402500

403501
auto tileLayout = LinearLayout::empty();
502+
int systolicDepth = dpas.getSystolicDepth();
404503
int repeatCount = dpas.getRepeatCount();
405504
int executionSize = dpas.getExecutionSize();
505+
unsigned KDim = 0;
506+
unsigned nonKDim = 0;
507+
if (opIdx == 0) { // Operand A
508+
auto regBasesA = DPASRegBasesA(opsPerChannel, repeatCount, threadsPerWarp,
509+
systolicDepth);
510+
auto laneBasesA =
511+
DPASLaneBasesA(opsPerChannel, threadsPerWarp, systolicDepth);
512+
tileLayout = LinearLayout({{kRegister, regBasesA}, {kLane, laneBasesA}},
513+
outDimNames);
514+
// A only repeats by repCluster[0]
515+
tileLayout *=
516+
LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]);
517+
nonKDim = 0;
518+
KDim = 1;
519+
} else if (opIdx == 1) { // Operand B
520+
auto regBasesB = DPASRegBasesB(opsPerChannel, executionSize, threadsPerWarp,
521+
systolicDepth);
522+
auto laneBasesB =
523+
DPASLaneBasesB(opsPerChannel, threadsPerWarp, executionSize);
524+
tileLayout = LinearLayout({{kRegister, regBasesB}, {kLane, laneBasesB}},
525+
outDimNames);
526+
// B only repeats by repCluster[1]
527+
tileLayout *=
528+
LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]);
529+
nonKDim = 1;
530+
KDim = 0;
531+
} else { // opIdx=2 -> Operand C
532+
auto regBasesC = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp);
533+
auto laneBasesC =
534+
DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp);
535+
tileLayout = LinearLayout({{kRegister, regBasesC}, {kLane, laneBasesC}},
536+
outDimNames);
537+
// The per-inst layout is repeated at each repCluster.
538+
// Hence, multiply with the identity layouts starting from the
539+
// least significant dimension.
540+
tileLayout *=
541+
LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]);
542+
tileLayout *=
543+
LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]);
544+
nonKDim = 0;
545+
KDim = 1;
546+
}
406547

407-
auto regBases = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp);
408-
auto laneBases = DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp);
409-
tileLayout =
410-
LinearLayout({{kRegister, regBases}, {kLane, laneBases}}, outDimNames);
411-
412-
// The per-inst layout is repeated at each repCluster.
413-
// Hence, multiply with the identity layouts starting from the
414-
// least significant dimension.
415-
tileLayout *=
416-
LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]);
548+
// Operand A/B repeats through the K-dimension first then repeats
549+
// through non-K dimension.
417550
tileLayout *=
418-
LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]);
419-
420-
// Then, it is repeated by DPASRepetitions to form per-Warp layout.
421-
tileLayout *= LinearLayout::identity1D(numReps[1], kRegister, outDimNames[1]);
422-
tileLayout *= LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]);
423-
424-
// Finally, per-warp layout is repeated among the warps in the CTA.
425-
LinearLayout warpLayout =
426-
identityND(S("warp"), dpas.getWarpsPerCTA(), {0, 1}, outDimNames);
551+
LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]);
552+
tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister,
553+
outDimNames[nonKDim]);
554+
555+
// For Operand C, warps split the tensor identically.
556+
// For Operand A and B, warps in the K-dimension share the same data.
557+
// In these cases, the warp hops for K-dimensions are zero.
558+
LinearLayout warpLayout = LinearLayout::empty();
559+
StringAttr kWarp = S("warp");
560+
if (opIdx == 0) {
561+
warpLayout =
562+
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
563+
warpLayout *= LinearLayout::zeros1D(warpsPerCTA[1], kWarp, outDimNames[1]);
564+
} else if (opIdx == 1) {
565+
warpLayout = LinearLayout::zeros1D(warpsPerCTA[0], kWarp, outDimNames[0]);
566+
warpLayout *=
567+
LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]);
568+
} else { /* opIdx == 2 */
569+
warpLayout = identityND(kWarp, warpsPerCTA, {0, 1}, outDimNames);
570+
}
427571
LinearLayout ctaLayout = tileLayout * warpLayout;
428572

429573
return combineCtaCgaWithShape(ctaLayout, CTALayoutAttr::getDefault(ctx, rank),

unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class DPAStoLinearLayoutTest : public ::testing::Test {
3939
};
4040

4141
TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) {
42+
// Default: Operand C
4243
EXPECT_EQ(DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32)),
4344
LinearLayout(
4445
{
@@ -57,6 +58,28 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) {
5758
{S("block"), {}},
5859
},
5960
{S("dim0"), S("dim1")}));
61+
// Test Operand A (opIdx=0)
62+
EXPECT_EQ(
63+
DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 0),
64+
LinearLayout(
65+
{
66+
{S("register"), {{0, 1}, {4, 0}}},
67+
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
68+
{S("warp"), {}},
69+
{S("block"), {}},
70+
},
71+
{S("dim0"), S("dim1")}));
72+
// Test Operand B (opIdx=1)
73+
EXPECT_EQ(
74+
DPAStoLinearLayout({16, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 1),
75+
LinearLayout(
76+
{
77+
{S("register"), {{1, 0}, {4, 0}, {8, 0}}},
78+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}},
79+
{S("warp"), {}},
80+
{S("block"), {}},
81+
},
82+
{S("dim0"), S("dim1")}));
6083
}
6184

6285
TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) {
@@ -70,6 +93,28 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) {
7093
{S("block"), {}},
7194
},
7295
{S("dim0"), S("dim1")}));
96+
// Test Operand A (opIdx=0)
97+
EXPECT_EQ(
98+
DPAStoLinearLayout({32, 16}, dpas({1, 1}, 8, 8, 16, 2, {4, 2}, 32), 0),
99+
LinearLayout(
100+
{
101+
{S("register"), {{0, 1}, {4, 0}, {8, 0}, {16, 0}}},
102+
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
103+
{S("warp"), {}},
104+
{S("block"), {}},
105+
},
106+
{S("dim0"), S("dim1")}));
107+
// Test Operand B (opIdx=1)
108+
EXPECT_EQ(
109+
DPAStoLinearLayout({16, 32}, dpas({1, 1}, 8, 8, 16, 2, {4, 2}, 32), 1),
110+
LinearLayout(
111+
{
112+
{S("register"), {{1, 0}, {4, 0}, {8, 0}, {0, 16}}},
113+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}},
114+
{S("warp"), {}},
115+
{S("block"), {}},
116+
},
117+
{S("dim0"), S("dim1")}));
73118
EXPECT_EQ(DPAStoLinearLayout({32, 32}, dpas({1, 1}, 8, 8, 16, 1, {4, 2}, 16)),
74119
LinearLayout(
75120
{
@@ -103,6 +148,34 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withWarp) {
103148
{S("dim0"), S("dim1")}));
104149
}
105150

151+
TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandA) {
152+
EXPECT_EQ(
153+
DPAStoLinearLayout({64, 64}, dpas({2, 2}, 8, 8, 16, 2, {4, 2}, 32), 0),
154+
LinearLayout(
155+
{
156+
{S("register"),
157+
{{0, 1}, {4, 0}, {8, 0}, {16, 0}, {0, 16}, {0, 32}}},
158+
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
159+
{S("warp"), {{32, 0}, {0, 0}}},
160+
{S("block"), {}},
161+
},
162+
{S("dim0"), S("dim1")}));
163+
}
164+
165+
TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandB) {
166+
EXPECT_EQ(
167+
DPAStoLinearLayout({64, 64}, dpas({2, 2}, 8, 8, 16, 2, {4, 2}, 32), 1),
168+
LinearLayout(
169+
{
170+
{S("register"),
171+
{{1, 0}, {4, 0}, {8, 0}, {0, 16}, {16, 0}, {32, 0}}},
172+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}},
173+
{S("warp"), {{0, 0}, {0, 32}}},
174+
{S("block"), {}},
175+
},
176+
{S("dim0"), S("dim1")}));
177+
}
178+
106179
TEST_F(DPAStoLinearLayoutTest, DPAS_withDPASRepetitions) {
107180
EXPECT_EQ(DPAStoLinearLayout({64, 64}, dpas({2, 1}, 8, 8, 16, 2, {4, 2}, 32)),
108181
LinearLayout(

0 commit comments

Comments
 (0)