Skip to content

Add support for mdspan #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,10 @@ jobs:
if: contains(matrix.os, 'ubuntu') || contains(matrix.os, 'macos')
run: |
export CPATH=$CONDA_PREFIX/include:$CPATH
git clone https://github.com/czgdp1807/mdspan
cd mdspan
cmake . -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
make install
cd ..
./integration_tests/run_tests.py -b gcc llvm wasm c
./integration_tests/run_tests.py -b gcc llvm wasm c -f
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ RUN(NAME switch_case_02.cpp LABELS gcc llvm)
RUN(NAME array_01.cpp LABELS gcc llvm)
RUN(NAME array_02.cpp LABELS gcc llvm)
RUN(NAME array_03.cpp LABELS gcc llvm)
RUN(NAME array_03_mdspan.cpp LABELS gcc llvm)
RUN(NAME array_04.cpp LABELS gcc llvm)
RUN(NAME array_05.cpp LABELS gcc llvm)
RUN(NAME array_06.cpp LABELS gcc llvm)
Expand Down
23 changes: 23 additions & 0 deletions integration_tests/array_03_mdspan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <iostream>
#include <vector>
#include <mdspan/mdspan.hpp>

void print(Kokkos::mdspan<double, Kokkos::dextents<int32_t, 1>>& arr) {
for( int32_t i = 0; i < arr.size(); i++ ) {
std::cout << arr[i] << "\n";
}
}

int main() {

std::vector<double> arr1_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};
std::vector<double> arr2_data = {5.0, 6.0, 7.0};

Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>> arr1{arr1_data.data(), 3, 3};

Kokkos::mdspan<double, Kokkos::dextents<int32_t, 1>> arr2{arr2_data.data(), 3};

print( arr2 );

return 0;
}
150 changes: 141 additions & 9 deletions src/lc/clang_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum SpecialFunc {
Sqrt,
PushBack,
Clear,
Data,
};

