1
- export DepthwiseConvDims
2
-
3
- """
4
- DepthwiseConvDims
5
-
6
- Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to
7
- characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from
8
- DenseConvDims primarily for channel calculation differences.
9
- """
10
- struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
11
- I:: NTuple{N, Int}
12
- K:: NTuple{N, Int}
13
- C_in:: Int
14
- C_mult:: Int
15
- end
16
-
17
- # Getters for the fields
18
- input_size (c:: DepthwiseConvDims ) = c. I
19
- kernel_size (c:: DepthwiseConvDims ) = c. K
20
- channels_in (c:: DepthwiseConvDims ) = c. C_in
21
- channels_out (c:: DepthwiseConvDims ) = c. C_in * channel_multiplier (c)
22
- channel_multiplier (c:: DepthwiseConvDims ) = c. C_mult
23
-
24
-
25
- # Convenience wrapper to create DepthwiseConvDims objects
26
- function DepthwiseConvDims (x_size:: NTuple{M} , w_size:: NTuple{M} ;
27
- stride= 1 , padding= 0 , dilation= 1 , flipkernel:: Bool = false ) where M
28
- # Do common parameter validation
29
- stride, padding, dilation = check_spdf (x_size, w_size, stride, padding, dilation)
30
-
31
- # Ensure channels are equal
32
- if x_size[end - 1 ] != w_size[end ]
33
- xs = x_size[end - 1 ]
34
- ws = w_size[end ]
35
- throw (DimensionMismatch (" Input channels must match! ($xs vs. $ws )" ))
36
- end
37
-
38
- return DepthwiseConvDims{
39
- M - 2 ,
40
- stride,
41
- padding,
42
- dilation,
43
- flipkernel
44
- }(
45
- # Image spatial size
46
- x_size[1 : end - 2 ],
47
-
48
- # Kernel spatial size
49
- w_size[1 : end - 2 ],
50
-
51
- # Input channels
52
- x_size[end - 1 ],
53
-
54
- # Channel multiplier
55
- w_size[end - 1 ],
56
- )
57
- end
58
-
59
- # Auto-extract sizes and just pass those directly in
60
- function DepthwiseConvDims (x:: AbstractArray , w:: AbstractArray ; kwargs... )
61
- if ndims (x) != ndims (w)
62
- throw (DimensionMismatch (" Rank of x and w must match! ($(ndims (x)) vs. $(ndims (w)) )" ))
63
- end
64
- return DepthwiseConvDims (size (x), size (w); kwargs... )
65
- end
66
-
67
- # Useful for constructing a new DepthwiseConvDims that has only a few elements different
68
- # from the original progenitor object.
69
- function DepthwiseConvDims (c:: DepthwiseConvDims ; N= spatial_dims (c), I= input_size (c), K= kernel_size (c),
70
- C_in= channels_in (c), C_m= channel_multiplier (c), S= stride (c),
71
- P= padding (c), D= dilation (c), F= flipkernel (c))
72
- return DepthwiseConvDims {N, S, P, D, F} (I, K, C_in, C_m)
73
- end
74
-
75
- # This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count
76
- function check_dims (x:: NTuple{M} , w:: NTuple{M} , y:: NTuple{M} , cdims:: DepthwiseConvDims ) where {M}
77
- # First, check that channel counts are all correct:
78
- @assert x[end - 1 ] == channels_in (cdims) DimensionMismatch (" Data input channel count ($(x[end - 1 ]) vs. $(channels_in (cdims)) )" )
79
- @assert y[end - 1 ] == channels_out (cdims) DimensionMismatch (" Data output channel count ($(y[end - 1 ]) vs. $(channels_out (cdims)) )" )
80
- @assert w[end - 1 ] == channel_multiplier (cdims) DimensionMismatch (" Kernel multiplier channel count ($(w[end - 1 ]) vs. $(channel_multiplier (cdims)) " )
81
- @assert w[end ] == channels_in (cdims) DimensionMismatch (" Kernel input channel count ($(w[end ]) vs. $(channels_in (cdims)) )" )
82
-
83
- # Next, check that the spatial dimensions match up
84
- @assert x[1 : end - 2 ] == input_size (cdims) DimensionMismatch (" Data input spatial size ($(x[1 : end - 2 ]) vs. $(input_size (cdims)) )" )
85
- @assert y[1 : end - 2 ] == output_size (cdims) DimensionMismatch (" Data output spatial size ($(y[1 : end - 2 ]) vs. $(output_size (cdims)) )" )
86
- @assert w[1 : end - 2 ] == kernel_size (cdims) DimensionMismatch (" Kernel spatial size ($(w[1 : end - 2 ]) vs. $(kernel_size (cdims)) )" )
87
-
88
- # Finally, check that the batch size matches
89
- @assert x[end ] == y[end ] DimensionMismatch (" Batch size ($(x[end ]) vs. $(y[end ]) )" )
90
- end
1
+ # export DepthwiseConvDims
2
+ #
3
+ # """
4
+ # DepthwiseConvDims
5
+ #
6
+ # Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to
7
+ # characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from
8
+ # DenseConvDims primarily for channel calculation differences.
9
+ # """
10
+ # struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
11
+ # I::NTuple{N, Int}
12
+ # K::NTuple{N, Int}
13
+ # C_in::Int
14
+ # C_mult::Int
15
+ # end
16
+ #
17
+ # # Getters for the fields
18
+ # input_size(c::DepthwiseConvDims) = c.I
19
+ # kernel_size(c::DepthwiseConvDims) = c.K
20
+ # channels_in(c::DepthwiseConvDims) = c.C_in
21
+ # channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c)
22
+ # channel_multiplier(c::DepthwiseConvDims) = c.C_mult
23
+ #
24
+ #
25
+ # # Convenience wrapper to create DepthwiseConvDims objects
26
+ # function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
27
+ # stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M
28
+ # # Do common parameter validation
29
+ # stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation)
30
+ #
31
+ # # Ensure channels are equal
32
+ # if x_size[end-1] != w_size[end]
33
+ # xs = x_size[end-1]
34
+ # ws = w_size[end]
35
+ # throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)"))
36
+ # end
37
+ #
38
+ # return DepthwiseConvDims{
39
+ # M - 2,
40
+ # stride,
41
+ # padding,
42
+ # dilation,
43
+ # flipkernel
44
+ # }(
45
+ # # Image spatial size
46
+ # x_size[1:end-2],
47
+ #
48
+ # # Kernel spatial size
49
+ # w_size[1:end-2],
50
+ #
51
+ # # Input channels
52
+ # x_size[end-1],
53
+ #
54
+ # # Channel multiplier
55
+ # w_size[end-1],
56
+ # )
57
+ # end
58
+ #
59
+ # # Auto-extract sizes and just pass those directly in
60
+ # function DepthwiseConvDims(x::AbstractArray, w::AbstractArray; kwargs...)
61
+ # if ndims(x) != ndims(w)
62
+ # throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))"))
63
+ # end
64
+ # return DepthwiseConvDims(size(x), size(w); kwargs...)
65
+ # end
66
+ #
67
+ # # Useful for constructing a new DepthwiseConvDims that has only a few elements different
68
+ # # from the original progenitor object.
69
+ # function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
70
+ # C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c),
71
+ # P=padding(c), D=dilation(c), F=flipkernel(c))
72
+ # return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m)
73
+ # end
74
+ #
75
+ # # This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count
76
+ # function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M}
77
+ # # First, check that channel counts are all correct:
78
+ # @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
79
+ # @assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
80
+ # @assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))")
81
+ # @assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))")
82
+ #
83
+ # # Next, check that the spatial dimensions match up
84
+ # @assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
85
+ # @assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
86
+ # @assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
87
+ #
88
+ # # Finally, check that the batch size matches
89
+ # @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
90
+ # end
0 commit comments