-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui #91491
Conversation
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Corentin Ferry (cferry-AMD) ChangesThese operations can be lowered to EmitC provided the sign-extension and truncation behavior is respected. Per C++ Reference: when casting to a narrower integer, truncation is guaranteed if unsigned casts are performed, or C++20 is used regardless of the sign. This implementation sticks to unsigned for trunci, so C++20 is not necessary. This implementation is a bit more generic than needed by these three operations to accomodate Full diff: https://github.com/llvm/llvm-project/pull/91491.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 1447b182ccfdb..6216e6ea89b9b 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -112,6 +112,78 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
}
};
+template <typename ArithOp, bool needsUnsigned>
+class CastConversion : public OpConversionPattern<ArithOp> {
+public:
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type opReturnType = this->getTypeConverter()->convertType(op.getType());
+ if (!isa_and_nonnull<IntegerType>(opReturnType)) {
+ return rewriter.notifyMatchFailure(op, "expected integer result type");
+ }
+
+ if (adaptor.getOperands().size() != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "CastConversion only supports unary ops");
+ }
+
+ Type operandType = adaptor.getIn().getType();
+ if (!isa_and_nonnull<IntegerType>(operandType)) {
+ return rewriter.notifyMatchFailure(op, "expected integer operand type");
+ }
+
+ bool isTruncation = operandType.getIntOrFloatBitWidth() >
+ opReturnType.getIntOrFloatBitWidth();
+ bool doUnsigned = needsUnsigned || isTruncation;
+
+ Type castType = opReturnType;
+ // For int conversions: if the op is a ui variant and the type wanted as
+ // return type isn't unsigned, we need to issue an unsigned type to do
+ // the conversion.
+ if (castType.isUnsignedInteger() != doUnsigned) {
+ castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
+ /*isSigned=*/!doUnsigned);
+ }
+
+ Value actualOp = adaptor.getIn();
+ // Fix the signedness of the operand if necessary
+ if (operandType.isUnsignedInteger() != doUnsigned) {
+ Type correctSignednessType =
+ rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+ /*isSigned=*/!doUnsigned);
+ actualOp = rewriter.template create<emitc::CastOp>(
+ op.getLoc(), correctSignednessType, actualOp);
+ }
+
+ auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
+ actualOp);
+
+ // Fix the signedness of what this operation returns (for integers,
+ // the arith ops want signless results)
+ if (castType != opReturnType) {
+ result = rewriter.template create<emitc::CastOp>(op.getLoc(),
+ opReturnType, result);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+template <typename ArithOp>
+class UnsignedCastConversion : public CastConversion<ArithOp, true> {
+ using CastConversion<ArithOp, true>::CastConversion;
+};
+
+template <typename ArithOp>
+class SignedCastConversion : public CastConversion<ArithOp, false> {
+ using CastConversion<ArithOp, false>::CastConversion;
+};
+
template <typename ArithOp, typename EmitCOp>
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
public:
@@ -313,6 +385,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
CmpIOpConversion,
SelectOpConversion,
+ // Truncation is guaranteed for unsigned types.
+ UnsignedCastConversion<arith::TruncIOp>,
+ SignedCastConversion<arith::ExtSIOp>,
+ UnsignedCastConversion<arith::ExtUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 66dfa8fa3e157..551c3ba7a77ef 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -63,3 +63,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
return %t: i1
}
+// -----
+
+func.func @index_cast(%arg0: i32) -> i32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.index_cast'}}
+ %idx = arith.index_cast %arg0 : i32 to index
+ %int = arith.index_cast %idx : index to i32
+
+ return %int : i32
+}
+
+// -----
+
+func.func @index_castui(%arg0: i32) -> i32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.index_castui'}}
+ %idx = arith.index_castui %arg0 : i32 to index
+ %int = arith.index_castui %idx : index to i32
+
+ return %int : i32
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 79fecd61494d0..80665bacd2a5c 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -177,3 +177,42 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
return
}
+
+// -----
+
+func.func @trunci(%arg0: i32) -> i8 {
+ // CHECK-LABEL: trunci
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+ // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+ // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
+ // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
+ %truncd = arith.trunci %arg0 : i32 to i8
+
+ return %truncd : i8
+}
+
+// -----
+
+func.func @extsi(%arg0: i32) {
+ // CHECK-LABEL: extsi
+ // CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
+ // CHECK: emitc.cast [[Arg0]] : i32 to i64
+
+ %extd = arith.extsi %arg0 : i32 to i64
+
+ return
+}
+
+// -----
+
+func.func @extui(%arg0: i32) {
+ // CHECK-LABEL: extui
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+ // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+ // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
+ // CHECK: emitc.cast %[[Conv1]] : ui64 to i64
+
+ %extd = arith.extui %arg0 : i32 to i64
+
+ return
+}
|
Pool of reviewers: @simon-camp @marbre @TinaAMD @mgehre-amd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat trick with the unsigned interpretation for truncation!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think truncation to i1 needs to be handled specially, as the arith dialect discards the high bits, but a conversion to bool is similar to x != 0
. For the same reason signed extension from i1 should be rejected by this pattern. Unsigned extension from i1 works correctly I think.
mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
Outdated
Show resolved
Hide resolved
Yes, the i1 case is special indeed -- thanks for the remark! I added its handling. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the formatting and missing newline this looks good to me.
Out of curiosity, what would be a good way of implementing the sign extension for i1? Doing unsigned extension, oring with 0b1111...00...
and casting back to signed?
I would do just the unsigned extension: I don't think it makes sense to interpret the only bit of an i1 as a sign bit (then what's its value?)... then the other choice would be to see a 1 as Now assuming we still sign-extend: as an alternative to the |
These operations can be lowered to EmitC provided the sign-extension and truncation behavior is respected.
Per C++ Reference: when casting to a narrower integer, truncation is guaranteed if unsigned casts are performed, or C++20 is used regardless of the sign. This implementation sticks to unsigned for trunci, so C++20 is not necessary.
This implementation is a bit more generic than needed by these three operations to accomodate
index_cast
andindex_castui
at a later point (specificemitc.size_t
andemitc.ssize_t
types are being discussed).