Skip to content

Commit 91175e6

Browse files
committed
Format codebase
1 parent 9c458cd commit 91175e6

20 files changed

+394
-422
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ Manifest.toml
66

77
docs/build/
88
docs/site/
9+
10+
.vscode

Project.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
4-
version = "2.4.1"
4+
version = "2.5.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -13,8 +13,8 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1313
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1414
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
16+
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1617
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
17-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1818
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
1919
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2020
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3030
SparseDiffToolsZygoteExt = "Zygote"
3131

3232
[compat]
33-
ADTypes = "0.1"
33+
ADTypes = "0.2"
3434
Adapt = "3.0"
3535
ArrayInterface = "7.4.2"
3636
Compat = "4"
@@ -39,7 +39,6 @@ FiniteDiff = "2.8.1"
3939
ForwardDiff = "0.10"
4040
Graphs = "1"
4141
Reexport = "1"
42-
Requires = "1"
4342
SciMLOperators = "0.2.11, 0.3"
4443
Setfield = "1"
4544
StaticArrayInterface = "1.3"

docs/make.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ using Documenter, SparseDiffTools
33
include("pages.jl")
44

55
makedocs(sitename = "SparseDiffTools.jl",
6-
authors = "Chris Rackauckas",
7-
modules = [SparseDiffTools],
8-
clean = true,
9-
doctest = false,
10-
format = Documenter.HTML(assets = ["assets/favicon.ico"],
11-
canonical = "https://docs.sciml.ai/SparseDiffTools/stable/"),
12-
pages = pages)
6+
authors = "Chris Rackauckas",
7+
modules = [SparseDiffTools],
8+
clean = true,
9+
doctest = false,
10+
format = Documenter.HTML(assets = ["assets/favicon.ico"],
11+
canonical = "https://docs.sciml.ai/SparseDiffTools/stable/"),
12+
pages = pages)
1313

1414
deploydocs(repo = "github.com/JuliaDiff/SparseDiffTools.jl.git";
15-
push_preview = true)
15+
push_preview = true)

ext/SparseDiffToolsZygoteExt.jl

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import Setfield: @set!
1111
### Jac, Hes products
1212

1313
function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v),
14-
cache2 = similar(v))
14+
cache2 = similar(v))
1515
g = let f = f
1616
(dx, x) -> dx .= first(Zygote.gradient(f, x))
1717
end
@@ -39,20 +39,20 @@ function SparseDiffTools.numback_hesvec(f, x, v)
3939
end
4040

4141
function SparseDiffTools.autoback_hesvec!(dy, f, x, v,
42-
cache1 = Dual{
43-
typeof(ForwardDiff.Tag(DeivVecTag(),
44-
eltype(x))),
45-
eltype(x), 1
46-
}.(x,
47-
ForwardDiff.Partials.(tuple.(reshape(v,
48-
size(x))))),
49-
cache2 = Dual{
50-
typeof(ForwardDiff.Tag(DeivVecTag(),
51-
eltype(x))),
52-
eltype(x), 1
53-
}.(x,
54-
ForwardDiff.Partials.(tuple.(reshape(v,
55-
size(x))))))
42+
cache1 = Dual{
43+
typeof(ForwardDiff.Tag(DeivVecTag(),
44+
eltype(x))),
45+
eltype(x), 1,
46+
}.(x,
47+
ForwardDiff.Partials.(tuple.(reshape(v,
48+
size(x))))),
49+
cache2 = Dual{
50+
typeof(ForwardDiff.Tag(DeivVecTag(),
51+
eltype(x))),
52+
eltype(x), 1,
53+
}.(x,
54+
ForwardDiff.Partials.(tuple.(reshape(v,
55+
size(x))))))
5656
g = let f = f
5757
(dx, x) -> dx .= first(Zygote.gradient(f, x))
5858
end
@@ -64,16 +64,20 @@ end
6464

