Skip to content

Commit

Permalink
Merge pull request #111 from czgdp1807/mdspan_01
Browse files Browse the repository at this point in the history
Add support for `mdspan`
  • Loading branch information
czgdp1807 authored Mar 19, 2024
2 parents fd8842f + 632c2db commit ac76a2a
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 10 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,10 @@ jobs:
if: contains(matrix.os, 'ubuntu') || contains(matrix.os, 'macos')
run: |
export CPATH=$CONDA_PREFIX/include:$CPATH
git clone https://github.com/czgdp1807/mdspan
cd mdspan
cmake . -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
make install
cd ..
./integration_tests/run_tests.py -b gcc llvm wasm c
./integration_tests/run_tests.py -b gcc llvm wasm c -f
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ RUN(NAME switch_case_02.cpp LABELS gcc llvm)
RUN(NAME array_01.cpp LABELS gcc llvm)
RUN(NAME array_02.cpp LABELS gcc llvm)
RUN(NAME array_03.cpp LABELS gcc llvm)
RUN(NAME array_03_mdspan.cpp LABELS gcc llvm)
RUN(NAME array_04.cpp LABELS gcc llvm)
RUN(NAME array_05.cpp LABELS gcc llvm)
RUN(NAME array_06.cpp LABELS gcc llvm)
Expand Down
23 changes: 23 additions & 0 deletions integration_tests/array_03_mdspan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <iostream>
#include <vector>
#include <mdspan/mdspan.hpp>

void print(Kokkos::mdspan<double, Kokkos::dextents<int32_t, 1>>& arr) {
for( int32_t i = 0; i < arr.size(); i++ ) {
std::cout << arr[i] << "\n";
}
}

int main() {

std::vector<double> arr1_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};
std::vector<double> arr2_data = {5.0, 6.0, 7.0};

Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>> arr1{arr1_data.data(), 3, 3};

Kokkos::mdspan<double, Kokkos::dextents<int32_t, 1>> arr2{arr2_data.data(), 3};

print( arr2 );

return 0;
}
150 changes: 141 additions & 9 deletions src/lc/clang_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum SpecialFunc {
Sqrt,
PushBack,
Clear,
Data,
};

