diff --git a/pypesto/optimize/ess/ess.py b/pypesto/optimize/ess/ess.py index 130ad2a9d..5e8cdb957 100644 --- a/pypesto/optimize/ess/ess.py +++ b/pypesto/optimize/ess/ess.py @@ -70,13 +70,13 @@ class ESSOptimizer: def __init__( self, *, - max_iter: int = 10**100, + max_iter: int = None, dim_refset: int = None, local_n1: int = 1, local_n2: int = 10, balance: float = 0.5, local_optimizer: 'pypesto.optimize.Optimizer' = None, - max_eval=np.inf, + max_eval=None, n_diverse: int = None, n_procs=None, n_threads=None, @@ -122,6 +122,18 @@ def __init__( Number of parallel threads to use for parallel function evaluation. Mutually exclusive with `n_procs`. """ + if max_eval is None and max_walltime_s is None and max_iter is None: + # in this case, we'd run forever + raise ValueError( + "Either `max_iter`, `max_eval` or `max_walltime_s` have to be provided." + ) + if max_eval is None: + max_eval = np.inf + if max_walltime_s is None: + max_walltime_s = np.inf + if max_iter is None: + max_iter = np.inf + # Hyperparameters self.local_n1: int = local_n1 self.local_n2: int = local_n2 @@ -230,7 +242,6 @@ def minimize( refset = self.refset else: self.refset = refset - problem = refset.evaluator.problem self.evaluator = refset.evaluator self.x_best = np.full( @@ -255,7 +266,7 @@ def minimize( self._go_beyond(x_best_children, fx_best_children) # Maybe perform a local search - if self.local_optimizer is not None: + if self.local_optimizer is not None and self._keep_going(): self._do_local_search(x_best_children, fx_best_children) # Replace RefSet members by best children where an improvement @@ -373,6 +384,9 @@ def _combine_solutions(self) -> Tuple[np.array, np.array]: best_idx = np.argmin(fxs_new) fy[i] = fxs_new[best_idx] y[i] = xs_new[best_idx] + + if not self._keep_going(): + break return y, fy def _combine(self, i, j) -> np.array: @@ -547,6 +561,8 @@ def _go_beyond(self, x_best_children, fx_best_children): self._maybe_update_global_best( x_best_children[i], fx_best_children[i] ) + if not self._keep_going(): + break def _report_iteration(self): """Log the current iteration."""