@@ -36,10 +36,17 @@ class CallOpConversion final : public OpConversionPattern<func::CallOp> {
36
36
return rewriter.notifyMatchFailure (
37
37
callOp, " only functions with zero or one result can be converted" );
38
38
39
+ // Convert the original function results.
40
+ Type resultTy = nullptr ;
41
+ if (callOp.getNumResults ()) {
42
+ resultTy = typeConverter->convertType (callOp.getResult (0 ).getType ());
43
+ if (!resultTy)
44
+ return rewriter.notifyMatchFailure (
45
+ callOp, " function return type conversion failed" );
46
+ }
47
+
39
48
rewriter.replaceOpWithNewOp <emitc::CallOp>(
40
- callOp,
41
- callOp.getNumResults () ? callOp.getResult (0 ).getType () : nullptr ,
42
- adaptor.getOperands (), callOp->getAttrs ());
49
+ callOp, resultTy, adaptor.getOperands (), callOp->getAttrs ());
43
50
44
51
return success ();
45
52
}
@@ -53,13 +60,34 @@ class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
53
60
matchAndRewrite (func::FuncOp funcOp, OpAdaptor adaptor,
54
61
ConversionPatternRewriter &rewriter) const override {
55
62
56
- if (funcOp.getFunctionType ().getNumResults () > 1 )
63
+ FunctionType type = funcOp.getFunctionType ();
64
+ if (!type)
65
+ return failure ();
66
+
67
+ if (type.getNumResults () > 1 )
57
68
return rewriter.notifyMatchFailure (
58
69
funcOp, " only functions with zero or one result can be converted" );
59
70
71
+ const TypeConverter *converter = getTypeConverter ();
72
+
73
+ // Convert function signature
74
+ TypeConverter::SignatureConversion signatureConversion (type.getNumInputs ());
75
+ SmallVector<Type, 1 > convertedResults;
76
+ if (failed (converter->convertSignatureArgs (type.getInputs (),
77
+ signatureConversion)) ||
78
+ failed (converter->convertTypes (type.getResults (), convertedResults)) ||
79
+ failed (rewriter.convertRegionTypes (&funcOp.getFunctionBody (),
80
+ *converter, &signatureConversion)))
81
+ return rewriter.notifyMatchFailure (funcOp, " signature conversion failed" );
82
+
83
+ // Convert the function type
84
+ auto convertedFunctionType = FunctionType::get (
85
+ rewriter.getContext (), signatureConversion.getConvertedTypes (),
86
+ convertedResults);
87
+
60
88
// Create the converted `emitc.func` op.
61
89
emitc::FuncOp newFuncOp = rewriter.create <emitc::FuncOp>(
62
- funcOp.getLoc (), funcOp.getName (), funcOp. getFunctionType () );
90
+ funcOp.getLoc (), funcOp.getName (), convertedFunctionType );
63
91
64
92
// Copy over all attributes other than the function name and type.
65
93
for (const auto &namedAttr : funcOp->getAttrs ()) {
@@ -113,8 +141,10 @@ class ReturnOpConversion final : public OpConversionPattern<func::ReturnOp> {
113
141
// Pattern population
114
142
// ===----------------------------------------------------------------------===//
115
143
116
- void mlir::populateFuncToEmitCPatterns (RewritePatternSet &patterns) {
144
+ void mlir::populateFuncToEmitCPatterns (RewritePatternSet &patterns,
145
+ TypeConverter &typeConverter) {
117
146
MLIRContext *ctx = patterns.getContext ();
118
147
119
- patterns.add <CallOpConversion, FuncOpConversion, ReturnOpConversion>(ctx);
148
+ patterns.add <CallOpConversion, FuncOpConversion, ReturnOpConversion>(
149
+ typeConverter, ctx);
120
150
}
0 commit comments