Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Implement sycl_ext_oneapi_group_load_store #10694

Open
wants to merge 8 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,13 @@ struct mptr_or_vec_elem_type<const multi_ptr<ElementType, Space, IsDecorated>>
template <typename T>
using mptr_or_vec_elem_type_t = typename mptr_or_vec_elem_type<T>::type;

template <int Size>
using cl_unsigned = std::conditional_t<
Size == 1, opencl::cl_uchar,
std::conditional_t<
Size == 2, opencl::cl_ushort,
std::conditional_t<Size == 4, opencl::cl_uint, opencl::cl_ulong>>>;

// select_apply_cl_scalar_t selects from T8/T16/T32/T64 basing on
// sizeof(IN). expected to handle scalar types.
template <typename T, typename T8, typename T16, typename T32, typename T64>
Expand Down
9 changes: 9 additions & 0 deletions sycl/include/sycl/detail/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,15 @@ void loop_impl(std::integer_sequence<size_t, Inds...>, F &&f) {
template <size_t count, class F> void loop(F &&f) {
loop_impl(std::make_index_sequence<count>{}, std::forward<F>(f));
}
template <size_t count, size_t limit, class F> void loop_unroll_up_to(F &&f) {
if constexpr (count > limit)
for (size_t i = 0; i < count; ++i)
f(i);
else
loop<count>([&](auto i) { f(i); });
}

inline constexpr bool is_power_of_two(int x) { return (x & (x - 1)) == 0; }
} // namespace detail

} // namespace _V1
Expand Down
8 changes: 1 addition & 7 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ inline namespace _V1 {
struct sub_group;
namespace ext {
namespace oneapi {
struct sub_group;
namespace experimental {
template <typename ParentGroup> class ballot_group;
template <size_t PartitionSize, typename ParentGroup> class fixed_size_group;
Expand Down Expand Up @@ -61,9 +60,6 @@ template <int Dimensions> struct group_scope<group<Dimensions>> {
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Workgroup;
};

template <> struct group_scope<::sycl::ext::oneapi::sub_group> {
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
};
template <> struct group_scope<::sycl::sub_group> {
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
};
Expand Down Expand Up @@ -254,12 +250,10 @@ using WidenOpenCLTypeTo32_t = std::conditional_t<
template <typename Group> struct GroupId {
using type = size_t;
};
template <> struct GroupId<::sycl::ext::oneapi::sub_group> {
using type = uint32_t;
};
template <> struct GroupId<::sycl::sub_group> {
using type = uint32_t;
};

template <typename Group, typename T, typename IdT>
EnableIfNativeBroadcast<T, IdT> GroupBroadcast(Group, T x, IdT local_id) {
using GroupIdT = typename GroupId<Group>::type;
Expand Down
26 changes: 16 additions & 10 deletions sycl/include/sycl/detail/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ inline constexpr bool is_fixed_size_group_v = is_fixed_size_group<T>::value;
template <int Dimensions> class group;
struct sub_group;
namespace ext::oneapi {
struct sub_group;

namespace experimental {
template <typename Group, std::size_t Extent> class group_with_scratchpad;
Expand All @@ -49,10 +48,7 @@ template <int Dimensions>
struct is_fixed_topology_group<sycl::group<Dimensions>> : std::true_type {};

template <>
struct is_fixed_topology_group<sycl::ext::oneapi::sub_group> : std::true_type {
};
template <> struct is_fixed_topology_group<sycl::sub_group> : std::true_type {};

struct is_fixed_topology_group<sycl::_V1::sub_group> : std::true_type {};
template <class T> struct is_user_constructed_group : std::false_type {};

template <class T>
Expand All @@ -65,6 +61,8 @@ template <typename T> struct is_group_helper : std::false_type {};
template <typename Group, std::size_t Extent>
struct is_group_helper<group_with_scratchpad<Group, Extent>> : std::true_type {
};
template <typename T>
inline constexpr bool is_group_helper_v = is_group_helper<T>::value;
} // namespace detail
} // namespace experimental
} // namespace ext::oneapi
Expand All @@ -78,13 +76,14 @@ struct is_group<group<Dimensions>> : std::true_type {};

template <typename T> struct is_sub_group : std::false_type {};

template <> struct is_sub_group<ext::oneapi::sub_group> : std::true_type {};
template <> struct is_sub_group<sycl::sub_group> : std::true_type {};

template <typename T>
struct is_generic_group
: std::integral_constant<bool,
is_group<T>::value || is_sub_group<T>::value> {};
template <typename T>
inline constexpr bool is_generic_group_v = is_generic_group<T>::value;

namespace half_impl {
class half;
Expand Down Expand Up @@ -316,6 +315,17 @@ template <typename T>
struct is_bool
: std::bool_constant<is_scalar_bool<vector_element_t<T>>::value> {};

// is_multi_ptr
template <typename T> struct is_multi_ptr_impl : public std::false_type {};

template <typename T, access::address_space Space,
access::decorated DecorateAddress>
struct is_multi_ptr_impl<multi_ptr<T, Space, DecorateAddress>>
: public std::true_type {};

template <typename T>
constexpr bool is_multi_ptr_v = is_multi_ptr_impl<std::remove_cv_t<T>>::value;

// is_pointer
template <typename T> struct is_pointer_impl : std::false_type {};

Expand All @@ -328,7 +338,6 @@ struct is_pointer_impl<multi_ptr<T, Space, DecorateAddress>> : std::true_type {

template <typename T>
struct is_pointer : is_pointer_impl<std::remove_cv_t<T>> {};

// is_multi_ptr
template <typename T> struct is_multi_ptr : std::false_type {};

Expand All @@ -337,9 +346,6 @@ template <typename ElementType, access::address_space Space,
struct is_multi_ptr<multi_ptr<ElementType, Space, IsDecorated>>
: std::true_type {};

template <class T>
inline constexpr bool is_multi_ptr_v = is_multi_ptr<T>::value;

// is_non_legacy_multi_ptr
template <typename T> struct is_non_legacy_multi_ptr : std::false_type {};

Expand Down
Loading