Skip to content

Commit 74f89d2

Browse files
committed
test documentation docstrings
1 parent c488cd9 commit 74f89d2

10 files changed

+400
-283
lines changed

.github/workflows/tests.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ jobs:
6464
set -xe
6565
pip install --upgrade pip setuptools wheel
6666
pip install -r docs/requirements.txt
67-
- name: Build documentation
67+
- name: Test examples and docstrings
6868
run: |
6969
set -xe
7070
python -VV
71-
cd docs && make clean && make html
71+
make doctest

docs/Makefile

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33

44
# You can set these variables from the command line, and also
55
# from the environment for the first two.
6-
SPHINXOPTS ?=
76
SPHINXBUILD ?= sphinx-build
87
SOURCEDIR = .
98
BUILDDIR = _build
9+
SPHINXOPTS = -d $(BUILDDIR)/doctrees -T
1010

1111
# Put it first so that "make" without argument is like "make help".
1212
help:
1313
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
1414

15-
.PHONY: help Makefile
15+
.PHONY: help Makefile doctest
1616

1717
# Catch-all target: route all unknown targets to Sphinx using the new
1818
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
@@ -25,6 +25,11 @@ clean:
2525
rm -rf _autosummary/
2626

2727
html-noplot:
28-
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
28+
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(SPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
2929
@echo
3030
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
31+
32+
doctest:
33+
$(SPHINXBUILD) -b doctest $(SPHINXOPTS) . $(BUILDDIR)/doctest
34+
@echo "Testing of doctests in the sources finished, look at the " \
35+
"results in $(BUILDDIR)/doctest/output.txt."

docs/conf.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,14 @@
5050
'sphinx.ext.napoleon', # napoleon on top of autodoc: https://stackoverflow.com/a/66930447 might correct some warnings
5151
'sphinx.ext.autodoc',
5252
'sphinx.ext.autosummary',
53+
'sphinx.ext.doctest',
5354
'sphinx.ext.intersphinx',
5455
'sphinx.ext.mathjax',
5556
'sphinx.ext.viewcode',
5657
'matplotlib.sphinxext.plot_directive',
5758
'sphinx_autodoc_typehints',
5859
'myst_nb',
59-
"sphinx_remove_toctrees",
60+
'sphinx_remove_toctrees',
6061
'sphinx_rtd_theme',
6162
'sphinx_gallery.gen_gallery',
6263
'sphinx_copybutton',
@@ -70,7 +71,12 @@
7071
"backreferences_dir": os.path.join("modules", "generated"),
7172
}
7273

74+
# Specify how to identify the prompt when copying code snippets
75+
copybutton_prompt_text = r">>> |\.\.\. "
76+
copybutton_prompt_is_regexp = True
77+
copybutton_exclude = "style"
7378

79+
trim_doctests_flags = True
7480
source_suffix = ['.rst', '.ipynb', '.md']
7581

7682
autosummary_generate = True

docs/constrained.rst

+23-15
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@ To solve constrained optimization problems, we can use projected gradient
2929
descent, which is gradient descent with an additional projection onto the
3030
constraint set. Constraints are specified by setting the ``projection``
3131
argument. For instance, non-negativity constraints can be specified using
32-
:func:`projection_non_negative <jaxopt.projection.projection_non_negative>`::
32+
:func:`projection_non_negative <jaxopt.projection.projection_non_negative>`:
3333

34-
from jaxopt import ProjectedGradient
35-
from jaxopt.projection import projection_non_negative
34+
.. doctest::
35+
36+
>>> from jaxopt import ProjectedGradient
37+
>>> from jaxopt.projection import projection_non_negative
38+
39+
>>> pg = ProjectedGradient(fun=fun, projection=projection_non_negative)
40+
>>> pg_sol = pg.run(w_init, data=(X, y)).params
3641

37-
pg = ProjectedGradient(fun=fun, projection=projection_non_negative)
38-
pg_sol = pg.run(w_init, data=(X, y)).params
3942

4043
Numerous projections are available, see below.
4144

@@ -45,13 +48,15 @@ Specifying projection parameters
4548
Some projections have a hyperparameter that can be specified. For
4649
instance, the hyperparameter of :func:`projection_l2_ball
4750
<jaxopt.projection.projection_l2_ball>` is the radius of the :math:`L_2` ball.
48-
This can be passed using the ``hyperparams_proj`` argument of ``run``::
51+
This can be passed using the ``hyperparams_proj`` argument of ``run``:
52+
53+
.. doctest::
4954

50-
from jaxopt.projection import projection_l2_ball
55+
>>> from jaxopt.projection import projection_l2_ball
5156

