Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FXML-4281] Remove flag for rounding mode of casting ops #168

Merged
merged 4 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ class RewritePatternSet;
class TypeConverter;

void populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns,
bool optionFloatToIntTruncates);
RewritePatternSet &patterns);
} // namespace mlir

#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
17 changes: 0 additions & 17 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -139,24 +139,7 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {

def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
let summary = "Convert Arith dialect to EmitC dialect";
let description = [{
This pass converts `arith` dialect operations to `emitc`.

The semantics of floating-point to integer conversions `arith.fptosi`,
`arith.fptoui` require rounding towards zero. Typical C++ implementations
use this behavior for float-to-integer casts, but that is not mandated by
C++ and there are implementation-defined means to change the default behavior.

If casts can be guaranteed to use round-to-zero, use the
`float-to-int-truncates` flag to allow conversion of `arith.fptosi` and
`arith.fptoui` operations.
}];
let dependentDialects = ["emitc::EmitCDialect"];
let options = [
Option<"floatToIntTruncates", "float-to-int-truncates", "bool",
/*default=*/"false",
"Whether the behavior of float-to-int cast in emitc is truncation">,
];
}

//===----------------------------------------------------------------------===//
Expand Down
27 changes: 8 additions & 19 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,9 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
// Floating-point to integer conversions.
template <typename CastOp>
class FtoICastOpConversion : public OpConversionPattern<CastOp> {
private:
bool floatToIntTruncates;

public:
FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context,
bool optionFloatToIntTruncates)
: OpConversionPattern<CastOp>(typeConverter, context),
floatToIntTruncates(optionFloatToIntTruncates) {}
FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<CastOp>(typeConverter, context) {}

LogicalResult
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
Expand All @@ -384,16 +379,13 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");

if (!floatToIntTruncates)
return rewriter.notifyMatchFailure(
castOp, "conversion currently requires EmitC casts to use truncation "
"as rounding mode");

Type dstType = this->getTypeConverter()->convertType(castOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");

if (!emitc::isSupportedIntegerType(dstType))
// Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
// truncated to 0, whereas a boolean conversion would return true.
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

Expand Down Expand Up @@ -468,8 +460,7 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
//===----------------------------------------------------------------------===//

void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns,
bool optionFloatToIntTruncates) {
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();

// clang-format off
Expand All @@ -488,11 +479,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
CmpIOpConversion,
SelectOpConversion,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>
>(typeConverter, ctx)
.add<
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
FtoICastOpConversion<arith::FPToUIOp>
>(typeConverter, ctx, optionFloatToIntTruncates);
>(typeConverter, ctx);
// clang-format on
}
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void ConvertArithToEmitC::runOnOperation() {
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

populateArithToEmitCPatterns(typeConverter, patterns, floatToIntTruncates);
populateArithToEmitCPatterns(typeConverter, patterns);

if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand Down

This file was deleted.

This file was deleted.

62 changes: 60 additions & 2 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,66 @@ func.func @arith_cmpf_vector(%arg0: vector<5xf32>, %arg1: vector<5xf32>) -> vect

// -----

func.func @arith_cast_f32(%arg0: f32) -> i32 {
func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : f32 to i32
%t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
return %t: tensor<5xi32>
}

// -----

func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
return %t: vector<5xi32>
}

// -----

func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : bf16 to i32
return %t: i32
}

// -----

func.func @arith_cast_f16(%arg0: f16) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : f16 to i32
return %t: i32
}


// -----

func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
%t = arith.sitofp %arg0 : i32 to bf16
return %t: bf16
}

// -----

func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
%t = arith.sitofp %arg0 : i32 to f16
return %t: f16
}

// -----

func.func @arith_cast_fptosi_i1(%arg0: f32) -> i1 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : f32 to i1
return %t: i1
}

// -----

func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
// expected-error @+1 {{failed to legalize operation 'arith.fptoui'}}
%t = arith.fptoui %arg0 : f32 to i1
return %t: i1
}

52 changes: 36 additions & 16 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -309,22 +309,6 @@ func.func @arith_cmpf_true(%arg0: f32, %arg1: f32) -> i1 {

// -----

func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
// CHECK: emitc.cast %arg0 : i8 to f32
%0 = arith.sitofp %arg0 : i8 to f32

// CHECK: emitc.cast %arg1 : i64 to f32
%1 = arith.sitofp %arg1 : i64 to f32

// CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
// CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
%2 = arith.uitofp %arg0 : i8 to f32

return
}

// -----

func.func @arith_cmpi_eq(%arg0: i32, %arg1: i32) -> i1 {
// CHECK-LABEL: arith_cmpi_eq
// CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32)
Expand Down Expand Up @@ -370,3 +354,39 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {

return
}

// -----

func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
// CHECK: emitc.cast %arg0 : f32 to i32
%0 = arith.fptosi %arg0 : f32 to i32

// CHECK: emitc.cast %arg1 : f64 to i32
%1 = arith.fptosi %arg1 : f64 to i32

// CHECK: emitc.cast %arg0 : f32 to i16
%2 = arith.fptosi %arg0 : f32 to i16

// CHECK: emitc.cast %arg1 : f64 to i16
%3 = arith.fptosi %arg1 : f64 to i16

// CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32
// CHECK: emitc.cast %[[CAST0]] : ui32 to i32
%4 = arith.fptoui %arg0 : f32 to i32

return
}

func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
// CHECK: emitc.cast %arg0 : i8 to f32
%0 = arith.sitofp %arg0 : i8 to f32

// CHECK: emitc.cast %arg1 : i64 to f32
%1 = arith.sitofp %arg1 : i64 to f32

// CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
// CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
%2 = arith.uitofp %arg0 : i8 to f32

return
}