@@ -43,16 +43,20 @@ enum SpecialFunc {
4343 PushBack,
4444 Clear,
4545 Data,
46+ Reserve,
4647};
4748
4849std::map<std::string, SpecialFunc> special_function_map = {
4950 {" printf" , SpecialFunc::Printf},
5051 {" exit" , SpecialFunc::Exit},
5152 {" view" , SpecialFunc::View},
53+ {" submdspan" , SpecialFunc::View},
5254 {" shape" , SpecialFunc::Shape},
55+ {" extent" , SpecialFunc::Shape},
5356 {" empty" , SpecialFunc::Empty},
5457 {" fill" , SpecialFunc::Fill},
5558 {" all" , SpecialFunc::All},
59+ {" full_extent" , SpecialFunc::All},
5660 {" any" , SpecialFunc::Any},
5761 {" not_equal" , SpecialFunc::NotEqual},
5862 {" equal" , SpecialFunc::Equal},
@@ -72,6 +76,7 @@ std::map<std::string, SpecialFunc> special_function_map = {
7276 {" push_back" , SpecialFunc::PushBack},
7377 {" clear" , SpecialFunc::Clear},
7478 {" data" , SpecialFunc::Data},
79+ {" reserve" , SpecialFunc::Reserve},
7580};
7681
7782class OneTimeUseString {
@@ -455,6 +460,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
455460 if ( array_type && is_third_party_cpp_array ) {
456461 *is_third_party_cpp_array = true ;
457462 *array_type = ThirdPartyCPPArrayTypes::MDSpanArray;
463+ xshape_result->from_pointer_n (xtensor_fixed_dims.p , xtensor_fixed_dims.size ());
458464 }
459465 type = ASRUtils::TYPE (ASR::make_Array_t (al, l,
460466 ClangTypeToASRType (qual_type),
@@ -496,6 +502,34 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
496502 xshape_result->push_back (al, dim);
497503 }
498504 return nullptr ;
505+ } else if ( template_name == " extents" ) {
506+ const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments ();
507+ if ( xshape_result == nullptr ) {
508+ throw std::runtime_error (" Result Vec<ASR::dimention_t>* not provided." );
509+ }
510+
511+ const clang::QualType& qual_type = template_arguments.at (0 ).getAsType ();
512+ ASR::ttype_t * index_type = ClangTypeToASRType (qual_type);
513+ if ( !ASR::is_a<ASR::Integer_t>(*index_type) ||
514+ ASRUtils::extract_kind_from_ttype_t (index_type) != 4 ) {
515+ throw std::runtime_error (" Only int32_t should be used for index type in Kokkos::dextents." );
516+ }
517+ ASR::expr_t * zero = ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, l, 0 ,
518+ ASRUtils::TYPE (ASR::make_Integer_t (al, l, 4 ))));
519+ for ( int i = 1 ; i < template_arguments.size (); i++ ) {
520+ clang::Expr* clang_rank = template_arguments.at (i).getAsExpr ();
521+ TraverseStmt (clang_rank);
522+ int rank = 0 ;
523+ if ( !ASRUtils::extract_value (ASRUtils::EXPR (tmp.get ()), rank) ) {
524+ throw std::runtime_error (" Rank provided in the xshape must be a constant." );
525+ }
526+ ASR::dimension_t dim; dim.loc = l;
527+ dim.m_length = ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, l, rank,
528+ ASRUtils::TYPE (ASR::make_Integer_t (al, l, 4 ))));
529+ dim.m_start = zero;
530+ xshape_result->push_back (al, dim);
531+ }
532+ return nullptr ;
499533 } else if ( template_name == " dextents" ) {
500534 const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments ();
501535 if ( xshape_result == nullptr ) {
@@ -553,7 +587,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
553587 } else if ( clang_type->getTypeClass () == clang::Type::TypeClass::Record ) {
554588 const clang::CXXRecordDecl* record_type = clang_type->getAsCXXRecordDecl ();
555589 std::string name = record_type->getNameAsString ();
556- if ( name == " xtensor_container" || name == " vector" ) {
590+ if ( name == " xtensor_container" || name == " vector" || name == " mdspan " ) {
557591 return nullptr ;
558592 }
559593 ASR::symbol_t * type_t = current_scope->resolve_symbol (name);
@@ -565,6 +599,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
565599 } else {
566600 type = ASRUtils::TYPE (ASR::make_Struct_t (al, l, type_t ));
567601 }
602+ } else if ( clang_type->getTypeClass () == clang::Type::TypeClass::SubstTemplateTypeParm ) {
603+ return nullptr ;
568604 } else {
569605 throw std::runtime_error (" clang::QualType not yet supported " +
570606 std::string (clang_type->getTypeClassName ()));
@@ -995,6 +1031,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
9951031 ASRUtils::make_ArrayBroadcast_t_util (al, Lloc (x), obj, value);
9961032 tmp = ASR::make_Assignment_t (al, Lloc (x), obj, value, nullptr );
9971033 is_stmt_created = true ;
1034+ } else {
1035+ // This means that right hand side of operator=
1036+ // was a CXXConstructor and there was an allocation
1037+ // statement created for it. So skip and return should
1038+ // done in this case.
1039+ tmp = nullptr ;
1040+ is_stmt_created = false ;
9981041 }
9991042 assignment_target = nullptr ;
10001043 } else if ( ASRUtils::is_complex (*ASRUtils::expr_type (obj)) ||
@@ -1393,8 +1436,9 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
13931436 }
13941437 alloc_arg.m_dims = alloc_dims.p ; alloc_arg.n_dims = alloc_dims.size ();
13951438 alloc_args.push_back (al, alloc_arg);
1396- tmp = ASR::make_Allocate_t (al, Lloc (x), alloc_args.p , alloc_args.size (),
1397- nullptr , nullptr , nullptr );
1439+ current_body->push_back (al, ASRUtils::STMT (ASR::make_Allocate_t (al, Lloc (x),
1440+ alloc_args.p , alloc_args.size (), nullptr , nullptr , nullptr )));
1441+ tmp = nullptr ;
13981442 is_stmt_created = true ;
13991443 } else {
14001444 throw std::runtime_error (" Only {...} is allowed for supplying shape to xt::empty." );
@@ -1437,6 +1481,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
14371481 tmp = ASRUtils::make_Cast_t_value (al, Lloc (x), callee, ASR::cast_kindType::ListToArray,
14381482 ASRUtils::TYPE (ASR::make_Array_t (al, Lloc (x), ASRUtils::get_contained_type (ASRUtils::expr_type (callee)),
14391483 dims.p , dims.size (), ASR::array_physical_typeType::UnboundedPointerToDataArray)));
1484+ } else if (sf == SpecialFunc::Reserve) {
1485+ if ( args.size () > 1 ) {
1486+ throw std::runtime_error (" std::vector::reserve should be called with only one argument." );
1487+ }
1488+
1489+ tmp = ASR::make_ListReserve_t (al, Lloc (x), callee, args[0 ]);
1490+ is_stmt_created = true ;
14401491 } else {
14411492 throw std::runtime_error (" Only printf and exit special functions supported" );
14421493 }
@@ -1579,7 +1630,37 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
15791630 ASR::ttype_t * constructor_type = ClangTypeToASRType (x->getType (), &shape_result,
15801631 &third_party_array_type, &is_third_party_array_type);
15811632 if ( is_third_party_array_type && third_party_array_type == ThirdPartyCPPArrayTypes::MDSpanArray ) {
1582- if ( x->getNumArgs () >= 1 ) {
1633+ if ( x->getNumArgs () == 0 ) {
1634+ tmp = nullptr ;
1635+ is_stmt_created = false ;
1636+ return true ;
1637+ }
1638+ if ( x->getNumArgs () == 1 ) {
1639+ if ( shape_result.size () != ASRUtils::extract_n_dims_from_ttype (constructor_type) ) {
1640+ throw std::runtime_error (" Allocation sizes for all dimensions not provided." );
1641+ }
1642+ if ( !ASRUtils::is_fixed_size_array (shape_result.p , shape_result.size ()) ) {
1643+ throw std::runtime_error (" Allocation size of Kokkos::mdspan array isn't specified correctly." );
1644+ }
1645+
1646+ Vec<ASR::alloc_arg_t > alloc_args; alloc_args.reserve (al, 1 );
1647+ ASR::alloc_arg_t alloc_arg;
1648+ alloc_arg.loc = Lloc (x);
1649+ LCOMPILERS_ASSERT (assignment_target != nullptr );
1650+ alloc_arg.m_a = assignment_target;
1651+ alloc_arg.m_dims = shape_result.p ; alloc_arg.n_dims = shape_result.size ();
1652+ alloc_arg.m_len_expr = nullptr ; alloc_arg.m_type = nullptr ;
1653+ alloc_args.push_back (al, alloc_arg);
1654+ Vec<ASR::expr_t *> dealloc_args; dealloc_args.reserve (al, 1 );
1655+ dealloc_args.push_back (al, assignment_target);
1656+ current_body->push_back (al, ASRUtils::STMT (ASR::make_ExplicitDeallocate_t (
1657+ al, Lloc (x), dealloc_args.p , dealloc_args.size ())));
1658+ current_body->push_back (al, ASRUtils::STMT (ASR::make_Allocate_t (al, Lloc (x),
1659+ alloc_args.p , alloc_args.size (), nullptr , nullptr , nullptr )));
1660+ tmp = nullptr ;
1661+ is_stmt_created = true ;
1662+ return true ;
1663+ } else if ( x->getNumArgs () >= 1 ) {
15831664 clang::Expr* _data = x->getArg (0 );
15841665 TraverseStmt (_data);
15851666 ASR::expr_t * list_to_array = ASRUtils::EXPR (tmp.get ());
@@ -1602,7 +1683,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
16021683 TraverseStmt (argi);
16031684 ASR::dimension_t alloc_dim;
16041685 alloc_dim.loc = Lloc (x);
1605- alloc_dim.m_start = ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, Lloc (x), 1 ,
1686+ alloc_dim.m_start = ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, Lloc (x), 0 ,
16061687 ASRUtils::TYPE (ASR::make_Integer_t (al, Lloc (x), 4 ))));
16071688 alloc_dim.m_length = ASRUtils::EXPR (tmp.get ());
16081689 alloc_dims.push_back (al, alloc_dim);
@@ -1625,9 +1706,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
16251706 return true ;
16261707 } else {
16271708 throw std::runtime_error (" Kokkos::mdspan constructor is called "
1628- " with incorrect number of arguments, "
1629- " expected 3 and found, " + std::to_string (x->getNumArgs ()));
1709+ " with incorrect number of arguments, expected " + std::to_string (
1710+ ASRUtils::extract_n_dims_from_ttype (constructor_type) + 1 ) +
1711+ " and found, " + std::to_string (x->getNumArgs ()));
16301712 }
1713+ } else if ( constructor_type == nullptr ) {
1714+ clang::Expr* x_ = x->getArg (0 );
1715+ return TraverseStmt (x_);
16311716 } else if ( x->getNumArgs () >= 0 ) {
16321717 return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseCXXConstructExpr (x);
16331718 }
@@ -1751,8 +1836,12 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
17511836 case clang::APValue::Struct: {
17521837 ASR::ttype_t * v_type = ASRUtils::type_get_past_const (ASRUtils::symbol_type (v));
17531838 if ( !ASR::is_a<ASR::Struct_t>(*v_type) ) {
1754- throw std::runtime_error (" Expected ASR::Struct_t type found, " +
1755- ASRUtils::type_to_str (v_type));
1839+ // throw std::runtime_error("Expected ASR::Struct_t type found, " +
1840+ // ASRUtils::type_to_str(v_type));
1841+ // No error, just return, clang marking something as Struct
1842+ // doesn't mean that it necessarily maps to ASR Struct, it can
1843+ // map to ASR Array type as well
1844+ return ;
17561845 }
17571846 ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(v_type);
17581847 ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
@@ -1820,6 +1909,9 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
18201909 tmp = ASR::make_Assignment_t (al, Lloc (x), var, init_val, nullptr );
18211910 is_stmt_created = true ;
18221911 }
1912+ } else {
1913+ is_stmt_created = false ;
1914+ tmp = nullptr ;
18231915 }
18241916
18251917 if ( x->getEvaluatedValue () ) {
@@ -2376,8 +2468,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
23762468 ASR::symbol_t * sym = resolve_symbol (name);
23772469 if ( name == " operator<<" || name == " cout" || name == " endl" ||
23782470 name == " operator()" || name == " operator+" || name == " operator=" ||
2379- name == " operator*" || name == " view" || name == " empty" ||
2380- name == " all" || name == " any" || name == " not_equal" ||
2471+ name == " operator*" || name == " view" || name == " submdspan " || name == " empty" ||
2472+ name == " all" || name == " full_extent " || name == " any" || name == " not_equal" ||
23812473 name == " exit" || name == " printf" || name == " exp" ||
23822474 name == " sum" || name == " amax" || name == " abs" ||
23832475 name == " operator-" || name == " operator/" || name == " operator>" ||
@@ -2390,7 +2482,11 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
23902482 throw std::runtime_error (" Special function " + name + " cannot be overshadowed yet." );
23912483 }
23922484 if ( sym == nullptr ) {
2393- cxx_operator_name_obj.set (name);
2485+ if ( name == " full_extent" ) {
2486+ is_all_called = true ;
2487+ } else {
2488+ cxx_operator_name_obj.set (name);
2489+ }
23942490 tmp = nullptr ;
23952491 return true ;
23962492 }
0 commit comments