Skip to content

rework batchnorm #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{
}
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.2.7"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -13,9 +14,9 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
Adapt = "3.3"
cuDNN = "1"
CUDA = "4"
NNlib = "0.8.15"
cuDNN = "1"
julia = "1.6"

[extras]
Expand Down
325 changes: 192 additions & 133 deletions src/cudnn/batchnorm.jl
Original file line number Diff line number Diff line change
@@ -1,155 +1,214 @@
using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL,
cudnnBatchNormalizationForwardTraining
using cuDNN: cudnnNormalizationForward!, cudnnNormalizationBackward,
CUDNN_NORM_PER_CHANNEL, CUDNN_TENSOR_NCHW


# TODO: replace with new cudnn normalization interface
# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl

mutable struct BNCache
mean
ivar
mean
ivar
end

BNCache() = BNCache(nothing, nothing)

@inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)

function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray,
running_mean, running_var, momentum; kws...)
affine_sz = _wsize(x)
g = fill!(similar(x, affine_sz), 1)
b = fill!(similar(x, affine_sz), 0)
return batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
end

# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
# so reshape a 2D Tensor into 4D
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
x = reshape(x, 1, 1, size(x, 1), size(x, 2))
y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
return dropdims(y, dims = (1, 2))
end

function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}},
running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...)
running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
batchnorm!(similar(x), g, b, x, running_mean, running_var, momentum; kws...)
end

function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T},
running_mean, running_var, momentum;
cache = nothing,
alpha = T(1), beta = T(0),
eps = T(1e-5),
training = true,
affine = true,
track_stats = true) where T<:CUDNNFloat
dims = _wsize(x)
if eps < CUDNN_BN_MIN_EPSILON
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
eps = CUDNN_BN_MIN_EPSILON
end

if running_mean === nothing || running_var === nothing
running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing"))
if track_stats || !training
running_mean = fill!(similar(x, dims), 0)
running_var = fill!(similar(x, dims), 1)
function batchnorm!(y::DenseCuArray{T}, scale::DenseCuArray{T}, bias::DenseCuArray{T}, x::DenseCuArray{T},
running_mean, running_var, momentum;
cache = nothing,
alpha = T(1), beta = T(0),
eps = T(1e-5),
training = true,
affine = true,
track_stats = true,

workspace = nothing,
reserveSpace = nothing,
) where T

dims = _wsize(x)
mode = CUDNN_NORM_PER_CHANNEL
format = CUDNN_TENSOR_NCHW


if running_mean === nothing || running_var === nothing
running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing"))
if track_stats || !training
running_mean = fill!(similar(x, dims), 0)
running_var = fill!(similar(x, dims), 1)
end
end
end

xd = cudnnTensorDescriptor(x)
yd = cudnnTensorDescriptor(y)
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW)))

if training
if !track_stats
running_mean = CU_NULL
running_var = CU_NULL
end
# default training => momentum > 0, use_estimates=false, gradients calculated
# default inference => momentum = 0, use_estimates=true, gradients not calculated

if cache !== nothing
mean = fill!(similar(x, dims), 0)
ivar = fill!(similar(x, dims), 1)
kws = (; mode, format, alpha, beta, epsilon=eps, workspace, reserveSpace)
if training && cache !== nothing
savedMean = fill!(similar(x, dims), 0)
savedInvVariance = fill!(similar(x, dims), 1)
else
mean = CU_NULL
ivar = CU_NULL
savedMean = nothing
savedInvVariance = nothing
end

cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar)
cudnnNormalizationForward!(y, x, running_mean, running_var, bias, scale;
training, exponentialAverageFactor=momentum,
savedMean, savedInvVariance,
kws...)

if cache !== nothing
cache.mean = mean
cache.ivar = ivar
if training && cache !== nothing
cache.mean = savedMean
cache.ivar = savedInvVariance
end
else
cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps)
end
return y
end

