@@ -37,13 +37,6 @@ using ..CUDA: i32
37
37
(eq && a′ == b′) || lt (a′, b′)
38
38
end
39
39
40
- # To allow sorting tuples of numbers:
41
- @inline _zero (x) = Base. zero (x)
42
- @inline _zero (:: Type{T} ) where {T<: Tuple{Vararg{Any,N}} } where {N} = ntuple (i -> zero (T. parameters[i]), N)
43
-
44
- @inline _one (x) = Base. one (x)
45
- @inline _one (:: Type{T} ) where {T<: Tuple{Vararg{Any,N}} } where {N} = ntuple (i -> one (T. parameters[i]), N)
46
-
47
40
48
41
# Batch partitioning
49
42
"""
@@ -80,7 +73,12 @@ Uses block y index to decide which values to operate on.
80
73
sync_threads ()
81
74
blockIdx_yz = (blockIdx (). z - 1 i32) * gridDim (). y + blockIdx (). y
82
75
idx0 = lo + (blockIdx_yz - 1 i32) * blockDim (). x + threadIdx (). x
83
- val = idx0 <= hi ? values[idx0] : _one (eltype (values))
76
+ val = if idx0 <= hi
77
+ values[idx0]
78
+ else
79
+ Ref {eltype(values)} ()[] # undef
80
+ # if idx0 > hi, val, comparison and dest_idx are unused
81
+ end
84
82
comparison = flex_lt (pivot, val, parity, lt, by)
85
83
86
84
@inbounds if idx0 <= hi
@@ -190,7 +188,7 @@ Must only run on 1 SM.
190
188
swap = if threadIdx (). x <= to_move
191
189
vals[lo + a + threadIdx (). x]
192
190
else
193
- _zero ( eltype (vals)) # unused value
191
+ Ref { eltype(vals)} ()[] # undef
194
192
end
195
193
sync_threads ()
196
194
if threadIdx (). x <= to_move
@@ -222,7 +220,6 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
222
220
223
221
@inbounds swap[threadIdx (). x] = vals[lo + threadIdx (). x * stride]
224
222
sync_threads ()
225
- old_val = _zero (eltype (swap))
226
223
227
224
log_blockDim = begin
228
225
out = 0
@@ -245,8 +242,10 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
245
242
to_swap = (i & k) == 0 && bitonic_lt (l, i) || (i & k) != 0 && bitonic_lt (i, l)
246
243
to_swap = to_swap == (i < l)
247
244
248
- if to_swap
249
- @inbounds old_val = swap[l + 1 ]
245
+ old_val = if to_swap
246
+ @inbounds swap[l + 1 ]
247
+ else
248
+ Ref {eltype(swap)} ()[] # undef
250
249
end
251
250
sync_threads ()
252
251
if to_swap
@@ -279,7 +278,7 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
279
278
buddy_val = if 1 <= buddy <= L && threadIdx (). x <= L
280
279
swap[buddy]
281
280
else
282
- _zero ( eltype (swap)) # unused value
281
+ Ref { eltype(swap)} ()[] # undef
283
282
end
284
283
sync_threads ()
285
284
if 1 <= buddy <= L && threadIdx (). x <= L
0 commit comments