Skip to content

Commit 2f3677c

Browse files
Add function to Toy DSL examples (buddy-compiler#61)
1 parent 14170de commit 2f3677c

File tree

4 files changed

+205
-35
lines changed

4 files changed

+205
-35
lines changed

examples/ToyDSL/Toy.g4

+19-4
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@ module
1010

1111
expression
1212
: Number
13-
| tensorLiteral
13+
| tensorLiteral
14+
{
15+
tensorDataBuffer.clear();
16+
}
1417
| identifierExpr
18+
| expression Mul expression
19+
| expression Add expression
1520
;
1621

1722
identifierExpr
@@ -62,14 +67,16 @@ funDefine
6267
: prototype block
6368
;
6469

65-
prototype
66-
: Def Identifier ParentheseOpen declList ParentheseClose
70+
prototype returns [std::string idName]
71+
: Def Identifier ParentheseOpen declList? ParentheseClose
72+
{
73+
$idName = $Identifier.text;
74+
}
6775
;
6876

6977
declList
7078
: Identifier
7179
| Identifier Comma declList
72-
|
7380
;
7481

7582
block
@@ -144,6 +151,14 @@ Comma
144151
: ','
145152
;
146153

154+
Add
155+
: '+'
156+
;
157+
158+
Mul
159+
: '*'
160+
;
161+
147162
WS
148163
: [ \r\n\t] -> skip
149164
;

examples/ToyDSL/function.toy

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
def fun(a) {
2+
print(a + a);
3+
print(a * a);
4+
print(transpose(a));
5+
return a * transpose(a);
6+
}
7+
8+
def main() {
9+
var a = [1, 2, 3, 4];
10+
var b = fun(a);
11+
print(b);
12+
var c<2,2> = [1, 2, 3, 4];
13+
var d = fun(c);
14+
print(d);
15+
}

examples/ToyDSL/include/MLIRToyVisitor.h

+150-31
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,9 @@ class MLIRToyVisitor : public ToyBaseVisitor {
5454
/// The builder helps create MLIR operations when traversing the AST.
5555
mlir::OpBuilder builder;
5656
/// The Symbol Table
57-
/// [TODO][LOW] make the symbol table support function prototype.
5857
llvm::ScopedHashTable<llvm::StringRef, mlir::Value> symbolTable;
59-
/// Return Status Flag
60-
/// The syntax supports omitting the return expression.
61-
bool returnFlag = false;
58+
llvm::ScopedHashTable<llvm::StringRef, int> funSymbolTable;
59+
llvm::StringMap<mlir::toy::FuncOp> functionMap;
6260
// Register the filename for the string attribute in MLIR location object.
6361
std::string fileName;
6462

@@ -71,6 +69,15 @@ class MLIRToyVisitor : public ToyBaseVisitor {
7169
symbolTable.insert(var, value);
7270
return mlir::success();
7371
}
72+
// Declear a function in the current module
73+
/// - Check the parameter number of the function.
74+
mlir::LogicalResult funcDeclare(llvm::StringRef functionName,
75+
int argsNumber) {
76+
if (funSymbolTable.count(functionName))
77+
return mlir::failure();
78+
funSymbolTable.insert(functionName, argsNumber);
79+
return mlir::success();
80+
}
7481

7582
/// Location
7683
/// Get the MLIR location object with the current line and row of the toy
@@ -89,24 +96,54 @@ class MLIRToyVisitor : public ToyBaseVisitor {
8996

9097
// Get the tensor value from the tensor literal node.
9198
std::any getTensor(ToyParser::TensorLiteralContext *ctx) {
92-
// [TODO][HIGH] find a better way to define the `dims`.
9399
std::vector<int64_t> dims;
94100
// get dimensions.
95101
dims.push_back(ctx->Comma().size() + 1);
96102
if (ctx->tensorLiteral(0)->tensorLiteral(0)) {
97-
dims.push_back(ctx->tensorLiteral(0)->Comma().size() + 1);
103+
ToyParser::TensorLiteralContext *list = ctx->tensorLiteral(0);
104+
while (list) {
105+
dims.push_back(list->Comma().size() + 1);
106+
if (list->tensorLiteral(0) && list->tensorLiteral(0)->Comma().size())
107+
list = list->tensorLiteral(0);
108+
else
109+
break;
110+
}
98111
}
99112
mlir::Type elementType = builder.getF64Type();
100-
auto type = getType(dims);
113+
mlir::Type type = getType(dims);
101114
auto dataType = mlir::RankedTensorType::get(dims, elementType);
102-
auto dataAttribute =
115+
mlir::DenseElementsAttr dataAttribute =
103116
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(ctx->data));
104-
auto loaction =
117+
mlir::Location loaction =
105118
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
106119
mlir::Value value =
107120
builder.create<mlir::toy::ConstantOp>(loaction, type, dataAttribute);
108121
return value;
109122
}
123+
// Module Visitor
124+
// - Visitor all function asts to get the number of function parameter.
125+
// - Visitor childrens.
126+
virtual std::any visitModule(ToyParser::ModuleContext *ctx) override {
127+
llvm::ScopedHashTableScope<llvm::StringRef, int> protoTypeSymbolTable(
128+
funSymbolTable);
129+
for (auto &function : ctx->funDefine()) {
130+
ToyParser::PrototypeContext *protoType = function->prototype();
131+
std::string functionName = protoType->Identifier()->toString();
132+
int declNumber = 0;
133+
if (protoType->declList()) {
134+
ToyParser::DeclListContext *list = protoType->declList();
135+
while (list) {
136+
declNumber++;
137+
if (list->declList())
138+
list = list->declList();
139+
else
140+
break;
141+
}
142+
}
143+
funcDeclare(function->prototype()->idName, declNumber);
144+
}
145+
return visitChildren(ctx);
146+
}
110147

111148
/// Function Definition Visitor
112149
/// - Register the function name, argument list, and return value into the
@@ -115,35 +152,70 @@ class MLIRToyVisitor : public ToyBaseVisitor {
115152
/// - Visit fucntion block.
116153
/// - Process the return operation.
117154
virtual std::any visitFunDefine(ToyParser::FunDefineContext *ctx) override {
118-
returnFlag = false;
119-
// [TODO] make the function support argument list and return value.
120155
llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(
121156
symbolTable);
122157
builder.setInsertionPointToEnd(theModule.getBody());
123158
// Visit function prototype.
124-
visit(ctx->prototype());
159+
mlir::toy::FuncOp function =
160+
std::any_cast<mlir::toy::FuncOp>(visit(ctx->prototype()));
161+
mlir::Block &entryBlock = function.front();
162+
163+
// Set the insertion point in the builder to the beginning of the function
164+
// body, it will be used throughout the codegen to create operations in this
165+
// function.
166+
builder.setInsertionPointToStart(&entryBlock);
167+
168+
std::vector<std::string> args;
169+
if (ctx->prototype()->declList()) {
170+
ToyParser::DeclListContext *list = ctx->prototype()->declList();
171+
while (list->Identifier()) {
172+
args.push_back(list->Identifier()->toString());
173+
if (list->declList())
174+
list = list->declList();
175+
else
176+
break;
177+
}
178+
}
179+
// Declare all the function arguments in the symbol table.
180+
llvm::ArrayRef<std::string> protoArgs = args;
181+
for (auto value : llvm::zip(protoArgs, entryBlock.getArguments())) {
182+
declare(std::get<0>(value), std::get<1>(value));
183+
}
184+
125185
// Visit fucntion block.
126186
visit(ctx->block());
127187
// Check the return status.
128188
// If there is no return expression at the end of the function, it will
129189
// generate a return operation automatically.
130-
if (!returnFlag) {
131-
auto location =
190+
mlir::toy::ReturnOp returnOp;
191+
if (!entryBlock.empty())
192+
returnOp = llvm::dyn_cast<mlir::toy::ReturnOp>(entryBlock.back());
193+
if (!returnOp) {
194+
mlir::Location location =
132195
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
133-
builder.create<mlir::toy::ReturnOp>(location,
134-
llvm::ArrayRef<mlir::Value>());
196+
builder.create<mlir::toy::ReturnOp>(location);
197+
} else if (returnOp.hasOperand()) {
198+
// Otherwise, if this return operation has an operand then add a result to
199+
// the function.
200+
std::vector<int64_t> shape;
201+
function.setType(builder.getFunctionType(
202+
function.getFunctionType().getInputs(), getType(shape)));
135203
}
204+
// If this function isn't main, then set the visibility to private.
205+
if (ctx->prototype()->Identifier()->toString() != "main")
206+
function.setPrivate();
207+
functionMap.insert({function.getName(), function});
136208
return 0;
137209
}
138210

139211
/// Prototype Visitor
140212
virtual std::any visitPrototype(ToyParser::PrototypeContext *ctx) override {
141213
mlir::Location location =
142214
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
143-
auto varNumber = 0;
215+
int varNumber = 0;
144216
// Get the number of arguments.
145217
if (ctx->declList()) {
146-
auto list = ctx->declList();
218+
ToyParser::DeclListContext *list = ctx->declList();
147219
while (list->Identifier()) {
148220
varNumber++;
149221
if (list->declList())
@@ -152,26 +224,37 @@ class MLIRToyVisitor : public ToyBaseVisitor {
152224
break;
153225
}
154226
}
155-
156227
llvm::SmallVector<mlir::Type, 4> argTypes(
157228
varNumber, mlir::UnrankedTensorType::get(builder.getF64Type()));
158-
auto funType = builder.getFunctionType(argTypes, llvm::None);
229+
mlir::FunctionType funType = builder.getFunctionType(argTypes, llvm::None);
159230
auto func = builder.create<mlir::toy::FuncOp>(
160231
location, ctx->Identifier()->toString(), funType);
161-
mlir::Block &entryblock = func.front();
162-
builder.setInsertionPointToStart(&entryblock);
163-
return 0;
232+
return func;
164233
}
165234

166235
/// Expression Visitor
167236
/// - If the expression is tensor literal, return the tensor MLIR value.
168237
/// - If the expression is function call or variable, visit the identifier.
238+
/// - If the expression is add expression or mul expression return add or mul
239+
/// value.
169240
virtual std::any visitExpression(ToyParser::ExpressionContext *ctx) override {
170241
mlir::Value value;
171242
if (ctx->tensorLiteral()) {
172243
return getTensor(ctx->tensorLiteral());
173244
} else if (ctx->identifierExpr()) {
174245
return visit(ctx->identifierExpr());
246+
} else if (ctx->Add() || ctx->Mul()) {
247+
// Derive the operation name from the binary operator. At the moment we
248+
// only support '+' and '*'.
249+
mlir::Value lhs = std::any_cast<mlir::Value>(visit(ctx->expression(0)));
250+
mlir::Value rhs = std::any_cast<mlir::Value>(visit(ctx->expression(1)));
251+
mlir::Location loaction =
252+
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
253+
if (ctx->Add())
254+
value = builder.create<mlir::toy::AddOp>(loaction, lhs, rhs);
255+
else
256+
value = builder.create<mlir::toy::MulOp>(loaction, lhs, rhs);
257+
return value;
175258
}
176259
return value;
177260
}
@@ -188,7 +271,7 @@ class MLIRToyVisitor : public ToyBaseVisitor {
188271
std::vector<int64_t> v0;
189272
auto v1 = ctx->type()->Number();
190273
for (auto i : v1) {
191-
auto j = atoi(i->toString().c_str());
274+
int64_t j = atoi(i->toString().c_str());
192275
v0.push_back(j);
193276
}
194277
mlir::Location location =
@@ -208,28 +291,65 @@ class MLIRToyVisitor : public ToyBaseVisitor {
208291
virtual std::any
209292
visitIdentifierExpr(ToyParser::IdentifierExprContext *ctx) override {
210293
mlir::Value value;
294+
int argsNumber = 0;
295+
mlir::Location location =
296+
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
211297
// If the identifier is a function call, visit and register all the
212298
// arguments. [TODO][LOW] add the semantic check (look up the symbol table)
213299
// for the function call.
214300
if (ctx->ParentheseOpen()) {
215-
auto location =
301+
mlir::Location location =
216302
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
217303
llvm::SmallVector<mlir::Value, 4> oprands;
218-
for (auto i : ctx->expression()) {
304+
for (ToyParser::ExpressionContext *i : ctx->expression()) {
219305
mlir::Value arg = std::any_cast<mlir::Value>(visit(i));
220306
oprands.push_back(arg);
307+
argsNumber++;
221308
}
222309
// If function call is a built-in operation, create the corresponding
223310
// operation.
224311
if (ctx->Identifier()->toString() == "print") {
225-
auto arg = oprands[0];
312+
if (argsNumber != 1) {
313+
mlir::emitError(location)
314+
<< "mismatch of function parameters 'print'";
315+
return nullptr;
316+
}
317+
mlir::Value arg = oprands[0];
226318
builder.create<mlir::toy::PrintOp>(location, arg);
227319
return 0;
320+
} else if (ctx->Identifier()->toString() == "transpose") {
321+
if (argsNumber != 1) {
322+
mlir::emitError(location)
323+
<< "mlismatch of function parameters 'transpose'";
324+
return nullptr;
325+
}
326+
mlir::Value arg = oprands[0];
327+
value = builder.create<mlir::toy::TransposeOp>(location, arg);
328+
return value;
329+
}
330+
// Otherwise this is a call to a user-defined function. Calls to
331+
// user-defined functions are mapped to a custom call that takes the
332+
// callee name as an attribute.
333+
auto callee = functionMap.find(ctx->Identifier()->toString());
334+
if (callee == functionMap.end()) {
335+
mlir::emitError(location) << "error: no defined function '"
336+
<< ctx->Identifier()->toString() << "'";
337+
return nullptr;
338+
}
339+
int numberdecl = funSymbolTable.lookup(ctx->Identifier()->toString());
340+
if (numberdecl != argsNumber) {
341+
mlir::emitError(location) << "error: mismatch of function parameters '"
342+
<< ctx->Identifier()->toString() << "'";
343+
return nullptr;
228344
}
229345
// If the function call cannot be mapped to the built-in operation, create
230346
// the GenericCallOp.
347+
mlir::toy::FuncOp calledFunc = callee->second;
231348
value = builder.create<mlir::toy::GenericCallOp>(
232-
location, ctx->Identifier()->toString(), oprands);
349+
location, calledFunc.getFunctionType().getResult(0),
350+
mlir::SymbolRefAttr::get(builder.getContext(),
351+
ctx->Identifier()->toString()),
352+
oprands);
233353
return value;
234354
} else {
235355
// If the identifier is a variable, return the MLIR value from the symbol
@@ -241,12 +361,11 @@ class MLIRToyVisitor : public ToyBaseVisitor {
241361

242362
/// Return Expression Visitor
243363
virtual std::any visitReturnExpr(ToyParser::ReturnExprContext *ctx) override {
244-
returnFlag = true;
245-
auto location =
364+
mlir::Location location =
246365
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
247366
mlir::Value expr = nullptr;
248367
if (ctx->expression()) {
249-
expr = std::any_cast<mlir::Value>(ctx->expression());
368+
expr = std::any_cast<mlir::Value>(visit(ctx->expression()));
250369
}
251370
// Generate return operation based on whether the function has the return
252371
// value.

examples/ToyDSL/makefile

+21
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,24 @@ buddy-toy-constant-translate:
3838

3939
buddy-toy-constant-run:
4040
@${BUDDY_TOY_DSL} ./constant.toy -emit=jit
41+
42+
toyc-function-run:
43+
@${MLIR_TOYC} ./function.toy -emit=jit
44+
45+
buddy-toy-function-ast:
46+
@${BUDDY_TOY_DSL} ./function.toy -emit=ast
47+
48+
buddy-toy-function-mlir:
49+
@${BUDDY_TOY_DSL} ./function.toy -emit=mlir
50+
51+
buddy-toy-function-affine:
52+
@${BUDDY_TOY_DSL} ./function.toy -emit=mlir-affine
53+
54+
buddy-toy-function-llvm:
55+
@${BUDDY_TOY_DSL} ./function.toy -emit=mlir-llvm
56+
57+
buddy-toy-function-translate:
58+
@${BUDDY_TOY_DSL} ./function.toy -emit=llvm
59+
60+
buddy-toy-function-run:
61+
@${BUDDY_TOY_DSL} ./function.toy -emit=jit

0 commit comments

Comments
 (0)