Skip to content

Commit 0c22405

Browse files
authored
Merge pull request #72 from czgdp1807/structs_01
Added support for ``struct``
2 parents 7deef14 + 638b47a commit 0c22405

File tree

7 files changed

+199
-17
lines changed

7 files changed

+199
-17
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,4 +207,6 @@ RUN(NAME array_21.cpp LABELS gcc llvm NOFAST)
207207
RUN(NAME array_22.cpp LABELS gcc llvm NOFAST)
208208
RUN(NAME array_23.cpp LABELS gcc llvm NOFAST)
209209

210+
RUN(NAME struct_01.cpp LABELS gcc llvm NOFAST)
211+
210212
RUN(NAME function_01.cpp LABELS gcc llvm NOFAST)

integration_tests/struct_01.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include <iostream>
2+
#include <complex>
3+
4+
using namespace std::complex_literals;
5+
6+
// module derived_types_01_m_01
7+
struct X {
8+
float r;
9+
int i;
10+
};
11+
12+
struct Y {
13+
std::complex<double> c;
14+
struct X d;
15+
};
16+
17+
void set(struct X& a) {
18+
a.i = 1;
19+
a.r = 1.5;
20+
}
21+
// module derived_types_01_m_01
22+
23+
// module derived_types_01_m_02
24+
struct Z {
25+
std::complex<double> k;
26+
struct Y l;
27+
};
28+
// module derived_types_01_m_02
29+
30+
int main() {
31+
32+
struct X b;
33+
struct Z c;
34+
35+
b.i = 5;
36+
b.r = 3.5;
37+
std::cout << b.i << " " << b.r << std::endl;
38+
if( b.i != 5 ) {
39+
exit(2);
40+
}
41+
if( b.r != 3.5 ) {
42+
exit(2);
43+
}
44+
45+
set(b);
46+
std::cout << b.i << " " << b.r << std::endl;
47+
if( b.i != 1 ) {
48+
exit(2);
49+
}
50+
if( b.r != 1.5 ) {
51+
exit(2);
52+
}
53+
54+
c.l.d.r = 2.0;
55+
c.l.d.i = 2;
56+
c.k = 2.0 + 2i;
57+
std::cout << c.l.d.r << " " << c.l.d.i << " " << c.k << std::endl;
58+
if( c.l.d.r != 2.0 ) {
59+
exit(2);
60+
}
61+
if( c.l.d.i != 2 ) {
62+
exit(2);
63+
}
64+
if( c.k != 2.0 + 2i ) {
65+
exit(2);
66+
}
67+
68+
set(c.l.d);
69+
std::cout << c.l.d.r << " " << c.l.d.i << " " << c.k << std::endl;
70+
if( c.l.d.r != 1.5 ) {
71+
exit(2);
72+
}
73+
if( c.l.d.i != 1.0 ) {
74+
exit(2);
75+
}
76+
if( c.k != 2.0 + 2i ) {
77+
exit(2);
78+
}
79+
80+
return 0;
81+
}

src/lc/clang_ast_to_asr.h

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,14 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
475475
} else {
476476
throw std::runtime_error(std::string("Unrecognized type ") + template_name);
477477
}
478+
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::Record ) {
479+
const clang::CXXRecordDecl* record_type = clang_type->getAsCXXRecordDecl();
480+
std::string struct_name = record_type->getNameAsString();
481+
ASR::symbol_t* struct_t = current_scope->resolve_symbol(struct_name);
482+
if( !struct_t ) {
483+
throw std::runtime_error(struct_name + " not defined.");
484+
}
485+
type = ASRUtils::TYPE(ASR::make_Struct_t(al, l, struct_t));
478486
} else {
479487
throw std::runtime_error("clang::QualType not yet supported " +
480488
std::string(clang_type->getTypeClassName()));
@@ -486,6 +494,40 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
486494
return type;
487495
}
488496

