Skip to content

Commit 23f0976

Browse files
committed
Combining dense, depthwise and groupwise convolutions through common interface `DenseConvDims`
1 parent e4fc929 commit 23f0976

6 files changed

+395
-413
lines changed

src/conv.jl

+10-28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, depthwiseconv,
2-
depthwiseconv!, ∇depthwiseconv_data, ∇depthwiseconv_data!, ∇depthwiseconv_filter,
3-
∇depthwiseconv_filter!
1+
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
42

53
## Convolution API
64
#
@@ -36,9 +34,6 @@ for (front_name, backend) in (
3634
:conv => :im2col,
3735
:∇conv_data => :im2col,
3836
:∇conv_filter => :im2col,
39-
:depthwiseconv => :im2col,
40-
:∇depthwiseconv_data => :im2col,
41-
:∇depthwiseconv_filter => :im2col,
4237
)
4338

4439
# These are the GEMM types we will accelerate with `im2col`
@@ -58,8 +53,7 @@ end
5853
# Our strategy for 1d and 2d convolution is to reshape to 3d convolutions, which
5954
# makes things MUCH EASIER for us on the backend side, and is in general pretty fast,
6055
# since we can specialize on sizes.
61-
for front_name in (:conv, :∇conv_data, :∇conv_filter,
62-
:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
56+
for front_name in (:conv, :∇conv_data, :∇conv_filter)
6357
for backend in (Symbol(), :_direct, :_im2col)
6458
for N in (3, 4)
6559
@eval begin
@@ -87,8 +81,7 @@ end
8781
# We always support a fallback, non-accelerated path, where we use the direct, but
8882
# slow, implementations. These should not typically be used, hence the `@debug`,
8983
# but let's ggo ahead and define them first:
90-
for front_name in (:conv, :∇conv_data, :∇conv_filter,
91-
:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
84+
for front_name in (:conv, :∇conv_data, :∇conv_filter)
9285
@eval begin
9386
function $(Symbol("$(front_name)!"))(
9487
y::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
@@ -106,7 +99,7 @@ end
10699
# allocation. :P
107100
for backend in (Symbol(), :_direct, :_im2col)
108101
# First make auto-allocating versions of the conv()-like calls:
109-
for name in (:conv, :depthwiseconv)
102+
for name in (:conv,)
110103
@eval begin
111104
function $(Symbol("$(name)$(backend)"))(
112105
x::AbstractArray{xT,N}, w::AbstractArray{wT,N},
@@ -118,7 +111,7 @@ for backend in (Symbol(), :_direct, :_im2col)
118111
end
119112
end
120113

121-
for name in (:∇conv_data, :∇depthwiseconv_data)
114+
for name in (:∇conv_data,)
122115
@eval begin
123116
function $(Symbol("$(name)$(backend)"))(
124117
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},
@@ -130,28 +123,17 @@ for backend in (Symbol(), :_direct, :_im2col)
130123
end
131124
end
132125

133-
# We do the conv/depthwiseconv filter backprops separately, as the shape calculation
134-
# for `w` is slightly different for depthwise than for normal dense convolution.
126+
# This filter back prop covers dense/depthwise/groupwise conv filter backprops, as groupcount alone
127+
# is a deciding factor from cudnn's perspective. For backends im2col and direct needs to be handled.
135128
@eval begin
136129
function $(Symbol("∇conv_filter$(backend)"))(
137130
x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},
138131
cdims::ConvDims; kwargs...) where {xT, yT, N}
139-
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims),
132+
dw = similar(dy, kernel_size(cdims)..., div(channels_in(cdims),group_count(cdims)),
140133
channels_out(cdims))
141134
return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...)
142135
end
143136
end
144-
145-
@eval begin
146-
function $(Symbol("∇depthwiseconv_filter$(backend)"))(
147-
x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},
148-
cdims::ConvDims; kwargs...) where {xT, yT, N}
149-
dw = similar(dy, kernel_size(cdims)..., channel_multiplier(cdims),
150-
channels_in(cdims))
151-
return $(Symbol("∇depthwiseconv_filter$(backend)!"))(dw, x, dy, cdims;
152-
kwargs...)
153-
end
154-
end
155137
end
156138

157139

@@ -172,10 +154,10 @@ function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flip
172154
return conv(x, w, cdims)
173155
end
174156

175-
function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false) where {T, N}
157+
function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groupcount) where {T, N}
176158
stride = expand(Val(N-2), stride)
177159
pad = expand(Val(N-2), pad)
178160
dilation = expand(Val(N-2), dilation)
179-
cdims = DepthwiseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
161+
cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped, groupcount=groupcount)
180162
return depthwiseconv(x, w, cdims)
181163
end

