@@ -21,15 +21,13 @@ namespace mlir {
21
21
namespace triton {
22
22
namespace gpu {
23
23
24
- namespace {
25
-
26
24
// Get the highest version supported for the hardware and the dot.
27
25
static int getMMAVersionSafe (int computeCapability, DotOp op) {
28
26
// List supported mma version in order of preference.
29
27
SmallVector<int > versionsSupported;
30
28
if (computeCapability < 75 ) {
31
29
versionsSupported = {1 };
32
- } else if (computeCapability < 90 ) {
30
+ } else if (computeCapability < 90 || computeCapability >= 100 ) {
33
31
versionsSupported = {2 };
34
32
} else if (computeCapability < 100 ) {
35
33
versionsSupported = {3 , 2 };
@@ -45,8 +43,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
45
43
return 0 ;
46
44
}
47
45
48
- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
49
- int numWarps) {
46
+ SmallVector<unsigned >
47
+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
50
48
auto rank = shape.size ();
51
49
// Early exit for batched matmul
52
50
if (rank == 3 )
@@ -110,10 +108,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
110
108
}
111
109
112
110
SmallVector<unsigned , 2 >
113
- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
111
+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
114
112
const SmallVector<unsigned , 3 > &instrShape) {
115
113
SetVector<Operation *> slices;
116
- mlir::getForwardSlice (dotOp. getResult (), &slices);
114
+ mlir::getForwardSlice (dotOp-> getResult (0 ), &slices);
117
115
// Contains a chained dot. We prefer to assign warps to one axis
118
116
// to facilitate use cases like flash attention, allowing reductions within
119
117
// the same warp.
@@ -168,11 +166,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
168
166
auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
169
167
newLayout, SharedMemorySpace);
170
168
rewriter.setInsertionPointAfterValue (arg);
169
+
170
+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
171
+ // to SharedEncoding.
172
+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
173
+ argType.getEncoding ())) {
174
+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
175
+ // then pass it to the LocalAllocOp.
176
+ auto newArgType = RankedTensorType::get (
177
+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
178
+ auto dotOperandToBlockedCvt =
179
+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
180
+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
181
+ dotOperandToBlockedCvt);
182
+ }
183
+
171
184
return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
172
185
}
173
186
174
187
SmallVector<unsigned , 3 >
175
- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
188
+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
176
189
int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
177
190
switch (version) {
178
191
case 2 :
@@ -185,18 +198,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
185
198
}
186
199
}
187
200
201
+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
202
+ // extension.
203
+ namespace {
204
+
188
205
class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
189
206
int computeCapability;
190
207
mutable llvm::DenseMap<Operation *, unsigned > dotOpInstNs;
191
208
192
209
static bool bwdFilter (Operation *op) {
210
+ // Dot operand layout assignment to Predicates are not currently supported
211
+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
212
+ // condition limits visibility of the original bit-width so that predicate
213
+ // are not considered, hence, kwidth can never be = 32.
214
+ if (isa<arith::UIToFPOp>(op)) {
215
+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
216
+ if (srcType.isInteger (1 ))
217
+ return false ;
218
+ }
193
219
return op->getNumOperands () == 1 &&
194
220
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
195
221
isPureUnaryInlineAsm (op) ||
196
222
op->getDialect ()->getTypeID () ==
197
223
mlir::TypeID::get<arith::ArithDialect>());
198
224
}
199
225
226
+ public:
200
227
// Finds the first different bitwidth in the chain of shape-preserving
201
228
// unary ops that x depends on.
202
229
// There are two primary scenarios:
@@ -720,6 +747,15 @@ class TritonGPUAccelerateMatmulPass
720
747
}
721
748
};
722
749
750
+ // Expose helper functions from BlockedToMMA to be reused for sparse matmul.
751
+ int computeOrigBitWidth (Value x) {
752
+ return BlockedToMMA::computeOrigBitWidth (x);
753
+ }
754
+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
755
+ int opIdx, bool allowTranspose) {
756
+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
757
+ }
758
+
723
759
} // namespace gpu
724
760
} // namespace triton
725
761
} // namespace mlir
0 commit comments