6565
function SparseDiffTools.autoback_hesvec(f, x, v)
6666
g = x -> first(Zygote.gradient(f, x))
67-
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))), eltype(x), 1
68-
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
67+
y = Dual{
68+
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
69+
eltype(x),
70+
1,
71+
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
6972
ForwardDiff.partials.(g(y), 1)
7073
end
7174

7275
## VecJac products
7376

7477
# VJP methods
7578
function SparseDiffTools.auto_vecjac!(du, f, x, v)
76-
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = AutoFiniteDiff()")
79+
!hasmethod(f, (typeof(x),)) &&
80+
error("For inplace function use autodiff = AutoFiniteDiff()")
7781
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
7882
end
7983

@@ -84,16 +88,17 @@ end
8488

8589
# overload operator interface
8690
function SparseDiffTools._vecjac(f, u, autodiff::AutoZygote)
87-
8891
cache = ()
8992
pullback = Zygote.pullback(f, u)
9093

9194
AutoDiffVJP(f, u, cache, autodiff, pullback)
9295
end
9396

94-
function update_coefficients(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
95-
) where{AD <: AutoZygote}
96-
97+
function update_coefficients(L::AutoDiffVJP{AD},
98+
u,
99+
p,
100+
t;
101+
VJP_input = nothing) where {AD <: AutoZygote}
97102
if !isnothing(VJP_input)
98103
@set! L.u = VJP_input
99104
end
@@ -102,9 +107,11 @@ function update_coefficients(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
102107
@set! L.pullback = Zygote.pullback(L.f, L.u)
103108
end
104109

105-
function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
106-
) where{AD <: AutoZygote}
107-
110+
function update_coefficients!(L::AutoDiffVJP{AD},
111+
u,
112+
p,
113+
t;
114+
VJP_input = nothing) where {AD <: AutoZygote}
108115
if !isnothing(VJP_input)
109116
copy!(L.u, VJP_input)
110117
end
@@ -116,7 +123,7 @@ function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t; VJP_input = nothing,
116123
end
117124

118125
# Interpret the call as df/du' * v
119-
function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing) where{AD <: AutoZygote}
126+
function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing) where {AD <: AutoZygote}
120127
# ignore VJP_input as pullback was computed in update_coefficients(...)
121128

