2020@inline function reduce_group (op, val:: T , neutral, shuffle:: Val{true} , :: Val{maxthreads} ) where {T, maxthreads}
2121 # shared mem for partial sums
2222 assume (threads_per_simdgroup () == 32 )
23- shared = MtlThreadGroupArray (T, 32 )
23+ shared = KI . localmemory (T, 32 )
2424
2525 wid = simdgroup_index_in_threadgroup ()
2626 lane = thread_index_in_simdgroup ()
3434 end
3535
3636 # wait for all partial reductions
37- threadgroup_barrier (MemoryFlagThreadGroup )
37+ KI . barrier ( )
3838
3939 # read from shared memory only if that warp existed
40- val = if thread_index_in_threadgroup () <= fld1 (threads_per_threadgroup (). x, 32 )
40+ val = if KI . get_local_id () . x <= fld1 (KI . get_local_size (). x, 32 )
4141 @inbounds shared[lane]
4242 else
4343 neutral
5252
5353# Reduce a value across a group, using local memory for communication
5454@inline function reduce_group (op, val:: T , neutral, shuffle:: Val{false} , :: Val{maxthreads} ) where {T, maxthreads}
55- threads = threads_per_threadgroup (). x
56- thread = thread_position_in_threadgroup (). x
55+ threads = KI . get_local_size (). x
56+ thread = KI . get_local_id (). x
5757
5858 # local mem for a complete reduction
59- shared = MtlThreadGroupArray (T, (maxthreads,))
59+ shared = KI . localmemory (T, (maxthreads,))
6060 @inbounds shared[thread] = val
6161
6262 # perform a reduction
6363 d = 1
6464 while d < threads
65- threadgroup_barrier (MemoryFlagThreadGroup )
65+ KI . barrier ( )
6666 index = 2 * d * (thread- 1 ) + 1
6767 @inbounds if index <= threads
6868 other_val = if index + d <= threads
@@ -94,9 +94,9 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
9494 :: Val{Rother} , :: Val{Rlen} , :: Val{grain} , shuffle, R, As... ) where {Rreduce, Rother, Rlen, grain}
9595 # decompose the 1D hardware indices into separate ones for reduction (across items
9696 # and possibly groups if it doesn't fit) and other elements (remaining groups)
97- localIdx_reduce = thread_position_in_threadgroup (). x
98- localDim_reduce = threads_per_threadgroup (). x * grain
99- groupIdx_reduce, groupIdx_other = fldmod1 (threadgroup_position_in_grid (). x, Rlen)
97+ localIdx_reduce = KI . get_local_id (). x
98+ localDim_reduce = KI . get_local_size (). x * grain
99+ groupIdx_reduce, groupIdx_other = fldmod1 (KI . get_group_id (). x, Rlen)
100100
101101 # group-based indexing into the values outside of the reduction dimension
102102 # (that means we can safely synchronize items within this group)
@@ -141,7 +141,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
141141end
142142
143143function serial_mapreduce_kernel (f, op, neutral, :: Val{Rreduce} , :: Val{Rother} , R, As) where {Rreduce, Rother}
144- grid_idx = thread_position_in_grid (). x
144+ grid_idx = KI . get_global_id (). x
145145
146146 @inbounds if grid_idx <= length (Rother)
147147 Iother = Rother[grid_idx]
@@ -166,11 +166,12 @@ end
166166
167167# # COV_EXCL_STOP
168168
169- serial_mapreduce_threshold (dev) = dev . maxThreadsPerThreadgroup . width * num_gpu_cores ( )
169+ serial_mapreduce_threshold (dev) = KI . max_work_group_size ( MetalBackend ()) * KI . multiprocessor_count ( MetalBackend () )
170170
171171function GPUArrays. mapreducedim! (f:: F , op:: OP , R:: WrappedMtlArray{T} ,
172172 A:: Union{AbstractArray,Broadcast.Broadcasted} ;
173173 init= nothing ) where {F, OP, T}
174+ backend = MetalBackend ()
174175 Base. check_reducedims (R, A)
175176 length (A) == 0 && return R # isempty(::Broadcasted) iterates
176177
@@ -195,10 +196,10 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
195196
196197 # If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
197198 if length (Rother) >= serial_mapreduce_threshold (device (R))
198- kernel = @metal launch = false serial_mapreduce_kernel ( f, op, init, Val (Rreduce), Val (Rother), R, A)
199- threads = min ( length (Rother), kernel . pipeline . maxTotalThreadsPerThreadgroup )
199+ kernel = KI . KIKernel (backend, serial_mapreduce_kernel, f, op, init, Val (Rreduce), Val (Rother), R, A)
200+ threads = KI . kernel_max_work_group_size (backend, kernel; max_work_items = length (Rother))
200201 groups = cld (length (Rother), threads)
201- kernel (f, op, init, Val (Rreduce), Val (Rother), R, A; threads, groups )
202+ kernel (f, op, init, Val (Rreduce), Val (Rother), R, A; numworkgroups = groups, workgroupsize = threads )
202203 return R
203204 end
204205
@@ -223,17 +224,17 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
223224 # we might not be able to launch all those threads to reduce each slice in one go.
224225 # that's why each threads also loops across their inputs, processing multiple values
225226 # so that we can span the entire reduction dimension using a single item group.
226- kernel = @metal launch = false partial_mapreduce_device ( f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
227+ kernel = KI . KIKernel (backend, partial_mapreduce_device, f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
227228 Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), R, A)
228229
229230 # how many threads do we want?
230231 #
231232 # threads in a group work together to reduce values across the reduction dimensions;
232233 # we want as many as possible to improve algorithm efficiency and execution occupancy.
233- wanted_threads = shuffle ? nextwarp (kernel. pipeline, length (Rreduce)) : length (Rreduce)
234+ wanted_threads = shuffle ? nextwarp (kernel. kern . pipeline, length (Rreduce)) : length (Rreduce)
234235 function compute_threads (max_threads)
235236 if wanted_threads > max_threads
236- shuffle ? prevwarp (kernel. pipeline, max_threads) : max_threads
237+ shuffle ? prevwarp (kernel. kern . pipeline, max_threads) : max_threads
237238 else
238239 wanted_threads
239240 end
@@ -243,7 +244,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
243244 # kernel above may be greater than the maxTotalThreadsPerThreadgroup of the eventually launched
244245 # kernel below, causing errors
245246 # reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup)
246- reduce_threads = compute_threads (512 )
247+ reduce_threads = compute_threads (KI . kernel_max_work_group_size (backend, kernel) )
247248
248249 # how many groups should we launch?
249250 #
@@ -262,7 +263,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
262263 # we can cover the dimensions to reduce using a single group
263264 kernel (f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
264265 Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), R, A;
265- threads, groups)
266+ numworkgroups = groups, workgroupsize = threads )
266267 else
267268 # we need multiple steps to cover all values to reduce
268269 partial = similar (R, (size (R)... , reduce_groups))
@@ -273,9 +274,12 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
273274 end
274275 # NOTE: we can't use the previously-compiled kernel, since the type of `partial`
275276 # might not match the original output container (e.g. if that was a view).
276- @metal threads groups partial_mapreduce_device (
277+ KI . KIKernel (backend, partial_mapreduce_device,
277278 f, op, init, Val (threads), Val (Rreduce), Val (Rother),
278- Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), partial, A)
279+ Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), partial, A)(
280+ f, op, init, Val (threads), Val (Rreduce), Val (Rother),
281+ Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), partial, A;
282+ numworkgroups= groups, workgroupsize= threads)
279283
280284 GPUArrays. mapreducedim! (identity, op, R, partial; init= init)
281285 end
0 commit comments