497+
bool TraverseCXXRecordDecl(clang::CXXRecordDecl* x) {
498+
SymbolTable* parent_scope = current_scope;
499+
current_scope = al.make_new<SymbolTable>(parent_scope);
500+
std::string struct_name = x->getNameAsString();
501+
Vec<char*> field_names; field_names.reserve(al, 1);
502+
for( auto field = x->field_begin(); field != x->field_end(); field++ ) {
503+
clang::FieldDecl* field_decl = *field;
504+
TraverseFieldDecl(field_decl);
505+
field_names.push_back(al, s2c(al, field_decl->getNameAsString()));
506+
}
507+
ASR::symbol_t* struct_t = ASR::down_cast<ASR::symbol_t>(ASR::make_StructType_t(al, Lloc(x), current_scope,
508+
s2c(al, struct_name), nullptr, 0, field_names.p, field_names.size(), ASR::abiType::Source,
509+
ASR::accessType::Public, false, x->isAbstract(), nullptr, 0, nullptr, nullptr));
510+
parent_scope->add_symbol(struct_name, struct_t);
511+
current_scope = parent_scope;
512+
return true;
513+
}
514+
515+
bool TraverseFieldDecl(clang::FieldDecl* x) {
516+
std::string name = x->getName().str();
517+
if( current_scope->get_symbol(name) ) {
518+
throw std::runtime_error(name + std::string(" is already defined."));
519+
}
520+
521+
ASR::ttype_t *asr_type = ClangTypeToASRType(x->getType());
522+
ASR::symbol_t *v = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(al, Lloc(x),
523+
current_scope, s2c(al, name), nullptr, 0, ASR::intentType::Local, nullptr, nullptr,
524+
ASR::storage_typeType::Default, asr_type, nullptr, ASR::abiType::Source,
525+
ASR::accessType::Public, ASR::presenceType::Required, false));
526+
current_scope->add_symbol(name, v);
527+
is_stmt_created = false;
528+
return true;
529+
}
530+
489531
bool TraverseParmVarDecl(clang::ParmVarDecl* x) {
490532
std::string name = x->getName().str();
491533
if( name == "" ) {
@@ -549,8 +591,33 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
549591
}
550592

551593
bool TraverseMemberExpr(clang::MemberExpr* x) {
552-
member_name_obj.set(x->getMemberDecl()->getNameAsString());
553-
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseMemberExpr(x);
594+
TraverseStmt(x->getBase());
595+
ASR::expr_t* base = ASRUtils::EXPR(tmp.get());
596+
std::string member_name = x->getMemberDecl()->getNameAsString();
597+
if( ASR::is_a<ASR::Struct_t>(*ASRUtils::expr_type(base)) ) {
598+
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(ASRUtils::expr_type(base));
599+
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
600+
ASRUtils::symbol_get_past_external(struct_t->m_derived_type));
601+
ASR::symbol_t* member = struct_type_t->m_symtab->resolve_symbol(member_name);
602+
if( !member ) {
603+
throw std::runtime_error(member_name + " not found in the scope of " + struct_type_t->m_name);
604+
}
605+
std::string mangled_name = current_scope->get_unique_name(
606+
member_name + "@" + struct_type_t->m_name);
607+
member = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(
608+
al, Lloc(x), current_scope, s2c(al, mangled_name), member,
609+
struct_type_t->m_name, nullptr, 0, s2c(al, member_name),
610+
ASR::accessType::Public));
611+
current_scope->add_symbol(mangled_name, member);
612+
tmp = ASR::make_StructInstanceMember_t(al, Lloc(x), base, member,
613+
ASRUtils::symbol_type(member), nullptr);
614+
} else if( special_function_map.find(member_name) != special_function_map.end() ) {
615+
member_name_obj.set(member_name);
616+
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseMemberExpr(x);
617+
} else {
618+
throw std::runtime_error("clang::MemberExpr only supported for struct and special functions.");
619+
}
620+
return true;
554621
}
555622

556623
bool TraverseCXXMemberCallExpr(clang::CXXMemberCallExpr* x) {
@@ -679,6 +746,15 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
679746
is_stmt_created = true;
680747
}
681748
assignment_target = nullptr;
749+
} else if( ASRUtils::is_complex(*ASRUtils::expr_type(obj)) ) {
750+
TraverseStmt(args[1]);
751+
if( !is_stmt_created ) {
752+
ASR::expr_t* value = ASRUtils::EXPR(tmp.get());
753+
cast_helper(obj, value, true);
754+
tmp = ASR::make_Assignment_t(al, Lloc(x), obj, value, nullptr);
755+
is_stmt_created = true;
756+
}
757+
assignment_target = nullptr;
682758
} else {
683759
throw std::runtime_error("operator= is supported only for arrays.");
684760
}
@@ -1132,10 +1208,10 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
11321208
return_var = ASRUtils::EXPR(ASR::make_Var_t(al, return_sym->base.loc, return_sym));
11331209
}
11341210

1135-
tmp = ASRUtils::make_Function_t_util(al, Lloc(x), current_scope, s2c(al, name), nullptr, 0,
1136-
args.p, args.size(), nullptr, 0, return_var, ASR::abiType::Source, ASR::accessType::Public,
1137-
ASR::deftypeType::Implementation, nullptr, false, false, false, false, false, nullptr, 0,
1138-
false, false, false);
1211+
tmp = ASRUtils::make_Function_t_util(al, Lloc(x), current_scope, s2c(al, name),
1212+
nullptr, 0, args.p, args.size(), nullptr, 0, return_var, ASR::abiType::Source,
1213+
ASR::accessType::Public, ASR::deftypeType::Implementation, nullptr, false, false,
1214+
false, false, false, nullptr, 0, false, false, false);
11391215
current_function_symbol = ASR::down_cast<ASR::symbol_t>(tmp.get());
11401216
current_function = ASR::down_cast<ASR::Function_t>(current_function_symbol);
11411217
current_scope = current_function->m_symtab;
@@ -2023,11 +2099,15 @@ class FindNamedClassConsumer: public clang::ASTConsumer {
20232099

20242100
public:
20252101

2102+
Allocator& al;
2103+
20262104
explicit FindNamedClassConsumer(clang::ASTContext *Context,
2027-
Allocator& al_, ASR::asr_t*& tu_): Visitor(Context, al_, tu_) {}
2105+
Allocator& al_, ASR::asr_t*& tu_): al(al_), Visitor(Context, al_, tu_) {}
20282106

