Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 80 additions & 16 deletions src/plaid/containers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@

import numpy as np
import yaml
from packaging.specifiers import SpecifierSet
from packaging.version import Version
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from tqdm import tqdm

import plaid
from plaid.constants import AUTHORIZED_INFO_KEYS
from plaid.containers.sample import Sample
from plaid.containers.utils import check_features_size_homogeneity
Expand Down Expand Up @@ -62,6 +67,62 @@ def process_sample(path: Union[str, Path]) -> tuple: # pragma: no cover
# %% Classes


class DataclassToDict:
def __getitem__(self, item):
return getattr(self, item)

def __setitem__(self, key, value):
assert hasattr(self, key), (
f"{self.__class__.__name__} has no attribute '{key}'."
)
return setattr(self, key, value)


@dataclass()
class LegalDescription(DataclassToDict):
"""Container for legal information."""

owner: Optional[str] = None
license: Optional[str] = None


@dataclass()
class DataProductionDescription(DataclassToDict):
"""Container for data production information."""

owner: Optional[str] = None
license: Optional[str] = None
type: Optional[str] = None
physics: Optional[str] = None
simulator: Optional[str] = None
hardware: Optional[str] = None
computation_duration: Optional[str] = None
script: Optional[str] = None
contact: Optional[str] = None
location: Optional[str] = None


@dataclass()
class DataDescription(DataclassToDict):
"""Container for data information."""

number_of_samples: Optional[str] = None
number_of_splits: Optional[str] = None
DOE: Optional[str] = None
inputs: Optional[str] = None
outputs: Optional[str] = None


@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DatasetInfos(DataclassToDict):
"""Container for dataset information."""

version: Union[Version, SpecifierSet] = Version(plaid.__version__)
legal: Optional[LegalDescription] = None
data_production: Optional[DataProductionDescription] = None
data_description: Optional[DataDescription] = None


class Dataset(object):
"""A set of samples, and optionnaly some other informations about the Dataset."""

Expand Down Expand Up @@ -113,7 +174,7 @@ def __init__(
"""
self._samples: dict[int, Sample] = {} # sample_id -> sample
# info_name -> description
self._infos: dict[str, dict[str, str]] = {}
self._infos = DatasetInfos()

if directory_path is not None:
if path is not None:
Expand Down Expand Up @@ -1470,29 +1531,32 @@ def _save_to_dir_(self, path: Union[str, Path], verbose: bool = False) -> None:
if verbose: # pragma: no cover
print(f"Saving database to: {path}")

# Save infos
assert "plaid" in self._infos
assert "version" in self._infos["plaid"]
plaid_version = Version(plaid.__version__)
if (
isinstance(self._infos["plaid"]["version"], SpecifierSet)
or self._infos["plaid"]["version"] != plaid_version
):
logger.warning(
f"Version mismatch: Dataset was loaded from version: {self._infos['plaid']['version']}, and will be saved with version: {plaid_version}"
)
self._infos["plaid"]["old_version"] = str(self._infos["plaid"]["version"])
self._infos["plaid"]["version"] = str(plaid_version)
infos_fname = path / "infos.yaml"
with open(infos_fname, "w") as file:
yaml.dump(self._infos, file, default_flow_style=False, sort_keys=False)

# Save samples
samples_dir = path / "samples"
if not (samples_dir.is_dir()):
samples_dir.mkdir(parents=True)

# ---# save samples
for i_sample, sample in tqdm(self._samples.items(), disable=not (verbose)):
sample_fname = samples_dir / f"sample_{i_sample:09d}"
sample.save(sample_fname)

# ---# save infos
if len(self._infos) > 0:
infos_fname = path / "infos.yaml"
with open(infos_fname, "w") as file:
yaml.dump(self._infos, file, default_flow_style=False, sort_keys=False)

# #---# save stats
# stats_fname = path / 'stats.yaml'
# self._stats.save(stats_fname)

# #---# save flags
# flags_fname = path / 'flags.yaml'
# self._flags.save(flags_fname)

def _load_from_dir_(
self,
path: Union[str, Path],
Expand Down
Loading