Skip to content

Commit e7188da

Browse files
authored
[FXML-4614] Add EmitC index types, lower arith.index_cast, arith.index_castui (#183)
1 parent 2a3ebb6 commit e7188da

File tree

16 files changed

+301
-119
lines changed

16 files changed

+301
-119
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

+4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ bool isIntegerIndexOrOpaqueType(Type type);
4343

4444
/// Determines whether \p type is a valid floating-point type in EmitC.
4545
bool isSupportedFloatType(mlir::Type type);
46+
47+
/// Determines whether \p type is a emitc.size_t/ssize_t type.
48+
bool isAnySizeTType(mlir::Type type);
49+
4650
} // namespace emitc
4751
} // namespace mlir
4852

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

+4-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
5151
def CExpression : NativeOpTrait<"emitc::CExpression">;
5252

5353
// Types only used in binary arithmetic operations.
54-
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index, EmitC_OpaqueType]>;
54+
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index,
55+
EmitC_SignedSizeT, EmitC_SizeT, EmitC_OpaqueType]>;
5556
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;
5657

5758
def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
@@ -287,6 +288,7 @@ def EmitC_CastOp : EmitC_Op<"cast",
287288
let arguments = (ins EmitCType:$source);
288289
let results = (outs EmitCType:$dest);
289290
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
291+
let hasFolder = 1;
290292
}
291293

292294
def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
@@ -470,7 +472,7 @@ def EmitC_ForOp : EmitC_Op<"for",
470472
upper bound and step respectively, and defines an SSA value for its
471473
induction variable. It has one region capturing the loop body. The induction
472474
variable is represented as an argument of this region. This SSA value is a
473-
signless integer or index. The step is a value of same type.
475+
signless integer, or an index. The step is a value of same type.
474476

475477
This operation has no result. The body region must contain exactly one block
476478
that terminates with `emitc.yield`. Calling ForOp::build will create such a

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

+8
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,12 @@ def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
130130
let assemblyFormat = "`<` qualified($pointee) `>`";
131131
}
132132

133+
def EmitC_SignedSizeT : EmitC_Type<"SignedSizeT", "ssize_t"> {
134+
let summary = "EmitC signed size type";
135+
}
136+
137+
def EmitC_SizeT : EmitC_Type<"SizeT", "size_t"> {
138+
let summary = "EmitC unsigned size type";
139+
}
140+
133141
#endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//===- TypeConversions.h - Convert signless types into C/C++ types -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Transforms/DialectConversion.h"
10+
11+
namespace mlir {
12+
void populateEmitCSizeTypeConversionPatterns(mlir::TypeConverter &converter);
13+
} // namespace mlir

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

+105-69
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/EmitC/IR/EmitC.h"
18+
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1819
#include "mlir/IR/BuiltinAttributes.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/Support/LogicalResult.h"
@@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern
3637
matchAndRewrite(arith::ConstantOp arithConst,
3738
arith::ConstantOp::Adaptor adaptor,
3839
ConversionPatternRewriter &rewriter) const override {
39-
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
40-
arithConst, arithConst.getType(), adaptor.getValue());
40+
Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
41+
if (!newTy)
42+
return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
43+
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
44+
adaptor.getValue());
4145
return success();
4246
}
4347
};
@@ -201,6 +205,35 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
201205
}
202206
};
203207

208+
/// Check if the signedness of type \p ty matches the expected
209+
/// signedness, and issue a type with the correct signedness if
210+
/// necessary.
211+
Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
212+
if (isa<IntegerType>(ty)) {
213+
// Turns signless integers into signed integers.
214+
if (ty.isUnsignedInteger() != needsUnsigned) {
215+
auto signedness = needsUnsigned
216+
? IntegerType::SignednessSemantics::Unsigned
217+
: IntegerType::SignednessSemantics::Signed;
218+
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
219+
signedness);
220+
}
221+
} else if (emitc::isAnySizeTType(ty)) {
222+
if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
223+
if (needsUnsigned)
224+
return emitc::SizeTType::get(ty.getContext());
225+
return emitc::SignedSizeTType::get(ty.getContext());
226+
}
227+
}
228+
return ty;
229+
}
230+
231+
/// Insert a cast operation to type \p ty if \p val
232+
/// does not have this type.
233+
Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
234+
return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
235+
}
236+
204237
class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
205238
public:
206239
using OpConversionPattern::OpConversionPattern;
@@ -250,31 +283,25 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
250283
ConversionPatternRewriter &rewriter) const override {
251284

252285
Type type = adaptor.getLhs().getType();
253-
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
254-
return rewriter.notifyMatchFailure(op, "expected integer or index type");
286+
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
287+
type)) {
288+
return rewriter.notifyMatchFailure(
289+
op, "expected integer or size_t/ssize_t type");
255290
}
256291