122129
y, back = L.pullback
@@ -126,14 +133,22 @@ function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing) where{AD <: AutoZygo
126133
end
127134

128135
# prefer non in-place method
129-
function (L::AutoDiffVJP{AD, IIP, true})(dv, v, p, t; VJP_input = nothing) where {AD <: AutoZygote, IIP}
136+
function (L::AutoDiffVJP{AD, IIP, true})(dv,
137+
v,
138+
p,
139+
t;
140+
VJP_input = nothing) where {AD <: AutoZygote, IIP}
130141
# ignore VJP_input as pullback was computed in update_coefficients!(...)
131142

132143
_dv = L(v, p, t; VJP_input = VJP_input)
133144
copy!(dv, _dv)
134145
end
135146

136-
function (L::AutoDiffVJP{AD, true, false})(dv, v, p, t; VJP_input = nothing) where {AD <: AutoZygote}
147+
function (L::AutoDiffVJP{AD, true, false})(dv,
148+
v,
149+
p,
150+
t;
151+
VJP_input = nothing) where {AD <: AutoZygote}
137152
@error("Zygote requires an out of place method with signature f(u).")
138153
end
139154

src/SparseDiffTools.jl

Lines changed: 43 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,31 @@
11
module SparseDiffTools
22

3-
using Compat
4-
using FiniteDiff
5-
using ForwardDiff
6-
using Graphs
7-
using Graphs: SimpleGraph
8-
using VertexSafeGraphs
9-
using Adapt
10-
11-
using Reexport
3+
# QoL/Helper Packages
4+
using Adapt, Compat, Reexport
5+
# Graph Coloring
6+
using Graphs, VertexSafeGraphs
7+
import Graphs: SimpleGraph
8+
# Differentiation
9+
using FiniteDiff, ForwardDiff
1210
@reexport using ADTypes
13-
14-
using LinearAlgebra
15-
using SparseArrays, ArrayInterface
16-
11+
import ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
12+
# Array Packages
13+
using ArrayInterface, SparseArrays
14+
import ArrayInterface: matrix_colors
1715
import StaticArrays
18-
19-
using ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
20-
using DataStructures: DisjointSets, find_root!, union!
21-
22-
using ArrayInterface: matrix_colors
23-
16+
# Others
2417
using SciMLOperators
18+
import DataStructures: DisjointSets, find_root!, union!
2519
import SciMLOperators: update_coefficients, update_coefficients!
26-
using Tricks: Tricks, static_hasmethod
27-
using Setfield: @set!
20+
import Setfield: @set!
21+
import Tricks: Tricks, static_hasmethod
2822

29-
abstract type AbstractAutoDiffVecProd end
23+
import PackageExtensionCompat: @require_extensions
24+
function __init__()
25+
@require_extensions
26+
end
3027

31-
export contract_color,
32-
greedy_d1,
33-
greedy_star1_coloring,
34-
greedy_star2_coloring,
35-
matrix2graph,
36-
matrix_colors,
37-
forwarddiff_color_jacobian!,
38-
forwarddiff_color_jacobian,
39-
ForwardColorJacCache,
40-
numauto_color_hessian!,
41-
numauto_color_hessian,
42-
autoauto_color_hessian!,
43-
autoauto_color_hessian,
44-
ForwardColorHesCache,
45-
ForwardAutoColorHesCache,
46-
auto_jacvec, auto_jacvec!,
47-
num_jacvec, num_jacvec!,
48-
num_vecjac, num_vecjac!,
49-
num_hesvec, num_hesvec!,
50-
numauto_hesvec, numauto_hesvec!,
51-
autonum_hesvec, autonum_hesvec!,
52-
num_hesvecgrad, num_hesvecgrad!,
53-
auto_hesvecgrad, auto_hesvecgrad!,
54-
JacVec, HesVec, HesVecGrad, VecJac,
55-
update_coefficients, update_coefficients!,
56-
value!
28+
abstract type AbstractAutoDiffVecProd end
5729

5830
include("coloring/high_level.jl")
5931
include("coloring/backtracking_coloring.jl")
@@ -63,6 +35,7 @@ include("coloring/acyclic_coloring.jl")
6335
include("coloring/greedy_star1_coloring.jl")
6436
include("coloring/greedy_star2_coloring.jl")
6537
include("coloring/matrix2graph.jl")
38+
6639
include("differentiation/compute_jacobian_ad.jl")
6740
include("differentiation/compute_hessian_ad.jl")
6841
include("differentiation/jaches_products.jl")
@@ -72,27 +45,34 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
7245
parameterless_type(x) = parameterless_type(typeof(x))
7346
parameterless_type(x::Type) = __parameterless_type(x)
7447

75-
import Requires
76-
7748
function numback_hesvec end
7849
function numback_hesvec! end
7950
function autoback_hesvec end
8051
function autoback_hesvec! end
8152
function auto_vecjac end
8253
function auto_vecjac! end
8354

84-
@static if !isdefined(Base, :get_extension)
85-
function __init__()
86-
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
87-
include("../ext/SparseDiffToolsZygoteExt.jl")
88-
@reexport using .SparseDiffToolsZygoteExt
89-
end
90-
end
91-
end
92-
93-
export
94-
numback_hesvec, numback_hesvec!,
95-
autoback_hesvec, autoback_hesvec!,
96-
auto_vecjac, auto_vecjac!
55+
# Coloring Algorithms
56+
export AcyclicColoring,
57+
BacktrackingColor, ContractionColor, GreedyD1Color, GreedyStar1Color, GreedyStar2Color
58+
export matrix2graph, matrix_colors
59+
# Sparse Jacobian Computation
60+
export ForwardColorJacCache, forwarddiff_color_jacobian, forwarddiff_color_jacobian!
61+
# Sparse Hessian Computation
62+
export numauto_color_hessian, numauto_color_hessian!, autoauto_color_hessian,
63+
autoauto_color_hessian!, ForwardAutoColorHesCache, ForwardColorHesCache
64+
# JacVec Products
65+
export auto_jacvec, auto_jacvec!, num_jacvec, num_jacvec!
66+
# VecJac Products
67+
export num_vecjac, num_vecjac!, auto_vecjac, auto_vecjac!
68+
# HesVec Products
69+
export numauto_hesvec,
70+
numauto_hesvec!, autonum_hesvec, autonum_hesvec!, numback_hesvec,
71+
numback_hesvec!
72+
# HesVecGrad Products
73+
export num_hesvecgrad, num_hesvecgrad!, auto_hesvecgrad, auto_hesvecgrad!
74+
# Operators
75+
export JacVec, HesVec, HesVecGrad, VecJac
76+
export update_coefficients, update_coefficients!, value!
9777

9878
end # module

0 commit comments

Comments
 (0)