@@ -39,7 +39,7 @@ def adabelief(
39
39
eps_root : float = 1e-16 ,
40
40
* ,
41
41
nesterov : bool = False ,
42
- ) -> base .GradientTransformation :
42
+ ) -> base .GradientTransformationExtraArgs :
43
43
r"""The AdaBelief optimizer.
44
44
45
45
AdaBelief is an adaptive learning rate optimizer that focuses on fast
@@ -141,7 +141,7 @@ def adadelta(
141
141
eps : float = 1e-6 ,
142
142
weight_decay : float = 0.0 ,
143
143
weight_decay_mask : MaskOrFn = None ,
144
- ) -> base .GradientTransformation :
144
+ ) -> base .GradientTransformationExtraArgs :
145
145
"""The Adadelta optimizer.
146
146
147
147
Adadelta is a stochastic gradient descent method that adapts learning rates
@@ -208,7 +208,7 @@ def adafactor(
208
208
eps : float = 1e-30 ,
209
209
factored : bool = True ,
210
210
weight_decay_mask : MaskOrFn = None ,
211
- ) -> base .GradientTransformation :
211
+ ) -> base .GradientTransformationExtraArgs :
212
212
"""The Adafactor optimizer.
213
213
214
214
Adafactor is an adaptive learning rate optimizer that focuses on fast
@@ -304,7 +304,7 @@ def adagrad(
304
304
learning_rate : base .ScalarOrSchedule ,
305
305
initial_accumulator_value : float = 0.1 ,
306
306
eps : float = 1e-7 ,
307
- ) -> base .GradientTransformation :
307
+ ) -> base .GradientTransformationExtraArgs :
308
308
r"""The Adagrad optimizer.
309
309
310
310
AdaGrad is a sub-gradient algorithm for stochastic optimization that adapts
@@ -394,7 +394,7 @@ def adam(
394
394
mu_dtype : Optional [Any ] = None ,
395
395
* ,
396
396
nesterov : bool = False ,
397
- ) -> base .GradientTransformation :
397
+ ) -> base .GradientTransformationExtraArgs :
398
398
r"""The Adam optimizer.
399
399
400
400
Adam is an SGD variant with gradient scaling adaptation. The scaling
@@ -580,7 +580,7 @@ def adamw(
580
580
mask : Optional [Union [Any , Callable [[base .Params ], Any ]]] = None ,
581
581
* ,
582
582
nesterov : bool = False ,
583
- ) -> base .GradientTransformation :
583
+ ) -> base .GradientTransformationExtraArgs :
584
584
r"""Adam with weight decay regularization.
585
585
586
586
AdamW uses weight decay to regularize learning towards small weights, as
@@ -789,7 +789,7 @@ def adan(
789
789
eps_root : float = 1e-8 ,
790
790
weight_decay : float = 0.0 ,
791
791
mask : Optional [Union [Any , Callable [[base .Params ], Any ]]] = None ,
792
- ) -> base .GradientTransformation :
792
+ ) -> base .GradientTransformationExtraArgs :
793
793
r"""The ADAptive Nesterov momentum algorithm (Adan).
794
794
795
795
Adan first reformulates the vanilla Nesterov acceleration to develop a new
@@ -905,7 +905,7 @@ def lion(
905
905
mu_dtype : Optional [Any ] = None ,
906
906
weight_decay : float = 1e-3 ,
907
907
mask : Optional [Union [Any , Callable [[base .Params ], Any ]]] = None ,
908
- ) -> base .GradientTransformation :
908
+ ) -> base .GradientTransformationExtraArgs :
909
909
r"""The Lion optimizer.
910
910
911
911
Lion is discovered by symbolic program search. Unlike most adaptive optimizers
@@ -1001,7 +1001,7 @@ def amsgrad(
1001
1001
eps : float = 1e-8 ,
1002
1002
eps_root : float = 0.0 ,
1003
1003
mu_dtype : Optional [Any ] = None ,
1004
- ) -> base .GradientTransformation :
1004
+ ) -> base .GradientTransformationExtraArgs :
1005
1005
"""The AMSGrad optimizer.
1006
1006
1007
1007
The original Adam can fail to converge to the optimal solution in some cases.
@@ -1058,7 +1058,7 @@ def amsgrad(
1058
1058
1059
1059
def fromage (
1060
1060
learning_rate : float , min_norm : float = 1e-6
1061
- ) -> base .GradientTransformation :
1061
+ ) -> base .GradientTransformationExtraArgs :
1062
1062
"""The Frobenius matched gradient descent (Fromage) optimizer.
1063
1063
1064
1064
Fromage is a learning algorithm that does not require learning rate tuning.
@@ -1119,7 +1119,7 @@ def lars(
1119
1119
trust_ratio_mask : MaskOrFn = True ,
1120
1120
momentum : float = 0.9 ,
1121
1121
nesterov : bool = False ,
1122
- ) -> base .GradientTransformation :
1122
+ ) -> base .GradientTransformationExtraArgs :
1123
1123
"""The LARS optimizer.
1124
1124
1125
1125
LARS is a layer-wise adaptive optimizer introduced to help scale SGD to
@@ -1191,7 +1191,7 @@ def lamb(
1191
1191
eps_root : float = 0.0 ,
1192
1192
weight_decay : float = 0.0 ,
1193
1193
mask : MaskOrFn = None ,
1194
- ) -> base .GradientTransformation :
1194
+ ) -> base .GradientTransformationExtraArgs :
1195
1195
"""The LAMB optimizer.
1196
1196
1197
1197
LAMB is a general purpose layer-wise adaptive large batch optimizer designed
@@ -1257,7 +1257,7 @@ def noisy_sgd(
1257
1257
eta : float = 0.01 ,
1258
1258
gamma : float = 0.55 ,
1259
1259
seed : int = 0 ,
1260
- ) -> base .GradientTransformation :
1260
+ ) -> base .GradientTransformationExtraArgs :
1261
1261
r"""A variant of SGD with added noise.
1262
1262
1263
1263
Noisy SGD is a variant of :func:`optax.sgd` that incorporates Gaussian noise
@@ -1325,7 +1325,7 @@ def noisy_sgd(
1325
1325
1326
1326
def sign_sgd (
1327
1327
learning_rate : base .ScalarOrSchedule ,
1328
- ) -> base .GradientTransformation :
1328
+ ) -> base .GradientTransformationExtraArgs :
1329
1329
r"""A variant of SGD using only the signs of the gradient components.
1330
1330
1331
1331
SignSGD is a variant of SGD that uses the signs of the gradient components in
@@ -1394,7 +1394,7 @@ def novograd(
1394
1394
eps : float = 1e-6 ,
1395
1395
eps_root : float = 0.0 ,
1396
1396
weight_decay : float = 0.0 ,
1397
- ) -> base .GradientTransformation :
1397
+ ) -> base .GradientTransformationExtraArgs :
1398
1398
"""NovoGrad optimizer.
1399
1399
1400
1400
NovoGrad is more robust to the initial learning rate and
@@ -1461,7 +1461,7 @@ def optimistic_gradient_descent(
1461
1461
learning_rate : base .ScalarOrSchedule ,
1462
1462
alpha : base .ScalarOrSchedule = 1.0 ,
1463
1463
beta : base .ScalarOrSchedule = 1.0 ,
1464
- ) -> base .GradientTransformation :
1464
+ ) -> base .GradientTransformationExtraArgs :
1465
1465
"""An Optimistic Gradient Descent optimizer.
1466
1466
1467
1467
Optimistic gradient descent is an approximation of extra-gradient methods
@@ -1523,7 +1523,7 @@ def optimistic_adam(
1523
1523
mu_dtype : Optional [Any ] = None ,
1524
1524
* ,
1525
1525
nesterov : bool = True ,
1526
- ) -> base .GradientTransformation :
1526
+ ) -> base .GradientTransformationExtraArgs :
1527
1527
r"""The Optimistic Adam optimizer.
1528
1528
1529
1529
This is an optimistic version of the Adam optimizer. It addresses the issue
@@ -1643,7 +1643,7 @@ def radam(
1643
1643
threshold : float = 5.0 ,
1644
1644
* ,
1645
1645
nesterov : bool = False ,
1646
- ) -> base .GradientTransformation :
1646
+ ) -> base .GradientTransformationExtraArgs :
1647
1647
"""The Rectified Adam optimizer.
1648
1648
1649
1649
The adaptive learning rate in Adam has undesirably large variance in early
@@ -1715,7 +1715,7 @@ def rmsprop(
1715
1715
momentum : Optional [float ] = None ,
1716
1716
nesterov : bool = False ,
1717
1717
bias_correction : bool = False ,
1718
- ) -> base .GradientTransformation :
1718
+ ) -> base .GradientTransformationExtraArgs :
1719
1719
r"""A flexible RMSProp optimizer.
1720
1720
1721
1721
RMSProp is an SGD variant with learning rate adaptation. The `learning_rate`
@@ -1824,7 +1824,7 @@ def sgd(
1824
1824
momentum : Optional [float ] = None ,
1825
1825
nesterov : bool = False ,
1826
1826
accumulator_dtype : Optional [Any ] = None ,
1827
- ) -> base .GradientTransformation :
1827
+ ) -> base .GradientTransformationExtraArgs :
1828
1828
r"""A canonical Stochastic Gradient Descent optimizer.
1829
1829
1830
1830
This implements stochastic gradient descent. It also includes support for
@@ -1911,7 +1911,7 @@ def sgd(
1911
1911
1912
1912
def sm3 (
1913
1913
learning_rate : float , momentum : float = 0.9
1914
- ) -> base .GradientTransformation :
1914
+ ) -> base .GradientTransformationExtraArgs :
1915
1915
r"""The SM3 optimizer.
1916
1916
1917
1917
SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a
@@ -2024,7 +2024,7 @@ def yogi(
2024
2024
b1 : float = 0.9 ,
2025
2025
b2 : float = 0.999 ,
2026
2026
eps : float = 1e-3 ,
2027
- ) -> base .GradientTransformation :
2027
+ ) -> base .GradientTransformationExtraArgs :
2028
2028
# pylint: disable=line-too-long
2029
2029
"""The Yogi optimizer.
2030
2030
@@ -2083,7 +2083,7 @@ def adamax(
2083
2083
b1 : float = 0.9 ,
2084
2084
b2 : float = 0.999 ,
2085
2085
eps : float = 1e-8 ,
2086
- ) -> base .GradientTransformation :
2086
+ ) -> base .GradientTransformationExtraArgs :
2087
2087
r"""A variant of the Adam optimizer that uses the infinity norm.
2088
2088
2089
2089
AdaMax is a variant of the :func:`optax.adam` optimizer. By generalizing
@@ -2170,7 +2170,7 @@ def adamaxw(
2170
2170
eps : float = 1e-8 ,
2171
2171
weight_decay : float = 1e-4 ,
2172
2172
mask : Optional [Union [Any , Callable [[base .Params ], Any ]]] = None ,
2173
- ) -> base .GradientTransformation :
2173
+ ) -> base .GradientTransformationExtraArgs :
2174
2174
"""Adamax with weight decay regularization.
2175
2175
2176
2176
AdamaxW uses weight decay to regularize learning towards small weights, as
@@ -2244,7 +2244,7 @@ def rprop(
2244
2244
eta_plus : float = 1.2 ,
2245
2245
min_step_size : float = 1e-6 ,
2246
2246
max_step_size : float = 50.0 ,
2247
- ) -> base .GradientTransformation :
2247
+ ) -> base .GradientTransformationExtraArgs :
2248
2248
"""The Rprop optimizer.
2249
2249
2250
2250
Rprop, short for resillient backpropogation, is a first order variant of
@@ -2405,7 +2405,7 @@ def lbfgs(
2405
2405
memory_size : int = 10 ,
2406
2406
scale_init_precond : bool = True ,
2407
2407
linesearch : Optional [
2408
- base .GradientTransformationExtraArgs
2408
+ Union [ base .GradientTransformationExtraArgs , base . GradientTransformation ]
2409
2409
] = _linesearch .scale_by_zoom_linesearch (
2410
2410
max_linesearch_steps = 20 , initial_guess_strategy = 'one'
2411
2411
),
0 commit comments