257292
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
258293
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
259-
Type arithmeticType = type;
260-
if (type.isUnsignedInteger() != needsUnsigned) {
261-
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
262-
/*isSigned=*/!needsUnsigned);
263-
}
264-
Value lhs = adaptor.getLhs();
265-
Value rhs = adaptor.getRhs();
266-
if (arithmeticType != type) {
267-
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
268-
lhs);
269-
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
270-
rhs);
271-
}
294+
295+
Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
296+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
297+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
298+
272299
rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
273300
return success();
274301
}
275302
};
276303

277-
template <typename ArithOp, bool needsUnsigned>
304+
template <typename ArithOp, bool castToUnsigned>
278305
class CastConversion : public OpConversionPattern<ArithOp> {
279306
public:
280307
using OpConversionPattern<ArithOp>::OpConversionPattern;
@@ -284,52 +311,58 @@ class CastConversion : public OpConversionPattern<ArithOp> {
284311
ConversionPatternRewriter &rewriter) const override {
285312

286313
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
287-
if (!isa_and_nonnull<IntegerType>(opReturnType)) {
288-
return rewriter.notifyMatchFailure(op, "expected integer result type");
289-
}
314+
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
315+
opReturnType))
316+
return rewriter.notifyMatchFailure(
317+
op, "expected integer or size_t/ssize_t result type");
290318

291319
if (adaptor.getOperands().size() != 1) {
292320
return rewriter.notifyMatchFailure(
293321
op, "CastConversion only supports unary ops");
294322
}
295323

296324
Type operandType = adaptor.getIn().getType();
297-
if (!isa_and_nonnull<IntegerType>(operandType)) {
298-
return rewriter.notifyMatchFailure(op, "expected integer operand type");
325+
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
326+
operandType))
327+
return rewriter.notifyMatchFailure(
328+
op, "expected integer or size_t/ssize_t operand type");
329+
330+
// to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
331+
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
332+
// truncation.
333+
if (opReturnType.isInteger(1)) {
334+
Type attrType = (emitc::isAnySizeTType(operandType))
335+
? rewriter.getIndexType()
336+
: operandType;
337+
auto constOne = rewriter.create<emitc::ConstantOp>(
338+
op.getLoc(), operandType, rewriter.getIntegerAttr(attrType, 1));
339+
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
340+
op.getLoc(), operandType, adaptor.getIn(), constOne);
341+
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
342+
oneAndOperand);
343+
return success();
299344
}
300345

301-
bool isTruncation = operandType.getIntOrFloatBitWidth() >
302-
opReturnType.getIntOrFloatBitWidth();
303-
bool doUnsigned = needsUnsigned || isTruncation;
304-
305-
Type castType = opReturnType;
306-
// For int conversions: if the op is a ui variant and the type wanted as
307-
// return type isn't unsigned, we need to issue an unsigned type to do
308-
// the conversion.
309-
if (castType.isUnsignedInteger() != doUnsigned) {
310-
castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
311-
/*isSigned=*/!doUnsigned);
312-
}
346+
bool isTruncation =
347+
(isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
348+
operandType.getIntOrFloatBitWidth() >
349+
opReturnType.getIntOrFloatBitWidth());
350+
bool doUnsigned = castToUnsigned || isTruncation;
313351