src/dim_helpers/ConvDims.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ abstract type ConvDims{N, S, P, D, F} end
1616

1717
# Hack to get rid of type parameters
1818
function basetype(::Type{C}) where {C <: ConvDims}
19-
if C <: DepthwiseConvDims
20-
return DepthwiseConvDims
21-
elseif C <: DenseConvDims
19+
if C <: DenseConvDims
2220
return DenseConvDims
2321
elseif C <: PoolDims
2422
return PoolDims

src/dim_helpers/DenseConvDims.jl

+16-14
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,31 @@ export DenseConvDims
55
66
Concrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d.
77
"""
8-
struct DenseConvDims{N,K,C_in,C_out,S,P,D,F} <: ConvDims{N,S,P,D,F}
8+
struct DenseConvDims{N,K,C_in,C_out,S,P,D,F,G} <: ConvDims{N,S,P,D,F}
99
I::NTuple{N,Int}
1010
end
1111

1212
# Getters for the fields
1313
input_size(c::DenseConvDims) = c.I
14-
kernel_size(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = K
15-
channels_in(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_in
16-
channels_out(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_out
14+
kernel_size(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = K
15+
channels_in(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_in
16+
channels_out(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_out
17+
group_count(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = G
1718

1819
# Convenience wrapper to create DenseConvDims objects
1920
function DenseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
20-
stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M
21+
stride=1, padding=0, dilation=1, flipkernel::Bool=false, groupcount=1) where M
2122
# Do common parameter validation
2223
stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation)
2324

2425
# Ensure channels are equal
25-
if x_size[end-1] != w_size[end-1]
26+
if x_size[end-1] != w_size[end-1]*groupcount
2627
xs = x_size[end-1]
27-
ws = w_size[end-1]
28+
ws = w_size[end-1]*groupcount
2829
throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)"))
2930
end
30-
31-
# The type parameters are what
31+
32+
# The type parameters are what
3233
return DenseConvDims{
3334
M - 2,
3435
w_size[1:end-2],
@@ -37,7 +38,8 @@ function DenseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
3738
stride,
3839
padding,
3940
dilation,
40-
flipkernel
41+
flipkernel,
42+
groupcount
4143
}(
4244
# Input spatial size
4345
x_size[1:end-2],
@@ -56,17 +58,17 @@ end
5658
# from the original progenitor object that it inherits shapes from.
5759
function DenseConvDims(c::ConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
5860
C_in=channels_in(c), C_out=channels_out(c), S=stride(c),
59-
P=padding(c), D=dilation(c), F=flipkernel(c))
60-
return DenseConvDims{N, K, C_in, C_out, S, P, D, F}(I)
61+
P=padding(c), D=dilation(c), F=flipkernel(c), G=group_count(c))
62+
return DenseConvDims{N, K, C_in, C_out, S, P, D, F, G}(I)
6163
end
6264

6365
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M}
6466
# First, check that channel counts are all correct:
6567
@assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
6668
@assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
67-
@assert w[end-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end-1]) vs. $(channels_in(cdims)))")
69+
@assert w[end-1] == channels_in(cdims)/group_count(cdims) DimensionMismatch("Kernel input channel count ($(w[end-1]) vs. $(channels_in(cdims)/group_count(cdims)))")
6870
@assert w[end] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[end]) vs. $(channels_out(cdims)))")
69-
71+
7072
# Next, check that the spatial dimensions match up
7173
@assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
7274
@assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")

src/dim_helpers/DepthwiseConvDims.jl

+90-90
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,90 @@
1-
export DepthwiseConvDims
2-
3-
"""
4-
DepthwiseConvDims
5-
6-
Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to
7-
characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from
8-
DenseConvDims primarily for channel calculation differences.
9-
"""
10-
struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
11-
I::NTuple{N, Int}
12-
K::NTuple{N, Int}
13-
C_in::Int
14-
C_mult::Int
15-
end
16-
17-
# Getters for the fields
18-
input_size(c::DepthwiseConvDims) = c.I
19-
kernel_size(c::DepthwiseConvDims) = c.K
20-
channels_in(c::DepthwiseConvDims) = c.C_in
21-
channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c)
22-
channel_multiplier(c::DepthwiseConvDims) = c.C_mult
23-
24-
25-
# Convenience wrapper to create DepthwiseConvDims objects
26-
function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
27-
stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M
28-
# Do common parameter validation
29-
stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation)
30-
31-
# Ensure channels are equal
32-
if x_size[end-1] != w_size[end]
33-
xs = x_size[end-1]
34-
ws = w_size[end]
35-
throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)"))
36-
end
37-
38-
return DepthwiseConvDims{
39-
M - 2,
40-
stride,
41-
padding,
42-
dilation,
43-
flipkernel
44-
}(
45-
# Image spatial size
46-
x_size[1:end-2],
47-
48-
# Kernel spatial size
49-
w_size[1:end-2],
50-
51-
# Input channels
52-
x_size[end-1],
53-
54-
# Channel multiplier
55-
w_size[end-1],
56-
)
57-
end
58-
59-
# Auto-extract sizes and just pass those directly in
60-
function DepthwiseConvDims(x::AbstractArray, w::AbstractArray; kwargs...)
61-
if ndims(x) != ndims(w)
62-
throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))"))
63-
end
64-
return DepthwiseConvDims(size(x), size(w); kwargs...)
65-
end
66-
67-
# Useful for constructing a new DepthwiseConvDims that has only a few elements different
68-
# from the original progenitor object.
69-
function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
70-
C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c),
71-
P=padding(c), D=dilation(c), F=flipkernel(c))
72-
return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m)
73-
end
74-
75-
# This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count
76-
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M}
77-
# First, check that channel counts are all correct:
78-
@assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
79-
@assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
80-
@assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))")
81-
@assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))")
82-
83-
# Next, check that the spatial dimensions match up
84-
@assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
85-
@assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
86-
@assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
87-
88-
# Finally, check that the batch size matches
89-
@assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
90-
end
1+
# export DepthwiseConvDims
2+
#
3+
# """
4+
# DepthwiseConvDims
5+
#
6+
# Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to
7+
# characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from
8+
# DenseConvDims primarily for channel calculation differences.
9+
# """
10+
# struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
11+
# I::NTuple{N, Int}
12+
# K::NTuple{N, Int}
13+
# C_in::Int
14+
# C_mult::Int
15+
# end
16+
#
17+
# # Getters for the fields
18+
# input_size(c::DepthwiseConvDims) = c.I
19+
# kernel_size(c::DepthwiseConvDims) = c.K
20+
# channels_in(c::DepthwiseConvDims) = c.C_in
21+
# channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c)
22+
# channel_multiplier(c::DepthwiseConvDims) = c.C_mult
23+
#
24+
#
25+
# # Convenience wrapper to create DepthwiseConvDims objects
26+
# function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
27+
# stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M
28+
# # Do common parameter validation
29+
# stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation)
30+
#
31+
# # Ensure channels are equal
32+
# if x_size[end-1] != w_size[end]
33+
# xs = x_size[end-1]
34+
# ws = w_size[end]
35+
# throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)"))
36+
# end
37+
#
38+
# return DepthwiseConvDims{
39+
# M - 2,
40+
# stride,
41+
# padding,
42+
# dilation,
43+
# flipkernel
44+
# }(
45+
# # Image spatial size
46+
# x_size[1:end-2],
47+
#
48+
# # Kernel spatial size
49+
# w_size[1:end-2],
50+
#
51+
# # Input channels
52+
# x_size[end-1],
53+
#
54+
# # Channel multiplier
55+
# w_size[end-1],
56+
# )
57+
# end
58+
#
59+
# # Auto-extract sizes and just pass those directly in
60+
# function DepthwiseConvDims(x::AbstractArray, w::AbstractArray; kwargs...)
61+
# if ndims(x) != ndims(w)
62+
# throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))"))
63+
# end
64+
# return DepthwiseConvDims(size(x), size(w); kwargs...)
65+
# end
66+
#
67+
# # Useful for constructing a new DepthwiseConvDims that has only a few elements different
68+
# # from the original progenitor object.
69+
# function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
70+
# C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c),
71+
# P=padding(c), D=dilation(c), F=flipkernel(c))
72+
# return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m)
73+
# end
74+
#
75+
# # This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count
76+
# function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M}
77+
# # First, check that channel counts are all correct:
78+
# @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
79+
# @assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
80+
# @assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))")
81+
# @assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))")
82+
#
83+
# # Next, check that the spatial dimensions match up
84+
# @assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
85+
# @assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
86+
# @assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
87+
#
88+
# # Finally, check that the batch size matches
89+
# @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
90+
# end

0 commit comments

Comments
 (0)