Skip to content

Commit 670a99a

Browse files
zezhangZe Zhang
andauthored
Handle torch.none type in tosa.clamp op (#2739)
This PR updates the torch-to-tosa conversion with following changes: - Support torch.none as min/max input argument for tosa.clamp op - Support negative value as start index for tosa.slice op - Add tosa.logical_or lowering support e2e test: python -m e2e_testing.main --config=tosa LIT tests: cmake --build build --target tools/torch-mlir/all --------- Co-authored-by: Ze Zhang <[email protected]>
1 parent 47ffc90 commit 670a99a

File tree

3 files changed

+129
-32
lines changed

3 files changed

+129
-32
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1919
#include "mlir/Dialect/Traits.h"
2020
#include "mlir/IR/Matchers.h"
21+
#include "mlir/Support/LogicalResult.h"
2122
#include "mlir/Transforms/DialectConversion.h"
2223
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
2324
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@@ -3336,9 +3337,11 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
33363337
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
33373338
return rewriter.notifyMatchFailure(op, "start must be a Scalar constant");
33383339

3339-
if (start < 0)
3340-
return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0");
3341-
3340+
if (start < 0) {
3341+
start = toPositiveDim(start, selfType.getShape()[dim]);
3342+
if (!isValidDim(start, selfType.getShape()[dim]))
3343+
return rewriter.notifyMatchFailure(op, "start is not a valid index");
3344+
}
33423345
start = std::min(selfType.getShape()[dim], start);
33433346

33443347
int64_t end;
@@ -3984,36 +3987,46 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
39843987
return rewriter.notifyMatchFailure(
39853988
op, "only tensor types input are currently supported");
39863989

3987-
IntegerAttr min_int, max_int;
3988-
FloatAttr min_fp, max_fp;
3989-
if (op.getMin().getType().isa<Torch::FloatType>()) {
3990-
double fp_min, fp_max;
3991-
if (!matchPattern(op.getMin(), m_TorchConstantFloat(&fp_min)))
3992-
return rewriter.notifyMatchFailure(
3993-
op, "unimplemented: value `fp_min` should be a torch constant float");
3994-
3995-
if (!matchPattern(op.getMax(), m_TorchConstantFloat(&fp_max)))
3996-
return rewriter.notifyMatchFailure(
3997-
op, "unimplemented: value `fp_max` should be a torch constant float");
3998-
3999-
min_int = rewriter.getI64IntegerAttr(static_cast<int64_t>(fp_min));
4000-
max_int = rewriter.getI64IntegerAttr(static_cast<int64_t>(fp_max));
4001-
min_fp = rewriter.getF32FloatAttr(static_cast<float>(fp_min));
4002-
max_fp = rewriter.getF32FloatAttr(static_cast<float>(fp_max));
4003-
} else {
4004-
int64_t int_min, int_max;
4005-
if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min)))
4006-
return rewriter.notifyMatchFailure(
4007-
op, "unimplemented: value `int_min` should be a torch constant int");
4008-
4009-
if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max)))
4010-
return rewriter.notifyMatchFailure(
4011-
op, "unimplemented: value `int_max` should be a torch constant int");
3990+
IntegerAttr min_int =
3991+
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min());
3992+
IntegerAttr max_int =
3993+
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::max());
3994+
FloatAttr min_fp =
3995+
rewriter.getF32FloatAttr(std::numeric_limits<float>::lowest());
3996+
FloatAttr max_fp =
3997+
rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
3998+
3999+
auto getValAttr = [&](Value operand, IntegerAttr &intAttr,
4000+
FloatAttr &fpAttr) -> LogicalResult {
4001+
double valFloat;
4002+
int64_t valInt;
4003+
if (matchPattern(operand, m_TorchConstantFloat(&valFloat))) {
4004+
intAttr = rewriter.getI64IntegerAttr(static_cast<int64_t>(valFloat));
4005+
fpAttr = rewriter.getF32FloatAttr(static_cast<float>(valFloat));
4006+
} else if (matchPattern(operand, m_TorchConstantInt(&valInt))) {
4007+
intAttr = rewriter.getI64IntegerAttr(valInt);
4008+
fpAttr = rewriter.getF32FloatAttr(static_cast<float>(valInt));
4009+
} else {
4010+
return failure();
4011+
}
4012+
return success();
4013+
};
40124014

4013-
min_int = rewriter.getI64IntegerAttr(int_min);
4014-
max_int = rewriter.getI64IntegerAttr(int_max);
4015-
min_fp = rewriter.getF32FloatAttr(static_cast<float>(int_min));
4016-
max_fp = rewriter.getF32FloatAttr(static_cast<float>(int_max));
4015+
LogicalResult minAttrResult = getValAttr(op.getMin(), min_int, min_fp);
4016+
LogicalResult maxAttrResult = getValAttr(op.getMax(), max_int, max_fp);
4017+
if (failed(minAttrResult) && failed(maxAttrResult)) {
4018+
return rewriter.notifyMatchFailure(
4019+
op, "either `min` or `max` should be a torch constant");
4020+
}
4021+
if (failed(minAttrResult) &&
4022+
succeeded(checkNotNone(rewriter, op, op.getMin()))) {
4023+
return rewriter.notifyMatchFailure(op,
4024+
"min attr should be a torch constant");
4025+
}
4026+
if (failed(maxAttrResult) &&
4027+
succeeded(checkNotNone(rewriter, op, op.getMax()))) {
4028+
return rewriter.notifyMatchFailure(op,
4029+
"max attr should be a torch constant");
40174030
}
40184031

