-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][toy] Update dialect conversion example #150826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe Toy tutorial used outdated API. Update the example to:
Patch is 41.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150826.diff 7 Files Affected:
diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md
index c750c07ddfc04..39ec02d04c8ce 100644
--- a/mlir/docs/Tutorials/Toy/Ch-5.md
+++ b/mlir/docs/Tutorials/Toy/Ch-5.md
@@ -91,53 +91,37 @@ doesn't matter. See `ConversionTarget::getOpInfo` for the details.
After the conversion target has been defined, we can define how to convert the
*illegal* operations into *legal* ones. Similarly to the canonicalization
framework introduced in [chapter 3](Ch-3.md), the
-[`DialectConversion` framework](../../DialectConversion.md) also uses
-[RewritePatterns](../QuickstartRewrites.md) to perform the conversion logic.
-These patterns may be the `RewritePatterns` seen before or a new type of pattern
-specific to the conversion framework `ConversionPattern`. `ConversionPatterns`
-are different from traditional `RewritePatterns` in that they accept an
-additional `operands` parameter containing operands that have been
-remapped/replaced. This is used when dealing with type conversions, as the
-pattern will want to operate on values of the new type but match against the
-old. For our lowering, this invariant will be useful as it translates from the
-[TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently being
-operated on to the [MemRefType](../../Dialects/Builtin.md/#memreftype). Let's
-look at a snippet of lowering the `toy.transpose` operation:
+[`DialectConversion` framework](../../DialectConversion.md) also uses a special
+kind of `ConversionPattern` to perform the conversion logic.
+`ConversionPatterns` are different from traditional `RewritePatterns` in that
+they accept an additional `operands` (or `adaptor`) parameter containing
+operands that have been remapped/replaced. This is used when dealing with type
+conversions, as the pattern will want to operate on values of the new type but
+match against the old. For our lowering, this invariant will be useful as it
+translates from the [TensorType](../../Dialects/Builtin.md/#rankedtensortype)
+currently being operated on to the
+[MemRefType](../../Dialects/Builtin.md/#memreftype). Let's look at a snippet of
+lowering the `toy.transpose` operation:
```c++
/// Lower the `toy.transpose` operation to an affine loop nest.
-struct TransposeOpLowering : public mlir::ConversionPattern {
- TransposeOpLowering(mlir::MLIRContext *ctx)
- : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
-
- /// Match and rewrite the given `toy.transpose` operation, with the given
- /// operands that have been remapped from `tensor<...>` to `memref<...>`.
- llvm::LogicalResult
- matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
- mlir::ConversionPatternRewriter &rewriter) const final {
- auto loc = op->getLoc();
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
- // Call to a helper function that will lower the current operation to a set
- // of affine loops. We provide a functor that operates on the remapped
- // operands, as well as the loop induction variables for the inner most
- // loop body.
- lowerOpToLoops(
- op, operands, rewriter,
- [loc](mlir::PatternRewriter &rewriter,
- ArrayRef<mlir::Value> memRefOperands,
- ArrayRef<mlir::Value> loopIvs) {
- // Generate an adaptor for the remapped operands of the TransposeOp.
- // This allows for using the nice named accessors that are generated
- // by the ODS. This adaptor is automatically provided by the ODS
- // framework.
- TransposeOpAdaptor transposeAdaptor(memRefOperands);
- mlir::Value input = transposeAdaptor.input();
-
- // Transpose the elements by generating a load from the reverse
- // indices.
- SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return mlir::AffineLoadOp::create(rewriter, loc, input, reverseIvs);
- });
+ LogicalResult
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ auto loc = op->getLoc();
+ lowerOpToLoops(op, rewriter,
+ [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input,
+ reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index 529de55304206..1ef9351422a59 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -81,6 +81,14 @@ enough for our use case.
LLVMTypeConverter typeConverter(&getContext());
```
+For the `toy.print` lowering, we need a special type converter to ensure that
+the pattern receives a `memref` value in its adaptor. If we were to use the
+LLVM type converter, it would receive an `llvm.struct`, which is the normal
+lowering of a `memref` type to LLVM. If we were to use no type converter at
+all, it would receive a value with the original tensor type. (Note: The dialect
+conversion driver currently passes the "most recently mapped value", i.e., a
+value of unspecified type. This is a bug in the conversion driver.)
+
### Conversion Patterns
Now that the conversion target has been defined, we need to provide the patterns
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index d65c89c3fcfa6..2969d3a795779 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -44,7 +44,7 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns
+// ToyToAffine Conversion Patterns
//===----------------------------------------------------------------------===//
/// Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
}
/// This defines the function type used to process an iteration of a lowered
-/// loop. It takes as input an OpBuilder, an range of memRefOperands
-/// corresponding to the operands of the input operation, and the range of loop
-/// induction variables for the iteration. It returns a value to store at the
-/// current index of the iteration.
-using LoopIterationFn = function_ref<Value(
- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
-
-static void lowerOpToLoops(Operation *op, ValueRange operands,
- PatternRewriter &rewriter,
+/// loop. It takes as input an OpBuilder and the range of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn =
+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
affine::buildAffineLoopNest(
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
- // Call the processing function with the rewriter, the memref operands,
- // and the loop induction variables. This function will return the value
- // to store at the current index.
- Value valueToStore = processIteration(nestedBuilder, operands, ivs);
+ // Call the processing function with the rewriter and the loop
+ // induction variables. This function will return the value to store at
+ // the current index.
+ Value valueToStore = processIteration(nestedBuilder, ivs);
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
ivs);
});
@@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
namespace {
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Binary operations
+// ToyToAffine Conversion Patterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
-struct BinaryOpLowering : public ConversionPattern {
- BinaryOpLowering(MLIRContext *ctx)
- : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+struct BinaryOpLowering : public OpConversionPattern<BinaryOp> {
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // BinaryOp. This allows for using the nice named accessors
- // that are generated by the ODS.
- typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
-
- // Generate loads for the element of 'lhs' and 'rhs' at the
- // inner loop.
- auto loadedLhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getLhs(), loopIvs);
- auto loadedRhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getRhs(), loopIvs);
-
- // Create the binary operation performed on the loaded
- // values.
- return LoweredBinaryOp::create(builder, loc, loadedLhs,
- loadedRhs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ // Generate loads for the element of 'lhs' and 'rhs' at the
+ // inner loop.
+ auto loadedLhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs);
+ auto loadedRhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs);
+
+ // Create the binary operation performed on the loaded
+ // values.
+ return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs);
+ });
return success();
}
};
@@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Constant operations
+// ToyToAffine Conversion Patterns: Constant operations
//===----------------------------------------------------------------------===//
-struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
- using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
+struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> {
+ using OpConversionPattern<toy::ConstantOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Func operations
+// ToyToAffine Conversion Patterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Print operations
+// ToyToAffine Conversion Patterns: Print operations
//===----------------------------------------------------------------------===//
struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Return operations
+// ToyToAffine Conversion Patterns: Return operations
//===----------------------------------------------------------------------===//
-struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
- using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
+struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> {
+ using OpConversionPattern<toy::ReturnOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Transpose operations
+// ToyToAffine Conversion Patterns: Transpose operations
//===----------------------------------------------------------------------===//
-struct TransposeOpLowering : public ConversionPattern {
- TransposeOpLowering(MLIRContext *ctx)
- : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // TransposeOp. This allows for using the nice named
- // accessors that are generated by the ODS.
- toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
- Value input = transposeAdaptor.getInput();
-
- // Transpose the elements by generating a load from the
- // reverse indices.
- SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return affine::AffineLoadOp::create(builder, loc, input,
- reverseIvs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input, reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index d65c89c3fcfa6..2969d3a795779 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -44,7 +44,7 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns
+// ToyToAffine Conversion Patterns
//===----------------------------------------------------------------------===//
/// Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
}
/// This defines the function type used to process an iteration of a lowered
-/// loop. It takes as input an OpBuilder, an range of memRefOperands
-/// corresponding to the operands of the input operation, and the range of loop
-/// induction variables for the iteration. It returns a value to store at the
-/// current index of the iteration.
-using LoopIterationFn = function_ref<Value(
- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
-
-static void lowerOpToLoops(Operation *op, ValueRange operands,
- PatternRewriter &rewriter,
+/// loop. It takes as input an OpBuilder and the range of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn =
+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
affine::buildAffineLoopNest(
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
- // Call the processing function with the rewriter, the memref operands,
- // and the loop induction variables. This function will return the value
- // to store at the current index.
- Value valueToStore = processIteration(nestedBuilder, operands, ivs);
+ // Call the processing function with the rewriter and the loop
+ // induction variables. This function will return the value to store at
+ // the current index.
+ Value valueToStore = processIteration(nestedBuilder, ivs);
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
ivs);
});
@@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
namespace {
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Binary operations
+// ToyToAffine Conversion Patterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
-struct BinaryOpLowering : public ConversionPattern {
- BinaryOpLowering(MLIRContext *ctx)
- : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+struct BinaryOpLowering : public OpConversionPattern<BinaryOp> {
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // BinaryOp. This allows for using the nice named accessors
- ...
[truncated]
|
20bc42e
to
a039b2a
Compare
// PrintOp. An identity converter is needed because the PrintOp lowering | ||
// operates on MemRefType instead of the lowered LLVM struct type. | ||
TypeConverter identityConverter; | ||
identityConverter.addConversion([](Type type) { return type; }); | ||
patterns.add<PrintOpLowering>(identityConverter, &getContext()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was the previous code using printOp.getInput()
not okay? My understanding was that it was fine to use the old operands of op
as long as you are aware of it being of the old type. This is what is happening here with the use of memref.load
. This is very useful IMO to drive recursive legalization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I always thought only taking the type of the original operand would be safe. But using the operand to construct new IR seems to work fine in the conversion driver. (And it will also work fine in the new One-Shot driver.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is it safe? Is the driver automatically remapping it when constructing a new operation? Or am I misunderstanding what you mean by "using the operand to construct new IR"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using the operand to construct new IR
I meant: inside of a ConversionPattern
, retrieve an operand from the matched op, and use that SSA value to construct a new op.
Inside of a ConversionPattern
, the main difference between picking (a) a value from the adaptor or (b) the respective operand from the matched op is the type of the SSA value that you get. (b) always has the original type. (a) may be a type-converted "version" of (b).
I was worried that (b) is not safe because the defining op of that value may have been scheduled for erasure, leaving us with an operand pointing to an erased op. But I think that cannot happen. When an op is replaced (or erased), its results are mapped to something else in the ConversionValueMapping
. During the "commit" phase (that is after the op from the example above has been created), those mappings are materialized in IR via replaceAllUsesWith
. At that point, the operand of the newly-constructed op in the example above will be switched to an SSA value that survives the conversion and has the correct type. (A materialization may be created at that point of time.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The Toy tutorial used outdated API. Update the example to: * Use the `OpAdaptor` in all places. * Do not mix `RewritePattern` and `ConversionPattern`. This cannot always be done safely and should not be advertised in the example code.
The Toy tutorial used outdated API. Update the example to:
OpAdaptor
in all places.RewritePattern
andConversionPattern
. This cannot always be done safely and should not be advertised in the example code.