Skip to content

Commit

Permalink
Merge pull request #72 from czgdp1807/structs_01
Browse files Browse the repository at this point in the history
Added support for ``struct``
  • Loading branch information
czgdp1807 authored Jan 24, 2024
2 parents 7deef14 + 638b47a commit 0c22405
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 17 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,6 @@ RUN(NAME array_21.cpp LABELS gcc llvm NOFAST)
RUN(NAME array_22.cpp LABELS gcc llvm NOFAST)
RUN(NAME array_23.cpp LABELS gcc llvm NOFAST)

RUN(NAME struct_01.cpp LABELS gcc llvm NOFAST)

RUN(NAME function_01.cpp LABELS gcc llvm NOFAST)
81 changes: 81 additions & 0 deletions integration_tests/struct_01.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#include <iostream>
#include <complex>

using namespace std::complex_literals;

// module derived_types_01_m_01
struct X {
float r;
int i;
};

struct Y {
std::complex<double> c;
struct X d;
};

void set(struct X& a) {
a.i = 1;
a.r = 1.5;
}
// module derived_types_01_m_01

// module derived_types_01_m_02
struct Z {
std::complex<double> k;
struct Y l;
};
// module derived_types_01_m_02

int main() {

struct X b;
struct Z c;

b.i = 5;
b.r = 3.5;
std::cout << b.i << " " << b.r << std::endl;
if( b.i != 5 ) {
exit(2);
}
if( b.r != 3.5 ) {
exit(2);
}

set(b);
std::cout << b.i << " " << b.r << std::endl;
if( b.i != 1 ) {
exit(2);
}
if( b.r != 1.5 ) {
exit(2);
}

c.l.d.r = 2.0;
c.l.d.i = 2;
c.k = 2.0 + 2i;
std::cout << c.l.d.r << " " << c.l.d.i << " " << c.k << std::endl;
if( c.l.d.r != 2.0 ) {
exit(2);
}
if( c.l.d.i != 2 ) {
exit(2);
}
if( c.k != 2.0 + 2i ) {
exit(2);
}

set(c.l.d);
std::cout << c.l.d.r << " " << c.l.d.i << " " << c.k << std::endl;
if( c.l.d.r != 1.5 ) {
exit(2);
}
if( c.l.d.i != 1.0 ) {
exit(2);
}
if( c.k != 2.0 + 2i ) {
exit(2);
}

return 0;
}
94 changes: 87 additions & 7 deletions src/lc/clang_ast_to_asr.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,14 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
} else {
throw std::runtime_error(std::string("Unrecognized type ") + template_name);
}
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::Record ) {
const clang::CXXRecordDecl* record_type = clang_type->getAsCXXRecordDecl();
std::string struct_name = record_type->getNameAsString();
ASR::symbol_t* struct_t = current_scope->resolve_symbol(struct_name);
if( !struct_t ) {
throw std::runtime_error(struct_name + " not defined.");
}
type = ASRUtils::TYPE(ASR::make_Struct_t(al, l, struct_t));
} else {
throw std::runtime_error("clang::QualType not yet supported " +
std::string(clang_type->getTypeClassName()));
Expand All @@ -486,6 +494,40 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
return type;
}

bool TraverseCXXRecordDecl(clang::CXXRecordDecl* x) {
SymbolTable* parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);
std::string struct_name = x->getNameAsString();
Vec<char*> field_names; field_names.reserve(al, 1);
for( auto field = x->field_begin(); field != x->field_end(); field++ ) {
clang::FieldDecl* field_decl = *field;
TraverseFieldDecl(field_decl);
field_names.push_back(al, s2c(al, field_decl->getNameAsString()));
}
ASR::symbol_t* struct_t = ASR::down_cast<ASR::symbol_t>(ASR::make_StructType_t(al, Lloc(x), current_scope,
s2c(al, struct_name), nullptr, 0, field_names.p, field_names.size(), ASR::abiType::Source,
ASR::accessType::Public, false, x->isAbstract(), nullptr, 0, nullptr, nullptr));
parent_scope->add_symbol(struct_name, struct_t);
current_scope = parent_scope;
return true;
}

