Skip to content

Commit 4e5e34d

Browse files
authored
[MLIR][ONNX] Add OnnxToTorch support for Slice Op (#2696)
1 parent 3e9bacd commit 4e5e34d

File tree

3 files changed

+294
-3
lines changed

3 files changed

+294
-3
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ struct OpBinder {
3333

3434
Location getLoc() { return op->getLoc(); }
3535

36+
int getNumOperands() { return op->getNumOperands(); }
37+
3638
// Operand matches of different arities.
3739
ParseResult tensorOperand(Value &value0) {
3840
if (op->getNumOperands() != 1)
@@ -189,7 +191,7 @@ struct OpBinder {
189191
}
190192

191193
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
192-
std::string defaultValue = "") {
194+
std::string defaultValue = "") {
193195
SmallString<64> name("torch.onnx.");
194196
name.append(nameSuffix);
195197
auto attr = op->getAttr(name);

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,8 +643,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
643643
llvm::SmallVector<int64_t> axes;
644644
int64_t keepDims;
645645
int64_t noop_with_empty_axes;
646-
if (binder.tensorOperand(data) ||
647-
binder.tensorResultType(resultType) ||
646+
if (binder.tensorOperand(data) || binder.tensorResultType(resultType) ||
648647
binder.s64IntegerArrayAttr(axes, "axes", 0) ||
649648
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
650649
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
@@ -1092,7 +1091,168 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
10921091
rewriter.replaceOp(binder.op, operand);
10931092
return success();
10941093
});
1094+
patterns.onOp(
1095+
"Slice", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1096+
Torch::ValueTensorType resultTorchType;
1097+
Value operand, starts, ends;
1098+
// Handle if axes are not provided
1099+
1100+
if (binder.tensorOperandAtIndex(operand, 0) ||
1101+
binder.tensorOperandAtIndex(starts, 1) ||
1102+
binder.tensorOperandAtIndex(ends, 2) ||
1103+
binder.tensorResultType(resultTorchType)) {
1104+
return failure();
1105+
}
1106+
1107+
auto context = rewriter.getContext();
1108+
auto operandTorchTy = operand.getType().cast<Torch::ValueTensorType>();
1109+
auto operandTy =
1110+
operandTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
1111+
1112+
if (!operandTy)
1113+
return rewriter.notifyMatchFailure(
1114+
binder.op,
1115+
"Expected tensor operator argument to be a ranked tensor type");
1116+
1117+
auto startsTorchTy = starts.getType().cast<Torch::ValueTensorType>();
1118+
auto startsTy =
1119+
startsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
1120+
int startSize = startsTy.getDimSize(0);
1121+
1122+
auto endsTorchTy = ends.getType().cast<Torch::ValueTensorType>();
1123+
auto endsTy =
1124+
endsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
1125+
int endSize = endsTy.getDimSize(0);
1126+
auto resultTy =
1127+
resultTorchType.toBuiltinTensor().dyn_cast<RankedTensorType>();
1128+
if (!resultTy)
1129+
return rewriter.notifyMatchFailure(
1130+
binder.op, "Expected result type to be a ranked tensor type");
1131+
1132+
Location loc = binder.getLoc();
1133+
1134+
// Binding `axes` from its arguments or through a default value
1135+
Value axes;
1136+
if (binder.getNumOperands() >= 4) {
1137+
if (binder.tensorOperandAtIndex(axes, 3)) {
1138+
return failure();
1139+
}
1140+
} else {
1141+
// The default axes value is the range from 0 to the number of
1142+
// dimensions
1143+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
1144+
auto defaultAxesType = Torch::ValueTensorType::get(
1145+
context, ArrayRef<int64_t>{operandTy.getRank()},
1146+
rewriter.getIntegerType(64, /*signed*/ 1));
1147+
Value arangeLength = rewriter.create<Torch::ConstantIntOp>(
1148+
loc, rewriter.getType<Torch::IntType>(),
1149+
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
1150+
operandTy.getRank()));
1151+
axes = rewriter.create<Torch::AtenArangeOp>(
1152+
loc, defaultAxesType, arangeLength, none, none, none, none);
1153+
}
1154+
1155+
// Binding `steps` from its arguments or through a default value
1156+
Value steps;
1157+
if (binder.getNumOperands() >= 5) {
1158+
if (binder.tensorOperandAtIndex(steps, 4)) {
1159+
return failure();
1160+
}
1161+
} else {
1162+
// The default `steps` value is a 1d tensor filled with ones with a
1163+
// size of the dimension of the operand
1164+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
1165+
auto defaultStepsType = Torch::ValueTensorType::get(
1166+
context, ArrayRef<int64_t>{operandTy.getRank()},
1167+
rewriter.getIntegerType(64, /*signed*/ 1));
1168+
Value sizeStepInput = rewriter.create<Torch::ConstantIntOp>(
1169+
loc, rewriter.getType<Torch::IntType>(),
1170+
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
1171+
operandTy.getRank()));
1172+
Value sizeStepsInput = rewriter.create<Torch::PrimListConstructOp>(
1173+
loc,
1174+
Torch::ListType::get(
1175+
Torch::IntType::get(binder.op->getContext())),
1176+
sizeStepInput);
1177+
steps = rewriter.create<Torch::AtenOnesOp>(
1178+
loc, defaultStepsType, sizeStepsInput, none, none, none, none);
1179+
}
10951180

1181+
if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 &&
1182+
startSize == endSize))
1183+
return rewriter.notifyMatchFailure(
1184+
binder.op, "Expected the rank of starts and ends tensors to be 1 "
1185+
"and their dimensions to match");
1186+
1187+
auto axesTorchTy = axes.getType().cast<Torch::ValueTensorType>();
1188+
auto axesTy =
1189+
axesTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
1190+
int64_t numAxes = axesTy.getDimSize(0);
1191+
1192+
if (!(axesTy && numAxes == endSize))
1193+
return rewriter.notifyMatchFailure(
1194+
binder.op, "Axes should be the same size of starts and ends");
1195+
1196+
auto stepsTy = steps.getType()
1197+
.cast<Torch::ValueTensorType>()
1198+
.toBuiltinTensor()
1199+
.dyn_cast<RankedTensorType>();
1200+
1201+
if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0)))
1202+
return rewriter.notifyMatchFailure(
1203+
binder.op, "Steps should be the same size of starts and ends");
1204+
1205+
Value zero = rewriter.create<Torch::ConstantIntOp>(
1206+
loc, rewriter.getType<Torch::IntType>(),
1207+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
1208+
1209+
auto select = [&](Value v, Value k) -> Value {
1210+
auto ty = v.getType().cast<Torch::ValueTensorType>();
1211+
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
1212+
loc,
1213+
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
1214+
ty.getOptionalDtype()),
1215+
v, zero, k);
1216+
Value item = rewriter.create<Torch::AtenItemOp>(
1217+
loc, rewriter.getType<Torch::IntType>(), sel);
1218+
return item;
1219+
};
1220+
1221+
llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
1222+
for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
1223+
if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) {
1224+
intermediateShape[i] = -1;
1225+
}
1226+
}
1227+
auto intermediateType = Torch::ValueTensorType::get(
1228+
context, intermediateShape, resultTorchType.getOptionalDtype());
1229+
for (int i = 0; i < numAxes; ++i) {
1230+
1231+
Value k = rewriter.create<Torch::ConstantIntOp>(
1232+
loc, rewriter.getType<Torch::IntType>(),
1233+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
1234+
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
1235+
loc,
1236+
Torch::ValueTensorType::get(
1237+
context, ArrayRef<int64_t>{1},
1238+
rewriter.getIntegerType(64, /*signed*/ 1)),
1239+
k);
1240+
1241+
Value start = select(starts, kTensor);
1242+
Value end = select(ends, kTensor);
1243+
Value axis = select(axes, kTensor);
1244+
Value step = select(steps, kTensor);
1245+
1246+
auto sliceType = intermediateType;
1247+
if (i == numAxes - 1)
1248+
sliceType = resultTorchType;
1249+
operand = rewriter.create<Torch::AtenSliceTensorOp>(
1250+
loc, sliceType, operand, axis, start, end, step);
1251+
}
1252+
1253+
rewriter.replaceOp(binder.op, operand);
1254+
return success();
1255+
});
10961256
patterns.onOp(
10971257
"Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
10981258
Torch::ValueTensorType resultType;

0 commit comments

Comments
 (0)