function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray,
running_mean, running_var, momentum; kws...)
affine_sz = _wsize(x)
g = fill!(similar(x, affine_sz), 1)
b = fill!(similar(x, affine_sz), 0)
return ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...)
end

function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2},
running_mean, running_var, momentum;
kws...) where T<:CUDNNFloat
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
size(dy, 2)), running_mean, running_var, momentum; kws...)
(dg, db, dropdims(dx, dims = (1, 2)))
end


function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
running_mean, running_var, momentum;
affine=true, kws...) where T<:CUDNNFloat
dg = similar(g)
db = similar(b)
dx = similar(x)
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...)
if affine
(dg, db, dx)
else
# cuDNN always calculates dg and db, therefore we just have to drop them
(nothing, nothing, dx)
end
end

function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T},
dx::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
running_mean, running_var,
momentum; cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0),
dalpha = T(1), dbeta = T(0), training = true,
track_stats = true) where T<:CUDNNFloat
if !track_stats
running_mean = CU_NULL
running_var = CU_NULL
end

xd = cudnnTensorDescriptor(x)
dyd = cudnnTensorDescriptor(dy)
dxd = cudnnTensorDescriptor(dx)
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))
if cache !== nothing
@debug "fetching mean and ivar from the cache"
mean, ivar = cache.mean, cache.ivar
else
mean, ivar = CU_NULL, CU_NULL
end

if eps < CUDNN_BN_MIN_EPSILON
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
eps = CUDNN_BN_MIN_EPSILON
end

cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
return y
end
# if training
# if !use_estimates && momentum == 0
# cudnnNormalizationForward!(y, x, nothing, nothing, bias, scale; training=true, exponentialAverageFactor=0, kw...)
# elseif !use_estimates && momentum > 0
# cudnnNormalizationForward!(y, x, mean_estimate, var_estimate, bias, scale; training=true, exponentialAverageFactor=momentum, kw...)
# elseif use_estimates && momentum == 0
# ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias
# elseif use_estimates && momentum > 0
# update_estimates!(x, mean_estimate, var_estimate, momentum)
# ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias
# end
# else
# if !use_estimates && momentum == 0
# cudnnNormalizationForward(x, nothing, nothing, bias, scale; training=true, exponentialAverageFactor=0, kw...)
# elseif !use_estimates && momentum > 0
# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=true, exponentialAverageFactor=momentum, kw...)
# elseif use_estimates && momentum == 0
# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=false, kw...)
# elseif use_estimates && momentum > 0
# update_estimates!(x, mean_estimate, var_estimate, momentum)
# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=false, kw...)
# end
# end
# end


# function update_estimates!(x, mean_estimate, var_estimate, update)
# (x, mean_estimate, var_estimate, update) = value.((x, mean_estimate, var_estimate, update))
# dims = findall(size(mean_estimate) .== 1)
# xmean = mean(x; dims)
# xvar = var(x; dims, mean=xmean, corrected=false)
# update = eltype(x)(update)
# mean_estimate .= xmean * update + mean_estimate * (1-update)
# var_estimate .= xvar * update + var_estimate * (1-update)
# end



