Skip to content

Commit

Permalink
[AutoBump] Merge with fixes of 2b01f8b (Oct 26)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Jan 28, 2025
2 parents 8e6a9e0 + 2b01f8b commit 93265ca
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 38 deletions.
83 changes: 48 additions & 35 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4639,6 +4639,25 @@ class ConvertAtenIndexTensorOpNone
}
};

Value wrapNegativeIndices(Value index, int maxIndex, Operation *op,
ConversionPatternRewriter &rewriter) {

auto zeroValue = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
auto maxIndexValue =
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();

auto indexType = dyn_cast<RankedTensorType>(index.getType());

auto wrappedIndicesOp = tosa::CreateOpAndInfer<tosa::AddOp>(
rewriter, op->getLoc(), indexType, maxIndexValue, index);
auto boolType = indexType.clone(rewriter.getIntegerType(1));
auto isNegativeIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, zeroValue, index);
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
indexType, isNegativeIndices,
wrappedIndicesOp, index);
}

template <>
LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -4677,6 +4696,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(

auto outType = getTypeConverter()->convertType(op.getType());

Operation *indicesTf;

// Support for multiple indexes
if (indexTensors.size() > 1) {
// t[i, i]
Expand Down Expand Up @@ -4710,6 +4731,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
index);
}

index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op,
rewriter);
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indiceShapeOneDim;
for (auto shape : indexShape) {
Expand Down Expand Up @@ -4852,57 +4875,47 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
auto indicesShapeConcat = indexesShape[0];
uint64_t lastDim = indexesRank[0];
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
indicesTfConcatTensors, lastDim);

if (!indicesTf) {
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
} else {

if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
// Single index
auto index = indexTensors[0];
auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
index);
}
rewriter.replaceOp(op, {result.value()});

return success();
}
index =
wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter);

// Support for multiple index
auto index = indexTensors[0];
auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}

// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
}
indicesShape.push_back(1);
indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));
}
indicesShape.push_back(1);
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));

if (!indicesTf) {
return rewriter.notifyMatchFailure(op,
"Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
indicesTf->getResult(0));

if (!result) {
return rewriter.notifyMatchFailure(
Expand Down
5 changes: 2 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,15 +1747,13 @@
"ArangeStartOutModule_basic",
"ScatterSrcStaticModule_basic",
# Runtime op verification: Out of bounds access
"IndexTensorNegativeIndexModule_basic",
"ReduceAllDimEmpty_basic",
}

FX_IMPORTER_TOSA_CRASHING_SET = {
"ScatterSrcModule_basic",
"ScatterSrcStaticModule_basic",
"HBC_basic",
"IndexTensorNegativeIndexModule_basic",
"InterpolateDynamicModule_scales_recompute_bilinear",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
Expand Down Expand Up @@ -2217,6 +2215,7 @@
"HardswishRandomModule_basic",
"HardtanhBackward_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorNegativeIndexModule_basic",
"IndexTensorStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"IscloseStaticModule_basic",
Expand Down Expand Up @@ -3670,7 +3669,7 @@
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"IndexTensorNegativeIndexModule_basic",
"IndexSelectRank0IdxModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
Expand Down
32 changes: 32 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2373,3 +2373,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t
%0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32>
return %0 : !torch.vtensor<[2,3,4,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64>
// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<i64>) -> tensor<i32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: 1, 2, 8>} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64>
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64>
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 4, 2>} : (tensor<1x1x8xi64>) -> tensor<4x2xi64>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64>
// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64>

func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
return %1 : !torch.vtensor<[4,2],si64>
}

0 comments on commit 93265ca

Please sign in to comment.