Skip to content

Commit

Permalink
ESSOptimizer: Check inputs and check stopping criteria more frequently (
Browse files Browse the repository at this point in the history
#1176)

* Ensure we have some stopping criterion to not run infinitely long
* Avoid exceeding stopping criteria by checking them more frequently
  • Loading branch information
dweindl authored Nov 17, 2023
1 parent f366ddc commit 4507747
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ class ESSOptimizer:
def __init__(
self,
*,
max_iter: int = 10**100,
max_iter: int = None,
dim_refset: int = None,
local_n1: int = 1,
local_n2: int = 10,
balance: float = 0.5,
local_optimizer: 'pypesto.optimize.Optimizer' = None,
max_eval=np.inf,
max_eval=None,
n_diverse: int = None,
n_procs=None,
n_threads=None,
Expand Down Expand Up @@ -122,6 +122,18 @@ def __init__(
Number of parallel threads to use for parallel function evaluation.
Mutually exclusive with `n_procs`.
"""
if max_eval is None and max_walltime_s is None and max_iter is None:
# in this case, we'd run forever
raise ValueError(
"Either `max_iter`, `max_eval` or `max_walltime_s` have to be provided."
)
if max_eval is None:
max_eval = np.inf
if max_walltime_s is None:
max_walltime_s = np.inf
if max_iter is None:
max_iter = np.inf

# Hyperparameters
self.local_n1: int = local_n1
self.local_n2: int = local_n2
Expand Down Expand Up @@ -230,7 +242,6 @@ def minimize(
refset = self.refset
else:
self.refset = refset
problem = refset.evaluator.problem

self.evaluator = refset.evaluator
self.x_best = np.full(
Expand All @@ -255,7 +266,7 @@ def minimize(
self._go_beyond(x_best_children, fx_best_children)

# Maybe perform a local search
if self.local_optimizer is not None:
if self.local_optimizer is not None and self._keep_going():
self._do_local_search(x_best_children, fx_best_children)

# Replace RefSet members by best children where an improvement
Expand Down Expand Up @@ -373,6 +384,9 @@ def _combine_solutions(self) -> Tuple[np.array, np.array]:
best_idx = np.argmin(fxs_new)
fy[i] = fxs_new[best_idx]
y[i] = xs_new[best_idx]

if not self._keep_going():
break
return y, fy

def _combine(self, i, j) -> np.array:
Expand Down Expand Up @@ -547,6 +561,8 @@ def _go_beyond(self, x_best_children, fx_best_children):
self._maybe_update_global_best(
x_best_children[i], fx_best_children[i]
)
if not self._keep_going():
break

def _report_iteration(self):
"""Log the current iteration."""
Expand Down

0 comments on commit 4507747

Please sign in to comment.