Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EnzymeRule for conv!/gather!/scatter!/dropout!/pool! #536

Merged
merged 26 commits into from
Sep 28, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cleanup rule
wsmoses committed Sep 27, 2023
commit 06dbbaad27401567a12d02bd7978d27e44a9b2d4
14 changes: 7 additions & 7 deletions src/conv.jl
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@
Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
in 1d/2d/3d convolutions respectively. `x` and `w` may have real or complex element types.
"""
function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N}
@inline function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to factor out this change to another PR

stride = expand(Val(N - 2), stride)
padding = expand(Val(N - 2), pad)
dilation = expand(Val(N - 2), dilation)
@@ -62,7 +62,7 @@ end
Depthwise convolution operation with filter `w` on input `x`. `x` and `w`
are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.
"""
function depthwiseconv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N}
@inline function depthwiseconv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N}
stride = expand(Val(N-2), stride)
pad = expand(Val(N-2), pad)
dilation = expand(Val(N-2), dilation)
@@ -80,7 +80,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
# First make auto-allocating versions of the conv()-like calls:
for name in (:conv, :depthwiseconv)
@eval begin
function $(Symbol("$(name)$(backend)"))(
@inline function $(Symbol("$(name)$(backend)"))(
x::AbstractArray{xT,N}, w::AbstractArray{wT,N},
cdims::ConvDims; kwargs...) where {xT, wT, N}
y = similar(x, promote_type(xT, wT), output_size(cdims)...,
@@ -92,7 +92,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)

for name in (:∇conv_data, :∇depthwiseconv_data)
@eval begin
function $(Symbol("$(name)$(backend)"))(
@inline function $(Symbol("$(name)$(backend)"))(
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},
cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims}
dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N))
@@ -104,7 +104,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
# We do the conv/depthwiseconv filter backprops separately, as the shape calculation
# for `w` is slightly different for depthwise than for normal dense convolution.
@eval begin
function $(Symbol("∇conv_filter$(backend)"))(
@inline function $(Symbol("∇conv_filter$(backend)"))(
x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},
cdims::ConvDims; kwargs...) where {xT, yT, N}
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims),
@@ -114,7 +114,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
end

@eval begin
function $(Symbol("∇depthwiseconv_filter$(backend)"))(
@inline function $(Symbol("∇depthwiseconv_filter$(backend)"))(
x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},
cdims::ConvDims; kwargs...) where {xT, yT, N}
dw = similar(dy, kernel_size(cdims)..., channel_multiplier(cdims),
@@ -137,7 +137,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
for backend in (Symbol(), :_direct, :_im2col) ## NNPACK is only for 2d conv
for N in (3, 4)
@eval begin
function $(Symbol("$(front_name)$(backend)!"))(
@inline function $(Symbol("$(front_name)$(backend)!"))(
y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N},
w::AbstractArray{wT,$N}, cdims::ConvDims;
kwargs...) where {yT, xT, wT}
20 changes: 14 additions & 6 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
@@ -3,9 +3,13 @@ import EnzymeCore
for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!))
@eval begin

function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT}
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT},
y::EnzymeCore.Annotation{<:AbstractArray{yT, N}},
x::EnzymeCore.Annotation{<:AbstractArray{xT, N}},
w::EnzymeCore.Annotation{<:AbstractArray{wT, N}},
cdims; kwargs...) where {RT, yT, xT, wT, N}

if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated
func.val(y.val, x.val, w.val, cdims.val; kwargs...)
end

@@ -37,7 +41,11 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT}
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache,
y::EnzymeCore.Annotation{<:AbstractArray{yT, N}},
x::EnzymeCore.Annotation{<:AbstractArray{xT, N}},
w::EnzymeCore.Annotation{<:AbstractArray{wT, N}},
cdims; kwargs...) where {RT, yT, xT, wT, N}
cache_x, cache_w = cache

# Don't cache x if not overwritten and w is active (and thus required)
@@ -65,15 +73,15 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, :
end

for (dy, dx, dw) in zip(dys, dxs, dws)
if !(typeof(y) <: EnzymeCore.Const) && dy !== w.val
if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val

if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
# dx += grad wrt x.val
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...)
end
if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
# dw += grad wrt w.val
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...)
end

dy .= 0