@@ -21,8 +21,6 @@ namespace mlir {
21
21
namespace triton {
22
22
namespace gpu {
23
23
24
- namespace {
25
-
26
24
// Get the highest version supported for the hardware and the dot.
27
25
static int getMMAVersionSafe (int computeCapability, DotOp op) {
28
26
// List supported mma version in order of preference.
@@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
47
45
return 0 ;
48
46
}
49
47
50
- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
51
- int numWarps) {
48
+ SmallVector<unsigned >
49
+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
52
50
auto rank = shape.size ();
53
51
// Early exit for batched matmul
54
52
if (rank == 3 )
@@ -112,10 +110,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
112
110
}
113
111
114
112
SmallVector<unsigned , 2 >
115
- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
113
+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
116
114
const SmallVector<unsigned , 3 > &instrShape) {
117
115
SetVector<Operation *> slices;
118
- mlir::getForwardSlice (dotOp. getResult (), &slices);
116
+ mlir::getForwardSlice (dotOp-> getResult (0 ), &slices);
119
117
// Contains a chained dot. We prefer to assign warps to one axis
120
118
// to facilitate use cases like flash attention, allowing reductions within
121
119
// the same warp.
@@ -170,11 +168,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
170
168
auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
171
169
newLayout, SharedMemorySpace);
172
170
rewriter.setInsertionPointAfterValue (arg);
171
+
172
+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
173
+ // to SharedEncoding.
174
+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
175
+ argType.getEncoding ())) {
176
+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
177
+ // then pass it to the LocalAllocOp.
178
+ auto newArgType = RankedTensorType::get (
179
+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
180
+ auto dotOperandToBlockedCvt =
181
+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
182
+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
183
+ dotOperandToBlockedCvt);
184
+ }
185
+
173
186
return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
174
187
}
175
188
176
189
SmallVector<unsigned , 3 >
177
- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
190
+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
178
191
int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
179
192
switch (version) {
180
193
case 2 :
@@ -188,6 +201,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
188
201
}
189
202
190
203
static bool bwdFilter (Operation *op) {
204
+ // Dot operand layout assignment to Predicates are not currently supported
205
+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
206
+ // condition limits visibility of the original bit-width so that predicate
207
+ // are not considered, hence, kwidth can never be = 32.
208
+ if (isa<arith::UIToFPOp>(op)) {
209
+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
210
+ if (srcType.isInteger (1 ))
211
+ return false ;
212
+ }
213
+
191
214
return op->getNumOperands () == 1 &&
192
215
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
193
216
isPureUnaryInlineAsm (op) ||
@@ -207,7 +230,7 @@ static bool bwdFilter(Operation *op) {
207
230
// result, kwidth can be the bitwidth of the lower precision primitive.
208
231
// Conversely, in the downcasting scenario, no reordering is performed,
209
232
// making it directory use the lower precision primitive.
210
- static int computeOrigBitWidth (Value x) {
233
+ int computeOrigBitWidth (Value x) {
211
234
int finalBitWidth = getElementTypeOrSelf (x).getIntOrFloatBitWidth ();
212
235
int origBitWidth = finalBitWidth;
213
236
SetVector<Operation *> slice;
@@ -227,6 +250,9 @@ static int computeOrigBitWidth(Value x) {
227
250
}
228
251
return origBitWidth;
229
252
}
253
+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
254
+ // extension.
255
+ namespace {
230
256
231
257
class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
232
258
int computeCapability;
@@ -632,7 +658,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
632
658
NvidiaMmaEncodingAttr mmaLayout =
633
659
dyn_cast<NvidiaMmaEncodingAttr>(D.getType ().getEncoding ());
634
660
if (mmaLayout) {
635
- bool isNativeFP8 = AElType.isFloat8E5M2 () || AElType.isFloat8E4M3FN ();
661
+ bool isNativeFP8 =
662
+ llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(AElType);
636
663
// promote operands for sm < 89 since fp8 mma is not natively supported
637
664
// promote operands for sm >= 90 when mma is not v3
638
665
if (!isNativeFP8 ||
@@ -1018,6 +1045,11 @@ class TritonGPUAccelerateMatmulPass
1018
1045
}
1019
1046
};
1020
1047
1048
+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
1049
+ int opIdx, bool allowTranspose) {
1050
+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
1051
+ }
1052
+
1021
1053
} // namespace gpu
1022
1054
} // namespace triton
1023
1055
} // namespace mlir
0 commit comments