bool TraverseFieldDecl(clang::FieldDecl* x) {
std::string name = x->getName().str();
if( current_scope->get_symbol(name) ) {
throw std::runtime_error(name + std::string(" is already defined."));
}

ASR::ttype_t *asr_type = ClangTypeToASRType(x->getType());
ASR::symbol_t *v = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(al, Lloc(x),
current_scope, s2c(al, name), nullptr, 0, ASR::intentType::Local, nullptr, nullptr,
ASR::storage_typeType::Default, asr_type, nullptr, ASR::abiType::Source,
ASR::accessType::Public, ASR::presenceType::Required, false));
current_scope->add_symbol(name, v);
is_stmt_created = false;
return true;
}

bool TraverseParmVarDecl(clang::ParmVarDecl* x) {
std::string name = x->getName().str();
if( name == "" ) {
Expand Down Expand Up @@ -549,8 +591,33 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

bool TraverseMemberExpr(clang::MemberExpr* x) {
member_name_obj.set(x->getMemberDecl()->getNameAsString());
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseMemberExpr(x);
TraverseStmt(x->getBase());
ASR::expr_t* base = ASRUtils::EXPR(tmp.get());
std::string member_name = x->getMemberDecl()->getNameAsString();
if( ASR::is_a<ASR::Struct_t>(*ASRUtils::expr_type(base)) ) {
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(ASRUtils::expr_type(base));
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
ASRUtils::symbol_get_past_external(struct_t->m_derived_type));
ASR::symbol_t* member = struct_type_t->m_symtab->resolve_symbol(member_name);
if( !member ) {
throw std::runtime_error(member_name + " not found in the scope of " + struct_type_t->m_name);
}
std::string mangled_name = current_scope->get_unique_name(
member_name + "@" + struct_type_t->m_name);
member = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(
al, Lloc(x), current_scope, s2c(al, mangled_name), member,
struct_type_t->m_name, nullptr, 0, s2c(al, member_name),
ASR::accessType::Public));
current_scope->add_symbol(mangled_name, member);
tmp = ASR::make_StructInstanceMember_t(al, Lloc(x), base, member,
ASRUtils::symbol_type(member), nullptr);
} else if( special_function_map.find(member_name) != special_function_map.end() ) {
member_name_obj.set(member_name);
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseMemberExpr(x);
} else {
throw std::runtime_error("clang::MemberExpr only supported for struct and special functions.");
}
return true;
}

bool TraverseCXXMemberCallExpr(clang::CXXMemberCallExpr* x) {
Expand Down Expand Up @@ -679,6 +746,15 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
is_stmt_created = true;
}
assignment_target = nullptr;
} else if( ASRUtils::is_complex(*ASRUtils::expr_type(obj)) ) {
TraverseStmt(args[1]);
if( !is_stmt_created ) {
ASR::expr_t* value = ASRUtils::EXPR(tmp.get());
cast_helper(obj, value, true);
tmp = ASR::make_Assignment_t(al, Lloc(x), obj, value, nullptr);
is_stmt_created = true;
}
assignment_target = nullptr;
} else {
throw std::runtime_error("operator= is supported only for arrays.");
}
Expand Down Expand Up @@ -1132,10 +1208,10 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
return_var = ASRUtils::EXPR(ASR::make_Var_t(al, return_sym->base.loc, return_sym));
}

tmp = ASRUtils::make_Function_t_util(al, Lloc(x), current_scope, s2c(al, name), nullptr, 0,
args.p, args.size(), nullptr, 0, return_var, ASR::abiType::Source, ASR::accessType::Public,
ASR::deftypeType::Implementation, nullptr, false, false, false, false, false, nullptr, 0,
false, false, false);
tmp = ASRUtils::make_Function_t_util(al, Lloc(x), current_scope, s2c(al, name),
nullptr, 0, args.p, args.size(), nullptr, 0, return_var, ASR::abiType::Source,
ASR::accessType::Public, ASR::deftypeType::Implementation, nullptr, false, false,
false, false, false, nullptr, 0, false, false, false);
current_function_symbol = ASR::down_cast<ASR::symbol_t>(tmp.get());
current_function = ASR::down_cast<ASR::Function_t>(current_function_symbol);
current_scope = current_function->m_symtab;
Expand Down Expand Up @@ -2023,11 +2099,15 @@ class FindNamedClassConsumer: public clang::ASTConsumer {

public:

Allocator& al;

explicit FindNamedClassConsumer(clang::ASTContext *Context,
Allocator& al_, ASR::asr_t*& tu_): Visitor(Context, al_, tu_) {}
Allocator& al_, ASR::asr_t*& tu_): al(al_), Visitor(Context, al_, tu_) {}

