@@ -21,8 +21,6 @@ namespace mlir {
21
21
namespace triton {
22
22
namespace gpu {
23
23
24
- namespace {
25
-
26
24
// Get the highest version supported for the hardware and the dot.
27
25
static int getMMAVersionSafe (int computeCapability, DotOp op) {
28
26
// List supported mma version in order of preference.
@@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
47
45
return 0 ;
48
46
}
49
47
50
- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
51
- int numWarps) {
48
+ SmallVector<unsigned >
49
+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
52
50
auto rank = shape.size ();
53
51
// Early exit for batched matmul
54
52
if (rank == 3 )
@@ -112,10 +110,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
112
110
}
113
111
114
112
SmallVector<unsigned , 2 >
115
- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
113
+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
116
114
const SmallVector<unsigned , 3 > &instrShape) {
117
115
SetVector<Operation *> slices;
118
- mlir::getForwardSlice (dotOp. getResult (), &slices);
116
+ mlir::getForwardSlice (dotOp-> getResult (0 ), &slices);
119
117
// Contains a chained dot. We prefer to assign warps to one axis
120
118
// to facilitate use cases like flash attention, allowing reductions within
121
119
// the same warp.
@@ -181,6 +179,21 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
181
179
auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
182
180
newLayout, SharedMemorySpace);
183
181
rewriter.setInsertionPointAfterValue (arg);
182
+
183
+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
184
+ // to SharedEncoding.
185
+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
186
+ argType.getEncoding ())) {
187
+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
188
+ // then pass it to the LocalAllocOp.
189
+ auto newArgType = RankedTensorType::get (
190
+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
191
+ auto dotOperandToBlockedCvt =
192
+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
193
+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
194
+ dotOperandToBlockedCvt);
195
+ }
196
+
184
197
return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
185
198
}
186
199
@@ -204,7 +217,7 @@ getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) {
204
217
}
205
218
206
219
SmallVector<unsigned , 3 >
207
- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
220
+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
208
221
int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
209
222
switch (version) {
210
223
case 2 :
@@ -218,6 +231,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
218
231
}
219
232
220
233
static bool bwdFilter (Operation *op) {
234
+ // Dot operand layout assignment to Predicates are not currently supported
235
+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
236
+ // condition limits visibility of the original bit-width so that predicate
237
+ // are not considered, hence, kwidth can never be = 32.
238
+ if (isa<arith::UIToFPOp>(op)) {
239
+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
240
+ if (srcType.isInteger (1 ))
241
+ return false ;
242
+ }
243
+
221
244
return op->getNumOperands () == 1 &&
222
245
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
223
246
isPureUnaryInlineAsm (op) ||
@@ -237,7 +260,7 @@ static bool bwdFilter(Operation *op) {
237
260
// result, kwidth can be the bitwidth of the lower precision primitive.
238
261
// Conversely, in the downcasting scenario, no reordering is performed,
239
262
// making it directory use the lower precision primitive.
240
- static int computeOrigBitWidth (Value x) {
263
+ int computeOrigBitWidth (Value x) {
241
264
int finalBitWidth = getElementTypeOrSelf (x).getIntOrFloatBitWidth ();
242
265
int origBitWidth = finalBitWidth;
243
266
SetVector<Operation *> slice;
@@ -257,6 +280,9 @@ static int computeOrigBitWidth(Value x) {
257
280
}
258
281
return origBitWidth;
259
282
}
283
+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
284
+ // extension.
285
+ namespace {
260
286
261
287
class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
262
288
int computeCapability;
@@ -1147,6 +1173,11 @@ class TritonGPUAccelerateMatmulPass
1147
1173
}
1148
1174
};
1149
1175
1176
+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
1177
+ int opIdx, bool allowTranspose) {
1178
+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
1179
+ }
1180
+
1150
1181
} // namespace gpu
1151
1182
} // namespace triton
1152
1183
} // namespace mlir
0 commit comments