std::map<std::string, SpecialFunc> special_function_map = {
Expand Down Expand Up @@ -68,6 +69,7 @@ std::map<std::string, SpecialFunc> special_function_map = {
{"sqrt", SpecialFunc::Sqrt},
{"push_back", SpecialFunc::PushBack},
{"clear", SpecialFunc::Clear},
{"data", SpecialFunc::Data},
};

class OneTimeUseString {
Expand Down Expand Up @@ -143,6 +145,11 @@ class OneTimeUseASRNode {

};

enum ThirdPartyCPPArrayTypes {
XTensorArray,
MDSpanArray,
};

class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisitor> {

public:
Expand Down Expand Up @@ -329,7 +336,9 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

ASR::ttype_t* ClangTypeToASRType(const clang::QualType& qual_type,
Vec<ASR::dimension_t>* xshape_result=nullptr) {
Vec<ASR::dimension_t>* xshape_result=nullptr,
ThirdPartyCPPArrayTypes* array_type=nullptr,
bool* is_third_party_cpp_array=nullptr) {
const clang::SplitQualType& split_qual_type = qual_type.split();
const clang::Type* clang_type = split_qual_type.asPair().first;
const clang::Qualifiers qualifiers = split_qual_type.asPair().second;
Expand Down Expand Up @@ -392,11 +401,11 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
} else if( clang_type->getTypeClass() == clang::Type::LValueReference ) {
const clang::LValueReferenceType* lvalue_reference_type = clang_type->getAs<clang::LValueReferenceType>();
clang::QualType pointee_type = lvalue_reference_type->getPointeeType();
type = ClangTypeToASRType(pointee_type, xshape_result);
type = ClangTypeToASRType(pointee_type, xshape_result, array_type, is_third_party_cpp_array);
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::Elaborated ) {
const clang::ElaboratedType* elaborated_type = clang_type->getAs<clang::ElaboratedType>();
clang::QualType desugared_type = elaborated_type->desugar();
type = ClangTypeToASRType(desugared_type, xshape_result);
type = ClangTypeToASRType(desugared_type, xshape_result, array_type, is_third_party_cpp_array);
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::TemplateSpecialization ) {
const clang::TemplateSpecializationType* template_specialization = clang_type->getAs<clang::TemplateSpecializationType>();
std::string template_name = template_specialization->getTemplateName().getAsTemplateDecl()->getNameAsString();
Expand All @@ -421,11 +430,34 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
empty_dim.m_length = nullptr;
empty_dims.push_back(al, empty_dim);
}
if( array_type && is_third_party_cpp_array ) {
*is_third_party_cpp_array = true;
*array_type = ThirdPartyCPPArrayTypes::XTensorArray;
}
type = ASRUtils::TYPE(ASR::make_Array_t(al, l,
ClangTypeToASRType(qual_type),
empty_dims.p, empty_dims.size(),
ASR::array_physical_typeType::DescriptorArray));
type = ASRUtils::TYPE(ASR::make_Allocatable_t(al, l, type));
} else if( template_name == "mdspan" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( template_arguments.size() != 2 ) {
throw std::runtime_error("mdspan type must be initialised with element type, index type and rank.");
}
const clang::QualType& qual_type = template_arguments.at(0).getAsType();
const clang::QualType& shape_type = template_arguments.at(1).getAsType();
Vec<ASR::dimension_t> xtensor_fixed_dims; xtensor_fixed_dims.reserve(al, 1);
ClangTypeToASRType(shape_type, &xtensor_fixed_dims, array_type, is_third_party_cpp_array);
// Only allocatables are made for Kokkos::mdspan
if( array_type && is_third_party_cpp_array ) {
*is_third_party_cpp_array = true;
*array_type = ThirdPartyCPPArrayTypes::MDSpanArray;
}
type = ASRUtils::TYPE(ASR::make_Array_t(al, l,
ClangTypeToASRType(qual_type),
xtensor_fixed_dims.p, xtensor_fixed_dims.size(),
ASR::array_physical_typeType::DescriptorArray));
type = ASRUtils::TYPE(ASR::make_Allocatable_t(al, l, type));
} else if( template_name == "xtensor_fixed" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( template_arguments.size() != 2 ) {
Expand All @@ -434,7 +466,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
const clang::QualType& qual_type = template_arguments.at(0).getAsType();
const clang::QualType& shape_type = template_arguments.at(1).getAsType();
Vec<ASR::dimension_t> xtensor_fixed_dims; xtensor_fixed_dims.reserve(al, 1);
ClangTypeToASRType(shape_type, &xtensor_fixed_dims);
ClangTypeToASRType(shape_type, &xtensor_fixed_dims, array_type, is_third_party_cpp_array);
type = ASRUtils::TYPE(ASR::make_Array_t(al, l,
ClangTypeToASRType(qual_type),
xtensor_fixed_dims.p, xtensor_fixed_dims.size(),
Expand All @@ -461,14 +493,41 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
xshape_result->push_back(al, dim);
}
return nullptr;
} else if( template_name == "dextents" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( xshape_result == nullptr ) {
throw std::runtime_error("Result Vec<ASR::dimention_t>* not provided.");
}

ASR::expr_t* zero = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, l, 0,
ASRUtils::TYPE(ASR::make_Integer_t(al, l, 4))));
const clang::QualType& index_type = template_arguments.at(0).getAsType();
if( !ASR::is_a<ASR::Integer_t>(*ClangTypeToASRType(index_type)) ||
ASRUtils::extract_kind_from_ttype_t(ClangTypeToASRType(index_type)) != 4 ) {
throw std::runtime_error("Only int32_t should be used for index type in Kokkos::dextents.");
}
clang::Expr* clang_rank = template_arguments.at(1).getAsExpr();
TraverseStmt(clang_rank);
int rank = 0;
if( !ASRUtils::extract_value(ASRUtils::EXPR(tmp.get()), rank) ) {
throw std::runtime_error("Rank provided in the Kokkos::dextents must be a constant.");
}
for( int i = 0; i < rank; i++ ) {
ASR::dimension_t dim; dim.loc = l;
dim.m_length = nullptr;
dim.m_start = zero;
xshape_result->push_back(al, dim);
}
return nullptr;
} else if( template_name == "complex" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( template_arguments.size() > 1 ) {
throw std::runtime_error("std::complex accepts only one argument.");
}

const clang::QualType& qual_type = template_arguments.at(0).getAsType();
ASR::ttype_t* complex_subtype = ClangTypeToASRType(qual_type, xshape_result);
ASR::ttype_t* complex_subtype = ClangTypeToASRType(qual_type, xshape_result,
array_type, is_third_party_cpp_array);
if( !ASRUtils::is_real(*complex_subtype) ) {
throw std::runtime_error(std::string("std::complex accepts only real types, found: ") +
ASRUtils::type_to_str(complex_subtype));
Expand All @@ -482,14 +541,18 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

const clang::QualType& qual_type = template_arguments.at(0).getAsType();
ASR::ttype_t* vector_subtype = ClangTypeToASRType(qual_type, xshape_result);
ASR::ttype_t* vector_subtype = ClangTypeToASRType(qual_type, xshape_result,
array_type, is_third_party_cpp_array);
type = ASRUtils::TYPE(ASR::make_List_t(al, l, vector_subtype));
} else {
throw std::runtime_error(std::string("Unrecognized type ") + template_name);
}
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::Record ) {
const clang::CXXRecordDecl* record_type = clang_type->getAsCXXRecordDecl();
std::string name = record_type->getNameAsString();
if( name == "xtensor_container" || name == "vector" ) {
return nullptr;
}
ASR::symbol_t* type_t = current_scope->resolve_symbol(name);
if( !type_t ) {
throw std::runtime_error(name + " not defined.");
Expand Down Expand Up @@ -787,7 +850,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
member_name_obj.set(member_name);
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseMemberExpr(x);
} else {
throw std::runtime_error("clang::MemberExpr only supported for struct and special functions.");
throw std::runtime_error("clang::MemberExpr only supported for struct, union and special functions.");
}
return true;
}
Expand Down Expand Up @@ -1357,6 +1420,20 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

tmp = ASR::make_ListClear_t(al, Lloc(x), callee);
is_stmt_created = true;
} else if (sf == SpecialFunc::Data) {
if( args.size() > 0 ) {
throw std::runtime_error("std::vector::data should be called with only one argument.");
}

Vec<ASR::dimension_t> dims; dims.reserve(al, 1);
ASR::dimension_t dim;
dim.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t(
al, Lloc(x), 0, ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4))));
dim.m_length = nullptr;
dims.push_back(al, dim);
tmp = ASRUtils::make_Cast_t_value(al, Lloc(x), callee, ASR::cast_kindType::ListToArray,
ASRUtils::TYPE(ASR::make_Array_t(al, Lloc(x), ASRUtils::get_contained_type(ASRUtils::expr_type(callee)),
dims.p, dims.size(), ASR::array_physical_typeType::UnboundedPointerToDataArray)));
} else {
throw std::runtime_error("Only printf and exit special functions supported");
}
Expand Down Expand Up @@ -1493,15 +1570,70 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

