Skip to content

Commit 53f659f

Browse files
author
JAXopt authors
committed
Merge pull request #416 from zaccharieramzi:broyden
PiperOrigin-RevId: 539672172
2 parents 257b673 + 4c67a7d commit 53f659f

File tree

5 files changed

+592
-0
lines changed

5 files changed

+592
-0
lines changed

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Root finding
106106
:toctree: _autosummary
107107

108108
jaxopt.Bisection
109+
jaxopt.Broyden
109110
jaxopt.ScipyRootFinding
110111

111112
Fixed point resolution

docs/root_finding.rst

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,49 @@ Scipy wrapper
6868
:toctree: _autosummary
6969

7070
jaxopt.ScipyRootFinding
71+
72+
73+
Broyden's method
74+
--------------
75+
76+
.. autosummary::
77+
:toctree: _autosummary
78+
79+
jaxopt.Broyden
80+
81+
Broyden's method is an iterative algorithm suitable for nonlinear root equations in any dimension.
82+
It is a quasi-Newton method (like L-BFGS), meaning that it uses an approximation of the Jacobian matrix
83+
at each iteration.
84+
The approximation is updated at each iteration with a rank-one update.
85+
This makes the approximation easy to invert using the Sherman-Morrison formula, given it does not use too many
86+
updates.
87+
One can control the number of updates with the ``history_size`` argument.
88+
Furthermore, Broyden's method uses a line search to ensure the rank-one updates are stable.
89+
90+
Example::
91+
92+
import jax.numpy as jnp
93+
from jaxopt import Broyden
94+
95+
def F(x):
96+
return x ** 3 - x - 2
97+
98+
broyden = Broyden(fun=F)
99+
print(broyden.run(jnp.array(1.0)).params)
100+
101+
102+
For implicit differentiation::
103+
104+
import jax
105+
import jax.numpy as jnp
106+
from jaxopt import Broyden
107+
108+
def F(x, factor):
109+
return factor * x ** 3 - x - 2
110+
111+
def root(factor):
112+
broyden = Broyden(fun=F)
113+
return broyden.run(jnp.array(1.0), factor=factor).params
114+
115+
# Derivative of root with respect to factor at 2.0.
116+
print(jax.grad(root)(2.0))

jaxopt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from jaxopt._src.bfgs import BFGS
2828
from jaxopt._src.bisection import Bisection
2929
from jaxopt._src.block_cd import BlockCoordinateDescent
30+
from jaxopt._src.broyden import Broyden
3031
from jaxopt._src.cd_qp import BoxCDQP
3132
from jaxopt._src.cvxpy_wrapper import CvxpyQP
3233
from jaxopt._src.eq_qp import EqualityConstrainedQP

0 commit comments

Comments
 (0)