Skip to content

Commit

Permalink
Merge pull request #114 from czgdp1807/mdspan_02
Browse files Browse the repository at this point in the history
Prohibit modification of `std::vector` objects used by `Kokkos::mdspan` constructor
  • Loading branch information
czgdp1807 authored Mar 20, 2024
2 parents ac76a2a + 166c0c4 commit 292d177
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 5 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ jobs:
if: contains(matrix.os, 'ubuntu') || contains(matrix.os, 'macos')
run: |
export CPATH=$CONDA_PREFIX/include:$CPATH
git clone https://github.com/czgdp1807/mdspan
cd mdspan
cmake . -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
make install
cd ..
lc --show-clang-ast tests/test.cpp
lc --show-clang-ast --parse-all-comments tests/parse_comments_01.cpp > parse_comments_01.stdout
grep "TextComment" parse_comments_01.stdout
Expand All @@ -76,10 +81,5 @@ jobs:
if: contains(matrix.os, 'ubuntu') || contains(matrix.os, 'macos')
run: |
export CPATH=$CONDA_PREFIX/include:$CPATH
git clone https://github.com/czgdp1807/mdspan
cd mdspan
cmake . -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
make install
cd ..
./integration_tests/run_tests.py -b gcc llvm wasm c
./integration_tests/run_tests.py -b gcc llvm wasm c -f
115 changes: 115 additions & 0 deletions src/lc/clang_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <libasr/pass/intrinsic_array_function_registry.h>
#include <lc/clang_ast_to_asr.h>

#include <set>

namespace LCompilers {

enum SpecialFunc {
Expand Down Expand Up @@ -176,6 +178,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
std::map<SymbolTable*, std::vector<ASR::symbol_t*>> scope2enums;
clang::ForStmt* for_loop;
bool inside_loop;
std::set<ASR::symbol_t*> read_only_symbols;

explicit ClangASTtoASRVisitor(clang::ASTContext *Context_,
Allocator& al_, ASR::asr_t*& tu_):
Expand Down Expand Up @@ -1585,6 +1588,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
ASR::cast_kindType::ListToArray) {
throw std::runtime_error("First argument of Kokkos::mdspan "
"constructor should be a call to std::vector::data() function.");
} else {
mark_as_read_only(list_to_array);
}
if( ASRUtils::extract_n_dims_from_ttype(constructor_type) != x->getNumArgs() - 1 ) {
throw std::runtime_error("Shape provided in the constructor "
Expand Down Expand Up @@ -2139,13 +2144,123 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}
}

void mark_as_read_only(ASR::symbol_t* symbol) {
read_only_symbols.insert(ASRUtils::symbol_get_past_external(symbol));
}

void mark_as_read_only(ASR::expr_t* list_to_array, bool root=true) {
ASR::expr_t* expr = list_to_array;
if( root ) {
if( !ASR::is_a<ASR::Cast_t>(*list_to_array) ) {
return ;
}
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(list_to_array);
if( cast_t->m_kind != ASR::cast_kindType::ListToArray ) {
return ;
}

expr = cast_t->m_arg;
}

switch( expr->type ) {
case ASR::exprType::Var: {
ASR::Var_t* var_t = ASR::down_cast<ASR::Var_t>(expr);
return mark_as_read_only(var_t->m_v);
}
case ASR::exprType::ListItem: {
ASR::ListItem_t* list_item_t = ASR::down_cast<ASR::ListItem_t>(expr);
return mark_as_read_only(list_item_t->m_a, false);
}
case ASR::exprType::ListSection: {
ASR::ListSection_t* list_section_t = ASR::down_cast<ASR::ListSection_t>(expr);
return mark_as_read_only(list_section_t->m_a, false);
}
case ASR::exprType::StructInstanceMember: {
ASR::StructInstanceMember_t* struct_instance_member_t = ASR::down_cast<ASR::StructInstanceMember_t>(expr);
return mark_as_read_only(struct_instance_member_t->m_m);
}
case ASR::exprType::StructStaticMember: {
ASR::StructStaticMember_t* struct_static_member_t = ASR::down_cast<ASR::StructStaticMember_t>(expr);
return mark_as_read_only(struct_static_member_t->m_m);
}
case ASR::exprType::EnumStaticMember: {
ASR::EnumStaticMember_t* enum_static_member_t = ASR::down_cast<ASR::EnumStaticMember_t>(expr);
return mark_as_read_only(enum_static_member_t->m_m);
}
case ASR::exprType::UnionInstanceMember: {
ASR::UnionInstanceMember_t* union_instance_member_t = ASR::down_cast<ASR::UnionInstanceMember_t>(expr);
return mark_as_read_only(union_instance_member_t->m_m);
}
default: {
throw std::runtime_error("ASRUtils::exprType::" + std::to_string(expr->type) +
" is not handled yet in is_read_only method.");
}
}
}

bool is_read_only(ASR::symbol_t* symbol, std::string& symbol_name) {
if( read_only_symbols.find(ASRUtils::symbol_get_past_external(symbol)) !=
read_only_symbols.end() ) {
symbol_name = ASRUtils::symbol_name(ASRUtils::symbol_get_past_external(symbol));
return true;
}
return false;
}

bool is_read_only(ASR::expr_t* expr, std::string& symbol_name) {
switch( expr->type ) {
case ASR::exprType::Var: {
ASR::Var_t* var_t = ASR::down_cast<ASR::Var_t>(expr);
return is_read_only(var_t->m_v, symbol_name);
}
case ASR::exprType::ListItem: {
ASR::ListItem_t* list_item_t = ASR::down_cast<ASR::ListItem_t>(expr);
return is_read_only(list_item_t->m_a, symbol_name);
}
case ASR::exprType::ArrayItem: {
ASR::ArrayItem_t* array_item_t = ASR::down_cast<ASR::ArrayItem_t>(expr);
return is_read_only(array_item_t->m_v, symbol_name);
}
case ASR::exprType::ListSection: {
ASR::ListSection_t* list_section_t = ASR::down_cast<ASR::ListSection_t>(expr);
return is_read_only(list_section_t->m_a, symbol_name);
}
case ASR::exprType::StructInstanceMember: {
ASR::StructInstanceMember_t* struct_instance_member_t = ASR::down_cast<ASR::StructInstanceMember_t>(expr);
return is_read_only(struct_instance_member_t->m_m, symbol_name);
}
case ASR::exprType::StructStaticMember: {
ASR::StructStaticMember_t* struct_static_member_t = ASR::down_cast<ASR::StructStaticMember_t>(expr);
return is_read_only(struct_static_member_t->m_m, symbol_name);
}
case ASR::exprType::EnumStaticMember: {
ASR::EnumStaticMember_t* enum_static_member_t = ASR::down_cast<ASR::EnumStaticMember_t>(expr);
return is_read_only(enum_static_member_t->m_m, symbol_name);
}
case ASR::exprType::UnionInstanceMember: {
ASR::UnionInstanceMember_t* union_instance_member_t = ASR::down_cast<ASR::UnionInstanceMember_t>(expr);
return is_read_only(union_instance_member_t->m_m, symbol_name);
}
default: {
throw std::runtime_error("ASRUtils::exprType::" + std::to_string(expr->type) +
" is not handled yet in is_read_only method.");
}
}
return false;
}

bool TraverseBinaryOperator(clang::BinaryOperator *x) {
clang::BinaryOperatorKind op = x->getOpcode();
TraverseStmt(x->getLHS());
ASR::expr_t* x_lhs = ASRUtils::EXPR(tmp.get());
TraverseStmt(x->getRHS());
ASR::expr_t* x_rhs = ASRUtils::EXPR(tmp.get());
if( op == clang::BO_Assign ) {
std::string symbol_name = "";
if( is_read_only(x_lhs, symbol_name) ) {
std::cerr << symbol_name + " is marked as read only." << std::endl;
exit(EXIT_FAILURE);
}
cast_helper(x_lhs, x_rhs, true);
ASRUtils::make_ArrayBroadcast_t_util(al, Lloc(x), x_lhs, x_rhs);
tmp = ASR::make_Assignment_t(al, Lloc(x), x_lhs, x_rhs, nullptr);
Expand Down
13 changes: 13 additions & 0 deletions tests/array_01_mdspan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <vector>
#include <mdspan/mdspan.hpp>

int main() {

std::vector<double> arr1_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};

Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>> arr1{arr1_data.data(), 3, 3};

arr1_data[0] = 4.0;

return 0;
}
31 changes: 31 additions & 0 deletions tests/array_02_mdspan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <vector>
#include <mdspan/mdspan.hpp>

struct A {
std::vector<double> arr1_data;
};

int main1() {

std::vector<double> arr2_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};
std::vector<double> arr3_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};

Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>> arr1{arr2_data.data(), 3, 3};

arr3_data[0] = 4.0;
arr2_data[0] = 4.0;

return 0;
}

