14
14
#include < sycl/detail/generic_type_traits.hpp>
15
15
#include < sycl/detail/helpers.hpp>
16
16
#include < sycl/detail/type_traits.hpp>
17
+ #include < sycl/ext/oneapi/experimental/non_uniform_groups.hpp>
17
18
#include < sycl/id.hpp>
18
19
#include < sycl/memory_enums.hpp>
19
20
@@ -23,6 +24,9 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) {
23
24
namespace ext {
24
25
namespace oneapi {
25
26
struct sub_group ;
27
+ namespace experimental {
28
+ template <typename ParentGroup> class ballot_group ;
29
+ } // namespace experimental
26
30
} // namespace oneapi
27
31
} // namespace ext
28
32
@@ -56,6 +60,11 @@ template <> struct group_scope<::sycl::ext::oneapi::sub_group> {
56
60
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
57
61
};
58
62
63
+ template <typename ParentGroup>
64
+ struct group_scope <sycl::ext::oneapi::experimental::ballot_group<ParentGroup>> {
65
+ static constexpr __spv::Scope::Flag value = group_scope<ParentGroup>::value;
66
+ };
67
+
59
68
// Generic shuffles and broadcasts may require multiple calls to
60
69
// intrinsics, and should use the fewest broadcasts possible
61
70
// - Loop over chunks until remaining bytes < chunk size
@@ -94,13 +103,37 @@ void GenericCall(const Functor &ApplyToBytes) {
94
103
}
95
104
}
96
105
97
- template <typename Group> bool GroupAll (bool pred) {
106
+ template <typename Group> bool GroupAll (Group, bool pred) {
98
107
return __spirv_GroupAll (group_scope<Group>::value, pred);
99
108
}
109
+ template <typename ParentGroup>
110
+ bool GroupAll (ext::oneapi::experimental::ballot_group<ParentGroup> g,
111
+ bool pred) {
112
+ // ballot_group partitions its parent into two groups (0 and 1)
113
+ // We have to force each group down different control flow
114
+ // Work-items in the "false" group (0) may still be active
115
+ if (g.get_group_id () == 1 ) {
116
+ return __spirv_GroupNonUniformAll (group_scope<ParentGroup>::value, pred);
117
+ } else {
118
+ return __spirv_GroupNonUniformAll (group_scope<ParentGroup>::value, pred);
119
+ }
120
+ }
100
121
101
- template <typename Group> bool GroupAny (bool pred) {
122
+ template <typename Group> bool GroupAny (Group, bool pred) {
102
123
return __spirv_GroupAny (group_scope<Group>::value, pred);
103
124
}
125
+ template <typename ParentGroup>
126
+ bool GroupAny (ext::oneapi::experimental::ballot_group<ParentGroup> g,
127
+ bool pred) {
128
+ // ballot_group partitions its parent into two groups (0 and 1)
129
+ // We have to force each group down different control flow
130
+ // Work-items in the "false" group (0) may still be active
131
+ if (g.get_group_id () == 1 ) {
132
+ return __spirv_GroupNonUniformAny (group_scope<ParentGroup>::value, pred);
133
+ } else {
134
+ return __spirv_GroupNonUniformAny (group_scope<ParentGroup>::value, pred);
135
+ }
136
+ }
104
137
105
138
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
106
139
// FIXME: Do not special-case for half or vec once all backends support all data
@@ -159,7 +192,7 @@ template <> struct GroupId<::sycl::ext::oneapi::sub_group> {
159
192
using type = uint32_t ;
160
193
};
161
194
template <typename Group, typename T, typename IdT>
162
- EnableIfNativeBroadcast<T, IdT> GroupBroadcast (T x, IdT local_id) {
195
+ EnableIfNativeBroadcast<T, IdT> GroupBroadcast (Group, T x, IdT local_id) {
163
196
using GroupIdT = typename GroupId<Group>::type;
164
197
GroupIdT GroupLocalId = static_cast <GroupIdT>(local_id);
165
198
using OCLT = detail::ConvertToOpenCLType_t<T>;
@@ -169,23 +202,51 @@ EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
169
202
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
170
203
return __spirv_GroupBroadcast (group_scope<Group>::value, OCLX, OCLId);
171
204
}
205
+ template <typename ParentGroup, typename T, typename IdT>
206
+ EnableIfNativeBroadcast<T, IdT>
207
+ GroupBroadcast (sycl::ext::oneapi::experimental::ballot_group<ParentGroup> g,
208
+ T x, IdT local_id) {
209
+ // Remap local_id to its original numbering in ParentGroup.
210
+ auto LocalId = detail::IdToMaskPosition (g, local_id);
211
+
212
+ // TODO: Refactor to avoid duplication after design settles.
213
+ using GroupIdT = typename GroupId<ParentGroup>::type;
214
+ GroupIdT GroupLocalId = static_cast <GroupIdT>(LocalId);
215
+ using OCLT = detail::ConvertToOpenCLType_t<T>;
216
+ using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
217
+ using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
218
+ WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
219
+ OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
220
+
221
+ // ballot_group partitions its parent into two groups (0 and 1)
222
+ // We have to force each group down different control flow
223
+ // Work-items in the "false" group (0) may still be active
224
+ if (g.get_group_id () == 1 ) {
225
+ return __spirv_GroupNonUniformBroadcast (group_scope<ParentGroup>::value,
226
+ OCLX, OCLId);
227
+ } else {
228
+ return __spirv_GroupNonUniformBroadcast (group_scope<ParentGroup>::value,
229
+ OCLX, OCLId);
230
+ }
231
+ }
232
+
172
233
template <typename Group, typename T, typename IdT>
173
- EnableIfBitcastBroadcast<T, IdT> GroupBroadcast (T x, IdT local_id) {
234
+ EnableIfBitcastBroadcast<T, IdT> GroupBroadcast (Group g, T x, IdT local_id) {
174
235
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
175
236
auto BroadcastX = bit_cast<BroadcastT>(x);
176
- BroadcastT Result = GroupBroadcast<Group>( BroadcastX, local_id);
237
+ BroadcastT Result = GroupBroadcast (g, BroadcastX, local_id);
177
238
return bit_cast<T>(Result);
178
239
}
179
240
template <typename Group, typename T, typename IdT>
180
- EnableIfGenericBroadcast<T, IdT> GroupBroadcast (T x, IdT local_id) {
241
+ EnableIfGenericBroadcast<T, IdT> GroupBroadcast (Group g, T x, IdT local_id) {
181
242
// Initialize with x to support type T without default constructor
182
243
T Result = x;
183
244
char *XBytes = reinterpret_cast <char *>(&x);
184
245
char *ResultBytes = reinterpret_cast <char *>(&Result);
185
246
auto BroadcastBytes = [=](size_t Offset, size_t Size) {
186
247
uint64_t BroadcastX, BroadcastResult;
187
248
std::memcpy (&BroadcastX, XBytes + Offset, Size);
188
- BroadcastResult = GroupBroadcast<Group>( BroadcastX, local_id);
249
+ BroadcastResult = GroupBroadcast (g, BroadcastX, local_id);
189
250
std::memcpy (ResultBytes + Offset, &BroadcastResult, Size);
190
251
};
191
252
GenericCall<T>(BroadcastBytes);
@@ -194,9 +255,10 @@ EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
194
255
195
256
// Broadcast with vector local index
196
257
template <typename Group, typename T, int Dimensions>
197
- EnableIfNativeBroadcast<T> GroupBroadcast (T x, id<Dimensions> local_id) {
258
+ EnableIfNativeBroadcast<T> GroupBroadcast (Group g, T x,
259
+ id<Dimensions> local_id) {
198
260
if (Dimensions == 1 ) {
199
- return GroupBroadcast<Group>( x, local_id[0 ]);
261
+ return GroupBroadcast (g, x, local_id[0 ]);
200
262
}
201
263
using IdT = vec<size_t , Dimensions>;
202
264
using OCLT = detail::ConvertToOpenCLType_t<T>;
@@ -210,17 +272,26 @@ EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
210
272
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
211
273
return __spirv_GroupBroadcast (group_scope<Group>::value, OCLX, OCLId);
212
274
}
275
+ template <typename ParentGroup, typename T>
276
+ EnableIfNativeBroadcast<T>
277
+ GroupBroadcast (sycl::ext::oneapi::experimental::ballot_group<ParentGroup> g,
278
+ T x, id<1 > local_id) {
279
+ // Limited to 1D indices for now because ParentGroup must be sub-group.
280
+ return GroupBroadcast (g, x, local_id[0 ]);
281
+ }
213
282
template <typename Group, typename T, int Dimensions>
214
- EnableIfBitcastBroadcast<T> GroupBroadcast (T x, id<Dimensions> local_id) {
283
+ EnableIfBitcastBroadcast<T> GroupBroadcast (Group g, T x,
284
+ id<Dimensions> local_id) {
215
285
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
216
286
auto BroadcastX = bit_cast<BroadcastT>(x);
217
- BroadcastT Result = GroupBroadcast<Group>( BroadcastX, local_id);
287
+ BroadcastT Result = GroupBroadcast (g, BroadcastX, local_id);
218
288
return bit_cast<T>(Result);
219
289
}
220
290
template <typename Group, typename T, int Dimensions>
221
- EnableIfGenericBroadcast<T> GroupBroadcast (T x, id<Dimensions> local_id) {
291
+ EnableIfGenericBroadcast<T> GroupBroadcast (Group g, T x,
292
+ id<Dimensions> local_id) {
222
293
if (Dimensions == 1 ) {
223
- return GroupBroadcast<Group>( x, local_id[0 ]);
294
+ return GroupBroadcast (g, x, local_id[0 ]);
224
295
}
225
296
// Initialize with x to support type T without default constructor
226
297
T Result = x;
@@ -229,7 +300,7 @@ EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
229
300
auto BroadcastBytes = [=](size_t Offset, size_t Size) {
230
301
uint64_t BroadcastX, BroadcastResult;
231
302
std::memcpy (&BroadcastX, XBytes + Offset, Size);
232
- BroadcastResult = GroupBroadcast<Group>( BroadcastX, local_id);
303
+ BroadcastResult = GroupBroadcast (g, BroadcastX, local_id);
233
304
std::memcpy (ResultBytes + Offset, &BroadcastResult, Size);
234
305
};
235
306
GenericCall<T>(BroadcastBytes);
@@ -803,6 +874,101 @@ EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
803
874
return Result;
804
875
}
805
876
877
+ template <typename Group>
878
+ typename std::enable_if_t <
879
+ ext::oneapi::experimental::is_fixed_topology_group_v<Group>>
880
+ ControlBarrier (Group, memory_scope FenceScope, memory_order Order) {
881
+ __spirv_ControlBarrier (group_scope<Group>::value, getScope (FenceScope),
882
+ getMemorySemanticsMask (Order) |
883
+ __spv::MemorySemanticsMask::SubgroupMemory |
884
+ __spv::MemorySemanticsMask::WorkgroupMemory |
885
+ __spv::MemorySemanticsMask::CrossWorkgroupMemory);
886
+ }
887
+
888
+ template <typename Group>
889
+ typename std::enable_if_t <
890
+ ext::oneapi::experimental::is_user_constructed_group_v<Group>>
891
+ ControlBarrier (Group, memory_scope FenceScope, memory_order Order) {
892
+ #if defined(__SPIR__)
893
+ // SPIR-V does not define an instruction to synchronize partial groups.
894
+ // However, most (possibly all?) of the current SPIR-V targets execute
895
+ // work-items in lockstep, so we can probably get away with a MemoryBarrier.
896
+ // TODO: Replace this if SPIR-V defines a NonUniformControlBarrier
897
+ __spirv_MemoryBarrier (getScope (FenceScope),
898
+ getMemorySemanticsMask (Order) |
899
+ __spv::MemorySemanticsMask::SubgroupMemory |
900
+ __spv::MemorySemanticsMask::WorkgroupMemory |
901
+ __spv::MemorySemanticsMask::CrossWorkgroupMemory);
902
+ #elif defined(__NVPTX__)
903
+ // TODO: Call syncwarp with appropriate mask extracted from the group
904
+ #endif
905
+ }
906
+
907
+ // TODO: Refactor to avoid duplication after design settles
908
+ #define __SYCL_GROUP_COLLECTIVE_OVERLOAD (Instruction ) \
909
+ template <__spv::GroupOperation Op, typename Group, typename T> \
910
+ inline typename std::enable_if_t < \
911
+ ext::oneapi::experimental::is_fixed_topology_group_v<Group>, T> \
912
+ Group##Instruction(Group G, T x) { \
913
+ using ConvertedT = detail::ConvertToOpenCLType_t<T>; \
914
+ \
915
+ using OCLT = \
916
+ conditional_t <std::is_same<ConvertedT, cl_char>() || \
917
+ std::is_same<ConvertedT, cl_short>(), \
918
+ cl_int, \
919
+ conditional_t <std::is_same<ConvertedT, cl_uchar>() || \
920
+ std::is_same<ConvertedT, cl_ushort>(), \
921
+ cl_uint, ConvertedT>>; \
922
+ OCLT Arg = x; \
923
+ OCLT Ret = __spirv_Group##Instruction (group_scope<Group>::value, \
924
+ static_cast <unsigned int >(Op), Arg); \
925
+ return Ret; \
926
+ } \
927
+ \
928
+ template <__spv::GroupOperation Op, typename ParentGroup, typename T> \
929
+ inline T Group##Instruction( \
930
+ ext::oneapi::experimental::ballot_group<ParentGroup> g, T x) { \
931
+ using ConvertedT = detail::ConvertToOpenCLType_t<T>; \
932
+ \
933
+ using OCLT = \
934
+ conditional_t <std::is_same<ConvertedT, cl_char>() || \
935
+ std::is_same<ConvertedT, cl_short>(), \
936
+ cl_int, \
937
+ conditional_t <std::is_same<ConvertedT, cl_uchar>() || \
938
+ std::is_same<ConvertedT, cl_ushort>(), \
939
+ cl_uint, ConvertedT>>; \
940
+ OCLT Arg = x; \
941
+ /* ballot_group partitions its parent into two groups (0 and 1) */ \
942
+ /* We have to force each group down different control flow */ \
943
+ /* Work-items in the "false" group (0) may still be active */ \
944
+ constexpr auto Scope = group_scope<ParentGroup>::value; \
945
+ constexpr auto OpInt = static_cast <unsigned int >(Op); \
946
+ if (g.get_group_id () == 1 ) { \
947
+ return __spirv_GroupNonUniform##Instruction (Scope, OpInt, Arg); \
948
+ } else { \
949
+ return __spirv_GroupNonUniform##Instruction (Scope, OpInt, Arg); \
950
+ } \
951
+ }
952
+
953
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (SMin)
954
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (UMin)
955
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (FMin)
956
+
957
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (SMax)
958
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (UMax)
959
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (FMax)
960
+
961
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (IAdd)
962
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (FAdd)
963
+
964
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (IMulKHR)
965
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (FMulKHR)
966
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (CMulINTEL)
967
+
968
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (BitwiseOrKHR)
969
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (BitwiseXorKHR)
970
+ __SYCL_GROUP_COLLECTIVE_OVERLOAD (BitwiseAndKHR)
971
+
806
972
} // namespace spirv
807
973
} // namespace detail
808
974
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
0 commit comments