Skip to content

Commit 292d177

Browse files
authored
Merge pull request #114 from czgdp1807/mdspan_02
Prohibit modification of `std::vector` objects used by `Kokkos::mdspan` constructor
2 parents ac76a2a + 166c0c4 commit 292d177

9 files changed

+200
-5
lines changed

.github/workflows/CI.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ jobs:
5151
if: contains(matrix.os, 'ubuntu') || contains(matrix.os, 'macos')
5252
run: |
5353
export CPATH=$CONDA_PREFIX/include:$CPATH
54+
git clone https://github.com/czgdp1807/mdspan
55+
cd mdspan
56+
cmake . -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
57+
make install
58+
cd ..
5459
lc --show-clang-ast tests/test.cpp
5560
lc --show-clang-ast --parse-all-comments tests/parse_comments_01.cpp > parse_comments_01.stdout
5661
grep "TextComment" parse_comments_01.stdout
@@ -76,10 +81,5 @@ jobs:
7681
if: contains(matrix.os, 'ubuntu') || contains(matrix.os, 'macos')
7782
run: |
7883
export CPATH=$CONDA_PREFIX/include:$CPATH
79-
git clone https://github.com/czgdp1807/mdspan
80-
cd mdspan
81-
cmake . -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
82-
make install
83-
cd ..
8484
./integration_tests/run_tests.py -b gcc llvm wasm c
8585
./integration_tests/run_tests.py -b gcc llvm wasm c -f

src/lc/clang_ast_to_asr.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <libasr/pass/intrinsic_array_function_registry.h>
1313
#include <lc/clang_ast_to_asr.h>
1414

