Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EnzymeRule for conv!/gather!/scatter!/dropout!/pool! #536

Merged
merged 26 commits into from
Sep 28, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Also add gather
wsmoses committed Sep 27, 2023
commit af42451f69e277050a36e48682fee8b9b9945e1c
63 changes: 59 additions & 4 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
@@ -57,16 +57,71 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
end

for (dy, dx, dw) in zip(dys, dxs, dws)
if !(typeof(x) <: EnzymeCore.Const) && dx !== x
# dx += grad wrt x
if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
# dx += grad wrt x.val
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
end
if !(typeof(w) <: EnzymeCore.Const) && dw !== w
# dw += grad wrt w
if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
# dw += grad wrt w.val
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
end
dy .= 0
end

return (nothing, nothing, nothing, nothing)
end


function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}

@assert !(OutType <: EnzymeCore.Const)
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed
func.val(dst.val, src.val, idx.val)
end

primal = if EnzymeCore.EnzymeRules.needs_primal(config)
dst.val
else
nothing
end
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
dst.dval
else
nothing
end

# Cache idx if its overwritten
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing

return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
end

function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}

# Don't cache idx if not overwritten
if !(typeof(src) <: EnzymeCore.Const)
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
cache_idx = idx.val
end
end

ddsts = dst.dval
dsrcs = src.dval

if EnzymeCore.EnzymeRules.width(config) == 1
ddsts = (ddsts,)
dsrcs = (dsrcs,)
end

for (ddst, dsrc) in zip(ddsts, dsrcs)
if !(typeof(src) <: EnzymeCore.Const) && ddst !== dst.val
src_size = size(src.val)
NNlib.∇gather_src(ddst, src_size, cache_idx)
end
if !(typeof(w) <: EnzymeCore.Const) && dw !== w
ddst .= 0
end
end

return (nothing, nothing, nothing, nothing)
end
2 changes: 1 addition & 1 deletion test/conv.jl
Original file line number Diff line number Diff line change
@@ -870,7 +870,7 @@ end
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
gradtest((x, w) -> conv(x, w, cdims), x, w)
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rule=true) # https://github.com/FluxML/Flux.jl/issues/1055
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rrule=true) # https://github.com/FluxML/Flux.jl/issues/1055

y = conv(x, w, cdims)
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
13 changes: 13 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using NNlib: gather, gather!
import EnzymeTestUtils
using EnzymeCore

function gather_testsuite(Backend)
device(x) = adapt(Backend(), x)
@@ -150,6 +152,17 @@ function gather_testsuite(Backend)
Backend == CPU ?
gradtest_fn(xs -> gather(xs, idx), src) :
gradtest_fn((s, i) -> gather(s, i), src, idx)

if Backend == CPU
for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated),
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)

EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue

EnzymeTestUtils.test_reverse(fun, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const))
end
end
end

@testset "gather gradient for tuple index" begin
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ using NNlib, Test, Statistics, Random
using ChainRulesCore, ChainRulesTestUtils
using Base.Broadcast: broadcasted
import EnzymeTestUtils
using EnzymeCore
import FiniteDifferences
import ForwardDiff
import Zygote
10 changes: 5 additions & 5 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -22,13 +22,13 @@ function gradtest(
end
if check_enzyme_rrule
if len(xs) == 2
for Tret in (Const, Active),
Tx in (Const, Duplicated, BatchDuplicated),
Ty in (Const, Duplicated, BatchDuplicated)
for Tret in (EnzymeCore.Const, EnzymeCore.Active),
Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)

are_activities_compatible(Tret, Tx, Ty) || continue
EnzymeTestUtils.are_activities_compatible(Tret, Tx, Ty) || continue

test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol)
EnzymeTestUtils.test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol)
end
else
throw(AssertionError("Unsupported arg count for testing"))