Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FXML-4501] Lower arith.andi, arith.ori, arith.shli, arith.shrsi, arith.shrui, arith.xori to emitc #165

Merged
merged 9 commits into from
May 29, 2024
132 changes: 132 additions & 0 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Region.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down Expand Up @@ -443,6 +444,131 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
}
};

template <typename ArithOp, typename EmitCOp>
class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
type)) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t type, vector/tensor support "
"not yet implemented");
}

// Bitwise ops can be performed directly on booleans
if (type.isInteger(1)) {
rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
adaptor.getRhs());
return success();
}

// Bitwise ops are defined by the C standard on unsigned operands.
Type arithmeticType =
adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);

Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

Value arithmeticResult = rewriter.template create<EmitCOp>(
op.getLoc(), arithmeticType, lhs, rhs);

Value result = adaptValueType(arithmeticResult, rewriter, type);

rewriter.replaceOp(op, result);
return success();
}
};

template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
class ShiftOpConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
type)) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t type");
}

if (type.isInteger(1)) {
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}

Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);

Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
// Shift amount interpreted as unsigned per Arith dialect spec.
Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
/*needsUnsigned=*/true);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);

// Add a runtime check for overflow
Value width;
if (isa<emitc::SignedSizeTType, emitc::SizeTType>(type)) {
Value eight = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType, rewriter.getIndexAttr(8));
emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
op.getLoc(), rhsType, "sizeof", SmallVector<Value, 1>({eight}));
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
sizeOfCall.getResult(0));
} else {
width = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType,
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
}

Value excessCheck = rewriter.create<emitc::CmpOp>(
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);

// Any concrete value is a valid refinement of poison.
Value poison = rewriter.create<emitc::ConstantOp>(
op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));

emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
op.getLoc(), arithmeticType, /*do_not_inline=*/false);
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
auto currentPoint = rewriter.getInsertionPoint();
rewriter.setInsertionPointToStart(&bodyBlock);
Value arithmeticResult =
rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
rewriter.setInsertionPoint(op->getBlock(), currentPoint);

Value result = adaptValueType(ternary, rewriter, type);

rewriter.replaceOp(op, result);
return success();
}
};

template <typename ArithOp, typename EmitCOp>
class SignedShiftOpConversion final
: public ShiftOpConversion<ArithOp, EmitCOp, false> {
using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
};

template <typename ArithOp, typename EmitCOp>
class UnsignedShiftOpConversion final
: public ShiftOpConversion<ArithOp, EmitCOp, true> {
using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
};

class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
Expand Down Expand Up @@ -581,6 +707,12 @@ void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns,
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
CmpFOpConversion,
CmpIOpConversion,
SelectOpConversion,
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,27 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
%t = arith.fptoui %arg0 : f32 to i1
return %t: i1
}

// -----

func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
// expected-error @+1 {{failed to legalize operation 'arith.shli'}}
%shli = arith.shli %arg0, %arg1 : i1
return
}

// -----

func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
// expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
%shrsi = arith.shrsi %arg0, %arg1 : i1
return
}

// -----

func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
// expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
%shrui = arith.shrui %arg0, %arg1 : i1
return
}
148 changes: 148 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,154 @@ func.func @arith_index(%arg0: i32, %arg1: i32) {

// -----

// CHECK-LABEL: arith_bitwise
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
func.func @arith_bitwise(%arg0: i32, %arg1: i32) {
// CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
// CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
// CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[AND]] : ui32 to i32
%5 = arith.andi %arg0, %arg1 : i32
// CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
// CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
// CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[OR]] : ui32 to i32
%6 = arith.ori %arg0, %arg1 : i32
// CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
// CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
// CHECK: %[[XOR:[^ ]*]] = emitc.bitwise_xor %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[XOR]] : ui32 to i32
%7 = arith.xori %arg0, %arg1 : i32

return
}

// -----

// CHECK-LABEL: arith_bitwise_bool
func.func @arith_bitwise_bool(%arg0: i1, %arg1: i1) {
// CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %arg0, %arg1 : (i1, i1) -> i1
%5 = arith.andi %arg0, %arg1 : i1
// CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %arg0, %arg1 : (i1, i1) -> i1
%6 = arith.ori %arg0, %arg1 : i1
// CHECK: %[[xor:[^ ]*]] = emitc.bitwise_xor %arg0, %arg1 : (i1, i1) -> i1
%7 = arith.xori %arg0, %arg1 : i1

return
}

// -----

// CHECK-LABEL: arith_shift_left
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
func.func @arith_shift_left(%arg0: i32, %arg1: i32) {
// CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
// CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
// CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32
// CHECK: emitc.yield %[[Ternary]] : ui32
// CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
%1 = arith.shli %arg0, %arg1 : i32
return
}

// -----

// CHECK-LABEL: arith_shift_right
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
// CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
// CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
// CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32
// CHECK: emitc.yield %[[Ternary]] : ui32
// CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
%2 = arith.shrui %arg0, %arg1 : i32

// CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32
// CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32
// CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32
// CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32
// CHECK: emitc.yield %[[STernary]] : i32
%3 = arith.shrsi %arg0, %arg1 : i32

return
}

// -----

// CHECK-LABEL: arith_shift_left_index
// CHECK-SAME: %[[AMOUNT:.*]]: i32
func.func @arith_shift_left_index(%amount: i32) {
%cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
%cast1 = arith.index_cast %amount : i32 to index
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
// CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
// CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
// CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
// CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
// CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t
// CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
%1 = arith.shli %cst0, %cast1 : index
return
}

// -----

// CHECK-LABEL: arith_shift_right_index
// CHECK-SAME: %[[AMOUNT:.*]]: i32
func.func @arith_shift_right_index(%amount: i32) {
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
// CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
// CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
%arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
%arg1 = arith.index_cast %amount : i32 to index

// CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
// CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
// CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t
// CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
%2 = arith.shrui %arg0, %arg1 : index

// CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t
// CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
// CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t
// CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t
// CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t
// CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t
// CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t
// CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t
%3 = arith.shrsi %arg0, %arg1 : index

return
}

// -----

func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
Expand Down