314-
Value actualOp = adaptor.getIn();
315-
// Fix the signedness of the operand if necessary
316-
if (operandType.isUnsignedInteger() != doUnsigned) {
317-
Type correctSignednessType =
318-
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
319-
/*isSigned=*/!doUnsigned);
320-
actualOp = rewriter.template create<emitc::CastOp>(
321-
op.getLoc(), correctSignednessType, actualOp);
322-
}
352+
// Adapt the signedness of the result (bitwidth-preserving cast)
353+
// This is needed e.g., if the return type is signless.
354+
Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
323355

324-
auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
325-
actualOp);
356+
// Adapt the signedness of the operand (bitwidth-preserving cast)
357+
Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
358+
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
326359

327-
// Fix the signedness of what this operation returns (for integers,
328-
// the arith ops want signless results)
329-
if (castType != opReturnType) {
330-
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
331-
opReturnType, result);
332-
}
360+
// Actual cast (may change bitwidth)
361+
auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
362+
castDestType, actualOp);
363+
364+
// Cast to the expected output type
365+
auto result = adaptValueType(cast, rewriter, opReturnType);
333366

334367
rewriter.replaceOp(op, result);
335368
return success();
@@ -355,7 +388,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
355388
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
356389
ConversionPatternRewriter &rewriter) const override {
357390

358-
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
391+
Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
392+
if (!newTy)
393+
return rewriter.notifyMatchFailure(arithOp,
394+
"converting result type failed");
395+
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
359396
adaptor.getOperands());
360397

361398
return success();
@@ -372,17 +409,17 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
372409
ConversionPatternRewriter &rewriter) const override {
373410

374411
Type type = this->getTypeConverter()->convertType(op.getType());
375-
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
376-
return rewriter.notifyMatchFailure(op, "expected integer type");
412+
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
413+
type)) {
414+
return rewriter.notifyMatchFailure(
415+
op, "expected integer or size_t/ssize_t type");
377416
}
378417

379418
if (type.isInteger(1)) {
380419
// arith expects wrap-around arithmethic, which doesn't happen on `bool`.
381420
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
382421
}
383422

384-
Value lhs = adaptor.getLhs();
385-
Value rhs = adaptor.getRhs();
386423
Type arithmeticType = type;
387424
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
388425
!bitEnumContainsAll(op.getOverflowFlags(),
@@ -392,20 +429,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
392429
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
393430
/*isSigned=*/false);
394431
}
395-
if (arithmeticType != type) {
396-
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
397-
lhs);
398-
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
399-
rhs);
400-
}
401432

402-
Value result = rewriter.template create<EmitCOp>(op.getLoc(),
403-
arithmeticType, lhs, rhs);
433+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
434+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
435+
436+
Value arithmeticResult = rewriter.template create<EmitCOp>(
437+
op.getLoc(), arithmeticType, lhs, rhs);
438+
439+
Value result = adaptValueType(arithmeticResult, rewriter, type);
404440

405-
if (arithmeticType != type) {
406-
result =
407-
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
408-
}
409441
rewriter.replaceOp(op, result);
410442
return success();
411443
}
@@ -535,6 +567,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
535567
RewritePatternSet &patterns) {
536568
MLIRContext *ctx = patterns.getContext();
537569

570+
mlir::populateEmitCSizeTypeConversionPatterns(typeConverter);
571+
538572
// clang-format off
539573
patterns.add<
540574
ArithConstantOpConversionPattern,
@@ -554,6 +588,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
554588
UnsignedCastConversion<arith::TruncIOp>,
555589
SignedCastConversion<arith::ExtSIOp>,
556590
UnsignedCastConversion<arith::ExtUIOp>,
591+
SignedCastConversion<arith::IndexCastOp>,
592+
UnsignedCastConversion<arith::IndexCastUIOp>,
557593
ItoFCastOpConversion<arith::SIToFPOp>,
558594
ItoFCastOpConversion<arith::UIToFPOp>,
559595
FtoICastOpConversion<arith::FPToSIOp>,

mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC
1111
LINK_LIBS PUBLIC
1212
MLIRArithDialect
1313
MLIREmitCDialect
14+
MLIREmitCTransforms
1415
MLIRPass
1516
MLIRTransformUtils
1617
)

0 commit comments

Comments
 (0)