Skip to content

Commit 6dcebb2

Browse files
authored
Merge pull request #115 from czgdp1807/mdspan_03
Support `Kokkos::submdspan` for slicing `Kokkos::mdspan` arrays
2 parents 292d177 + 341debb commit 6dcebb2

File tree

8 files changed

+356
-18
lines changed

8 files changed

+356
-18
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ RUN(NAME array_02.cpp LABELS gcc llvm)
188188
RUN(NAME array_03.cpp LABELS gcc llvm)
189189
RUN(NAME array_03_mdspan.cpp LABELS gcc llvm)
190190
RUN(NAME array_04.cpp LABELS gcc llvm)
191+
RUN(NAME array_04_mdspan.cpp LABELS gcc llvm)
191192
RUN(NAME array_05.cpp LABELS gcc llvm)
192193
RUN(NAME array_06.cpp LABELS gcc llvm)
193194
RUN(NAME array_07.cpp LABELS gcc llvm)

integration_tests/array_04.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,17 @@ int main() {
1313

1414
xt::xtensor<double, 1> arr2 {5.0, 6.0, 7.0};
1515
xt::xtensor<double, 1> res;
16+
xt::xtensor_fixed<double, xt::xshape<3>> res_ = {7.0, 11.0, 14.0};
1617

1718
res = xt::empty<double>({3});
1819
res = xt::view(arr1, 1) + arr2;
19-
std::cout<< arr1 << arr2 <<std::endl;
20-
std::cout << res << std::endl;
20+
std::cout << arr1 << "\n";
21+
std::cout << arr2 << "\n";
22+
std::cout << res << "\n";
23+
24+
if( xt::any(xt::abs(res - res_) > 1e-8) ) {
25+
exit(2);
26+
}
2127

2228
return 0;
2329
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include <iostream>
2+
#include <vector>
3+
#include <mdspan/mdspan.hpp>
4+
5+
void print2(Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>>& arr) {
6+
for( int32_t i = 0; i < arr.extent(0); i++ ) {
7+
for( int32_t j = 0; j < arr.extent(1) ; j++ ) {
8+
std::cout << arr(i, j) << " ";
9+
}
10+
}
11+
}
12+
13+
void print1(Kokkos::mdspan<double, Kokkos::dextents<int32_t, 1>>& arr) {
14+
for( int32_t i = 0; i < arr.size(); i++ ) {
15+
std::cout << arr[i] << " ";
16+
}
17+
std::cout << "\n";
18+
}
19+
20+
int main() {
21+
22+
std::vector<double> arr1_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};
23+
Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>> arr1(arr1_data.data(), 3, 3);
24+
25+
std::vector<double> arr2_data = {5.0, 6.0, 7.0};
26+
Kokkos::mdspan<double, Kokkos::dextents<int32_t, 1>> arr2(arr2_data.data(), 3);
27+
Kokkos::mdspan<double, Kokkos::dextents<int32_t, 1>> res;
28+
std::vector<double> res_data_correct = {7.0, 11.0, 14.0};
29+
30+
std::vector<double> res_data; res_data.reserve(3);
31+
res = Kokkos::mdspan<double, Kokkos::extents<int32_t, 3>>(res_data.data());
32+
res = Kokkos::submdspan(arr1, 1, Kokkos::full_extent);
33+
for( int32_t i = 0; i < res.size(); i++ ) {
34+
res[i] = res[i] + arr2[i];
35+
}
36+
37+
print2(arr1);
38+
print1(arr2);
39+
print1(res);
40+
41+
42+
for( int32_t i = 0; i < res.size(); i++ ) {
43+
if( res[i] != res_data_correct[i] ) {
44+
exit(2);
45+
}
46+
}
47+
48+
return 0;
49+
}

src/lc/clang_ast_to_asr.cpp

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,20 @@ enum SpecialFunc {
4343
PushBack,
4444
Clear,
4545
Data,
46+
Reserve,
4647
};
4748

