Skip to content

Commit

Permalink
Add support for state save/load, and other minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 21, 2025
1 parent db1c417 commit 6aa1549
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 15 deletions.
8 changes: 5 additions & 3 deletions scico/optimize/_admm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -197,8 +197,10 @@ def _itstat_extra_fields(self):

return itstat_fields, itstat_attrib

def minimizer(self):
"""Return current estimate of the functional mimimizer."""
def _state_variable_names(self) -> List[str]:
return ["z_list", "z_list_old", "u_list"]

def minimizer(self) -> Union[Array, BlockArray]:
return self.x

def objective(
Expand Down
8 changes: 5 additions & 3 deletions scico/optimize/_ladmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2024 by SCICO Developers
# Copyright (C) 2021-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -142,8 +142,10 @@ def _itstat_extra_fields(self):
itstat_attrib = ["norm_primal_residual()", "norm_dual_residual()"]
return itstat_fields, itstat_attrib

def minimizer(self):
"""Return current estimate of the functional mimimizer."""
def _state_variable_names(self) -> List[str]:
return ["x", "z", "z_old", "u"]

def minimizer(self) -> Union[Array, BlockArray]:
return self.x

def objective(
Expand Down
8 changes: 5 additions & 3 deletions scico/optimize/_padmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2024 by SCICO Developers
# Copyright (C) 2022-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -137,8 +137,10 @@ def _itstat_extra_fields(self):
itstat_attrib = ["norm_primal_residual()", "norm_dual_residual()"]
return itstat_fields, itstat_attrib

def minimizer(self):
"""Return current estimate of the functional mimimizer."""
def _state_variable_names(self) -> List[str]:
return ["x", "z", "z_old", "u", "u_old"]

def minimizer(self) -> Union[Array, BlockArray]:
return self.x

def objective(
Expand Down
4 changes: 1 addition & 3 deletions scico/optimize/_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,9 @@ def _itstat_extra_fields(self):
return itstat_fields, itstat_attrib

def _state_variable_names(self) -> List[str]:
"""Get optimizer state variable names."""
return ["x", "L"]

def minimizer(self):
"""Return current estimate of the functional mimimizer."""
def minimizer(self) -> Union[Array, BlockArray]:
return self.x

def objective(self, x: Optional[Union[Array, BlockArray]] = None) -> float:
Expand Down
8 changes: 5 additions & 3 deletions scico/optimize/_primaldual.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# Copyright (C) 2021-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -160,8 +160,10 @@ def _itstat_extra_fields(self):
itstat_attrib = ["norm_primal_residual()", "norm_dual_residual()"]
return itstat_fields, itstat_attrib

def minimizer(self):
"""Return current estimate of the functional mimimizer."""
def _state_variable_names(self) -> List[str]:
return ["x", "x_old", "z", "z_old"]

def minimizer(self) -> Union[Array, BlockArray]:
return self.x

def objective(
Expand Down
22 changes: 22 additions & 0 deletions scico/test/optimize/test_admm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import tempfile

import numpy as np

import pytest
Expand Down Expand Up @@ -145,6 +148,25 @@ def test_admm_generic(self):
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-3

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "admm.npz")
admm_.save_state(path)
admm2 = ADMM(
f=f,
g_list=g_list,
C_list=C_list,
rho_list=rho_list,
maxiter=maxiter,
itstat_options={"display": False},
x0=A.adj(self.y),
subproblem_solver=GenericSubproblemSolver(
minimize_kwargs={"options": {"maxiter": 50}}
),
)
admm2.load_state(path)
np.testing.assert_allclose(admm_.z_list[0], admm2.z_list[0], rtol=1e-7)
np.testing.assert_allclose(admm_.u_list[0], admm2.u_list[0], rtol=1e-7)

def test_admm_quadratic_scico(self):
maxiter = 25
ρ = 4e-1
Expand Down

0 comments on commit 6aa1549

Please sign in to comment.