Skip to content

Commit

Permalink
Format codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 14, 2023
1 parent 9c458cd commit 91175e6
Show file tree
Hide file tree
Showing 20 changed files with 394 additions and 422 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ Manifest.toml

docs/build/
docs/site/

.vscode
7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "2.4.1"
version = "2.5.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -13,8 +13,8 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
SparseDiffToolsZygoteExt = "Zygote"

[compat]
ADTypes = "0.1"
ADTypes = "0.2"
Adapt = "3.0"
ArrayInterface = "7.4.2"
Compat = "4"
Expand All @@ -39,7 +39,6 @@ FiniteDiff = "2.8.1"
ForwardDiff = "0.10"
Graphs = "1"
Reexport = "1"
Requires = "1"
SciMLOperators = "0.2.11, 0.3"
Setfield = "1"
StaticArrayInterface = "1.3"
Expand Down
16 changes: 8 additions & 8 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ using Documenter, SparseDiffTools
include("pages.jl")

makedocs(sitename = "SparseDiffTools.jl",
authors = "Chris Rackauckas",
modules = [SparseDiffTools],
clean = true,
doctest = false,
format = Documenter.HTML(assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/SparseDiffTools/stable/"),
pages = pages)
authors = "Chris Rackauckas",
modules = [SparseDiffTools],
clean = true,
doctest = false,
format = Documenter.HTML(assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/SparseDiffTools/stable/"),
pages = pages)

deploydocs(repo = "github.com/JuliaDiff/SparseDiffTools.jl.git";
push_preview = true)
push_preview = true)
71 changes: 43 additions & 28 deletions ext/SparseDiffToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Setfield: @set!
### Jac, Hes products

function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v),
cache2 = similar(v))
cache2 = similar(v))
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
Expand Down Expand Up @@ -39,20 +39,20 @@ function SparseDiffTools.numback_hesvec(f, x, v)
end

function SparseDiffTools.autoback_hesvec!(dy, f, x, v,
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),
eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(tuple.(reshape(v,
size(x))))),
cache2 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),
eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(tuple.(reshape(v,
size(x))))))
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),
eltype(x))),
eltype(x), 1,
}.(x,
ForwardDiff.Partials.(tuple.(reshape(v,
size(x))))),
cache2 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),
eltype(x))),
eltype(x), 1,
}.(x,
ForwardDiff.Partials.(tuple.(reshape(v,
size(x))))))
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
Expand All @@ -64,16 +64,20 @@ end

function SparseDiffTools.autoback_hesvec(f, x, v)
g = x -> first(Zygote.gradient(f, x))
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
y = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
eltype(x),
1,
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
ForwardDiff.partials.(g(y), 1)
end

## VecJac products

# VJP methods
function SparseDiffTools.auto_vecjac!(du, f, x, v)
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = AutoFiniteDiff()")
!hasmethod(f, (typeof(x),)) &&
error("For inplace function use autodiff = AutoFiniteDiff()")
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
end

Expand All @@ -84,16 +88,17 @@ end

# overload operator interface
function SparseDiffTools._vecjac(f, u, autodiff::AutoZygote)

cache = ()
pullback = Zygote.pullback(f, u)

AutoDiffVJP(f, u, cache, autodiff, pullback)
end

function update_coefficients(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
) where{AD <: AutoZygote}

function update_coefficients(L::AutoDiffVJP{AD},
u,
p,
t;
VJP_input = nothing) where {AD <: AutoZygote}
if !isnothing(VJP_input)
@set! L.u = VJP_input
end
Expand All @@ -102,9 +107,11 @@ function update_coefficients(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
@set! L.pullback = Zygote.pullback(L.f, L.u)
end

function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
) where{AD <: AutoZygote}

function update_coefficients!(L::AutoDiffVJP{AD},
u,
p,
t;
VJP_input = nothing) where {AD <: AutoZygote}
if !isnothing(VJP_input)
copy!(L.u, VJP_input)
end
Expand All @@ -116,7 +123,7 @@ function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
end

# Interpret the call as df/du' * v
function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing) where{AD <: AutoZygote}
function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing) where {AD <: AutoZygote}
# ignore VJP_input as pullback was computed in update_coefficients(...)

