28
28
import jax .numpy as jnp
29
29
30
30
from jaxopt ._src import base
31
- from jaxopt ._src .backtracking_linesearch import BacktrackingLineSearch
32
- from jaxopt ._src .zoom_linesearch import zoom_linesearch
33
31
from jaxopt .tree_util import tree_add_scalar_mul
34
32
from jaxopt .tree_util import tree_l2_norm
35
33
from jaxopt .tree_util import tree_sub
36
34
from jaxopt ._src .tree_util import tree_single_dtype
37
35
from jaxopt ._src .scipy_wrappers import make_onp_to_jnp
38
36
from jaxopt ._src .scipy_wrappers import pytree_topology_from_example
37
+ from jaxopt ._src .linesearch_util import _reset_stepsize
38
+ from jaxopt ._src .linesearch_util import _setup_linesearch
39
39
40
40
41
41
_dot = partial (jnp .dot , precision = jax .lax .Precision .HIGHEST )
@@ -53,6 +53,7 @@ def pytree_to_flat_array(pytree, dtype):
53
53
54
54
class BfgsState (NamedTuple ):
55
55
"""Named tuple containing state information."""
56
+
56
57
iter_num : int
57
58
value : float
58
59
grad : Any
@@ -74,10 +75,8 @@ class BFGS(base.IterativeSolver):
74
75
value_and_grad: whether ``fun`` just returns the value (False) or both
75
76
the value and gradient (True).
76
77
has_aux: whether ``fun`` outputs auxiliary data or not.
77
- If ``has_aux`` is False, ``fun`` is expected to be
78
- scalar-valued.
79
- If ``has_aux`` is True, then we have one of the following
80
- two cases.
78
+ If ``has_aux`` is False, ``fun`` is expected to be scalar-valued.
79
+ If ``has_aux`` is True, then we have one of the following two cases.
81
80
If ``value_and_grad`` is False, the output should be
82
81
``value, aux = fun(...)``.
83
82
If ``value_and_grad == True``, the output should be
@@ -92,6 +91,8 @@ class BFGS(base.IterativeSolver):
92
91
or a callable specifying the **positive** stepsize to use at each iteration.
93
92
linesearch: the type of line search to use: "backtracking" for backtracking
94
93
line search or "zoom" for zoom line search.
94
+ condition: condition used to select the stepsize when using backtracking
95
+ linesearch
95
96
maxls: maximum number of iterations to use in the line search.
96
97
decrease_factor: factor by which to decrease the stepsize during line search
97
98
(default: 0.8).
@@ -151,6 +152,7 @@ def init_state(self,
151
152
init_params: pytree containing the initial parameters.
152
153
*args: additional positional arguments to be passed to ``fun``.
153
154
**kwargs: additional keyword arguments to be passed to ``fun``.
155
+
154
156
Returns:
155
157
state
156
158
"""
@@ -179,6 +181,7 @@ def update(self,
179
181
state: named tuple containing the solver state.
180
182
*args: additional positional arguments to be passed to ``fun``.
181
183
**kwargs: additional keyword arguments to be passed to ``fun``.
184
+
182
185
Returns:
183
186
(params, state)
184
187
"""
@@ -190,65 +193,18 @@ def update(self,
190
193
191
194
descent_direction = flat_array_to_pytree (- _dot (state .H , flat_grad ))
192
195
193
- if not isinstance (self .stepsize , Callable ) and self .stepsize <= 0 :
194
- # with line search
195
-
196
- if self .linesearch == "backtracking" :
197
- ls = BacktrackingLineSearch (fun = self ._value_and_grad_with_aux ,
198
- value_and_grad = True ,
199
- maxiter = self .maxls ,
200
- decrease_factor = self .decrease_factor ,
201
- max_stepsize = self .max_stepsize ,
202
- condition = self .condition ,
203
- jit = self .jit ,
204
- unroll = self .unroll ,
205
- has_aux = True )
206
- init_stepsize = jnp .where (state .stepsize <= self .min_stepsize ,
207
- # If stepsize became too small, we restart it.
208
- self .max_stepsize ,
209
- # Else, we increase a bit the previous one.
210
- state .stepsize * self .increase_factor )
211
- new_stepsize , ls_state = ls .run (init_stepsize ,
212
- params , value , grad ,
213
- descent_direction ,
214
- * args , ** kwargs )
215
- new_value = ls_state .value
216
- new_aux = ls_state .aux
217
- new_params = ls_state .params
218
- new_grad = ls_state .grad
219
-
220
- elif self .linesearch == "zoom" :
221
- ls_state = zoom_linesearch (f = self ._value_and_grad_with_aux ,
222
- xk = params , pk = descent_direction ,
223
- old_fval = value , gfk = grad , maxiter = self .maxls ,
224
- value_and_grad = True , has_aux = True , aux = state .aux ,
225
- args = args , kwargs = kwargs )
226
- new_value = ls_state .f_k
227
- new_aux = ls_state .aux
228
- new_stepsize = ls_state .a_k
229
- new_grad = ls_state .g_k
230
- # FIXME: zoom_linesearch currently doesn't return new_params
231
- # so we have to recompute it.
232
- t = new_stepsize .astype (tree_single_dtype (params ))
233
- new_params = tree_add_scalar_mul (params , t , descent_direction )
234
- # FIXME: (zaccharieramzi) sometimes the linesearch fails
235
- # and therefore its value g_k does not correspond
236
- # to the gradient at the new parameters.
237
- # with the following conditional loop we have a hot fix that just
238
- # recomputes the value, gradient and auxiliary value
239
- # at the new parameters. It would be better to understand
240
- # what the g_k passed by zoom_linesearch is in this case
241
- # and why it is wrong.
242
- (new_value , new_aux ), new_grad = jax .lax .cond (
243
- ls_state .failed ,
244
- lambda : self ._value_and_grad_with_aux (new_params , * args , ** kwargs ),
245
- lambda : ((new_value , new_aux ), new_grad ),
246
- )
247
- else :
248
- raise ValueError ("Invalid name in 'linesearch' option." )
249
-
196
+ use_linesearch = not isinstance (self .stepsize , Callable ) and self .stepsize <= 0
197
+
198
+ if use_linesearch :
199
+ init_stepsize = self ._reset_stepsize (state .stepsize )
200
+ new_stepsize , ls_state = self .run_ls (
201
+ init_stepsize , params , value , grad , descent_direction , * args , ** kwargs
202
+ )
203
+ new_params = ls_state .params
204
+ new_value = ls_state .value
205
+ new_grad = ls_state .grad
206
+ new_aux = ls_state .aux
250
207
else :
251
- # without line search
252
208
if isinstance (self .stepsize , Callable ):
253
209
new_stepsize = self .stepsize (state .iter_num )
254
210
else :
@@ -284,7 +240,7 @@ def optimality_fun(self, params, *args, **kwargs):
284
240
return self ._value_and_grad_fun (params , * args , ** kwargs )[1 ]
285
241
286
242
def _value_and_grad_fun (self , params , * args , ** kwargs ):
287
- (value , aux ), grad = self ._value_and_grad_with_aux (params , * args , ** kwargs )
243
+ (value , _ ), grad = self ._value_and_grad_with_aux (params , * args , ** kwargs )
288
244
return value , grad
289
245
290
246
def __post_init__ (self ):
@@ -294,3 +250,31 @@ def __post_init__(self):
294
250
has_aux = self .has_aux )
295
251
296
252
self .reference_signature = self .fun
253
+ jit , unroll = self ._get_loop_options ()
254
+ linesearch_solver = _setup_linesearch (
255
+ linesearch = self .linesearch ,
256
+ fun = self ._value_and_grad_with_aux ,
257
+ value_and_grad = True ,
258
+ has_aux = True ,
259
+ maxlsiter = self .maxls ,
260
+ max_stepsize = self .max_stepsize ,
261
+ jit = jit ,
262
+ unroll = unroll ,
263
+ verbose = self .verbose ,
264
+ condition = self .condition ,
265
+ decrease_factor = self .decrease_factor ,
266
+ increase_factor = self .increase_factor ,
267
+ )
268
+
269
+ self ._reset_stepsize = partial (
270
+ _reset_stepsize ,
271
+ self .linesearch ,
272
+ self .max_stepsize ,
273
+ self .min_stepsize ,
274
+ self .increase_factor ,
275
+ )
276
+
277
+ if jit :
278
+ self .run_ls = jax .jit (linesearch_solver .run )
279
+ else :
280
+ self .run_ls = linesearch_solver .run
0 commit comments