diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 0ecded75c5d8b..306e4a4395208 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1942,7 +1942,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure, ); let results = (outs - TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output + TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output ); let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 5a4d6ff464f19..cff3de0a69af9 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -38,29 +38,17 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed> // Used to express accumulator results or compare results. //===----------------------------------------------------------------------===// -def Tosa_UInt8 : UI<8>; -def Tosa_UInt16 : UI<16>; - def Tosa_Int4 : I<4>; def Tosa_Int8 : I<8>; -def Tosa_Int16 : I<16>; def Tosa_Int32 : I<32>; -def Tosa_Int48 : I<48>; def Tosa_Int64 : I<64>; -def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8, - Tosa_Int16, - Tosa_Int32, - Tosa_Int48, - Tosa_Int64]>; - -def Tosa_Bool : I<1>; - -// No unsigned unquantized int types. -def Tosa_Int : AnyTypeOf<[Tosa_Bool, - Tosa_UInt8, - Tosa_UInt16, - Tosa_SignedInt]>; +// The TOSA dialect allows more types than the TOSA standard to allow for +// experimentation. For historical reasons, signless is used in the place of +// signed. +// The TosaValidation pass can be used to check for standard conformance. +def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger, + AnySignlessInteger]>; def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, Tosa_Int64]>; @@ -172,9 +160,6 @@ class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint< def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">; def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">; -def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">; -def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">; -def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">; //===----------------------------------------------------------------------===// // Attribute predicates and classes. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 967775281ad91..b669b7362e943 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -410,6 +410,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> { bool CheckVariable(Operation *op); bool CheckVariableReadOrWrite(Operation *op); + bool isValidElementType(Type type); + SmallVector<std::function<LogicalResult(Operation *)>> constCheckers; TosaLevel tosaLevel; DenseMap<StringAttr, mlir::Type> variablesMap; @@ -503,15 +505,58 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) { return success(); } +bool TosaValidation::isValidElementType(Type type) { + if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) { + return false; + } + if (type.isF64()) { + return false; + } + if (auto intTy = dyn_cast<IntegerType>(type)) { + if (intTy.isUnsigned()) { + switch (intTy.getWidth()) { + case 8: + case 16: + return true; + default: + return false; + } + } else { + // Signless - treated as signed. + switch (intTy.getWidth()) { + case 1: + case 4: + case 8: + case 16: + case 32: + case 48: + case 64: + return true; + default: + return false; + } + } + return false; + } + return true; +} + void TosaValidation::runOnOperation() { configLevelAndProfile(); getOperation().walk([&](Operation *op) { for (Value operand : op->getOperands()) { - if ((profile == TosaProfileEnum::BaseInference) && - isa<FloatType>(getElementTypeOrSelf(operand))) { + auto elementTy = getElementTypeOrSelf(operand); + if (!isValidElementType(elementTy)) { + op->emitOpError() << "failed level check: element type " << elementTy + << " is not legal"; return signalPassFailure(); } - if (getElementTypeOrSelf(operand).isF64()) { + } + for (Type resultTy : op->getResultTypes()) { + auto elementTy = getElementTypeOrSelf(resultTy); + if (!isValidElementType(elementTy)) { + op->emitOpError() << "failed level check: element type " << elementTy + << " is not legal"; return signalPassFailure(); } } diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 35ecbcc799e3d..1d3ef28283670 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -115,6 +115,22 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> { // ----- +func.func @test_const_i2(%arg0 : tensor<1xi2>) { + // expected-error@+1 {{'tosa.const' op failed level check: element type 'i2' is not legal}} + %0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2> + return +} + +// ----- + +func.func @test_const_ui32(%arg0 : tensor<1xui32>) { + // expected-error@+1 {{'tosa.const' op failed level check: element type 'ui32' is not legal}} + %0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32> + return +} + +// ----- + func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}} %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :