Skip to content

Commit 70a35de

Browse files
[SYCL] Add ballot_group support to algorithms (#8784)
Enables the following functions to be used with ballot_group arguments: - group_barrier - group_broadcast - any_of_group - all_of_group - none_of_group - reduce_over_group - exclusive_scan_over_group - inclusive_scan_over_group Signed-off-by: John Pennycook <[email protected]> --------- Signed-off-by: John Pennycook <[email protected]> Co-authored-by: aelovikov-intel <[email protected]>
1 parent 917d689 commit 70a35de

20 files changed

+781
-126
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,68 @@ __SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT int
968968
__spirv_GroupNonUniformBallotFindLSB(__spv::Scope::Flag,
969969
__ocl_vec_t<uint32_t, 4>) noexcept;
970970

971+
template <typename ValueT, typename IdT>
972+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
973+
__spirv_GroupNonUniformBroadcast(__spv::Scope::Flag, ValueT, IdT);
974+
975+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT bool
976+
__spirv_GroupNonUniformAll(__spv::Scope::Flag, bool);
977+
978+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT bool
979+
__spirv_GroupNonUniformAny(__spv::Scope::Flag, bool);
980+
981+
template <typename ValueT>
982+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
983+
__spirv_GroupNonUniformSMin(__spv::Scope::Flag, unsigned int, ValueT);
984+
985+
template <typename ValueT>
986+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
987+
__spirv_GroupNonUniformUMin(__spv::Scope::Flag, unsigned int, ValueT);
988+
989+
template <typename ValueT>
990+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
991+
__spirv_GroupNonUniformFMin(__spv::Scope::Flag, unsigned int, ValueT);
992+
993+
template <typename ValueT>
994+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
995+
__spirv_GroupNonUniformSMax(__spv::Scope::Flag, unsigned int, ValueT);
996+
997+
template <typename ValueT>
998+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
999+
__spirv_GroupNonUniformUMax(__spv::Scope::Flag, unsigned int, ValueT);
1000+
1001+
template <typename ValueT>
1002+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1003+
__spirv_GroupNonUniformFMax(__spv::Scope::Flag, unsigned int, ValueT);
1004+
1005+
template <typename ValueT>
1006+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1007+
__spirv_GroupNonUniformIAdd(__spv::Scope::Flag, unsigned int, ValueT);
1008+
1009+
template <typename ValueT>
1010+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1011+
__spirv_GroupNonUniformFAdd(__spv::Scope::Flag, unsigned int, ValueT);
1012+
1013+
template <typename ValueT>
1014+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1015+
__spirv_GroupNonUniformIMul(__spv::Scope::Flag, unsigned int, ValueT);
1016+
1017+
template <typename ValueT>
1018+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1019+
__spirv_GroupNonUniformFMul(__spv::Scope::Flag, unsigned int, ValueT);
1020+
1021+
template <typename ValueT>
1022+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1023+
__spirv_GroupNonUniformBitwiseOr(__spv::Scope::Flag, unsigned int, ValueT);
1024+
1025+
template <typename ValueT>
1026+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1027+
__spirv_GroupNonUniformBitwiseXor(__spv::Scope::Flag, unsigned int, ValueT);
1028+
1029+
template <typename ValueT>
1030+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
1031+
__spirv_GroupNonUniformBitwiseAnd(__spv::Scope::Flag, unsigned int, ValueT);
1032+
9711033
extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT void
9721034
__clc_BarrierInitialize(int64_t *state, int32_t expected_count) noexcept;
9731035

sycl/include/sycl/detail/spirv.hpp

Lines changed: 180 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <sycl/detail/generic_type_traits.hpp>
1515
#include <sycl/detail/helpers.hpp>
1616
#include <sycl/detail/type_traits.hpp>
17+
#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp>
1718
#include <sycl/id.hpp>
1819
#include <sycl/memory_enums.hpp>
1920

@@ -23,6 +24,9 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) {
2324
namespace ext {
2425
namespace oneapi {
2526
struct sub_group;
27+
namespace experimental {
28+
template <typename ParentGroup> class ballot_group;
29+
} // namespace experimental
2630
} // namespace oneapi
2731
} // namespace ext
2832

@@ -56,6 +60,11 @@ template <> struct group_scope<::sycl::ext::oneapi::sub_group> {
5660
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
5761
};
5862

63+
template <typename ParentGroup>
64+
struct group_scope<sycl::ext::oneapi::experimental::ballot_group<ParentGroup>> {
65+
static constexpr __spv::Scope::Flag value = group_scope<ParentGroup>::value;
66+
};
67+
5968
// Generic shuffles and broadcasts may require multiple calls to
6069
// intrinsics, and should use the fewest broadcasts possible
6170
// - Loop over chunks until remaining bytes < chunk size
@@ -94,13 +103,37 @@ void GenericCall(const Functor &ApplyToBytes) {
94103
}
95104
}
96105

