@@ -52,7 +52,8 @@ function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
52
52
error (" both scale and bias must be provided or left as nothing" )
53
53
end
54
54
scale′, bias′ = _maybe_reshape (scale, affine_size), _maybe_reshape (bias, affine_size)
55
- return _apply_scale_bias ((x .- μ) ./ sqrt .(σ² .+ ϵ), scale′, bias′)
55
+ denom = inv .(sqrt .(σ² .+ ϵ))
56
+ return _apply_scale_bias ((x .- μ) .* denom, scale′, bias′)
56
57
end
57
58
58
59
"""
61
62
Contains running mean and variance estimates for stateful norm functions.
62
63
`momentum` controls the strength of the moving average update.
63
64
64
- If the parameters are mutable, they will be updated in-place.
65
- Otherwise, they will be replaced wholesale.
65
+ Parameters should be mutable and will be updated in-place.
66
66
67
67
See also [`update_running_stats!`](@ref).
68
68
"""
69
- mutable struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
69
+ struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
70
70
mean:: M
71
71
variance:: V
72
72
momentum:: MT
@@ -127,16 +127,9 @@ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Di
127
127
correction = m / (m - one (V))
128
128
129
129
running_mean, running_var = stats. mean, stats. variance
130
- if ChainRulesCore. is_inplaceable_destination (running_mean)
131
- stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
132
- else
133
- stats. mean = res_mtm .* running_mean .+ momentum .* vec (μ)
134
- end
135
- if ChainRulesCore. is_inplaceable_destination (running_var)
136
- stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
137
- else
138
- stats. variance = res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
139
- end
130
+ stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
131
+ stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
132
+ return
140
133
end
141
134
142
135
# Convenience functions
@@ -175,7 +168,7 @@ function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias =
175
168
throw (DimensionMismatch (" got $S reduction dims for $N -dimensional array" ))
176
169
end
177
170
μ, σ² = norm_stats (x, ntuple (identity, S))
178
- return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S])
171
+ return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S]:: Dims{S} )
179
172
end
180
173
181
174
"""
0 commit comments