@@ -38,19 +38,26 @@ For instance, suppose we want to solve the following optimization problem
38
38
39
39
which corresponds to the choice :math: `g(w, \text {l1 reg}) = \text {l1 reg} \cdot ||w||_1 `. The
40
40
corresponding ``prox `` operator is :func: `prox_lasso <jaxopt.prox.prox_lasso> `.
41
- We can therefore write::
41
+ We can therefore write:
42
42
43
- from jaxopt import ProximalGradient
44
- from jaxopt.prox import prox_lasso
43
+ .. doctest ::
45
44
46
- def least_squares(w, data):
47
- X, y = data
48
- residuals = jnp.dot(X, w) - y
49
- return jnp.mean(residuals ** 2)
45
+ >>> import jax.numpy as jnp
46
+ >>> from jaxopt import ProximalGradient
47
+ >>> from jaxopt.prox import prox_lasso
48
+ >>> from sklearn import datasets
49
+ >>> X, y = datasets.make_regression()
50
+
51
+ >>> def least_squares (w , data ):
52
+ ... inputs, targets = data
53
+ ... residuals = jnp.dot(inputs, w) - targets
54
+ ... return jnp.mean(residuals ** 2 )
55
+
56
+ >>> l1reg = 1.0
57
+ >>> w_init = jnp.zeros(n_features)
58
+ >>> pg = ProximalGradient(fun = least_squares, prox = prox_lasso)
59
+ >>> pg_sol = pg.run(w_init, hyperparams_prox = l1reg, data = (X, y)).params
50
60
51
- l1reg = 1.0
52
- pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
53
- pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
54
61
55
62
Note that :func: `prox_lasso <jaxopt.prox.prox_lasso> ` has a hyperparameter
56
63
``l1reg ``, which controls the :math: `L_1 ` regularization strength. As shown
@@ -65,13 +72,16 @@ Differentiation
65
72
66
73
In some applications, it is useful to differentiate the solution of the solver
67
74
with respect to some hyperparameters. Continuing the previous example, we can
68
- now differentiate the solution w.r.t. ``l1reg ``::
75
+ now differentiate the solution w.r.t. ``l1reg ``:
76
+
69
77
70
- def solution(l1reg):
71
- pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
72
- return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
78
+ .. doctest ::
73
79
74
- print(jax.jacobian(solution)(l1reg))
80
+ >>> def solution (l1reg ):
81
+ ... pg = ProximalGradient(fun = least_squares, prox = prox_lasso, implicit_diff = True )
82
+ ... return pg.run(w_init, hyperparams_prox = l1reg, data = (X, y)).params
83
+
84
+ >>> print (jax.jacobian(solution)(l1reg))
75
85
76
86
Under the hood, we use the implicit function theorem if ``implicit_diff=True ``
77
87
and autodiff of unrolled iterations if ``implicit_diff=False ``. See the
@@ -95,15 +105,18 @@ Block coordinate descent
95
105
Contrary to other solvers, :class: `jaxopt.BlockCoordinateDescent ` only works with
96
106
:ref: `composite linear objective functions <composite_linear_functions >`.
97
107
98
- Example::
108
+ Example:
109
+
110
+ .. doctest ::
99
111
100
- from jaxopt import objective
101
- from jaxopt import prox
112
+ >>> from jaxopt import objective
113
+ >>> from jaxopt import prox
114
+ >>> import jax.numpy as jnp
102
115
103
- l1reg = 1.0
104
- w_init = jnp.zeros(n_features)
105
- bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
106
- lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
116
+ >>> l1reg = 1.0
117
+ >>> w_init = jnp.zeros(n_features)
118
+ >>> bcd = jaxopt. BlockCoordinateDescent(fun = objective.least_squares, block_prox = prox.prox_lasso)
119
+ >>> lasso_sol = bcd.run(w_init, hyperparams_prox = l1reg, data = (X, y)).params
107
120
108
121
.. topic :: Examples
109
122
0 commit comments