@@ -54,11 +54,9 @@ class MLIRToyVisitor : public ToyBaseVisitor {
54
54
// / The builder helps create MLIR operations when traversing the AST.
55
55
mlir::OpBuilder builder;
56
56
// / The Symbol Table
57
- // / [TODO][LOW] make the symbol table support function prototype.
58
57
llvm::ScopedHashTable<llvm::StringRef, mlir::Value> symbolTable;
59
- // / Return Status Flag
60
- // / The syntax supports omitting the return expression.
61
- bool returnFlag = false ;
58
+ llvm::ScopedHashTable<llvm::StringRef, int > funSymbolTable;
59
+ llvm::StringMap<mlir::toy::FuncOp> functionMap;
62
60
// Register the filename for the string attribute in MLIR location object.
63
61
std::string fileName;
64
62
@@ -71,6 +69,15 @@ class MLIRToyVisitor : public ToyBaseVisitor {
71
69
symbolTable.insert (var, value);
72
70
return mlir::success ();
73
71
}
72
+ // Declear a function in the current module
73
+ // / - Check the parameter number of the function.
74
+ mlir::LogicalResult funcDeclare (llvm::StringRef functionName,
75
+ int argsNumber) {
76
+ if (funSymbolTable.count (functionName))
77
+ return mlir::failure ();
78
+ funSymbolTable.insert (functionName, argsNumber);
79
+ return mlir::success ();
80
+ }
74
81
75
82
// / Location
76
83
// / Get the MLIR location object with the current line and row of the toy
@@ -89,24 +96,54 @@ class MLIRToyVisitor : public ToyBaseVisitor {
89
96
90
97
// Get the tensor value from the tensor literal node.
91
98
std::any getTensor (ToyParser::TensorLiteralContext *ctx) {
92
- // [TODO][HIGH] find a better way to define the `dims`.
93
99
std::vector<int64_t > dims;
94
100
// get dimensions.
95
101
dims.push_back (ctx->Comma ().size () + 1 );
96
102
if (ctx->tensorLiteral (0 )->tensorLiteral (0 )) {
97
- dims.push_back (ctx->tensorLiteral (0 )->Comma ().size () + 1 );
103
+ ToyParser::TensorLiteralContext *list = ctx->tensorLiteral (0 );
104
+ while (list) {
105
+ dims.push_back (list->Comma ().size () + 1 );
106
+ if (list->tensorLiteral (0 ) && list->tensorLiteral (0 )->Comma ().size ())
107
+ list = list->tensorLiteral (0 );
108
+ else
109
+ break ;
110
+ }
98
111
}
99
112
mlir::Type elementType = builder.getF64Type ();
100
- auto type = getType (dims);
113
+ mlir::Type type = getType (dims);
101
114
auto dataType = mlir::RankedTensorType::get (dims, elementType);
102
- auto dataAttribute =
115
+ mlir::DenseElementsAttr dataAttribute =
103
116
mlir::DenseElementsAttr::get (dataType, llvm::makeArrayRef (ctx->data ));
104
- auto loaction =
117
+ mlir::Location loaction =
105
118
loc (ctx->start ->getLine (), ctx->start ->getCharPositionInLine ());
106
119
mlir::Value value =
107
120
builder.create <mlir::toy::ConstantOp>(loaction, type, dataAttribute);
108
121
return value;
109
122
}
123
+ // Module Visitor
124
+ // - Visitor all function asts to get the number of function parameter.
125
+ // - Visitor childrens.
126
+ virtual std::any visitModule (ToyParser::ModuleContext *ctx) override {
127
+ llvm::ScopedHashTableScope<llvm::StringRef, int > protoTypeSymbolTable (
128
+ funSymbolTable);
129
+ for (auto &function : ctx->funDefine ()) {
130
+ ToyParser::PrototypeContext *protoType = function->prototype ();
131
+ std::string functionName = protoType->Identifier ()->toString ();
132
+ int declNumber = 0 ;
133
+ if (protoType->declList ()) {
134
+ ToyParser::DeclListContext *list = protoType->declList ();
135
+ while (list) {
136
+ declNumber++;
137
+ if (list->declList ())
138
+ list = list->declList ();
139
+ else
140
+ break ;
141
+ }
142
+ }
143
+ funcDeclare (function->prototype ()->idName , declNumber);
144
+ }
145
+ return visitChildren (ctx);
146
+ }
110
147
111
148
// / Function Definition Visitor
112
149
// / - Register the function name, argument list, and return value into the
@@ -115,35 +152,70 @@ class MLIRToyVisitor : public ToyBaseVisitor {
115
152
// / - Visit fucntion block.
116
153
// / - Process the return operation.
117
154
virtual std::any visitFunDefine (ToyParser::FunDefineContext *ctx) override {
118
- returnFlag = false ;
119
- // [TODO] make the function support argument list and return value.
120
155
llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope (
121
156
symbolTable);
122
157
builder.setInsertionPointToEnd (theModule.getBody ());
123
158
// Visit function prototype.
124
- visit (ctx->prototype ());
159
+ mlir::toy::FuncOp function =
160
+ std::any_cast<mlir::toy::FuncOp>(visit (ctx->prototype ()));
161
+ mlir::Block &entryBlock = function.front ();
162
+
163
+ // Set the insertion point in the builder to the beginning of the function
164
+ // body, it will be used throughout the codegen to create operations in this
165
+ // function.
166
+ builder.setInsertionPointToStart (&entryBlock);
167
+
168
+ std::vector<std::string> args;
169
+ if (ctx->prototype ()->declList ()) {
170
+ ToyParser::DeclListContext *list = ctx->prototype ()->declList ();
171
+ while (list->Identifier ()) {
172
+ args.push_back (list->Identifier ()->toString ());
173
+ if (list->declList ())
174
+ list = list->declList ();
175
+ else
176
+ break ;
177
+ }
178
+ }
179
+ // Declare all the function arguments in the symbol table.
180
+ llvm::ArrayRef<std::string> protoArgs = args;
181
+ for (auto value : llvm::zip (protoArgs, entryBlock.getArguments ())) {
182
+ declare (std::get<0 >(value), std::get<1 >(value));
183
+ }
184
+
125
185
// Visit fucntion block.
126
186
visit (ctx->block ());
127
187
// Check the return status.
128
188
// If there is no return expression at the end of the function, it will
129
189
// generate a return operation automatically.
130
- if (!returnFlag) {
131
- auto location =
190
+ mlir::toy::ReturnOp returnOp;
191
+ if (!entryBlock.empty ())
192
+ returnOp = llvm::dyn_cast<mlir::toy::ReturnOp>(entryBlock.back ());
193
+ if (!returnOp) {
194
+ mlir::Location location =
132
195
loc (ctx->start ->getLine (), ctx->start ->getCharPositionInLine ());
133
- builder.create <mlir::toy::ReturnOp>(location,
134
- llvm::ArrayRef<mlir::Value>());
196
+ builder.create <mlir::toy::ReturnOp>(location);
197
+ } else if (returnOp.hasOperand ()) {
198
+ // Otherwise, if this return operation has an operand then add a result to
199
+ // the function.
200
+ std::vector<int64_t > shape;
201
+ function.setType (builder.getFunctionType (
202
+ function.getFunctionType ().getInputs (), getType (shape)));
135
203
}
204
+ // If this function isn't main, then set the visibility to private.
205
+ if (ctx->prototype ()->Identifier ()->toString () != " main" )
206
+ function.setPrivate ();
207
+ functionMap.insert ({function.getName (), function});
136
208
return 0 ;
137
209
}
138
210
139
211
// / Prototype Visitor
140
212
virtual std::any visitPrototype (ToyParser::PrototypeContext *ctx) override {
141
213
mlir::Location location =
142
214
loc (ctx->start ->getLine (), ctx->start ->getCharPositionInLine ());
143
- auto varNumber = 0 ;
215
+ int varNumber = 0 ;
144
216
// Get the number of arguments.
145
217
if (ctx->declList ()) {
146
- auto list = ctx->declList ();
218
+ ToyParser::DeclListContext * list = ctx->declList ();
147
219
while (list->Identifier ()) {
148
220
varNumber++;
149
221
if (list->declList ())
@@ -152,26 +224,37 @@ class MLIRToyVisitor : public ToyBaseVisitor {
152
224
break ;
153
225
}
154
226
}
155
-
156
227
llvm::SmallVector<mlir::Type, 4 > argTypes (
157
228
varNumber, mlir::UnrankedTensorType::get (builder.getF64Type ()));
158
- auto funType = builder.getFunctionType (argTypes, llvm::None);
229
+ mlir::FunctionType funType = builder.getFunctionType (argTypes, llvm::None);
159
230
auto func = builder.create <mlir::toy::FuncOp>(
160
231
location, ctx->Identifier ()->toString (), funType);
161
- mlir::Block &entryblock = func.front ();
162
- builder.setInsertionPointToStart (&entryblock);
163
- return 0 ;
232
+ return func;
164
233
}
165
234
166
235
// / Expression Visitor
167
236
// / - If the expression is tensor literal, return the tensor MLIR value.
168
237
// / - If the expression is function call or variable, visit the identifier.
238
+ // / - If the expression is add expression or mul expression return add or mul
239
+ // / value.
169
240
virtual std::any visitExpression (ToyParser::ExpressionContext *ctx) override {
170
241
mlir::Value value;
171
242
if (ctx->tensorLiteral ()) {
172
243
return getTensor (ctx->tensorLiteral ());
173
244
} else if (ctx->identifierExpr ()) {
174
245
return visit (ctx->identifierExpr ());
246
+ } else if (ctx->Add () || ctx->Mul ()) {
247
+ // Derive the operation name from the binary operator. At the moment we
248
+ // only support '+' and '*'.
249
+ mlir::Value lhs = std::any_cast<mlir::Value>(visit (ctx->expression (0 )));
250
+ mlir::Value rhs = std::any_cast<mlir::Value>(visit (ctx->expression (1 )));
251
+ mlir::Location loaction =
252
+ loc (ctx->start ->getLine (), ctx->start ->getCharPositionInLine ());
253
+ if (ctx->Add ())
254
+ value = builder.create <mlir::toy::AddOp>(loaction, lhs, rhs);
255
+ else
256
+ value = builder.create <mlir::toy::MulOp>(loaction, lhs, rhs);
257
+ return value;
175
258
}
176
259
return value;
177
260
}
@@ -188,7 +271,7 @@ class MLIRToyVisitor : public ToyBaseVisitor {
188
271
std::vector<int64_t > v0;
189
272
auto v1 = ctx->type ()->Number ();
190
273
for (auto i : v1) {
191
- auto j = atoi (i->toString ().c_str ());
274
+ int64_t j = atoi (i->toString ().c_str ());
192
275
v0.push_back (j);
193
276
}
194
277
mlir::Location location =
@@ -208,28 +291,65 @@ class MLIRToyVisitor : public ToyBaseVisitor {
208
291
virtual std::any
209
292
visitIdentifierExpr (ToyParser::IdentifierExprContext *ctx) override {
210
293
mlir::Value value;
294
+ int argsNumber = 0 ;
295
+ mlir::Location location =
296
+ loc (ctx->start ->getLine (), ctx->start ->getCharPositionInLine ());
211
297
// If the identifier is a function call, visit and register all the
212
298
// arguments. [TODO][LOW] add the semantic check (look up the symbol table)
213
299
// for the function call.
214
300
if (ctx->ParentheseOpen ()) {
215
- auto location =
301
+ mlir::Location location =
216
302
loc (ctx->start ->getLine (), ctx->start ->getCharPositionInLine ());
217
303
llvm::SmallVector<mlir::Value, 4 > oprands;
218
- for (auto i : ctx->expression ()) {
304
+ for (ToyParser::ExpressionContext * i : ctx->expression ()) {
219
305
mlir::Value arg = std::any_cast<mlir::Value>(visit (i));
220
306
oprands.push_back (arg);
307
+ argsNumber++;
221
308
}
222
309
// If function call is a built-in operation, create the corresponding
223
310
// operation.
224
311
if (ctx->Identifier ()->toString () == " print" ) {
225
- auto arg = oprands[0 ];
312
+ if (argsNumber != 1 ) {
313
+ mlir::emitError (location)
314
+ << " mismatch of function parameters 'print'" ;
315
+ return nullptr ;
316
+ }
317
+ mlir::Value arg = oprands[0 ];
226
318
builder.create <mlir::toy::PrintOp>(location, arg);
227
319
return 0 ;
320
+ } else if (ctx->Identifier ()->toString () == " transpose" ) {
321
+ if (argsNumber != 1 ) {
322
+ mlir::emitError (location)
323
+ << " mlismatch of function parameters 'transpose'" ;
324
+ return nullptr ;
325
+ }
326
+ mlir::Value arg = oprands[0 ];
327
+ value = builder.create <mlir::toy::TransposeOp>(location, arg);
328
+ return value;
329
+ }
330
+ // Otherwise this is a call to a user-defined function. Calls to
331
+ // user-defined functions are mapped to a custom call that takes the
332
+ // callee name as an attribute.
333
+ auto callee = functionMap.find (ctx->Identifier ()->toString ());
334
+ if (callee == functionMap.end ()) {
335
+ mlir::emitError (location) << " error: no defined function '"
336
+ << ctx->Identifier ()->toString () << " '" ;
337
+ return nullptr ;
338
+ }
339
+ int numberdecl = funSymbolTable.lookup (ctx->Identifier ()->toString ());
340
+ if (numberdecl != argsNumber) {
341
+ mlir::emitError (location) << " error: mismatch of function parameters '"
342
+ << ctx->Identifier ()->toString () << " '" ;
343
+ return nullptr ;
228
344
}
229
345
// If the function call cannot be mapped to the built-in operation, create
230
346
// the GenericCallOp.
347
+ mlir::toy::FuncOp calledFunc = callee->second ;
231
348
value = builder.create <mlir::toy::GenericCallOp>(
232
- location, ctx->Identifier ()->toString (), oprands);
349
+ location, calledFunc.getFunctionType ().getResult (0 ),
350
+ mlir::SymbolRefAttr::get (builder.getContext (),
351
+ ctx->Identifier ()->toString ()),
352
+ oprands);
233
353
return value;
234
354
} else {
235
355
// If the identifier is a variable, return the MLIR value from the symbol
@@ -241,12 +361,11 @@ class MLIRToyVisitor : public ToyBaseVisitor {
241
361
242
362
// / Return Expression Visitor
243
363
virtual std::any visitReturnExpr (ToyParser::ReturnExprContext *ctx) override {
244
- returnFlag = true ;
245
- auto location =
364
+ mlir::Location location =
246
365
loc (ctx->start ->getLine (), ctx->start ->getCharPositionInLine ());
247
366
mlir::Value expr = nullptr ;
248
367
if (ctx->expression ()) {
249
- expr = std::any_cast<mlir::Value>(ctx->expression ());
368
+ expr = std::any_cast<mlir::Value>(visit ( ctx->expression () ));
250
369
}
251
370
// Generate return operation based on whether the function has the return
252
371
// value.
0 commit comments