40194032
auto outType = getTypeConverter()->convertType(op.getType());
@@ -5025,6 +5038,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
50255038
patterns.add<ConvertAtenBinaryOp<AtenOp, TosaOp>>(typeConverter, context);
50265039
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
50275040
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
5041+
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
50285042
#undef INSERT_BINARY_PATTERN
50295043

50305044
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,15 @@
10351035
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
10361036
"ElementwiseAtenDivIntScalarModule_basic",
10371037
"ElementwiseAtenIsinfOpModule_basic",
1038+
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
1039+
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
1040+
"ElementwiseAtenLogicalOrOpDiffArgs2Module_basic",
1041+
"ElementwiseAtenLogicalOrOpDiffArgs3Module_basic",
1042+
"ElementwiseAtenLogicalOrOpModule_basic",
1043+
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
1044+
"ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic",
1045+
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
1046+
"ElementwiseAtenLogicalOrOpRandomModule_basic",
10381047
"ElementwiseAtenWhereSelfModule_basic",
10391048
"ElementwiseBinaryModule_basic",
10401049
"ElementwiseBinaryStaticShapeModule_basic",
@@ -1047,6 +1056,9 @@
10471056
"ElementwiseBitwiseXorModule_basic",
10481057
"ElementwiseBitwiseXorStaticShapeModule_basic",
10491058
"ElementwiseCeilModule_basic",
1059+
"ElementwiseClampMaxModule_basic",
1060+
"ElementwiseClampMinModule_basic",
1061+
"ElementwiseClampModule_basic",
10501062
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
10511063
"ElementwiseCloneContiguousModule_basic",
10521064
"ElementwiseCloneModule_basic",

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,22 @@ func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
645645

646646
// -----
647647

648+
// CHECK-LABEL: func.func @torch.aten.logical_or$basic(
649+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>,
650+
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
651+
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
652+
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
653+
// CHECK: %[[VAL_4:.*]] = tosa.logical_or %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
654+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
655+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
656+
// CHECK: }
657+
func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
658+
%0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
659+
return %0 : !torch.vtensor<[?,?],i1>
660+
}
661+
662+
// -----
663+
648664
// CHECK-LABEL: func.func @forward(
649665
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
650666
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
@@ -1055,6 +1071,61 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) ->
10551071
return %0 : !torch.vtensor<[1,1,128,128],si64>
10561072
}
10571073

1074+
// -----
1075+
// CHECK-LABEL: func.func @torch.aten.slice.negative_start(
1076+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> {
1077+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32>
1078+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
1079+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
1080+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 100
1081+
// CHECK: %[[VAL_5:.*]] = torch.constant.int -16
1082+
// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 4, 16, 256>, start = array<i64: 0, 49, 0>} : (tensor<4x65x256xf32>) -> tensor<4x16x256xf32>
1083+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32>
1084+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,16,256],f32>
1085+
// CHECK: }
1086+
func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> {
1087+
%int0 = torch.constant.int 0
1088+
%int1 = torch.constant.int 1
1089+
%int100 = torch.constant.int 100
1090+
%int-16 = torch.constant.int -16
1091+
%0 = torch.aten.slice.Tensor %arg0, %int1, %int-16, %int100, %int1 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,16,256],f32>
1092+
return %0 : !torch.vtensor<[4,16,256],f32>
1093+
}
1094+
1095+
// -----
1096+
// CHECK-LABEL: func.func @torch.aten.clamp.min_none(
1097+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
1098+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64>
1099+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
1100+
// CHECK: %[[VAL_3:.*]] = torch.constant.none
1101+
// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0.000000e+00 : f32, max_int = 0 : i64, min_fp = -3.40282347E+38 : f32, min_int = -9223372036854775808 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64>
1102+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64>
1103+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64>
1104+
// CHECK: }
1105+
func.func @torch.aten.clamp.min_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
1106+
%int0 = torch.constant.int 0
1107+
%none = torch.constant.none
1108+
%0 = torch.aten.clamp %arg0, %none, %int0 : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.int -> !torch.vtensor<[1,1,128,128],si64>
1109+
return %0 : !torch.vtensor<[1,1,128,128],si64>
1110+
}
1111+
1112+
// -----
1113+
// CHECK-LABEL: func.func @torch.aten.clamp.max_none(
1114+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
1115+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64>
1116+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
1117+
// CHECK: %[[VAL_3:.*]] = torch.constant.none
1118+
// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64>
1119+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64>
1120+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64>
1121+
// CHECK: }
1122+
func.func @torch.aten.clamp.max_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
1123+
%int0 = torch.constant.int 0
1124+
%none = torch.constant.none
1125+
%0 = torch.aten.clamp %arg0, %int0, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,1,128,128],si64>
1126+
return %0 : !torch.vtensor<[1,1,128,128],si64>
1127+
}
1128+
10581129
// -----
10591130
// CHECK-LABEL: func.func @torch.aten.clamp(
10601131
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {

0 commit comments

Comments
 (0)