Skip to content

Commit

Permalink
remove tricks
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith authored Oct 9, 2024
1 parent f1b7fe1 commit edb18e9
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions ext/SparseDiffToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import SparseDiffTools: SparseDiffTools, DeivVecTag, AutoDiffVJP, __test_backend
import ForwardDiff: ForwardDiff, Dual, partials
import SciMLOperators: update_coefficients, update_coefficients!
import Setfield: @set!
import Tricks: static_hasmethod

import SparseDiffTools: numback_hesvec!,
numback_hesvec, autoback_hesvec!, autoback_hesvec, auto_vecjac!,
Expand Down Expand Up @@ -101,7 +100,7 @@ end

# VJP methods
function auto_vecjac!(du, f::F, x, v) where {F}
!static_hasmethod(f, typeof((x,))) &&
!hasmethod(f, typeof((x,))) &&
error("For inplace function use autodiff = AutoFiniteDiff()")
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
end
Expand All @@ -113,7 +112,7 @@ end

# overload operator interface
function SparseDiffTools._vecjac(f::F, _, u, autodiff::AutoZygote) where {F}
!static_hasmethod(f, typeof((u,))) &&
!hasmethod(f, typeof((u,))) &&
error("For inplace function use autodiff = AutoFiniteDiff()")
pullback = Zygote.pullback(f, u)
return AutoDiffVJP(f, u, (), autodiff, pullback)
Expand Down

0 comments on commit edb18e9

Please sign in to comment.