Skip to content

Commit 0e76668

Browse files
committed
Implement that fix again
1 parent 510bb2c commit 0e76668

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

src/mapreduce.jl

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,11 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
231231
#
232232
# threads in a group work together to reduce values across the reduction dimensions;
233233
# we want as many as possible to improve algorithm efficiency and execution occupancy.
234-
wanted_threads = shuffle ? nextwarp(kernel.kern.pipeline, length(Rreduce)) : length(Rreduce)
235-
function compute_threads(max_threads)
234+
function compute_threads(kern)
235+
max_threads = KI.kernel_max_work_group_size(backend, kern)
236+
wanted_threads = shuffle ? nextwarp(kern.kern.pipeline, length(Rreduce)) : length(Rreduce)
236237
if wanted_threads > max_threads
237-
shuffle ? prevwarp(kernel.kern.pipeline, max_threads) : max_threads
238+
shuffle ? prevwarp(kern.kern.pipeline, max_threads) : max_threads
238239
else
239240
wanted_threads
240241
end
@@ -244,7 +245,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
244245
# kernel above may be greater than the maxTotalThreadsPerThreadgroup of the eventually launched
245246
# kernel below, causing errors
246247
# reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup)
247-
reduce_threads = compute_threads(KI.kernel_max_work_group_size(backend, kernel))
248+
reduce_threads = compute_threads(kernel)
248249

249250
# how many groups should we launch?
250251
#
@@ -265,21 +266,33 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
265266
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A;
266267
numworkgroups=groups, workgroupsize=threads)
267268
else
268-
# we need multiple steps to cover all values to reduce
269-
partial = similar(R, (size(R)..., reduce_groups))
269+
# temporary empty array whose type will match the final partial array
270+
partial = similar(R, ntuple(_ -> 0, Val(ndims(R)+1)))
271+
272+
# NOTE: we can't use the previously-compiled kernel, or its launch configuration,
273+
# since the type of `partial` might not match the original output container
274+
# (e.g. if that was a view).
275+
partial_kernel = KI.KIKernel(backend, partial_mapreduce_device,
276+
f, op, init, Val(threads), Val(Rreduce),
277+
Val(Rother), Val(UInt64(length(Rother))),
278+
Val(grain), Val(shuffle), partial, A)
279+
partial_reduce_threads = compute_threads(partial_kernel)
280+
partial_reduce_groups = cld(length(Rreduce), partial_reduce_threads * grain)
281+
282+
partial_threads = partial_reduce_threads
283+
partial_groups = partial_reduce_groups*other_groups
284+
285+
partial = similar(R, (size(R)..., partial_reduce_groups))
270286
if init === nothing
271287
# without an explicit initializer we need to copy from the output container
272-
# use broadcasting to extend singleton dimensions
273288
partial .= R
274289
end
275-
# NOTE: we can't use the previously-compiled kernel, since the type of `partial`
276-
# might not match the original output container (e.g. if that was a view).
277-
KI.KIKernel(backend, partial_mapreduce_device,
278-
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
279-
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)(
280-
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
281-
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A;
282-
numworkgroups=groups, workgroupsize=threads)
290+
291+
partial_kernel(f, op, init, Val(threads), Val(Rreduce),
292+
Val(Rother), Val(UInt64(length(Rother))),
293+
Val(grain), Val(shuffle), partial, A;
294+
numworkgroups=partial_groups, workgroupsize=partial_threads)
295+
283296

284297
GPUArrays.mapreducedim!(identity, op, R, partial; init=init)
285298
end

0 commit comments

Comments
 (0)