From 2b01f8b7f3cca87c3dc9c75edd91397803e9f6d4 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 25 Oct 2024 18:37:19 -0400 Subject: [PATCH] [Tosa] : Add support for negative indices in index.tensor and index.Tensor_hacked_twin for TorchToTosa lowering. (#3790) 1. Negative indices for tensor indexing is handled by wrapping around the index values by checking their values at run time. Without the fix, there was a runtime error. 2. Added a lit test to lock down the behavior. 3. Updated the `xfails_set` for `fx_importer_tosa` config to lockdown the behavior with e2e test as well. "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY." --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 83 +++++++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 4 +- test/Conversion/TorchToTosa/basic.mlir | 32 +++++++++ 3 files changed, 81 insertions(+), 38 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e5f4fea4f46c..b6dbdc2c7b8c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4093,6 +4093,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +Value wrapNegativeIndices(Value index, int maxIndex, Operation *op, + ConversionPatternRewriter &rewriter) { + + auto zeroValue = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto maxIndexValue = + tosa::getConstTensor(rewriter, op, maxIndex, {}).value(); + + auto indexType = dyn_cast(index.getType()); + + auto wrappedIndicesOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), indexType, maxIndexValue, index); + auto boolType = indexType.clone(rewriter.getIntegerType(1)); + auto isNegativeIndices = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, zeroValue, index); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), + indexType, isNegativeIndices, + wrappedIndicesOp, index); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, @@ -4124,6 +4143,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = getTypeConverter()->convertType(op.getType()); + Operation *indicesTf; + // Support for multiple indexes if (indexTensors.size() > 1) { // t[i, i] @@ -4157,6 +4178,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( index); } + index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op, + rewriter); // Expand last dim of index to tf indices [2,3] -> [2,3,1] SmallVector indiceShapeOneDim; for (auto shape : indexShape) { @@ -4299,49 +4322,39 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indicesShapeConcat = indexesShape[0]; uint64_t lastDim = indexesRank[0]; indicesShapeConcat.push_back(indicesTfConcatTensors.size()); - auto indicesTf = tosa::CreateOpAndInfer( + indicesTf = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), indicesTfConcatTensors, lastDim); - if (!indicesTf) { - return rewriter.notifyMatchFailure( - op, "Convert TorchIndex To TfIndices fail."); - } - // do the tf gathernp algorithm with tf style indices as input. - auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + } else { - if (!result) { - return rewriter.notifyMatchFailure( - op, "Convert GatherNdOp fail for index tensor."); + // Single index + auto index = indexTensors[0]; + auto indexType = dyn_cast(index.getType()); + auto indexShape = indexType.getShape(); + // index i64 to i32 for tosa compatible + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), + index); } - rewriter.replaceOp(op, {result.value()}); - return success(); - } + index = + wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter); - // Support for multiple index - auto index = indexTensors[0]; - auto indexType = dyn_cast(index.getType()); - auto indexShape = indexType.getShape(); - // index i64 to i32 for tosa compatible - if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); - } - - // Expand last dim of index to tf indices [2,3] -> [2,3,1] - SmallVector indicesShape; - for (auto shape : indexShape) { - indicesShape.push_back(shape); + // Expand last dim of index to tf indices [2,3] -> [2,3,1] + SmallVector indicesShape; + for (auto shape : indexShape) { + indicesShape.push_back(shape); + } + indicesShape.push_back(1); + indicesTf = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, + rewriter.getDenseI64ArrayAttr(indicesShape)); } - indicesShape.push_back(1); - auto indicesTf = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, - rewriter.getDenseI64ArrayAttr(indicesShape)); if (!indicesTf) { return rewriter.notifyMatchFailure(op, @@ -4349,7 +4362,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // do the tf gathernp algorithm with tf style indices as input. auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + indicesTf->getResult(0)); if (!result) { return rewriter.notifyMatchFailure( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e370a1d8b73d..82ca24443162 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1698,7 +1698,6 @@ "ArangeStartOutModule_basic", "ScatterSrcStaticModule_basic", # Runtime op verification: Out of bounds access - "IndexTensorNegativeIndexModule_basic", "ReduceAllDimEmpty_basic", } @@ -1706,7 +1705,6 @@ "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", - "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_scales_recompute_bilinear", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", @@ -2162,6 +2160,7 @@ "HardswishRandomModule_basic", "HardtanhBackward_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorNegativeIndexModule_basic", "IndexTensorStaticModule_basic", "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", @@ -3635,7 +3634,6 @@ "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", "IndexSelectRank0IdxModule_basic", - "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", "InterpolateStaticModule_scales_bilinear_align_corners", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index e412bb390c35..ed6f909c4a1b 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2131,3 +2131,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t %0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32> return %0 : !torch.vtensor<[2,3,4,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64> +// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<1x1x8xi64>) -> tensor<4x2xi64> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64> + +func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { + %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + return %1 : !torch.vtensor<[4,2],si64> + }