std::map<std::string, SpecialFunc> special_function_map = {
Expand Down Expand Up @@ -68,6 +69,7 @@ std::map<std::string, SpecialFunc> special_function_map = {
{"sqrt", SpecialFunc::Sqrt},
{"push_back", SpecialFunc::PushBack},
{"clear", SpecialFunc::Clear},
{"data", SpecialFunc::Data},
};

class OneTimeUseString {
Expand Down Expand Up @@ -143,6 +145,11 @@ class OneTimeUseASRNode {

};

enum ThirdPartyCPPArrayTypes {
XTensorArray,
MDSpanArray,
};

class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisitor> {

public:
Expand Down Expand Up @@ -329,7 +336,9 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

ASR::ttype_t* ClangTypeToASRType(const clang::QualType& qual_type,
Vec<ASR::dimension_t>* xshape_result=nullptr) {
Vec<ASR::dimension_t>* xshape_result=nullptr,
ThirdPartyCPPArrayTypes* array_type=nullptr,
bool* is_third_party_cpp_array=nullptr) {
const clang::SplitQualType& split_qual_type = qual_type.split();
const clang::Type* clang_type = split_qual_type.asPair().first;
const clang::Qualifiers qualifiers = split_qual_type.asPair().second;
Expand Down Expand Up @@ -392,11 +401,11 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
} else if( clang_type->getTypeClass() == clang::Type::LValueReference ) {
const clang::LValueReferenceType* lvalue_reference_type = clang_type->getAs<clang::LValueReferenceType>();
clang::QualType pointee_type = lvalue_reference_type->getPointeeType();
type = ClangTypeToASRType(pointee_type, xshape_result);
type = ClangTypeToASRType(pointee_type, xshape_result, array_type, is_third_party_cpp_array);
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::Elaborated ) {
const clang::ElaboratedType* elaborated_type = clang_type->getAs<clang::ElaboratedType>();
clang::QualType desugared_type = elaborated_type->desugar();
type = ClangTypeToASRType(desugared_type, xshape_result);
type = ClangTypeToASRType(desugared_type, xshape_result, array_type, is_third_party_cpp_array);
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::TemplateSpecialization ) {
const clang::TemplateSpecializationType* template_specialization = clang_type->getAs<clang::TemplateSpecializationType>();
std::string template_name = template_specialization->getTemplateName().getAsTemplateDecl()->getNameAsString();
Expand All @@ -421,11 +430,34 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
empty_dim.m_length = nullptr;
empty_dims.push_back(al, empty_dim);
}
if( array_type && is_third_party_cpp_array ) {
*is_third_party_cpp_array = true;
*array_type = ThirdPartyCPPArrayTypes::XTensorArray;
}
type = ASRUtils::TYPE(ASR::make_Array_t(al, l,
ClangTypeToASRType(qual_type),
empty_dims.p, empty_dims.size(),
ASR::array_physical_typeType::DescriptorArray));
type = ASRUtils::TYPE(ASR::make_Allocatable_t(al, l, type));
} else if( template_name == "mdspan" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( template_arguments.size() != 2 ) {
throw std::runtime_error("mdspan type must be initialised with element type, index type and rank.");
}
const clang::QualType& qual_type = template_arguments.at(0).getAsType();
const clang::QualType& shape_type = template_arguments.at(1).getAsType();
Vec<ASR::dimension_t> xtensor_fixed_dims; xtensor_fixed_dims.reserve(al, 1);
ClangTypeToASRType(shape_type, &xtensor_fixed_dims, array_type, is_third_party_cpp_array);
// Only allocatables are made for Kokkos::mdspan
if( array_type && is_third_party_cpp_array ) {
*is_third_party_cpp_array = true;
*array_type = ThirdPartyCPPArrayTypes::MDSpanArray;
}
type = ASRUtils::TYPE(ASR::make_Array_t(al, l,
ClangTypeToASRType(qual_type),
xtensor_fixed_dims.p, xtensor_fixed_dims.size(),
ASR::array_physical_typeType::DescriptorArray));
type = ASRUtils::TYPE(ASR::make_Allocatable_t(al, l, type));
} else if( template_name == "xtensor_fixed" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( template_arguments.size() != 2 ) {
Expand All @@ -434,7 +466,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
const clang::QualType& qual_type = template_arguments.at(0).getAsType();
const clang::QualType& shape_type = template_arguments.at(1).getAsType();
Vec<ASR::dimension_t> xtensor_fixed_dims; xtensor_fixed_dims.reserve(al, 1);
ClangTypeToASRType(shape_type, &xtensor_fixed_dims);
ClangTypeToASRType(shape_type, &xtensor_fixed_dims, array_type, is_third_party_cpp_array);
type = ASRUtils::TYPE(ASR::make_Array_t(al, l,
ClangTypeToASRType(qual_type),
xtensor_fixed_dims.p, xtensor_fixed_dims.size(),
Expand All @@ -461,14 +493,41 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
xshape_result->push_back(al, dim);
}
return nullptr;
} else if( template_name == "dextents" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( xshape_result == nullptr ) {
throw std::runtime_error("Result Vec<ASR::dimention_t>* not provided.");
}

ASR::expr_t* zero = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, l, 0,
ASRUtils::TYPE(ASR::make_Integer_t(al, l, 4))));
const clang::QualType& index_type = template_arguments.at(0).getAsType();
if( !ASR::is_a<ASR::Integer_t>(*ClangTypeToASRType(index_type)) ||
ASRUtils::extract_kind_from_ttype_t(ClangTypeToASRType(index_type)) != 4 ) {
throw std::runtime_error("Only int32_t should be used for index type in Kokkos::dextents.");
}
clang::Expr* clang_rank = template_arguments.at(1).getAsExpr();
TraverseStmt(clang_rank);
int rank = 0;
if( !ASRUtils::extract_value(ASRUtils::EXPR(tmp.get()), rank) ) {
throw std::runtime_error("Rank provided in the Kokkos::dextents must be a constant.");
}
for( int i = 0; i < rank; i++ ) {
ASR::dimension_t dim; dim.loc = l;
dim.m_length = nullptr;
dim.m_start = zero;
xshape_result->push_back(al, dim);
}
return nullptr;
} else if( template_name == "complex" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( template_arguments.size() > 1 ) {
throw std::runtime_error("std::complex accepts only one argument.");
}

const clang::QualType& qual_type = template_arguments.at(0).getAsType();
ASR::ttype_t* complex_subtype = ClangTypeToASRType(qual_type, xshape_result);
ASR::ttype_t* complex_subtype = ClangTypeToASRType(qual_type, xshape_result,
array_type, is_third_party_cpp_array);
if( !ASRUtils::is_real(*complex_subtype) ) {
throw std::runtime_error(std::string("std::complex accepts only real types, found: ") +
ASRUtils::type_to_str(complex_subtype));
Expand All @@ -482,14 +541,18 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

const clang::QualType& qual_type = template_arguments.at(0).getAsType();
ASR::ttype_t* vector_subtype = ClangTypeToASRType(qual_type, xshape_result);
ASR::ttype_t* vector_subtype = ClangTypeToASRType(qual_type, xshape_result,
array_type, is_third_party_cpp_array);
type = ASRUtils::TYPE(ASR::make_List_t(al, l, vector_subtype));
} else {
throw std::runtime_error(std::string("Unrecognized type ") + template_name);
}
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::Record ) {
const clang::CXXRecordDecl* record_type = clang_type->getAsCXXRecordDecl();
std::string name = record_type->getNameAsString();
if( name == "xtensor_container" || name == "vector" ) {
return nullptr;
}
ASR::symbol_t* type_t = current_scope->resolve_symbol(name);
if( !type_t ) {
throw std::runtime_error(name + " not defined.");
Expand Down Expand Up @@ -787,7 +850,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
member_name_obj.set(member_name);
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseMemberExpr(x);
} else {
throw std::runtime_error("clang::MemberExpr only supported for struct and special functions.");
throw std::runtime_error("clang::MemberExpr only supported for struct, union and special functions.");
}
return true;
}
Expand Down Expand Up @@ -1357,6 +1420,20 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

tmp = ASR::make_ListClear_t(al, Lloc(x), callee);
is_stmt_created = true;
} else if (sf == SpecialFunc::Data) {
if( args.size() > 0 ) {
throw std::runtime_error("std::vector::data should be called with only one argument.");
}

Vec<ASR::dimension_t> dims; dims.reserve(al, 1);
ASR::dimension_t dim;
dim.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t(
al, Lloc(x), 0, ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4))));
dim.m_length = nullptr;
dims.push_back(al, dim);
tmp = ASRUtils::make_Cast_t_value(al, Lloc(x), callee, ASR::cast_kindType::ListToArray,
ASRUtils::TYPE(ASR::make_Array_t(al, Lloc(x), ASRUtils::get_contained_type(ASRUtils::expr_type(callee)),
dims.p, dims.size(), ASR::array_physical_typeType::UnboundedPointerToDataArray)));
} else {
throw std::runtime_error("Only printf and exit special functions supported");
}
Expand Down Expand Up @@ -1493,15 +1570,70 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}