15+
#include <set>
16+
1517
namespace LCompilers {
1618

1719
enum SpecialFunc {
@@ -176,6 +178,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
176178
std::map<SymbolTable*, std::vector<ASR::symbol_t*>> scope2enums;
177179
clang::ForStmt* for_loop;
178180
bool inside_loop;
181+
std::set<ASR::symbol_t*> read_only_symbols;
179182

180183
explicit ClangASTtoASRVisitor(clang::ASTContext *Context_,
181184
Allocator& al_, ASR::asr_t*& tu_):
@@ -1585,6 +1588,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
15851588
ASR::cast_kindType::ListToArray) {
15861589
throw std::runtime_error("First argument of Kokkos::mdspan "
15871590
"constructor should be a call to std::vector::data() function.");
1591+
} else {
1592+
mark_as_read_only(list_to_array);
15881593
}
15891594
if( ASRUtils::extract_n_dims_from_ttype(constructor_type) != x->getNumArgs() - 1 ) {
15901595
throw std::runtime_error("Shape provided in the constructor "
@@ -2139,13 +2144,123 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
21392144
}
21402145
}
21412146

2147+
void mark_as_read_only(ASR::symbol_t* symbol) {
2148+
read_only_symbols.insert(ASRUtils::symbol_get_past_external(symbol));
2149+
}
2150+
2151+
void mark_as_read_only(ASR::expr_t* list_to_array, bool root=true) {
2152+
ASR::expr_t* expr = list_to_array;
2153+
if( root ) {
2154+
if( !ASR::is_a<ASR::Cast_t>(*list_to_array) ) {
2155+
return ;
2156+
}
2157+
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(list_to_array);
2158+
if( cast_t->m_kind != ASR::cast_kindType::ListToArray ) {
2159+
return ;
2160+
}
2161+
2162+
expr = cast_t->m_arg;
2163+
}
2164+
2165+
switch( expr->type ) {
2166+
case ASR::exprType::Var: {
2167+
ASR::Var_t* var_t = ASR::down_cast<ASR::Var_t>(expr);
2168+
return mark_as_read_only(var_t->m_v);
2169+
}
2170+
case ASR::exprType::ListItem: {
2171+
ASR::ListItem_t* list_item_t = ASR::down_cast<ASR::ListItem_t>(expr);
2172+
return mark_as_read_only(list_item_t->m_a, false);
2173+
}
2174+
case ASR::exprType::ListSection: {
2175+
ASR::ListSection_t* list_section_t = ASR::down_cast<ASR::ListSection_t>(expr);
2176+
return mark_as_read_only(list_section_t->m_a, false);
2177+
}
2178+
case ASR::exprType::StructInstanceMember: {
2179+
ASR::StructInstanceMember_t* struct_instance_member_t = ASR::down_cast<ASR::StructInstanceMember_t>(expr);
2180+
return mark_as_read_only(struct_instance_member_t->m_m);
2181+
}
2182+
case ASR::exprType::StructStaticMember: {
2183+
ASR::StructStaticMember_t* struct_static_member_t = ASR::down_cast<ASR::StructStaticMember_t>(expr);
2184+
return mark_as_read_only(struct_static_member_t->m_m);
2185+
}
2186+
case ASR::exprType::EnumStaticMember: {
2187+
ASR::EnumStaticMember_t* enum_static_member_t = ASR::down_cast<ASR::EnumStaticMember_t>(expr);
2188+
return mark_as_read_only(enum_static_member_t->m_m);
2189+
}
2190+
case ASR::exprType::UnionInstanceMember: {
2191+
ASR::UnionInstanceMember_t* union_instance_member_t = ASR::down_cast<ASR::UnionInstanceMember_t>(expr);
2192+
return mark_as_read_only(union_instance_member_t->m_m);
2193+
}
2194+
default: {
2195+
throw std::runtime_error("ASRUtils::exprType::" + std::to_string(expr->type) +
2196+
" is not handled yet in is_read_only method.");
2197+
}
2198+
}
2199+
}
2200+
2201+
bool is_read_only(ASR::symbol_t* symbol, std::string& symbol_name) {
2202+
if( read_only_symbols.find(ASRUtils::symbol_get_past_external(symbol)) !=
2203+
read_only_symbols.end() ) {
2204+
symbol_name = ASRUtils::symbol_name(ASRUtils::symbol_get_past_external(symbol));
2205+
return true;
2206+
}
2207+
return false;
2208+
}
2209+
2210+
bool is_read_only(ASR::expr_t* expr, std::string& symbol_name) {
2211+
switch( expr->type ) {
2212+
case ASR::exprType::Var: {
2213+
ASR::Var_t* var_t = ASR::down_cast<ASR::Var_t>(expr);
2214+
return is_read_only(var_t->m_v, symbol_name);
2215+
}
2216+
case ASR::exprType::ListItem: {
2217+
ASR::ListItem_t* list_item_t = ASR::down_cast<ASR::ListItem_t>(expr);
2218+
return is_read_only(list_item_t->m_a, symbol_name);
2219+
}
2220+
case ASR::exprType::ArrayItem: {
2221+
ASR::ArrayItem_t* array_item_t = ASR::down_cast<ASR::ArrayItem_t>(expr);
2222+
return is_read_only(array_item_t->m_v, symbol_name);
2223+
}
2224+
case ASR::exprType::ListSection: {
2225+
ASR::ListSection_t* list_section_t = ASR::down_cast<ASR::ListSection_t>(expr);
2226+
return is_read_only(list_section_t->m_a, symbol_name);
2227+
}
2228+
case ASR::exprType::StructInstanceMember: {
2229+
ASR::StructInstanceMember_t* struct_instance_member_t = ASR::down_cast<ASR::StructInstanceMember_t>(expr);
2230+
return is_read_only(struct_instance_member_t->m_m, symbol_name);
2231+
}
2232+
case ASR::exprType::StructStaticMember: {
2233+
ASR::StructStaticMember_t* struct_static_member_t = ASR::down_cast<ASR::StructStaticMember_t>(expr);
2234+
return is_read_only(struct_static_member_t->m_m, symbol_name);
2235+
}
2236+
case ASR::exprType::EnumStaticMember: {
2237+
ASR::EnumStaticMember_t* enum_static_member_t = ASR::down_cast<ASR::EnumStaticMember_t>(expr);
2238+
return is_read_only(enum_static_member_t->m_m, symbol_name);
2239+
}
2240+
case ASR::exprType::UnionInstanceMember: {
2241+
ASR::UnionInstanceMember_t* union_instance_member_t = ASR::down_cast<ASR::UnionInstanceMember_t>(expr);
2242+
return is_read_only(union_instance_member_t->m_m, symbol_name);
2243+
}
2244+
default: {
2245+
throw std::runtime_error("ASRUtils::exprType::" + std::to_string(expr->type) +
2246+
" is not handled yet in is_read_only method.");
2247+
}
2248+
}
2249+
return false;
2250+
}
2251+
21422252
bool TraverseBinaryOperator(clang::BinaryOperator *x) {
21432253
clang::BinaryOperatorKind op = x->getOpcode();
21442254
TraverseStmt(x->getLHS());
21452255
ASR::expr_t* x_lhs = ASRUtils::EXPR(tmp.get());
21462256
TraverseStmt(x->getRHS());
21472257
ASR::expr_t* x_rhs = ASRUtils::EXPR(tmp.get());
21482258
if( op == clang::BO_Assign ) {
2259+
std::string symbol_name = "";
2260+
if( is_read_only(x_lhs, symbol_name) ) {
2261+
std::cerr << symbol_name + " is marked as read only." << std::endl;
2262+
exit(EXIT_FAILURE);
2263+
}
21492264
cast_helper(x_lhs, x_rhs, true);
21502265
ASRUtils::make_ArrayBroadcast_t_util(al, Lloc(x), x_lhs, x_rhs);
21512266
tmp = ASR::make_Assignment_t(al, Lloc(x), x_lhs, x_rhs, nullptr);

tests/array_01_mdspan.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include <vector>
2+
#include <mdspan/mdspan.hpp>
3+
4+
int main() {
5+
6+
std::vector<double> arr1_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};
7+
8+
Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>> arr1{arr1_data.data(), 3, 3};
9+
10+
arr1_data[0] = 4.0;
11+
12+
return 0;
13+
}

tests/array_02_mdspan.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <vector>
2+
#include <mdspan/mdspan.hpp>
3+
4+
struct A {
5+
std::vector<double> arr1_data;
6+
};
7+
8+
int main1() {
9+
10+
std::vector<double> arr2_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};
11+
std::vector<double> arr3_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};
12+
13+
Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>> arr1{arr2_data.data(), 3, 3};
14+
15+
arr3_data[0] = 4.0;
16+
arr2_data[0] = 4.0;
17+
18+
return 0;
19+
}
20+
21+
int main() {
22+
23+
struct A aobj;
24+
aobj.arr1_data = {1.0, 2.0, 3.0, 2.0, 5.0, 7.0, 2.0, 5.0, 7.0};
25+
26+
Kokkos::mdspan<double, Kokkos::dextents<int32_t, 2>> arr1{aobj.arr1_data.data(), 3, 3};
27+
28+
aobj.arr1_data[0] = 4.0;
29+
30+
return 0;
31+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-array_01_mdspan-3a0babe",
3+
"cmd": "lc --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/array_01_mdspan.cpp",
5+
"infile_hash": "89f0c42a5270d3e994901c745aeb9d6b6d005eb466f53abc3c6b0684",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "asr-array_01_mdspan-3a0babe.stderr",
11+
"stderr_hash": "75fae36610809dcd01894bb09490f3b174e656cf490f47aa38aa5caf",
12+
"returncode": 1
13+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
arr1_data is marked as read only.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-array_02_mdspan-a54a3c5",
3+
"cmd": "lc --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/array_02_mdspan.cpp",
5+
"infile_hash": "0a58aca0631849d0d9f53e91255e409d79d6e8df0db2ce820f0f5d15",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "asr-array_02_mdspan-a54a3c5.stderr",
11+
"stderr_hash": "c335fe2f02049228933225c455954d7fb45ed1e5b318fafe53932c96",
12+
"returncode": 1
13+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
arr2_data is marked as read only.

tests/tests.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ cpp = true
2222
llvm = true
2323
fortran = true
2424

25+
[[test]]
26+
filename = "array_01_mdspan.cpp"
27+
asr = true
28+
29+
[[test]]
30+
filename = "array_02_mdspan.cpp"
31+
asr = true
32+
2533
[[test]]
2634
filename = "../integration_tests/array_01.cpp"
2735
asr = true

0 commit comments

Comments
 (0)