4849
std::map<std::string, SpecialFunc> special_function_map = {
4950
{"printf", SpecialFunc::Printf},
5051
{"exit", SpecialFunc::Exit},
5152
{"view", SpecialFunc::View},
53+
{"submdspan", SpecialFunc::View},
5254
{"shape", SpecialFunc::Shape},
55+
{"extent", SpecialFunc::Shape},
5356
{"empty", SpecialFunc::Empty},
5457
{"fill", SpecialFunc::Fill},
5558
{"all", SpecialFunc::All},
59+
{"full_extent", SpecialFunc::All},
5660
{"any", SpecialFunc::Any},
5761
{"not_equal", SpecialFunc::NotEqual},
5862
{"equal", SpecialFunc::Equal},
@@ -72,6 +76,7 @@ std::map<std::string, SpecialFunc> special_function_map = {
7276
{"push_back", SpecialFunc::PushBack},
7377
{"clear", SpecialFunc::Clear},
7478
{"data", SpecialFunc::Data},
79+
{"reserve", SpecialFunc::Reserve},
7580
};
7681

7782
class OneTimeUseString {
@@ -455,6 +460,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
455460
if( array_type && is_third_party_cpp_array ) {
456461
*is_third_party_cpp_array = true;
457462
*array_type = ThirdPartyCPPArrayTypes::MDSpanArray;
463+
xshape_result->from_pointer_n(xtensor_fixed_dims.p, xtensor_fixed_dims.size());
458464
}
459465
type = ASRUtils::TYPE(ASR::make_Array_t(al, l,
460466
ClangTypeToASRType(qual_type),
@@ -496,6 +502,34 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
496502
xshape_result->push_back(al, dim);
497503
}
498504
return nullptr;
505+
} else if( template_name == "extents" ) {
506+
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
507+
if( xshape_result == nullptr ) {
508+
throw std::runtime_error("Result Vec<ASR::dimention_t>* not provided.");
509+
}
510+
511+
const clang::QualType& qual_type = template_arguments.at(0).getAsType();
512+
ASR::ttype_t* index_type = ClangTypeToASRType(qual_type);
513+
if( !ASR::is_a<ASR::Integer_t>(*index_type) ||
514+
ASRUtils::extract_kind_from_ttype_t(index_type) != 4 ) {
515+
throw std::runtime_error("Only int32_t should be used for index type in Kokkos::dextents.");
516+
}
517+
ASR::expr_t* zero = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, l, 0,
518+
ASRUtils::TYPE(ASR::make_Integer_t(al, l, 4))));
519+
for( int i = 1; i < template_arguments.size(); i++ ) {
520+
clang::Expr* clang_rank = template_arguments.at(i).getAsExpr();
521+
TraverseStmt(clang_rank);
522+
int rank = 0;
523+
if( !ASRUtils::extract_value(ASRUtils::EXPR(tmp.get()), rank) ) {
524+
throw std::runtime_error("Rank provided in the xshape must be a constant.");
525+
}
526+
ASR::dimension_t dim; dim.loc = l;
527+
dim.m_length = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, l, rank,
528+
ASRUtils::TYPE(ASR::make_Integer_t(al, l, 4))));
529+
dim.m_start = zero;
530+
xshape_result->push_back(al, dim);
531+
}
532+
return nullptr;
499533
} else if( template_name == "dextents" ) {
500534
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
501535
if( xshape_result == nullptr ) {
@@ -553,7 +587,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
553587
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::Record ) {
554588
const clang::CXXRecordDecl* record_type = clang_type->getAsCXXRecordDecl();
555589
std::string name = record_type->getNameAsString();
556-
if( name == "xtensor_container" || name == "vector" ) {
590+
if( name == "xtensor_container" || name == "vector" || name == "mdspan" ) {
557591
return nullptr;
558592
}
559593
ASR::symbol_t* type_t = current_scope->resolve_symbol(name);
@@ -565,6 +599,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
565599
} else {
566600
type = ASRUtils::TYPE(ASR::make_Struct_t(al, l, type_t));
567601
}
602+
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::SubstTemplateTypeParm ) {
603+
return nullptr;
568604
} else {
569605
throw std::runtime_error("clang::QualType not yet supported " +
570606
std::string(clang_type->getTypeClassName()));
@@ -995,6 +1031,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
9951031
ASRUtils::make_ArrayBroadcast_t_util(al, Lloc(x), obj, value);
9961032
tmp = ASR::make_Assignment_t(al, Lloc(x), obj, value, nullptr);
9971033
is_stmt_created = true;
1034+
} else {
1035+
// This means that right hand side of operator=
1036+
// was a CXXConstructor and there was an allocation
1037+
// statement created for it. So skip and return should
1038+
// done in this case.
1039+
tmp = nullptr;
1040+
is_stmt_created = false;
9981041
}
9991042
assignment_target = nullptr;
10001043
} else if( ASRUtils::is_complex(*ASRUtils::expr_type(obj)) ||
@@ -1393,8 +1436,9 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
13931436
}
13941437
alloc_arg.m_dims = alloc_dims.p; alloc_arg.n_dims = alloc_dims.size();
13951438
alloc_args.push_back(al, alloc_arg);
1396-
tmp = ASR::make_Allocate_t(al, Lloc(x), alloc_args.p, alloc_args.size(),
1397-
nullptr, nullptr, nullptr);
1439+
current_body->push_back(al, ASRUtils::STMT(ASR::make_Allocate_t(al, Lloc(x),
1440+
alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr)));
1441+
tmp = nullptr;
13981442
is_stmt_created = true;
13991443
} else {
14001444
throw std::runtime_error("Only {...} is allowed for supplying shape to xt::empty.");
@@ -1437,6 +1481,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
14371481
tmp = ASRUtils::make_Cast_t_value(al, Lloc(x), callee, ASR::cast_kindType::ListToArray,
14381482
ASRUtils::TYPE(ASR::make_Array_t(al, Lloc(x), ASRUtils::get_contained_type(ASRUtils::expr_type(callee)),
14391483
dims.p, dims.size(), ASR::array_physical_typeType::UnboundedPointerToDataArray)));
1484+
} else if (sf == SpecialFunc::Reserve) {
1485+
if( args.size() > 1 ) {
1486+
throw std::runtime_error("std::vector::reserve should be called with only one argument.");
1487+
}
1488+
1489+
tmp = ASR::make_ListReserve_t(al, Lloc(x), callee, args[0]);
1490+
is_stmt_created = true;
14401491
} else {
14411492
throw std::runtime_error("Only printf and exit special functions supported");
14421493
}
@@ -1579,7 +1630,37 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
15791630
ASR::ttype_t* constructor_type = ClangTypeToASRType(x->getType(), &shape_result,
15801631
&third_party_array_type, &is_third_party_array_type);
15811632
if( is_third_party_array_type && third_party_array_type == ThirdPartyCPPArrayTypes::MDSpanArray ) {
1582-
if( x->getNumArgs() >= 1 ) {
1633+
if( x->getNumArgs() == 0 ) {
1634+
tmp = nullptr;
1635+
is_stmt_created = false;
1636+
return true;
1637+
}
1638+
if( x->getNumArgs() == 1 ) {
1639+
if( shape_result.size() != ASRUtils::extract_n_dims_from_ttype(constructor_type) ) {
1640+
throw std::runtime_error("Allocation sizes for all dimensions not provided.");
1641+
}
1642+
if( !ASRUtils::is_fixed_size_array(shape_result.p, shape_result.size()) ) {
1643+
throw std::runtime_error("Allocation size of Kokkos::mdspan array isn't specified correctly.");
1644+
}
1645+
1646+
Vec<ASR::alloc_arg_t> alloc_args; alloc_args.reserve(al, 1);
1647+
ASR::alloc_arg_t alloc_arg;
1648+
alloc_arg.loc = Lloc(x);
1649+
LCOMPILERS_ASSERT(assignment_target != nullptr);
1650+
alloc_arg.m_a = assignment_target;
1651+
alloc_arg.m_dims = shape_result.p; alloc_arg.n_dims = shape_result.size();
1652+
alloc_arg.m_len_expr = nullptr; alloc_arg.m_type = nullptr;
1653+
alloc_args.push_back(al, alloc_arg);
1654+
Vec<ASR::expr_t*> dealloc_args; dealloc_args.reserve(al, 1);
1655+
dealloc_args.push_back(al, assignment_target);
1656+
current_body->push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t(
1657+
al, Lloc(x), dealloc_args.p, dealloc_args.size())));
1658+
current_body->push_back(al, ASRUtils::STMT(ASR::make_Allocate_t(al, Lloc(x),
1659+
alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr)));
1660+
tmp = nullptr;
1661+
is_stmt_created = true;
1662+
return true;
1663+
} else if( x->getNumArgs() >= 1 ) {
15831664
clang::Expr* _data = x->getArg(0);
15841665
TraverseStmt(_data);
15851666
ASR::expr_t* list_to_array = ASRUtils::EXPR(tmp.get());
@@ -1602,7 +1683,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
16021683
TraverseStmt(argi);
16031684
ASR::dimension_t alloc_dim;
16041685
alloc_dim.loc = Lloc(x);
1605-
alloc_dim.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, Lloc(x), 1,
1686+
alloc_dim.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, Lloc(x), 0,
16061687
ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4))));
16071688
alloc_dim.m_length = ASRUtils::EXPR(tmp.get());
16081689
alloc_dims.push_back(al, alloc_dim);
@@ -1625,9 +1706,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
16251706
return true;
16261707
} else {
16271708
throw std::runtime_error("Kokkos::mdspan constructor is called "
1628-
"with incorrect number of arguments, "
1629-
"expected 3 and found, " + std::to_string(x->getNumArgs()));
1709+
"with incorrect number of arguments, expected " + std::to_string(
1710+
ASRUtils::extract_n_dims_from_ttype(constructor_type) + 1) +
1711+
" and found, " + std::to_string(x->getNumArgs()));
16301712
}
1713+
} else if( constructor_type == nullptr ) {
1714+
clang::Expr* x_ = x->getArg(0);
1715+
return TraverseStmt(x_);
16311716
} else if( x->getNumArgs() >= 0 ) {
16321717
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseCXXConstructExpr(x);
16331718
}
@@ -1751,8 +1836,12 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
17511836
case clang::APValue::Struct: {
17521837
ASR::ttype_t* v_type = ASRUtils::type_get_past_const(ASRUtils::symbol_type(v));
17531838
if( !ASR::is_a<ASR::Struct_t>(*v_type) ) {
1754-
throw std::runtime_error("Expected ASR::Struct_t type found, " +
1755-
ASRUtils::type_to_str(v_type));
1839+
// throw std::runtime_error("Expected ASR::Struct_t type found, " +
1840+
// ASRUtils::type_to_str(v_type));
1841+
// No error, just return, clang marking something as Struct
1842+
// doesn't mean that it necessarily maps to ASR Struct, it can
1843+
// map to ASR Array type as well
1844+
return ;
17561845
}
17571846
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(v_type);
17581847
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
@@ -1820,6 +1909,9 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
18201909
tmp = ASR::make_Assignment_t(al, Lloc(x), var, init_val, nullptr);
18211910
is_stmt_created = true;
18221911
}
1912+
} else {
1913+
is_stmt_created = false;
1914+
tmp = nullptr;
18231915
}
18241916

