From 51ea14167c8f789cf6d52ca0b698e27178a370c0 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 16 Apr 2024 10:44:05 +0100 Subject: [PATCH 1/8] Support more Arith integer binary ops --- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 45 +++++++++++++++++ .../ArithToEmitC/arith-to-emitc.mlir | 48 +++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 62067c5e25644..31a937f7c4c93 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -443,6 +443,46 @@ class IntegerOpConversion final : public OpConversionPattern { } }; +template +class BitwiseOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type type = this->getTypeConverter()->convertType(op.getType()); + if (!isa_and_nonnull( + type)) { + return rewriter.notifyMatchFailure( + op, "expected integer or size_t/ssize_t type"); + } + + if (type.isInteger(1)) { + if (!booleansLegal) + return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); + + rewriter.replaceOpWithNewOp(op, type, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + Type arithmeticType = adaptIntegralTypeSignedness(type, true); + + Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); + Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); + + Value arithmeticResult = rewriter.template create( + op.getLoc(), arithmeticType, lhs, rhs); + + Value result = adaptValueType(arithmeticResult, rewriter, type); + + rewriter.replaceOp(op, result); + return success(); + } +}; + class SelectOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -581,6 +621,11 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, IntegerOpConversion, IntegerOpConversion, IntegerOpConversion, + BitwiseOpConversion, + BitwiseOpConversion, + BitwiseOpConversion, + BitwiseOpConversion, + BitwiseOpConversion, CmpFOpConversion, CmpIOpConversion, SelectOpConversion, diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index d3f2dcafd180a..f84164217983f 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -92,6 +92,54 @@ func.func @arith_index(%arg0: i32, %arg1: i32) { return } +// ----- + +// CHECK-LABEL: arith_bitwise +func.func @arith_bitwise(%arg0: i32, %arg1: i32) { + // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 + // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[AND]] : ui32 to i32 + %5 = arith.andi %arg0, %arg1 : i32 + // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 + // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[OR]] : ui32 to i32 + %6 = arith.ori %arg0, %arg1 : i32 + // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[XOR:[^ ]*]] = emitc.bitwise_xor %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 + // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[XOR]] : ui32 to i32 + %7 = arith.xori %arg0, %arg1 : i32 + // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 + // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SHL]] : ui32 to i32 + %8 = arith.shli %arg0, %arg1 : i32 + // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 + // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SHR]] : ui32 to i32 + %9 = arith.shrsi %arg0, %arg1 : i32 + + return +} + +// ----- + +// CHECK-LABEL: arith_bitwise_bool +func.func @arith_bitwise_bool(%arg0: i1, %arg1: i1) { + // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %arg0, %arg1 : (i1, i1) -> i1 + %5 = arith.andi %arg0, %arg1 : i1 + // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %arg0, %arg1 : (i1, i1) -> i1 + %6 = arith.ori %arg0, %arg1 : i1 + // CHECK: %[[xor:[^ ]*]] = emitc.bitwise_xor %arg0, %arg1 : (i1, i1) -> i1 + %7 = arith.xori %arg0, %arg1 : i1 + + return +} + + // ----- func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () { From 581c1ab745a207c9879ca42a8f35143f5c657e5c Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 21 May 2024 07:42:18 +0100 Subject: [PATCH 2/8] Support shrsi --- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 32 ++++++++++++++----- .../ArithToEmitC/arith-to-emitc.mlir | 25 ++++++++------- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 31a937f7c4c93..b30eeba767065 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -443,8 +443,9 @@ class IntegerOpConversion final : public OpConversionPattern { } }; -template -class BitwiseOpConversion final : public OpConversionPattern { +template +class BitwiseOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -468,7 +469,7 @@ class BitwiseOpConversion final : public OpConversionPattern { return success(); } - Type arithmeticType = adaptIntegralTypeSignedness(type, true); + Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp); Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); @@ -483,6 +484,20 @@ class BitwiseOpConversion final : public OpConversionPattern { } }; +template +class SignedBitwiseOpConversion final + : public BitwiseOpConversion { + using BitwiseOpConversion::BitwiseOpConversion; +}; + +template +class UnsignedBitwiseOpConversion final + : public BitwiseOpConversion { + using BitwiseOpConversion::BitwiseOpConversion; +}; + class SelectOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -621,11 +636,12 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, IntegerOpConversion, IntegerOpConversion, IntegerOpConversion, - BitwiseOpConversion, - BitwiseOpConversion, - BitwiseOpConversion, - BitwiseOpConversion, - BitwiseOpConversion, + UnsignedBitwiseOpConversion, + UnsignedBitwiseOpConversion, + UnsignedBitwiseOpConversion, + UnsignedBitwiseOpConversion, + SignedBitwiseOpConversion, + UnsignedBitwiseOpConversion, CmpFOpConversion, CmpIOpConversion, SelectOpConversion, diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index f84164217983f..d885aa9215527 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -95,32 +95,35 @@ func.func @arith_index(%arg0: i32, %arg1: i32) { // ----- // CHECK-LABEL: arith_bitwise +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 func.func @arith_bitwise(%arg0: i32, %arg1: i32) { - // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 - // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[AND]] : ui32 to i32 %5 = arith.andi %arg0, %arg1 : i32 - // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 - // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[OR]] : ui32 to i32 %6 = arith.ori %arg0, %arg1 : i32 - // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 - // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 // CHECK: %[[XOR:[^ ]*]] = emitc.bitwise_xor %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[XOR]] : ui32 to i32 %7 = arith.xori %arg0, %arg1 : i32 - // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 - // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SHL]] : ui32 to i32 %8 = arith.shli %arg0, %arg1 : i32 - // CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32 - // CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32 + // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 + // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SHR]] : ui32 to i32 - %9 = arith.shrsi %arg0, %arg1 : i32 + %9 = arith.shrui %arg0, %arg1 : i32 + // CHECK: %[[SHRS:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32 + %10 = arith.shrsi %arg0, %arg1 : i32 return } From adc0ab1192b78dd22567ee6f1ddaf992e738eeba Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 21 May 2024 07:49:15 +0100 Subject: [PATCH 3/8] Also add checks for illegal ops --- .../arith-to-emitc-unsupported.mlir | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 6d52c73af37e4..38abad1b22985 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -78,3 +78,27 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 { %t = arith.fptoui %arg0 : f32 to i1 return %t: i1 } + +// ----- + +func.func @arith_shli_i1(%arg0: i1, %arg1: i1) { + // expected-error @+1 {{failed to legalize operation 'arith.shli'}} + %shli = arith.shli %arg0, %arg1 : i1 + return +} + +// ----- + +func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) { + // expected-error @+1 {{failed to legalize operation 'arith.shrsi'}} + %shrsi = arith.shrsi %arg0, %arg1 : i1 + return +} + +// ----- + +func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) { + // expected-error @+1 {{failed to legalize operation 'arith.shrui'}} + %shrui = arith.shrui %arg0, %arg1 : i1 + return +} From 18cc6c01092b0c5d0a8e82dcedd92a63ed39bb15 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 27 May 2024 09:55:18 +0100 Subject: [PATCH 4/8] SHL/SHR check to put poison instead of UB --- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 128 ++++++++++++++--- .../ArithToEmitC/arith-to-emitc.mlir | 129 ++++++++++++++++-- 2 files changed, 225 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index b30eeba767065..a0416d4420fcb 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Region.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" @@ -443,8 +444,7 @@ class IntegerOpConversion final : public OpConversionPattern { } }; -template +template class BitwiseOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -460,16 +460,15 @@ class BitwiseOpConversion : public OpConversionPattern { op, "expected integer or size_t/ssize_t type"); } + // There is no unsigned i1 type, bitwise ops can be performed directly + // on booleans. if (type.isInteger(1)) { - if (!booleansLegal) - return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); - rewriter.replaceOpWithNewOp(op, type, adaptor.getLhs(), adaptor.getRhs()); return success(); } - Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp); + Type arithmeticType = adaptIntegralTypeSignedness(type, true); Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); @@ -484,18 +483,103 @@ class BitwiseOpConversion : public OpConversionPattern { } }; -template -class SignedBitwiseOpConversion final - : public BitwiseOpConversion { - using BitwiseOpConversion::BitwiseOpConversion; +template +class ShiftOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type type = this->getTypeConverter()->convertType(op.getType()); + if (!isa_and_nonnull( + type)) { + return rewriter.notifyMatchFailure( + op, "expected integer or size_t/ssize_t type"); + } + + if (type.isInteger(1)) { + return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); + } + + Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp); + + Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); + Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); + + // Add a runtime check for overflow + // This is an abuse of the size_t type since we're potentially using values + // below -1. + Value width; + Type sizeTType = (isa(type)) ? arithmeticType + : (isUnsignedOp) + ? (Type)(emitc::SizeTType::get(op.getContext())) + : (Type)(emitc::SignedSizeTType::get(op.getContext())); + if (isa(type)) { + Value eight = rewriter.create( + op.getLoc(), sizeTType, rewriter.getIndexAttr(8)); + emitc::CallOpaqueOp sizeOfCall = rewriter.create( + op.getLoc(), sizeTType, "sizeof", SmallVector({eight})); + width = rewriter.create(op.getLoc(), sizeTType, eight, + sizeOfCall.getResult(0)); + } else { + width = rewriter.create( + op.getLoc(), sizeTType, + rewriter.getIntegerAttr(sizeTType, type.getIntOrFloatBitWidth())); + } + + Value oobTest; + Value excessCheck = rewriter.create( + op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); + oobTest = excessCheck; + if (!isUnsignedOp) { + Value zero = rewriter.create( + op.getLoc(), sizeTType, + (isa(sizeTType) ? rewriter.getIntegerAttr(sizeTType, 0) + : rewriter.getIndexAttr(0))); + Value defaultCheck = + rewriter.create(op.getLoc(), rewriter.getI1Type(), + emitc::CmpPredicate::ge, rhs, zero); + oobTest = rewriter.create( + op.getLoc(), rewriter.getI1Type(), excessCheck, defaultCheck); + } + + Value poison = rewriter.create( + op.getLoc(), arithmeticType, + (isa(arithmeticType) + ? rewriter.getIntegerAttr(arithmeticType, 0) + : rewriter.getIndexAttr(0))); + + emitc::ExpressionOp ternary = rewriter.create( + op.getLoc(), arithmeticType, /*do_not_inline=*/false); + Block &bodyBlock = ternary.getBodyRegion().emplaceBlock(); + auto currentPoint = rewriter.getInsertionPoint(); + rewriter.setInsertionPointToStart(&bodyBlock); + Value arithmeticResult = + rewriter.create(op.getLoc(), arithmeticType, lhs, rhs); + Value resultOrPoison = rewriter.create( + op.getLoc(), arithmeticType, oobTest, arithmeticResult, poison); + rewriter.create(op.getLoc(), resultOrPoison); + rewriter.setInsertionPoint(op->getBlock(), currentPoint); + + Value result = adaptValueType(ternary, rewriter, type); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +template +class SignedShiftOpConversion final + : public ShiftOpConversion { + using ShiftOpConversion::ShiftOpConversion; }; -template -class UnsignedBitwiseOpConversion final - : public BitwiseOpConversion { - using BitwiseOpConversion::BitwiseOpConversion; +template +class UnsignedShiftOpConversion final + : public ShiftOpConversion { + using ShiftOpConversion::ShiftOpConversion; }; class SelectOpConversion : public OpConversionPattern { @@ -636,12 +720,12 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, IntegerOpConversion, IntegerOpConversion, IntegerOpConversion, - UnsignedBitwiseOpConversion, - UnsignedBitwiseOpConversion, - UnsignedBitwiseOpConversion, - UnsignedBitwiseOpConversion, - SignedBitwiseOpConversion, - UnsignedBitwiseOpConversion, + BitwiseOpConversion, + BitwiseOpConversion, + BitwiseOpConversion, + UnsignedShiftOpConversion, + SignedShiftOpConversion, + UnsignedShiftOpConversion, CmpFOpConversion, CmpIOpConversion, SelectOpConversion, diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index d885aa9215527..cf95d375ba1e5 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -112,18 +112,127 @@ func.func @arith_bitwise(%arg0: i32, %arg1: i32) { // CHECK: %[[XOR:[^ ]*]] = emitc.bitwise_xor %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[XOR]] : ui32 to i32 %7 = arith.xori %arg0, %arg1 : i32 - // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 - // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 + + return +} + +// ----- + +// CHECK-LABEL: arith_shift_left +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func.func @arith_shift_left(%arg0: i32, %arg1: i32) { + // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 + // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 + // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32 + // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1 + // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0 + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 { // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 - // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SHL]] : ui32 to i32 - %8 = arith.shli %arg0, %arg1 : i32 - // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 - // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 + // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32 + // CHECK: emitc.yield %[[Ternary]] : ui32 + // CHECK: } + // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32 + %1 = arith.shli %arg0, %arg1 : i32 + return +} + +// ----- + +// CHECK-LABEL: arith_shift_right +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func.func @arith_shift_right(%arg0: i32, %arg1: i32) { + // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 + // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 + // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32 + // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1 + // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32 + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 { // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 - // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SHR]] : ui32 to i32 - %9 = arith.shrui %arg0, %arg1 : i32 - // CHECK: %[[SHRS:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32 - %10 = arith.shrsi %arg0, %arg1 : i32 + // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32 + // CHECK: emitc.yield %[[Ternary]] : ui32 + // CHECK: } + // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32 + %2 = arith.shrui %arg0, %arg1 : i32 + + // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}i32 + // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[ARG1]], %[[SSizeConstant]] : (i32, i32) -> i1 + // CHECK-DAG: %[[SZeroConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32 + // CHECK-DAG: %[[SCmpPositive:[^ ]*]] = emitc.cmp ge, %[[ARG1]], %[[SZeroConstant]] : (i32, i32) -> i1 + // CHECK-DAG: %[[SCmpAnd:[^ ]*]] = emitc.logical_and %[[SCmpNoExcess]], %[[SCmpPositive]] : i1, i1 + // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32 + // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32 { + // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32 + // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpAnd]], %[[SHRSI]], %[[SZero]] : i32 + // CHECK: emitc.yield %[[STernary]] : i32 + // CHECK: } + %3 = arith.shrsi %arg0, %arg1 : i32 + + return +} + +// ----- + +// CHECK-LABEL: arith_shift_left_index +// CHECK-SAME: %[[AMOUNT:.*]]: i32 +func.func @arith_shift_left_index(%amount: i32) { + %cst0 = "arith.constant"() {value = 42 : index} : () -> (index) + %cast1 = arith.index_cast %amount : i32 to index + // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t + // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t + // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t + // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index + // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1 + // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0 + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t { + // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t + // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t + // CHECK: } + %1 = arith.shli %cst0, %cast1 : index + return +} + +// ----- + +// CHECK-LABEL: arith_shift_right_index +// CHECK-SAME: %[[AMOUNT:.*]]: i32 +func.func @arith_shift_right_index(%amount: i32) { + %arg0 = "arith.constant"() {value = 42 : index} : () -> (index) + %arg1 = arith.index_cast %amount : i32 to index + // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t + // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t + // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t + // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index + // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1 + // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t { + // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t + // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t + // CHECK: } + %2 = arith.shrui %arg0, %arg1 : index + + // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t + // CHECK-DAG: %[[SAmountIdx:[^ ]*]] = emitc.cast %[[AmountIdx]] : !emitc.size_t to !emitc.ssize_t + // CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.ssize_t + // CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.ssize_t) -> !emitc.ssize_t + // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.ssize_t, !emitc.ssize_t) -> !emitc.ssize_t + // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SAmountIdx]], %[[SSizeConstant]] : (!emitc.ssize_t, !emitc.ssize_t) -> i1 + // CHECK-DAG: %[[SZeroConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t + // CHECK-DAG: %[[SCmpPositive:[^ ]*]] = emitc.cmp ge, %[[SAmountIdx]], %[[SZeroConstant]] : (!emitc.ssize_t, !emitc.ssize_t) -> i1 + // CHECK-DAG: %[[SCmpAnd:[^ ]*]] = emitc.logical_and %[[SCmpNoExcess]], %[[SCmpPositive]] : i1, i1 + // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t + // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t { + // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[SAmountIdx]] : (!emitc.ssize_t, !emitc.ssize_t) -> !emitc.ssize_t + // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpAnd]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t + // CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t + // CHECK: } + // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t + %3 = arith.shrsi %arg0, %arg1 : index return } From 5889384cbf91e1d8c52699813b302adda89e0b81 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 27 May 2024 13:29:33 +0100 Subject: [PATCH 5/8] Review comments --- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 45 +++++++------------ .../ArithToEmitC/arith-to-emitc.mlir | 33 ++++++-------- 2 files changed, 30 insertions(+), 48 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index a0416d4420fcb..7d01bea72821a 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -460,15 +460,16 @@ class BitwiseOpConversion : public OpConversionPattern { op, "expected integer or size_t/ssize_t type"); } - // There is no unsigned i1 type, bitwise ops can be performed directly - // on booleans. + // Bitwise ops can be performed directly on booleans (avoid converting to + // ui1) if (type.isInteger(1)) { rewriter.replaceOpWithNewOp(op, type, adaptor.getLhs(), adaptor.getRhs()); return success(); } - Type arithmeticType = adaptIntegralTypeSignedness(type, true); + Type arithmeticType = + adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true); Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); @@ -506,45 +507,31 @@ class ShiftOpConversion : public OpConversionPattern { Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp); Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); - Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); + // Shift amount interpreted as unsigned per Arith dialect spec. + Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(), + /*needsUnsigned=*/true); + Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType); // Add a runtime check for overflow - // This is an abuse of the size_t type since we're potentially using values - // below -1. Value width; - Type sizeTType = (isa(type)) ? arithmeticType - : (isUnsignedOp) - ? (Type)(emitc::SizeTType::get(op.getContext())) - : (Type)(emitc::SignedSizeTType::get(op.getContext())); + if (isa(type)) { Value eight = rewriter.create( - op.getLoc(), sizeTType, rewriter.getIndexAttr(8)); + op.getLoc(), rhsType, rewriter.getIndexAttr(8)); emitc::CallOpaqueOp sizeOfCall = rewriter.create( - op.getLoc(), sizeTType, "sizeof", SmallVector({eight})); - width = rewriter.create(op.getLoc(), sizeTType, eight, + op.getLoc(), rhsType, "sizeof", SmallVector({eight})); + width = rewriter.create(op.getLoc(), rhsType, eight, sizeOfCall.getResult(0)); } else { width = rewriter.create( - op.getLoc(), sizeTType, - rewriter.getIntegerAttr(sizeTType, type.getIntOrFloatBitWidth())); + op.getLoc(), rhsType, + rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth())); } - Value oobTest; Value excessCheck = rewriter.create( op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); - oobTest = excessCheck; - if (!isUnsignedOp) { - Value zero = rewriter.create( - op.getLoc(), sizeTType, - (isa(sizeTType) ? rewriter.getIntegerAttr(sizeTType, 0) - : rewriter.getIndexAttr(0))); - Value defaultCheck = - rewriter.create(op.getLoc(), rewriter.getI1Type(), - emitc::CmpPredicate::ge, rhs, zero); - oobTest = rewriter.create( - op.getLoc(), rewriter.getI1Type(), excessCheck, defaultCheck); - } + // Any concrete value is a valid refinement of poison. Value poison = rewriter.create( op.getLoc(), arithmeticType, (isa(arithmeticType) @@ -559,7 +546,7 @@ class ShiftOpConversion : public OpConversionPattern { Value arithmeticResult = rewriter.create(op.getLoc(), arithmeticType, lhs, rhs); Value resultOrPoison = rewriter.create( - op.getLoc(), arithmeticType, oobTest, arithmeticResult, poison); + op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison); rewriter.create(op.getLoc(), resultOrPoison); rewriter.setInsertionPoint(op->getBlock(), currentPoint); diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index cf95d375ba1e5..420096e747ec6 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -154,15 +154,13 @@ func.func @arith_shift_right(%arg0: i32, %arg1: i32) { // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32 %2 = arith.shrui %arg0, %arg1 : i32 - // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}i32 - // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[ARG1]], %[[SSizeConstant]] : (i32, i32) -> i1 - // CHECK-DAG: %[[SZeroConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32 - // CHECK-DAG: %[[SCmpPositive:[^ ]*]] = emitc.cmp ge, %[[ARG1]], %[[SZeroConstant]] : (i32, i32) -> i1 - // CHECK-DAG: %[[SCmpAnd:[^ ]*]] = emitc.logical_and %[[SCmpNoExcess]], %[[SCmpPositive]] : i1, i1 + // CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 + // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32 + // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1 // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32 // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32 { - // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32 - // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpAnd]], %[[SHRSI]], %[[SZero]] : i32 + // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32 + // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32 // CHECK: emitc.yield %[[STernary]] : i32 // CHECK: } %3 = arith.shrsi %arg0, %arg1 : i32 @@ -199,11 +197,12 @@ func.func @arith_shift_left_index(%amount: i32) { // CHECK-LABEL: arith_shift_right_index // CHECK-SAME: %[[AMOUNT:.*]]: i32 func.func @arith_shift_right_index(%amount: i32) { - %arg0 = "arith.constant"() {value = 42 : index} : () -> (index) - %arg1 = arith.index_cast %amount : i32 to index // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t + %arg0 = "arith.constant"() {value = 42 : index} : () -> (index) + %arg1 = arith.index_cast %amount : i32 to index + // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t @@ -217,18 +216,14 @@ func.func @arith_shift_right_index(%amount: i32) { %2 = arith.shrui %arg0, %arg1 : index // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t - // CHECK-DAG: %[[SAmountIdx:[^ ]*]] = emitc.cast %[[AmountIdx]] : !emitc.size_t to !emitc.ssize_t - // CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.ssize_t - // CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.ssize_t) -> !emitc.ssize_t - // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.ssize_t, !emitc.ssize_t) -> !emitc.ssize_t - // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SAmountIdx]], %[[SSizeConstant]] : (!emitc.ssize_t, !emitc.ssize_t) -> i1 - // CHECK-DAG: %[[SZeroConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t - // CHECK-DAG: %[[SCmpPositive:[^ ]*]] = emitc.cmp ge, %[[SAmountIdx]], %[[SZeroConstant]] : (!emitc.ssize_t, !emitc.ssize_t) -> i1 - // CHECK-DAG: %[[SCmpAnd:[^ ]*]] = emitc.logical_and %[[SCmpNoExcess]], %[[SCmpPositive]] : i1, i1 + // CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t + // CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1 // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t { - // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[SAmountIdx]] : (!emitc.ssize_t, !emitc.ssize_t) -> !emitc.ssize_t - // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpAnd]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t + // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t + // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t // CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t // CHECK: } // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t From e3778e9e7306e169d7e416c54a8256622cb5f898 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 28 May 2024 16:03:28 +0100 Subject: [PATCH 6/8] Review comments --- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 5 +- .../ArithToEmitC/arith-to-emitc.mlir | 47 ++++++++----------- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 7d01bea72821a..cbf7f388586a3 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -453,6 +453,7 @@ class BitwiseOpConversion : public OpConversionPattern { matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Vectors and tensors are not lowered currently. Type type = this->getTypeConverter()->convertType(op.getType()); if (!isa_and_nonnull( type)) { @@ -460,8 +461,7 @@ class BitwiseOpConversion : public OpConversionPattern { op, "expected integer or size_t/ssize_t type"); } - // Bitwise ops can be performed directly on booleans (avoid converting to - // ui1) + // Bitwise ops can be performed directly on booleans if (type.isInteger(1)) { rewriter.replaceOpWithNewOp(op, type, adaptor.getLhs(), adaptor.getRhs()); @@ -514,7 +514,6 @@ class ShiftOpConversion : public OpConversionPattern { // Add a runtime check for overflow Value width; - if (isa(type)) { Value eight = rewriter.create( op.getLoc(), rhsType, rewriter.getIndexAttr(8)); diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 420096e747ec6..52ee6dbeb303e 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -118,6 +118,20 @@ func.func @arith_bitwise(%arg0: i32, %arg1: i32) { // ----- +// CHECK-LABEL: arith_bitwise_bool +func.func @arith_bitwise_bool(%arg0: i1, %arg1: i1) { + // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %arg0, %arg1 : (i1, i1) -> i1 + %5 = arith.andi %arg0, %arg1 : i1 + // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %arg0, %arg1 : (i1, i1) -> i1 + %6 = arith.ori %arg0, %arg1 : i1 + // CHECK: %[[xor:[^ ]*]] = emitc.bitwise_xor %arg0, %arg1 : (i1, i1) -> i1 + %7 = arith.xori %arg0, %arg1 : i1 + + return +} + +// ----- + // CHECK-LABEL: arith_shift_left // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 func.func @arith_shift_left(%arg0: i32, %arg1: i32) { @@ -126,11 +140,10 @@ func.func @arith_shift_left(%arg0: i32, %arg1: i32) { // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32 // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1 // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0 - // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 { + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32 // CHECK: emitc.yield %[[Ternary]] : ui32 - // CHECK: } // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32 %1 = arith.shli %arg0, %arg1 : i32 return @@ -146,11 +159,10 @@ func.func @arith_shift_right(%arg0: i32, %arg1: i32) { // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32 // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1 // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32 - // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 { + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32 // CHECK: emitc.yield %[[Ternary]] : ui32 - // CHECK: } // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32 %2 = arith.shrui %arg0, %arg1 : i32 @@ -158,11 +170,10 @@ func.func @arith_shift_right(%arg0: i32, %arg1: i32) { // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32 // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1 // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32 - // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32 { + // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32 // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32 // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32 // CHECK: emitc.yield %[[STernary]] : i32 - // CHECK: } %3 = arith.shrsi %arg0, %arg1 : i32 return @@ -183,11 +194,10 @@ func.func @arith_shift_left_index(%amount: i32) { // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1 // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0 - // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t { + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t - // CHECK: } %1 = arith.shli %cst0, %cast1 : index return } @@ -208,11 +218,10 @@ func.func @arith_shift_right_index(%amount: i32) { // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1 // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t - // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t { + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t - // CHECK: } %2 = arith.shrui %arg0, %arg1 : index // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t @@ -221,32 +230,16 @@ func.func @arith_shift_right_index(%amount: i32) { // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1 // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t - // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t { + // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t // CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t - // CHECK: } // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t %3 = arith.shrsi %arg0, %arg1 : index return } -// ----- - -// CHECK-LABEL: arith_bitwise_bool -func.func @arith_bitwise_bool(%arg0: i1, %arg1: i1) { - // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %arg0, %arg1 : (i1, i1) -> i1 - %5 = arith.andi %arg0, %arg1 : i1 - // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %arg0, %arg1 : (i1, i1) -> i1 - %6 = arith.ori %arg0, %arg1 : i1 - // CHECK: %[[xor:[^ ]*]] = emitc.bitwise_xor %arg0, %arg1 : (i1, i1) -> i1 - %7 = arith.xori %arg0, %arg1 : i1 - - return -} - - // ----- func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () { From 860ff628d281687fc3206b53ac932cc73af0c6f9 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 28 May 2024 16:10:18 +0100 Subject: [PATCH 7/8] add missing comment on unsigned --- mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index f2ae9d67c595f..10cca6bed2f8d 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -468,6 +468,7 @@ class BitwiseOpConversion : public OpConversionPattern { return success(); } + // Bitwise ops are defined by the C standard on unsigned operands. Type arithmeticType = adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true); From 192958aaceff7c59201de6be6aea8d56a5d1be23 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 29 May 2024 09:39:23 +0100 Subject: [PATCH 8/8] Vector/tensor comment fix --- mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 10cca6bed2f8d..93637d02b1f0e 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -453,12 +453,12 @@ class BitwiseOpConversion : public OpConversionPattern { matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Vectors and tensors are not lowered currently. Type type = this->getTypeConverter()->convertType(op.getType()); if (!isa_and_nonnull( type)) { return rewriter.notifyMatchFailure( - op, "expected integer or size_t/ssize_t type"); + op, "expected integer or size_t/ssize_t type, vector/tensor support " + "not yet implemented"); } // Bitwise ops can be performed directly on booleans