Skip to content

Commit 14da5b5

Browse files
committed
Adapting ConvDims to accomadate Groupwise and Depthwise Convolutions and removing seperate implementations of Depthwise and Groupwise.
1 parent 0afd23a commit 14da5b5

12 files changed

+326
-239
lines changed

src/conv.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
1717
#
1818
# All methods require a `ConvDims` object to define the dimensions and optional
1919
# elements of the convolution (padding, stride, dilation, kernel-flipping, etc...),
20-
# which is easily constructable through something like `DenseConvDims(x, w)`. All
20+
# which is easily constructable through something like `ConvDims(x, w)`. All
2121
# methods take in the `ConvDims` of the associated normal, forward-pass convolution,
2222
# that is, the following is legal:
2323
#
@@ -123,7 +123,7 @@ for backend in (Symbol(), :_direct, :_im2col)
123123
end
124124
end
125125

126-
# This filter back prop covers dense/depthwise/groupwise conv filter backprops, as groupcount alone
126+
# This filter back prop covers dense/depthwise/groupwise conv filter backprops, as groupcount alone
127127
# is a deciding factor from cudnn's perspective. For backends im2col and direct needs to be handled.
128128
@eval begin
129129
function $(Symbol("∇conv_filter$(backend)"))(
@@ -140,7 +140,7 @@ end
140140
# Use NNPACK if it is available and the operation is supported
141141
if is_nnpack_available()
142142
function conv(x::Array{xT, 4}, w::Array{wT, 4},
143-
cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F};
143+
cdims::ConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F};
144144
kwargs...) where {xT, wT, K, C_in, C_out, S, P, F}
145145
return conv_nnpack(x, w, cdims; kwargs...)
146146
end
@@ -150,14 +150,14 @@ function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flip
150150
stride = expand(Val(N-2), stride)
151151
pad = expand(Val(N-2), pad)
152152
dilation = expand(Val(N-2), dilation)
153-
cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
153+
cdims = ConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
154154
return conv(x, w, cdims)
155155
end
156156

157157
function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groupcount) where {T, N}
158158
stride = expand(Val(N-2), stride)
159159
pad = expand(Val(N-2), pad)
160160
dilation = expand(Val(N-2), dilation)
161-
cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped, groupcount=groupcount)
161+
cdims = ConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped, groupcount=groupcount)
162162
return depthwiseconv(x, w, cdims)
163163
end

src/dim_helpers.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Various helper functions to calculate dimensions for operations
2+
include("dim_helpers/AbstractDims.jl")
23
include("dim_helpers/ConvDims.jl")
3-
include("dim_helpers/DenseConvDims.jl")
4-
include("dim_helpers/DepthwiseConvDims.jl")
54
include("dim_helpers/PoolDims.jl")
65

76

@@ -45,7 +44,7 @@ function transpose_pad(cdims::ConvDims)
4544
end
4645