52-
radius = 1.0
53-
pg = ProjectedGradient(fun=fun, projection=projection_l2_ball)
54-
pg_sol = pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
57+
>>> radius = 1.0
58+
>>> pg = ProjectedGradient(fun=fun, projection=projection_l2_ball)
59+
>>> pg_sol = pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
5560

5661
.. topic:: Examples
5762

@@ -62,13 +67,16 @@ Differentiation
6267

6368
In some applications, it is useful to differentiate the solution of the solver
6469
with respect to some hyperparameters. Continuing the previous example, we can
65-
now differentiate the solution w.r.t. ``radius``::
70+
now differentiate the solution w.r.t. ``radius``:
71+
72+
.. doctest::
73+
74+
>>> def solution(radius):
75+
... pg = ProjectedGradient(fun=fun, projection=projection_l2_ball, implicit_diff=True)
76+
... return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
6677

67-
def solution(radius):
68-
pg = ProjectedGradient(fun=fun, projection=projection_l2_ball, implicit_diff=True)
69-
return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
78+
>>> print(jax.jacobian(solution)(radius))
7079

71-
print(jax.jacobian(solution)(radius))
7280

7381
Under the hood, we use the implicit function theorem if ``implicit_diff=True``
7482
and autodiff of unrolled iterations if ``implicit_diff=False``. See the

docs/non_smooth.rst

+35-22
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,26 @@ For instance, suppose we want to solve the following optimization problem
3838
3939
which corresponds to the choice :math:`g(w, \text{l1reg}) = \text{l1reg} \cdot ||w||_1`. The
4040
corresponding ``prox`` operator is :func:`prox_lasso <jaxopt.prox.prox_lasso>`.
41-
We can therefore write::
41+
We can therefore write:
4242

43-
from jaxopt import ProximalGradient
44-
from jaxopt.prox import prox_lasso
43+
.. doctest::
4544

46-
def least_squares(w, data):
47-
X, y = data
48-
residuals = jnp.dot(X, w) - y
49-
return jnp.mean(residuals ** 2)
45+
>>> import jax.numpy as jnp
46+
>>> from jaxopt import ProximalGradient
47+
>>> from jaxopt.prox import prox_lasso
48+
>>> from sklearn import datasets
49+
>>> X, y = datasets.make_regression()
50+
51+
>>> def least_squares(w, data):
52+
... inputs, targets = data
53+
... residuals = jnp.dot(inputs, w) - targets
54+
... return jnp.mean(residuals ** 2)
55+
56+
>>> l1reg = 1.0
57+
>>> w_init = jnp.zeros(n_features)
58+
>>> pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
59+
>>> pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
5060

51-
l1reg = 1.0
52-
pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
53-
pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
5461

5562
Note that :func:`prox_lasso <jaxopt.prox.prox_lasso>` has a hyperparameter
5663
``l1reg``, which controls the :math:`L_1` regularization strength. As shown
@@ -65,13 +72,16 @@ Differentiation
6572

6673
In some applications, it is useful to differentiate the solution of the solver
6774
with respect to some hyperparameters. Continuing the previous example, we can
68-
now differentiate the solution w.r.t. ``l1reg``::
75+
now differentiate the solution w.r.t. ``l1reg``:
76+
6977

70-
def solution(l1reg):
71-
pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
72-
return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
78+
.. doctest::
7379

74-
print(jax.jacobian(solution)(l1reg))
80+
>>> def solution(l1reg):
81+
... pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
82+
... return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
83+
84+
>>> print(jax.jacobian(solution)(l1reg))
7585

7686
Under the hood, we use the implicit function theorem if ``implicit_diff=True``
7787
and autodiff of unrolled iterations if ``implicit_diff=False``. See the
@@ -95,15 +105,18 @@ Block coordinate descent
95105
Contrary to other solvers, :class:`jaxopt.BlockCoordinateDescent` only works with
96106
:ref:`composite linear objective functions <composite_linear_functions>`.
97107

98-
Example::
108+
Example:
109+
110+
.. doctest::
99111

100-
from jaxopt import objective
101-
from jaxopt import prox
112+
>>> from jaxopt import objective
113+
>>> from jaxopt import prox
114+
>>> import jax.numpy as jnp
102115

103-
l1reg = 1.0
104-
w_init = jnp.zeros(n_features)
105-
bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
106-
lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
116+
>>> l1reg = 1.0
117+
>>> w_init = jnp.zeros(n_features)
118+
>>> bcd = jaxopt.BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
119+
>>> lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
107120

108121
.. topic:: Examples
109122

0 commit comments

Comments
 (0)