Skip to content

Commit

Permalink
Merge pull request #24 from czgdp1807/lc_15
Browse files Browse the repository at this point in the history
Add support for `xt::xtensor_fixed` and elemental access
  • Loading branch information
czgdp1807 authored Dec 19, 2023
2 parents d7e1c43 + 95d8eea commit 3f2b43d
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 20 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,5 @@ RUN(NAME array_02.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CMAKE_CURRENT_SOURCE_DIR}/../src/runtime/include)
RUN(NAME array_03.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
RUN(NAME array_05.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
23 changes: 23 additions & 0 deletions integration_tests/array_05.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <iostream>
#include <xtensor/xfixed.hpp>
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"

int main() {

xt::xtensor_fixed<double, xt::xshape<3, 3>> arr1;
xt::xtensor<double, 1> arr2 {5.0, 6.0, 7.0};
xt::xtensor_fixed<double, xt::xshape<3>> result;

for( int i = 0; i < 3; i++ ) {
for( int j = 0; j < 3; j++ ) {
arr1(i, j) = i*3 + j;
}
}

std::cout << arr1 << arr2 << std::endl;
result = xt::view(arr1, 1) + arr2;
std::cout << result << std::endl;

return 0;
}
197 changes: 177 additions & 20 deletions src/lc/clang_ast_to_asr.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ namespace LCompilers {

enum SpecialFunc {
Printf,
Exit
Exit,
View
};

std::map<std::string, SpecialFunc> special_function_map = {
{"printf", SpecialFunc::Printf},
{"exit", SpecialFunc::Exit}
{"exit", SpecialFunc::Exit},
{"view", SpecialFunc::View}
};

class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisitor> {
Expand All @@ -55,13 +57,14 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
std::vector<std::map<std::string, std::string>> scopes;
std::string cxx_operator_name;
ASR::expr_t* assignment_target;
Vec<ASR::expr_t*>* print_args;

explicit ClangASTtoASRVisitor(clang::ASTContext *Context_,
Allocator& al_, ASR::asr_t*& tu_):
Context(Context_), al{al_}, tu{tu_},
tmp{nullptr}, current_body{nullptr},
is_stmt_created{true}, cxx_operator_name{""},
assignment_target{nullptr} {}
assignment_target{nullptr}, print_args{nullptr} {}

template <typename T>
Location Lloc(T *x) {
Expand Down Expand Up @@ -212,7 +215,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
return array;
}

ASR::ttype_t* ClangTypeToASRType(const clang::QualType& qual_type) {
ASR::ttype_t* ClangTypeToASRType(const clang::QualType& qual_type,
Vec<ASR::dimension_t>* xshape_result=nullptr) {
const clang::SplitQualType& split_qual_type = qual_type.split();
const clang::Type* clang_type = split_qual_type.asPair().first;
const clang::Qualifiers qualifiers = split_qual_type.asPair().second;
Expand Down Expand Up @@ -266,7 +270,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::Elaborated ) {
const clang::ElaboratedType* elaborated_type = clang_type->getAs<clang::ElaboratedType>();
clang::QualType desugared_type = elaborated_type->desugar();
type = ClangTypeToASRType(desugared_type);
type = ClangTypeToASRType(desugared_type, xshape_result);
} else if( clang_type->getTypeClass() == clang::Type::TypeClass::TemplateSpecialization ) {
const clang::TemplateSpecializationType* template_specialization = clang_type->getAs<clang::TemplateSpecializationType>();
std::string template_name = template_specialization->getTemplateName().getAsTemplateDecl()->getNameAsString();
Expand Down Expand Up @@ -296,6 +300,41 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
empty_dims.p, empty_dims.size(),
ASR::array_physical_typeType::DescriptorArray));
type = ASRUtils::TYPE(ASR::make_Allocatable_t(al, l, type));
} else if( template_name == "xtensor_fixed" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( template_arguments.size() != 2 ) {
throw std::runtime_error("xtensor_fixed type must be initialised with element type and shape.");
}
const clang::QualType& qual_type = template_arguments.at(0).getAsType();
const clang::QualType& shape_type = template_arguments.at(1).getAsType();
Vec<ASR::dimension_t> xtensor_fixed_dims; xtensor_fixed_dims.reserve(al, 1);
ClangTypeToASRType(shape_type, &xtensor_fixed_dims);
type = ASRUtils::TYPE(ASR::make_Array_t(al, l,
ClangTypeToASRType(qual_type),
xtensor_fixed_dims.p, xtensor_fixed_dims.size(),
ASR::array_physical_typeType::FixedSizeArray));
} else if( template_name == "xshape" ) {
const std::vector<clang::TemplateArgument>& template_arguments = template_specialization->template_arguments();
if( xshape_result == nullptr ) {
throw std::runtime_error("Result Vec<ASR::dimention_t>* not provided.");
}

ASR::expr_t* zero = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, l, 0,
ASRUtils::TYPE(ASR::make_Integer_t(al, l, 4))));
for( int i = 0; i < template_arguments.size(); i++ ) {
clang::Expr* clang_rank = template_arguments.at(i).getAsExpr();
TraverseStmt(clang_rank);
int rank = 0;
if( !ASRUtils::extract_value(ASRUtils::EXPR(tmp), rank) ) {
throw std::runtime_error("Rank provided in the xshape must be a constant.");
}
ASR::dimension_t dim; dim.loc = l;
dim.m_length = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, l, rank,
ASRUtils::TYPE(ASR::make_Integer_t(al, l, 4))));
dim.m_start = zero;
xshape_result->push_back(al, dim);
}
return nullptr;
} else {
throw std::runtime_error(std::string("Unrecognized type ") + template_name);
}
Expand Down Expand Up @@ -360,33 +399,111 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