4746
"""
48-
insert_singleton_spatial_dimension(cdims::DenseConvDims)
47+
insert_singleton_spatial_dimension(cdims::ConvDims)
4948
5049
When converting a 1d convolution to a 2d, or a 2d to a 3d, we need to insert a singleton
5150
spatial dimension at the end of the spatial dimensions. This does so for a ConvDims.

src/dim_helpers/AbstractDims.jl

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
export AbstractDims
2+
3+
"""
4+
AbstractDims
5+
6+
Type system-level information about convolution dimensions. Critical for things like
7+
`im2col!()` to generate efficient code, and helpful to reduce the number of kwargs
8+
getting passed around.
9+
10+
We don't want to specialize on things like image size/channel count, so we generally
11+
store those as fields, just for convenience, and to allow for non-breaking changes when
12+
we decide we _do_ want to specialize on those values. We always want to specialize on
13+
things like stride, padding, dilation, and kernel flipping though.
14+
"""
15+
abstract type AbstractDims{N, S, P, D, F} end
16+
17+
# Hack to get rid of type parameters
18+
function basetype(::Type{C}) where {C <: AbstractDims}
19+
if C <: ConvDims
20+
return ConvDims
21+
elseif C <: PoolDims
22+
return PoolDims
23+
else
24+
return nothing
25+
end
26+
end
27+
28+
# Obvious getter definitions for the type system-level definitions
29+
spatial_dims(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = N
30+
stride(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = S
31+
padding(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = P
32+
dilation(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = D
33+
flipkernel(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = F
34+
35+
"""
36+
im2col_dims(c::AbstractDims)
37+
38+
im2col calculates, for each output pixel, the "convolution" of N kernels where N is the
39+
number of output channels, by doing a matrix multiply. The dimensions of that matrix
40+
are given by this function.
41+
"""
42+
im2col_dims(c::AbstractDims) = (prod(output_size(c)), prod(kernel_size(c))*channels_in(c))
43+
44+
# Protect your skin, kids. Also do common validation of stride, padding, etc...
45+
function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N}
46+
# Number of spatial dimensions in `x` and `w`.
47+
nd = N - 2
48+
49+
# Given a number, duplicate it out to have `nd` length. If it's already a collection,
50+
# just splat it out into a tuple so it's always a tuple. We'll lint length later.
51+
expand_size(p::Number) = ntuple(_ -> Int(p), nd)
52+
expand_size(p) = tuple(p...)
53+
54+
# Convert stride, padding, dilation, etc.. to fully-specified tuples
55+
pstride = expand_size(stride)
56+
pdilation = expand_size(dilation)
57+
ppadding = expand_size(padding)
58+
59+
if length(pstride) != nd
60+
throw(DimensionMismatch("Stride $(length(stride))d, should be $(nd)d!"))
61+
end
62+
if length(pdilation) != nd
63+
throw(DimensionMismatch("Dilation $(length(pdilation))d, should be $(nd)d!"))
64+
end
65+
66+
# padding is kind of a special case; we allow it to be either 2-length or 4-length,
67+
# since we support asymmetrical padding
68+
if length(ppadding) != 2*nd
69+
if length(ppadding) == nd
70+
# Do this repeat dance so that we get lo/hi symmetrical padding
71+
ppadding = tuple(repeat(collect(ppadding), inner=2)...)
72+
else
73+
throw(DimensionMismatch("Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!"))
74+
end
75+
end
76+
77+
# Assert that kernel size * dilation is <= padded input size
78+
for idx in 1:nd
79+
Is = x_size[idx]
80+
Pl = ppadding[(idx - 1)*2 + 1]
81+
Ph = ppadding[(idx - 1)*2 + 2]
82+
Ks = w_size[idx]
83+
Ds = pdilation[idx]
84+
if Is + Pl + Ph < (Ks - 1)*Ds + 1
85+
throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!"))
86+
end
87+
end
88+
89+
return pstride, ppadding, pdilation
90+
end
91+
92+
"""
93+
output_size(c::AbstractDims)
94+
95+
Calculate the output (spatial) dimensions of the convolution. Get channel count via
96+
`channels_out(c)`, and batch count is unknowable.
97+
"""
98+
function output_size(c::AbstractDims)
99+
I = input_size(c)
100+
K = kernel_size(c)
101+
S = stride(c)
102+
P = padding(c)
103+
D = dilation(c)
104+
105+
return ntuple(spatial_dims(c)) do i
106+
return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1
107+
end
108+
end
109+
110+
# Override show() for these beauties
111+
function Base.show(io::IO, cdims::C) where {C <: AbstractDims}
112+
I = (input_size(cdims)..., channels_in(cdims))
113+
O = (output_size(cdims)..., channels_out(cdims))
114+
K = kernel_size(cdims)
115+
S = stride(cdims)
116+
P = padding(cdims)
117+
D = dilation(cdims)
118+
F = flipkernel(cdims)
119+
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S pad: $P, dil: $D, flip: $F")
120+
end

src/dim_helpers/ConvDims.jl

+60-92
Original file line numberDiff line numberDiff line change
@@ -12,109 +12,77 @@ store those as fields, just for convenience, and to allow for non-breaking chang
1212
we decide we _do_ want to specialize on those values. We always want to specialize on
1313
things like stride, padding, dilation, and kernel flipping though.
1414
"""
15-
abstract type ConvDims{N, S, P, D, F} end
1615

17-
# Hack to get rid of type parameters
18-
function basetype(::Type{C}) where {C <: ConvDims}
19-
if C <: DenseConvDims
20-
return DenseConvDims
21-
elseif C <: PoolDims
22-
return PoolDims
23-
else
24-
return nothing
25-
end
16+
struct ConvDims{N,K,C_in,C_out,S,P,D,F,G} <: AbstractDims{N,S,P,D,F}
17+
I::NTuple{N,Int}
2618
end
2719

