-
Notifications
You must be signed in to change notification settings - Fork 67
Bitonic_Sort #943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Bitonic_Sort #943
Conversation
|
[CI]: Can one of the admins verify this patch? |
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareExchangeWithPartner( | ||
| bool takeLarger, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerLoPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair, | ||
| NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerHiPair, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool loSelfSmaller = comp(loPair.first, partnerLoPair.first); | ||
| const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller; | ||
| loPair.first = takePartnerLo ? partnerLoPair.first : loPair.first; | ||
| loPair.second = takePartnerLo ? partnerLoPair.second : loPair.second; | ||
|
|
||
| const bool hiSelfSmaller = comp(hiPair.first, partnerHiPair.first); | ||
| const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller; | ||
| hiPair.first = takePartnerHi ? partnerHiPair.first : hiPair.first; | ||
| hiPair.second = takePartnerHi ? partnerHiPair.second : hiPair.second; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define an operator= for pair, then this becomes
if(takePartnerLo)
loPair = pLoPair;
if(takePartnerHi)
hiPair = pHiPair;don't worry about branching, since everything's an assignment it'll just be OpSelects under the hood
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can't define an operator= in HLSL none of the operators which should return references can be defined in HLSL so assignment and array indexing (as well as compound assignment)
However all structs in HLSL are trivial so can be assigned with =
| template<typename KeyType, typename ValueType> | ||
| void swap( | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair) | ||
| { | ||
| pair<KeyType, ValueType> temp = loPair; | ||
| loPair = hiPair; | ||
| hiPair = temp; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work without a definition for operator= for pair. Your code compiles because you're not using it rn. We want to keep this version and rewrite all the swaps to use this version using pairs.
The definition for pair, the overload for operator= and this swap method all belong in https://github.com/Devsh-Graphics-Programming/Nabla/blob/master/include/nbl/builtin/hlsl/utility.hlsl, mimicking how std::pair is in the <utility> header in cpp. Move them over there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
operator= can only be defined in a #ifndef __HLSL_VERSION macro block, for DXC reasons
| static void mergeStage(NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor, uint32_t stage, bool bitonicAscending, uint32_t invocationID, | ||
| NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) loPair, NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) hiPair) | ||
| { | ||
| const uint32_t WorkgroupSize = config_t::WorkgroupSize; | ||
| const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2(); | ||
| comparator_t comp; | ||
|
|
||
| [unroll] | ||
| for (uint32_t pass = 0; pass <= stage; pass++) | ||
| { | ||
| if (pass) | ||
| sharedmemAccessor.workgroupExecutionAndMemoryBarrier(); | ||
|
|
||
| const uint32_t stridePower = (stage - pass + 1) + subgroupSizeLog2; | ||
| const uint32_t stride = 1u << stridePower; | ||
| const uint32_t threadStride = stride >> 1; | ||
|
|
||
| nbl::hlsl::pair<key_t, value_t> pLoPair = loPair; | ||
| shuffleXor(pLoPair, threadStride, sharedmemAccessor); | ||
| sharedmemAccessor.workgroupExecutionAndMemoryBarrier(); | ||
|
|
||
| nbl::hlsl::pair<key_t, value_t> pHiPair = hiPair; | ||
| shuffleXor(pHiPair, threadStride, sharedmemAccessor); | ||
|
|
||
| const bool isUpper = (invocationID & threadStride) != 0; | ||
| const bool takeLarger = isUpper == bitonicAscending; | ||
|
|
||
| nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loPair, pLoPair, hiPair, pHiPair, comp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're using the shared memory accessor here, when you should be using an adaptor. The user only has to pass a generic sharedmem accessor, then the adaptor ensures its accesses are optimal. If changing accessor for adaptors is what breaks your code, we can look into making a specialized sharedmem adaptor for pairs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, we're passing the burden of writing an optimal accessor to the user, which requires the user knowing the underlying impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold this one until Matt replies because the adaptor needs to change here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, we're passing the burden of writing an optimal accessor to the user, which requires the user knowing the underlying impl
yep
Removed commented-out template struct for pair.
| [unroll] | ||
| for (uint32_t strideLog = simpleLog - 1u; strideLog + 1u > 0u; strideLog--) | ||
| { | ||
| const uint32_t stride = 1u << strideLog; | ||
| [unroll] | ||
| for (uint32_t virtualThreadID = threadID; virtualThreadID < ElementsPerSimpleSort / 2; virtualThreadID += WorkgroupSize) | ||
| { | ||
| const uint32_t loIx = (((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u))) + offsetAccessor.offset; | ||
| const uint32_t hiIx = loIx | stride; | ||
|
|
||
| nbl::hlsl::pair<key_t, value_t> lopair, hipair; | ||
| accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair); | ||
| accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair); | ||
|
|
||
| swap(lopair, hipair); | ||
|
|
||
| accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair); | ||
| accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair); | ||
| } | ||
| sharedmemAccessor.workgroupExecutionAndMemoryBarrier(); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels overkill to do this. Instead of having to go over half the array to swap it around just make the E=2 struct specialization take an ascending parameter like you did for the subgroup sort. Set it to true by default, and in the calls above you need in this E > 2 struct pass ascending = !(WorkgroupID & 1) to ensure every even workgroup is sorted ascendingly and every odd workgroup is sorted descendingly. It's the same you did for the subgroup case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't seem to be done and is important, but maybe lets try the Subgroup Transpose trick
| template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> > | ||
| struct bitonic_sort_config | ||
| { | ||
| using key_t = KeyType; | ||
| using value_t = ValueType; | ||
| using comparator_t = Comparator; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where's the SubgroupSizeLog2 ?
| static void __call(bool ascending, NBL_REF_ARG(pair<key_t, value_t>) loPair, NBL_REF_ARG(pair<key_t, value_t>) hiPair) | ||
| { | ||
| const uint32_t invocationID = glsl::gl_SubgroupInvocationID(); | ||
| const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs to be compile time constant from config like the FFT
| template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> > | ||
| struct bitonic_sort_config | ||
| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you also want to handle "multiple items per thread" in the subgroup sort
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareSwap( | ||
| bool ascending, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool shouldSwap = comp(hiPair.first, loPair.first); | ||
| const bool doSwap = (shouldSwap == ascending); | ||
|
|
||
| if (doSwap) | ||
| swap(loPair, hiPair); | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
write a TODO to make a struct which can be specialized by declaring a PoT sorting network within a thread
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
example
template<typename KeyValue, uint32_t Log2N>
struct SortingNetwork
{
void operator()(const bool ascending, NBL_REF_ARG(array<KeyValue,0x1u<<Log2N>) data, NBL_CONST_REF_ARG(Comparator) comp);
};and then you can partially specialize it for all Log2N from 1 to 3 or 4
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerLoPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair, | ||
| NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerHiPair, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't do everything for pair<Key,Value> but a KeyValue on which get<0,KeyValue>(kv) and get<0,KeyValue>(kv) can be called (instead of .first and .second).
And then you can trivially allow pair<K,V> be used as KeyValue
Why? You will have the option of "packing" the key or value into the would-be padding bits of each other (when SoA isn't possible or beneficial) and doing all sorts of other tricks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
furthemore to allow working on more than 2 elements per thread the signature should really just be bool, NBL_REF_ARG(KeyValue) data[ElementsPerThread])
can do NBL_REF_ARG(array<U,ElementsPerThread>) if DXC gives you pain on array types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw its fine to not rewrite the function and static_assert(ElementsPErThread==2) for now
or just provide the partial spec for ElementsPerThread==2 (since static assert gives trouble in DXC SPIR-V backend) but then you need to work it into a struct functor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just talked with @Fletterio and we figured out how to trade subgroup variables to not have to perform redundant comparisons, this means 4-way Key-Value exchanges per thread won't be necessary and this function can disappear
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { | ||
| const bool loSelfSmaller = comp(loPair.first, partnerLoPair.first); | ||
| const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's a faster way to write this loSelfSmaller==takeLarger, look at a T/F table of your current expression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure this is correct?
#940 (comment)
if I myself am smaller, then comp which is < will return true, but I will swap with partner when takeLarger is true
did you mess up variable naming here?
| loPair = partnerLoPair; | ||
|
|
||
| const bool hiSelfSmaller = comp(hiPair.first, partnerHiPair.first); | ||
| const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's a faster way to write this hiSelfSmaller==takeLarger, look at a T/F table of your current expression
| for (uint32_t pass = 0; pass <= stage; pass++) | ||
| { | ||
| const uint32_t stride = 1u << (stage - pass); // Element stride |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not write the loop as for (uint32_t stride=1u<<stage; stride; stride=stride>>1) and forget the pass ?
| // The SharedMemoryAccessor MUST provide the following methods: | ||
| // * void get(uint32_t index, NBL_REF_ARG(uint32_t) value); | ||
| // * void set(uint32_t index, in uint32_t value); | ||
| // * void workgroupExecutionAndMemoryBarrier(); | ||
| template<typename T, typename V = uint32_t, typename I = uint32_t> | ||
| NBL_BOOL_CONCEPT BitonicSortSharedMemoryAccessor = concepts::accessors::GenericSharedMemoryAccessor<T, V, I>; | ||
|
|
||
| // The Accessor MUST provide the following methods: | ||
| // * void get(uint32_t index, NBL_REF_ARG(pair<KeyType, ValueType>) value); | ||
| // * void set(uint32_t index, in pair<KeyType, ValueType> value); | ||
| template<typename T, typename KeyType, typename ValueType, typename I = uint32_t> | ||
| NBL_BOOL_CONCEPT BitonicSortAccessor = concepts::accessors::GenericDataAccessor<T, pair<KeyType, ValueType>, I>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bitonic_sort::BitonicSort... is a tautology, drop the BitonicSort prefix from the names
| template<typename KeyType, typename ValueType, typename Comparator> | ||
| void compareSwap( | ||
| bool ascending, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) loPair, | ||
| NBL_REF_ARG(pair<KeyType, ValueType>) hiPair, | ||
| NBL_CONST_REF_ARG(Comparator) comp) | ||
| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transform into
template<typename KeyValue, uint32_t Log2N>
struct LocalPasses
{
void operator()(const bool ascending, NBL_REF_ARG(array<KeyValue,0x1u<<Log2N>) data, NBL_CONST_REF_ARG(Comparator) comp);
};and make your current implementation a partial spec for LocalPasses<KeyValue,2>
| const uint32_t invocationID = glsl::gl_SubgroupInvocationID(); | ||
| const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2(); | ||
| [unroll] | ||
| for (uint32_t stage = 0; stage <= subgroupSizeLog2; stage++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually I'd ask for an even cooler feature, would be nice if stage<=sortSizeLog2 and __call had a last argument with a default const uint32_t sortSizeLog2=Config::SubgroupSizeLog2
Why? Because it would be super useful for things like sorting arrays which are smaller than what a subgroup can process (so we can pack multiple arrays to be sorted independently into a single subgroup)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the const bool bitonicAscending = (stage == subgroupSizeLog2) ? would have to change to const bool bitonicAscending = (stage == sortSizeLog2 ) ?
| // Shuffle from partner using XOR | ||
| const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loPair.first, threadStride); | ||
| const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loPair.second, threadStride); | ||
| const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiPair.first, threadStride); | ||
| const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiPair.second, threadStride); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here's a problem with subgroupShuffle it only works on scalars and vectors.
Now why would we use a Comparison Based Sort over a Counting Sort?
Because the Key is really long in relation to number of elements to sort OR only supports custom comparison and can't be represented as a ordered number.
Therefore its most useful whenever the Key is a struct.
You need to document and require (conceptualize) that your KeyValue needs to have a shuffle_type typedef which is a vector<fundamental_type,N> and the KeyValue can be static_cast-ed to and from shuffle_type, you won't be able to just subgroupShuffle a pair<Key,Value> or anything similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually drop the static_cast if we had a stateful converter then we could avoid doing silly things
Why are we even trading/shuffling a full value type? when all that we need to do is give each Kev a surrogate Value which is just subgroupInvocationIndex which tells us how to shuffle at the very end.
| template<typename T1, typename T2> | ||
| struct pair | ||
| { | ||
| using first_type = T1; | ||
| using second_type = T2; | ||
|
|
||
| first_type first; | ||
| second_type second; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
want a get and a get_helper struct in this header, plus a particla spec for pair in this header
| { | ||
| const uint32_t stride = 1u << strideLog; | ||
| [unroll] | ||
| for (uint32_t virtualThreadID = threadID; virtualThreadID < TotalElements / 2; virtualThreadID += WorkgroupSize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this won't unroll, but this will
for (uint32_t baseVirtualThread=0; baseVirtualThread<HalfTotalElements; baseVirtualThread+=WorkgroupSize)
{
const uint32_t virtualThreadID = baseVirtualThread+threadID;the compiler now see that the loop condition is invariant (otherwise you have a dep on threadID and netiher DXC or SPIR-V opt can do the long and ardous analysis to see if your workgroup count is a multiple and the loop condition is actually uniform for whole group always)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already pointed it out Devsh-Graphics-Programming/Nabla-Examples-and-Tests#209 (comment)
| namespace bitonic_sort | ||
| { | ||
|
|
||
| template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator = less<KeyType> NBL_PRIMARY_REQUIRES(_ElementsPerInvocationLog2 >= 1 && _WorkgroupSizeLog2 >= 5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should just require that _WorkgroupSizeLog2>SubgroupSizeLog2 I see you've hard coded an assumption that Subgroups are Size 32, the subgroup size needs to come into the config too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should also be a general TotalElementsLog2>SubgroupSizeLog2+ElementsPerInvocationLog2 then this tells us our VirtualWorkgroupSizeLog2
|
|
||
| NBL_CONSTEXPR_STATIC_INLINE uint32_t ElementsPerInvocation = 1u << ElementsPerInvocationLog2; | ||
| NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = 1u << WorkgroupSizeLog2; | ||
| NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedmemDWORDs = sizeof(pair<key_t, value_t>) / sizeof(uint32_t) * WorkgroupSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SharedMemorySize should come from outside, and it must be as big as sizeof(KeyValue::workgroup_type)<<WorkgroupSizeLog2
the idea is that if we have more shared memory, then we issue less barriers by:
- exchanging more items per barrier (up to
0x1u<<(WorkgroupSizeLog2+ElementPerInvocationLog2)) - cycling through offsets within the shared memory to barrier every so often (every time you wrap around)
| template<typename Config, class device_capabilities = void> | ||
| struct BitonicSort; | ||
|
|
||
| // ==================== ElementsPerThreadLog2 = 1 Specialization (No Virtual Threading) ==================== | ||
| // This handles arrays of size WorkgroupSize * 2 using subgroup + workgroup operations | ||
| template<uint16_t WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator, class device_capabilities> | ||
| struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyType, ValueType, Comparator>, device_capabilities> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Fletterio I think device_capabilities aren't really needed for anything here or the FFT because we baiscally require Shuffle operations, unlike where for the Prefix sum we use them to decide between native subgroupAdd and a thing made with subgroup shuffles
| template<typename Key, typename Value, typename key_adaptor_t, typename value_adaptor_t> | ||
| static void shuffleXor(NBL_REF_ARG(pair<Key, Value>) p, uint32_t ownedIdx, uint32_t mask, NBL_REF_ARG(key_adaptor_t) keyAdaptor, NBL_REF_ARG(value_adaptor_t) valueAdaptor) | ||
| { | ||
| keyAdaptor.template set<Key>(ownedIdx, p.first); | ||
| valueAdaptor.template set<Value>(ownedIdx, p.second); | ||
|
|
||
| // Wait until all writes are done before reading - only barrier on one adaptor here | ||
| keyAdaptor.workgroupExecutionAndMemoryBarrier(); | ||
|
|
||
| keyAdaptor.template get<Key>(ownedIdx ^ mask, p.first); | ||
| valueAdaptor.template get<Value>(ownedIdx ^ mask, p.second); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Fletterio an XOR shuffle of a struct (single accessor) within a workgroup should be commonalized somewhere cause its used in FFT and here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try making a generic version but memory layouts fuck me up a bit. I'll iterate on a version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah with shuffling KeyValue and leaving the SoA/AoS bank-conflict decomposition to userspace Accessor allows us to use the workgroup shuffleXoR #940 (comment)
its probably worth extending that thing to handle Virtual Workgroups (when you have more shared memory and can do multiple shuffles without a barrier in the middle
| } | ||
|
|
||
| // PHASE 3: Global memory bitonic merge | ||
| const uint32_t totalLog = hlsl::findMSB(TotalElements - 1) + 1u; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TotalElements should already be passed as Log2
| if (sub) | ||
| sharedmemAccessor.workgroupExecutionAndMemoryBarrier(); | ||
|
|
||
| offsetAccessor.offset = sub * ElementsPerSimpleSort; | ||
|
|
||
| // Call E=1 workgroup sort | ||
| BitonicSort<simple_config_t, device_capabilities>::template __call(offsetAccessor, sharedmemAccessor); | ||
| } | ||
| sharedmemAccessor.workgroupExecutionAndMemoryBarrier(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put the barrier at the end of the loop then you don't need if(sub) and the barrier after the loop
| { | ||
| const uint32_t WorkgroupSize = config_t::WorkgroupSize; | ||
|
|
||
| const uint32_t invocationID = glsl::gl_LocalInvocationID().x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh btw the Spec doesn't say that LocalInvocationIndex = SubgroupID*SubgroupSize+SubgroupInvocationID
This is why in such shaders you need to always use
| uint16_t SubgroupContiguousIndex() |
instead of gl_LocalInvocationID or gl_LocalInvocationIndex
| const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2(); | ||
| const uint32_t subgroupSize = 1u << subgroupSizeLog2; | ||
| const uint32_t subgroupID = glsl::gl_SubgroupID(); | ||
| const uint32_t numSubgroups = WorkgroupSize / subgroupSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's actually a GLSL builtin for this
| // ==================== ElementsPerThreadLog2 > 1 Specialization (Virtual Threading) ==================== | ||
| // This handles larger arrays by combining global memory stages with recursive E=1 workgroup sorts | ||
| template<uint16_t ElementsPerThreadLog2, uint16_t WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator, class device_capabilities> | ||
| struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, WorkgroupSizeLog2, KeyType, ValueType, Comparator>, device_capabilities> | ||
| { | ||
| using config_t = bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, WorkgroupSizeLog2, KeyType, ValueType, Comparator>; | ||
| using simple_config_t = bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyType, ValueType, Comparator>; | ||
|
|
||
| using key_t = KeyType; | ||
| using value_t = ValueType; | ||
| using comparator_t = Comparator; | ||
|
|
||
| template<typename Accessor, typename SharedMemoryAccessor> | ||
| static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor) | ||
| { | ||
| const uint32_t WorkgroupSize = config_t::WorkgroupSize; | ||
| const uint32_t ElementsPerThread = config_t::ElementsPerInvocation; | ||
| const uint32_t TotalElements = WorkgroupSize * ElementsPerThread; | ||
| const uint32_t ElementsPerSimpleSort = WorkgroupSize * 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ElementsPerThread should be the number of elements each thread operates on, the multiplier for Virtual Threading (what your current ElementsPerThread acts as) should be derived from TotalElementsLog2-WorkgroupSizeLog2-ElementsPerThreadLog2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in view of #943 (comment) you could actually count VirtualSubgroups
| const uint32_t subgroupSize = 1u << subgroupSizeLog2; | ||
| const uint32_t subgroupID = glsl::gl_SubgroupID(); | ||
| const uint32_t numSubgroups = WorkgroupSize / subgroupSize; | ||
| const uint32_t numSubgroupsLog2 = findMSB(numSubgroups); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be cool derive this as TotalElementsLog2-SubgroupSizeLog2-ElementsPerThreadLog2 this would allow us to do workgroup partitioned sorts, eg:
- sort 4x256 with a 512 workgroup elements per thread 2
| const bool isUpper = bool(invocationID & threadStride); | ||
| const bool takeLarger = isUpper == bitonicAscending; | ||
|
|
||
| nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loPair, partnerLoPair, hiPair, partnerHiPair, comp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments throughout code needed about this stuff #940 (comment)
| template<typename Config, class device_capabilities = void> | ||
| struct BitonicSort; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its declaraed in a tautologic namespace, as @Fletterio mentioned we don't want bitonic_sort::BitonicSort
| nbl::hlsl::pair<key_t, value_t> plopair = lopair; | ||
| shuffleXor(plopair, invocationID, threadStride, keyAdaptor, valueAdaptor); | ||
|
|
||
| nbl::hlsl::pair<key_t, value_t> phipair = hipair; | ||
| shuffleXor(phipair, invocationID ^ threadStride, threadStride, keyAdaptor, valueAdaptor); | ||
|
|
||
| const bool isUpper = (invocationID & threadStride) != 0; | ||
| const bool takeLarger = isUpper == bitonicAscending; | ||
|
|
||
| nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, lopair, plopair, hipair, phipair, comp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're doing double the work... unlike subgroups where an invocation can't write into its partner's register. here you can definitely have one thread read both items, and write both items
| using key_adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, uint32_t, uint32_t, 1, WorkgroupSize>; | ||
| using value_adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, uint32_t, uint32_t, 1, WorkgroupSize, integral_constant<uint32_t, WorkgroupSize * sizeof(key_t) / sizeof(uint32_t)> >; | ||
|
|
||
| key_adaptor_t keyAdaptor; | ||
| keyAdaptor.accessor = sharedmemAccessor; | ||
| value_adaptor_t valueAdaptor; | ||
| valueAdaptor.accessor = sharedmemAccessor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can skip the StructureOfArrays in this code, more useful to use it in userspace , see #943 (comment)
Description
Testing
TODO list: