Skip to content

Commit 0c47b6a

Browse files
committed
Add EmitC index type converter to FuncToEmitC
1 parent 57dd987 commit 0c47b6a

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1718
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1820
#include "mlir/Pass/Pass.h"
1921
#include "mlir/Transforms/DialectConversion.h"
22+
#include <algorithm>
2023

2124
namespace mlir {
2225
#define GEN_PASS_DEF_CONVERTFUNCTOEMITC
@@ -33,6 +36,51 @@ struct ConvertFuncToEmitC
3336
} // namespace
3437

3538
void ConvertFuncToEmitC::runOnOperation() {
39+
// Convert function interface types within the func dialect first to supported
40+
// EmitC types
41+
ConversionTarget interfaceConversionTarget(getContext());
42+
interfaceConversionTarget.addDynamicallyLegalOp<func::CallOp>(
43+
[](func::CallOp op) {
44+
auto operandTypes = op->getOperandTypes();
45+
if (std::any_of(operandTypes.begin(), operandTypes.end(),
46+
[](Type t) { return isa<IndexType>(t); }))
47+
return false;
48+
auto resultTypes = op.getResultTypes();
49+
return !(std::any_of(resultTypes.begin(), resultTypes.end(),
50+
[](Type t) { return isa<IndexType>(t); }));
51+
});
52+
interfaceConversionTarget.addDynamicallyLegalOp<func::FuncOp>(
53+
[](func::FuncOp op) {
54+
auto operandTypes = op->getOperandTypes();
55+
if (std::any_of(operandTypes.begin(), operandTypes.end(),
56+
[](Type t) { return isa<IndexType>(t); }))
57+
return false;
58+
auto resultTypes = op.getResultTypes();
59+
return !(std::any_of(resultTypes.begin(), resultTypes.end(),
60+
[](Type t) { return isa<IndexType>(t); }));
61+
});
62+
interfaceConversionTarget.addDynamicallyLegalOp<func::ReturnOp>(
63+
[](func::ReturnOp op) {
64+
auto operandTypes = op->getOperandTypes();
65+
return !(std::any_of(operandTypes.begin(), operandTypes.end(),
66+
[](Type t) { return isa<IndexType>(t); }));
67+
});
68+
69+
RewritePatternSet interfaceRewritePatterns(&getContext());
70+
TypeConverter typeConverter;
71+
typeConverter.addConversion([](Type type) { return type; });
72+
populateEmitCSizeTypeConversionPatterns(typeConverter);
73+
populateReturnOpTypeConversionPattern(interfaceRewritePatterns,
74+
typeConverter);
75+
populateCallOpTypeConversionPattern(interfaceRewritePatterns, typeConverter);
76+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
77+
interfaceRewritePatterns, typeConverter);
78+
79+
if (failed(applyPartialConversion(getOperation(), interfaceConversionTarget,
80+
std::move(interfaceRewritePatterns))))
81+
signalPassFailure();
82+
83+
// Then convert the func ops themselves to EmitC
3684
ConversionTarget target(getContext());
3785

3886
target.addLegalDialect<emitc::EmitCDialect>();

mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,25 @@ func.func @call(%arg0: i32) -> i32 {
5858

5959
// CHECK-LABEL: emitc.func private @return_i32(i32) -> i32 attributes {specifiers = ["extern"]}
6060
func.func private @return_i32(%arg0: i32) -> i32
61+
62+
// -----
63+
64+
// CHECK-LABEL: emitc.func @use_index
65+
// CHECK-SAME: (%[[Arg0:.*]]: !emitc.size_t) -> !emitc.size_t
66+
// CHECK: emitc.return %[[Arg0]] : !emitc.size_t
67+
func.func @use_index(%arg0: index) -> index {
68+
return %arg0 : index
69+
}
70+
71+
// -----
72+
73+
// CHECK-LABEL: emitc.func private @prototype_index(!emitc.size_t) -> !emitc.size_t attributes {specifiers = ["extern"]}
74+
func.func private @prototype_index(%arg0: index) -> index
75+
76+
// CHECK-LABEL: emitc.func @call(%arg0: !emitc.size_t) -> !emitc.size_t
77+
// CHECK-NEXT: %0 = emitc.call @prototype_index(%arg0) : (!emitc.size_t) -> !emitc.size_t
78+
// CHECK-NEXT: emitc.return %0 : !emitc.size_t
79+
func.func @call(%arg0: index) -> index {
80+
%0 = call @prototype_index(%arg0) : (index) -> (index)
81+
return %0 : index
82+
}

0 commit comments

Comments
 (0)