28-
# Obvious getter definitions for the type system-level definitions
29-
spatial_dims(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = N
30-
stride(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = S
31-
padding(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = P
32-
dilation(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = D
33-
flipkernel(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = F
34-
35-
"""
36-
im2col_dims(c::ConvDims)
37-
38-
im2col calculates, for each output pixel, the "convolution" of N kernels where N is the
39-
number of output channels, by doing a matrix multiply. The dimensions of that matrix
40-
are given by this function.
41-
"""
42-
im2col_dims(c::ConvDims) = (prod(output_size(c)), prod(kernel_size(c))*channels_in(c))
4320

44-
# Protect your skin, kids. Also do common validation of stride, padding, etc...
45-
function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N}
46-
# Number of spatial dimensions in `x` and `w`.
47-
nd = N - 2
48-
49-
# Given a number, duplicate it out to have `nd` length. If it's already a collection,
50-
# just splat it out into a tuple so it's always a tuple. We'll lint length later.
51-
expand_size(p::Number) = ntuple(_ -> Int(p), nd)
52-
expand_size(p) = tuple(p...)
53-
54-
# Convert stride, padding, dilation, etc.. to fully-specified tuples
55-
pstride = expand_size(stride)
56-
pdilation = expand_size(dilation)
57-
ppadding = expand_size(padding)
58-
59-
if length(pstride) != nd
60-
throw(DimensionMismatch("Stride $(length(stride))d, should be $(nd)d!"))
61-
end
62-
if length(pdilation) != nd
63-
throw(DimensionMismatch("Dilation $(length(pdilation))d, should be $(nd)d!"))
21+
# Getters for the fields
22+
input_size(c::ConvDims) = c.I
23+
kernel_size(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = K
24+
channels_in(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_in
25+
channels_out(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_out
26+
group_count(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = G
27+
28+
# Convenience wrapper to create ConvDims objects
29+
function ConvDims(x_size::NTuple{M}, w_size::NTuple{M};
30+
stride=1, padding=0, dilation=1, flipkernel::Bool=false, groupcount=1) where M
31+
# Do common parameter validation
32+
stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation)
33+
34+
# Ensure channels are equal
35+
if x_size[M-1] != w_size[M-1]*groupcount
36+
xs = x_size[M-1]
37+
ws = w_size[M-1]*groupcount
38+
throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)"))
6439
end
6540

66-
# padding is kind of a special case; we allow it to be either 2-length or 4-length,
67-
# since we support asymmetrical padding
68-
if length(ppadding) != 2*nd
69-
if length(ppadding) == nd
70-
# Do this repeat dance so that we get lo/hi symmetrical padding
71-
ppadding = tuple(repeat(collect(ppadding), inner=2)...)
72-
else
73-
throw(DimensionMismatch("Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!"))
74-
end
75-
end
41+
# The type parameters are what
42+
return ConvDims{
43+
M - 2,
44+
w_size[1:M-2],
45+
x_size[M-1],
46+
w_size[M],
47+
stride,
48+
padding,
49+
dilation,
50+
flipkernel,
51+
groupcount
52+
}(
53+
# Input spatial size
54+
x_size[1:M-2],
55+
)
56+
end
7657

77-
# Assert that kernel size * dilation is <= padded input size
78-
for idx in 1:nd
79-
Is = x_size[idx]
80-
Pl = ppadding[(idx - 1)*2 + 1]
81-
Ph = ppadding[(idx - 1)*2 + 2]
82-
Ks = w_size[idx]
83-
Ds = pdilation[idx]
84-
if Is + Pl + Ph < (Ks - 1)*Ds + 1
85-
throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!"))
86-
end
58+
# Auto-extract sizes and sub out to big brother above
59+
function ConvDims(x::AbstractArray, w::AbstractArray; kwargs...)
60+
if ndims(x) != ndims(w)
61+
throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))"))
8762
end
88-
89-
return pstride, ppadding, pdilation
63+
return ConvDims(size(x), size(w); kwargs...)
9064
end
9165

92-
"""
93-
output_size(c::ConvDims)
66+
# Useful for constructing a new ConvDims that has only a few elements different
67+
# from the original progenitor object that it inherits shapes from.
68+
function ConvDims(c::AbstractDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
69+
C_in=channels_in(c), C_out=channels_out(c), S=stride(c),
70+
P=padding(c), D=dilation(c), F=flipkernel(c), G=group_count(c))
71+
return ConvDims{N, K, C_in, C_out, S, P, D, F, G}(I)
72+
end
9473

95-
Calculate the output (spatial) dimensions of the convolution. Get channel count via
96-
`channels_out(c)`, and batch count is unknowable.
97-
"""
98-
function output_size(c::ConvDims)
99-
I = input_size(c)
100-
K = kernel_size(c)
101-
S = stride(c)
102-
P = padding(c)
103-
D = dilation(c)
74+
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::ConvDims) where {M}
75+
# First, check that channel counts are all correct:
76+
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
77+
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
78+
@assert w[M-1] == channels_in(cdims)/group_count(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)/group_count(cdims)))")
79+
@assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))")
10480

105-
return ntuple(spatial_dims(c)) do i
106-
return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1
107-
end
108-
end
81+
# Next, check that the spatial dimensions match up
82+
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
83+
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
84+
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")
10985

110-
# Override show() for these beauties
111-
function Base.show(io::IO, cdims::C) where {C <: ConvDims}
112-
I = (input_size(cdims)..., channels_in(cdims))
113-
O = (output_size(cdims)..., channels_out(cdims))
114-
K = kernel_size(cdims)
115-
S = stride(cdims)
116-
P = padding(cdims)
117-
D = dilation(cdims)
118-
F = flipkernel(cdims)
119-
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S pad: $P, dil: $D, flip: $F")
86+
# Finally, check that the batch size matches
87+
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
12088
end

0 commit comments

Comments
 (0)