15
15
16
16
#include " mlir/Dialect/Arith/IR/Arith.h"
17
17
#include " mlir/Dialect/EmitC/IR/EmitC.h"
18
+ #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
18
19
#include " mlir/IR/BuiltinAttributes.h"
19
20
#include " mlir/IR/BuiltinTypes.h"
20
21
#include " mlir/Support/LogicalResult.h"
@@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern
36
37
matchAndRewrite (arith::ConstantOp arithConst,
37
38
arith::ConstantOp::Adaptor adaptor,
38
39
ConversionPatternRewriter &rewriter) const override {
39
- rewriter.replaceOpWithNewOp <emitc::ConstantOp>(
40
- arithConst, arithConst.getType (), adaptor.getValue ());
40
+ Type newTy = this ->getTypeConverter ()->convertType (arithConst.getType ());
41
+ if (!newTy)
42
+ return rewriter.notifyMatchFailure (arithConst, " type conversion failed" );
43
+ rewriter.replaceOpWithNewOp <emitc::ConstantOp>(arithConst, newTy,
44
+ adaptor.getValue ());
41
45
return success ();
42
46
}
43
47
};
@@ -201,6 +205,35 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
201
205
}
202
206
};
203
207
208
+ // / Check if the signedness of type \p ty matches the expected
209
+ // / signedness, and issue a type with the correct signedness if
210
+ // / necessary.
211
+ Type adaptIntegralTypeSignedness (Type ty, bool needsUnsigned) {
212
+ if (isa<IntegerType>(ty)) {
213
+ // Turns signless integers into signed integers.
214
+ if (ty.isUnsignedInteger () != needsUnsigned) {
215
+ auto signedness = needsUnsigned
216
+ ? IntegerType::SignednessSemantics::Unsigned
217
+ : IntegerType::SignednessSemantics::Signed;
218
+ return IntegerType::get (ty.getContext (), ty.getIntOrFloatBitWidth (),
219
+ signedness);
220
+ }
221
+ } else if (emitc::isAnySizeTType (ty)) {
222
+ if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
223
+ if (needsUnsigned)
224
+ return emitc::SizeTType::get (ty.getContext ());
225
+ return emitc::SignedSizeTType::get (ty.getContext ());
226
+ }
227
+ }
228
+ return ty;
229
+ }
230
+
231
+ // / Insert a cast operation to type \p ty if \p val
232
+ // / does not have this type.
233
+ Value adaptValueType (Value val, ConversionPatternRewriter &rewriter, Type ty) {
234
+ return rewriter.createOrFold <emitc::CastOp>(val.getLoc (), ty, val);
235
+ }
236
+
204
237
class CmpIOpConversion : public OpConversionPattern <arith::CmpIOp> {
205
238
public:
206
239
using OpConversionPattern::OpConversionPattern;
@@ -250,31 +283,25 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
250
283
ConversionPatternRewriter &rewriter) const override {
251
284
252
285
Type type = adaptor.getLhs ().getType ();
253
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
254
- return rewriter.notifyMatchFailure (op, " expected integer or index type" );
286
+ if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
287
+ type)) {
288
+ return rewriter.notifyMatchFailure (
289
+ op, " expected integer or size_t/ssize_t type" );
255
290
}
256
291
257
292
bool needsUnsigned = needsUnsignedCmp (op.getPredicate ());
258
293
emitc::CmpPredicate pred = toEmitCPred (op.getPredicate ());
259
- Type arithmeticType = type;
260
- if (type.isUnsignedInteger () != needsUnsigned) {
261
- arithmeticType = rewriter.getIntegerType (type.getIntOrFloatBitWidth (),
262
- /* isSigned=*/ !needsUnsigned);
263
- }
264
- Value lhs = adaptor.getLhs ();
265
- Value rhs = adaptor.getRhs ();
266
- if (arithmeticType != type) {
267
- lhs = rewriter.template create <emitc::CastOp>(op.getLoc (), arithmeticType,
268
- lhs);
269
- rhs = rewriter.template create <emitc::CastOp>(op.getLoc (), arithmeticType,
270
- rhs);
271
- }
294
+
295
+ Type arithmeticType = adaptIntegralTypeSignedness (type, needsUnsigned);
296
+ Value lhs = adaptValueType (adaptor.getLhs (), rewriter, arithmeticType);
297
+ Value rhs = adaptValueType (adaptor.getRhs (), rewriter, arithmeticType);
298
+
272
299
rewriter.replaceOpWithNewOp <emitc::CmpOp>(op, op.getType (), pred, lhs, rhs);
273
300
return success ();
274
301
}
275
302
};
276
303
277
- template <typename ArithOp, bool needsUnsigned >
304
+ template <typename ArithOp, bool castToUnsigned >
278
305
class CastConversion : public OpConversionPattern <ArithOp> {
279
306
public:
280
307
using OpConversionPattern<ArithOp>::OpConversionPattern;
@@ -284,52 +311,58 @@ class CastConversion : public OpConversionPattern<ArithOp> {
284
311
ConversionPatternRewriter &rewriter) const override {
285
312
286
313
Type opReturnType = this ->getTypeConverter ()->convertType (op.getType ());
287
- if (!isa_and_nonnull<IntegerType>(opReturnType)) {
288
- return rewriter.notifyMatchFailure (op, " expected integer result type" );
289
- }
314
+ if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
315
+ opReturnType))
316
+ return rewriter.notifyMatchFailure (
317
+ op, " expected integer or size_t/ssize_t result type" );
290
318
291
319
if (adaptor.getOperands ().size () != 1 ) {
292
320
return rewriter.notifyMatchFailure (
293
321
op, " CastConversion only supports unary ops" );
294
322
}
295
323
296
324
Type operandType = adaptor.getIn ().getType ();
297
- if (!isa_and_nonnull<IntegerType>(operandType)) {
298
- return rewriter.notifyMatchFailure (op, " expected integer operand type" );
325
+ if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
326
+ operandType))
327
+ return rewriter.notifyMatchFailure (
328
+ op, " expected integer or size_t/ssize_t operand type" );
329
+
330
+ // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
331
+ // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
332
+ // truncation.
333
+ if (opReturnType.isInteger (1 )) {
334
+ Type attrType = (emitc::isAnySizeTType (operandType))
335
+ ? rewriter.getIndexType ()
336
+ : operandType;
337
+ auto constOne = rewriter.create <emitc::ConstantOp>(
338
+ op.getLoc (), operandType, rewriter.getIntegerAttr (attrType, 1 ));
339
+ auto oneAndOperand = rewriter.create <emitc::BitwiseAndOp>(
340
+ op.getLoc (), operandType, adaptor.getIn (), constOne);
341
+ rewriter.replaceOpWithNewOp <emitc::CastOp>(op, opReturnType,
342
+ oneAndOperand);
343
+ return success ();
299
344
}
300
345
301
- bool isTruncation = operandType.getIntOrFloatBitWidth () >
302
- opReturnType.getIntOrFloatBitWidth ();
303
- bool doUnsigned = needsUnsigned || isTruncation;
304
-
305
- Type castType = opReturnType;
306
- // For int conversions: if the op is a ui variant and the type wanted as
307
- // return type isn't unsigned, we need to issue an unsigned type to do
308
- // the conversion.
309
- if (castType.isUnsignedInteger () != doUnsigned) {
310
- castType = rewriter.getIntegerType (opReturnType.getIntOrFloatBitWidth (),
311
- /* isSigned=*/ !doUnsigned);
312
- }
346
+ bool isTruncation =
347
+ (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
348
+ operandType.getIntOrFloatBitWidth () >
349
+ opReturnType.getIntOrFloatBitWidth ());
350
+ bool doUnsigned = castToUnsigned || isTruncation;
313
351
314
- Value actualOp = adaptor.getIn ();
315
- // Fix the signedness of the operand if necessary
316
- if (operandType.isUnsignedInteger () != doUnsigned) {
317
- Type correctSignednessType =
318
- rewriter.getIntegerType (operandType.getIntOrFloatBitWidth (),
319
- /* isSigned=*/ !doUnsigned);
320
- actualOp = rewriter.template create <emitc::CastOp>(
321
- op.getLoc (), correctSignednessType, actualOp);
322
- }
352
+ // Adapt the signedness of the result (bitwidth-preserving cast)
353
+ // This is needed e.g., if the return type is signless.
354
+ Type castDestType = adaptIntegralTypeSignedness (opReturnType, doUnsigned);
323
355
324
- auto result = rewriter.template create <emitc::CastOp>(op.getLoc (), castType,
325
- actualOp);
356
+ // Adapt the signedness of the operand (bitwidth-preserving cast)
357
+ Type castSrcType = adaptIntegralTypeSignedness (operandType, doUnsigned);
358
+ Value actualOp = adaptValueType (adaptor.getIn (), rewriter, castSrcType);
326
359
327
- // Fix the signedness of what this operation returns (for integers,
328
- // the arith ops want signless results)
329
- if (castType != opReturnType) {
330
- result = rewriter. template create <emitc::CastOp>(op. getLoc (),
331
- opReturnType, result);
332
- }
360
+ // Actual cast (may change bitwidth)
361
+ auto cast = rewriter. template create <emitc::CastOp>(op. getLoc (),
362
+ castDestType, actualOp);
363
+
364
+ // Cast to the expected output type
365
+ auto result = adaptValueType (cast, rewriter, opReturnType);
333
366
334
367
rewriter.replaceOp (op, result);
335
368
return success ();
@@ -355,7 +388,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
355
388
matchAndRewrite (ArithOp arithOp, typename ArithOp::Adaptor adaptor,
356
389
ConversionPatternRewriter &rewriter) const override {
357
390
358
- rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, arithOp.getType (),
391
+ Type newTy = this ->getTypeConverter ()->convertType (arithOp.getType ());
392
+ if (!newTy)
393
+ return rewriter.notifyMatchFailure (arithOp,
394
+ " converting result type failed" );
395
+ rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, newTy,
359
396
adaptor.getOperands ());
360
397
361
398
return success ();
@@ -372,17 +409,17 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
372
409
ConversionPatternRewriter &rewriter) const override {
373
410
374
411
Type type = this ->getTypeConverter ()->convertType (op.getType ());
375
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
376
- return rewriter.notifyMatchFailure (op, " expected integer type" );
412
+ if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
413
+ type)) {
414
+ return rewriter.notifyMatchFailure (
415
+ op, " expected integer or size_t/ssize_t type" );
377
416
}
378
417
379
418
if (type.isInteger (1 )) {
380
419
// arith expects wrap-around arithmethic, which doesn't happen on `bool`.
381
420
return rewriter.notifyMatchFailure (op, " i1 type is not implemented" );
382
421
}
383
422
384
- Value lhs = adaptor.getLhs ();
385
- Value rhs = adaptor.getRhs ();
386
423
Type arithmeticType = type;
387
424
if ((type.isSignlessInteger () || type.isSignedInteger ()) &&
388
425
!bitEnumContainsAll (op.getOverflowFlags (),
@@ -392,20 +429,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
392
429
arithmeticType = rewriter.getIntegerType (type.getIntOrFloatBitWidth (),
393
430
/* isSigned=*/ false );
394
431
}
395
- if (arithmeticType != type) {
396
- lhs = rewriter.template create <emitc::CastOp>(op.getLoc (), arithmeticType,
397
- lhs);
398
- rhs = rewriter.template create <emitc::CastOp>(op.getLoc (), arithmeticType,
399
- rhs);
400
- }
401
432
402
- Value result = rewriter.template create <EmitCOp>(op.getLoc (),
403
- arithmeticType, lhs, rhs);
433
+ Value lhs = adaptValueType (adaptor.getLhs (), rewriter, arithmeticType);
434
+ Value rhs = adaptValueType (adaptor.getRhs (), rewriter, arithmeticType);
435
+
436
+ Value arithmeticResult = rewriter.template create <EmitCOp>(
437
+ op.getLoc (), arithmeticType, lhs, rhs);
438
+
439
+ Value result = adaptValueType (arithmeticResult, rewriter, type);
404
440
405
- if (arithmeticType != type) {
406
- result =
407
- rewriter.template create <emitc::CastOp>(op.getLoc (), type, result);
408
- }
409
441
rewriter.replaceOp (op, result);
410
442
return success ();
411
443
}
@@ -535,6 +567,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
535
567
RewritePatternSet &patterns) {
536
568
MLIRContext *ctx = patterns.getContext ();
537
569
570
+ mlir::populateEmitCSizeTypeConversionPatterns (typeConverter);
571
+
538
572
// clang-format off
539
573
patterns.add <
540
574
ArithConstantOpConversionPattern,
@@ -554,6 +588,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
554
588
UnsignedCastConversion<arith::TruncIOp>,
555
589
SignedCastConversion<arith::ExtSIOp>,
556
590
UnsignedCastConversion<arith::ExtUIOp>,
591
+ SignedCastConversion<arith::IndexCastOp>,
592
+ UnsignedCastConversion<arith::IndexCastUIOp>,
557
593
ItoFCastOpConversion<arith::SIToFPOp>,
558
594
ItoFCastOpConversion<arith::UIToFPOp>,
559
595
FtoICastOpConversion<arith::FPToSIOp>,
0 commit comments