diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..7a73a41 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,2 @@ +{ +} \ No newline at end of file diff --git a/Project.toml b/Project.toml index dac85ae..4ed003e 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.7" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -13,9 +14,9 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] Adapt = "3.3" -cuDNN = "1" CUDA = "4" NNlib = "0.8.15" +cuDNN = "1" julia = "1.6" [extras] diff --git a/src/cudnn/batchnorm.jl b/src/cudnn/batchnorm.jl index 1e20fbc..42cadf6 100644 --- a/src/cudnn/batchnorm.jl +++ b/src/cudnn/batchnorm.jl @@ -1,155 +1,214 @@ -using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, - cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, - cudnnBatchNormalizationForwardTraining +using cuDNN: cudnnNormalizationForward!, cudnnNormalizationBackward, + CUDNN_NORM_PER_CHANNEL, CUDNN_TENSOR_NCHW -# TODO: replace with new cudnn normalization interface -# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl - mutable struct BNCache - mean - ivar + mean + ivar end BNCache() = BNCache(nothing, nothing) @inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) -function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, - running_mean, running_var, momentum; kws...) - affine_sz = _wsize(x) - g = fill!(similar(x, affine_sz), 1) - b = fill!(similar(x, affine_sz), 0) - return batchnorm(g, b, x, running_mean, running_var, momentum; kws...) -end - -# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations -# so reshape a 2D Tensor into 4D -function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2}, - running_mean, running_var, momentum; kws...) where T<:CUDNNFloat - x = reshape(x, 1, 1, size(x, 1), size(x, 2)) - y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...) - return dropdims(y, dims = (1, 2)) -end - function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}}, - running_mean, running_var, momentum; kws...) where T<:CUDNNFloat - cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...) + running_mean, running_var, momentum; kws...) where T<:CUDNNFloat + batchnorm!(similar(x), g, b, x, running_mean, running_var, momentum; kws...) end -function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, - running_mean, running_var, momentum; - cache = nothing, - alpha = T(1), beta = T(0), - eps = T(1e-5), - training = true, - affine = true, - track_stats = true) where T<:CUDNNFloat - dims = _wsize(x) - if eps < CUDNN_BN_MIN_EPSILON - @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" - eps = CUDNN_BN_MIN_EPSILON - end - - if running_mean === nothing || running_var === nothing - running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing")) - if track_stats || !training - running_mean = fill!(similar(x, dims), 0) - running_var = fill!(similar(x, dims), 1) +function batchnorm!(y::DenseCuArray{T}, scale::DenseCuArray{T}, bias::DenseCuArray{T}, x::DenseCuArray{T}, + running_mean, running_var, momentum; + cache = nothing, + alpha = T(1), beta = T(0), + eps = T(1e-5), + training = true, + affine = true, + track_stats = true, + + workspace = nothing, + reserveSpace = nothing, + ) where T + + dims = _wsize(x) + mode = CUDNN_NORM_PER_CHANNEL + format = CUDNN_TENSOR_NCHW + + + if running_mean === nothing || running_var === nothing + running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing")) + if track_stats || !training + running_mean = fill!(similar(x, dims), 0) + running_var = fill!(similar(x, dims), 1) + end end - end - - xd = cudnnTensorDescriptor(x) - yd = cudnnTensorDescriptor(y) - gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW))) - if training - if !track_stats - running_mean = CU_NULL - running_var = CU_NULL - end + # default training => momentum > 0, use_estimates=false, gradients calculated + # default inference => momentum = 0, use_estimates=true, gradients not calculated - if cache !== nothing - mean = fill!(similar(x, dims), 0) - ivar = fill!(similar(x, dims), 1) + kws = (; mode, format, alpha, beta, epsilon=eps, workspace, reserveSpace) + if training && cache !== nothing + savedMean = fill!(similar(x, dims), 0) + savedInvVariance = fill!(similar(x, dims), 1) else - mean = CU_NULL - ivar = CU_NULL + savedMean = nothing + savedInvVariance = nothing end - cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar) + cudnnNormalizationForward!(y, x, running_mean, running_var, bias, scale; + training, exponentialAverageFactor=momentum, + savedMean, savedInvVariance, + kws...) - if cache !== nothing - cache.mean = mean - cache.ivar = ivar + if training && cache !== nothing + cache.mean = savedMean + cache.ivar = savedInvVariance end - else - cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) - end - return y -end - -function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray, - running_mean, running_var, momentum; kws...) - affine_sz = _wsize(x) - g = fill!(similar(x, affine_sz), 1) - b = fill!(similar(x, affine_sz), 0) - return ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...) -end - -function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2}, - running_mean, running_var, momentum; - kws...) where T<:CUDNNFloat - dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), - size(dy, 2)), running_mean, running_var, momentum; kws...) - (dg, db, dropdims(dx, dims = (1, 2))) -end - - -function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, - running_mean, running_var, momentum; - affine=true, kws...) where T<:CUDNNFloat - dg = similar(g) - db = similar(b) - dx = similar(x) - cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...) - if affine - (dg, db, dx) - else - # cuDNN always calculates dg and db, therefore we just have to drop them - (nothing, nothing, dx) - end -end -function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T}, - dx::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, - running_mean, running_var, - momentum; cache = nothing, eps = T(1e-5), - alpha = T(1), beta = T(0), - dalpha = T(1), dbeta = T(0), training = true, - track_stats = true) where T<:CUDNNFloat - if !track_stats - running_mean = CU_NULL - running_var = CU_NULL - end - - xd = cudnnTensorDescriptor(x) - dyd = cudnnTensorDescriptor(dy) - dxd = cudnnTensorDescriptor(dx) - gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW))) - if cache !== nothing - @debug "fetching mean and ivar from the cache" - mean, ivar = cache.mean, cache.ivar - else - mean, ivar = CU_NULL, CU_NULL - end - - if eps < CUDNN_BN_MIN_EPSILON - @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" - eps = CUDNN_BN_MIN_EPSILON - end - - cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL, - scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta), - xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar) + return y end + # if training + # if !use_estimates && momentum == 0 + # cudnnNormalizationForward!(y, x, nothing, nothing, bias, scale; training=true, exponentialAverageFactor=0, kw...) + # elseif !use_estimates && momentum > 0 + # cudnnNormalizationForward!(y, x, mean_estimate, var_estimate, bias, scale; training=true, exponentialAverageFactor=momentum, kw...) + # elseif use_estimates && momentum == 0 + # ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias + # elseif use_estimates && momentum > 0 + # update_estimates!(x, mean_estimate, var_estimate, momentum) + # ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias + # end + # else + # if !use_estimates && momentum == 0 + # cudnnNormalizationForward(x, nothing, nothing, bias, scale; training=true, exponentialAverageFactor=0, kw...) + # elseif !use_estimates && momentum > 0 + # cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=true, exponentialAverageFactor=momentum, kw...) + # elseif use_estimates && momentum == 0 + # cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=false, kw...) + # elseif use_estimates && momentum > 0 + # update_estimates!(x, mean_estimate, var_estimate, momentum) + # cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=false, kw...) + # end + # end +# end + + +# function update_estimates!(x, mean_estimate, var_estimate, update) +# (x, mean_estimate, var_estimate, update) = value.((x, mean_estimate, var_estimate, update)) +# dims = findall(size(mean_estimate) .== 1) +# xmean = mean(x; dims) +# xvar = var(x; dims, mean=xmean, corrected=false) +# update = eltype(x)(update) +# mean_estimate .= xmean * update + mean_estimate * (1-update) +# var_estimate .= xvar * update + var_estimate * (1-update) +# end + + + +# function batchnorm( +# x::GPUVal, mean_estimate::GPUVal, var_estimate::GPUVal, bias::GPUVal, scale::GPUVal; +# use_estimates = !Knet.training(), +# update =training ? 0.1 : 0.0, +# epsilon = 1e-5, +# mode = nothing, +# format = nothing, +# savedMean = nothing, +# savedVar = nothing, +# workspace = nothing, +# reserveSpace = nothing, +# dx = Ref{Any}(nothing), +# dscale = Ref{Any}(nothing), +# dbias = Ref{Any}(nothing), +# o...) +# @assert size(mean_estimate) == size(var_estimate) == size(bias) == size(scale) +# n = ndims(x) +# if size(mean_estimate) == ntuple(i->(i===n-1 ? size(x,i) : 1), n) +# mode === nothing ? mode = CUDNN_NORM_PER_CHANNEL : @assert mode === CUDNN_NORM_PER_CHANNEL +# format === nothing ? format = CUDNN_TENSOR_NCHW : @assert format === CUDNN_TENSOR_NCHW +# elseif size(mean_estimate) == ntuple(i->(i===1 ? size(x,i) : 1), n) +# mode === nothing ? mode = CUDNN_NORM_PER_CHANNEL : @assert mode === CUDNN_NORM_PER_CHANNEL +# format === nothing ? format = CUDNN_TENSOR_NHWC : @assert format === CUDNN_TENSOR_NHWC +# elseif size(mean_estimate) == ntuple(i->(i===n ? 1 : size(x,i)), n) +# mode === nothing ? mode = CUDNN_NORM_PER_ACTIVATION : @assert mode === CUDNN_NORM_PER_ACTIVATION +# format === nothing ? format = CUDNN_TENSOR_NCHW : @assert format === CUDNN_TENSOR_NCHW +# else +# error("Unsupported batchnorm size x=$(size(x)) m=$(size(m))") +# end +# # default training => update > 0, use_estimates=false, gradients calculated +# # default inference => update = 0, use_estimates=true, gradients not calculated +# # Other combinations must be manually implemented +# kw = (; mode, format, epsilon, savedMean, savedInvVariance=savedVar, workspace, reserveSpace, dx, dscale, dbias) +# if training && !use_estimates && update == 0 +# cudnnNormalizationForward(x, nothing, nothing, bias, scale; training=true, exponentialAverageFactor=0, kw...) +# elseif training && !use_estimates && update > 0 +# (mean_estimate, var_estimate) = value.((mean_estimate, var_estimate)) +# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=true, exponentialAverageFactor=update, kw...) +# elseif training && use_estimates && update == 0 +# ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias +# elseif training && use_estimates && update > 0 +# update_estimates!(x, mean_estimate, var_estimate, update) +# ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias +# elseif !training && !use_estimates && update == 0 +# cudnnNormalizationForward(x, nothing, nothing, bias, scale; training=true, exponentialAverageFactor=0, kw...) +# elseif !training && !use_estimates && update > 0 +# (mean_estimate, var_estimate) = value.((mean_estimate, var_estimate)) +# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=true, exponentialAverageFactor=update, kw...) +# elseif !training && use_estimates && update == 0 +# (mean_estimate, var_estimate) = value.((mean_estimate, var_estimate)) +# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=false, kw...) +# elseif !training && use_estimates && update > 0 +# (mean_estimate, var_estimate) = value.((mean_estimate, var_estimate)) +# update_estimates!(x, mean_estimate, var_estimate, update) +# cudnnNormalizationForward(x, mean_estimate, var_estimate, bias, scale; training=false, kw...) +# end +# end + +# function update_estimates!(x, mean_estimate, var_estimate, update) +# (x, mean_estimate, var_estimate, update) = value.((x, mean_estimate, var_estimate, update)) +# dims = findall(size(mean_estimate) .== 1) +# xmean = mean(x; dims) +# xvar = var(x; dims, mean=xmean, corrected=false) +# update = eltype(x)(update) +# mean_estimate .= xmean * update + mean_estimate * (1-update) +# var_estimate .= xvar * update + var_estimate * (1-update) +# end + + + + + + + + + + + + + +########################################## +##### FROM KNEt +# function batchnorm( +# x, mean_estimate, var_estimate, bias, scale; +# epsilon = 1e-5, +# update =training ? 0.1 : 0.0, +# use_estimates = !Knet.training(), +# o... +# ) +# update,epsilon = eltype(x).((update,epsilon)) +# if update > 0 || !use_estimates +# dims = findall(size(mean_estimate) .== 1) +# xmean = mean(x; dims) +# xvar = var(x; dims, mean=xmean, corrected=false) +# end +# if update > 0 +# (m, v, xm, xv) = value.((mean_estimate, var_estimate, xmean, xvar)) +# m .= xm * update + m * (1-update) +# v .= xv * update + v * (1-update) +# end +# if use_estimates +# y = ((x .- mean_estimate) ./ sqrt.(epsilon .+ var_estimate)) .* scale .+ bias +# else +# y = ((x .- xmean) ./ sqrt.(epsilon .+ xvar)) .* scale .+ bias +# end +# return y +# end +############ \ No newline at end of file diff --git a/src/cudnn/batchnorm_old.jl b/src/cudnn/batchnorm_old.jl new file mode 100644 index 0000000..1e20fbc --- /dev/null +++ b/src/cudnn/batchnorm_old.jl @@ -0,0 +1,155 @@ +using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, + cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, + cudnnBatchNormalizationForwardTraining + + +# TODO: replace with new cudnn normalization interface +# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl + +mutable struct BNCache + mean + ivar +end + +BNCache() = BNCache(nothing, nothing) + +@inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + +function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, + running_mean, running_var, momentum; kws...) + affine_sz = _wsize(x) + g = fill!(similar(x, affine_sz), 1) + b = fill!(similar(x, affine_sz), 0) + return batchnorm(g, b, x, running_mean, running_var, momentum; kws...) +end + +# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations +# so reshape a 2D Tensor into 4D +function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2}, + running_mean, running_var, momentum; kws...) where T<:CUDNNFloat + x = reshape(x, 1, 1, size(x, 1), size(x, 2)) + y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...) + return dropdims(y, dims = (1, 2)) +end + +function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}}, + running_mean, running_var, momentum; kws...) where T<:CUDNNFloat + cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...) +end + +function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, + running_mean, running_var, momentum; + cache = nothing, + alpha = T(1), beta = T(0), + eps = T(1e-5), + training = true, + affine = true, + track_stats = true) where T<:CUDNNFloat + dims = _wsize(x) + if eps < CUDNN_BN_MIN_EPSILON + @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" + eps = CUDNN_BN_MIN_EPSILON + end + + if running_mean === nothing || running_var === nothing + running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing")) + if track_stats || !training + running_mean = fill!(similar(x, dims), 0) + running_var = fill!(similar(x, dims), 1) + end + end + + xd = cudnnTensorDescriptor(x) + yd = cudnnTensorDescriptor(y) + gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW))) + + if training + if !track_stats + running_mean = CU_NULL + running_var = CU_NULL + end + + if cache !== nothing + mean = fill!(similar(x, dims), 0) + ivar = fill!(similar(x, dims), 1) + else + mean = CU_NULL + ivar = CU_NULL + end + + cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar) + + if cache !== nothing + cache.mean = mean + cache.ivar = ivar + end + else + cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) + end + return y +end + +function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray, + running_mean, running_var, momentum; kws...) + affine_sz = _wsize(x) + g = fill!(similar(x, affine_sz), 1) + b = fill!(similar(x, affine_sz), 0) + return ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...) +end + +function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2}, + running_mean, running_var, momentum; + kws...) where T<:CUDNNFloat + dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), + size(dy, 2)), running_mean, running_var, momentum; kws...) + (dg, db, dropdims(dx, dims = (1, 2))) +end + + +function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, + running_mean, running_var, momentum; + affine=true, kws...) where T<:CUDNNFloat + dg = similar(g) + db = similar(b) + dx = similar(x) + cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...) + if affine + (dg, db, dx) + else + # cuDNN always calculates dg and db, therefore we just have to drop them + (nothing, nothing, dx) + end +end + +function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T}, + dx::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, + running_mean, running_var, + momentum; cache = nothing, eps = T(1e-5), + alpha = T(1), beta = T(0), + dalpha = T(1), dbeta = T(0), training = true, + track_stats = true) where T<:CUDNNFloat + if !track_stats + running_mean = CU_NULL + running_var = CU_NULL + end + + xd = cudnnTensorDescriptor(x) + dyd = cudnnTensorDescriptor(dy) + dxd = cudnnTensorDescriptor(dx) + gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW))) + if cache !== nothing + @debug "fetching mean and ivar from the cache" + mean, ivar = cache.mean, cache.ivar + else + mean, ivar = CU_NULL, CU_NULL + end + + if eps < CUDNN_BN_MIN_EPSILON + @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" + eps = CUDNN_BN_MIN_EPSILON + end + + cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL, + scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta), + xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar) +end diff --git a/test.jl b/test.jl new file mode 100644 index 0000000..a4d5f1e --- /dev/null +++ b/test.jl @@ -0,0 +1,17 @@ +using NNlibCUDA, NNlib, CUDA + +@inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + +# Test batchnorm +x = randn(Float32, 2, 2, 3, 5) |> cu +affine_sz = _wsize(x) +g = fill!(similar(x, affine_sz), 1) +b = fill!(similar(x, affine_sz), 0) +running_mean = fill!(similar(x, affine_sz), 0) +running_var = fill!(similar(x, affine_sz), 1) +# running_mean = CUDA.CU_NULL +# running_var = CUDA.CU_NULL +# running_mean = nothing +# running_var = nothing + +y = NNlibCUDA.batchnorm(g, b, x, running_mean, running_var, 0., training=true)