@@ -67,7 +67,8 @@ function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
67
67
error (" both scale and bias must be provided or left as nothing" )
68
68
end
69
69
scale′, bias′ = _maybe_reshape (scale, affine_size), _maybe_reshape (bias, affine_size)
70
- return _apply_scale_bias ((x .- μ) ./ sqrt .(σ² .+ ϵ), scale′, bias′)
70
+ denom = inv .(sqrt .(σ² .+ ϵ))
71
+ return _apply_scale_bias ((x .- μ) .* denom, scale′, bias′)
71
72
end
72
73
73
74
"""
76
77
Contains running mean and variance estimates for stateful norm functions.
77
78
`momentum` controls the strength of the moving average update.
78
79
79
- If the parameters are mutable, they will be updated in-place.
80
- Otherwise, they will be replaced wholesale.
80
+ Parameters should be mutable and will be updated in-place.
81
81
82
82
See also [`update_running_stats!`](@ref).
83
83
"""
84
- mutable struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
84
+ struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
85
85
mean:: M
86
86
variance:: V
87
87
momentum:: MT
@@ -142,16 +142,9 @@ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Di
142
142
correction = m / (m - one (V))
143
143
144
144
running_mean, running_var = stats. mean, stats. variance
145
- if ChainRulesCore. is_inplaceable_destination (running_mean)
146
- stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
147
- else
148
- stats. mean = res_mtm .* running_mean .+ momentum .* vec (μ)
149
- end
150
- if ChainRulesCore. is_inplaceable_destination (running_var)
151
- stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
152
- else
153
- stats. variance = res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
154
- end
145
+ stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
146
+ stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
147
+ return
155
148
end
156
149
157
150
# Convenience functions
@@ -190,7 +183,7 @@ function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias =
190
183
throw (DimensionMismatch (" got $S reduction dims for $N -dimensional array" ))
191
184
end
192
185
μ, σ² = norm_stats (x, ntuple (identity, S))
193
- return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S])
186
+ return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S]:: Dims{S} )
194
187
end
195
188
196
189
"""
0 commit comments