Skip to content

Commit

Permalink
Support OptimizationResultHDF5Writer.write(OptimizeResult)
Browse files Browse the repository at this point in the history
Previously, only `Result` was supported.

Closes #1526.
  • Loading branch information
dweindl committed Nov 27, 2024
1 parent 1eb61d7 commit 9b562db
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions pypesto/store/save_to_hdf5.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
"""Include functions for saving various results to hdf5."""
from __future__ import annotations

import logging
import os
from numbers import Integral
from pathlib import Path
from typing import Union

import h5py
import numpy as np

from .. import OptimizeResult
from ..result import ProfilerResult, Result, SampleResult
from .hdf5 import write_array, write_float_array

logger = logging.getLogger(__name__)


def check_overwrite(
f: Union[h5py.File, h5py.Group], overwrite: bool, target: str
):
def check_overwrite(f: h5py.File | h5py.Group, overwrite: bool, target: str):
"""
Check whether target already exists.
Expand Down Expand Up @@ -53,7 +52,7 @@ class ProblemHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize writer.
Expand Down Expand Up @@ -106,7 +105,7 @@ class OptimizationResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize Writer.
Expand All @@ -117,22 +116,29 @@ def __init__(self, storage_filename: Union[str, Path]):
"""
self.storage_filename = str(storage_filename)

def write(self, result: Result, overwrite=False):
def write(self, result: Result | OptimizeResult, overwrite=False):
"""Write HDF5 result file from pyPESTO result object."""
# Create destination directory
if isinstance(self.storage_filename, str):
basedir = os.path.dirname(self.storage_filename)
if basedir:
os.makedirs(basedir, exist_ok=True)

if isinstance(result, Result):
result = result.optimize_result
if not isinstance(result, OptimizeResult):
raise ValueError(
"The result object must be of type `OptimizeResult`."
)

with h5py.File(self.storage_filename, "a") as f:
check_overwrite(f, overwrite, "optimization")
optimization_grp = f.require_group("optimization")
# settings =
# optimization_grp.create_dataset("settings", settings, dtype=)
results_grp = optimization_grp.require_group("results")

for start in result.optimize_result.list:
for start in result.list:
start_id = start["id"]
start_grp = results_grp.require_group(start_id)
for key in start.keys():
Expand All @@ -155,7 +161,7 @@ class SamplingResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize Writer.
Expand Down Expand Up @@ -208,7 +214,7 @@ class ProfileResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: Union[str, Path]):
def __init__(self, storage_filename: str | Path):
"""
Initialize Writer.
Expand Down Expand Up @@ -241,7 +247,7 @@ def write(self, result: Result, overwrite: bool = False):

@staticmethod
def _write_profiler_result(
parameter_profile: Union[ProfilerResult, None], result_grp: h5py.Group
parameter_profile: ProfilerResult | None, result_grp: h5py.Group
) -> None:
"""Write a single ProfilerResult to hdf5.
Expand All @@ -267,7 +273,7 @@ def _write_profiler_result(

def write_result(
result: Result,
filename: Union[str, Path],
filename: str | Path,
overwrite: bool = False,
problem: bool = True,
optimize: bool = False,
Expand Down

0 comments on commit 9b562db

Please sign in to comment.