diff --git a/sycl/include/sycl/khr/group_interface.hpp b/sycl/include/sycl/khr/group_interface.hpp new file mode 100644 index 0000000000000..3bbe38ebb88c3 --- /dev/null +++ b/sycl/include/sycl/khr/group_interface.hpp @@ -0,0 +1,229 @@ +//==----- group_interface.hpp --- sycl_khr_group_interface extension -------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#pragma once +#ifdef __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS + +#include +#include +#include + +#if __cplusplus >= 202302L && defined(__has_include) +#if __has_include() +#include +#endif +#endif + +namespace sycl { +inline namespace _V1 { + +namespace khr { + +// Forward declarations for traits. +template class work_group; +class sub_group; +template class work_item; + +} // namespace khr + +namespace detail { +#if defined(__cpp_lib_mdspan) +template struct single_extents; + +template single_extents<1> { + using type = std::extents; +} + +template single_extents<2> { + using type = std::extents; +} + +template single_extents<3> { + using type = std::extents; +} +#endif + +template struct is_khr_group : public std::false_type {}; + +template +struct is_khr_group> : public std::true_type {}; + +template <> struct is_khr_group : public std::true_type {}; + +} // namespace detail + +namespace khr { + +// Forward declaration for friend function. +template +std::enable_if_t::value, + work_item> +get_item(ParentGroup g); + +template class work_group { +public: + using id_type = id; + using linear_id_type = size_t; + using range_type = range; +#if defined(__cpp_lib_mdspan) + using extents_type = std::dextents; +#endif + using size_type = size_t; + static constexpr int dimensions = Dimensions; + static constexpr memory_scope fence_scope = memory_scope::work_group; + + work_group(group g) noexcept {} + + operator group() const noexcept { return legacy(); } + + id_type id() const noexcept { return legacy().get_group_id(); } + + linear_id_type linear_id() const noexcept { + return legacy().get_group_linear_id(); + } + + range_type range() const noexcept { return legacy().get_group_range(); } + +#if defined(__cpp_lib_mdspan) + constexpr extents_type extents() const noexcept { + auto LocalRange = legacy().get_local_range(); + if constexpr (dimensions == 1) { + return extents_type(LocalRange[0]); + } else if constexpr (dimensions == 2) { + return extents_type(LocalRange[0], LocalRange[1]); + } else if constexpr (dimensions == 3) { + return extents_type(LocalRange[0], LocalRange[1], LocalRange[2]); + } + } + + constexpr index_type extent(rank_type r) const noexcept { + return extents().extent(r); + } +#endif + + constexpr size_type size() const noexcept { + return legacy().get_local_range().size(); + } + +private: + group legacy() const noexcept { + return ext::oneapi::this_work_item::get_work_group(); + } +}; + +class sub_group { +public: + using id_type = id<1>; + using linear_id_type = uint32_t; + using range_type = range<1>; +#if defined(__cpp_lib_mdspan) + using extents_type = std::dextents; +#endif + using size_type = uint32_t; + static constexpr int dimensions = 1; + static constexpr memory_scope fence_scope = memory_scope::sub_group; + + sub_group(sycl::sub_group g) noexcept {} + + operator sycl::sub_group() const noexcept { return legacy(); } + + id_type id() const noexcept { return legacy().get_group_id(); } + + linear_id_type linear_id() const noexcept { + return legacy().get_group_linear_id(); + } + + range_type range() const noexcept { return legacy().get_group_range(); } + +#if defined(__cpp_lib_mdspan) + constexpr extents_type extents() const noexcept { + return extents_type(legacy().get_local_range()[0]); + } + + constexpr index_type extent(rank_type r) const noexcept { + return extents().extent(r); + } +#endif + + constexpr size_type size() const noexcept { + return legacy().get_local_range()[0]; + } + + constexpr size_type max_size() const noexcept { + return legacy().get_max_local_range()[0]; + } + +private: + sycl::sub_group legacy() const noexcept { + return ext::oneapi::this_work_item::get_sub_group(); + } +}; + +template class work_item { +public: + using id_type = typename ParentGroup::id_type; + using linear_id_type = typename ParentGroup::linear_id_type; + using range_type = typename ParentGroup::range_type; +#if defined(__cpp_lib_mdspan) + using extents_type = + detail::single_extents; +#endif + using size_type = typename ParentGroup::size_type; + static constexpr int dimensions = ParentGroup::dimensions; + static constexpr memory_scope fence_scope = memory_scope::work_item; + + id_type id() const noexcept { return legacy().get_local_id(); } + + linear_id_type linear_id() const noexcept { + return legacy().get_local_linear_id(); + } + + range_type range() const noexcept { return legacy().get_local_range(); } + +#if defined(__cpp_lib_mdspan) + constexpr extents_type extents() const noexcept { return extents_type(); } + + constexpr index_type extent(rank_type r) const noexcept { + return extents().extent(r); + } +#endif + + constexpr size_type size() const noexcept { return 1; } + +private: + auto legacy() const noexcept { + if constexpr (std::is_same_v) { + return ext::oneapi::this_work_item::get_sub_group(); + } else { + return ext::oneapi::this_work_item::get_work_group< + ParentGroup::dimensions>(); + } + } + +protected: + work_item() {} + + friend work_item get_item(ParentGroup); +}; + +template +std::enable_if_t::value, + work_item> +get_item(ParentGroup g) { + return work_item{}; +} + +template bool leader_of(Group g) { + return get_item(g).linear_id() == 0; +} + +} // namespace khr +} // namespace _V1 +} // namespace sycl + +#endif // __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS diff --git a/sycl/test-e2e/GroupInterface/leader_of.cpp b/sycl/test-e2e/GroupInterface/leader_of.cpp new file mode 100644 index 0000000000000..638b5f21bd4e9 --- /dev/null +++ b/sycl/test-e2e/GroupInterface/leader_of.cpp @@ -0,0 +1,40 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +#include +#include +#include +#include +#define __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS +#include + +using namespace sycl; + +void test(queue q) { + int out = 0; + size_t G = 4; + + range<2> R(G, G); + { + buffer out_buf(&out, 1); + + q.submit([&](handler &cgh) { + auto out = out_buf.template get_access(cgh); + cgh.parallel_for(nd_range<2>(R, R), [=](nd_item<2> it) { + khr::work_group<2> g = it.get_group(); + if (khr::leader_of(g)) { + out[0] += 1; + } + }); + }); + } + assert(out == 1); +} + +int main() { + queue q; + test(q); + + std::cout << "Test passed." << std::endl; + return 0; +} diff --git a/sycl/test-e2e/GroupInterface/sub_group.cpp b/sycl/test-e2e/GroupInterface/sub_group.cpp new file mode 100644 index 0000000000000..828a0e5c48d48 --- /dev/null +++ b/sycl/test-e2e/GroupInterface/sub_group.cpp @@ -0,0 +1,111 @@ +// REQUIRES: cpu + +// RUN: %{build} %cxx_std_optionc++23 -o %t.out +// RUN: %{run} %t.out + +#include +#define __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS +#include + +#include + +#include + +using namespace sycl; + +static_assert(std::is_same_v>); +static_assert(std::is_same_v); +static_assert(std::is_same_v>); +#if defined(__cpp_lib_mdspan) +static_assert( + std::is_same_v>); +#endif +static_assert(std::is_same_v); +static_assert(khr::sub_group::dimensions == 1); +static_assert(khr::sub_group::fence_scope == memory_scope::sub_group); + +int main() { + queue q(cpu_selector_v); + + const int sz = 16; + q.submit([&](handler &h) { + h.parallel_for(nd_range<1>{sz, sz}, [=](nd_item<1> item) { + sub_group g = item.get_sub_group(); + + khr::sub_group sg = g; + assert(sg.id() == g.get_group_id()); + assert(sg.linear_id() == g.get_group_linear_id()); + assert(sg.range() == g.get_group_range()); +#if defined(__cpp_lib_mdspan) + assert(sg.extents().rank() == 1); + assert(sg.extent(0) == g.get_local_range()[0]); +#endif + assert(sg.size() == g.get_local_linear_range()); + assert(sg.max_size() == g.get_max_local_range()[0]); + + khr::work_item wi = get_item(sg); + assert(wi.id() == g.get_local_id()); + assert(wi.linear_id() == g.get_local_linear_id()); + assert(wi.range() == g.get_local_range()); +#if defined(__cpp_lib_mdspan) + assert(wi.extents().rank() == 1); + assert(wi.extent(0) == 1); +#endif + assert(wi.size() == 1); + }); + }); + q.submit([&](handler &h) { + h.parallel_for(nd_range<2>{range<2>{sz, sz}, range<2>{sz, sz}}, + [=](nd_item<2> item) { + sub_group g = item.get_sub_group(); + + khr::sub_group sg = g; + assert(sg.id() == g.get_group_id()); + assert(sg.linear_id() == g.get_group_linear_id()); + assert(sg.range() == g.get_group_range()); +#if defined(__cpp_lib_mdspan) + assert(sg.extents().rank() == 1); + assert(sg.extent(0) == g.get_local_range()[0]); +#endif + assert(sg.size() == g.get_local_linear_range()); + assert(sg.max_size() == g.get_max_local_range()[0]); + + khr::work_item wi = get_item(sg); + assert(wi.id() == g.get_local_id()); + assert(wi.linear_id() == g.get_local_linear_id()); + assert(wi.range() == g.get_local_range()); +#if defined(__cpp_lib_mdspan) + assert(wi.extents().rank() == 1); + assert(wi.extent(0) == 1); +#endif + assert(wi.size() == 1); + }); + }); + q.submit([&](handler &h) { + h.parallel_for(nd_range<3>{range<3>{sz, sz, sz}, range<3>{sz, sz, sz}}, + [=](nd_item<3> item) { + sub_group g = item.get_sub_group(); + + khr::sub_group sg = g; + assert(sg.id() == g.get_group_id()); + assert(sg.linear_id() == g.get_group_linear_id()); + assert(sg.range() == g.get_group_range()); +#if defined(__cpp_lib_mdspan) + assert(sg.extents().rank() == 1); + assert(sg.extent(0) == g.get_local_range()[0]); +#endif + assert(sg.size() == g.get_local_linear_range()); + assert(sg.max_size() == g.get_max_local_range()[0]); + + khr::work_item wi = get_item(sg); + assert(wi.id() == g.get_local_id()); + assert(wi.linear_id() == g.get_local_linear_id()); + assert(wi.range() == g.get_local_range()); +#if defined(__cpp_lib_mdspan) + assert(wi.extents().rank() == 1); + assert(wi.extent(0) == 1); +#endif + }); + }); + q.wait(); +} diff --git a/sycl/test-e2e/GroupInterface/work_group.cpp b/sycl/test-e2e/GroupInterface/work_group.cpp new file mode 100644 index 0000000000000..a0580d19966df --- /dev/null +++ b/sycl/test-e2e/GroupInterface/work_group.cpp @@ -0,0 +1,137 @@ +// REQUIRES: cpu + +// RUN: %{build} %cxx_std_optionc++23 -o %t.out +// RUN: %{run} %t.out + +#include +#define __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS +#include + +#include + +#include + +using namespace sycl; + +static_assert(std::is_same_v::id_type, id<1>>); +static_assert(std::is_same_v::linear_id_type, size_t>); +static_assert(std::is_same_v::range_type, range<1>>); +#if defined(__cpp_lib_mdspan) +static_assert( + std::is_same_v::extents_type, std::dextents>); +#endif +static_assert(std::is_same_v::size_type, size_t>); +static_assert(khr::work_group<1>::dimensions == 1); +static_assert(khr::work_group<1>::fence_scope == memory_scope::work_group); + +static_assert(std::is_same_v::id_type, id<2>>); +static_assert(std::is_same_v::linear_id_type, size_t>); +static_assert(std::is_same_v::range_type, range<2>>); +#if defined(__cpp_lib_mdspan) +static_assert( + std::is_same_v::extents_type, std::dextents>); +#endif +static_assert(std::is_same_v::size_type, size_t>); +static_assert(khr::work_group<2>::dimensions == 2); +static_assert(khr::work_group<2>::fence_scope == memory_scope::work_group); + +static_assert(std::is_same_v::id_type, id<3>>); +static_assert(std::is_same_v::linear_id_type, size_t>); +static_assert(std::is_same_v::range_type, range<3>>); +#if defined(__cpp_lib_mdspan) +static_assert( + std::is_same_v::extents_type, std::dextents>); +#endif +static_assert(std::is_same_v::size_type, size_t>); +static_assert(khr::work_group<3>::dimensions == 3); +static_assert(khr::work_group<3>::fence_scope == memory_scope::work_group); + +int main() { + queue q(cpu_selector_v); + + const int sz = 16; + q.submit([&](handler &h) { + h.parallel_for(nd_range<1>{sz, sz}, [=](nd_item<1> item) { + group<1> g = item.get_group(); + + khr::work_group<1> wg = g; + assert(wg.id() == g.get_group_id()); + assert(wg.linear_id() == g.get_group_linear_id()); + assert(wg.range() == g.get_group_range()); +#if defined(__cpp_lib_mdspan) + assert(wg.extents().rank() == 1); + assert(wg.extent(0) == g.get_local_range()[0]); +#endif + assert(wg.size() == g.get_local_linear_range()); + + khr::work_item wi = get_item(wg); + assert(wi.id() == g.get_local_id()); + assert(wi.linear_id() == g.get_local_linear_id()); + assert(wi.range() == g.get_local_range()); +#if defined(__cpp_lib_mdspan) + assert(wi.extents().rank() == 1); + assert(wi.extent(0) == 1); +#endif + assert(wi.size() == 1); + }); + }); + q.submit([&](handler &h) { + h.parallel_for(nd_range<2>{range<2>{sz, sz}, range<2>{sz, sz}}, + [=](nd_item<2> item) { + group<2> g = item.get_group(); + + khr::work_group<2> wg = g; + assert(wg.id() == g.get_group_id()); + assert(wg.linear_id() == g.get_group_linear_id()); + assert(wg.range() == g.get_group_range()); +#if defined(__cpp_lib_mdspan) + assert(wg.extents().rank() == 2); + assert(wg.extent(0) == g.get_local_range()[0]); + assert(wg.extent(1) == g.get_local_range()[1]); +#endif + assert(wg.size() == g.get_local_linear_range()); + + khr::work_item wi = get_item(wg); + assert(wi.id() == g.get_local_id()); + assert(wi.linear_id() == g.get_local_linear_id()); + assert(wi.range() == g.get_local_range()); +#if defined(__cpp_lib_mdspan) + assert(wi.extents().rank() == 2); + assert(wi.extent(0) == 1); + assert(wi.extent(1) == 1); +#endif + assert(wi.size() == 1); + }); + }); + q.submit([&](handler &h) { + h.parallel_for(nd_range<3>{range<3>{sz, sz, sz}, range<3>{sz, sz, sz}}, + [=](nd_item<3> item) { + group<3> g = item.get_group(); + + khr::work_group<3> wg = g; + assert(wg.id() == g.get_group_id()); + assert(wg.linear_id() == g.get_group_linear_id()); + assert(wg.range() == g.get_group_range()); +#if defined(__cpp_lib_mdspan) + assert(wg.extents().rank() == 3); + assert(wg.extent(0) == g.get_local_range()[0]); + assert(wg.extent(1) == g.get_local_range()[1]); + assert(wg.extent(2) == g.get_local_range()[2]); +#endif + assert(wg.size() == g.get_local_linear_range()); + + khr::work_item wi = get_item(wg); + assert(wi.id() == g.get_local_id()); + assert(wi.linear_id() == g.get_local_linear_id()); + assert(wi.range() == g.get_local_range()); +#if defined(__cpp_lib_mdspan) + assert(wi.extents().rank() == 3); + assert(wi.extent(0) == 1); + assert(wi.extent(1) == 1); + assert(wi.extent(2) == 1); +#endif + assert(wi.size() == 1); + }); + }); + q.wait(); +}