|
17 | 17 | from functools import partial
|
18 | 18 | from typing import Any
|
19 | 19 | from typing import Callable
|
| 20 | +from typing import List |
20 | 21 | from typing import Tuple
|
21 | 22 |
|
22 | 23 | import jax
|
23 | 24 | import jax.numpy as jnp
|
24 | 25 | from jax.scipy.special import logsumexp
|
25 | 26 |
|
| 27 | +from jaxopt._src.fixed_point_iteration import FixedPointIteration |
26 | 28 | from jaxopt._src.bisection import Bisection
|
27 | 29 | from jaxopt._src.eq_qp import EqualityConstrainedQP
|
28 | 30 | from jaxopt._src.lbfgs import LBFGS
|
29 | 31 | from jaxopt._src.osqp import OSQP, BoxOSQP
|
30 | 32 | from jaxopt._src import tree_util
|
31 | 33 |
|
32 | 34 |
|
| 35 | +def alternating_projections(initial_guess: Any, |
| 36 | + projections: List, |
| 37 | + hyperparams: List, |
| 38 | + **fixed_point_params) -> Any: |
| 39 | + """Alternating projections algorithm. |
| 40 | +
|
| 41 | + This algorithm returns a point in the intersection of convex sets |
| 42 | + by projecting onto each set in turn. |
| 43 | +
|
| 44 | + If the sets are not convex, or if their intersection is empty, |
| 45 | + this algorithm may not converge. |
| 46 | + |
| 47 | + If the sets are convex and their intersection is non empty, |
| 48 | + the algorithm converges to a point `p*` in the intersection of the sets. |
| 49 | + However this point `p*` is not necessarily the closest to the initial guess, |
| 50 | + i.e alternating_projections is not a valid projection itself. |
| 51 | + |
| 52 | + If the inittial guess lies in the intersection of the sets, then |
| 53 | + the algorithm converges to this point. Hence this algorithm is a retraction. |
| 54 | + If the initial guess lies outside the intersection, and if the intersection |
| 55 | + contains more than one point, then the algorithm converges to an arbitrary |
| 56 | + point in the intersection. |
| 57 | +
|
| 58 | + Implicit differentiation will measure the sensitivity of `p*` |
| 59 | + to perturbations in the `hyperparams`, but not to perturbations |
| 60 | + in the initial guess. |
| 61 | +
|
| 62 | + Args: |
| 63 | + projections: a sequence of projections, each of which is a function that |
| 64 | + with signature ``x, hyperparams -> x``. |
| 65 | + hyperparams: a list of hyperparameters for each projection, each being a |
| 66 | + pytree. |
| 67 | + **fixed_point_params: parameters for the fixed point solver. |
| 68 | + Returns: |
| 69 | + A Pytree lying in the intersection of the sets. |
| 70 | +
|
| 71 | + References: |
| 72 | + Escalante, R. and Raydan, M., 2011. Alternating projection methods. |
| 73 | + Society for Industrial and Applied Mathematics. |
| 74 | + """ |
| 75 | + assert len(projections) == len(hyperparams) |
| 76 | + |
| 77 | + def composed_projections(x, hyperparams): |
| 78 | + for proj, hparam in zip(projections, hyperparams): |
| 79 | + x = proj(x, hparam) |
| 80 | + return x |
| 81 | + |
| 82 | + if 'maxiter' not in fixed_point_params: |
| 83 | + fixed_point_params["maxiter"] = 100 |
| 84 | + if 'tol' not in fixed_point_params: |
| 85 | + fixed_point_params["tol"] = 1e-5 |
| 86 | + |
| 87 | + # look for a fixed point of this operator |
| 88 | + solver = FixedPointIteration(fixed_point_fun=composed_projections, |
| 89 | + **fixed_point_params) |
| 90 | + fixed_point = solver.run(initial_guess, hyperparams).params |
| 91 | + return fixed_point |
| 92 | + |
| 93 | + |
33 | 94 | def projection_non_negative(x: Any, hyperparams=None) -> Any:
|
34 | 95 | r"""Projection onto the non-negative orthant:
|
35 | 96 |
|
|
0 commit comments