diff --git a/Project.toml b/Project.toml index ce57272177..3c71524e2f 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ NNlib = "0.7" Reexport = "0.2" StatsBase = "0.33" ZipFile = "0.9" -Zygote = "0.5" +Zygote = "0.6" julia = "1.5" [extras] diff --git a/src/deprecations.jl b/src/deprecations.jl index 9752337153..28546346be 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -31,44 +31,3 @@ function Broadcast.broadcasted(::typeof(logitbinarycrossentropy), ŷ, y) @warn "logitbinarycrossentropy.(ŷ, y) is deprecated, use Losses.logitbinarycrossentropy(ŷ, y, agg=identity) instead" Losses.logitbinarycrossentropy(ŷ, y, agg=identity) end - - -# To move to Zygote - -using Base.Broadcast: broadcasted - - -Zygote.@adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) - y = b === false ? a : a .+ b - y, Δ -> (nothing, Δ, nothing) -end -Zygote.@adjoint function broadcasted(::typeof(+), b::Bool, a::AbstractArray{<:Number}) - y = b === false ? a : b .+ a - y, Δ -> (nothing, nothing, Δ) -end - - -Zygote.@adjoint function broadcasted(::typeof(-), a::AbstractArray{<:Number}, b::Bool) - y = b === false ? a : a .- b - y, Δ -> (nothing, Δ, nothing) -end -Zygote.@adjoint function broadcasted(::typeof(-), b::Bool, a::AbstractArray{<:Number}) - b .- a, Δ -> (nothing, nothing, .-Δ) -end - - -Zygote.@adjoint function broadcasted(::typeof(*), a::AbstractArray{<:Number}, b::Bool) - if b === false - zero(a), Δ -> (nothing, zero(Δ), nothing) - else - a, Δ -> (nothing, Δ, nothing) - end -end -Zygote.@adjoint function broadcasted(::typeof(*), b::Bool, a::AbstractArray{<:Number}) - if b === false - zero(a), Δ -> (nothing, nothing, zero(Δ)) - else - a, Δ -> (nothing, nothing, Δ) - end -end -