bool TraverseCXXConstructExpr(clang::CXXConstructExpr* x) {
if( x->getNumArgs() >= 0 ) {
ThirdPartyCPPArrayTypes third_party_array_type;
bool is_third_party_array_type = false;
Vec<ASR::dimension_t> shape_result; shape_result.reserve(al, 1);
ASR::ttype_t* constructor_type = ClangTypeToASRType(x->getType(), &shape_result,
&third_party_array_type, &is_third_party_array_type);
if( is_third_party_array_type && third_party_array_type == ThirdPartyCPPArrayTypes::MDSpanArray ) {
if( x->getNumArgs() >= 1 ) {
clang::Expr* _data = x->getArg(0);
TraverseStmt(_data);
ASR::expr_t* list_to_array = ASRUtils::EXPR(tmp.get());
if( !ASR::is_a<ASR::Cast_t>(*list_to_array) ||
ASR::down_cast<ASR::Cast_t>(list_to_array)->m_kind !=
ASR::cast_kindType::ListToArray) {
throw std::runtime_error("First argument of Kokkos::mdspan "
"constructor should be a call to std::vector::data() function.");
}
if( ASRUtils::extract_n_dims_from_ttype(constructor_type) != x->getNumArgs() - 1 ) {
throw std::runtime_error("Shape provided in the constructor "
"doesn't match with the rank of the array.");
}
Vec<ASR::dimension_t> alloc_dims; alloc_dims.reserve(al, x->getNumArgs() - 1);
Vec<ASR::alloc_arg_t> alloc_args; alloc_args.reserve(al, 1);
for( int i = 1; i < x->getNumArgs(); i++ ) {
clang::Expr* argi = x->getArg(i);
TraverseStmt(argi);
ASR::dimension_t alloc_dim;
alloc_dim.loc = Lloc(x);
alloc_dim.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, Lloc(x), 1,
ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4))));
alloc_dim.m_length = ASRUtils::EXPR(tmp.get());
alloc_dims.push_back(al, alloc_dim);
}
ASR::alloc_arg_t alloc_arg;
alloc_arg.loc = Lloc(x);
LCOMPILERS_ASSERT(assignment_target != nullptr);
alloc_arg.m_a = assignment_target;
alloc_arg.m_dims = alloc_dims.p; alloc_arg.n_dims = alloc_dims.size();
alloc_arg.m_len_expr = nullptr; alloc_arg.m_type = nullptr;
alloc_args.push_back(al, alloc_arg);
Vec<ASR::expr_t*> dealloc_args; dealloc_args.reserve(al, 1);
dealloc_args.push_back(al, assignment_target);
current_body->push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t(
al, Lloc(x), dealloc_args.p, dealloc_args.size())));
current_body->push_back(al, ASRUtils::STMT(ASR::make_Allocate_t(al, Lloc(x),
alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr)));
tmp = reinterpret_cast<ASR::asr_t*>(list_to_array);
is_stmt_created = false;
return true;
} else {
throw std::runtime_error("Kokkos::mdspan constructor is called "
"with incorrect number of arguments, "
"expected 3 and found, " + std::to_string(x->getNumArgs()));
}
} else if( x->getNumArgs() >= 0 ) {
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseCXXConstructExpr(x);
}
tmp = nullptr;
return true;
}

void add_reshape_if_needed(ASR::expr_t*& expr, ASR::expr_t* target_expr) {
if( ASR::is_a<ASR::List_t>(*ASRUtils::expr_type(target_expr)) ) {
if( ASR::is_a<ASR::List_t>(*ASRUtils::expr_type(target_expr)) ||
((ASR::is_a<ASR::Cast_t>(*expr) &&
ASR::down_cast<ASR::Cast_t>(expr)->m_kind == ASR::cast_kindType::ListToArray)) ) {
return ;
}
ASR::ttype_t* expr_type = ASRUtils::expr_type(expr);
Expand Down
1 change: 1 addition & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ cast_kind
| CPtrToUnsignedInteger
| UnsignedIntegerToCPtr
| IntegerToSymbolicExpression
| ListToArray

dimension = (expr? start, expr? length)

Expand Down
Loading