Skip to content

Commit fa51a39

Browse files
don't spawn if only one job
1 parent cf5a07f commit fa51a39

File tree

4 files changed

+126
-71
lines changed

4 files changed

+126
-71
lines changed

src/conv.jl

+34-10
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ for (front_name, backend, signature) in (
181181
)
182182
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
183183
@eval begin
184-
184+
185185
function $(Symbol("$(front_name)!"))(
186186
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
187187
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
@@ -202,11 +202,19 @@ for (front_name, backend, signature) in (
202202
C_in = channels_in(cdims) ÷ groupcount(cdims),
203203
C_out = channels_out(cdims) ÷ groupcount(cdims))
204204

205-
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
205+
function do_work(xc, wc)
206206
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
207207
w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
208208
y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...]
209-
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
209+
$(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
210+
end
211+
212+
if length(x_cs) > 1
213+
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
214+
Threads.@spawn do_work(xc, wc)
215+
end
216+
else
217+
do_work(first(x_cs), first(w_cs))
210218
end
211219

212220
return out
@@ -246,11 +254,19 @@ for (front_name, backend, signature) in (
246254
C_in = channels_in(cdims) ÷ groupcount(cdims),
247255
C_out = channels_out(cdims) ÷ groupcount(cdims))
248256

249-
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
257+
function do_work(xc, yc, wc)
250258
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
251259
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
252260
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
253-
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...)
261+
$(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...)
262+
end
263+
264+
if length(dx_cs) > 1
265+
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
266+
Threads.@spawn do_work(xc, yc, wc)
267+
end
268+
else
269+
do_work(first(dx_cs), first(dy_cs), first(w_cs))
254270
end
255271

256272
return out
@@ -288,11 +304,19 @@ for (front_name, backend, signature) in (
288304
C_in = channels_in(cdims) ÷ groupcount(cdims),
289305
C_out = channels_out(cdims) ÷ groupcount(cdims))
290306

291-
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
307+
function do_work(wc, xc, yc)
292308
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
293309
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
294-
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
295-
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...)
310+
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...] # TODO: Is this supposed to use wc?
311+
$(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...)
312+
end
313+
314+
if length(dw_cs) > 1
315+
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
316+
Threads.@spawn do_work(wc, xc, yc)
317+
end
318+
else
319+
do_work(first(dw_cs), first(x_cs), first(dy_cs))
296320
end
297321

298322
return out
@@ -306,10 +330,10 @@ for (front_name, backend, signature) in (
306330
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
307331
(:depthwiseconv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
308332
(:depthwiseconv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
309-
333+
310334
(:∇depthwiseconv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
311335
(:∇depthwiseconv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
312-
336+
313337
(:∇depthwiseconv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
314338
(:∇depthwiseconv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
315339
)

src/gemm.jl

+20-10
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,33 @@ for (gemm, elt) in gemm_datatype_mappings
104104

105105
old_threads = get_num_threads()
106106
set_num_threads(1)
107-
Threads.@sync for ks in Iterators.partition(1:size(C, 3), cld(size(C, 3), n_threads))
108-
Threads.@spawn for k in ks
107+
108+
parts = Iterators.partition(1:size(C, 3), cld(size(C, 3), n_threads))
109+
110+
function do_work(ks)
111+
for k in ks
109112

110113
ptrAk = ptrA + (k-1) * strA * sizeof($elt)
111114
ptrBk = ptrB + (k-1) * strB * sizeof($elt)
112115
ptrCk = ptrC + (k-1) * strC * sizeof($elt)
113116

114117
ccall((@blasfunc($(gemm)), libblas), Nothing,
115-
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
116-
Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},
117-
Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},
118-
Ref{BlasInt}),
119-
transA, transB, m, n,
120-
ka, alpha, ptrAk, max(1,Base.stride(A,2)),
121-
ptrBk, max(1,Base.stride(B,2)), beta, ptrCk,
122-
max(1,Base.stride(C,2)))
118+
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
119+
Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},
120+
Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},
121+
Ref{BlasInt}),
122+
transA, transB, m, n,
123+
ka, alpha, ptrAk, max(1,Base.stride(A,2)),
124+
ptrBk, max(1,Base.stride(B,2)), beta, ptrCk,
125+
max(1,Base.stride(C,2)))
126+
end
127+
end
128+
if length(parts) > 1
129+
Threads.@sync for ks in parts
130+
Threads.@spawn do_work(ks)
123131
end
132+
else
133+
do_work(first(parts))
124134
end
125135
set_num_threads(old_threads)
126136

src/impl/conv_im2col.jl

