Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible OptimizationResultHDF5Writer #1528

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 74 additions & 31 deletions pypesto/store/save_to_hdf5.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
"""Include functions for saving various results to hdf5."""
from __future__ import annotations

import logging
import os
from numbers import Integral
from pathlib import Path
from typing import Union

import h5py
import numpy as np

from .. import OptimizeResult, OptimizerResult
from ..result import ProfilerResult, Result, SampleResult
from .hdf5 import write_array, write_float_array

logger = logging.getLogger(__name__)


def check_overwrite(
f: Union[h5py.File, h5py.Group], overwrite: bool, target: str
):
def check_overwrite(f: h5py.File | h5py.Group, overwrite: bool, target: str):
"""
Check whether target already exists.

Expand All @@ -36,7 +35,7 @@ def check_overwrite(
del f[target]
else:
raise RuntimeError(
f"File `{f.filename}` already exists and contains "
f"File `{f.file.filename}` already exists and contains "
f"information about {target} result. "
f"If you wish to overwrite the file, set "
f"`overwrite=True`."
Expand All @@ -53,7 +52,7 @@ class ProblemHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize writer.

Expand Down Expand Up @@ -106,7 +105,7 @@ class OptimizationResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize Writer.

Expand All @@ -117,32 +116,76 @@ def __init__(self, storage_filename: Union[str, Path]):
"""
self.storage_filename = str(storage_filename)

def write(self, result: Result, overwrite=False):
"""Write HDF5 result file from pyPESTO result object."""
# Create destination directory
if isinstance(self.storage_filename, str):
basedir = os.path.dirname(self.storage_filename)
if basedir:
os.makedirs(basedir, exist_ok=True)
def write(
self,
result: Result
| OptimizeResult
| OptimizerResult
| list[OptimizerResult],
overwrite=False,
):
"""Write HDF5 result file from pyPESTO result object.

Parameters
----------
result: Result to be saved.
overwrite: Boolean, whether already existing results should be
overwritten. This applies to the whole list of results, not only to
individual results. See :meth:`write_optimizer_result` for
incrementally writing a sequence of `OptimizerResult`.
"""
Path(self.storage_filename).parent.mkdir(parents=True, exist_ok=True)

if isinstance(result, Result):
results = result.optimize_result.list
elif isinstance(result, OptimizeResult):
results = result.list
elif isinstance(result, list):
results = result
elif isinstance(result, OptimizerResult):
results = [result]
else:
raise ValueError(f"Unsupported type for `result`: {type(result)}.")

with h5py.File(self.storage_filename, "a") as f:
check_overwrite(f, overwrite, "optimization")
optimization_grp = f.require_group("optimization")
# settings =
# optimization_grp.create_dataset("settings", settings, dtype=)
results_grp = optimization_grp.require_group("results")

for start in result.optimize_result.list:
start_id = start["id"]
start_grp = results_grp.require_group(start_id)
for key in start.keys():
if key == "history":
continue
if isinstance(start[key], np.ndarray):
write_array(start_grp, key, start[key])
elif start[key] is not None:
start_grp.attrs[key] = start[key]
f.flush()
for start in results:
self._do_write_optimizer_result(start, results_grp, overwrite)

def write_optimizer_result(
self, result: OptimizerResult, overwrite: bool = False
):
"""Write HDF5 result file from pyPESTO result object.

Parameters
----------
result: Result to be saved.
overwrite: Boolean, whether already existing results with the same ID
should be overwritten.s
"""
Path(self.storage_filename).parent.mkdir(parents=True, exist_ok=True)

with h5py.File(self.storage_filename, "a") as f:
results_grp = f.require_group("optimization/results")
self._do_write_optimizer_result(result, results_grp, overwrite)

def _do_write_optimizer_result(
self, result: OptimizerResult, g: h5py.Group = None, overwrite=False
):
"""Write an OptimizerResult to the given group."""
sub_group_id = result["id"]
check_overwrite(g, overwrite, sub_group_id)
start_grp = g.require_group(sub_group_id)
for key in result.keys():
if key == "history":
continue
if isinstance(result[key], np.ndarray):
write_array(start_grp, key, result[key])
elif result[key] is not None:
start_grp.attrs[key] = result[key]


class SamplingResultHDF5Writer:
Expand All @@ -155,7 +198,7 @@ class SamplingResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize Writer.

Expand Down Expand Up @@ -208,7 +251,7 @@ class ProfileResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize Writer.

Expand Down Expand Up @@ -241,7 +284,7 @@ def write(self, result: Result, overwrite: bool = False):

@staticmethod
def _write_profiler_result(
parameter_profile: Union[ProfilerResult, None], result_grp: h5py.Group
parameter_profile: ProfilerResult | None, result_grp: h5py.Group
) -> None:
"""Write a single ProfilerResult to hdf5.

Expand All @@ -267,7 +310,7 @@ def _write_profiler_result(

def write_result(
result: Result,
filename: Union[str, Path],
filename: str | Path,
overwrite: bool = False,
problem: bool = True,
optimize: bool = False,
Expand Down
27 changes: 25 additions & 2 deletions test/base/test_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Test the `pypesto.store` module."""

import os
import tempfile
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
import pytest
import scipy.optimize as so

import pypesto
Expand Down Expand Up @@ -52,7 +54,7 @@

def test_storage_opt_result():
minimize_result = create_optimization_result()
with tempfile.TemporaryDirectory(dir=".") as tmpdirname:
with TemporaryDirectory(dir=".") as tmpdirname:
result_file_name = os.path.join(tmpdirname, "a", "b", "result.h5")
opt_result_writer = OptimizationResultHDF5Writer(result_file_name)
opt_result_writer.write(minimize_result)
Expand Down Expand Up @@ -89,6 +91,27 @@ def test_storage_opt_result_update(hdf5_file):
assert opt_res[key] == read_result.optimize_result[i][key]


def test_write_optimizer_results_incrementally():
"""Test writing optimizer results incrementally to the same file."""
res = create_optimization_result()
res1, res2 = res.optimize_result.list[:2]

with TemporaryDirectory() as tmp_dir:
result_path = Path(tmp_dir, "result.h5")
writer = OptimizationResultHDF5Writer(result_path)
writer.write_optimizer_result(res1)
writer.write_optimizer_result(res2)
reader = OptimizationResultHDF5Reader(result_path)
read_res = reader.read()
assert len(read_res.optimize_result) == 2

# overwriting works
writer.write_optimizer_result(res1, overwrite=True)
# overwriting attempt fails without overwrite=True
with pytest.raises(RuntimeError):
writer.write_optimizer_result(res1)


def test_storage_problem(hdf5_file):
problem = create_problem()
problem_writer = ProblemHDF5Writer(hdf5_file)
Expand Down