bool TraverseCXXOperatorCallExpr(clang::CXXOperatorCallExpr* x) {
clang::Expr* callee = x->getCallee();
cxx_operator_name.clear();
TraverseStmt(callee);
if( cxx_operator_name == "operator<<" ) {
if( print_args == nullptr ) {
print_args = al.make_new<Vec<ASR::expr_t*>>();
print_args->reserve(al, 1);
}
clang::Expr** args = x->getArgs();
cxx_operator_name.clear();
TraverseStmt(args[1]);
if( cxx_operator_name.size() == 0 && print_args != nullptr && tmp != nullptr ) {
ASR::expr_t* arg = ASRUtils::EXPR(tmp);
print_args->push_back(al, arg);
}
cxx_operator_name.clear();
TraverseStmt(args[0]);
if( cxx_operator_name != "cout" ) {
throw std::runtime_error("Only cout is supported in operator<<.");
if( cxx_operator_name == "cout" ) {
Vec<ASR::expr_t*> print_args_vec;
print_args_vec.reserve(al, print_args->size());
for( int i = print_args->size() - 1; i >= 0; i-- ) {
print_args_vec.push_back(al, print_args->p[i]);
}
tmp = ASR::make_Print_t(al, Lloc(x), print_args_vec.p,
print_args_vec.size(), nullptr, nullptr);
print_args = nullptr;
is_stmt_created = true;
}
size_t n_args = x->getNumArgs();
Vec<ASR::expr_t*> print_args;
print_args.reserve(al, n_args);
for( size_t i = 1; i < n_args; i++ ) {
TraverseStmt(args[i]);
ASR::expr_t* arg = ASRUtils::EXPR(tmp);
print_args.push_back(al, arg);
} else if( cxx_operator_name == "operator()" ) {
clang::Expr** args = x->getArgs();
if( x->getNumArgs() == 0 ) {
throw std::runtime_error("operator() needs at least the callee to be present.");
}

TraverseStmt(args[0]);
ASR::expr_t* obj = ASRUtils::EXPR(tmp);
if( ASRUtils::is_array(ASRUtils::expr_type(obj)) ) {
Vec<ASR::array_index_t> array_indices;
array_indices.reserve(al, x->getNumArgs() - 1);
for( size_t i = 1; i < x->getNumArgs(); i++ ) {
TraverseStmt(args[i]);
ASR::expr_t* index = ASRUtils::EXPR(tmp);
ASR::array_index_t array_index;
array_index.loc = index->base.loc;
array_index.m_left = nullptr;
array_index.m_right = index;
array_index.m_step = nullptr;
array_indices.push_back(al, array_index);
}
tmp = ASR::make_ArrayItem_t(al, Lloc(x), obj,
array_indices.p, array_indices.size(),
ASRUtils::type_get_past_allocatable(
ASRUtils::type_get_past_pointer(
ASRUtils::type_get_past_array(
ASRUtils::expr_type(obj)))),
ASR::arraystorageType::RowMajor, nullptr);
} else {
throw std::runtime_error("Only indexing arrays is supported for now with operator().");
}
} else if( cxx_operator_name == "operator=" ) {
if( x->getNumArgs() != 2 ) {
throw std::runtime_error("operator= accepts two arguments, found " + std::to_string(x->getNumArgs()));
}

clang::Expr** args = x->getArgs();
TraverseStmt(args[0]);
ASR::expr_t* obj = ASRUtils::EXPR(tmp);
if( ASRUtils::is_array(ASRUtils::expr_type(obj)) ) {
TraverseStmt(args[1]);
ASR::expr_t* value = ASRUtils::EXPR(tmp);
tmp = ASR::make_Assignment_t(al, Lloc(x), obj, value, nullptr);
is_stmt_created = true;
} else {
throw std::runtime_error("operator= is supported only for arrays.");
}
} else if( cxx_operator_name == "operator+" ) {
if( x->getNumArgs() != 2 ) {
throw std::runtime_error("operator+ accepts two arguments, found " + std::to_string(x->getNumArgs()));
}

clang::Expr** args = x->getArgs();
TraverseStmt(args[0]);
ASR::expr_t* obj = ASRUtils::EXPR(tmp);
if( ASRUtils::is_array(ASRUtils::expr_type(obj)) ) {
TraverseStmt(args[1]);
ASR::expr_t* value = ASRUtils::EXPR(tmp);
CreateBinOp(obj, value, ASR::binopType::Add, Lloc(x));
} else {
throw std::runtime_error("operator= is supported only for arrays.");
}
tmp = ASR::make_Print_t(al, Lloc(x), print_args.p, print_args.size(), nullptr, nullptr);
is_stmt_created = true;
} else {
throw std::runtime_error("Only std::ostream operator is supported.");
throw std::runtime_error("Only std::ostream and operator() are supported, found " + cxx_operator_name);
}
cxx_operator_name.clear();
return true;
}

bool check_and_handle_special_function(
clang::CallExpr *x, ASR::expr_t* callee) {
ASR::Var_t* callee_Var = ASR::down_cast<ASR::Var_t>(callee);
std::string func_name = std::string(ASRUtils::symbol_name(callee_Var->m_v));
std::string func_name;
if( cxx_operator_name.size() > 0 ) {
func_name = cxx_operator_name;
} else {
ASR::Var_t* callee_Var = ASR::down_cast<ASR::Var_t>(callee);
func_name = std::string(ASRUtils::symbol_name(callee_Var->m_v));
}
if (special_function_map.find(func_name) == special_function_map.end()) {
return false;
}
Expand All @@ -404,6 +521,43 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}
if (sf == SpecialFunc::Printf) {
tmp = ASR::make_Print_t(al, Lloc(x), args.p, args.size(), nullptr, nullptr);
} else if (sf == SpecialFunc::View) {
ASR::expr_t* array = args.p[0];
size_t rank = ASRUtils::extract_n_dims_from_ttype(ASRUtils::expr_type(array));
Vec<ASR::array_index_t> array_section_indices;
array_section_indices.reserve(al, rank);
size_t i, j, result_dims = 0;
for( i = 0, j = 1; j < args.size(); j++, i++ ) {
ASR::array_index_t index;
index.loc = args.p[j]->base.loc;
index.m_left = nullptr;
index.m_right = args.p[j];
index.m_step = nullptr;
array_section_indices.push_back(al, index);
}
for( ; i < rank; i++ ) {
ASR::array_index_t index;
index.loc = callee->base.loc;
index.m_left = ASRUtils::get_bound<SemanticError>(array, i + 1, "lbound", al);
index.m_right = ASRUtils::get_bound<SemanticError>(array, i + 1, "ubound", al);
index.m_step = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, index.loc, 1,
ASRUtils::TYPE(ASR::make_Integer_t(al, index.loc, 4))));
array_section_indices.push_back(al, index);
result_dims += 1;
}
ASR::ttype_t* element_type = ASRUtils::extract_type(ASRUtils::expr_type(array));
Vec<ASR::dimension_t> empty_dims; empty_dims.reserve(al, result_dims);
for( size_t i = 0; i < result_dims; i++ ) {
ASR::dimension_t dim;
dim.loc = Lloc(x);
dim.m_length = nullptr;
dim.m_start = nullptr;
empty_dims.push_back(al, dim);
}
ASR::ttype_t* array_section_type = ASRUtils::TYPE(ASR::make_Array_t(al, Lloc(x),
element_type, empty_dims.p, empty_dims.size(), ASR::array_physical_typeType::DescriptorArray));
tmp = ASR::make_ArraySection_t(al, Lloc(x), array, array_section_indices.p,
array_section_indices.size(), array_section_type, nullptr);
} else if (sf == SpecialFunc::Exit) {
LCOMPILERS_ASSERT(args.size() == 1);
tmp = ASR::make_Stop_t(al, Lloc(x), args[0]);
Expand Down Expand Up @@ -802,7 +956,10 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

bool TraverseDeclRefExpr(clang::DeclRefExpr* x) {
std::string name = x->getNameInfo().getAsString();
if( name == "operator<<" || name == "cout" ) {
if( name == "operator<<" || name == "cout" ||
name == "endl" || name == "operator()" ||
name == "operator+" || name == "operator=" ||
name == "view" ) {
cxx_operator_name = name;
return true;
}
Expand Down

0 comments on commit 3f2b43d

Please sign in to comment.