Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] [TOSA] Allow any floating point type #175

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1873,11 +1873,11 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
}];

let arguments = (ins
Tosa_Tensor_Plus_F64:$input
Tosa_Tensor:$input
);

let results = (outs
Tosa_Tensor_Plus_F64:$output
Tosa_Tensor:$output
);

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
Expand Down Expand Up @@ -1960,7 +1960,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);

let results = (outs
TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
);

let hasFolder = 1;
Expand Down
49 changes: 10 additions & 39 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,17 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
// Used to express accumulator results or compare results.
//===----------------------------------------------------------------------===//

def Tosa_UInt8 : UI<8>;
def Tosa_UInt16 : UI<16>;

def Tosa_Int4 : I<4>;
def Tosa_Int8 : I<8>;
def Tosa_Int16 : I<16>;
def Tosa_Int32 : I<32>;
def Tosa_Int48 : I<48>;
def Tosa_Int64 : I<64>;

def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
Tosa_Int16,
Tosa_Int32,
Tosa_Int48,
Tosa_Int64]>;

def Tosa_Bool : I<1>;

def Tosa_Int : AnyTypeOf<[Tosa_Bool,
AnyUnsignedInteger,
AnySignlessInteger,
// TODO: For backwards compatibility, keep Tosa_SignedInt, which is actually
// a set of signless types.
Tosa_SignedInt]>;
// The TOSA dialect allows more types than the TOSA standard to allow for
// experimentation. For historical reasons, signless is used in the place of
// signed.
// The TosaValidation pass can be used to check for standard conformance.
def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
AnySignlessInteger]>;

def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
Tosa_Int64]>;
Expand All @@ -84,28 +71,16 @@ def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
Tosa_QuantizedType<"int16", [16, 0], 1>,
Tosa_QuantizedType<"int32", [32, 0], 1>]>;

//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//
def Tosa_Float : AnyTypeOf<[
F32,
F16,
BF16]>;

//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;

// Add F64 type support just for tosa::CastOp and tosa::ConstOp
def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64],
"number_plus_f64">;

// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, Tosa_Float]>;
Tosa_QuantizedInt, AnyFloat]>;

//===----------------------------------------------------------------------===//
// Tensor types
Expand All @@ -114,18 +89,17 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
def Tosa_IntTensor : TensorOf<[Tosa_Int]>;
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;

def Tosa_FloatTensor : TensorOf<[Tosa_Float]>;
def Tosa_FloatTensor : TensorOf<[AnyFloat]>;

// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;

// Must be ranked but no further constraints
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;

// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
Tosa_Float.predicate]>, "tosa.dtype">;
AnyFloat.predicate]>, "tosa.dtype">;

class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
Expand Down Expand Up @@ -173,9 +147,6 @@ class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<

def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;

//===----------------------------------------------------------------------===//
// Attribute predicates and classes.
Expand Down
50 changes: 47 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);

bool isValidElementType(Type type);

SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
Expand Down Expand Up @@ -503,15 +505,57 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
return success();
}

bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
if (profile == TosaProfileEnum::BaseInference)
return false;
return type.isF32() || type.isF16() || type.isBF16();
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
switch (intTy.getWidth()) {
case 8:
case 16:
return true;
default:
return false;
}
} else {
// Signless - treated as signed.
switch (intTy.getWidth()) {
case 1:
case 4:
case 8:
case 16:
case 32:
case 48:
case 64:
return true;
default:
return false;
}
}
return false;
}
return true;
}

void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
if ((profile == TosaProfileEnum::BaseInference) &&
isa<FloatType>(getElementTypeOrSelf(operand))) {
auto elementTy = getElementTypeOrSelf(operand);
if (!isValidElementType(elementTy)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
}
if (getElementTypeOrSelf(operand).isF64()) {
}
for (Type resultTy : op->getResultTypes()) {
auto elementTy = getElementTypeOrSelf(resultTy);
if (!isValidElementType(elementTy)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
}
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
// -----

func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}}
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,30 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {

// -----

func.func @test_const_i2(%arg0 : tensor<1xi2>) {
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'i2' is not legal}}
%0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
return
}

// -----

func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'ui32' is not legal}}
%0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32>
return
}

// -----

func.func @test_const_f64(%arg0 : tensor<1xf64>) {
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'f64' is not legal}}
%0 = "tosa.const"() {value = dense<0.0> : tensor<1xf64>} : () -> tensor<1xf64>
return
}

// -----

func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
Expand Down