y, back = L.pullback
Expand All @@ -126,14 +133,22 @@ function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing) where{AD <: AutoZygo
end

# prefer non in-place method
function (L::AutoDiffVJP{AD, IIP, true})(dv, v, p, t; VJP_input = nothing) where {AD <: AutoZygote, IIP}
function (L::AutoDiffVJP{AD, IIP, true})(dv,
v,
p,
t;
VJP_input = nothing) where {AD <: AutoZygote, IIP}
# ignore VJP_input as pullback was computed in update_coefficients!(...)

_dv = L(v, p, t; VJP_input = VJP_input)
copy!(dv, _dv)
end

function (L::AutoDiffVJP{AD, true, false})(dv, v, p, t; VJP_input = nothing) where {AD <: AutoZygote}
function (L::AutoDiffVJP{AD, true, false})(dv,
v,
p,
t;
VJP_input = nothing) where {AD <: AutoZygote}
@error("Zygote requires an out of place method with signature f(u).")
end

Expand Down
106 changes: 43 additions & 63 deletions src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
@@ -1,59 +1,31 @@
module SparseDiffTools

using Compat
using FiniteDiff
using ForwardDiff
using Graphs
using Graphs: SimpleGraph
using VertexSafeGraphs
using Adapt

using Reexport
# QoL/Helper Packages
using Adapt, Compat, Reexport
# Graph Coloring
using Graphs, VertexSafeGraphs
import Graphs: SimpleGraph
# Differentiation
using FiniteDiff, ForwardDiff
@reexport using ADTypes

using LinearAlgebra
using SparseArrays, ArrayInterface

import ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
# Array Packages
using ArrayInterface, SparseArrays
import ArrayInterface: matrix_colors
import StaticArrays

using ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
using DataStructures: DisjointSets, find_root!, union!

using ArrayInterface: matrix_colors

# Others
using SciMLOperators
import DataStructures: DisjointSets, find_root!, union!
import SciMLOperators: update_coefficients, update_coefficients!
using Tricks: Tricks, static_hasmethod
using Setfield: @set!
import Setfield: @set!
import Tricks: Tricks, static_hasmethod

abstract type AbstractAutoDiffVecProd end
import PackageExtensionCompat: @require_extensions
function __init__()
@require_extensions
end

export contract_color,
greedy_d1,
greedy_star1_coloring,
greedy_star2_coloring,
matrix2graph,
matrix_colors,
forwarddiff_color_jacobian!,
forwarddiff_color_jacobian,
ForwardColorJacCache,
numauto_color_hessian!,
numauto_color_hessian,
autoauto_color_hessian!,
autoauto_color_hessian,
ForwardColorHesCache,
ForwardAutoColorHesCache,
auto_jacvec, auto_jacvec!,
num_jacvec, num_jacvec!,
num_vecjac, num_vecjac!,
num_hesvec, num_hesvec!,
numauto_hesvec, numauto_hesvec!,
autonum_hesvec, autonum_hesvec!,
num_hesvecgrad, num_hesvecgrad!,
auto_hesvecgrad, auto_hesvecgrad!,
JacVec, HesVec, HesVecGrad, VecJac,
update_coefficients, update_coefficients!,
value!
abstract type AbstractAutoDiffVecProd end

include("coloring/high_level.jl")
include("coloring/backtracking_coloring.jl")
Expand All @@ -63,6 +35,7 @@ include("coloring/acyclic_coloring.jl")
include("coloring/greedy_star1_coloring.jl")
include("coloring/greedy_star2_coloring.jl")
include("coloring/matrix2graph.jl")

include("differentiation/compute_jacobian_ad.jl")
include("differentiation/compute_hessian_ad.jl")
include("differentiation/jaches_products.jl")
Expand All @@ -72,27 +45,34 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

import Requires

function numback_hesvec end
function numback_hesvec! end
function autoback_hesvec end
function autoback_hesvec! end
function auto_vecjac end
function auto_vecjac! end

@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("../ext/SparseDiffToolsZygoteExt.jl")
@reexport using .SparseDiffToolsZygoteExt
end
end
end

export
numback_hesvec, numback_hesvec!,
autoback_hesvec, autoback_hesvec!,
auto_vecjac, auto_vecjac!
# Coloring Algorithms
export AcyclicColoring,
BacktrackingColor, ContractionColor, GreedyD1Color, GreedyStar1Color, GreedyStar2Color
export matrix2graph, matrix_colors
# Sparse Jacobian Computation
export ForwardColorJacCache, forwarddiff_color_jacobian, forwarddiff_color_jacobian!
# Sparse Hessian Computation
export numauto_color_hessian, numauto_color_hessian!, autoauto_color_hessian,
autoauto_color_hessian!, ForwardAutoColorHesCache, ForwardColorHesCache
# JacVec Products
export auto_jacvec, auto_jacvec!, num_jacvec, num_jacvec!
# VecJac Products
export num_vecjac, num_vecjac!, auto_vecjac, auto_vecjac!
# HesVec Products
export numauto_hesvec,
numauto_hesvec!, autonum_hesvec, autonum_hesvec!, numback_hesvec,
numback_hesvec!
# HesVecGrad Products
export num_hesvecgrad, num_hesvecgrad!, auto_hesvecgrad, auto_hesvecgrad!
# Operators
export JacVec, HesVec, HesVecGrad, VecJac
export update_coefficients, update_coefficients!, value!

end # module
Loading

0 comments on commit 91175e6

Please sign in to comment.