@@ -12,109 +12,77 @@ store those as fields, just for convenience, and to allow for non-breaking chang
12
12
we decide we _do_ want to specialize on those values. We always want to specialize on
13
13
things like stride, padding, dilation, and kernel flipping though.
14
14
"""
15
- abstract type ConvDims{N, S, P, D, F} end
16
15
17
- # Hack to get rid of type parameters
18
- function basetype (:: Type{C} ) where {C <: ConvDims }
19
- if C <: DenseConvDims
20
- return DenseConvDims
21
- elseif C <: PoolDims
22
- return PoolDims
23
- else
24
- return nothing
25
- end
16
+ struct ConvDims{N,K,C_in,C_out,S,P,D,F,G} <: AbstractDims{N,S,P,D,F}
17
+ I:: NTuple{N,Int}
26
18
end
27
19
28
- # Obvious getter definitions for the type system-level definitions
29
- spatial_dims (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = N
30
- stride (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = S
31
- padding (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = P
32
- dilation (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = D
33
- flipkernel (c:: ConvDims{N,S,P,D,F} ) where {N, S, P, D, F} = F
34
-
35
- """
36
- im2col_dims(c::ConvDims)
37
-
38
- im2col calculates, for each output pixel, the "convolution" of N kernels where N is the
39
- number of output channels, by doing a matrix multiply. The dimensions of that matrix
40
- are given by this function.
41
- """
42
- im2col_dims (c:: ConvDims ) = (prod (output_size (c)), prod (kernel_size (c))* channels_in (c))
43
20
44
- # Protect your skin, kids. Also do common validation of stride, padding, etc...
45
- function check_spdf (x_size:: NTuple{N} , w_size:: NTuple{N} , stride, padding, dilation) where {N}
46
- # Number of spatial dimensions in `x` and `w`.
47
- nd = N - 2
48
-
49
- # Given a number, duplicate it out to have `nd` length. If it's already a collection,
50
- # just splat it out into a tuple so it's always a tuple. We'll lint length later.
51
- expand_size (p:: Number ) = ntuple (_ -> Int (p), nd)
52
- expand_size (p) = tuple (p... )
53
-
54
- # Convert stride, padding, dilation, etc.. to fully-specified tuples
55
- pstride = expand_size (stride)
56
- pdilation = expand_size (dilation)
57
- ppadding = expand_size (padding)
58
-
59
- if length (pstride) != nd
60
- throw (DimensionMismatch (" Stride $(length (stride)) d, should be $(nd) d!" ))
61
- end
62
- if length (pdilation) != nd
63
- throw (DimensionMismatch (" Dilation $(length (pdilation)) d, should be $(nd) d!" ))
21
+ # Getters for the fields
22
+ input_size (c:: ConvDims ) = c. I
23
+ kernel_size (c:: ConvDims{N,K,C_in,C_out,S,P,D,F,G} ) where {N,K,C_in,C_out,S,P,D,F,G} = K
24
+ channels_in (c:: ConvDims{N,K,C_in,C_out,S,P,D,F,G} ) where {N,K,C_in,C_out,S,P,D,F,G} = C_in
25
+ channels_out (c:: ConvDims{N,K,C_in,C_out,S,P,D,F,G} ) where {N,K,C_in,C_out,S,P,D,F,G} = C_out
26
+ group_count (c:: ConvDims{N,K,C_in,C_out,S,P,D,F,G} ) where {N,K,C_in,C_out,S,P,D,F,G} = G
27
+
28
+ # Convenience wrapper to create ConvDims objects
29
+ function ConvDims (x_size:: NTuple{M} , w_size:: NTuple{M} ;
30
+ stride= 1 , padding= 0 , dilation= 1 , flipkernel:: Bool = false , groupcount= 1 ) where M
31
+ # Do common parameter validation
32
+ stride, padding, dilation = check_spdf (x_size, w_size, stride, padding, dilation)
33
+
34
+ # Ensure channels are equal
35
+ if x_size[M- 1 ] != w_size[M- 1 ]* groupcount
36
+ xs = x_size[M- 1 ]
37
+ ws = w_size[M- 1 ]* groupcount
38
+ throw (DimensionMismatch (" Input channels must match! ($xs vs. $ws )" ))
64
39
end
65
40
66
- # padding is kind of a special case; we allow it to be either 2-length or 4-length,
67
- # since we support asymmetrical padding
68
- if length (ppadding) != 2 * nd
69
- if length (ppadding) == nd
70
- # Do this repeat dance so that we get lo/hi symmetrical padding
71
- ppadding = tuple (repeat (collect (ppadding), inner= 2 )... )
72
- else
73
- throw (DimensionMismatch (" Padding $(length (ppadding)) d, should be either $(nd) d or $(2 * nd) d!" ))
74
- end
75
- end
41
+ # The type parameters are what
42
+ return ConvDims{
43
+ M - 2 ,
44
+ w_size[1 : M- 2 ],
45
+ x_size[M- 1 ],
46
+ w_size[M],
47
+ stride,
48
+ padding,
49
+ dilation,
50
+ flipkernel,
51
+ groupcount
52
+ }(
53
+ # Input spatial size
54
+ x_size[1 : M- 2 ],
55
+ )
56
+ end
76
57
77
- # Assert that kernel size * dilation is <= padded input size
78
- for idx in 1 : nd
79
- Is = x_size[idx]
80
- Pl = ppadding[(idx - 1 )* 2 + 1 ]
81
- Ph = ppadding[(idx - 1 )* 2 + 2 ]
82
- Ks = w_size[idx]
83
- Ds = pdilation[idx]
84
- if Is + Pl + Ph < (Ks - 1 )* Ds + 1
85
- throw (DimensionMismatch (" Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph )!" ))
86
- end
58
+ # Auto-extract sizes and sub out to big brother above
59
+ function ConvDims (x:: AbstractArray , w:: AbstractArray ; kwargs... )
60
+ if ndims (x) != ndims (w)
61
+ throw (DimensionMismatch (" Rank of x and w must match! ($(ndims (x)) vs. $(ndims (w)) )" ))
87
62
end
88
-
89
- return pstride, ppadding, pdilation
63
+ return ConvDims (size (x), size (w); kwargs... )
90
64
end
91
65
92
- """
93
- output_size(c::ConvDims)
66
+ # Useful for constructing a new ConvDims that has only a few elements different
67
+ # from the original progenitor object that it inherits shapes from.
68
+ function ConvDims (c:: AbstractDims ; N= spatial_dims (c), I= input_size (c), K= kernel_size (c),
69
+ C_in= channels_in (c), C_out= channels_out (c), S= stride (c),
70
+ P= padding (c), D= dilation (c), F= flipkernel (c), G= group_count (c))
71
+ return ConvDims {N, K, C_in, C_out, S, P, D, F, G} (I)
72
+ end
94
73
95
- Calculate the output (spatial) dimensions of the convolution. Get channel count via
96
- `channels_out(c)`, and batch count is unknowable.
97
- """
98
- function output_size (c:: ConvDims )
99
- I = input_size (c)
100
- K = kernel_size (c)
101
- S = stride (c)
102
- P = padding (c)
103
- D = dilation (c)
74
+ function check_dims (x:: NTuple{M} , w:: NTuple{M} , y:: NTuple{M} , cdims:: ConvDims ) where {M}
75
+ # First, check that channel counts are all correct:
76
+ @assert x[M- 1 ] == channels_in (cdims) DimensionMismatch (" Data input channel count ($(x[M- 1 ]) vs. $(channels_in (cdims)) )" )
77
+ @assert y[M- 1 ] == channels_out (cdims) DimensionMismatch (" Data output channel count ($(y[M- 1 ]) vs. $(channels_out (cdims)) )" )
78
+ @assert w[M- 1 ] == channels_in (cdims)/ group_count (cdims) DimensionMismatch (" Kernel input channel count ($(w[M- 1 ]) vs. $(channels_in (cdims)/ group_count (cdims)) )" )
79
+ @assert w[M] == channels_out (cdims) DimensionMismatch (" Kernel output channel count ($(w[M]) vs. $(channels_out (cdims)) )" )
104
80
105
- return ntuple ( spatial_dims (c)) do i
106
- return div (I[i] + P[(i - 1 ) * 2 + 1 ] + P[(i - 1 ) * 2 + 2 ] - (K[i] - 1 ) * D[i] - 1 , S[i ]) + 1
107
- end
108
- end
81
+ # Next, check that the spatial dimensions match up
82
+ @assert x[ 1 : M - 2 ] == input_size (cdims) DimensionMismatch ( " Data input spatial size ( $(x[ 1 : M - 2 ]) vs. $( input_size (cdims)) ) " )
83
+ @assert y[ 1 : M - 2 ] == output_size (cdims) DimensionMismatch ( " Data output spatial size ( $(y[ 1 : M - 2 ]) vs. $( output_size (cdims)) ) " )
84
+ @assert w[ 1 : M - 2 ] == kernel_size (cdims) DimensionMismatch ( " Kernel spatial size ( $(w[ 1 : M - 2 ]) vs. $( kernel_size (cdims)) ) " )
109
85
110
- # Override show() for these beauties
111
- function Base. show (io:: IO , cdims:: C ) where {C <: ConvDims }
112
- I = (input_size (cdims)... , channels_in (cdims))
113
- O = (output_size (cdims)... , channels_out (cdims))
114
- K = kernel_size (cdims)
115
- S = stride (cdims)
116
- P = padding (cdims)
117
- D = dilation (cdims)
118
- F = flipkernel (cdims)
119
- print (io, " $(basetype (C)) : $I * $K -> $O , stride: $S pad: $P , dil: $D , flip: $F " )
86
+ # Finally, check that the batch size matches
87
+ @assert x[M] == y[M] DimensionMismatch (" Batch size ($(x[M]) vs. $(y[M]) )" )
120
88
end
0 commit comments