Skip to content

Commit 1572796

Browse files
author
JAXopt authors
committed
Merge pull request #442 from vroulet:zoom_linesearch_revamp
PiperOrigin-RevId: 543992823
2 parents 7760823 + ef25b9f commit 1572796

9 files changed

+1338
-870
lines changed

jaxopt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@
5252
from jaxopt._src.scipy_wrappers import ScipyLeastSquares
5353
from jaxopt._src.scipy_wrappers import ScipyMinimize
5454
from jaxopt._src.scipy_wrappers import ScipyRootFinding
55+
from jaxopt._src.zoom_linesearch import ZoomLineSearch

jaxopt/_src/bfgs.py

Lines changed: 49 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@
2828
import jax.numpy as jnp
2929

3030
from jaxopt._src import base
31-
from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch
32-
from jaxopt._src.zoom_linesearch import zoom_linesearch
3331
from jaxopt.tree_util import tree_add_scalar_mul
3432
from jaxopt.tree_util import tree_l2_norm
3533
from jaxopt.tree_util import tree_sub
3634
from jaxopt._src.tree_util import tree_single_dtype
3735
from jaxopt._src.scipy_wrappers import make_onp_to_jnp
3836
from jaxopt._src.scipy_wrappers import pytree_topology_from_example
37+
from jaxopt._src.linesearch_util import _reset_stepsize
38+
from jaxopt._src.linesearch_util import _setup_linesearch
3939

4040

4141
_dot = partial(jnp.dot, precision=jax.lax.Precision.HIGHEST)
@@ -53,6 +53,7 @@ def pytree_to_flat_array(pytree, dtype):
5353

