Skip to content

Commit 51ea141

Browse files
committed
Support more Arith integer binary ops
1 parent 57dd987 commit 51ea141

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,46 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
443443
}
444444
};
445445

446+
template <typename ArithOp, typename EmitCOp, bool booleansLegal>
447+
class BitwiseOpConversion final : public OpConversionPattern<ArithOp> {
448+
public:
449+
using OpConversionPattern<ArithOp>::OpConversionPattern;
450+
451+
LogicalResult
452+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
453+
ConversionPatternRewriter &rewriter) const override {
454+
455+
Type type = this->getTypeConverter()->convertType(op.getType());
456+
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
457+
type)) {
458+
return rewriter.notifyMatchFailure(
459+
op, "expected integer or size_t/ssize_t type");
460+
}
461+
462+
if (type.isInteger(1)) {
463+
if (!booleansLegal)
464+
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
465+
466+
rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
467+
adaptor.getRhs());
468+
return success();
469+
}
470+
471+
Type arithmeticType = adaptIntegralTypeSignedness(type, true);
472+
473+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
474+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
475+
476+
Value arithmeticResult = rewriter.template create<EmitCOp>(
477+
op.getLoc(), arithmeticType, lhs, rhs);
478+
479+
Value result = adaptValueType(arithmeticResult, rewriter, type);
480+
481+
rewriter.replaceOp(op, result);
482+
return success();
483+
}
484+
};
485+
446486
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
447487
public:
448488
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -581,6 +621,11 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
581621
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
582622
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
583623
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
624+
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp, true>,
625+
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp, true>,
626+
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp, true>,
627+
BitwiseOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp, false>,
628+
BitwiseOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp, false>,
584629
CmpFOpConversion,
585630
CmpIOpConversion,
586631
SelectOpConversion,

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,54 @@ func.func @arith_index(%arg0: i32, %arg1: i32) {
9292
return
9393
}
9494

95+
// -----
96+
97+
// CHECK-LABEL: arith_bitwise
98+
func.func @arith_bitwise(%arg0: i32, %arg1: i32) {
99+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
100+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
101+
// CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
102+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[AND]] : ui32 to i32
103+
%5 = arith.andi %arg0, %arg1 : i32
104+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
105+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
106+
// CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
107+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[OR]] : ui32 to i32
108+
%6 = arith.ori %arg0, %arg1 : i32
109+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
110+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
111+
// CHECK: %[[XOR:[^ ]*]] = emitc.bitwise_xor %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
112+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[XOR]] : ui32 to i32
113+
%7 = arith.xori %arg0, %arg1 : i32
114+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
115+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
116+
// CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
117+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SHL]] : ui32 to i32
118+
%8 = arith.shli %arg0, %arg1 : i32
119+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
120+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
121+
// CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
122+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SHR]] : ui32 to i32
123+
%9 = arith.shrsi %arg0, %arg1 : i32
124+
125+
return
126+
}
127+
128+
// -----
129+
130+
// CHECK-LABEL: arith_bitwise_bool
131+
func.func @arith_bitwise_bool(%arg0: i1, %arg1: i1) {
132+
// CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %arg0, %arg1 : (i1, i1) -> i1
133+
%5 = arith.andi %arg0, %arg1 : i1
134+
// CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %arg0, %arg1 : (i1, i1) -> i1
135+
%6 = arith.ori %arg0, %arg1 : i1
136+
// CHECK: %[[xor:[^ ]*]] = emitc.bitwise_xor %arg0, %arg1 : (i1, i1) -> i1
137+
%7 = arith.xori %arg0, %arg1 : i1
138+
139+
return
140+
}
141+
142+
95143
// -----
96144

97145
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {

0 commit comments

Comments
 (0)