97-
template <typename Group> bool GroupAll(bool pred) {
106+
template <typename Group> bool GroupAll(Group, bool pred) {
98107
return __spirv_GroupAll(group_scope<Group>::value, pred);
99108
}
109+
template <typename ParentGroup>
110+
bool GroupAll(ext::oneapi::experimental::ballot_group<ParentGroup> g,
111+
bool pred) {
112+
// ballot_group partitions its parent into two groups (0 and 1)
113+
// We have to force each group down different control flow
114+
// Work-items in the "false" group (0) may still be active
115+
if (g.get_group_id() == 1) {
116+
return __spirv_GroupNonUniformAll(group_scope<ParentGroup>::value, pred);
117+
} else {
118+
return __spirv_GroupNonUniformAll(group_scope<ParentGroup>::value, pred);
119+
}
120+
}
100121

101-
template <typename Group> bool GroupAny(bool pred) {
122+
template <typename Group> bool GroupAny(Group, bool pred) {
102123
return __spirv_GroupAny(group_scope<Group>::value, pred);
103124
}
125+
template <typename ParentGroup>
126+
bool GroupAny(ext::oneapi::experimental::ballot_group<ParentGroup> g,
127+
bool pred) {
128+
// ballot_group partitions its parent into two groups (0 and 1)
129+
// We have to force each group down different control flow
130+
// Work-items in the "false" group (0) may still be active
131+
if (g.get_group_id() == 1) {
132+
return __spirv_GroupNonUniformAny(group_scope<ParentGroup>::value, pred);
133+
} else {
134+
return __spirv_GroupNonUniformAny(group_scope<ParentGroup>::value, pred);
135+
}
136+
}
104137

105138
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
106139
// FIXME: Do not special-case for half or vec once all backends support all data
@@ -159,7 +192,7 @@ template <> struct GroupId<::sycl::ext::oneapi::sub_group> {
159192
using type = uint32_t;
160193
};
161194
template <typename Group, typename T, typename IdT>
162-
EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
195+
EnableIfNativeBroadcast<T, IdT> GroupBroadcast(Group, T x, IdT local_id) {
163196
using GroupIdT = typename GroupId<Group>::type;
164197
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
165198
using OCLT = detail::ConvertToOpenCLType_t<T>;
@@ -169,23 +202,51 @@ EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
169202
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
170203
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
171204
}
205+
template <typename ParentGroup, typename T, typename IdT>
206+
EnableIfNativeBroadcast<T, IdT>
207+
GroupBroadcast(sycl::ext::oneapi::experimental::ballot_group<ParentGroup> g,
208+
T x, IdT local_id) {
209+
// Remap local_id to its original numbering in ParentGroup.
210+
auto LocalId = detail::IdToMaskPosition(g, local_id);
211+
212+
// TODO: Refactor to avoid duplication after design settles.
213+
using GroupIdT = typename GroupId<ParentGroup>::type;
214+
GroupIdT GroupLocalId = static_cast<GroupIdT>(LocalId);
215+
using OCLT = detail::ConvertToOpenCLType_t<T>;
216+
using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
217+
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
218+
WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
219+
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
220+
221+
// ballot_group partitions its parent into two groups (0 and 1)
222+
// We have to force each group down different control flow
223+
// Work-items in the "false" group (0) may still be active
224+
if (g.get_group_id() == 1) {
225+
return __spirv_GroupNonUniformBroadcast(group_scope<ParentGroup>::value,
226+
OCLX, OCLId);
227+
} else {
228+
return __spirv_GroupNonUniformBroadcast(group_scope<ParentGroup>::value,
229+
OCLX, OCLId);
230+
}
231+
}
232+
172233
template <typename Group, typename T, typename IdT>
173-
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
234+
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(Group g, T x, IdT local_id) {
174235
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
175236
auto BroadcastX = bit_cast<BroadcastT>(x);
176-
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
237+
BroadcastT Result = GroupBroadcast(g, BroadcastX, local_id);
177238
return bit_cast<T>(Result);
178239
}
179240
template <typename Group, typename T, typename IdT>
180-
EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
241+
EnableIfGenericBroadcast<T, IdT> GroupBroadcast(Group g, T x, IdT local_id) {
181242
// Initialize with x to support type T without default constructor
182243
T Result = x;
183244
char *XBytes = reinterpret_cast<char *>(&x);
184245
char *ResultBytes = reinterpret_cast<char *>(&Result);
185246
auto BroadcastBytes = [=](size_t Offset, size_t Size) {
186247
uint64_t BroadcastX, BroadcastResult;
187248
std::memcpy(&BroadcastX, XBytes + Offset, Size);
188-
BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
249+
BroadcastResult = GroupBroadcast(g, BroadcastX, local_id);
189250
std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
190251
};
191252
GenericCall<T>(BroadcastBytes);
@@ -194,9 +255,10 @@ EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
194255

195256
// Broadcast with vector local index
196257
template <typename Group, typename T, int Dimensions>
197-
EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
258+
EnableIfNativeBroadcast<T> GroupBroadcast(Group g, T x,
259+
id<Dimensions> local_id) {
198260
if (Dimensions == 1) {
199-
return GroupBroadcast<Group>(x, local_id[0]);
261+
return GroupBroadcast(g, x, local_id[0]);
200262
}
201263
using IdT = vec<size_t, Dimensions>;
202264
using OCLT = detail::ConvertToOpenCLType_t<T>;
@@ -210,17 +272,26 @@ EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
210272
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
211273
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
212274
}
275+
template <typename ParentGroup, typename T>
276+
EnableIfNativeBroadcast<T>
277+
GroupBroadcast(sycl::ext::oneapi::experimental::ballot_group<ParentGroup> g,
278+
T x, id<1> local_id) {
279+
// Limited to 1D indices for now because ParentGroup must be sub-group.
280+
return GroupBroadcast(g, x, local_id[0]);
281+
}
213282
template <typename Group, typename T, int Dimensions>
214-
EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
283+
EnableIfBitcastBroadcast<T> GroupBroadcast(Group g, T x,
284+
id<Dimensions> local_id) {
215285
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
216286
auto BroadcastX = bit_cast<BroadcastT>(x);
217-
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
287+
BroadcastT Result = GroupBroadcast(g, BroadcastX, local_id);
218288
return bit_cast<T>(Result);
219289
}
220290
template <typename Group, typename T, int Dimensions>
221-
EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
291+
EnableIfGenericBroadcast<T> GroupBroadcast(Group g, T x,
292+
id<Dimensions> local_id) {
222293
if (Dimensions == 1) {
223-
return GroupBroadcast<Group>(x, local_id[0]);
294+
return GroupBroadcast(g, x, local_id[0]);
224295
}
225296
// Initialize with x to support type T without default constructor
226297
T Result = x;
@@ -229,7 +300,7 @@ EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
229300
auto BroadcastBytes = [=](size_t Offset, size_t Size) {
230301
uint64_t BroadcastX, BroadcastResult;
231302
std::memcpy(&BroadcastX, XBytes + Offset, Size);
232-
BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
303+
BroadcastResult = GroupBroadcast(g, BroadcastX, local_id);
233304
std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
234305
};
235306
GenericCall<T>(BroadcastBytes);
@@ -803,6 +874,101 @@ EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
803874
return Result;
804875
}
805876

877+
template <typename Group>
878+
typename std::enable_if_t<
879+
ext::oneapi::experimental::is_fixed_topology_group_v<Group>>
880+
ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
881+
__spirv_ControlBarrier(group_scope<Group>::value, getScope(FenceScope),
882+
getMemorySemanticsMask(Order) |
883+
__spv::MemorySemanticsMask::SubgroupMemory |
884+
__spv::MemorySemanticsMask::WorkgroupMemory |
885+
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
886+
}
887+
888+
template <typename Group>
889+
typename std::enable_if_t<
890+
ext::oneapi::experimental::is_user_constructed_group_v<Group>>
891+
ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
892+
#if defined(__SPIR__)
893+
// SPIR-V does not define an instruction to synchronize partial groups.
894+
// However, most (possibly all?) of the current SPIR-V targets execute
895+
// work-items in lockstep, so we can probably get away with a MemoryBarrier.
896+
// TODO: Replace this if SPIR-V defines a NonUniformControlBarrier
897+
__spirv_MemoryBarrier(getScope(FenceScope),
898+
getMemorySemanticsMask(Order) |
899+
__spv::MemorySemanticsMask::SubgroupMemory |
900+
__spv::MemorySemanticsMask::WorkgroupMemory |
901+
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
902+
#elif defined(__NVPTX__)
903+
// TODO: Call syncwarp with appropriate mask extracted from the group
904+
#endif
905+
}
906+
907+
// TODO: Refactor to avoid duplication after design settles
908+
#define __SYCL_GROUP_COLLECTIVE_OVERLOAD(Instruction) \
909+
template <__spv::GroupOperation Op, typename Group, typename T> \
910+
inline typename std::enable_if_t< \
911+
ext::oneapi::experimental::is_fixed_topology_group_v<Group>, T> \
912+
Group##Instruction(Group G, T x) { \
913+
using ConvertedT = detail::ConvertToOpenCLType_t<T>; \
914+
\
915+
using OCLT = \
916+
conditional_t<std::is_same<ConvertedT, cl_char>() || \
917+
std::is_same<ConvertedT, cl_short>(), \
918+
cl_int, \
919+
conditional_t<std::is_same<ConvertedT, cl_uchar>() || \
920+
std::is_same<ConvertedT, cl_ushort>(), \
921+
cl_uint, ConvertedT>>; \
922+
OCLT Arg = x; \
923+
OCLT Ret = __spirv_Group##Instruction(group_scope<Group>::value, \
924+
static_cast<unsigned int>(Op), Arg); \
925+
return Ret; \
926+
} \
927+
\
928+
template <__spv::GroupOperation Op, typename ParentGroup, typename T> \
929+
inline T Group##Instruction( \
930+
ext::oneapi::experimental::ballot_group<ParentGroup> g, T x) { \
931+
using ConvertedT = detail::ConvertToOpenCLType_t<T>; \
932+
\
933+
using OCLT = \
934+
conditional_t<std::is_same<ConvertedT, cl_char>() || \
935+
std::is_same<ConvertedT, cl_short>(), \
936+
cl_int, \
937+
conditional_t<std::is_same<ConvertedT, cl_uchar>() || \
938+
std::is_same<ConvertedT, cl_ushort>(), \
939+
cl_uint, ConvertedT>>; \
940+
OCLT Arg = x; \
941+
/* ballot_group partitions its parent into two groups (0 and 1) */ \
942+
/* We have to force each group down different control flow */ \
943+
/* Work-items in the "false" group (0) may still be active */ \
944+
constexpr auto Scope = group_scope<ParentGroup>::value; \
945+
constexpr auto OpInt = static_cast<unsigned int>(Op); \
946+
if (g.get_group_id() == 1) { \
947+
return __spirv_GroupNonUniform##Instruction(Scope, OpInt, Arg); \
948+
} else { \
949+
return __spirv_GroupNonUniform##Instruction(Scope, OpInt, Arg); \
950+
} \
951+
}
952+
953+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(SMin)
954+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(UMin)
955+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(FMin)
956+
957+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(SMax)
958+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(UMax)
959+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(FMax)
960+
961+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(IAdd)
962+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(FAdd)
963+
964+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(IMulKHR)
965+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(FMulKHR)
966+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(CMulINTEL)
967+
968+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(BitwiseOrKHR)
969+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(BitwiseXorKHR)
970+
__SYCL_GROUP_COLLECTIVE_OVERLOAD(BitwiseAndKHR)
971+
806972
} // namespace spirv
807973
} // namespace detail
808974
} // __SYCL_INLINE_VER_NAMESPACE(_V1)

sycl/include/sycl/detail/type_traits.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,29 @@ struct sub_group;
2727
namespace experimental {
2828
template <typename Group, std::size_t Extent> class group_with_scratchpad;
2929

30+
template <class T> struct is_fixed_topology_group : std::false_type {};
31+
32+
template <class T>
33+
inline constexpr bool is_fixed_topology_group_v =
34+
is_fixed_topology_group<T>::value;
35+
36+
#ifdef SYCL_EXT_ONEAPI_ROOT_GROUP
37+
template <> struct is_fixed_topology_group<root_group> : std::true_type {};
38+
#endif
39+
40+
template <int Dimensions>
41+
struct is_fixed_topology_group<sycl::group<Dimensions>> : std::true_type {};
42+
43+
template <>
44+
struct is_fixed_topology_group<sycl::ext::oneapi::sub_group> : std::true_type {
45+
};
46+
47+
template <class T> struct is_user_constructed_group : std::false_type {};
48+
49+
template <class T>
50+
inline constexpr bool is_user_constructed_group_v =
51+
is_user_constructed_group<T>::value;
52+
3053
namespace detail {
3154
template <typename T> struct is_group_helper : std::false_type {};
3255

0 commit comments

Comments
 (0)