@@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
153
153
auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
154
154
newLayout, SharedMemorySpace);
155
155
rewriter.setInsertionPointAfterValue (arg);
156
+
157
+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
158
+ // to SharedEncoding.
159
+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
160
+ argType.getEncoding ())) {
161
+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
162
+ // then pass it to the LocalAllocOp.
163
+ auto newArgType = RankedTensorType::get (
164
+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
165
+ auto dotOperandToBlockedCvt =
166
+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
167
+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
168
+ dotOperandToBlockedCvt);
169
+ }
170
+
156
171
return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
157
172
}
158
173
@@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
162
177
mutable llvm::DenseMap<Operation *, unsigned > dotOpInstNs;
163
178
164
179
static bool bwdFilter (Operation *op) {
180
+ // Dot operand layout assignment to Predicates are not currently supported
181
+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
182
+ // condition limits visibility of the original bit-width so that predicate
183
+ // are not considered, hence, kwidth can never be = 32.
184
+ if (isa<arith::UIToFPOp>(op)) {
185
+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
186
+ if (srcType.isInteger (1 ))
187
+ return false ;
188
+ }
165
189
return op->getNumOperands () == 1 &&
166
190
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
167
191
isPureUnaryInlineAsm (op) ||
@@ -357,7 +381,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
357
381
NvidiaMmaEncodingAttr mmaLayout =
358
382
dyn_cast<NvidiaMmaEncodingAttr>(D.getType ().getEncoding ());
359
383
if (mmaLayout) {
360
- bool isNativeFP8 = AElType.isFloat8E5M2 () || AElType.isFloat8E4M3FNUZ ();
384
+ bool isNativeFP8 = AElType.isFloat8E5M2 () || AElType.isFloat8E4M3FN ();
361
385
// promote operands for sm < 89 since fp8 mma is not natively supported
362
386
// promote operands for sm >= 90 when mma is not v3
363
387
if (!isNativeFP8 ||
0 commit comments