diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 0337314ce7f34..5fe4fc8695017 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -274,6 +274,78 @@ class CmpIOpConversion : public OpConversionPattern { } }; +template +class CastConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type opReturnType = this->getTypeConverter()->convertType(op.getType()); + if (!isa_and_nonnull(opReturnType)) { + return rewriter.notifyMatchFailure(op, "expected integer result type"); + } + + if (adaptor.getOperands().size() != 1) { + return rewriter.notifyMatchFailure( + op, "CastConversion only supports unary ops"); + } + + Type operandType = adaptor.getIn().getType(); + if (!isa_and_nonnull(operandType)) { + return rewriter.notifyMatchFailure(op, "expected integer operand type"); + } + + bool isTruncation = operandType.getIntOrFloatBitWidth() > + opReturnType.getIntOrFloatBitWidth(); + bool doUnsigned = needsUnsigned || isTruncation; + + Type castType = opReturnType; + // For int conversions: if the op is a ui variant and the type wanted as + // return type isn't unsigned, we need to issue an unsigned type to do + // the conversion. + if (castType.isUnsignedInteger() != doUnsigned) { + castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(), + /*isSigned=*/!doUnsigned); + } + + Value actualOp = adaptor.getIn(); + // Fix the signedness of the operand if necessary + if (operandType.isUnsignedInteger() != doUnsigned) { + Type correctSignednessType = + rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(), + /*isSigned=*/!doUnsigned); + actualOp = rewriter.template create( + op.getLoc(), correctSignednessType, actualOp); + } + + auto result = rewriter.template create(op.getLoc(), castType, + actualOp); + + // Fix the signedness of what this operation returns (for integers, + // the arith ops want signless results) + if (castType != opReturnType) { + result = rewriter.template create(op.getLoc(), + opReturnType, result); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +template +class UnsignedCastConversion : public CastConversion { + using CastConversion::CastConversion; +}; + +template +class SignedCastConversion : public CastConversion { + using CastConversion::CastConversion; +}; + template class ArithOpConversion final : public OpConversionPattern { public: @@ -478,6 +550,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, CmpFOpConversion, CmpIOpConversion, SelectOpConversion, + // Truncation is guaranteed for unsigned types. + UnsignedCastConversion, + SignedCastConversion, + UnsignedCastConversion, ItoFCastOpConversion, ItoFCastOpConversion, FtoICastOpConversion, diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 32c0c0381d326..5fcb2b3a553e5 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -79,3 +79,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 { return %t: i1 } +// ----- + +func.func @index_cast(%arg0: i32) -> i32 { + // expected-error @+1 {{failed to legalize operation 'arith.index_cast'}} + %idx = arith.index_cast %arg0 : i32 to index + %int = arith.index_cast %idx : index to i32 + + return %int : i32 +} + +// ----- + +func.func @index_castui(%arg0: i32) -> i32 { + // expected-error @+1 {{failed to legalize operation 'arith.index_castui'}} + %idx = arith.index_castui %arg0 : i32 to index + %int = arith.index_castui %idx : index to i32 + + return %int : i32 +} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index ed63d40808973..bda1180282142 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -390,3 +390,42 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) { return } + +// ----- + +func.func @trunci(%arg0: i32) -> i8 { + // CHECK-LABEL: trunci + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32 + // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8 + // CHECK: emitc.cast %[[Trunc]] : ui8 to i8 + %truncd = arith.trunci %arg0 : i32 to i8 + + return %truncd : i8 +} + +// ----- + +func.func @extsi(%arg0: i32) { + // CHECK-LABEL: extsi + // CHECK-SAME: ([[Arg0:[^ ]*]]: i32) + // CHECK: emitc.cast [[Arg0]] : i32 to i64 + + %extd = arith.extsi %arg0 : i32 to i64 + + return +} + +// ----- + +func.func @extui(%arg0: i32) { + // CHECK-LABEL: extui + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32 + // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64 + // CHECK: emitc.cast %[[Conv1]] : ui64 to i64 + + %extd = arith.extui %arg0 : i32 to i64 + + return +}