20292107
virtual void HandleTranslationUnit(clang::ASTContext &Context) {
20302108
Visitor.TraverseDecl(Context.getTranslationUnitDecl());
2109+
PassUtils::UpdateDependenciesVisitor dependency_visitor(al);
2110+
dependency_visitor.visit_TranslationUnit(*((ASR::TranslationUnit_t*) Visitor.tu));
20312111
}
20322112

20332113
private:

src/libasr/asr_utils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_i
190190
}
191191
return sym;
192192
}
193-
193+
194194
void handle_Var(ASR::expr_t* arg_expr, ASR::expr_t** expr_to_replace) {
195195
if (ASR::is_a<ASR::Var_t>(*arg_expr)) {
196196
ASR::Var_t* arg_var = ASR::down_cast<ASR::Var_t>(arg_expr);
@@ -1461,8 +1461,10 @@ void make_ArrayBroadcast_t_util(Allocator& al, const Location& loc,
14611461
size_t expr1_ndims) {
14621462
ASR::ttype_t* expr1_type = ASRUtils::expr_type(expr1);
14631463
Vec<ASR::expr_t*> shape_args;
1464-
shape_args.reserve(al, 1);
1464+
shape_args.reserve(al, 2);
14651465
shape_args.push_back(al, expr1);
1466+
shape_args.push_back(al,
1467+
make_ConstantWithKind(make_IntegerConstant_t, make_Integer_t, 4, 4, loc));
14661468

14671469
Vec<ASR::dimension_t> dims;
14681470
dims.reserve(al, 1);

src/libasr/pass/pass_utils.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ namespace LCompilers {
335335
}
336336

337337
void visit_FunctionCall(const ASR::FunctionCall_t& x) {
338-
if (fill_function_dependencies) {
338+
if (fill_function_dependencies) {
339339
ASR::symbol_t* asr_owner_sym = nullptr;
340340
if (current_scope->asr_owner && ASR::is_a<ASR::symbol_t>(*current_scope->asr_owner)) {
341341
asr_owner_sym = ASR::down_cast<ASR::symbol_t>(current_scope->asr_owner);
@@ -352,7 +352,7 @@ namespace LCompilers {
352352
}
353353
} else {
354354
function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name));
355-
}
355+
}
356356
}
357357
}
358358
if( ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) &&
@@ -373,7 +373,7 @@ namespace LCompilers {
373373
}
374374

375375
SymbolTable* temp_scope = current_scope;
376-
376+
377377
if (asr_owner_sym && temp_scope->get_counter() != ASRUtils::symbol_parent_symtab(x.m_name)->get_counter() &&
378378
!ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) && !ASR::is_a<ASR::Variable_t>(*x.m_name)) {
379379
if (ASR::is_a<ASR::AssociateBlock_t>(*asr_owner_sym) || ASR::is_a<ASR::Block_t>(*asr_owner_sym)) {
@@ -383,7 +383,7 @@ namespace LCompilers {
383383
}
384384
} else {
385385
function_dependencies.push_back(al, ASRUtils::symbol_name(x.m_name));
386-
}
386+
}
387387
}
388388
}
389389

@@ -430,6 +430,21 @@ namespace LCompilers {
430430
}
431431
current_scope = parent_symtab;
432432
}
433+
434+
void visit_StructType(const ASR::StructType_t& x) {
435+
ASR::StructType_t& xx = const_cast<ASR::StructType_t&>(x);
436+
SetChar vec; vec.reserve(al, 1);
437+
for( auto itr: x.m_symtab->get_scope() ) {
438+
ASR::ttype_t* type = ASRUtils::extract_type(
439+
ASRUtils::symbol_type(itr.second));
440+
if( ASR::is_a<ASR::Struct_t>(*type) ) {
441+
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(type);
442+
vec.push_back(al, ASRUtils::symbol_name(struct_t->m_derived_type));
443+
}
444+
}
445+
xx.m_dependencies = vec.p;
446+
xx.n_dependencies = vec.size();
447+
}
433448
};
434449

435450
namespace ReplacerUtils {

tests/reference/asr-array_02-ec70729.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-array_02-ec70729.stdout",
9-
"stdout_hash": "be1e2e3174f008918af1f48265f78ff96e7fe5bd39cb3331be787e64",
9+
"stdout_hash": "f2ab5c25354dd2f1f779bae37cf83ab5101f4a335f967764b63153a9",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

tests/reference/asr-array_02-ec70729.stdout

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
[]
5656
.false.
5757
)
58-
[]
58+
[matmul_test]
5959
[]
6060
[(=
6161
(Var 3 sum)
@@ -155,7 +155,7 @@
155155
(Variable
156156
2
157157
arr1d
158-
[]
158+
[n]
159159
Local
160160
()
161161
()
@@ -176,7 +176,9 @@
176176
(Variable
177177
2
178178
arr2d
179-
[]
179+
[nums
180+
m
181+
n]
180182
Local
181183
()
182184
()

0 commit comments

Comments
 (0)