Skip to content

Commit 85d80d1

Browse files
Don't spawn if: only one iteration or no threads or threading disabled (#633)
* don't spawn if only one job * fix what appears to be a typo * revertme: temporary test * fix check * rm test * add NNlib.ALLOW_THREADING control * use ScopedValues * use `@with` to avoid new scope * rename do_work functions * add note * v0.9.29
1 parent ec337e6 commit 85d80d1

9 files changed

+200
-94
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlib"
22
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3-
version = "0.9.28"
3+
version = "0.9.29"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -10,6 +10,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1010
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1314
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1415

1516
[weakdeps]
@@ -43,6 +44,7 @@ GPUArraysCore = "0.1, 0.2"
4344
KernelAbstractions = "0.9.2"
4445
LinearAlgebra = "<0.0.1, 1"
4546
Random = "<0.0.1, 1"
47+
ScopedValues = "1.3.0"
4648
SpecialFunctions = "2"
4749
Statistics = "1"
4850
cuDNN = "1"

docs/src/index.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,13 @@ for CUDA support, or
1212
```julia
1313
using NNlib, AMDGPU
1414
```
15-
for AMDGPU support.
15+
for AMDGPU support.
16+
17+
## Threading
18+
19+
Various `NNlib` functions utilize available julia threads on divisible workloads. To disable this use
20+
the `ScopedValue`-backed switch `NNlib.@disallow_spawns`
21+
i.e.
22+
```julia
23+
NNlib.@disallow_spawns function_that_uses_nnlib()
24+
```

src/NNlib.jl

+19-4
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,26 @@ using LinearAlgebra
1313
using LinearAlgebra.BLAS: @blasfunc, BlasInt
1414
using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose
1515
using Random
16+
using ScopedValues
1617
using Statistics
1718
using Statistics: mean
1819

1920
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}
2021

22+
# internal. TODO: change to an approach where amount of threading is controlled, not just on/off
23+
const ALLOW_SPAWNS = ScopedValue(true)
24+
should_use_spawn() = Threads.nthreads(:default) > 1 && ALLOW_SPAWNS[]
25+
"""
26+
@disallow_spawns ex
27+
28+
Disallow NNlib to use `@spawn` on divisible workloads. i.e. within `conv` etc.
29+
"""
30+
macro disallow_spawns(ex)
31+
quote
32+
@with ALLOW_SPAWNS => false $(esc(ex))
33+
end
34+
end
35+
2136
# Include APIs
2237
include("dim_helpers.jl")
2338
export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims
@@ -35,7 +50,7 @@ include("dropout.jl")
3550
export dropout, dropout!
3651

3752
include("softmax.jl")
38-
export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax,
53+
export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax,
3954
logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp
4055

4156
include("batched/batchedadjtrans.jl")
@@ -47,9 +62,9 @@ include("gemm.jl")
4762
export grid_sample, ∇grid_sample
4863

4964
include("conv.jl")
50-
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
51-
∇conv_filter!, depthwiseconv, depthwiseconv!,
52-
∇depthwiseconv_data, ∇depthwiseconv_data!,
65+
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
66+
∇conv_filter!, depthwiseconv, depthwiseconv!,
67+
∇depthwiseconv_data, ∇depthwiseconv_data!,
5368
∇depthwiseconv_filter, ∇depthwiseconv_filter!
5469

5570
include("conv_bias_act.jl")

src/conv.jl

+40-10
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ for (front_name, backend, signature) in (
181181
)
182182
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
183183
@eval begin
184-
184+
185185
function $(Symbol("$(front_name)!"))(
186186
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
187187
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
@@ -202,11 +202,21 @@ for (front_name, backend, signature) in (
202202
C_in = channels_in(cdims) ÷ groupcount(cdims),
203203
C_out = channels_out(cdims) ÷ groupcount(cdims))
204204

205-
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
205+
function conv_group(xc, wc)
206206
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
207207
w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
208208
y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...]
209-
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
209+
$(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
210+
end
211+
212+
if should_use_spawn() && length(x_cs) > 1
213+
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
214+
Threads.@spawn conv_group(xc, wc)
215+
end
216+
else
217+
for (xc, wc) in zip(x_cs, w_cs)
218+
conv_group(xc, wc)
219+
end
210220
end
211221

212222
return out
@@ -246,11 +256,21 @@ for (front_name, backend, signature) in (
246256
C_in = channels_in(cdims) ÷ groupcount(cdims),
247257
C_out = channels_out(cdims) ÷ groupcount(cdims))
248258

249-
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
259+
function ∇conv_data_group(xc, yc, wc)
250260
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
251261
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
252262
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
253-
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...)
263+
$(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...)
264+
end
265+
266+
if should_use_spawn() && length(dx_cs) > 1
267+
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
268+
Threads.@spawn ∇conv_data_group(xc, yc, wc)
269+
end
270+
else
271+
for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
272+
∇conv_data_group(xc, yc, wc)
273+
end
254274
end
255275