# function batchnorm(
# x::GPUVal, mean_estimate::GPUVal, var_estimate::GPUVal, bias::GPUVal, scale::GPUVal;
# use_estimates = !Knet.training(),
# update =training ? 0.1 : 0.0,
# epsilon = 1e-5,
# mode = nothing,
# format = nothing,
# savedMean = nothing,
# savedVar = nothing,
# workspace = nothing,
# reserveSpace = nothing,
# dx = Ref{Any}(nothing),
# dscale = Ref{Any}(nothing),
# dbias = Ref{Any}(nothing),
# o...)
# @assert size(mean_estimate) == size(var_estimate) == size(bias) == size(scale)
# n = ndims(x)
# if size(mean_estimate) == ntuple(i->(i===n-1 ? size(x,i) : 1), n)
# mode === nothing ? mode = CUDNN_NORM_PER_CHANNEL : @assert mode === CUDNN_NORM_PER_CHANNEL
# format === nothing ? format = CUDNN_TENSOR_NCHW : @assert format === CUDNN_TENSOR_NCHW
# elseif size(mean_estimate) == ntuple(i->(i===1 ? size(x,i) : 1), n)
# mode === nothing ? mode = CUDNN_NORM_PER_CHANNEL : @assert mode === CUDNN_NORM_PER_CHANNEL
# format === nothing ? format = CUDNN_TENSOR_NHWC : @assert format === CUDNN_TENSOR_NHWC
# elseif size(mean_estimate) == ntuple(i->(i===n ? 1 : size(x,i)), n)
# mode === nothing ? mode = CUDNN_NORM_PER_ACTIVATION : @assert mode === CUDNN_NORM_PER_ACTIVATION
# format === nothing ? format = CUDNN_TENSOR_NCHW : @assert format === CUDNN_TENSOR_NCHW
# else
# error("Unsupported batchnorm size x=$(size(x)) m=$(size(m))")
# end
# # default training => update > 0, use_estimates=false, gradients calculated
# # default inference => update = 0, use_estimates=true, gradients not calculated
# # Other combinations must be manually implemented
# kw = (; mode, format, epsilon, savedMean, savedInvVariance=savedVar, workspace, reserveSpace, dx, dscale, dbias)
# if training && !use_estimates && update == 0
# cudnnNormalizationForward(x, nothing, nothing, bias, scale; training=true, exponentialAverageFactor=0, kw...)
# elseif training && !use_estimates && update > 0
# (mean_estimate, var_estimate) = value.((mean_estimate, var_estimate))
# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=true, exponentialAverageFactor=update, kw...)
# elseif training && use_estimates && update == 0
# ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias
# elseif training && use_estimates && update > 0
# update_estimates!(x, mean_estimate, var_estimate, update)
# ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias
# elseif !training && !use_estimates && update == 0
# cudnnNormalizationForward(x, nothing, nothing, bias, scale; training=true, exponentialAverageFactor=0, kw...)
# elseif !training && !use_estimates && update > 0
# (mean_estimate, var_estimate) = value.((mean_estimate, var_estimate))
# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=true, exponentialAverageFactor=update, kw...)
# elseif !training && use_estimates && update == 0
# (mean_estimate, var_estimate) = value.((mean_estimate, var_estimate))
# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=false, kw...)
# elseif !training && use_estimates && update > 0
# (mean_estimate, var_estimate) = value.((mean_estimate, var_estimate))
# update_estimates!(x, mean_estimate, var_estimate, update)
# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=false, kw...)
# end
# end

# function update_estimates!(x, mean_estimate, var_estimate, update)
# (x, mean_estimate, var_estimate, update) = value.((x, mean_estimate, var_estimate, update))
# dims = findall(size(mean_estimate) .== 1)
# xmean = mean(x; dims)
# xvar = var(x; dims, mean=xmean, corrected=false)
# update = eltype(x)(update)
# mean_estimate .= xmean * update + mean_estimate * (1-update)
# var_estimate .= xvar * update + var_estimate * (1-update)
# end













##########################################
##### FROM KNEt
# function batchnorm(
# x, mean_estimate, var_estimate, bias, scale;
# epsilon = 1e-5,
# update =training ? 0.1 : 0.0,
# use_estimates = !Knet.training(),
# o...
# )
# update,epsilon = eltype(x).((update,epsilon))
# if update > 0 || !use_estimates
# dims = findall(size(mean_estimate) .== 1)
# xmean = mean(x; dims)
# xvar = var(x; dims, mean=xmean, corrected=false)
# end
# if update > 0
# (m, v, xm, xv) = value.((mean_estimate, var_estimate, xmean, xvar))
# m .= xm * update + m * (1-update)
# v .= xv * update + v * (1-update)
# end
# if use_estimates
# y = ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias
# else
# y = ((x .- xmean) ./ sqrt.(epsilon .+ xvar)) .* scale .+ bias
# end
# return y
# end
############
Loading