Skip to content

Commit

Permalink
Merge pull request #42 from czgdp1807/lc_24
Browse files Browse the repository at this point in the history
Ported ``integration_tests/arrays_op_29.f90`` from LFortran and improve LC to support it
  • Loading branch information
czgdp1807 authored Dec 28, 2023
2 parents 3aa037f + 5afbd51 commit 5901096
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 9 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,5 @@ RUN(NAME array_11.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
RUN(NAME array_12.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
RUN(NAME array_13.cpp LABELS gcc llvm NOFAST
EXTRA_ARGS --extra-arg=-I${CONDA_PREFIX}/include)
30 changes: 30 additions & 0 deletions integration_tests/array_13.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <iostream>
#include <cmath>
#include <xtensor/xtensor.hpp>
#include <xtensor/xfixed.hpp>
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"

xt::xtensor<double, 1> R3(const xt::xtensor<double, 1>& x) {
xt::xtensor<double, 1> R3_ = xt::empty<double>({x.size()});

double a, b;
const double c = 2.0;
a = x(0);
b = x(1);
R3_.fill(std::pow(a, 2.0) + std::pow(b, 2.0) - std::pow(c, 2.0));

return R3_;
}

int main() {

xt::xtensor_fixed<double, xt::xshape<2>> x = {0.5, 0.5};
std::cout << R3(x) << std::endl;
if( xt::any(xt::not_equal(R3(x), -3.5)) ) {
exit(2);
}

return 0;

}
51 changes: 45 additions & 6 deletions src/lc/clang_ast_to_asr.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ enum SpecialFunc {
AMax,
Sum,
Range,
Size,
Pow,
};

std::map<std::string, SpecialFunc> special_function_map = {
Expand All @@ -62,6 +64,8 @@ std::map<std::string, SpecialFunc> special_function_map = {
{"amax", SpecialFunc::AMax},
{"sum", SpecialFunc::Sum},
{"range", SpecialFunc::Range},
{"size", SpecialFunc::Size},
{"pow", SpecialFunc::Pow},
};

class OneTimeUseString {
Expand Down Expand Up @@ -438,6 +442,9 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
intent_type = ASR::intentType::In;
type = ASRUtils::type_get_past_const(type);
}
if( ASRUtils::is_allocatable(type) ) {
type = ASRUtils::type_get_past_allocatable(type);
}
if( x->getType()->getTypeClass() != clang::Type::LValueReference &&
ASRUtils::is_array(type) ) {
throw std::runtime_error("Array objects should be passed by reference only.");
Expand Down Expand Up @@ -496,7 +503,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
TraverseStmt(callee);
ASR::expr_t* asr_callee = ASRUtils::EXPR(tmp.get());
if( !check_and_handle_special_function(x, asr_callee) ) {
throw std::runtime_error("Only xt::xtensor::shape, xt::xtensor::fill is supported.");
throw std::runtime_error("Only xt::xtensor::shape, "
"xt::xtensor::fill, xt::xtensor::size is supported.");
}

return true;
Expand Down Expand Up @@ -764,6 +772,14 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
tmp = ASR::make_ArraySize_t(al, Lloc(x), callee, dim,
ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4)),
nullptr);
} else if (sf == SpecialFunc::Size) {
if( args.size() != 0 ) {
throw std::runtime_error("xt::xtensor::size should be called with only one argument.");
}

tmp = ASR::make_ArraySize_t(al, Lloc(x), callee, nullptr,
ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4)),
nullptr);
} else if (sf == SpecialFunc::Fill) {
if( args.size() != 1 ) {
throw std::runtime_error("xt::xtensor::fill should be called with only one argument.");
Expand Down Expand Up @@ -804,6 +820,12 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
tmp = ASRUtils::make_IntrinsicScalarFunction_t_util(al, Lloc(x),
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::Exp),
args.p, args.size(), 0, ASRUtils::expr_type(args.p[0]), nullptr);
} else if( sf == SpecialFunc::Pow ) {
if( args.size() != 2 ) {
throw std::runtime_error("Pow accepts exactly two arguments.");
}

CreateBinOp(args.p[0], args.p[1], ASR::binopType::Pow, Lloc(x));
} else if( sf == SpecialFunc::Abs ) {
tmp = ASRUtils::make_IntrinsicScalarFunction_t_util(al, Lloc(x),
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::Abs),
Expand Down Expand Up @@ -1215,6 +1237,19 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}
}

