Skip to content

Commit 28c7051

Browse files
1 parent 46a25d7 commit 28c7051

File tree

4 files changed

+21
-21
lines changed

4 files changed

+21
-21
lines changed

externals/llvm-project

lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class AdjustCallingConventionForFunc
8181
}
8282
newResultTypes.push_back(type);
8383
}
84-
rewriter.updateRootInPlace(func, [&] {
84+
rewriter.modifyOpInPlace(func, [&] {
8585
func.setType(FunctionType::get(
8686
getContext(), conversion.getConvertedTypes(), newResultTypes));
8787
// Clear out the type bounds, now that the type incorporates them.
@@ -194,14 +194,12 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
194194
TypeConverter typeConverter;
195195
typeConverter.addConversion([](Type type) { return type; });
196196
typeConverter.addConversion(
197-
[](Torch::TupleType type,
198-
SmallVectorImpl<Type> &types) -> LogicalResult {
197+
[](Torch::TupleType type, SmallVectorImpl<Type> &types) -> LogicalResult {
199198
llvm::append_range(types, type.getContainedTypes());
200199
return success();
201200
});
202201
typeConverter.addConversion(
203-
[](Torch::NoneType type,
204-
SmallVectorImpl<Type> &types) -> LogicalResult {
202+
[](Torch::NoneType type, SmallVectorImpl<Type> &types) -> LogicalResult {
205203
return success();
206204
});
207205

lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
175175

176176
// Replace return type of view-like ops with value-semantics type variant.
177177
for (Operation *viewLikeOp : ops.viewLikeOps) {
178-
rewriter.updateRootInPlace(viewLikeOp, [&] {
178+
rewriter.modifyOpInPlace(viewLikeOp, [&] {
179179
Value result = viewLikeOp->getResult(0);
180180
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
181181
if (resultType)
@@ -337,7 +337,7 @@ class RewriteViewLikeSubgraph
337337
// correctly copy them back to their mlir::func::ReturnOp's expected types.
338338
DenseMap<Value, Type> originalTypes;
339339
for (Operation *op : viewLikeOps) {
340-
rewriter.updateRootInPlace(op, [&]() {
340+
rewriter.modifyOpInPlace(op, [&]() {
341341
if (auto nonValueTensorType =
342342
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
343343
originalTypes[op->getResult(0)] = nonValueTensorType;

lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
#include "PassDetail.h"
1111

12+
#include "ReifyAbstractInterpCalculationsUtils.h"
1213
#include "mlir/Transforms/DialectConversion.h"
1314
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1415
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
15-
#include "ReifyAbstractInterpCalculationsUtils.h"
1616
#include "llvm/ADT/StringExtras.h"
1717

1818
using namespace mlir;
@@ -72,8 +72,8 @@ namespace {
7272
// immutable tensors.
7373
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
7474
public:
75-
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context,
76-
const std::optional<SymbolTable>& extraLibrary)
75+
ConvertHasValueSemanticsOpsToValueTensors(
76+
MLIRContext *context, const std::optional<SymbolTable> &extraLibrary)
7777
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
7878
this->extraLibrary = extraLibrary;
7979
}
@@ -87,7 +87,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
8787
return rewriter.notifyMatchFailure(op, "does not have value semantics");
8888
}
8989

90-
rewriter.startRootUpdate(op);
90+
rewriter.startOpModification(op);
9191
// Convert all operands.
9292
SmallVector<Value> newOperands;
9393
for (OpOperand &opOperand : op->getOpOperands()) {
@@ -105,7 +105,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
105105
auto listConstruct =
106106
opOperand.get().getDefiningOp<PrimListConstructOp>();
107107
if (!listConstruct) {
108-
rewriter.cancelRootUpdate(op);
108+
rewriter.cancelOpModification(op);
109109
return rewriter.notifyMatchFailure(
110110
op, "unimplemented: list of non vtensor type not constructed "
111111
"from list construct");
@@ -120,7 +120,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
120120
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
121121
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
122122
})) {
123-
rewriter.cancelRootUpdate(op);
123+
rewriter.cancelOpModification(op);
124124
return rewriter.notifyMatchFailure(
125125
op, "unimplemented: list containing optional type is not "
126126
"handled.");
@@ -138,7 +138,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
138138

139139
Type newListType = getContainerOrTensorTypeWithValueSemantics(listType);
140140
if (!newListType) {
141-
rewriter.cancelRootUpdate(op);
141+
rewriter.cancelOpModification(op);
142142
return rewriter.notifyMatchFailure(
143143
op, "Unable to convert list type to value semantics.");
144144
}
@@ -154,7 +154,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
154154
// from the non value tensor of the original optional value.
155155
auto derefine = opOperand.get().getDefiningOp<DerefineOp>();
156156
if (!derefine) {
157-
rewriter.cancelRootUpdate(op);
157+
rewriter.cancelOpModification(op);
158158
return rewriter.notifyMatchFailure(
159159
op, "unimplemented: optional of non vtensor type not from "
160160
"derefine");
@@ -180,9 +180,10 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
180180
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
181181
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
182182
}
183-
rewriter.finalizeRootUpdate(op);
183+
rewriter.finalizeOpModification(op);
184184
return success();
185185
}
186+
186187
private:
187188
std::optional<SymbolTable> extraLibrary;
188189
};
@@ -290,17 +291,18 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern {
290291
Operation *newOp = rewriter.create(state);
291292
// Note: need to convert result to first input's dtype because mix precision
292293
// compute would result in different behaviors.
293-
// For example:
294-
// a = torch.randn(3, 3).half() # float16
295-
// b = torch.randn(3, 3) # float32
294+
// For example:
295+
// a = torch.randn(3, 3).half() # float16
296+
// b = torch.randn(3, 3) # float32
296297
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
297298
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
298299
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
299300
Value cstFalse = rewriter.create<ConstantBoolOp>(op->getLoc(), false);
300301
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
301302
auto toDtype = rewriter.create<AtenToDtypeOp>(
302303
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
303-
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
304+
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse,
305+
/*memory_format=*/none);
304306
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
305307
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
306308
op->getOperand(0));

0 commit comments

Comments
 (0)