256276
return out
@@ -288,11 +308,21 @@ for (front_name, backend, signature) in (
288308
C_in = channels_in(cdims) ÷ groupcount(cdims),
289309
C_out = channels_out(cdims) ÷ groupcount(cdims))
290310

291-
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
311+
function ∇conv_filter_group(wc, xc, yc)
292312
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
293313
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
294-
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
295-
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...)
314+
dw = @view out[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
315+
$(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...)
316+
end
317+
318+
if should_use_spawn() && length(dw_cs) > 1
319+
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
320+
Threads.@spawn ∇conv_filter_group(wc, xc, yc)
321+
end
322+
else
323+
for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
324+
∇conv_filter_group(wc, xc, yc)
325+
end
296326
end
297327

298328
return out
@@ -306,10 +336,10 @@ for (front_name, backend, signature) in (
306336
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
307337
(:depthwiseconv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
308338
(:depthwiseconv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
309-
339+
310340
(:∇depthwiseconv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
311341
(:∇depthwiseconv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
312-
342+
313343
(:∇depthwiseconv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
314344
(:∇depthwiseconv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
315345
)

src/gemm.jl

+22-10
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,34 @@ for (gemm, elt) in gemm_datatype_mappings
104104

105105
old_threads = get_num_threads()
106106
set_num_threads(1)
107-
Threads.@sync for ks in Iterators.partition(1:size(C, 3), cld(size(C, 3), n_threads))
108-
Threads.@spawn for k in ks
107+
108+
parts = Iterators.partition(1:size(C, 3), cld(size(C, 3), n_threads))
109+
110+
function gemm!_part(ks)
111+
for k in ks
109112

110113
ptrAk = ptrA + (k-1) * strA * sizeof($elt)
111114
ptrBk = ptrB + (k-1) * strB * sizeof($elt)
112115
ptrCk = ptrC + (k-1) * strC * sizeof($elt)
113116

114117
ccall((@blasfunc($(gemm)), libblas), Nothing,
115-
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
116-
Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},
117-
Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},
118-
Ref{BlasInt}),
119-
transA, transB, m, n,
120-
ka, alpha, ptrAk, max(1,Base.stride(A,2)),
121-
ptrBk, max(1,Base.stride(B,2)), beta, ptrCk,
122-
max(1,Base.stride(C,2)))
118+
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
119+
Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},
120+
Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},
121+
Ref{BlasInt}),
122+
transA, transB, m, n,
123+
ka, alpha, ptrAk, max(1,Base.stride(A,2)),
124+
ptrBk, max(1,Base.stride(B,2)), beta, ptrCk,
125+
max(1,Base.stride(C,2)))
126+
end
127+
end
128+
if should_use_spawn() && length(parts) > 1
129+
Threads.@sync for ks in parts
130+
Threads.@spawn gemm!_part(ks)
131+
end
132+
else
133+
for ks in parts
134+
gemm!_part(ks)
123135
end
124136
end
125137
set_num_threads(old_threads)

src/impl/conv_im2col.jl

