diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 1680d521d919f..eaf952fba474d 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -321,15 +321,20 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (isa(input) || isa(output)) return true; - // Cast to array is only possible from an array - if (isa(input) != isa(output)) - return false; // Arrays can be casted to arrays by reference. if (isa(input) && isa(output)) return true; // Scalars + if (auto arrayType = dyn_cast(input)) { + if (auto pointerType = dyn_cast(output)) { + return (arrayType.getElementType() == pointerType.getPointee()) && + arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1; + } + return false; + } + return ( (emitc::isIntegerIndexOrOpaqueType(input) || emitc::isSupportedFloatType(input) || isa(input)) && @@ -809,9 +814,9 @@ void IfOp::print(OpAsmPrinter &p) { /// Given the region at `index`, or the parent operation if `index` is None, /// return the successor regions. These are the regions that may be selected -/// during the flow of control. `operands` is a set of optional attributes that -/// correspond to a constant value for each operand, or null if that operand is -/// not a constant. +/// during the flow of control. `operands` is a set of optional attributes +/// that correspond to a constant value for each operand, or null if that +/// operand is not a constant. void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. @@ -1161,8 +1166,8 @@ emitc::ArrayType::cloneWith(std::optional> shape, LogicalResult mlir::emitc::LValueType::verify( llvm::function_ref emitError, mlir::Type value) { - // Check that the wrapped type is valid. This especially forbids nested lvalue - // types. + // Check that the wrapped type is valid. This especially forbids nested + // lvalue types. if (!isSupportedEmitCType(value)) return emitError() << "!emitc.lvalue must wrap supported emitc type, but got " << value; diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index b9d46e6dc5280..0b04325002b00 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -133,6 +133,45 @@ func.func @cast_tensor(%arg : tensor) { func.func @cast_array(%arg : !emitc.array<4xf32>) { // expected-error @+1 {{'emitc.cast' op cast of array must bear a reference}} %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> +} + +// ----- + +func.func @cast_to_array(%arg : f32) { + // expected-error @+1 {{'emitc.cast' op operand type 'f32' and result type '!emitc.array<4xf32>' are cast incompatible}} + %1 = emitc.cast %arg: f32 to !emitc.array<4xf32> + return +} + +// ----- + +func.func @cast_multidimensional_array(%arg : !emitc.array<1x2xi32>) { + // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<1x2xi32>' and result type '!emitc.ptr' are cast incompatible}} + %1 = emitc.cast %arg: !emitc.array<1x2xi32> to !emitc.ptr + return +} + +// ----- + +func.func @cast_array_zero_rank(%arg : !emitc.array<0xi32>) { + // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<0xi32>' and result type '!emitc.ptr' are cast incompatible}} + %1 = emitc.cast %arg: !emitc.array<0xi32> to !emitc.ptr + return +} + +// ----- + +func.func @cast_array_to_pointer_types_mismatch(%arg : !emitc.array<3xi32>) { + // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<3xi32>' and result type '!emitc.ptr' are cast incompatible}} + %1 = emitc.cast %arg: !emitc.array<3xi32> to !emitc.ptr + return +} + +// ----- + +func.func @cast_pointer_to_array(%arg : !emitc.ptr) { + // expected-error @+1 {{'emitc.cast' op operand type '!emitc.ptr' and result type '!emitc.array<3xi32>' are cast incompatible}} + %1 = emitc.cast %arg: !emitc.ptr to !emitc.array<3xi32> return } diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 80a33b2b9621f..aa61ba6dbfa0a 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -48,6 +48,11 @@ func.func @cast_array(%arg : !emitc.array<4xf32>) { return } +func.func @cast_array_to_pointer(%arg0: !emitc.array<3xi32>) { + %1 = emitc.cast %arg0: !emitc.array<3xi32> to !emitc.ptr + return +} + func.func @c() { %1 = "emitc.constant"(){value = 42 : i32} : () -> i32 %2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t