@@ -231,10 +231,11 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
231231 #
232232 # threads in a group work together to reduce values across the reduction dimensions;
233233 # we want as many as possible to improve algorithm efficiency and execution occupancy.
234- wanted_threads = shuffle ? nextwarp (kernel. kern. pipeline, length (Rreduce)) : length (Rreduce)
235- function compute_threads (max_threads)
234+ function compute_threads (kern)
235+ max_threads = KI. kernel_max_work_group_size (backend, kern)
236+ wanted_threads = shuffle ? nextwarp (kern. kern. pipeline, length (Rreduce)) : length (Rreduce)
236237 if wanted_threads > max_threads
237- shuffle ? prevwarp (kernel . kern. pipeline, max_threads) : max_threads
238+ shuffle ? prevwarp (kern . kern. pipeline, max_threads) : max_threads
238239 else
239240 wanted_threads
240241 end
@@ -244,7 +245,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
244245 # kernel above may be greater than the maxTotalThreadsPerThreadgroup of the eventually launched
245246 # kernel below, causing errors
246247 # reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup)
247- reduce_threads = compute_threads (KI . kernel_max_work_group_size (backend, kernel) )
248+ reduce_threads = compute_threads (kernel)
248249
249250 # how many groups should we launch?
250251 #
@@ -265,21 +266,33 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
265266 Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), R, A;
266267 numworkgroups= groups, workgroupsize= threads)
267268 else
268- # we need multiple steps to cover all values to reduce
269- partial = similar (R, (size (R)... , reduce_groups))
269+ # temporary empty array whose type will match the final partial array
270+ partial = similar (R, ntuple (_ -> 0 , Val (ndims (R)+ 1 )))
271+
272+ # NOTE: we can't use the previously-compiled kernel, or its launch configuration,
273+ # since the type of `partial` might not match the original output container
274+ # (e.g. if that was a view).
275+ partial_kernel = KI. KIKernel (backend, partial_mapreduce_device,
276+ f, op, init, Val (threads), Val (Rreduce),
277+ Val (Rother), Val (UInt64 (length (Rother))),
278+ Val (grain), Val (shuffle), partial, A)
279+ partial_reduce_threads = compute_threads (partial_kernel)
280+ partial_reduce_groups = cld (length (Rreduce), partial_reduce_threads * grain)
281+
282+ partial_threads = partial_reduce_threads
283+ partial_groups = partial_reduce_groups* other_groups
284+
285+ partial = similar (R, (size (R)... , partial_reduce_groups))
270286 if init === nothing
271287 # without an explicit initializer we need to copy from the output container
272- # use broadcasting to extend singleton dimensions
273288 partial .= R
274289 end
275- # NOTE: we can't use the previously-compiled kernel, since the type of `partial`
276- # might not match the original output container (e.g. if that was a view).
277- KI. KIKernel (backend, partial_mapreduce_device,
278- f, op, init, Val (threads), Val (Rreduce), Val (Rother),
279- Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), partial, A)(
280- f, op, init, Val (threads), Val (Rreduce), Val (Rother),
281- Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), partial, A;
282- numworkgroups= groups, workgroupsize= threads)
290+
291+ partial_kernel (f, op, init, Val (threads), Val (Rreduce),
292+ Val (Rother), Val (UInt64 (length (Rother))),
293+ Val (grain), Val (shuffle), partial, A;
294+ numworkgroups= partial_groups, workgroupsize= partial_threads)
295+
283296
284297 GPUArrays. mapreducedim! (identity, op, R, partial; init= init)
285298 end
0 commit comments