Skip to content

Commit c3dcdf8

Browse files
committed
Add Alternating Projections retraction
1 parent 1572796 commit c3dcdf8

File tree

4 files changed

+96
-0
lines changed

4 files changed

+96
-0
lines changed

docs/constrained.rst

+8
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,14 @@ The following operators are available.
109109
Projections always have two arguments: the input to be projected and the
110110
parameters of the convex set.
111111

112+
Note that a retraction is also provided, that allows to retrieve
113+
an arbitrary point lying in the intersection of convex sets.
114+
115+
.. autosummary::
116+
:toctree: _autosummary
117+
118+
jaxopt.projection.alternating_projections
119+
112120
Mirror descent
113121
--------------
114122

jaxopt/_src/projection.py

+61
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,80 @@
1717
from functools import partial
1818
from typing import Any
1919
from typing import Callable
20+
from typing import List
2021
from typing import Tuple
2122

2223
import jax
2324
import jax.numpy as jnp
2425
from jax.scipy.special import logsumexp
2526

27+
from jaxopt._src.fixed_point_iteration import FixedPointIteration
2628
from jaxopt._src.bisection import Bisection
2729
from jaxopt._src.eq_qp import EqualityConstrainedQP
2830
from jaxopt._src.lbfgs import LBFGS
2931
from jaxopt._src.osqp import OSQP, BoxOSQP
3032
from jaxopt._src import tree_util
3133

3234

35+
def alternating_projections(initial_guess: Any,
36+
projections: List,
37+
hyperparams: List,
38+
**fixed_point_params) -> Any:
39+
"""Alternating projections algorithm.
40+
41+
This algorithm returns a point in the intersection of convex sets
42+
by projecting onto each set in turn.
43+
44+
If the sets are not convex, or if their intersection is empty,
45+
this algorithm may not converge.
46+
47+
If the sets are convex and their intersection is non empty,
48+
the algorithm converges to a point `p*` in the intersection of the sets.
49+
However this point `p*` is not necessarily the closest to the initial guess,
50+
i.e alternating_projections is not a valid projection itself.
51+
52+
If the inittial guess lies in the intersection of the sets, then
53+
the algorithm converges to this point. Hence this algorithm is a retraction.
54+
If the initial guess lies outside the intersection, and if the intersection
55+
contains more than one point, then the algorithm converges to an arbitrary
56+
point in the intersection.
57+
58+
Implicit differentiation will measure the sensitivity of `p*`
59+
to perturbations in the `hyperparams`, but not to perturbations
60+
in the initial guess.
61+
62+
Args:
63+
projections: a sequence of projections, each of which is a function that
64+
with signature ``x, hyperparams -> x``.
65+
hyperparams: a list of hyperparameters for each projection, each being a
66+
pytree.
67+
**fixed_point_params: parameters for the fixed point solver.
68+
Returns:
69+
A Pytree lying in the intersection of the sets.
70+
71+
References:
72+
Escalante, R. and Raydan, M., 2011. Alternating projection methods.
73+
Society for Industrial and Applied Mathematics.
74+
"""
75+
assert len(projections) == len(hyperparams)
76+
77+
def composed_projections(x, hyperparams):
78+
for proj, hparam in zip(projections, hyperparams):
79+
x = proj(x, hparam)
80+
return x
81+
82+
if 'maxiter' not in fixed_point_params:
83+
fixed_point_params["maxiter"] = 100
84+
if 'tol' not in fixed_point_params:
85+
fixed_point_params["tol"] = 1e-5
86+
87+
# look for a fixed point of this operator
88+
solver = FixedPointIteration(fixed_point_fun=composed_projections,
89+
**fixed_point_params)
90+
fixed_point = solver.run(initial_guess, hyperparams).params
91+
return fixed_point
92+
93+
3394
def projection_non_negative(x: Any, hyperparams=None) -> Any:
3495
r"""Projection onto the non-negative orthant:
3596

jaxopt/projection.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from jaxopt._src.projection import alternating_projections
1516
from jaxopt._src.projection import projection_non_negative
1617
from jaxopt._src.projection import projection_box
1718
from jaxopt._src.projection import projection_hypercube

tests/projection_test.py

+26
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,32 @@ def test_projection_birkhoff(self):
439439
solution1 = projection.projection_birkhoff(doubly_stochastic_matrix)
440440
self.assertArraysAllClose(doubly_stochastic_matrix, solution1)
441441

442+
def test_alternating_projections(self):
443+
# x1 + x2 = 1
444+
x = jnp.array([-2.0, 1.0, 3.0])
445+
a = jnp.array([ 1.0, 1.0, 0.])
446+
b = jnp.array(1.0)
447+
448+
# l2 ball of radius 1.5
449+
radius = jnp.array(1.5)
450+
451+
def retract_on_disk_intercept(b):
452+
# The intersection of a ball with an hyperplane is a disk.
453+
retract_on_disk = [projection.projection_l2_ball,
454+
projection.projection_hyperplane]
455+
hyper_params = [radius, (a, b)]
456+
in_disk = projection.alternating_projections(x, retract_on_disk, hyper_params)
457+
458+
return in_disk
459+
460+
in_disk = retract_on_disk_intercept(b)
461+
atol = 1e-5
462+
self.assertLessEqual(jnp.linalg.norm(in_disk), radius + atol)
463+
self.assertArraysAllClose(jnp.dot(a, in_disk), jnp.array(b), atol=atol)
464+
465+
# test that there is no error.
466+
unused_jac = jax.jacrev(retract_on_disk_intercept)(b)
467+
442468
def test_projection_sparse_simplex(self):
443469
def top_k(x, k):
444470
"""Preserve the top-k entries of the vector x and put -inf values elsewhere.

0 commit comments

Comments
 (0)