18251917
if( x->getEvaluatedValue() ) {
@@ -2376,8 +2468,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
23762468
ASR::symbol_t* sym = resolve_symbol(name);
23772469
if( name == "operator<<" || name == "cout" || name == "endl" ||
23782470
name == "operator()" || name == "operator+" || name == "operator=" ||
2379-
name == "operator*" || name == "view" || name == "empty" ||
2380-
name == "all" || name == "any" || name == "not_equal" ||
2471+
name == "operator*" || name == "view" || name == "submdspan" || name == "empty" ||
2472+
name == "all" || name == "full_extent" || name == "any" || name == "not_equal" ||
23812473
name == "exit" || name == "printf" || name == "exp" ||
23822474
name == "sum" || name == "amax" || name == "abs" ||
23832475
name == "operator-" || name == "operator/" || name == "operator>" ||
@@ -2390,7 +2482,11 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
23902482
throw std::runtime_error("Special function " + name + " cannot be overshadowed yet.");
23912483
}
23922484
if( sym == nullptr ) {
2393-
cxx_operator_name_obj.set(name);
2485+
if( name == "full_extent" ) {
2486+
is_all_called = true;
2487+
} else {
2488+
cxx_operator_name_obj.set(name);
2489+
}
23942490
tmp = nullptr;
23952491
return true;
23962492
}

src/libasr/ASR.asdl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ stmt
229229
| BlockCall(int label, symbol m)
230230
| SetInsert(expr a, expr ele)
231231
| SetRemove(expr a, expr ele)
232+
| ListReserve(expr a, expr size)
232233
| ListInsert(expr a, expr pos, expr ele)
233234
| ListRemove(expr a, expr ele)
234235
| ListClear(expr a)

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
10231023
visit_AllocateUtil(x, x.m_stat, false);
10241024
}
10251025

