diff --git a/benchmarks/sparse_custom_root.py b/benchmarks/sparse_custom_root.py new file mode 100644 index 00000000..b10acf48 --- /dev/null +++ b/benchmarks/sparse_custom_root.py @@ -0,0 +1,75 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import jax +import jax.numpy as jnp + +from jaxopt import prox +from jaxopt import implicit_diff as idf +from jaxopt._src import test_util +from jaxopt import objective + +from sklearn import datasets + + +def lasso_objective(params, lam, X, y): + residuals = jnp.dot(X, params) - y + return 0.5 * jnp.mean(residuals ** 2) / len(y) + lam * jnp.sum( + jnp.abs(params)) + + +def lasso_solver(params, X, y, lam): + sol = test_util.lasso_skl(X, y, lam) + return sol + + +X, y = datasets.make_regression( + n_samples=10, n_features=10_000, random_state=0) +lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) +lam = lam_max / 2 +L = jax.numpy.linalg.norm(X, ord=2) ** 2 + + +def make_restricted_optimality_fun(support): + def restricted_optimality_fun(restricted_params, X, y, lam): + # this is suboptimal, I would try to compute restricted_X once for all + restricted_X = X[:, support] + return lasso_optimality_fun(restricted_params, restricted_X, y, lam) + return restricted_optimality_fun + + +def lasso_optimality_fun(params, X, y, lam, tol=1e-4): + n_samples = X.shape[0] + return prox.prox_lasso( + params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L, lam * len(y) / L) - params + + +t_start = time.time() +lasso_solver_decorated = idf.custom_root(lasso_optimality_fun)(lasso_solver) +sol = test_util.lasso_skl(X=X, y=y, lam=lam) +J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) +t_custom = time.time() - t_start + + +t_start = time.time() +lasso_solver_decorated = idf.sparse_custom_root( + lasso_optimality_fun, make_restricted_optimality_fun)(lasso_solver) +sol = test_util.lasso_skl(X=X, y=y, lam=lam) +J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) +t_custom_sparse = time.time() - t_start + + +print("Time taken to compute the Jacobian %.3f" % t_custom) +print("Time taken to compute the Jacobian with the sparse implementation %.3f" % t_custom_sparse) diff --git a/benchmarks/sparse_vjp.py b/benchmarks/sparse_vjp.py new file mode 100644 index 00000000..3dfcc872 --- /dev/null +++ b/benchmarks/sparse_vjp.py @@ -0,0 +1,104 @@ +import time +import jax + +import jax.numpy as jnp +import numpy as onp + +from jaxopt import prox +from jaxopt import implicit_diff as idf +from jaxopt._src import test_util +from jaxopt import linear_solve +from jaxopt import objective + +from sklearn import datasets + +X, y = datasets.make_regression( + n_samples=100, n_features=100_000, random_state=0) + +L = onp.linalg.norm(X, ord=2) ** 2 + + +def optimality_fun(params, X, y, lam): + n_samples = X.shape[0] + return prox.prox_lasso( + params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L, + lam * len(y) / L) - params + + +def make_restricted_optimality_fun(support): + def restricted_optimality_fun(restricted_params, X, y, lam): + # this is suboptimal, I would try to compute restricted_X once for all + restricted_X = X[:, support] + return optimality_fun(restricted_params, restricted_X, y, lam) + return restricted_optimality_fun + + +lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) +lam = lam_max / 2 +t_start = time.time() +sol = test_util.lasso_skl(X, y, lam) +t_optim = time.time() - t_start + +onp.random.seed(0) +rand = onp.random.normal(0, 1, len(sol)) +dict_times = {} +dict_grad = {} + +for maxiter in [10, 100, 1000, 2000]: + def solve(matvec, b): + return linear_solve.solve_normal_cg( + matvec, b, None, tol=1e-32, maxiter=maxiter) + + vjp = lambda g: idf.root_vjp( + optimality_fun=optimality_fun, + sol=sol, + args=(X, y, lam), + cotangent=g, + solve=solve)[2] # vjp w.r.t. lam + + t_start = time.time() + grad = vjp(rand) + t_jac = time.time() - t_start + dict_times[maxiter] = t_jac + dict_grad[maxiter] = grad.copy() + + +def solve_sparse(matvec, b): + return linear_solve.solve_cg( + matvec, b, None, tol=1e-32, maxiter=(sol != 0).sum()) + + +vjp_sparse = lambda g: idf.sparse_root_vjp( + optimality_fun=optimality_fun, + make_restricted_optimality_fun=make_restricted_optimality_fun, + sol=sol, + args=(X, y, lam), + cotangent=g, + solve=solve_sparse)[2] # vjp w.r.t. lam + +vjp_sparse2 = lambda g: idf.sparse_root_vjp2( + optimality_fun=optimality_fun, + sol=sol, + args=(X, y, lam), + cotangent=g, + solve=solve_sparse)[2] # vjp w.r.t. lam + +t_start = time.time() +grad_sparse = vjp_sparse(rand) +t_jac_sparse = time.time() - t_start + +t_start = time.time() +grad_sparse2 = vjp_sparse(rand) +t_jac_sparse2 = time.time() - t_start + +print("Time taken to solve the Lasso optimization problem %.3f" % t_optim) +for maxiter in dict_times.keys(): + print("Time taken to compute the gradient with n= %i iterations %.3f | distance to the sparse gradient %.e" % ( + maxiter, dict_times[maxiter], jnp.linalg.norm(dict_grad[maxiter] - grad_sparse) / grad_sparse)) +print("Time taken to compute the gradient with the sparse implementation %.3f" % t_jac_sparse) +print("Time taken to compute the gradient with the sparse2 implementation %.3f" % t_jac_sparse2) + + +# Computation time are the same, which is very weird to me +# However, the Jacobian computed the sparse way is much closer to the real +# Jacobian diff --git a/examples/lasso_implicit_diff_sparse.py b/examples/lasso_implicit_diff_sparse.py new file mode 100644 index 00000000..dca673d0 --- /dev/null +++ b/examples/lasso_implicit_diff_sparse.py @@ -0,0 +1,122 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implicit differentiation of the lasso based on a sparse implementation.""" + +import time +from absl import app +import jax +import jax.numpy as jnp +import numpy as onp +from jaxopt import implicit_diff +from jaxopt import linear_solve +from jaxopt import OptaxSolver +from jaxopt import prox +from jaxopt import objective +from jaxopt._src import test_util +import optax +from sklearn import datasets +from sklearn import model_selection +from sklearn import preprocessing + +# def main(argv): +# del argv + +# Prepare data. +# X, y = datasets.load_boston(return_X_y=True) + +X, y = datasets.make_regression( + n_samples=30, n_features=10_000, random_state=0) + +# X = preprocessing.normalize(X) +# data = (X_tr, X_val, y_tr, y_val) +data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0) + +L = onp.linalg.norm(X, ord=2) ** 2 + + +def optimality_fun(params, lam, data): + X, y = data + n_samples = X.shape[0] + return prox.prox_lasso( + params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L, + lam * len(y) / L) - params + + +def make_restricted_optimality_fun(support): + def restricted_optimality_fun(restricted_params, lam, data): + # this is suboptimal, I would try to compute restricted_X once for all + X, y = data + restricted_X = X[:, support] + return optimality_fun(restricted_params, lam, (restricted_X, y)) + return restricted_optimality_fun + + +@implicit_diff.sparse_custom_root( + optimality_fun=optimality_fun, + make_restricted_optimality_fun=make_restricted_optimality_fun) +def lasso_solver(init_params, lam, data): + """Solve Lasso.""" + X_tr, y_tr = data + # TODO add warm start? + sol = test_util.lasso_skl(X, y, lam) + return sol + +# @implicit_diff.custom_root( +# optimality_fun=optimality_fun) +# def lasso_solver(init_params, lam, data): +# """Solve Lasso.""" +# X_tr, y_tr = data +# # TODO add warm start? +# sol = test_util.lasso_skl(X, y, lam) +# return sol + + +# Perhaps confusingly, theta is a parameter of the outer objective, +# but l2reg = jnp.exp(theta) is an hyper-parameter of the inner objective. +def outer_objective(theta, init_inner, data): + """Validation loss.""" + X_tr, X_val, y_tr, y_val = data + # We use the bijective mapping l2reg = jnp.exp(theta) + # both to optimize in log-space and to ensure positivity. + lam = jnp.exp(theta) + w_fit = lasso_solver(init_inner, lam, (X_tr, y_tr)) + y_pred = jnp.dot(X_val, w_fit) + loss_value = jnp.mean((y_pred - y_val) ** 2) + # We return w_fit as auxiliary data. + # Auxiliary data is stored in the optimizer state (see below). + return loss_value, w_fit + + +# Initialize solver. +solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True) +lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) +lam = lam_max / 10 +theta_init = jnp.log(lam) +theta, state = solver.init(theta_init) +init_w = jnp.zeros(X.shape[1]) + +t_start = time.time() +# Run outer loop. +for _ in range(10): + theta, state = solver.update( + params=theta, state=state, init_inner=init_w, data=data) + # The auxiliary data returned by the outer loss is stored in the state. + init_w = state.aux + print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.") +t_ellapsed = time.time() - t_start + +# if __name__ == "__main__": +# app.run(main) +print("Time taken for 10 iterations: %.2f" % t_ellapsed) diff --git a/jaxopt/_src/implicit_diff.py b/jaxopt/_src/implicit_diff.py index 206e86ed..dcf8923c 100644 --- a/jaxopt/_src/implicit_diff.py +++ b/jaxopt/_src/implicit_diff.py @@ -19,6 +19,7 @@ from typing import Callable from typing import Tuple +import numpy as np # to be removed, this is for the first draft import jax from jaxopt._src import linear_solve @@ -73,6 +74,121 @@ def fun_args(*args): return vjp_fun_args(u) +def sparse_root_vjp(optimality_fun: Callable, + make_restricted_optimality_fun: Callable, + sol: Any, + args: Tuple, + cotangent: Any, + solve: Callable = linear_solve.solve_cg) -> Any: + """Sparse vector-Jacobian product of a root. + + The invariant is ``optimality_fun(sol, *args) == 0``. + + Args: + optimality_fun: the optimality function to use. + F in the paper + make_restricted_optimality_fun: TODO XXX. + sol: solution / root (pytree). + args: tuple containing the arguments with respect to which we wish to + differentiate ``sol`` against. + cotangent: vector to left-multiply the Jacobian with + (pytree, same structure as ``sol``). + solve: a linear solver of the form, ``x = solve(matvec, b)``, + where ``matvec(x) = Ax`` and ``Ax=b``. + Returns: + vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t. + each argument. Each ``vjps[i]` has the same pytree structure as + ``args[i]``. + """ + support = sol != 0 # nonzeros coefficients of the solution + restricted_sol = sol[support] # solution restricted to the support + + restricted_optimality_fun = make_restricted_optimality_fun(support) + + def fun_sol(restricted_sol): + # We close over the arguments. + return restricted_optimality_fun(restricted_sol, *args) + + _, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol) + + # Compute the multiplication A^T u = (u^T A)^T resticted to the support. + def restricted_matvec(restricted_v): + return vjp_fun_sol(restricted_v)[0] + + # The solution of A^T u = v, where + # A = jacobian(optimality_fun, argnums=0) + # v = -cotangent. + restricted_v = tree_scalar_mul(-1, cotangent[support]) + restricted_u = solve(restricted_matvec, restricted_v) + + def fun_args(*args): + # We close over the solution. + return restricted_optimality_fun(restricted_sol, *args) + + _, vjp_fun_args = jax.vjp(fun_args, *args) + + return vjp_fun_args(restricted_u) + +def sparse_root_vjp2(optimality_fun: Callable, + # filter_args: Callable, + sol: Any, + args: Tuple, + cotangent: Any, + solve: Callable = linear_solve.solve_cg) -> Any: + """Sparse vector-Jacobian product of a root. + + The invariant is ``optimality_fun(sol, *args) == 0``. + + Args: + optimality_fun: the optimality function to use. + F in the paper + make_restricted_optimality_fun: TODO XXX. + sol: solution / root (pytree). + args: tuple containing the arguments with respect to which we wish to + differentiate ``sol`` against. + cotangent: vector to left-multiply the Jacobian with + (pytree, same structure as ``sol``). + solve: a linear solver of the form, ``x = solve(matvec, b)``, + where ``matvec(x) = Ax`` and ``Ax=b``. + Returns: + vjps: tuple of the same length as ``len(args)`` containing the vjps w.r.t. + each argument. Each ``vjps[i]` has the same pytree structure as + ``args[i]``. + """ + support = sol != 0 # nonzeros coefficients of the solution + restricted_sol = sol[support] # solution restricted to the support + + X, y, lam = args + new_args = X[:, support], y, lam + + def fun_sol(restricted_sol): + # We close over the arguments. + return optimality_fun(restricted_sol, *new_args) + + _, vjp_fun_sol = jax.vjp(fun_sol, restricted_sol) + + # Compute the multiplication A^T u = (u^T A)^T resticted to the support. + def restricted_matvec(restricted_v): + return vjp_fun_sol(restricted_v)[0] + + # The solution of A^T u = v, where + # A = jacobian(optimality_fun, argnums=0) + # v = -cotangent. + restricted_v = tree_scalar_mul(-1, cotangent[support]) + restricted_u = solve(restricted_matvec, restricted_v) + + def fun_args(*args): + # We close over the solution. + X, y, lam = args + new_args = X[:, support], y, lam + return optimality_fun(restricted_sol, *new_args) + + _, vjp_fun_args = jax.vjp(fun_args, *args) + # _, vjp_fun_args = jax.vjp(fun_args, *new_args) + + return vjp_fun_args(restricted_u) + + def _jvp_sol(optimality_fun, sol, args, tangent): """JVP in the first argument of optimality_fun.""" # We close over the arguments. @@ -144,6 +260,65 @@ def solver_fun_bwd(tup, cotangent): return wrapped_solver_fun +def _sparse_custom_root( + solver_fun, optimality_fun, make_restricted_optimality_fun, solve, has_aux): + def solver_fun_fwd(init_params, *args): + res = solver_fun(init_params, *args) + return res, (res, args) + + def solver_fun_bwd(tup, cotangent): + res, args = tup + + # solver_fun can return auxiliary data if has_aux = True. + if has_aux: + cotangent = cotangent[0] + sol = res[0] + else: + sol = res + + # Compute VJPs w.r.t. args. + vjps = sparse_root_vjp( + optimality_fun=optimality_fun, + make_restricted_optimality_fun=make_restricted_optimality_fun, + sol=sol, args=args, cotangent=cotangent, solve=solve) + # For init_params, we return None. + return (None,) + vjps + + wrapped_solver_fun = jax.custom_vjp(solver_fun) + wrapped_solver_fun.defvjp(solver_fun_fwd, solver_fun_bwd) + + return wrapped_solver_fun + + +def _sparse_custom_root2( + solver_fun, optimality_fun, solve, has_aux): + def solver_fun_fwd(init_params, *args): + res = solver_fun(init_params, *args) + return res, (res, args) + + def solver_fun_bwd(tup, cotangent): + res, args = tup + + # solver_fun can return auxiliary data if has_aux = True. + if has_aux: + cotangent = cotangent[0] + sol = res[0] + else: + sol = res + + # Compute VJPs w.r.t. args. + vjps = sparse_root_vjp2( + optimality_fun=optimality_fun, + sol=sol, args=args, cotangent=cotangent, solve=solve) + # For init_params, we return None. + return (None,) + vjps + + wrapped_solver_fun = jax.custom_vjp(solver_fun) + wrapped_solver_fun.defvjp(solver_fun_fwd, solver_fun_bwd) + + return wrapped_solver_fun + + def custom_root(optimality_fun: Callable, has_aux: bool = False, solve: Callable = linear_solve.solve_normal_cg): @@ -165,6 +340,53 @@ def wrapper(solver_fun): return wrapper +def sparse_custom_root(optimality_fun: Callable, + make_restricted_optimality_fun: Callable, + has_aux: bool = False, + solve: Callable = linear_solve.solve_normal_cg): + """Decorator for adding implicit differentiation to a root solver. + + Args: + optimality_fun: an equation function, ``optimality_fun(params, *args)`. + The invariant is ``optimality_fun(sol, *args) == 0`` at the + solution / root ``sol``. + has_aux: whether the decorated solver function returns auxiliary data. + solve: a linear solver of the form, ``solve(matvec, b)``. + + Returns: + A solver function decorator, i.e., + ``custom_root(optimality_fun)(solver_fun)``. + """ + def wrapper(solver_fun): + return _sparse_custom_root( + solver_fun, optimality_fun, make_restricted_optimality_fun, solve, + has_aux) + return wrapper + + +def sparse_custom_root2( + optimality_fun: Callable, has_aux: bool = False, + solve: Callable = linear_solve.solve_normal_cg): + """Decorator for adding implicit differentiation to a root solver. + + Args: + optimality_fun: an equation function, ``optimality_fun(params, *args)`. + The invariant is ``optimality_fun(sol, *args) == 0`` at the + solution / root ``sol``. + has_aux: whether the decorated solver function returns auxiliary data. + solve: a linear solver of the form, ``solve(matvec, b)``. + + Returns: + A solver function decorator, i.e., + ``custom_root(optimality_fun)(solver_fun)``. + """ + def wrapper(solver_fun): + return _sparse_custom_root2( + solver_fun, optimality_fun, solve, has_aux) + + return wrapper + + def custom_fixed_point(fixed_point_fun: Callable, has_aux: bool = False, solve: Callable = linear_solve.solve_normal_cg): diff --git a/jaxopt/implicit_diff.py b/jaxopt/implicit_diff.py index f043d428..3981ba7b 100644 --- a/jaxopt/implicit_diff.py +++ b/jaxopt/implicit_diff.py @@ -16,3 +16,7 @@ from jaxopt._src.implicit_diff import custom_fixed_point from jaxopt._src.implicit_diff import root_jvp from jaxopt._src.implicit_diff import root_vjp +from jaxopt._src.implicit_diff import sparse_root_vjp +from jaxopt._src.implicit_diff import sparse_root_vjp2 +from jaxopt._src.implicit_diff import sparse_custom_root +from jaxopt._src.implicit_diff import sparse_custom_root2 diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index cc36a900..7e690a14 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from numpy.core.numeric import ones from absl.testing import absltest from absl.testing import parameterized +import numpy as np import jax from jax import test_util as jtu import jax.numpy as jnp +from jaxopt import prox from jaxopt import implicit_diff as idf from jaxopt._src import test_util +from jaxopt import objective from sklearn import datasets @@ -30,7 +34,17 @@ def ridge_objective(params, lam, X, y): return 0.5 * jnp.mean(residuals ** 2) + 0.5 * lam * jnp.sum(params ** 2) -# def ridge_solver(init_params, lam, X, y): +def lasso_objective(params, lam, X, y): + residuals = jnp.dot(X, params) - y + return 0.5 * jnp.mean(residuals ** 2) / len(y) + lam * jnp.sum( + jnp.abs(params)) + + +def lasso_solver(params, X, y, lam): + sol = test_util.lasso_skl(X, y, lam) + return sol + + def ridge_solver(init_params, lam, X, y): del init_params # not used XX = jnp.dot(X.T, X) @@ -39,6 +53,27 @@ def ridge_solver(init_params, lam, X, y): return jnp.linalg.solve(XX + lam * len(y) * I, Xy) +X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) +lam_max = jnp.max(jnp.abs(X.T @ y)) / len(y) +lam = lam_max / 2 +L = jax.numpy.linalg.norm(X, ord=2) ** 2 + + +def make_restricted_optimality_fun(support): + def restricted_optimality_fun(restricted_params, X, y, lam): + # this is suboptimal, I would try to compute restricted_X once for all + restricted_X = X[:, support] + return lasso_optimality_fun(restricted_params, restricted_X, y, lam) + return restricted_optimality_fun + + +def lasso_optimality_fun(params, X, y, lam, tol=1e-4): + n_samples = X.shape[0] + return prox.prox_lasso( + params - jax.grad(objective.least_squares)(params, (X, y)) * n_samples / L, + lam * len(y) / L) - params + + class ImplicitDiffTest(jtu.JaxTestCase): def test_root_vjp(self): @@ -55,6 +90,38 @@ def test_root_vjp(self): J_num = test_util.ridge_solver_jac(X, y, lam, eps=1e-4) self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_lasso_root_vjp(self): + sol = test_util.lasso_skl(X, y, lam) + vjp = lambda g: idf.root_vjp(optimality_fun=lasso_optimality_fun, + sol=sol, + args=(X, y, lam), + cotangent=g)[2] # vjp w.r.t. lam + I = jnp.eye(len(sol)) + J = jax.vmap(vjp)(I) + J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) + self.assertArraysAllClose(J, J_num, atol=5e-2) + + def test_lasso_sparse_root_vjp(self): + sol = test_util.lasso_skl(X, y, lam) + + vjp = lambda g: idf.sparse_root_vjp( + optimality_fun=lasso_optimality_fun, + make_restricted_optimality_fun=make_restricted_optimality_fun, + sol=sol, + args=(X, y, lam), + cotangent=g)[2] # vjp w.r.t. lam + vjp2 = lambda g: idf.sparse_root_vjp2( + optimality_fun=lasso_optimality_fun, + sol=sol, + args=(X, y, lam), + cotangent=g)[2] # vjp w.r.t. lam + I = jnp.eye(len(sol)) + J = jax.vmap(vjp)(I) + J2 = jax.vmap(vjp2)(I) + J_num = test_util.lasso_skl_jac(X, y, lam, eps=1e-4) + self.assertArraysAllClose(J, J_num, atol=5e-2) + self.assertArraysAllClose(J2, J_num, atol=5e-2) + def test_root_jvp(self): X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0) optimality_fun = jax.grad(ridge_objective) @@ -79,6 +146,29 @@ def test_custom_root(self): J = jax.jacrev(ridge_solver_decorated, argnums=1)(None, lam, X=X, y=y) self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_custom_root_lasso(self): + lasso_solver_decorated = idf.custom_root( + lasso_optimality_fun)(lasso_solver) + sol = test_util.lasso_skl(X=X, y=y, lam=lam) + sol_decorated = lasso_solver_decorated(None, X=X, y=y, lam=lam) + self.assertArraysAllClose(sol, sol_decorated, atol=1e-4) + J_num = test_util.lasso_skl_jac(X=X, y=y, lam=lam, tol=1e-4) + J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) + self.assertArraysAllClose(J, J_num, atol=5e-2) + + def test_sparse_custom_root_lasso(self): + lasso_solver_decorated = idf.sparse_custom_root( + lasso_optimality_fun, make_restricted_optimality_fun)(lasso_solver) + lasso_solver_decorated2 = idf.sparse_custom_root2( + lasso_optimality_fun)(lasso_solver) + sol = test_util.lasso_skl(X=X, y=y, lam=lam) + sol_decorated = lasso_solver_decorated(None, X=X, y=y, lam=lam) + self.assertArraysAllClose(sol, sol_decorated, atol=1e-4) + J_num = test_util.lasso_skl_jac(X=X, y=y, lam=lam, tol=1e-4) + J = jax.jacrev(lasso_solver_decorated, argnums=3)(None, X, y, lam) + J2 = jax.jacrev(lasso_solver_decorated2, argnums=3)(None, X, y, lam) + self.assertArraysAllClose(J, J_num, atol=5e-2) + def test_custom_root_with_has_aux(self): def ridge_solver_with_aux(init_params, lam, X, y): return ridge_solver(init_params, lam, X, y), None