Skip to content

Commit 7630379

Browse files
authored
[mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui
This commit adds conversion to EmitC for arith dialect casts between integer types (trunc, extsi, extui), excluding indexes for now.
1 parent 267de85 commit 7630379

File tree

3 files changed

+162
-0
lines changed

3 files changed

+162
-0
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

+92
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/EmitC/IR/EmitC.h"
18+
#include "mlir/Tools/PDLL/AST/Types.h"
1819
#include "mlir/Transforms/DialectConversion.h"
1920

2021
using namespace mlir;
@@ -112,6 +113,93 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
112113
}
113114
};
114115

116+
template <typename ArithOp, bool castToUnsigned>
117+
class CastConversion : public OpConversionPattern<ArithOp> {
118+
public:
119+
using OpConversionPattern<ArithOp>::OpConversionPattern;
120+
121+
LogicalResult
122+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
123+
ConversionPatternRewriter &rewriter) const override {
124+
125+
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
126+
if (!isa_and_nonnull<IntegerType>(opReturnType))
127+
return rewriter.notifyMatchFailure(op, "expected integer result type");
128+
129+
if (adaptor.getOperands().size() != 1) {
130+
return rewriter.notifyMatchFailure(
131+
op, "CastConversion only supports unary ops");
132+
}
133+
134+
Type operandType = adaptor.getIn().getType();
135+
if (!isa_and_nonnull<IntegerType>(operandType))
136+
return rewriter.notifyMatchFailure(op, "expected integer operand type");
137+
138+
// Signed (sign-extending) casts from i1 are not supported.
139+
if (operandType.isInteger(1) && !castToUnsigned)
140+
return rewriter.notifyMatchFailure(op,
141+
"operation not supported on i1 type");
142+
143+
// to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
144+
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
145+
// truncation.
146+
if (opReturnType.isInteger(1)) {
147+
auto constOne = rewriter.create<emitc::ConstantOp>(
148+
op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
149+
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
150+
op.getLoc(), operandType, adaptor.getIn(), constOne);
151+
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
152+
oneAndOperand);
153+
return success();
154+
}
155+
156+
bool isTruncation = operandType.getIntOrFloatBitWidth() >
157+
opReturnType.getIntOrFloatBitWidth();
158+
bool doUnsigned = castToUnsigned || isTruncation;
159+
160+
Type castType = opReturnType;
161+
// If the op is a ui variant and the type wanted as
162+
// return type isn't unsigned, we need to issue an unsigned type to do
163+
// the conversion.
164+
if (castType.isUnsignedInteger() != doUnsigned) {
165+
castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
166+
/*isSigned=*/!doUnsigned);
167+
}
168+
169+
Value actualOp = adaptor.getIn();
170+
// Adapt the signedness of the operand if necessary
171+
if (operandType.isUnsignedInteger() != doUnsigned) {
172+
Type correctSignednessType =
173+
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
174+
/*isSigned=*/!doUnsigned);
175+
actualOp = rewriter.template create<emitc::CastOp>(
176+
op.getLoc(), correctSignednessType, actualOp);
177+
}
178+
179+
auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
180+
actualOp);
181+
182+
// Cast to the expected output type
183+
if (castType != opReturnType) {
184+
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
185+
opReturnType, result);
186+
}
187+
188+
rewriter.replaceOp(op, result);
189+
return success();
190+
}
191+
};
192+
193+
template <typename ArithOp>
194+
class UnsignedCastConversion : public CastConversion<ArithOp, true> {
195+
using CastConversion<ArithOp, true>::CastConversion;
196+
};
197+
198+
template <typename ArithOp>
199+
class SignedCastConversion : public CastConversion<ArithOp, false> {
200+
using CastConversion<ArithOp, false>::CastConversion;
201+
};
202+
115203
template <typename ArithOp, typename EmitCOp>
116204
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
117205
public:
@@ -313,6 +401,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
313401
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
314402
CmpIOpConversion,
315403
SelectOpConversion,
404+
// Truncation is guaranteed for unsigned types.
405+
UnsignedCastConversion<arith::TruncIOp>,
406+
SignedCastConversion<arith::ExtSIOp>,
407+
UnsignedCastConversion<arith::ExtUIOp>,
316408
ItoFCastOpConversion<arith::SIToFPOp>,
317409
ItoFCastOpConversion<arith::UIToFPOp>,
318410
FtoICastOpConversion<arith::FPToSIOp>,

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

+7
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,10 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
6363
return %t: i1
6464
}
6565

66+
// -----
67+
68+
func.func @arith_extsi_i1_to_i32(%arg0: i1) {
69+
// expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
70+
%idx = arith.extsi %arg0 : i1 to i32
71+
return
72+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

+63
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,66 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
177177

178178
return
179179
}
180+
181+
// -----
182+
183+
func.func @arith_trunci(%arg0: i32) -> i8 {
184+
// CHECK-LABEL: arith_trunci
185+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
186+
// CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
187+
// CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
188+
// CHECK: emitc.cast %[[Trunc]] : ui8 to i8
189+
%truncd = arith.trunci %arg0 : i32 to i8
190+
191+
return %truncd : i8
192+
}
193+
194+
// -----
195+
196+
func.func @arith_trunci_to_i1(%arg0: i32) -> i1 {
197+
// CHECK-LABEL: arith_trunci_to_i1
198+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
199+
// CHECK: %[[Const:.*]] = "emitc.constant"
200+
// CHECK-SAME: value = 1
201+
// CHECK: %[[And:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
202+
// CHECK: emitc.cast %[[And]] : i32 to i1
203+
%truncd = arith.trunci %arg0 : i32 to i1
204+
205+
return %truncd : i1
206+
}
207+
208+
// -----
209+
210+
func.func @arith_extsi(%arg0: i32) {
211+
// CHECK-LABEL: arith_extsi
212+
// CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
213+
// CHECK: emitc.cast [[Arg0]] : i32 to i64
214+
%extd = arith.extsi %arg0 : i32 to i64
215+
216+
return
217+
}
218+
219+
// -----
220+
221+
func.func @arith_extui(%arg0: i32) {
222+
// CHECK-LABEL: arith_extui
223+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
224+
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
225+
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
226+
// CHECK: emitc.cast %[[Conv1]] : ui64 to i64
227+
%extd = arith.extui %arg0 : i32 to i64
228+
229+
return
230+
}
231+
232+
// -----
233+
234+
func.func @arith_extui_i1_to_i32(%arg0: i1) {
235+
// CHECK-LABEL: arith_extui_i1_to_i32
236+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i1)
237+
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i1 to ui1
238+
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui1 to ui32
239+
// CHECK: emitc.cast %[[Conv1]] : ui32 to i32
240+
%idx = arith.extui %arg0 : i1 to i32
241+
return
242+
}

0 commit comments

Comments
 (0)