1026+
void visit_ListReserve(const ASR::ListReserve_t& x) {
1027+
ASR::ttype_t* el_type = ASRUtils::get_contained_type(
1028+
ASRUtils::expr_type(x.m_a));
1029+
int64_t ptr_loads_copy = ptr_loads;
1030+
ptr_loads = 0;
1031+
this->visit_expr(*x.m_a);
1032+
llvm::Value* plist = tmp;
1033+
1034+
ptr_loads = 1;
1035+
this->visit_expr_wrapper(x.m_size, true);
1036+
ptr_loads = ptr_loads_copy;
1037+
llvm::Value *pos = tmp;
1038+
1039+
llvm::Value* list_data_ptr = list_api->get_pointer_to_list_data(plist);
1040+
llvm::Value* size = list_api->len(plist);
1041+
llvm::DataLayout data_layout(module.get());
1042+
llvm::Type* llvm_el_type = llvm_utils->get_type_from_ttype_t_util(el_type, module.get());
1043+
size_t el_struct_size = data_layout.getTypeAllocSize(llvm_el_type);
1044+
size = builder->CreateMul(size, llvm::ConstantInt::get(
1045+
llvm::Type::getInt32Ty(context), llvm::APInt(32, el_struct_size)));
1046+
llvm::Value* alloc_ptr = LLVM::lfortran_malloc(context, *module, *builder, size);
1047+
builder->CreateStore(builder->CreateBitCast(alloc_ptr, llvm_el_type->getPointerTo()), list_data_ptr);
1048+
}
1049+
10261050
void visit_ReAlloc(const ASR::ReAlloc_t& x) {
10271051
LCOMPILERS_ASSERT(x.n_args == 1);
10281052
handle_allocated(x.m_args[0].m_a);

tests/reference/asr-array_04-f95b8eb.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
"basename": "asr-array_04-f95b8eb",
33
"cmd": "lc --show-asr --no-color {infile} -o {outfile}",
44
"infile": "tests/../integration_tests/array_04.cpp",
5-
"infile_hash": "dd38786061d4a19743e4c32c2a3ecb08193777f81a2a0d3e87af0f0a",
5+
"infile_hash": "86bab3bcf300c07e9861e92ce21215829ae2893db1d378776b8ab332",
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-array_04-f95b8eb.stdout",
9-
"stdout_hash": "ecd1c3d056e6792b7c6ec0baa0730c74d60745d34d459ffb6122aee1",
9+
"stdout_hash": "b72419a0a8bd62f822377669e6e98ca64a95400f6c7ce7b7ad33a68a",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)