virtual void HandleTranslationUnit(clang::ASTContext &Context) {
Visitor.TraverseDecl(Context.getTranslationUnitDecl());
PassUtils::UpdateDependenciesVisitor dependency_visitor(al);
dependency_visitor.visit_TranslationUnit(*((ASR::TranslationUnit_t*) Visitor.tu));
}

private:
Expand Down
6 changes: 4 additions & 2 deletions src/libasr/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_i
}
return sym;
}

void handle_Var(ASR::expr_t* arg_expr, ASR::expr_t** expr_to_replace) {
if (ASR::is_a<ASR::Var_t>(*arg_expr)) {
ASR::Var_t* arg_var = ASR::down_cast<ASR::Var_t>(arg_expr);
Expand Down Expand Up @@ -1461,8 +1461,10 @@ void make_ArrayBroadcast_t_util(Allocator& al, const Location& loc,
size_t expr1_ndims) {
ASR::ttype_t* expr1_type = ASRUtils::expr_type(expr1);
Vec<ASR::expr_t*> shape_args;
shape_args.reserve(al, 1);
shape_args.reserve(al, 2);
shape_args.push_back(al, expr1);
shape_args.push_back(al,
make_ConstantWithKind(make_IntegerConstant_t, make_Integer_t, 4, 4, loc));

Vec<ASR::dimension_t> dims;
dims.reserve(al, 1);
Expand Down
23 changes: 19 additions & 4 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ namespace LCompilers {
}

void visit_FunctionCall(const ASR::FunctionCall_t& x) {
if (fill_function_dependencies) {
if (fill_function_dependencies) {
ASR::symbol_t* asr_owner_sym = nullptr;
if (current_scope->asr_owner && ASR::is_a<ASR::symbol_t>(*current_scope->asr_owner)) {
asr_owner_sym = ASR::down_cast<ASR::symbol_t>(current_scope->asr_owner);
Expand All @@ -352,7 +352,7 @@ namespace LCompilers {
}
} else {
function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name));
}
}
}
}
if( ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) &&
Expand All @@ -373,7 +373,7 @@ namespace LCompilers {
}

SymbolTable* temp_scope = current_scope;

if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter() &&
!ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) && !ASR::is_a<ASR::Variable_t>(*x.m_name)) {
if (ASR::is_a<ASR::AssociateBlock_t>(*asr_owner_sym) || ASR::is_a<ASR::Block_t>(*asr_owner_sym)) {
Expand All @@ -383,7 +383,7 @@ namespace LCompilers {
}
} else {
function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name));
}
}
}
}

Expand Down Expand Up @@ -430,6 +430,21 @@ namespace LCompilers {
}
current_scope = parent_symtab;
}

void visit_StructType(const ASR::StructType_t& x) {
ASR::StructType_t& xx = const_cast<ASR::StructType_t&>(x);
SetChar vec; vec.reserve(al, 1);
for( auto itr: x.m_symtab->get_scope() ) {
ASR::ttype_t* type = ASRUtils::extract_type(
ASRUtils::symbol_type(itr.second));
if( ASR::is_a<ASR::Struct_t>(*type) ) {
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(type);
vec.push_back(al, ASRUtils::symbol_name(struct_t->m_derived_type));
}
}
xx.m_dependencies = vec.p;
xx.n_dependencies = vec.size();
}
};

namespace ReplacerUtils {
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-array_02-ec70729.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-array_02-ec70729.stdout",
"stdout_hash": "be1e2e3174f008918af1f48265f78ff96e7fe5bd39cb3331be787e64",
"stdout_hash": "f2ab5c25354dd2f1f779bae37cf83ab5101f4a335f967764b63153a9",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
8 changes: 5 additions & 3 deletions tests/reference/asr-array_02-ec70729.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
[]
.false.
)
[]
[matmul_test]
[]
[(=
(Var 3 sum)
Expand Down Expand Up @@ -155,7 +155,7 @@
(Variable
2
arr1d
[]
[n]
Local
()
()
Expand All @@ -176,7 +176,9 @@
(Variable
2
arr2d
[]
[nums
m
n]
Local
()
()
Expand Down

0 comments on commit 0c22405

Please sign in to comment.