@@ -12,7 +12,8 @@ namespace mlir::triton::gpu {
12
12
13
13
namespace {
14
14
15
- template <typename T> bool hasEncoding (Value value) {
15
+ template <typename T>
16
+ bool hasEncoding (Value value) {
16
17
auto type = value.getType ();
17
18
if (auto tensorType = dyn_cast<TensorOrMemDesc>(type)) {
18
19
auto encoding = tensorType.getEncoding ();
@@ -25,7 +26,7 @@ bool hasDotOperandEncoding(Value value) {
25
26
return hasEncoding<triton::gpu::DotOperandEncodingAttr>(value);
26
27
}
27
28
28
- } // namespace
29
+ } // namespace
29
30
30
31
// ===----------------------------------------------------------------------===//
31
32
// Canonicalizer
@@ -36,16 +37,13 @@ struct CanonicalizeConvertFromReshape
36
37
: public mlir::OpRewritePattern<triton::ReshapeOp> {
37
38
using OpRewritePattern::OpRewritePattern;
38
39
39
- mlir::LogicalResult
40
- matchAndRewrite (triton::ReshapeOp op,
41
- PatternRewriter &rewriter) const override {
40
+ mlir::LogicalResult matchAndRewrite (
41
+ triton::ReshapeOp op, PatternRewriter &rewriter) const override {
42
42
auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
43
- if (!convert)
44
- return failure ();
43
+ if (!convert) return failure ();
45
44
if (isExpensiveView (convert.getSrc ().getType (), op.getType ()))
46
45
return failure ();
47
- if (!op.getAllowReorder () || op.getEfficientLayout ())
48
- return failure ();
46
+ if (!op.getAllowReorder () || op.getEfficientLayout ()) return failure ();
49
47
50
48
rewriter.replaceOpWithNewOp <triton::ReshapeOp>(
51
49
op, op.getType (), convert.getSrc (), op.getAllowReorder ());
@@ -58,12 +56,10 @@ struct CanonicalizeConvertFromHistogram
58
56
: public mlir::OpRewritePattern<triton::HistogramOp> {
59
57
using OpRewritePattern::OpRewritePattern;
60
58
61
- mlir::LogicalResult
62
- matchAndRewrite (triton::HistogramOp op,
63
- PatternRewriter &rewriter) const override {
59
+ mlir::LogicalResult matchAndRewrite (
60
+ triton::HistogramOp op, PatternRewriter &rewriter) const override {
64
61
auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
65
- if (!convert)
66
- return failure ();
62
+ if (!convert) return failure ();
67
63
rewriter.replaceOpWithNewOp <triton::HistogramOp>(
68
64
op, op->getResult (0 ).getType (), convert.getSrc ());
69
65
return mlir::success ();
@@ -79,15 +75,13 @@ struct CanonicalizeConvertFromHistogram
79
75
struct CanonicalizeConvertFromGatherSource : public OpRewritePattern <GatherOp> {
80
76
using OpRewritePattern::OpRewritePattern;
81
77
82
- mlir::LogicalResult
83
- matchAndRewrite ( GatherOp op, PatternRewriter &rewriter) const override {
78
+ mlir::LogicalResult matchAndRewrite (
79
+ GatherOp op, PatternRewriter &rewriter) const override {
84
80
// Don't do this if the compiler picked an optimized layout.
85
- if (op.getEfficientLayout ())
86
- return failure ();
81
+ if (op.getEfficientLayout ()) return failure ();
87
82
88
83
auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
89
- if (!convert)
90
- return failure ();
84
+ if (!convert) return failure ();
91
85
92
86
rewriter.replaceOpWithNewOp <GatherOp>(op, convert.getSrc (), op.getIndices (),
93
87
op.getAxis ());
@@ -100,13 +94,15 @@ struct CanonicalizeConvertFromAlloc
100
94
: public mlir::OpRewritePattern<triton::gpu::LocalAllocOp> {
101
95
using OpRewritePattern::OpRewritePattern;
102
96
103
- mlir::LogicalResult
104
- matchAndRewrite (triton::gpu::LocalAllocOp op,
105
- PatternRewriter &rewriter) const override {
106
- if (!op.getSrc ())
107
- return failure ();
97
+ mlir::LogicalResult matchAndRewrite (
98
+ triton::gpu::LocalAllocOp op, PatternRewriter &rewriter) const override {
99
+ if (!op.getSrc ()) return failure ();
108
100
auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
109
- if (!convert)
101
+ if (!convert) return failure ();
102
+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
103
+ // to SharedEncoding, so we want to keep this layout conversion.
104
+ if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
105
+ convert.getSrc ().getType ().getEncoding ()))
110
106
return failure ();
111
107
rewriter.replaceOpWithNewOp <triton::gpu::LocalAllocOp>(
112
108
op, op->getResult (0 ).getType (), convert.getSrc ());
@@ -119,12 +115,10 @@ struct CanonicalizeConvertFromLocalStore
119
115
: public mlir::OpRewritePattern<triton::gpu::LocalStoreOp> {
120
116
using OpRewritePattern::OpRewritePattern;
121
117
122
- mlir::LogicalResult
123
- matchAndRewrite (triton::gpu::LocalStoreOp op,
124
- PatternRewriter &rewriter) const override {
118
+ mlir::LogicalResult matchAndRewrite (
119
+ triton::gpu::LocalStoreOp op, PatternRewriter &rewriter) const override {
125
120
auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
126
- if (!convert)
127
- return failure ();
121
+ if (!convert) return failure ();
128
122
rewriter.replaceOpWithNewOp <triton::gpu::LocalStoreOp>(op, convert.getSrc (),
129
123
op.getDst ());
130
124
return mlir::success ();
@@ -135,19 +129,16 @@ struct CanonicalizeConvertFromSplit
135
129
: public mlir::OpRewritePattern<triton::SplitOp> {
136
130
using OpRewritePattern::OpRewritePattern;
137
131
138
- mlir::LogicalResult
139
- matchAndRewrite (triton::SplitOp op,
140
- PatternRewriter &rewriter) const override {
132
+ mlir::LogicalResult matchAndRewrite (
133
+ triton::SplitOp op, PatternRewriter &rewriter) const override {
141
134
auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
142
- if (!convert)
143
- return failure ();
135
+ if (!convert) return failure ();
144
136
auto srcEncoding = convert.getSrc ().getType ().getEncoding ();
145
137
// Multiple source layout can give the same output layout, if the source
146
138
// layout of the convert gives the same destination layout we can skip the
147
139
// convert.
148
140
auto dstEncoding = inferDstEncoding (op, srcEncoding);
149
- if (dstEncoding != op.getOutLHS ().getType ().getEncoding ())
150
- return failure ();
141
+ if (dstEncoding != op.getOutLHS ().getType ().getEncoding ()) return failure ();
151
142
rewriter.replaceOpWithNewOp <triton::SplitOp>(op, convert.getSrc ());
152
143
return mlir::success ();
153
144
}
@@ -157,9 +148,8 @@ struct CanonicalizeConvertFromConvert
157
148
: public OpRewritePattern<ConvertLayoutOp> {
158
149
using OpRewritePattern::OpRewritePattern;
159
150
160
- mlir::LogicalResult
161
- matchAndRewrite (ConvertLayoutOp op,
162
- PatternRewriter &rewriter) const override {
151
+ mlir::LogicalResult matchAndRewrite (
152
+ ConvertLayoutOp op, PatternRewriter &rewriter) const override {
163
153
// Convert to the same layout is redundant.
164
154
if (op->getResultTypes () == op->getOperandTypes ()) {
165
155
rewriter.replaceOp (op, op->getOperands ());
@@ -170,22 +160,21 @@ struct CanonicalizeConvertFromConvert
170
160
// heuristic to accommodate fused attention.
171
161
auto srcType = op.getSrc ().getType ();
172
162
auto dstType = op.getType ();
173
- if (mlir::isa <DotOperandEncodingAttr>(dstType.getEncoding ()) &&
174
- mlir::isa <NvidiaMmaEncodingAttr>(srcType.getEncoding ()))
163
+ if (mlir::isa_and_nonnull <DotOperandEncodingAttr>(dstType.getEncoding ()) &&
164
+ mlir::isa_and_nonnull <NvidiaMmaEncodingAttr>(srcType.getEncoding ()))
175
165
return failure ();
176
166
177
167
// for hopper MMAv3
178
- if (mlir::isa <SharedEncodingAttr>(dstType.getEncoding ()) &&
179
- mlir::isa <NvidiaMmaEncodingAttr>(srcType.getEncoding ()) &&
168
+ if (mlir::isa_and_nonnull <SharedEncodingAttr>(dstType.getEncoding ()) &&
169
+ mlir::isa_and_nonnull <NvidiaMmaEncodingAttr>(srcType.getEncoding ()) &&
180
170
llvm::any_of (op.getResult ().getUsers (), [](Operation *dot) {
181
171
return dot->hasTrait <OpTrait::DotLike>();
182
172
})) {
183
173
return failure ();
184
174
}
185
175
186
176
Operation *arg = op.getSrc ().getDefiningOp ();
187
- if (!arg)
188
- return failure ();
177
+ if (!arg) return failure ();
189
178
190
179
// cvt(reshape) -> reshape
191
180
if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
@@ -233,8 +222,7 @@ struct CanonicalizeConvertFromConvert
233
222
234
223
// cvt(cat) -> cat
235
224
if (auto cat = dyn_cast<CatOp>(arg)) {
236
- if (isExpensiveCat (cat, op.getType ().getEncoding ()))
237
- return failure ();
225
+ if (isExpensiveCat (cat, op.getType ().getEncoding ())) return failure ();
238
226
239
227
rewriter.replaceOpWithNewOp <CatOp>(op, op->getResult (0 ).getType (),
240
228
cat.getOperands ());
@@ -291,15 +279,14 @@ LogicalResult UpcastMXFPOp::verify() {
291
279
292
280
auto xTy = getSrc ().getType ();
293
281
auto scaleTy = getScale ().getType ();
294
- Builder b (getContext ());
295
- if (xTy.getElementType () != b.getBF16Type () &&
296
- xTy.getElementType () != b.getF16Type () &&
297
- xTy.getElementType () != b.getI8Type ()) {
298
- return emitOpError (
299
- " element type of the first operand must be bf16/fp16 or i8" );
282
+
283
+ if (xTy.getElementType () != BFloat16Type::get (getContext ()) &&
284
+ xTy.getElementType () != Float16Type::get (getContext ()) &&
285
+ xTy.getElementType () != IntegerType::get (getContext (), 8 )) {
286
+ return emitOpError (" element type of the first operand must be bf16 or i8" );
300
287
}
301
288
302
- if (scaleTy.getElementType () != b. getI8Type ( )) {
289
+ if (scaleTy.getElementType () != IntegerType::get ( getContext (), 8 )) {
303
290
return emitOpError (" element type of the second operand must be uint8" );
304
291
}
305
292
@@ -373,14 +360,12 @@ LogicalResult UpcastMXFPOp::verify() {
373
360
return success ();
374
361
}
375
362
376
- RankedTensorType
377
- UpcastMXFPOp::deduceOutputType (TypedValue<RankedTensorType> inputTensor,
378
- ScaleDotElemType inputElemType,
379
- Type outputElemType) {
363
+ RankedTensorType UpcastMXFPOp::deduceOutputType (
364
+ TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType,
365
+ Type outputElemType) {
380
366
MLIRContext *ctx = inputTensor.getContext ();
381
367
auto xTy = inputTensor.getType ();
382
- if (inputElemType != ScaleDotElemType::E2M1)
383
- return xTy;
368
+ if (inputElemType != ScaleDotElemType::E2M1) return xTy;
384
369
385
370
auto xShape = xTy.getShape ();
386
371
auto newShape = llvm::to_vector (xShape);
@@ -466,17 +451,13 @@ void LocalAllocOp::getEffects(
466
451
}
467
452
468
453
OpFoldResult LocalAllocOp::fold (FoldAdaptor adaptor) {
469
- if (getType ().getMutableMemory ())
470
- return {};
454
+ if (getType ().getMutableMemory ()) return {};
471
455
auto src = getSrc ();
472
- if (!src)
473
- return {};
456
+ if (!src) return {};
474
457
auto localLoadOp = src.getDefiningOp <LocalLoadOp>();
475
- if (!localLoadOp)
476
- return {};
458
+ if (!localLoadOp) return {};
477
459
auto loadSrc = localLoadOp.getSrc ();
478
- if (loadSrc.getType () != getType ())
479
- return {};
460
+ if (loadSrc.getType () != getType ()) return {};
480
461
return loadSrc;
481
462
}
482
463
0 commit comments