Skip to content

Commit 9d658a4

Browse files
rdyroOptaxDev
authored and
OptaxDev
committed
Ensure optimizers accept extra args
Most already use optax.chain which enforces this Fix for #1131 PiperOrigin-RevId: 731893013
1 parent 6ee6e61 commit 9d658a4

9 files changed

+102
-38
lines changed

optax/_src/alias.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def adabelief(
3939
eps_root: float = 1e-16,
4040
*,
4141
nesterov: bool = False,
42-
) -> base.GradientTransformation:
42+
) -> base.GradientTransformationExtraArgs:
4343
r"""The AdaBelief optimizer.
4444
4545
AdaBelief is an adaptive learning rate optimizer that focuses on fast
@@ -141,7 +141,7 @@ def adadelta(
141141
eps: float = 1e-6,
142142
weight_decay: float = 0.0,
143143
weight_decay_mask: MaskOrFn = None,
144-
) -> base.GradientTransformation:
144+
) -> base.GradientTransformationExtraArgs:
145145
"""The Adadelta optimizer.
146146
147147
Adadelta is a stochastic gradient descent method that adapts learning rates
@@ -208,7 +208,7 @@ def adafactor(
208208
eps: float = 1e-30,
209209
factored: bool = True,
210210
weight_decay_mask: MaskOrFn = None,
211-
) -> base.GradientTransformation:
211+
) -> base.GradientTransformationExtraArgs:
212212
"""The Adafactor optimizer.
213213
214214
Adafactor is an adaptive learning rate optimizer that focuses on fast
@@ -304,7 +304,7 @@ def adagrad(
304304
learning_rate: base.ScalarOrSchedule,
305305
initial_accumulator_value: float = 0.1,
306306
eps: float = 1e-7,
307-
) -> base.GradientTransformation:
307+
) -> base.GradientTransformationExtraArgs:
308308
r"""The Adagrad optimizer.
309309
310310
AdaGrad is a sub-gradient algorithm for stochastic optimization that adapts
@@ -394,7 +394,7 @@ def adam(
394394
mu_dtype: Optional[Any] = None,
395395
*,
396396
nesterov: bool = False,
397-
) -> base.GradientTransformation:
397+
) -> base.GradientTransformationExtraArgs:
398398
r"""The Adam optimizer.
399399
400400
Adam is an SGD variant with gradient scaling adaptation. The scaling
@@ -580,7 +580,7 @@ def adamw(
580580
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
581581
*,
582582
nesterov: bool = False,
583-
) -> base.GradientTransformation:
583+
) -> base.GradientTransformationExtraArgs:
584584
r"""Adam with weight decay regularization.
585585
586586
AdamW uses weight decay to regularize learning towards small weights, as
@@ -789,7 +789,7 @@ def adan(
789789
eps_root: float = 1e-8,
790790
weight_decay: float = 0.0,
791791
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
792-
) -> base.GradientTransformation:
792+
) -> base.GradientTransformationExtraArgs:
793793
r"""The ADAptive Nesterov momentum algorithm (Adan).
794794
795795
Adan first reformulates the vanilla Nesterov acceleration to develop a new
@@ -905,7 +905,7 @@ def lion(
905905
mu_dtype: Optional[Any] = None,
906906
weight_decay: float = 1e-3,
907907
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
908-
) -> base.GradientTransformation:
908+
) -> base.GradientTransformationExtraArgs:
909909
r"""The Lion optimizer.
910910
911911
Lion is discovered by symbolic program search. Unlike most adaptive optimizers
@@ -1001,7 +1001,7 @@ def amsgrad(
10011001
eps: float = 1e-8,
10021002
eps_root: float = 0.0,
10031003
mu_dtype: Optional[Any] = None,
1004-
) -> base.GradientTransformation:
1004+
) -> base.GradientTransformationExtraArgs:
10051005
"""The AMSGrad optimizer.
10061006
10071007
The original Adam can fail to converge to the optimal solution in some cases.
@@ -1058,7 +1058,7 @@ def amsgrad(
10581058

10591059
def fromage(
10601060
learning_rate: float, min_norm: float = 1e-6
1061-
) -> base.GradientTransformation:
1061+
) -> base.GradientTransformationExtraArgs:
10621062
"""The Frobenius matched gradient descent (Fromage) optimizer.
10631063
10641064
Fromage is a learning algorithm that does not require learning rate tuning.
@@ -1119,7 +1119,7 @@ def lars(
11191119
trust_ratio_mask: MaskOrFn = True,
11201120
momentum: float = 0.9,
11211121
nesterov: bool = False,
1122-
) -> base.GradientTransformation:
1122+
) -> base.GradientTransformationExtraArgs:
11231123
"""The LARS optimizer.
11241124
11251125
LARS is a layer-wise adaptive optimizer introduced to help scale SGD to
@@ -1191,7 +1191,7 @@ def lamb(
11911191
eps_root: float = 0.0,
11921192
weight_decay: float = 0.0,
11931193
mask: MaskOrFn = None,
1194-
) -> base.GradientTransformation:
1194+
) -> base.GradientTransformationExtraArgs:
11951195
"""The LAMB optimizer.
11961196
11971197
LAMB is a general purpose layer-wise adaptive large batch optimizer designed
@@ -1257,7 +1257,7 @@ def noisy_sgd(
12571257
eta: float = 0.01,
12581258
gamma: float = 0.55,
12591259
seed: int = 0,
1260-
) -> base.GradientTransformation:
1260+
) -> base.GradientTransformationExtraArgs:
12611261
r"""A variant of SGD with added noise.
12621262
12631263
Noisy SGD is a variant of :func:`optax.sgd` that incorporates Gaussian noise
@@ -1325,7 +1325,7 @@ def noisy_sgd(
13251325

13261326
def sign_sgd(
13271327
learning_rate: base.ScalarOrSchedule,
1328-
) -> base.GradientTransformation:
1328+
) -> base.GradientTransformationExtraArgs:
13291329
r"""A variant of SGD using only the signs of the gradient components.
13301330
13311331
SignSGD is a variant of SGD that uses the signs of the gradient components in
@@ -1394,7 +1394,7 @@ def novograd(
13941394
eps: float = 1e-6,
13951395
eps_root: float = 0.0,
13961396
weight_decay: float = 0.0,
1397-
) -> base.GradientTransformation:
1397+
) -> base.GradientTransformationExtraArgs:
13981398
"""NovoGrad optimizer.
13991399
14001400
NovoGrad is more robust to the initial learning rate and
@@ -1461,7 +1461,7 @@ def optimistic_gradient_descent(
14611461
learning_rate: base.ScalarOrSchedule,
14621462
alpha: base.ScalarOrSchedule = 1.0,
14631463
beta: base.ScalarOrSchedule = 1.0,
1464-
) -> base.GradientTransformation:
1464+
) -> base.GradientTransformationExtraArgs:
14651465
"""An Optimistic Gradient Descent optimizer.
14661466
14671467
Optimistic gradient descent is an approximation of extra-gradient methods
@@ -1523,7 +1523,7 @@ def optimistic_adam(
15231523
mu_dtype: Optional[Any] = None,
15241524
*,
15251525
nesterov: bool = True,
1526-
) -> base.GradientTransformation:
1526+
) -> base.GradientTransformationExtraArgs:
15271527
r"""The Optimistic Adam optimizer.
15281528
15291529
This is an optimistic version of the Adam optimizer. It addresses the issue
@@ -1643,7 +1643,7 @@ def radam(
16431643
threshold: float = 5.0,
16441644
*,
16451645
nesterov: bool = False,
1646-
) -> base.GradientTransformation:
1646+
) -> base.GradientTransformationExtraArgs:
16471647
"""The Rectified Adam optimizer.
16481648
16491649
The adaptive learning rate in Adam has undesirably large variance in early
@@ -1715,7 +1715,7 @@ def rmsprop(
17151715
momentum: Optional[float] = None,
17161716
nesterov: bool = False,
17171717
bias_correction: bool = False,
1718-
) -> base.GradientTransformation:
1718+
) -> base.GradientTransformationExtraArgs:
17191719
r"""A flexible RMSProp optimizer.
17201720
17211721
RMSProp is an SGD variant with learning rate adaptation. The `learning_rate`
@@ -1824,7 +1824,7 @@ def sgd(
18241824
momentum: Optional[float] = None,
18251825
nesterov: bool = False,
18261826
accumulator_dtype: Optional[Any] = None,
1827-
) -> base.GradientTransformation:
1827+
) -> base.GradientTransformationExtraArgs:
18281828
r"""A canonical Stochastic Gradient Descent optimizer.
18291829
18301830
This implements stochastic gradient descent. It also includes support for
@@ -1911,7 +1911,7 @@ def sgd(
19111911

19121912
def sm3(
19131913
learning_rate: float, momentum: float = 0.9
1914-
) -> base.GradientTransformation:
1914+
) -> base.GradientTransformationExtraArgs:
19151915
r"""The SM3 optimizer.
19161916
19171917
SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a
@@ -2024,7 +2024,7 @@ def yogi(
20242024
b1: float = 0.9,
20252025
b2: float = 0.999,
20262026
eps: float = 1e-3,
2027-
) -> base.GradientTransformation:
2027+
) -> base.GradientTransformationExtraArgs:
20282028
# pylint: disable=line-too-long
20292029
"""The Yogi optimizer.
20302030
@@ -2083,7 +2083,7 @@ def adamax(
20832083
b1: float = 0.9,
20842084
b2: float = 0.999,
20852085
eps: float = 1e-8,
2086-
) -> base.GradientTransformation:
2086+
) -> base.GradientTransformationExtraArgs:
20872087
r"""A variant of the Adam optimizer that uses the infinity norm.
20882088
20892089
AdaMax is a variant of the :func:`optax.adam` optimizer. By generalizing
@@ -2170,7 +2170,7 @@ def adamaxw(
21702170
eps: float = 1e-8,
21712171
weight_decay: float = 1e-4,
21722172
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
2173-
) -> base.GradientTransformation:
2173+
) -> base.GradientTransformationExtraArgs:
21742174
"""Adamax with weight decay regularization.
21752175
21762176
AdamaxW uses weight decay to regularize learning towards small weights, as
@@ -2244,7 +2244,7 @@ def rprop(
22442244
eta_plus: float = 1.2,
22452245
min_step_size: float = 1e-6,
22462246
max_step_size: float = 50.0,
2247-
) -> base.GradientTransformation:
2247+
) -> base.GradientTransformationExtraArgs:
22482248
"""The Rprop optimizer.
22492249
22502250
Rprop, short for resillient backpropogation, is a first order variant of
@@ -2405,7 +2405,7 @@ def lbfgs(
24052405
memory_size: int = 10,
24062406
scale_init_precond: bool = True,
24072407
linesearch: Optional[
2408-
base.GradientTransformationExtraArgs
2408+
Union[base.GradientTransformationExtraArgs, base.GradientTransformation]
24092409
] = _linesearch.scale_by_zoom_linesearch(
24102410
max_linesearch_steps=20, initial_guess_strategy='one'
24112411
),

optax/_src/alias_test.py

+24
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,30 @@ def step(params, state):
187187

188188
chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2)
189189

190+
@parameterized.product(_OPTIMIZERS_UNDER_TEST)
191+
def test_optimizers_accept_extra_args(self, opt_name, opt_kwargs):
192+
opt = getattr(alias, opt_name)(**opt_kwargs)
193+
# intentionally ommit: opt = base.with_extra_args_support(opt)
194+
initial_params, _, objective = _setup_rosenbrock(jnp.float32)
195+
196+
@jax.jit
197+
def step(params, state):
198+
value, updates = jax.value_and_grad(objective)(params)
199+
update_kwargs = {'unexpected_extra_args_your_optimizer_doesnt_expect': 1}
200+
if opt_name in ['polyak_sgd']:
201+
update_kwargs = {'value': value}
202+
updates, state = opt.update(updates, state, params, **update_kwargs)
203+
params = update.apply_updates(params, updates)
204+
return params, state
205+
206+
params = initial_params
207+
with self.subTest('Test that init works with extra values'):
208+
state = opt.init(params)
209+
210+
with self.subTest('Test that update works with extra values'):
211+
for _ in range(2):
212+
params, state = step(params, state)
213+
190214
@chex.all_variants
191215
@parameterized.product(_OPTIMIZERS_UNDER_TEST)
192216
def test_optimizers_can_be_wrapped_in_inject_hyperparams(

optax/_src/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def __call__(
150150
class GradientTransformation(NamedTuple):
151151
# pylint: disable=line-too-long
152152
"""A pair of pure functions implementing a gradient transformation.
153+
154+
Prefer :class:`GradientTransformationExtraArgs` for new optimizers.
153155
154156
Optax optimizers are all implemented as *gradient transformations*.
155157
A gradient transformation is defined to be a pair of pure functions, which