void CreateUnaryMinus(ASR::expr_t* op, const Location& loc) {
if( ASRUtils::is_integer(*ASRUtils::expr_type(op)) ) {
tmp = ASR::make_IntegerUnaryMinus_t(al, loc, op,
ASRUtils::expr_type(op), nullptr);
} else if( ASRUtils::is_real(*ASRUtils::expr_type(op)) ) {
tmp = ASR::make_RealUnaryMinus_t(al, loc, op,
ASRUtils::expr_type(op), nullptr);
} else {
throw std::runtime_error("Only integer and real types are supported so "
"far for unary operator, found: " + ASRUtils::type_to_str(ASRUtils::expr_type(op)));
}
}

bool TraverseBinaryOperator(clang::BinaryOperator *x) {
clang::BinaryOperatorKind op = x->getOpcode();
TraverseStmt(x->getLHS());
Expand Down Expand Up @@ -1311,7 +1346,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
name == "exit" || name == "printf" || name == "exp" ||
name == "sum" || name == "amax" || name == "abs" ||
name == "operator-" || name == "operator/" || name == "operator>" ||
name == "range" ) {
name == "range" || name == "pow" ) {
if( sym != nullptr && ASR::is_a<ASR::Function_t>(
*ASRUtils::symbol_get_past_external(sym)) ) {
throw std::runtime_error("Special function " + name + " cannot be overshadowed yet.");
Expand Down Expand Up @@ -1431,11 +1466,11 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

bool TraverseUnaryOperator(clang::UnaryOperator* x) {
clang::UnaryOperatorKind op = x->getOpcode();
clang::Expr* expr = x->getSubExpr();
TraverseStmt(expr);
ASR::expr_t* var = ASRUtils::EXPR(tmp.get());
switch( op ) {
case clang::UnaryOperatorKind::UO_PostInc: {
clang::Expr* expr = x->getSubExpr();
TraverseStmt(expr);
ASR::expr_t* var = ASRUtils::EXPR(tmp.get());
ASR::expr_t* incbyone = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(
al, Lloc(x), var, ASR::binopType::Add,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(
Expand All @@ -1445,8 +1480,12 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
is_stmt_created = true;
break;
}
case clang::UnaryOperatorKind::UO_Minus: {
CreateUnaryMinus(var, Lloc(x));
break;
}
default: {
throw std::runtime_error("Only postfix increment is supported so far");
throw std::runtime_error("Only postfix increment and minus are supported so far.");
}
}
return true;
Expand Down
6 changes: 3 additions & 3 deletions src/libasr/codegen/llvm_array_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ namespace LCompilers {
int n_args, bool check_for_bounds, bool is_unbounded_pointer_to_data) {
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
llvm::Value* idx = llvm::ConstantInt::get(context, llvm::APInt(32, 0));
for( int r = 0, r1 = 0; r < n_args; r++ ) {
for( int r = n_args - 1, r1 = 2*n_args - 2; r >= 0; r-- ) {
llvm::Value* curr_llvm_idx = m_args[r];
llvm::Value* lval = llvm_diminfo[r1];
curr_llvm_idx = builder->CreateSub(curr_llvm_idx, lval);
Expand All @@ -611,10 +611,10 @@ namespace LCompilers {
}
idx = builder->CreateAdd(idx, builder->CreateMul(prod, curr_llvm_idx));
if (is_unbounded_pointer_to_data) {
r1 += 1;
r1 -= 1;
} else {
llvm::Value* dim_size = llvm_diminfo[r1 + 1];
r1 += 2;
r1 -= 2;
prod = builder->CreateMul(prod, dim_size);
}
}
Expand Down

0 comments on commit 5901096

Please sign in to comment.