11module CUDAKernels
22
33using .. CUDA
4- using .. CUDA: @device_override , CUSPARSE, default_memory, UnifiedMemory
4+ using .. CUDA: @device_override , CUSPARSE, default_memory, UnifiedMemory, cufunction, cudaconvert
55
66import KernelAbstractions as KA
7+ import KernelAbstractions: KI
78
89import StaticArrays
910import SparseArrays: AbstractSparseArray
@@ -157,34 +158,59 @@ function (obj::KA.Kernel{CUDABackend})(args...; ndrange=nothing, workgroupsize=n
157158 return nothing
158159end
159160
161+ KI. argconvert (:: CUDABackend , arg) = cudaconvert (arg)
162+
163+ function KI. kernel_function (:: CUDABackend , f:: F , tt:: TT = Tuple{}; name= nothing , kwargs... ) where {F,TT}
164+ kern = cufunction (f, tt; name, kwargs... )
165+ KI. Kernel {CUDABackend, typeof(kern)} (CUDABackend (), kern)
166+ end
167+
168+ function (obj:: KI.Kernel{CUDABackend} )(args... ; numworkgroups = 1 , workgroupsize = 1 )
169+ KI. check_launch_args (numworkgroups, workgroupsize)
170+
171+ obj. kern (args... ; threads= workgroupsize, blocks= numworkgroups)
172+ return nothing
173+ end
174+
175+
176+ function KI. kernel_max_work_group_size (:: CUDABackend , kikern:: KI.Kernel{<:CUDABackend} ; max_work_items:: Int = typemax (Int)):: Int
177+ kernel_config = launch_configuration (kikern. kern. fun)
178+
179+ Int (min (kernel_config. threads, max_work_items))
180+ end
181+ function KI. max_work_group_size (:: CUDABackend ):: Int
182+ Int (attribute (device (), CUDA. DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK))
183+ end
184+ function KI. multiprocessor_count (:: CUDABackend ):: Int
185+ Int (attribute (device (), CUDA. DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
186+ end
187+
160188# # indexing
161189
162190# # COV_EXCL_START
163- @device_override @inline function KA . __index_Local_Linear (ctx )
164- return threadIdx (). x
191+ @device_override @inline function KI . get_local_id ( )
192+ return (; x = Int ( threadIdx (). x), y = Int ( threadIdx () . y), z = Int ( threadIdx () . z))
165193end
166194
167195
168- @device_override @inline function KA . __index_Group_Linear (ctx )
169- return blockIdx (). x
196+ @device_override @inline function KI . get_group_id ( )
197+ return (; x = Int ( blockIdx (). x), y = Int ( blockIdx () . y), z = Int ( blockIdx () . z))
170198end
171199
172- @device_override @inline function KA. __index_Global_Linear (ctx)
173- I = @inbounds KA. expand (KA. __iterspace (ctx), blockIdx (). x, threadIdx (). x)
174- # TODO : This is unfortunate, can we get the linear index cheaper
175- @inbounds LinearIndices (KA. __ndrange (ctx))[I]
200+ @device_override @inline function KI. get_global_id ()
201+ return (; x = Int ((blockIdx (). x- 1 )* blockDim (). x + threadIdx (). x), y = Int ((blockIdx (). y- 1 )* blockDim (). y + threadIdx (). y), z = Int ((blockIdx (). z- 1 )* blockDim (). z + threadIdx (). z))
176202end
177203
178- @device_override @inline function KA . __index_Local_Cartesian (ctx )
179- @inbounds KA . workitems (KA . __iterspace (ctx))[ threadIdx () . x]
204+ @device_override @inline function KI . get_local_size ( )
205+ return (; x = Int ( blockDim () . x), y = Int ( blockDim () . y), z = Int ( blockDim () . z))
180206end
181207
182- @device_override @inline function KA . __index_Group_Cartesian (ctx )
183- @inbounds KA . blocks (KA . __iterspace (ctx))[ blockIdx () . x]
208+ @device_override @inline function KI . get_num_groups ( )
209+ return (; x = Int ( gridDim () . x), y = Int ( gridDim () . y), z = Int ( gridDim () . z))
184210end
185211
186- @device_override @inline function KA . __index_Global_Cartesian (ctx )
187- return @inbounds KA . expand (KA . __iterspace (ctx), blockIdx () . x, threadIdx () . x )
212+ @device_override @inline function KI . get_global_size ( )
213+ return (; x = Int ( blockDim () . x * gridDim () . x), y = Int ( blockDim () . y * gridDim () . y), z = Int ( blockDim () . z * gridDim () . z) )
188214end
189215
190216@device_override @inline function KA. __validindex (ctx)
198224
199225# # shared and scratch memory
200226
201- @device_override @inline function KA. SharedMemory (:: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id}
227+ # @device_override @inline function KI.localmemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
228+ @device_override @inline function KI. localmemory (:: Type{T} , :: Val{Dims} ) where {T, Dims}
202229 CuStaticSharedArray (T, Dims)
203230end
204231
@@ -208,11 +235,11 @@ end
208235
209236# # synchronization and printing
210237
211- @device_override @inline function KA . __synchronize ()
238+ @device_override @inline function KI . barrier ()
212239 sync_threads ()
213240end
214241
215- @device_override @inline function KA . __print (args... )
242+ @device_override @inline function KI . _print (args... )
216243 CUDA. _cuprint (args... )
217244end
218245
0 commit comments