Skip to content

Commit

Permalink
Merge pull request #71 from czgdp1807/functions_01
Browse files Browse the repository at this point in the history
Added support for pre-declaring functions before defining
  • Loading branch information
czgdp1807 authored Jan 23, 2024
2 parents c6bff74 + 2d996f1 commit 55bb448
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 29 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,5 @@ RUN(NAME array_20.cpp LABELS gcc llvm NOFAST)
RUN(NAME array_21.cpp LABELS gcc llvm NOFAST)
RUN(NAME array_22.cpp LABELS gcc llvm NOFAST)
RUN(NAME array_23.cpp LABELS gcc llvm NOFAST)

RUN(NAME function_01.cpp LABELS gcc llvm NOFAST)
3 changes: 3 additions & 0 deletions integration_tests/array_18.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include <xtensor/xfixed.hpp>
#include "xtensor/xio.hpp"

xt::xtensor<double, 1> solution();
void compare_solutions(const xt::xtensor<double, 1>& y);

xt::xtensor<double, 1> solution() {
xt::xtensor<double, 1> x = xt::empty<double>({2});
x = {1.0, 2.0};
Expand Down
39 changes: 39 additions & 0 deletions integration_tests/function_01.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <iostream>

float f_real(const float);

float f(const float a) {
float b, x;
x = 2;
b = a + f_real(0.0);
return b;
}

float f_real(const float a) {
float b;
if( a == 0.0 ) {
b = 2.0;
} else {
b = a + f(1.0);
}
return b;
}

int main() {

float x = 5, y;
float p = 5, q;
float a, b, c;
y = f(x);
std::cout << y << std::endl;
if( y != 7.0 ) {
exit(2);
}

q = f_real(p);
std::cout << q << std::endl;
if( q != 8.0 ) {
exit(2);
}

}
79 changes: 50 additions & 29 deletions src/lc/clang_ast_to_asr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1104,43 +1104,64 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

bool TraverseFunctionDecl(clang::FunctionDecl *x) {
SymbolTable* parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);

std::string name = x->getName().str();
Vec<ASR::expr_t*> args;
args.reserve(al, 1);
for (auto &p : x->parameters()) {
TraverseDecl(p);
args.push_back(al, ASRUtils::EXPR(tmp.get()));
}
ASR::symbol_t* current_function_symbol = parent_scope->resolve_symbol(name);
ASR::Function_t* current_function = nullptr;

if( current_function_symbol == nullptr ) {
current_scope = al.make_new<SymbolTable>(parent_scope);
Vec<ASR::expr_t*> args;
args.reserve(al, 1);
for (auto &p : x->parameters()) {
TraverseDecl(p);
args.push_back(al, ASRUtils::EXPR(tmp.get()));
}

ASR::ttype_t* return_type = ClangTypeToASRType(x->getReturnType());
ASR::symbol_t* return_sym = nullptr;
ASR::expr_t* return_var = nullptr;
if (return_type != nullptr) {
return_sym = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(al, Lloc(x),
current_scope, s2c(al, "__return_var"), nullptr, 0, ASR::intentType::ReturnVar, nullptr, nullptr,
ASR::storage_typeType::Default, return_type, nullptr, ASR::abiType::Source, ASR::accessType::Public,
ASR::presenceType::Required, false));
current_scope->add_symbol("__return_var", return_sym);
}
ASR::ttype_t* return_type = ClangTypeToASRType(x->getReturnType());
ASR::symbol_t* return_sym = nullptr;
ASR::expr_t* return_var = nullptr;
if (return_type != nullptr) {
return_sym = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(al, Lloc(x),
current_scope, s2c(al, "__return_var"), nullptr, 0, ASR::intentType::ReturnVar, nullptr, nullptr,
ASR::storage_typeType::Default, return_type, nullptr, ASR::abiType::Source, ASR::accessType::Public,
ASR::presenceType::Required, false));
current_scope->add_symbol("__return_var", return_sym);
}

if (return_type != nullptr) {
return_var = ASRUtils::EXPR(ASR::make_Var_t(al, return_sym->base.loc, return_sym));
}
if (return_type != nullptr) {
return_var = ASRUtils::EXPR(ASR::make_Var_t(al, return_sym->base.loc, return_sym));
}

tmp = ASRUtils::make_Function_t_util(al, Lloc(x), current_scope, s2c(al, name), nullptr, 0,
args.p, args.size(), nullptr, 0, return_var, ASR::abiType::Source, ASR::accessType::Public,
ASR::deftypeType::Implementation, nullptr, false, false, false, false, false, nullptr, 0,
false, false, false);
ASR::symbol_t* current_function_symbol = ASR::down_cast<ASR::symbol_t>(tmp.get());
ASR::Function_t* current_function = ASR::down_cast<ASR::Function_t>(current_function_symbol);
parent_scope->add_symbol(name, current_function_symbol);
tmp = ASRUtils::make_Function_t_util(al, Lloc(x), current_scope, s2c(al, name), nullptr, 0,
args.p, args.size(), nullptr, 0, return_var, ASR::abiType::Source, ASR::accessType::Public,
ASR::deftypeType::Implementation, nullptr, false, false, false, false, false, nullptr, 0,
false, false, false);
current_function_symbol = ASR::down_cast<ASR::symbol_t>(tmp.get());
current_function = ASR::down_cast<ASR::Function_t>(current_function_symbol);
current_scope = current_function->m_symtab;
parent_scope->add_symbol(name, current_function_symbol);
} else {
current_function = ASR::down_cast<ASR::Function_t>(current_function_symbol);
current_scope = current_function->m_symtab;
for( size_t i = 0; i < current_function->n_args; i++ ) {
ASR::Var_t* argi = ASR::down_cast<ASR::Var_t>(current_function->m_args[i]);
if( current_scope->get_symbol(ASRUtils::symbol_name(argi->m_v)) == argi->m_v ) {
current_scope->erase_symbol(ASRUtils::symbol_name(argi->m_v));
}
}

int i = 0;
for (auto &p : x->parameters()) {
TraverseDecl(p);
current_function->m_args[i] = ASRUtils::EXPR(tmp.get());
i++;
}
}

Vec<ASR::stmt_t*>* current_body_copy = current_body;
Vec<ASR::stmt_t*> body; body.reserve(al, 1);
current_body = &body;
if( x->hasBody() ) {
if( x->doesThisDeclarationHaveABody() ) {
TraverseStmt(x->getBody());
}
current_body = current_body_copy;
Expand Down

0 comments on commit 55bb448

Please sign in to comment.