|
| 1 | +# Copyright 2021 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import jax |
| 16 | +from jax import test_util as jtu |
| 17 | +import jax.numpy as jnp |
| 18 | + |
| 19 | +from jaxopt._src.eq_qp_preconditioned import PseudoInversePreconditionedEqQP |
| 20 | +from jaxopt import EqualityConstrainedQP |
| 21 | +import numpy as onp |
| 22 | + |
| 23 | + |
| 24 | +class PreconditionedEqualityConstrainedQPTest(jtu.JaxTestCase): |
| 25 | + def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b): |
| 26 | + def fun(Q, c, A, b): |
| 27 | + Q = 0.5 * (Q + Q.T) |
| 28 | + |
| 29 | + hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) |
| 30 | + # reduce the primal variables to a scalar value for test purpose. |
| 31 | + return jnp.sum(solver.run(**hyperparams).params[0]) |
| 32 | + |
| 33 | + # Derivative w.r.t. A. |
| 34 | + rng = onp.random.RandomState(0) |
| 35 | + V = rng.rand(*A.shape) |
| 36 | + V /= onp.sqrt(onp.sum(V ** 2)) |
| 37 | + eps = 1e-4 |
| 38 | + deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b)) |
| 39 | + deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps) |
| 40 | + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) |
| 41 | + |
| 42 | + # Derivative w.r.t. b. |
| 43 | + v = rng.rand(*b.shape) |
| 44 | + v /= onp.sqrt(onp.sum(v ** 2)) |
| 45 | + eps = 1e-4 |
| 46 | + deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b)) |
| 47 | + deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps) |
| 48 | + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) |
| 49 | + |
| 50 | + # Derivative w.r.t. Q |
| 51 | + W = rng.rand(*Q.shape) |
| 52 | + W /= onp.sqrt(onp.sum(W ** 2)) |
| 53 | + eps = 1e-4 |
| 54 | + deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b)) |
| 55 | + deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps) |
| 56 | + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) |
| 57 | + |
| 58 | + # Derivative w.r.t. c |
| 59 | + w = rng.rand(*c.shape) |
| 60 | + w /= onp.sqrt(onp.sum(w ** 2)) |
| 61 | + eps = 1e-4 |
| 62 | + deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b)) |
| 63 | + deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps) |
| 64 | + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) |
| 65 | + |
| 66 | + def test_pseudoinverse_preconditioner(self): |
| 67 | + Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) |
| 68 | + c = jnp.array([1.0, 1.0]) |
| 69 | + A = jnp.array([[1.0, 1.0]]) |
| 70 | + b = jnp.array([1.0]) |
| 71 | + qp = EqualityConstrainedQP(tol=1e-7) |
| 72 | + preconditioned_qp = PseudoInversePreconditionedEqQP(qp) |
| 73 | + params_obj = (Q, c) |
| 74 | + params_eq = (A, b) |
| 75 | + params_precond = preconditioned_qp.init_params(params_obj, params_eq) |
| 76 | + hyperparams = dict( |
| 77 | + params_obj=params_obj, |
| 78 | + params_eq=params_eq, |
| 79 | + ) |
| 80 | + sol = preconditioned_qp.run(**hyperparams, params_precond=params_precond).params |
| 81 | + self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) |
| 82 | + self._check_derivative_Q_c_A_b(qp, Q, c, A, b) |
0 commit comments