Skip to content

Commit

Permalink
SacessOptimizer: More efficient saving of intermediate results
Browse files Browse the repository at this point in the history
Don't write old local optima again and again in every iteration to save time and avoid huge fragmented files.

Some other smaller modifications/fixes:
* make sacess_history plotting more robust.
* Set SacessOptimizer.exit_flag
  • Loading branch information
dweindl committed Nov 28, 2024
1 parent d1c9e52 commit 08dc47f
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pypesto/optimize/ess/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Enhanced Scatter Search."""

from .ess import ESSOptimizer
from .ess import ESSExitFlag, ESSOptimizer
from .function_evaluator import (
FunctionEvaluator,
FunctionEvaluatorMP,
Expand Down
12 changes: 9 additions & 3 deletions pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@


class ESSExitFlag(int, enum.Enum):
"""Exit flags used by :class:`ESSOptimizer`."""
"""Scatter search exit flags.
Exit flags used by :class:`pypesto.ess.ESSOptimizer` and
:class:`pypesto.ess.SacessOptimizer`.
"""

# ESS did not run/finish yet
DID_NOT_RUN = 0
# Exited after reaching maximum number of iterations
# Exited after reaching the maximum number of iterations
MAX_ITER = -1
# Exited after exhausting function evaluation budget
MAX_EVAL = -2
# Exited after exhausting wall-time budget
MAX_TIME = -3
# Termination because for other reason than exit criteria
# Termination because of other reasons than exit criteria
ERROR = -99


Expand Down Expand Up @@ -242,6 +246,8 @@ def _initialize(self):
# Overall best function value found so far
self.fx_best: float = np.inf
# Results from local searches (only those with finite fval)
# (there is potential to save memory here by only keeping the
# parameters in memory and not the full result)
self.local_solutions: list[OptimizerResult] = []
# Index of current iteration
self.n_iter: int = 0
Expand Down
70 changes: 58 additions & 12 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from ... import MemoryHistory
from ...startpoint import StartpointMethod
from ...store.read_from_hdf5 import read_result
from ...store.save_to_hdf5 import write_result
from ...store.save_to_hdf5 import (
OptimizationResultHDF5Writer,
ProblemHDF5Writer,
)
from ..optimize import Problem
from .ess import ESSExitFlag, ESSOptimizer
from .function_evaluator import create_function_evaluator
Expand Down Expand Up @@ -162,7 +165,7 @@ def __init__(
Directory for temporary files. This defaults to a directory in the
current working directory named ``SacessOptimizerTemp-{random suffix}``.
When setting this option, make sure any optimizers running in
parallel have a unique `tmpdir`.
parallel have a unique `tmpdir`. Expected to be empty.
mp_start_method:
The start method for the multiprocessing context.
See :mod:`multiprocessing` for details. Running `SacessOptimizer`
Expand Down Expand Up @@ -326,7 +329,9 @@ def minimize(
self.histories = [
worker_result.history for worker_result in self.worker_results
]

self.exit_flag = min(
worker_result.exit_flag for worker_result in self.worker_results
)
result = self._create_result(problem)

walltime = time.time() - start_time
Expand Down Expand Up @@ -649,6 +654,10 @@ def run(
):
self._start_time = time.time()

# index of the local solution in ESSOptimizer.local_solutions
# that was most recently saved by _autosave
last_saved_local_solution = -1

self._logger.setLevel(self._loglevel)
# Set the manager logger to one created within the current process
self._manager._logger = self._logger
Expand Down Expand Up @@ -683,20 +692,14 @@ def run(
# perform one ESS iteration
ess._do_iteration()

if self._tmp_result_file:
# TODO maybe not in every iteration?
ess_results = ess._create_result()
write_result(
ess_results,
self._tmp_result_file,
overwrite=True,
optimize=True,
)
# check if the best solution of the last local ESS is sufficiently
# better than the sacess-wide best solution
self.maybe_update_best(ess.x_best, ess.fx_best)
self._best_known_fx = min(ess.fx_best, self._best_known_fx)

self._autosave(ess, last_saved_local_solution)
last_saved_local_solution = len(ess.local_solutions) - 1

self._cooperate()
self._maybe_adapt(problem)

Expand Down Expand Up @@ -911,6 +914,49 @@ def abort(self):

self._finalize(None)

def _autosave(self, ess: ESSOptimizer, last_saved_local_solution: int):
"""Save intermediate results.
If a temporary result file is set, save the (part of) the current state
of the ESS to that file.
We save the current best solution and the local optimizer results.
"""
if not self._tmp_result_file:
return

# save problem in first iteration
if ess.n_iter == 1:
pypesto_problem_writer = ProblemHDF5Writer(self._tmp_result_file)
pypesto_problem_writer.write(
ess.evaluator.problem, overwrite=False
)

opt_res_writer = OptimizationResultHDF5Writer(self._tmp_result_file)
for i in range(
last_saved_local_solution + 1, len(ess.local_solutions)
):
optimizer_result = ess.local_solutions[i]
optimizer_result.id = str(i + ess.n_iter)
opt_res_writer.write_optimizer_result(
optimizer_result, overwrite=False
)

# save the current best solution
optimizer_result = pypesto.OptimizerResult(
id=str(len(ess.local_solutions) + ess.n_iter),
x=ess.x_best,
fval=ess.fx_best,
message=f"Global best (iteration {ess.n_iter})",
time=time.time() - ess.starttime,
n_fval=ess.evaluator.n_eval,
optimizer=str(ess),
)
optimizer_result.update_to_full(ess.evaluator.problem)
opt_res_writer.write_optimizer_result(
optimizer_result, overwrite=False
)

@staticmethod
def get_temp_result_filename(worker_idx: int, tmpdir: str | Path) -> str:
return str(Path(tmpdir, f"sacess-{worker_idx:02d}_tmp.h5").absolute())
Expand Down
6 changes: 6 additions & 0 deletions pypesto/visualize/optimizer_history.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
from collections.abc import Iterable
from typing import Optional, Union

Expand Down Expand Up @@ -501,6 +502,8 @@ def sacess_history(
The plot axes. `ax` or a new axes if `ax` was `None`.
"""
ax = ax or plt.subplot()
if len(histories) == 0:
warnings.warn("No histories to plot.", stacklevel=2)