5454
class BfgsState(NamedTuple):
5555
"""Named tuple containing state information."""
56+
5657
iter_num: int
5758
value: float
5859
grad: Any
@@ -74,10 +75,8 @@ class BFGS(base.IterativeSolver):
7475
value_and_grad: whether ``fun`` just returns the value (False) or both
7576
the value and gradient (True).
7677
has_aux: whether ``fun`` outputs auxiliary data or not.
77-
If ``has_aux`` is False, ``fun`` is expected to be
78-
scalar-valued.
79-
If ``has_aux`` is True, then we have one of the following
80-
two cases.
78+
If ``has_aux`` is False, ``fun`` is expected to be scalar-valued.
79+
If ``has_aux`` is True, then we have one of the following two cases.
8180
If ``value_and_grad`` is False, the output should be
8281
``value, aux = fun(...)``.
8382
If ``value_and_grad == True``, the output should be
@@ -92,6 +91,8 @@ class BFGS(base.IterativeSolver):
9291
or a callable specifying the **positive** stepsize to use at each iteration.
9392
linesearch: the type of line search to use: "backtracking" for backtracking
9493
line search or "zoom" for zoom line search.
94+
condition: condition used to select the stepsize when using backtracking
95+
linesearch
9596
maxls: maximum number of iterations to use in the line search.
9697
decrease_factor: factor by which to decrease the stepsize during line search
9798
(default: 0.8).
@@ -151,6 +152,7 @@ def init_state(self,
151152
init_params: pytree containing the initial parameters.
152153
*args: additional positional arguments to be passed to ``fun``.
153154
**kwargs: additional keyword arguments to be passed to ``fun``.
155+
154156
Returns:
155157
state
156158
"""
@@ -179,6 +181,7 @@ def update(self,
179181
state: named tuple containing the solver state.
180182
*args: additional positional arguments to be passed to ``fun``.
181183
**kwargs: additional keyword arguments to be passed to ``fun``.
184+
182185
Returns:
183186
(params, state)
184187
"""
@@ -190,65 +193,18 @@ def update(self,
190193

191194
descent_direction = flat_array_to_pytree(-_dot(state.H, flat_grad))
192195

193-
if not isinstance(self.stepsize, Callable) and self.stepsize <= 0:
194-
# with line search
195-
196-
if self.linesearch == "backtracking":
197-
ls = BacktrackingLineSearch(fun=self._value_and_grad_with_aux,
198-
value_and_grad=True,
199-
maxiter=self.maxls,
200-
decrease_factor=self.decrease_factor,
201-
max_stepsize=self.max_stepsize,
202-
condition=self.condition,
203-
jit=self.jit,
204-
unroll=self.unroll,
205-
has_aux=True)
206-
init_stepsize = jnp.where(state.stepsize <= self.min_stepsize,
207-
# If stepsize became too small, we restart it.
208-
self.max_stepsize,
209-
# Else, we increase a bit the previous one.
210-
state.stepsize * self.increase_factor)
211-
new_stepsize, ls_state = ls.run(init_stepsize,
212-
params, value, grad,
213-
descent_direction,
214-
*args, **kwargs)
215-
new_value = ls_state.value
216-
new_aux = ls_state.aux
217-
new_params = ls_state.params
218-
new_grad = ls_state.grad
219-
220-
elif self.linesearch == "zoom":
221-
ls_state = zoom_linesearch(f=self._value_and_grad_with_aux,
222-
xk=params, pk=descent_direction,
223-
old_fval=value, gfk=grad, maxiter=self.maxls,
224-
value_and_grad=True, has_aux=True, aux=state.aux,
225-
args=args, kwargs=kwargs)
226-
new_value = ls_state.f_k
227-
new_aux = ls_state.aux
228-
new_stepsize = ls_state.a_k
229-
new_grad = ls_state.g_k
230-
# FIXME: zoom_linesearch currently doesn't return new_params
231-
# so we have to recompute it.
232-
t = new_stepsize.astype(tree_single_dtype(params))
233-
new_params = tree_add_scalar_mul(params, t, descent_direction)
234-
# FIXME: (zaccharieramzi) sometimes the linesearch fails
235-
# and therefore its value g_k does not correspond
236-
# to the gradient at the new parameters.
237-
# with the following conditional loop we have a hot fix that just
238-
# recomputes the value, gradient and auxiliary value
239-
# at the new parameters. It would be better to understand
240-
# what the g_k passed by zoom_linesearch is in this case
241-
# and why it is wrong.
242-
(new_value, new_aux), new_grad = jax.lax.cond(
243-
ls_state.failed,
244-
lambda: self._value_and_grad_with_aux(new_params, *args, **kwargs),
245-
lambda: ((new_value, new_aux), new_grad),
246-
)
247-
else:
248-
raise ValueError("Invalid name in 'linesearch' option.")
249-
196+
use_linesearch = not isinstance(self.stepsize, Callable) and self.stepsize <= 0
197+
198+
if use_linesearch:
199+
init_stepsize = self._reset_stepsize(state.stepsize)
200+
new_stepsize, ls_state = self.run_ls(
201+
init_stepsize, params, value, grad, descent_direction, *args, **kwargs
202+
)
203+
new_params = ls_state.params
204+
new_value = ls_state.value
205+
new_grad = ls_state.grad
206+
new_aux = ls_state.aux
250207
else:
251-
# without line search
252208
if isinstance(self.stepsize, Callable):
253209
new_stepsize = self.stepsize(state.iter_num)
254210
else:
@@ -284,7 +240,7 @@ def optimality_fun(self, params, *args, **kwargs):
284240
return self._value_and_grad_fun(params, *args, **kwargs)[1]
285241

286242
def _value_and_grad_fun(self, params, *args, **kwargs):
287-
(value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs)
243+
(value, _), grad = self._value_and_grad_with_aux(params, *args, **kwargs)
288244
return value, grad
289245

290246
def __post_init__(self):
@@ -294,3 +250,31 @@ def __post_init__(self):
294250
has_aux=self.has_aux)
295251

296252
self.reference_signature = self.fun
253+
jit, unroll = self._get_loop_options()
254+
linesearch_solver = _setup_linesearch(
255+
linesearch=self.linesearch,
256+
fun=self._value_and_grad_with_aux,
257+
value_and_grad=True,
258+
has_aux=True,
259+
maxlsiter=self.maxls,
260+
max_stepsize=self.max_stepsize,
261+
jit=jit,
262+
unroll=unroll,
263+
verbose=self.verbose,
264+
condition=self.condition,
265+
decrease_factor=self.decrease_factor,
266+
increase_factor=self.increase_factor,
267+
)
268+
269+
self._reset_stepsize = partial(
270+
_reset_stepsize,
271+
self.linesearch,
272+
self.max_stepsize,
273+
self.min_stepsize,
274+
self.increase_factor,
275+
)
276+
277+
if jit:
278+
self.run_ls = jax.jit(linesearch_solver.run)
279+
else:
280+
self.run_ls = linesearch_solver.run

0 commit comments

Comments
 (0)