Skip to content

Commit 92eb9fe

Browse files
committed
Avoid using unused one/zero values.
1 parent 92e9d74 commit 92eb9fe

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

src/sorting.jl

+12-13
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,6 @@ using ..CUDA: i32
3737
(eq && a′ == b′) || lt(a′, b′)
3838
end
3939

40-
# To allow sorting tuples of numbers:
41-
@inline _zero(x) = Base.zero(x)
42-
@inline _zero(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> zero(T.parameters[i]), N)
43-
44-
@inline _one(x) = Base.one(x)
45-
@inline _one(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> one(T.parameters[i]), N)
46-
4740

4841
# Batch partitioning
4942
"""
@@ -80,7 +73,12 @@ Uses block y index to decide which values to operate on.
8073
sync_threads()
8174
blockIdx_yz = (blockIdx().z - 1i32) * gridDim().y + blockIdx().y
8275
idx0 = lo + (blockIdx_yz - 1i32) * blockDim().x + threadIdx().x
83-
val = idx0 <= hi ? values[idx0] : _one(eltype(values))
76+
val = if idx0 <= hi
77+
values[idx0]
78+
else
79+
Ref{eltype(values)}()[] # undef
80+
# if idx0 > hi, val, comparison and dest_idx are unused
81+
end
8482
comparison = flex_lt(pivot, val, parity, lt, by)
8583

8684
@inbounds if idx0 <= hi
@@ -190,7 +188,7 @@ Must only run on 1 SM.
190188
swap = if threadIdx().x <= to_move
191189
vals[lo + a + threadIdx().x]
192190
else
193-
_zero(eltype(vals)) # unused value
191+
Ref{eltype(vals)}()[] # undef
194192
end
195193
sync_threads()
196194
if threadIdx().x <= to_move
@@ -222,7 +220,6 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
222220

223221
@inbounds swap[threadIdx().x] = vals[lo + threadIdx().x * stride]
224222
sync_threads()
225-
old_val = _zero(eltype(swap))
226223

227224
log_blockDim = begin
228225
out = 0
@@ -245,8 +242,10 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
245242
to_swap = (i & k) == 0 && bitonic_lt(l, i) || (i & k) != 0 && bitonic_lt(i, l)
246243
to_swap = to_swap == (i < l)
247244

248-
if to_swap
249-
@inbounds old_val = swap[l + 1]
245+
old_val = if to_swap
246+
@inbounds swap[l + 1]
247+
else
248+
Ref{eltype(swap)}()[] # undef
250249
end
251250
sync_threads()
252251
if to_swap
@@ -279,7 +278,7 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
279278
buddy_val = if 1 <= buddy <= L && threadIdx().x <= L
280279
swap[buddy]
281280
else
282-
_zero(eltype(swap)) # unused value
281+
Ref{eltype(swap)}()[] # undef
283282
end
284283
sync_threads()
285284
if 1 <= buddy <= L && threadIdx().x <= L

0 commit comments

Comments
 (0)