# plot overall minimum
# merge results
Expand Down Expand Up @@ -530,6 +533,9 @@ def sacess_history(
# plot steps of individual workers
for worker_idx, history in enumerate(histories):
x, y = history.get_time_trace(), history.get_fval_trace()
if len(x) == 0:
warnings.warn(f"No trace for worker #{worker_idx}.", stacklevel=2)
continue
# extend from last decrease to last timepoint
x = np.append(x, [np.max(t)])
y = np.append(y, [np.min(y)])
Expand Down
3 changes: 3 additions & 0 deletions test/optimize/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pypesto.optimize as optimize
from pypesto import Objective
from pypesto.optimize.ess import (
ESSExitFlag,
ESSOptimizer,
FunctionEvaluatorMP,
RefSet,
Expand Down Expand Up @@ -501,12 +502,14 @@ def test_ess(problem, local_optimizer, ess_type, request):
adaptation_sent_coeff=5,
),
)

else:
raise ValueError(f"Unsupported ESS type {ess_type}.")

res = ess.minimize(
problem=problem,
)
assert ess.exit_flag in (ESSExitFlag.MAX_TIME, ESSExitFlag.MAX_ITER)
print("ESS result: ", res.summary())

# best values roughly: cr: 4.701; rosen 7.592e-10
Expand Down
3 changes: 2 additions & 1 deletion test/visualize/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,14 +1219,15 @@ def test_time_trajectory_model():
@close_fig
def test_sacess_history():
"""Test pypesto.visualize.optimizer_history.sacess_history"""
from pypesto.optimize.ess.sacess import SacessOptimizer
from pypesto.optimize.ess.sacess import ESSExitFlag, SacessOptimizer
from pypesto.visualize.optimizer_history import sacess_history

problem = create_problem()
sacess = SacessOptimizer(
max_walltime_s=1, num_workers=2, sacess_loglevel=logging.WARNING
)
sacess.minimize(problem)
assert sacess.exit_flag == ESSExitFlag.MAX_TIME
sacess_history(sacess.histories)


Expand Down

0 comments on commit 08dc47f

Please sign in to comment.