Skip to content

Commit 5b49a64

Browse files
committedSep 30, 2023
Simplify RunningStats, faster var calculation and try fixing 1.6 inference
1 parent a9dc138 commit 5b49a64

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed
 

‎src/normalization.jl

+8-15
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
6767
error("both scale and bias must be provided or left as nothing")
6868
end
6969
scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size)
70-
return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′)
70+
denom = inv.(sqrt.(σ² .+ ϵ))
71+
return _apply_scale_bias((x .- μ) .* denom, scale′, bias′)
7172
end
7273

7374
"""
@@ -76,12 +77,11 @@ end
7677
Contains running mean and variance estimates for stateful norm functions.
7778
`momentum` controls the strength of the moving average update.
7879
79-
If the parameters are mutable, they will be updated in-place.
80-
Otherwise, they will be replaced wholesale.
80+
Parameters should be mutable and will be updated in-place.
8181
8282
See also [`update_running_stats!`](@ref).
8383
"""
84-
mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real}
84+
struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real}
8585
mean::M
8686
variance::V
8787
momentum::MT
@@ -142,16 +142,9 @@ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Di
142142
correction = m / (m - one(V))
143143

144144
running_mean, running_var = stats.mean, stats.variance
145-
if ChainRulesCore.is_inplaceable_destination(running_mean)
146-
stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ)
147-
else
148-
stats.mean = res_mtm .* running_mean .+ momentum .* vec(μ)
149-
end
150-
if ChainRulesCore.is_inplaceable_destination(running_var)
151-
stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
152-
else
153-
stats.variance = res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
154-
end
145+
stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ)
146+
stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
147+
return
155148
end
156149

157150
# Convenience functions
@@ -190,7 +183,7 @@ function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias =
190183
throw(DimensionMismatch("got $S reduction dims for $N-dimensional array"))
191184
end
192185
μ, σ² = norm_stats(x, ntuple(identity, S))
193-
return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S])
186+
return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S]::Dims{S})
194187
end
195188

196189
"""

0 commit comments

Comments
 (0)
Please sign in to comment.