Skip to content

Commit bdb918e

Browse files
[mlir][arith] arith-to-apfloat: Bail on unsupported bitwidth (#170994)
Bitwidths greater than 64 are not supported by `arith-to-apfloat`.
1 parent 3b355b2 commit bdb918e

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
102102

103103
LogicalResult matchAndRewrite(OpTy op,
104104
PatternRewriter &rewriter) const override {
105+
if (op.getType().getIntOrFloatBitWidth() > 64)
106+
return rewriter.notifyMatchFailure(op,
107+
"bitwidth > 64 bits is not supported");
108+
105109
// Get APFloat function from runtime library.
106110
FailureOr<FuncOp> fn =
107111
lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
@@ -148,6 +152,11 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
148152

149153
LogicalResult matchAndRewrite(OpTy op,
150154
PatternRewriter &rewriter) const override {
155+
if (op.getType().getIntOrFloatBitWidth() > 64 ||
156+
op.getOperand().getType().getIntOrFloatBitWidth() > 64)
157+
return rewriter.notifyMatchFailure(op,
158+
"bitwidth > 64 bits is not supported");
159+
151160
// Get APFloat function from runtime library.
152161
auto i32Type = IntegerType::get(symTable->getContext(), 32);
153162
auto i64Type = IntegerType::get(symTable->getContext(), 64);
@@ -195,9 +204,10 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
195204

196205
LogicalResult matchAndRewrite(OpTy op,
197206
PatternRewriter &rewriter) const override {
198-
if (op.getType().getIntOrFloatBitWidth() > 64)
199-
return rewriter.notifyMatchFailure(
200-
op, "result type > 64 bits is not supported");
207+
if (op.getType().getIntOrFloatBitWidth() > 64 ||
208+
op.getOperand().getType().getIntOrFloatBitWidth() > 64)
209+
return rewriter.notifyMatchFailure(op,
210+
"bitwidth > 64 bits is not supported");
201211

202212
// Get APFloat function from runtime library.
203213
auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -252,11 +262,10 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
252262

253263
LogicalResult matchAndRewrite(OpTy op,
254264
PatternRewriter &rewriter) const override {
255-
Location loc = op.getLoc();
256-
if (op.getIn().getType().getIntOrFloatBitWidth() > 64) {
257-
return rewriter.notifyMatchFailure(
258-
loc, "integer bitwidth > 64 is not supported");
259-
}
265+
if (op.getType().getIntOrFloatBitWidth() > 64 ||
266+
op.getOperand().getType().getIntOrFloatBitWidth() > 64)
267+
return rewriter.notifyMatchFailure(op,
268+
"bitwidth > 64 bits is not supported");
260269

261270
// Get APFloat function from runtime library.
262271
auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -270,6 +279,7 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
270279

271280
rewriter.setInsertionPoint(op);
272281
// Cast operands to 64-bit integers.
282+
Location loc = op.getLoc();
273283
auto inIntTy = cast<IntegerType>(op.getOperand().getType());
274284
Value operandBits = op.getOperand();
275285
if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
@@ -317,6 +327,10 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
317327

318328
LogicalResult matchAndRewrite(arith::CmpFOp op,
319329
PatternRewriter &rewriter) const override {
330+
if (op.getLhs().getType().getIntOrFloatBitWidth() > 64)
331+
return rewriter.notifyMatchFailure(op,
332+
"bitwidth > 64 bits is not supported");
333+
320334
// Get APFloat function from runtime library.
321335
auto i1Type = IntegerType::get(symTable->getContext(), 1);
322336
auto i8Type = IntegerType::get(symTable->getContext(), 8);
@@ -456,6 +470,10 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
456470

457471
LogicalResult matchAndRewrite(arith::NegFOp op,
458472
PatternRewriter &rewriter) const override {
473+
if (op.getOperand().getType().getIntOrFloatBitWidth() > 64)
474+
return rewriter.notifyMatchFailure(op,
475+
"bitwidth > 64 bits is not supported");
476+
459477
// Get APFloat function from runtime library.
460478
auto i32Type = IntegerType::get(symTable->getContext(), 32);
461479
auto i64Type = IntegerType::get(symTable->getContext(), 64);

mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,28 @@ func.func @maxnumf(%arg0: f32, %arg1: f32) {
263263
%0 = arith.maxnumf %arg0, %arg1 : f32
264264
return
265265
}
266+
267+
// -----
268+
269+
// CHECK-LABEL: func.func @unsupported_bitwidth
270+
// CHECK: arith.addf {{.*}} : f128
271+
// CHECK: arith.negf {{.*}} : f128
272+
// CHECK: arith.cmpf {{.*}} : f128
273+
// CHECK: arith.extf {{.*}} : f32 to f128
274+
// CHECK: arith.truncf {{.*}} : f128 to f32
275+
// CHECK: arith.fptosi {{.*}} : f128 to i32
276+
// CHECK: arith.fptosi {{.*}} : f32 to i92
277+
// CHECK: arith.sitofp {{.*}} : i1 to f128
278+
// CHECK: arith.sitofp {{.*}} : i92 to f32
279+
func.func @unsupported_bitwidth(%arg0: f128, %arg1: f128, %arg2: f32) {
280+
%0 = arith.addf %arg0, %arg1 : f128
281+
%1 = arith.negf %arg0 : f128
282+
%2 = arith.cmpf "ult", %arg0, %arg1 : f128
283+
%3 = arith.extf %arg2 : f32 to f128
284+
%4 = arith.truncf %arg0 : f128 to f32
285+
%5 = arith.fptosi %arg0 : f128 to i32
286+
%6 = arith.fptosi %arg2 : f32 to i92
287+
%7 = arith.sitofp %2 : i1 to f128
288+
%8 = arith.sitofp %6 : i92 to f32
289+
return
290+
}

0 commit comments

Comments
 (0)