Skip to content

Commit 3ba29f3

Browse files
JackAKirkAlcpzjoeatodd
authored
[SYCL][COMPAT][cuda] Add "ptr_to_integer" syclcompat functions. (#14283)
Add "ptr_to_integer" (generic address space to .shared) syclcompat functions. These functions are commonly required in optimized libraries that use inline ptx. The standard naming convention of removing "__" from corresponding cuda builtins has been applied. See the readme and accompanying test-e2e for example usage. --------- Signed-off-by: JackAKirk <[email protected]> Co-authored-by: Alberto Cabrera Pérez <[email protected]> Co-authored-by: Joe Todd <[email protected]>
1 parent 3a1c3cb commit 3ba29f3

File tree

3 files changed

+116
-0
lines changed

3 files changed

+116
-0
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,41 @@ public:
855855
} // syclcompat
856856
```
857857
858+
### ptr_to_int
859+
860+
The following cuda backend specific function is introduced in order to
861+
translate from local memory pointers to `uint32_t` or `size_t` variables that
862+
contain a byte address to the local (local refers to`.shared` in nvptx) memory
863+
state space.
864+
865+
``` c++
866+
namespace syclcompat {
867+
template <typename T>
868+
__syclcompat_inline__
869+
std::enable_if_t<std::is_same_v<T, uint32_t> || std::is_same_v<T, size_t>,
870+
T>
871+
ptr_to_int(void *ptr)
872+
} // namespace syclcompat
873+
```
874+
875+
These variables can be used in inline PTX instructions that take address
876+
operands. Such inline PTX instructions are commonly used in optimized
877+
libraries. A simplified example usage of the above functions is as follows:
878+
879+
``` c++
880+
half *data = syclcompat::local_mem<half[NUM_ELEMENTS]>();
881+
// ...
882+
// ...
883+
T addr =
884+
syclcompat::ptr_to_int<T>(reinterpret_cast<char *>(data) + (id % 8) * 16);
885+
uint32_t fragment;
886+
#if defined(__NVPTX__)
887+
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
888+
: "=r"(fragment)
889+
: "r"(addr));
890+
#endif
891+
```
892+
858893
### Device Information
859894

860895
`sycl::device` properties are encapsulated using the `device_info` helper class.

sycl/include/syclcompat/memory.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
#include <syclcompat/device.hpp>
5555
#include <syclcompat/traits.hpp>
56+
#include <syclcompat/defs.hpp>
5657

5758
#if defined(__linux__)
5859
#include <sys/mman.h>
@@ -86,6 +87,23 @@ enum memcpy_direction {
8687
};
8788
}
8889

90+
template <typename T>
91+
__syclcompat_inline__
92+
std::enable_if_t<std::is_same_v<T, uint32_t> || std::is_same_v<T, size_t>,
93+
T>
94+
ptr_to_int(void *ptr) {
95+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
96+
if constexpr (std::is_same_v<T, uint32_t>) {
97+
return (intptr_t)(sycl::decorated_local_ptr<const void>::pointer)ptr;
98+
} else {
99+
return (size_t)(sycl::decorated_local_ptr<const void>::pointer)ptr;
100+
}
101+
#else
102+
throw sycl::exception(make_error_code(sycl::errc::runtime),
103+
"ptr_to_int is only supported on Nvidia devices.");
104+
#endif
105+
}
106+
89107
enum class memory_region {
90108
global = 0, // device global memory
91109
constant, // device read-only memory
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// REQUIRES: cuda
2+
// RUN: %{build} -Xsycl-target-backend --cuda-gpu-arch=sm_75 -o %t.out
3+
// RUN: %{run} %t.out
4+
#include <sycl/detail/core.hpp>
5+
#include <sycl/group_barrier.hpp>
6+
#include <syclcompat/memory.hpp>
7+
8+
using namespace sycl;
9+
#define NUM_ELEMENTS 64
10+
11+
template <class T> void test(queue stream) {
12+
half *res = malloc_shared<half>(NUM_ELEMENTS, stream);
13+
14+
for (int i = 0; i < NUM_ELEMENTS; ++i) {
15+
res[i] = 0.5;
16+
}
17+
18+
sycl::nd_range<1> global_range{sycl::range{32}, sycl::range{32}};
19+
20+
stream
21+
.submit([&](handler &h) {
22+
h.parallel_for<T>(global_range, [=](nd_item<1> item) {
23+
sycl::group work_group = item.get_group();
24+
int id = item.get_global_linear_id();
25+
half *data = syclcompat::local_mem<half[NUM_ELEMENTS]>();
26+
27+
data[id * 2] = id;
28+
data[id * 2 + 1] = id + 0.5;
29+
30+
T addr =
31+
syclcompat::ptr_to_int<T>(reinterpret_cast<char *>(data) + (id % 8) * 16);
32+
33+
uint32_t fragment;
34+
#if defined(__NVPTX__)
35+
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
36+
: "=r"(fragment)
37+
: "r"(addr));
38+
#endif
39+
sycl::group_barrier(work_group);
40+
41+
half *data_ptr = reinterpret_cast<half *>(&fragment);
42+
res[id * 2] = data_ptr[0];
43+
res[id * 2 + 1] = data_ptr[1];
44+
});
45+
})
46+
.wait();
47+
48+
for (int i = 0; i < NUM_ELEMENTS; i++) {
49+
assert(res[i] == static_cast<half>(i / 2.0));
50+
}
51+
52+
free(res, stream);
53+
};
54+
55+
int main() {
56+
57+
queue stream{property::queue::in_order{}};
58+
test<size_t>(stream);
59+
test<uint32_t>(stream);
60+
61+
std::cout << "PASS" << std::endl;
62+
return 0;
63+
}

0 commit comments

Comments
 (0)