@@ -643,8 +643,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
643
643
llvm::SmallVector<int64_t > axes;
644
644
int64_t keepDims;
645
645
int64_t noop_with_empty_axes;
646
- if (binder.tensorOperand (data) ||
647
- binder.tensorResultType (resultType) ||
646
+ if (binder.tensorOperand (data) || binder.tensorResultType (resultType) ||
648
647
binder.s64IntegerArrayAttr (axes, " axes" , 0 ) ||
649
648
binder.s64IntegerAttr (keepDims, " keepdims" , 1 ) ||
650
649
binder.s64IntegerAttr (noop_with_empty_axes, " noop_with_empty_axes" ,
@@ -1092,7 +1091,168 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
1092
1091
rewriter.replaceOp (binder.op , operand);
1093
1092
return success ();
1094
1093
});
1094
+ patterns.onOp (
1095
+ " Slice" , 13 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1096
+ Torch::ValueTensorType resultTorchType;
1097
+ Value operand, starts, ends;
1098
+ // Handle if axes are not provided
1099
+
1100
+ if (binder.tensorOperandAtIndex (operand, 0 ) ||
1101
+ binder.tensorOperandAtIndex (starts, 1 ) ||
1102
+ binder.tensorOperandAtIndex (ends, 2 ) ||
1103
+ binder.tensorResultType (resultTorchType)) {
1104
+ return failure ();
1105
+ }
1106
+
1107
+ auto context = rewriter.getContext ();
1108
+ auto operandTorchTy = operand.getType ().cast <Torch::ValueTensorType>();
1109
+ auto operandTy =
1110
+ operandTorchTy.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1111
+
1112
+ if (!operandTy)
1113
+ return rewriter.notifyMatchFailure (
1114
+ binder.op ,
1115
+ " Expected tensor operator argument to be a ranked tensor type" );
1116
+
1117
+ auto startsTorchTy = starts.getType ().cast <Torch::ValueTensorType>();
1118
+ auto startsTy =
1119
+ startsTorchTy.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1120
+ int startSize = startsTy.getDimSize (0 );
1121
+
1122
+ auto endsTorchTy = ends.getType ().cast <Torch::ValueTensorType>();
1123
+ auto endsTy =
1124
+ endsTorchTy.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1125
+ int endSize = endsTy.getDimSize (0 );
1126
+ auto resultTy =
1127
+ resultTorchType.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1128
+ if (!resultTy)
1129
+ return rewriter.notifyMatchFailure (
1130
+ binder.op , " Expected result type to be a ranked tensor type" );
1131
+
1132
+ Location loc = binder.getLoc ();
1133
+
1134
+ // Binding `axes` from its arguments or through a default value
1135
+ Value axes;
1136
+ if (binder.getNumOperands () >= 4 ) {
1137
+ if (binder.tensorOperandAtIndex (axes, 3 )) {
1138
+ return failure ();
1139
+ }
1140
+ } else {
1141
+ // The default axes value is the range from 0 to the number of
1142
+ // dimensions
1143
+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
1144
+ auto defaultAxesType = Torch::ValueTensorType::get (
1145
+ context, ArrayRef<int64_t >{operandTy.getRank ()},
1146
+ rewriter.getIntegerType (64 , /* signed*/ 1 ));
1147
+ Value arangeLength = rewriter.create <Torch::ConstantIntOp>(
1148
+ loc, rewriter.getType <Torch::IntType>(),
1149
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
1150
+ operandTy.getRank ()));
1151
+ axes = rewriter.create <Torch::AtenArangeOp>(
1152
+ loc, defaultAxesType, arangeLength, none, none, none, none);
1153
+ }
1154
+
1155
+ // Binding `steps` from its arguments or through a default value
1156
+ Value steps;
1157
+ if (binder.getNumOperands () >= 5 ) {
1158
+ if (binder.tensorOperandAtIndex (steps, 4 )) {
1159
+ return failure ();
1160
+ }
1161
+ } else {
1162
+ // The default `steps` value is a 1d tensor filled with ones with a
1163
+ // size of the dimension of the operand
1164
+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
1165
+ auto defaultStepsType = Torch::ValueTensorType::get (
1166
+ context, ArrayRef<int64_t >{operandTy.getRank ()},
1167
+ rewriter.getIntegerType (64 , /* signed*/ 1 ));
1168
+ Value sizeStepInput = rewriter.create <Torch::ConstantIntOp>(
1169
+ loc, rewriter.getType <Torch::IntType>(),
1170
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
1171
+ operandTy.getRank ()));
1172
+ Value sizeStepsInput = rewriter.create <Torch::PrimListConstructOp>(
1173
+ loc,
1174
+ Torch::ListType::get (
1175
+ Torch::IntType::get (binder.op ->getContext ())),
1176
+ sizeStepInput);
1177
+ steps = rewriter.create <Torch::AtenOnesOp>(
1178
+ loc, defaultStepsType, sizeStepsInput, none, none, none, none);
1179
+ }
1095
1180
1181
+ if (!(endsTy.getRank () == 1 && startsTy.getRank () == 1 &&
1182
+ startSize == endSize))
1183
+ return rewriter.notifyMatchFailure (
1184
+ binder.op , " Expected the rank of starts and ends tensors to be 1 "
1185
+ " and their dimensions to match" );
1186
+
1187
+ auto axesTorchTy = axes.getType ().cast <Torch::ValueTensorType>();
1188
+ auto axesTy =
1189
+ axesTorchTy.toBuiltinTensor ().dyn_cast <RankedTensorType>();
1190
+ int64_t numAxes = axesTy.getDimSize (0 );
1191
+
1192
+ if (!(axesTy && numAxes == endSize))
1193
+ return rewriter.notifyMatchFailure (
1194
+ binder.op , " Axes should be the same size of starts and ends" );
1195
+
1196
+ auto stepsTy = steps.getType ()
1197
+ .cast <Torch::ValueTensorType>()
1198
+ .toBuiltinTensor ()
1199
+ .dyn_cast <RankedTensorType>();
1200
+
1201
+ if (!(stepsTy && stepsTy.getDimSize (0 ) == endsTy.getDimSize (0 )))
1202
+ return rewriter.notifyMatchFailure (
1203
+ binder.op , " Steps should be the same size of starts and ends" );
1204
+
1205
+ Value zero = rewriter.create <Torch::ConstantIntOp>(
1206
+ loc, rewriter.getType <Torch::IntType>(),
1207
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
1208
+
1209
+ auto select = [&](Value v, Value k) -> Value {
1210
+ auto ty = v.getType ().cast <Torch::ValueTensorType>();
1211
+ auto sel = rewriter.create <Torch::AtenIndexSelectOp>(
1212
+ loc,
1213
+ Torch::ValueTensorType::get (ty.getContext (), ArrayRef<int64_t >{1 },
1214
+ ty.getOptionalDtype ()),
1215
+ v, zero, k);
1216
+ Value item = rewriter.create <Torch::AtenItemOp>(
1217
+ loc, rewriter.getType <Torch::IntType>(), sel);
1218
+ return item;
1219
+ };
1220
+
1221
+ llvm::SmallVector<int64_t > intermediateShape (operandTy.getShape ());
1222
+ for (int i = 0 , s = operandTy.getRank (); i < s; ++i) {
1223
+ if (operandTy.getDimSize (i) != resultTy.getDimSize (i)) {
1224
+ intermediateShape[i] = -1 ;
1225
+ }
1226
+ }
1227
+ auto intermediateType = Torch::ValueTensorType::get (
1228
+ context, intermediateShape, resultTorchType.getOptionalDtype ());
1229
+ for (int i = 0 ; i < numAxes; ++i) {
1230
+
1231
+ Value k = rewriter.create <Torch::ConstantIntOp>(
1232
+ loc, rewriter.getType <Torch::IntType>(),
1233
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), i));
1234
+ Value kTensor = rewriter.create <Torch::PrimNumToTensorScalarOp>(
1235
+ loc,
1236
+ Torch::ValueTensorType::get (
1237
+ context, ArrayRef<int64_t >{1 },
1238
+ rewriter.getIntegerType (64 , /* signed*/ 1 )),
1239
+ k);
1240
+
1241
+ Value start = select (starts, kTensor );
1242
+ Value end = select (ends, kTensor );
1243
+ Value axis = select (axes, kTensor );
1244
+ Value step = select (steps, kTensor );
1245
+
1246
+ auto sliceType = intermediateType;
1247
+ if (i == numAxes - 1 )
1248
+ sliceType = resultTorchType;
1249
+ operand = rewriter.create <Torch::AtenSliceTensorOp>(
1250
+ loc, sliceType, operand, axis, start, end, step);
1251
+ }
1252
+
1253
+ rewriter.replaceOp (binder.op , operand);
1254
+ return success ();
1255
+ });
1096
1256
patterns.onOp (
1097
1257
" Reshape" , 5 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1098
1258
Torch::ValueTensorType resultType;
0 commit comments