bool TraverseCXXConstructExpr(clang::CXXConstructExpr* x) {
if( x->getNumArgs() >= 0 ) {
ThirdPartyCPPArrayTypes third_party_array_type;
bool is_third_party_array_type = false;
Vec<ASR::dimension_t> shape_result; shape_result.reserve(al, 1);
ASR::ttype_t* constructor_type = ClangTypeToASRType(x->getType(), &shape_result,
&third_party_array_type, &is_third_party_array_type);
if( is_third_party_array_type && third_party_array_type == ThirdPartyCPPArrayTypes::MDSpanArray ) {
if( x->getNumArgs() >= 1 ) {
clang::Expr* _data = x->getArg(0);
TraverseStmt(_data);
ASR::expr_t* list_to_array = ASRUtils::EXPR(tmp.get());
if( !ASR::is_a<ASR::Cast_t>(*list_to_array) ||
ASR::down_cast<ASR::Cast_t>(list_to_array)->m_kind !=
ASR::cast_kindType::ListToArray) {
throw std::runtime_error("First argument of Kokkos::mdspan "
"constructor should be a call to std::vector::data() function.");
}
if( ASRUtils::extract_n_dims_from_ttype(constructor_type) != x->getNumArgs() - 1 ) {
throw std::runtime_error("Shape provided in the constructor "
"doesn't match with the rank of the array.");
}
Vec<ASR::dimension_t> alloc_dims; alloc_dims.reserve(al, x->getNumArgs() - 1);
Vec<ASR::alloc_arg_t> alloc_args; alloc_args.reserve(al, 1);
for( int i = 1; i < x->getNumArgs(); i++ ) {
clang::Expr* argi = x->getArg(i);
TraverseStmt(argi);
ASR::dimension_t alloc_dim;
alloc_dim.loc = Lloc(x);
alloc_dim.m_start = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, Lloc(x), 1,
ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4))));
alloc_dim.m_length = ASRUtils::EXPR(tmp.get());
alloc_dims.push_back(al, alloc_dim);
}
ASR::alloc_arg_t alloc_arg;
alloc_arg.loc = Lloc(x);
LCOMPILERS_ASSERT(assignment_target != nullptr);
alloc_arg.m_a = assignment_target;
alloc_arg.m_dims = alloc_dims.p; alloc_arg.n_dims = alloc_dims.size();
alloc_arg.m_len_expr = nullptr; alloc_arg.m_type = nullptr;
alloc_args.push_back(al, alloc_arg);
Vec<ASR::expr_t*> dealloc_args; dealloc_args.reserve(al, 1);
dealloc_args.push_back(al, assignment_target);
current_body->push_back(al, ASRUtils::STMT(ASR::make_ExplicitDeallocate_t(
al, Lloc(x), dealloc_args.p, dealloc_args.size())));
current_body->push_back(al, ASRUtils::STMT(ASR::make_Allocate_t(al, Lloc(x),
alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr)));
tmp = reinterpret_cast<ASR::asr_t*>(list_to_array);
is_stmt_created = false;
return true;
} else {
throw std::runtime_error("Kokkos::mdspan constructor is called "
"with incorrect number of arguments, "
"expected 3 and found, " + std::to_string(x->getNumArgs()));
}
} else if( x->getNumArgs() >= 0 ) {
return clang::RecursiveASTVisitor<ClangASTtoASRVisitor>::TraverseCXXConstructExpr(x);
}
tmp = nullptr;
return true;
}

void add_reshape_if_needed(ASR::expr_t*& expr, ASR::expr_t* target_expr) {
if( ASR::is_a<ASR::List_t>(*ASRUtils::expr_type(target_expr)) ) {
if( ASR::is_a<ASR::List_t>(*ASRUtils::expr_type(target_expr)) ||
((ASR::is_a<ASR::Cast_t>(*expr) &&
ASR::down_cast<ASR::Cast_t>(expr)->m_kind == ASR::cast_kindType::ListToArray)) ) {
return ;
}
ASR::ttype_t* expr_type = ASRUtils::expr_type(expr);
Expand Down
1 change: 1 addition & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ cast_kind
| CPtrToUnsignedInteger
| UnsignedIntegerToCPtr
| IntegerToSymbolicExpression
| ListToArray

dimension = (expr? start, expr? length)

Expand Down
Loading

0 comments on commit ac76a2a

Please sign in to comment.