optax/_src/transform.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -975,8 +975,8 @@ def init_fn(params):
975975
del params
976976
return ScaleByScheduleState(count=jnp.zeros([], jnp.int32))
977977

978-
def update_fn(updates, state, params=None):
979-
del params
978+
def update_fn(updates, state, params=None, **extra_args):
979+
del params, extra_args
980980
step_size = step_size_fn(state.count)
981981
updates = jax.tree.map(
982982
lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates

optax/contrib/_common_test.py

+29
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,35 @@ def obj_fn(params):
213213

214214
class ContribTest(chex.TestCase):
215215

216+
@parameterized.product(_ALL_OPTIMIZERS_UNDER_TEST, wrap=[True, False])
217+
def test_optimizers_accept_extra_args(
218+
self, opt_name, opt_kwargs, wrapper_name, wrapper_kwargs, wrap):
219+
opt = _get_opt_factory(opt_name)(**opt_kwargs)
220+
if wrap and wrapper_name is not None:
221+
opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs)
222+
# intentionally ommit: opt = base.with_extra_args_support(opt)
223+
224+
initial_params, _, objective = _setup_rosenbrock(jnp.float32)
225+
226+
@jax.jit
227+
def step(params, state):
228+
value, updates = jax.value_and_grad(objective)(params)
229+
update_kwargs = {'unexpected_extra_args_your_optimizer_doesnt_expect': 1}
230+
if opt_name in ['momo', 'momo_adam', 'sgd']:
231+
update_kwargs['value'] = value
232+
if opt_name in ['sophia']:
233+
update_kwargs['obj_fn'] = objective
234+
updates, state = opt.update(updates, state, params, **update_kwargs)
235+
params = update.apply_updates(params, updates)
236+
return params, state
237+
238+
params = initial_params
239+
state = opt.init(params)
240+
241+
with self.subTest('Test that update works with extra args'):
242+
for _ in range(2):
243+
params, state = step(params, state)
244+
216245
@parameterized.product(
217246
_ALL_OPTIMIZERS_UNDER_TEST,
218247
target=(

optax/contrib/_dadapt_adamw.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def dadapt_adamw(
4646
eps: float = 1e-8,
4747
estim_lr0: float = 1e-6,
4848
weight_decay: float = 0.0,
49-
) -> base.GradientTransformation:
49+
) -> base.GradientTransformationExtraArgs:
5050
"""Learning rate free AdamW by D-Adaptation.
5151
5252
Adapts the baseline learning rate of AdamW automatically by estimating the
@@ -91,7 +91,9 @@ def update_fn(
9191
updates: base.Updates,
9292
state: DAdaptAdamWState,
9393
params: Optional[base.Params] = None,
94+
**extra_args,
9495
) -> tuple[base.Updates, DAdaptAdamWState]:
96+
del extra_args
9597
if params is None:
9698
raise ValueError(base.NO_PARAMS_MSG)
9799
count = state.count
@@ -141,4 +143,4 @@ def update_fn(
141143
)
142144
return p_update, new_state
143145

144-
return base.GradientTransformation(init_fn, update_fn)
146+
return base.GradientTransformationExtraArgs(init_fn, update_fn)

0 commit comments

Comments
 (0)