From d222bf411ea5606deab046548ec0fe182b501ec1 Mon Sep 17 00:00:00 2001 From: shudson Date: Fri, 19 Jan 2024 17:02:45 -0600 Subject: [PATCH] Save persis_info every k (with history) --- docs/data_structures/libE_specs.rst | 3 +++ libensemble/manager.py | 37 +++++++++++++++++++---------- libensemble/specs.py | 3 +++ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/docs/data_structures/libE_specs.rst b/docs/data_structures/libE_specs.rst index d471cf968..2e24cfec9 100644 --- a/docs/data_structures/libE_specs.rst +++ b/docs/data_structures/libE_specs.rst @@ -201,6 +201,9 @@ libEnsemble is primarily customized by setting options within a ``LibeSpecs`` cl **H_file_prefix** Optional[str] = ``"libE_history"`` Prefix for ``H`` filename. + **persis_info_file_prefix** Optional[str] = ``"libE_persis"`` + Prefix for ``persis_info`` filename. + **use_persis_return_gen** [bool] = ``False``: Adds persistent generator output fields to the History array on return. diff --git a/libensemble/manager.py b/libensemble/manager.py index cce7682f8..01b70fe7f 100644 --- a/libensemble/manager.py +++ b/libensemble/manager.py @@ -7,6 +7,7 @@ import glob import logging import os +import pickle import platform import socket import sys @@ -279,42 +280,52 @@ def _get_date_start_str(self) -> str: date_start = "" return date_start - def _save_every_k(self, fname: str, count: int, k: int, complete: bool) -> None: - """Saves history every kth step""" + def _save_every_k(self, persis_info: dict, fname: str, count: int, k: int, complete: bool) -> None: + """Saves history and persis_info every kth step""" if not complete: count = k * (count // k) date_start = self._get_date_start_str() - filename = fname.format(self.libE_specs["H_file_prefix"], date_start, count) + filename = fname.format(self.libE_specs["H_file_prefix"], date_start, count, "npy") if (not os.path.isfile(filename) and count > 0) or complete: - for old_file in glob.glob(fname.format(self.libE_specs["H_file_prefix"], date_start, "*")): + for old_file in glob.glob(fname.format(self.libE_specs["H_file_prefix"], date_start, "*", "npy")): os.remove(old_file) np.save(filename, self.hist.trim_H()) - def _save_every_k_sims(self, complete: bool) -> None: + filename = fname.format(self.libE_specs["persis_info_file_prefix"], date_start, count, "pickle") + for old_file in glob.glob( + fname.format(self.libE_specs["persis_info_file_prefix"], date_start, "*", "pickle") + ): + os.remove(old_file) + with open(filename, "wb") as f: + pickle.dump(persis_info, f) + + def _save_every_k_sims(self, persis_info: dict, complete: bool) -> None: """Saves history every kth sim step""" self._save_every_k( - os.path.join(self.libE_specs["workflow_dir_path"], "{}_{}after_sim_{}.npy"), + persis_info, + os.path.join(self.libE_specs["workflow_dir_path"], "{}_{}after_sim_{}.{}"), self.hist.sim_ended_count, self.libE_specs["save_every_k_sims"], complete, ) - def _save_every_k_gens(self, complete: bool) -> None: + def _save_every_k_gens(self, persis_info: dict, complete: bool) -> None: """Saves history every kth gen step""" self._save_every_k( - os.path.join(self.libE_specs["workflow_dir_path"], "{}_{}after_gen_{}.npy"), + persis_info, + os.path.join(self.libE_specs["workflow_dir_path"], "{}_{}after_gen_{}.{}"), self.hist.index, self.libE_specs["save_every_k_gens"], complete, ) - def _init_every_k_save(self, complete=False) -> None: + def _init_every_k_save(self, persis_info, complete=False) -> None: force_final = complete and not self.libE_specs.get("save_every_k_gens") if self.libE_specs.get("save_every_k_sims") or force_final: - self._save_every_k_sims(complete) + self._save_every_k_sims(persis_info, complete) if self.libE_specs.get("save_every_k_gens"): - self._save_every_k_gens(complete) + self._save_every_k_gens(persis_info, complete) # --- Handle outgoing messages to workers (work orders from alloc) @@ -436,7 +447,7 @@ def _receive_from_workers(self, persis_info: dict) -> dict: new_stuff = True self._handle_msg_from_worker(persis_info, w) - self._init_every_k_save() + self._init_every_k_save(persis_info) return persis_info def _update_state_on_worker_msg(self, persis_info: dict, D_recv: dict, w: int) -> None: @@ -575,7 +586,7 @@ def _final_receive_and_kill(self, persis_info: dict) -> (dict, int, int): if self.WorkerExc: exit_flag = 1 - self._init_every_k_save(complete=self.libE_specs["save_H_on_completion"]) + self._init_every_k_save(persis_info, complete=self.libE_specs["save_H_on_completion"]) self._kill_workers() return persis_info, exit_flag, self.elapsed() diff --git a/libensemble/specs.py b/libensemble/specs.py index f7b7b3ea5..9f8096f51 100644 --- a/libensemble/specs.py +++ b/libensemble/specs.py @@ -195,6 +195,9 @@ class LibeSpecs(BaseModel): H_file_prefix: Optional[str] = "libE_history" """ Prefix for ``H`` filename.""" + persis_info_file_prefix: Optional[str] = "libE_persis" + """ Prefix for ``H`` filename.""" + worker_timeout: Optional[int] = 1 """ On libEnsemble shutdown, number of seconds after which workers considered timed out, then terminated. """