+33-22
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,26 @@ function conv_im2col!(
4747

4848
parts = Iterators.partition(axes(x, 5), ceil(Int, size(x, 5) / ntasks))
4949

50-
@sync for (task_n, part) in enumerate(parts)
51-
Threads.@spawn begin
52-
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
53-
for batch_idx in part
54-
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
55-
GC.@preserve col_slice w y begin
56-
col_ptr = pointer(col_slice)
57-
w_ptr = pointer(w)
58-
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
59-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
60-
end
50+
function do_work(task_n, part)
51+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
52+
for batch_idx in part
53+
im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)
54+
GC.@preserve col_slice w y begin
55+
col_ptr = pointer(col_slice)
56+
w_ptr = pointer(w)
57+
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
58+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
6159
end
6260
end
6361
end
62+
63+
if length(parts) > 1
64+
@sync for (task_n, part) in enumerate(parts)
65+
Threads.@spawn do_work(task_n, part)
66+
end
67+
else
68+
do_work(1, first(parts))
69+
end
6470
return y
6571
end
6672

@@ -152,19 +158,24 @@ function ∇conv_data_im2col!(
152158

153159
parts = Iterators.partition(axes(dx, 5), ceil(Int, size(dx, 5) / ntasks))
154160

155-
@sync for (task_n, part) in enumerate(parts)
156-
Threads.@spawn begin
157-
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
158-
for batch_idx in part
159-
GC.@preserve col_slice w dy begin
160-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
161-
w_ptr = pointer(w)
162-
col_ptr = pointer(col_slice)
163-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
164-
end
165-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
161+
function do_work(task_n, part)
162+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
163+
for batch_idx in part
164+
GC.@preserve col_slice w dy begin
165+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
166+
w_ptr = pointer(w)
167+
col_ptr = pointer(col_slice)
168+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
166169
end
170+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
171+
end
172+
end
173+
if length(parts) > 1
174+
@sync for (task_n, part) in enumerate(parts)
175+
Threads.@spawn do_work(task_n, part)
167176
end
177+
else
178+
do_work(1, first(parts))
168179
end
169180
return dx
170181
end

src/impl/depthwiseconv_im2col.jl

+39-29
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,30 @@ function depthwiseconv_im2col!(
3030

3131
dcdims = DenseConvDims(cdims)
3232

33-
@sync for (task_n, part) in enumerate(parts)
34-
Threads.@spawn begin
35-
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
36-
for batch_idx in part
37-
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
38-
39-
# We do a separate convolution for each channel in x, as we must
40-
for c_in in 1:channels_in(cdims)
41-
# Walk each pointer forward as we process each input channel
42-
GC.@preserve col_slice w y begin
43-
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
44-
w_ptr = pointer(w, (c_in-1)*K*N+1)
45-
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
46-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
47-
end
33+
function do_work(task_n, part)
34+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
35+
for batch_idx in part
36+
im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)
37+
38+
# We do a separate convolution for each channel in x, as we must
39+
for c_in in 1:channels_in(cdims)
40+
# Walk each pointer forward as we process each input channel
41+
GC.@preserve col_slice w y begin
42+
col_ptr = pointer(col_slice, (c_in-1)*M*K+1)
43+
w_ptr = pointer(w, (c_in-1)*K*N+1)
44+
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
45+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
4846
end
4947
end
5048
end
5149
end
50+
if length(parts) > 1
51+
@sync for (task_n, part) in enumerate(parts)
52+
Threads.@spawn do_work(task_n, part)
53+
end
54+
else
55+
do_work(1, first(parts))
56+
end
5257
return y
5358
end
5459

@@ -117,23 +122,28 @@ function ∇depthwiseconv_data_im2col!(
117122

118123
parts = Iterators.partition(axes(dx)[end], ceil(Int, size(dx, 5) / ntasks))
119124

120-
@sync for (task_n, part) in enumerate(parts)
121-
Threads.@spawn begin
122-
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
123-
for batch_idx in part
124-
# We do a separate convolution for each channel in x, as we must
125-
for cidx in 1:channels_in(cdims)
126-
GC.@preserve col_slice w dy begin
127-
# Walk each pointer forward as we process each input channel
128-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
129-
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
130-
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
131-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
132-
end
125+
function do_work(task_n, part)
126+
col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace
127+
for batch_idx in part
128+
# We do a separate convolution for each channel in x, as we must
129+
for cidx in 1:channels_in(cdims)
130+
GC.@preserve col_slice w dy begin
131+
# Walk each pointer forward as we process each input channel
132+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
133+
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
134+
col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)
135+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
133136
end
134-
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
135137
end
138+
col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)
139+
end
140+
end
141+
if length(parts) > 1
142+
@sync for (task_n, part) in enumerate(parts)
143+
Threads.@spawn do_work(task_n, part)
136144
end
145+
else
146+
do_work(1, first(parts))
137147
end
138148
return dx
139149
end

0 commit comments

Comments
 (0)