+37-22
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,28 @@ function conv_im2col!(
4747

4848
parts = Iterators.partition(axes(x, 5), ceil(Int, size(x, 5) / ntasks))
4949

50-
@sync for (task_n, part) in enumerate(parts)
51-
Threads.@spawn begin
52-
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
53-
for batch_idx in part
54-
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
55-
GC.@preserve col_slice w y begin
56-
col_ptr = pointer(col_slice)
57-
w_ptr = pointer(w)
58-
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
59-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
60-
end
50+
function conv_part(task_n, part)
51+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
52+
for batch_idx in part
53+
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
54+
GC.@preserve col_slice w y begin
55+
col_ptr = pointer(col_slice)
56+
w_ptr = pointer(w)
57+
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
58+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
6159
end
6260
end
6361
end
62+
63+
if should_use_spawn() && length(parts) > 1
64+
@sync for (task_n, part) in enumerate(parts)
65+
Threads.@spawn conv_part(task_n, part)
66+
end
67+
else
68+
for (task_n, part) in enumerate(parts)
69+
conv_part(task_n, part)
70+
end
71+
end
6472
return y
6573
end
6674

@@ -152,18 +160,25 @@ function ∇conv_data_im2col!(
152160

153161
parts = Iterators.partition(axes(dx, 5), ceil(Int, size(dx, 5) / ntasks))
154162

155-
@sync for (task_n, part) in enumerate(parts)
156-
Threads.@spawn begin
157-
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
158-
for batch_idx in part
159-
GC.@preserve col_slice w dy begin
160-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
161-
w_ptr = pointer(w)
162-
col_ptr = pointer(col_slice)
163-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
164-
end
165-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
163+
function ∇conv_data_part(task_n, part)
164+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
165+
for batch_idx in part
166+
GC.@preserve col_slice w dy begin
167+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
168+
w_ptr = pointer(w)
169+
col_ptr = pointer(col_slice)
170+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
166171
end
172+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
173+
end
174+
end
175+
if should_use_spawn() && length(parts) > 1
176+
@sync for (task_n, part) in enumerate(parts)
177+
Threads.@spawn ∇conv_data_part(task_n, part)
178+
end
179+
else
180+
for (task_n, part) in enumerate(parts)
181+
∇conv_data_part(task_n, part)
167182
end
168183
end
169184
return dx

src/impl/depthwiseconv_im2col.jl

+43-29
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,32 @@ function depthwiseconv_im2col!(
3030

3131
dcdims = DenseConvDims(cdims)
3232

33-
@sync for (task_n, part) in enumerate(parts)
34-
Threads.@spawn begin
35-
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
36-
for batch_idx in part
37-
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
38-
39-
# We do a separate convolution for each channel in x, as we must
40-
for c_in in 1:channels_in(cdims)
41-
# Walk each pointer forward as we process each input channel
42-
GC.@preserve col_slice w y begin
43-
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
44-
w_ptr = pointer(w, (c_in-1)*K*N+1)
45-
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
46-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
47-
end
33+
function depthwiseconv_part(task_n, part)
34+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
35+
for batch_idx in part
36+
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
37+
38+
# We do a separate convolution for each channel in x, as we must
39+
for c_in in 1:channels_in(cdims)
40+
# Walk each pointer forward as we process each input channel
41+
GC.@preserve col_slice w y begin
42+
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
43+
w_ptr = pointer(w, (c_in-1)*K*N+1)
44+
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
45+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
4846
end
4947
end
5048
end
5149
end
50+
if should_use_spawn() && length(parts) > 1
51+
@sync for (task_n, part) in enumerate(parts)
52+
Threads.@spawn depthwiseconv_part(task_n, part)
53+
end
54+
else
55+
for (task_n, part) in enumerate(parts)
56+
depthwiseconv_part(task_n, part)
57+
end
58+
end
5259
return y
5360
end
5461

@@ -117,22 +124,29 @@ function ∇depthwiseconv_data_im2col!(
117124

118125
parts = Iterators.partition(axes(dx)[end], ceil(Int, size(dx, 5) / ntasks))
119126

120-
@sync for (task_n, part) in enumerate(parts)
121-
Threads.@spawn begin
122-
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
123-
for batch_idx in part
124-
# We do a separate convolution for each channel in x, as we must
125-
for cidx in 1:channels_in(cdims)
126-
GC.@preserve col_slice w dy begin
127-
# Walk each pointer forward as we process each input channel
128-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
129-
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
130-
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
131-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
132-
end
127+
function ∇depthwiseconv_data_part(task_n, part)
128+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
129+
for batch_idx in part
130+
# We do a separate convolution for each channel in x, as we must
131+
for cidx in 1:channels_in(cdims)
132+
GC.@preserve col_slice w dy begin
133+
# Walk each pointer forward as we process each input channel
134+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
135+
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
136+
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
137+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
133138
end
134-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
135139
end
140+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
141+
end
142+
end
143+
if should_use_spawn() && length(parts) > 1
144+
@sync for (task_n, part) in enumerate(parts)
145+
Threads.@spawn ∇depthwiseconv_data_part(task_n, part)
146+
end
147+
else
148+
for (task_n, part) in enumerate(parts)
149+
∇depthwiseconv_data_part(task_n, part)
136150
end
137151
end
138152
return dx

0 commit comments

Comments
 (0)