9
9
10
10
#include " PassDetail.h"
11
11
12
+ #include " ReifyAbstractInterpCalculationsUtils.h"
12
13
#include " mlir/Transforms/DialectConversion.h"
13
14
#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
14
15
#include " torch-mlir/Dialect/Torch/Transforms/Passes.h"
15
- #include " ReifyAbstractInterpCalculationsUtils.h"
16
16
#include " llvm/ADT/StringExtras.h"
17
17
18
18
using namespace mlir ;
@@ -72,8 +72,8 @@ namespace {
72
72
// immutable tensors.
73
73
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
74
74
public:
75
- ConvertHasValueSemanticsOpsToValueTensors (MLIRContext *context,
76
- const std::optional<SymbolTable>& extraLibrary)
75
+ ConvertHasValueSemanticsOpsToValueTensors (
76
+ MLIRContext *context, const std::optional<SymbolTable> & extraLibrary)
77
77
: RewritePattern(MatchAnyOpTypeTag(), /* benefit=*/ 1 , context) {
78
78
this ->extraLibrary = extraLibrary;
79
79
}
@@ -87,7 +87,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
87
87
return rewriter.notifyMatchFailure (op, " does not have value semantics" );
88
88
}
89
89
90
- rewriter.startRootUpdate (op);
90
+ rewriter.startOpModification (op);
91
91
// Convert all operands.
92
92
SmallVector<Value> newOperands;
93
93
for (OpOperand &opOperand : op->getOpOperands ()) {
@@ -105,7 +105,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
105
105
auto listConstruct =
106
106
opOperand.get ().getDefiningOp <PrimListConstructOp>();
107
107
if (!listConstruct) {
108
- rewriter.cancelRootUpdate (op);
108
+ rewriter.cancelOpModification (op);
109
109
return rewriter.notifyMatchFailure (
110
110
op, " unimplemented: list of non vtensor type not constructed "
111
111
" from list construct" );
@@ -120,7 +120,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
120
120
if (!llvm::all_of (listConstruct.getElements (), [](Value val) {
121
121
return val.getType ().isa <NonValueTensorType, Torch::NoneType>();
122
122
})) {
123
- rewriter.cancelRootUpdate (op);
123
+ rewriter.cancelOpModification (op);
124
124
return rewriter.notifyMatchFailure (
125
125
op, " unimplemented: list containing optional type is not "
126
126
" handled." );
@@ -138,7 +138,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
138
138
139
139
Type newListType = getContainerOrTensorTypeWithValueSemantics (listType);
140
140
if (!newListType) {
141
- rewriter.cancelRootUpdate (op);
141
+ rewriter.cancelOpModification (op);
142
142
return rewriter.notifyMatchFailure (
143
143
op, " Unable to convert list type to value semantics." );
144
144
}
@@ -154,7 +154,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
154
154
// from the non value tensor of the original optional value.
155
155
auto derefine = opOperand.get ().getDefiningOp <DerefineOp>();
156
156
if (!derefine) {
157
- rewriter.cancelRootUpdate (op);
157
+ rewriter.cancelOpModification (op);
158
158
return rewriter.notifyMatchFailure (
159
159
op, " unimplemented: optional of non vtensor type not from "
160
160
" derefine" );
@@ -180,9 +180,10 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
180
180
rewriter.create <CopyToNonValueTensorOp>(op->getLoc (), result);
181
181
result.replaceAllUsesExcept (nonValueTensor, nonValueTensor);
182
182
}
183
- rewriter.finalizeRootUpdate (op);
183
+ rewriter.finalizeOpModification (op);
184
184
return success ();
185
185
}
186
+
186
187
private:
187
188
std::optional<SymbolTable> extraLibrary;
188
189
};
@@ -290,17 +291,18 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern {
290
291
Operation *newOp = rewriter.create (state);
291
292
// Note: need to convert result to first input's dtype because mix precision
292
293
// compute would result in different behaviors.
293
- // For example:
294
- // a = torch.randn(3, 3).half() # float16
295
- // b = torch.randn(3, 3) # float32
294
+ // For example:
295
+ // a = torch.randn(3, 3).half() # float16
296
+ // b = torch.randn(3, 3) # float32
296
297
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
297
298
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
298
299
Value none = rewriter.create <ConstantNoneOp>(op->getLoc ());
299
300
Value cstFalse = rewriter.create <ConstantBoolOp>(op->getLoc (), false );
300
301
auto aDtype = rewriter.create <PrimDtypeOp>(op->getLoc (), op->getOperand (0 ));
301
302
auto toDtype = rewriter.create <AtenToDtypeOp>(
302
303
op->getLoc (), newOp->getResult (0 ).getType (), newOp->getResult (0 ),
303
- aDtype, /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse, /* memory_format=*/ none);
304
+ aDtype, /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse,
305
+ /* memory_format=*/ none);
304
306
auto tensor = rewriter.create <CopyToValueTensorOp>(op->getLoc (), toDtype);
305
307
createOverwriteTensorContents (rewriter, op->getLoc (), tensor,
306
308
op->getOperand (0 ));
0 commit comments