From 79c437cf219b206a76f15f6419e8a110293cca73 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 15 Jun 2023 20:08:14 +0200 Subject: [PATCH] port batchnorm rrule from Flux --- src/normalization.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/normalization.jl b/src/normalization.jl index 48fc53de6..c06843d38 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -2,3 +2,13 @@ function batchnorm end function ∇batchnorm end + + +function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...) + y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) + function batchnorm_pullback(Δ) + grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...) + (NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent()) + end + y, batchnorm_pullback +end