int main() {

struct A aobj;
aobj.arr1_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};

Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>> arr1{aobj.arr1_data.data(), 3, 3};

aobj.arr1_data[0] = 4.0;

return 0;
}
13 changes: 13 additions & 0 deletions tests/reference/asr-array_01_mdspan-3a0babe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-array_01_mdspan-3a0babe",
"cmd": "lc --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/array_01_mdspan.cpp",
"infile_hash": "89f0c42a5270d3e994901c745aeb9d6b6d005eb466f53abc3c6b0684",
"outfile": null,
"outfile_hash": null,
"stdout": null,
"stdout_hash": null,
"stderr": "asr-array_01_mdspan-3a0babe.stderr",
"stderr_hash": "75fae36610809dcd01894bb09490f3b174e656cf490f47aa38aa5caf",
"returncode": 1
}
1 change: 1 addition & 0 deletions tests/reference/asr-array_01_mdspan-3a0babe.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
arr1_data is marked as read only.
13 changes: 13 additions & 0 deletions tests/reference/asr-array_02_mdspan-a54a3c5.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-array_02_mdspan-a54a3c5",
"cmd": "lc --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/array_02_mdspan.cpp",
"infile_hash": "0a58aca0631849d0d9f53e91255e409d79d6e8df0db2ce820f0f5d15",
"outfile": null,
"outfile_hash": null,
"stdout": null,
"stdout_hash": null,
"stderr": "asr-array_02_mdspan-a54a3c5.stderr",
"stderr_hash": "c335fe2f02049228933225c455954d7fb45ed1e5b318fafe53932c96",
"returncode": 1
}
1 change: 1 addition & 0 deletions tests/reference/asr-array_02_mdspan-a54a3c5.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
arr2_data is marked as read only.
8 changes: 8 additions & 0 deletions tests/tests.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ cpp = true
llvm = true
fortran = true

[[test]]
filename = "array_01_mdspan.cpp"
asr = true

[[test]]
filename = "array_02_mdspan.cpp"
asr = true

[[test]]
filename = "../integration_tests/array_01.cpp"
asr = true
Expand Down

0 comments on commit 292d177

Please sign in to comment.