Skip to content

Commit 6882593

Browse files
committed
Simplify RunningStats, faster var calculation and try fixing 1.6 inference
1 parent 3a52deb commit 6882593

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

src/normalization.jl

+8-15
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
5252
error("both scale and bias must be provided or left as nothing")
5353
end
5454
scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size)
55-
return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′)
55+
denom = inv.(sqrt.(σ² .+ ϵ))
56+
return _apply_scale_bias((x .- μ) .* denom, scale′, bias′)
5657
end
5758

5859
"""
@@ -61,12 +62,11 @@ end
6162
Contains running mean and variance estimates for stateful norm functions.
6263
`momentum` controls the strength of the moving average update.
6364
64-
If the parameters are mutable, they will be updated in-place.
65-
Otherwise, they will be replaced wholesale.
65+
Parameters should be mutable and will be updated in-place.
6666
6767
See also [`update_running_stats!`](@ref).
6868
"""
69-
mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real}
69+
struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real}
7070
mean::M
7171
variance::V
7272
momentum::MT
@@ -127,16 +127,9 @@ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Di
127127
correction = m / (m - one(V))
128128

129129
running_mean, running_var = stats.mean, stats.variance
130-
if ChainRulesCore.is_inplaceable_destination(running_mean)
131-
stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ)
132-
else
133-
stats.mean = res_mtm .* running_mean .+ momentum .* vec(μ)
134-
end
135-
if ChainRulesCore.is_inplaceable_destination(running_var)
136-
stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
137-
else
138-
stats.variance = res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
139-
end
130+
stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ)
131+
stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
132+
return
140133
end
141134

142135
# Convenience functions
@@ -175,7 +168,7 @@ function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias =
175168
throw(DimensionMismatch("got $S reduction dims for $N-dimensional array"))
176169
end
177170
μ, σ² = norm_stats(x, ntuple(identity, S))
178-
return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S])
171+
return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S]::Dims{S})
179172
end
180173

181174
"""

0 commit comments

Comments
 (0)