From ee128d7e129c6ad706c6e75a66c3c9281a056db3 Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Wed, 25 Mar 2026 09:27:24 -0400 Subject: [PATCH 1/7] More dataloader options including support for different multiprocessing methods. LMDB refactor to allow for forkserver/spawn serialization + resolve issues that required pinning lmdb for training. --- openfold3/core/data/framework/data_module.py | 24 +++++++-- .../framework/single_datasets/base_of3.py | 31 ++++++++--- .../single_datasets/dataset_utils.py | 19 +++++++ .../core/data/primitives/caches/format.py | 23 ++++++-- openfold3/core/data/primitives/caches/lmdb.py | 52 +++++++++++++++---- openfold3/entry_points/validator.py | 4 ++ pyproject.toml | 2 +- 7 files changed, 130 insertions(+), 25 deletions(-) diff --git a/openfold3/core/data/framework/data_module.py b/openfold3/core/data/framework/data_module.py index e0b858918..950a8d3d6 100644 --- a/openfold3/core/data/framework/data_module.py +++ b/openfold3/core/data/framework/data_module.py @@ -152,7 +152,11 @@ class DataModuleConfig(BaseModel): datasets: list[SerializeAsAny[BaseModel]] batch_size: int = 1 num_workers: int = 0 + prefetch_factor: int | None = None num_workers_validation: int = 0 + prefetch_factor_validation: int | None = None + persistent_workers: bool = False + multiprocessing_context: str | None = None data_seed: int = 42 epoch_len: int = 1 @@ -165,8 +169,14 @@ def __init__(self, data_module_config: DataModuleConfig) -> None: # Possibly initialize directly from DataModuleConfig self.batch_size = data_module_config.batch_size + self.num_workers = data_module_config.num_workers + self.prefetch_factor = data_module_config.prefetch_factor self.num_workers_validation = data_module_config.num_workers_validation + self.prefetch_factor_validation = data_module_config.prefetch_factor_validation + self.persistent_workers = data_module_config.persistent_workers + self.multiprocessing_context = data_module_config.multiprocessing_context + self.data_seed = data_module_config.data_seed self.next_data_seed = data_module_config.data_seed self.epoch_len = data_module_config.epoch_len @@ -408,17 +418,20 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None) Returns: DataLoader: DataLoader object. """ - - # TODO: Val does not need this many workers. Due to memory leak issue, - # reduce workers here to run with more workers overall in training - # as temporary quick fix. if ( mode == DatasetMode.validation and DatasetMode.train in self.multi_dataset_config.modes ): num_workers = self.num_workers_validation + prefetch_factor = self.prefetch_factor_validation else: num_workers = self.num_workers + prefetch_factor = self.prefetch_factor + + persistent_workers = self.persistent_workers and num_workers > 0 + multiprocessing_context = ( + self.multiprocessing_context if num_workers > 0 else None + ) generator = self.generators.get(mode) if generator is None: @@ -445,6 +458,9 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None) collate_fn=openfold_batch_collator, generator=self.generators[mode], worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + multiprocessing_context=multiprocessing_context, ) def train_dataloader(self) -> DataLoader: diff --git a/openfold3/core/data/framework/single_datasets/base_of3.py b/openfold3/core/data/framework/single_datasets/base_of3.py index 3b9e01de9..f756714e7 100644 --- a/openfold3/core/data/framework/single_datasets/base_of3.py +++ b/openfold3/core/data/framework/single_datasets/base_of3.py @@ -27,6 +27,7 @@ SingleDataset, register_dataset, ) +from openfold3.core.data.framework.single_datasets.dataset_utils import warm_lmdb_cache from openfold3.core.data.io.dataset_cache import read_datacache from openfold3.core.data.pipelines.featurization.conformer import ( featurize_reference_conformers_of3, @@ -153,15 +154,20 @@ def __init__(self, dataset_config) -> None: # TODO: rename dataset_cache_file to dataset_cache_path to signal that it can be # a directory or a file # TODO: potentially expose the LMDB database encoding types - self.dataset_cache = read_datacache( - dataset_config.dataset_paths.dataset_cache_file - ) + self._dataset_cache_file = dataset_config.dataset_paths.dataset_cache_file + self.dataset_cache = read_datacache(self._dataset_cache_file) + self.warm_cache() + self.datapoint_cache = {} - if dataset_config.dataset_paths.template_structures_directory is not None: - self.ccd = pdbx.CIFFile.read(dataset_config.dataset_paths.ccd_file) - else: - self.ccd = None + # Only used if template structures are not preprocessed + # Lazy-loaded so the dataset is picklable (forkserver) + self._ccd = None + self._ccd_file = ( + dataset_config.dataset_paths.ccd_file + if dataset_config.dataset_paths.template_structures_directory is not None + else None + ) # Dataset configuration # n_tokens can be set in the getitem method separately for each sample using @@ -174,6 +180,17 @@ def __init__(self, dataset_config) -> None: self.single_moltype = None self.debug_mode = dataset_config.debug_mode + def warm_cache(self) -> None: + """Warm the LMDB page cache (no-op for JSON cache).""" + if self._dataset_cache_file.is_dir(): + warm_lmdb_cache(self._dataset_cache_file) + + @property + def ccd(self): + if self._ccd is None and self._ccd_file is not None: + self._ccd = pdbx.CIFFile.read(self._ccd_file) + return self._ccd + @log_runtime_memory(runtime_dict_key="runtime-create-structure-features") def create_structure_features( self, diff --git a/openfold3/core/data/framework/single_datasets/dataset_utils.py b/openfold3/core/data/framework/single_datasets/dataset_utils.py index 7fbb2ed04..f8fe42b68 100644 --- a/openfold3/core/data/framework/single_datasets/dataset_utils.py +++ b/openfold3/core/data/framework/single_datasets/dataset_utils.py @@ -15,7 +15,9 @@ import copy import logging import os +import time from itertools import cycle, islice +from pathlib import Path import pandas as pd import torch @@ -30,6 +32,7 @@ naive_alignment, ) +logger = logging.getLogger(__name__) worker_seed_log = logging.getLogger(f"{__name__}.worker_seed") @@ -153,3 +156,19 @@ def getitem_debug_log(dataset_name: str = "") -> None: f"pid={os.getpid()} worker_id={worker_id} wi.seed={wi_seed} " f"wi.base_seed={wi_base_seed} torch.initial_seed={torch_seed}", ) + + +def warm_lmdb_cache(lmdb_directory: Path) -> None: + """Sequentially read the LMDB data file to warm the OS page cache""" + data_file = lmdb_directory / "data.mdb" + file_size_gb = data_file.stat().st_size / (1024**3) + logger.info( + f"Warming LMDB page cache for {lmdb_directory} ({file_size_gb:.1f} GB)..." + ) + t0 = time.monotonic() + chunk_size = 8 * 1024 * 1024 + with open(data_file, "rb") as f: + while f.read(chunk_size): + pass + elapsed = time.monotonic() - t0 + logger.info(f"LMDB cache warm complete in {elapsed:.1f}s") diff --git a/openfold3/core/data/primitives/caches/format.py b/openfold3/core/data/primitives/caches/format.py index db2d0c785..88e55d089 100755 --- a/openfold3/core/data/primitives/caches/format.py +++ b/openfold3/core/data/primitives/caches/format.py @@ -135,6 +135,7 @@ def from_json(cls, file: Path) -> PreprocessingDataCache: # rerunning preprocessing elif status == "failed": release_date = None + experimental_method = None resolution = None chains = None interfaces = None @@ -414,7 +415,7 @@ class DatasetCache: # TODO: update parsers for this base class @classmethod def from_json(cls, file: Path) -> DatasetCache: - """Costructs a datacache from a json. + """Constructs a datacache from a json. Args: file (Path): @@ -434,6 +435,7 @@ def from_json(cls, file: Path) -> DatasetCache: reference_molecule_data=cls._parse_ref_mol_data_json(data), ) + @staticmethod def _parse_type_json(data: dict) -> None: # Remove _type field (already an internal private attribute so shouldn't be # defined as an explicit field) @@ -441,6 +443,7 @@ def _parse_type_json(data: dict) -> None: # This is conditional for legacy compatibility, should be removed after del data["_type"] + @staticmethod def _parse_name_json(data: dict) -> str: return data["name"] @@ -532,10 +535,16 @@ def from_lmdb( _ = cls._parse_type_lmdb(transaction, str_encoding) name = cls._parse_name_lmdb(transaction, str_encoding) structure_data = cls._parse_structure_data_lmdb( - lmdb_env, str_encoding, structure_data_encoding + lmdb_env, + str_encoding, + structure_data_encoding, + lmdb_path=lmdb_directory, ) reference_molecule_data = cls._parse_ref_mol_data_lmdb( - lmdb_env, str_encoding, reference_molecule_data_encoding + lmdb_env, + str_encoding, + reference_molecule_data_encoding, + lmdb_path=lmdb_directory, ) return cls( @@ -544,6 +553,7 @@ def from_lmdb( reference_molecule_data=reference_molecule_data, ) + @staticmethod def _parse_type_lmdb( transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"] ) -> str: @@ -555,6 +565,7 @@ def _parse_type_lmdb( return _type + @staticmethod def _parse_name_lmdb( transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"] ) -> str: @@ -566,10 +577,12 @@ def _parse_name_lmdb( return name + @staticmethod def _parse_structure_data_lmdb( lmdb_env: lmdb.Environment, str_encoding: Literal["utf-8", "pkl"], structure_data_encoding: Literal["utf-8", "pkl"], + lmdb_path: Path | None = None, ) -> LMDBDict: from openfold3.core.data.primitives.caches.lmdb import ( LMDBDict, @@ -580,12 +593,15 @@ def _parse_structure_data_lmdb( prefix="structure_data", key_encoding=str_encoding, value_encoding=structure_data_encoding, + lmdb_path=lmdb_path, ) + @staticmethod def _parse_ref_mol_data_lmdb( lmdb_env: lmdb.Environment, str_encoding: Literal["utf-8", "pkl"], reference_molecule_data_encoding: Literal["utf-8", "pkl"], + lmdb_path: Path | None = None, ) -> LMDBDict: from openfold3.core.data.primitives.caches.lmdb import ( LMDBDict, @@ -596,6 +612,7 @@ def _parse_ref_mol_data_lmdb( prefix="reference_molecule_data", key_encoding=str_encoding, value_encoding=reference_molecule_data_encoding, + lmdb_path=lmdb_path, ) def to_lmdb( diff --git a/openfold3/core/data/primitives/caches/lmdb.py b/openfold3/core/data/primitives/caches/lmdb.py index c53edd472..97870c0df 100644 --- a/openfold3/core/data/primitives/caches/lmdb.py +++ b/openfold3/core/data/primitives/caches/lmdb.py @@ -131,9 +131,10 @@ def __init__( self, lmdb_env: lmdb.Environment, prefix: str, - separator: chr = ":", + separator: str = ":", key_encoding: Literal["utf-8", "pkl"] = "utf-8", value_encoding: Literal["utf-8", "pkl"] = "pkl", + lmdb_path: str | Path | None = None, ): """A dict-like class with an LMDB backend for lazy loading of datacache entries. @@ -141,27 +142,40 @@ def __init__( lmdb_env (lmdb.Environment): The LMDB environment object. prefix (str): header for fields used to construct keys in lmdb - separator (chr): Single separator character used to construct key + separator (str): Single separator character used to construct key key_encoding (Literal["utf-8", "pkl"]): Encoding of keys. Defaults to "utf-8". value_encoding (Literal["utf-8", "pkl"]): Encoding of values. Defaults to "pkl". + lmdb_path (str | Path | None): + Path to the LMDB directory. Used to reopen env after pickle + (e.g. for forkserver/spawn multiprocessing). Raises: KeyError: If a non-existent key is requested. """ self._lmdb_env = lmdb_env + self._lmdb_path = str(lmdb_path) if lmdb_path is not None else None self._prefix = prefix + separator self._key_encoding = key_encoding self._value_encoding = value_encoding - - with self._lmdb_env.begin() as transaction, transaction.cursor() as cursor: - # Collect all keys - encoded_prefix = prefix.encode(self._key_encoding) - # Assign the number of keys so don't have to store all keys in memory - self._n_keys = len( - [key for key, _ in cursor if key.startswith(encoded_prefix)] + self._n_keys = None # Computed on first __len__ call + + def __getstate__(self): + state = self.__dict__.copy() + del state["_lmdb_env"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + if self._lmdb_path is not None: + self._lmdb_env = lmdb.open( + self._lmdb_path, readonly=True, lock=False, subdir=True + ) + else: + raise RuntimeError( + "Cannot unpickle LMDBDict: no lmdb_path was provided at creation" ) def _decode_key(self, key): @@ -185,14 +199,32 @@ def __iter__(self): if not cursor.next(): break + def _count_keys(self): + """Count keys matching the prefix.""" + encoded_prefix = self._prefix.encode(self._key_encoding) + count = 0 + with self._lmdb_env.begin() as txn, txn.cursor() as cursor: + # Use set_range to jump to the first prefix occurrence + # and avoid scanning the entire LMDB. + if cursor.set_range(encoded_prefix): + while True: + if not cursor.key().startswith(encoded_prefix): + break + count += 1 + if not cursor.next(): + break + return count + def __len__(self): + if self._n_keys is None: + self._n_keys = self._count_keys() return self._n_keys def __getitem__(self, key): with self._lmdb_env.begin() as transaction: key_bytes = f"{self._prefix}{key}".encode(self._key_encoding) value_bytes = transaction.get(key_bytes) - if not value_bytes: + if value_bytes is None: raise KeyError(key) else: if self._value_encoding == "pkl": diff --git a/openfold3/entry_points/validator.py b/openfold3/entry_points/validator.py index 767cd7008..29b7f3dd2 100644 --- a/openfold3/entry_points/validator.py +++ b/openfold3/entry_points/validator.py @@ -112,7 +112,11 @@ class DataModuleArgs(BaseModel): batch_size: int = 1 data_seed: int | None = None num_workers: int = 10 + prefetch_factor: int | None = None num_workers_validation: int = 4 + prefetch_factor_validation: int | None = None + multiprocessing_context: str | None = None + persistent_workers: bool = False epoch_len: int = 4 diff --git a/pyproject.toml b/pyproject.toml index 6837874bd..e3f63957c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "boto3", "memory_profiler", "func_timeout", - "lmdb", # == 1.6 # This verison might be required to avoid Error during training + "lmdb", "ijson", "deepspeed", "pdbeccdutils", From 4e7fca61e9584948e7e52e4833c389f3300bbc58 Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Wed, 25 Mar 2026 15:21:15 -0400 Subject: [PATCH 2/7] LMDB fixes for forked processes with latest version of lmdb, also ensure connection is closed when dataset init finishes prior to forking --- .../framework/single_datasets/base_of3.py | 6 +- .../single_datasets/dataset_utils.py | 20 +++--- openfold3/core/data/io/dataset_cache.py | 3 + .../core/data/primitives/caches/format.py | 68 ++++++++----------- openfold3/core/data/primitives/caches/lmdb.py | 64 ++++++++++------- 5 files changed, 87 insertions(+), 74 deletions(-) diff --git a/openfold3/core/data/framework/single_datasets/base_of3.py b/openfold3/core/data/framework/single_datasets/base_of3.py index f756714e7..d27649510 100644 --- a/openfold3/core/data/framework/single_datasets/base_of3.py +++ b/openfold3/core/data/framework/single_datasets/base_of3.py @@ -180,8 +180,12 @@ def __init__(self, dataset_config) -> None: self.single_moltype = None self.debug_mode = dataset_config.debug_mode + # Release backend connections so worker processes inherit clean state. + # Each backend reopens lazily on first access in the worker. + self.dataset_cache.release_connections() + def warm_cache(self) -> None: - """Warm the LMDB page cache (no-op for JSON cache).""" + """Warm the OS page cache for LMDB. No-op for JSON.""" if self._dataset_cache_file.is_dir(): warm_lmdb_cache(self._dataset_cache_file) diff --git a/openfold3/core/data/framework/single_datasets/dataset_utils.py b/openfold3/core/data/framework/single_datasets/dataset_utils.py index f8fe42b68..70e80d72e 100644 --- a/openfold3/core/data/framework/single_datasets/dataset_utils.py +++ b/openfold3/core/data/framework/single_datasets/dataset_utils.py @@ -158,17 +158,19 @@ def getitem_debug_log(dataset_name: str = "") -> None: ) -def warm_lmdb_cache(lmdb_directory: Path) -> None: - """Sequentially read the LMDB data file to warm the OS page cache""" - data_file = lmdb_directory / "data.mdb" - file_size_gb = data_file.stat().st_size / (1024**3) - logger.info( - f"Warming LMDB page cache for {lmdb_directory} ({file_size_gb:.1f} GB)..." - ) +def warm_file_cache(file_path: Path) -> None: + """Sequentially read a file to warm the OS page cache.""" + file_size_gb = file_path.stat().st_size / (1024**3) + logger.info(f"Warming page cache for {file_path} ({file_size_gb:.1f} GB)...") t0 = time.monotonic() chunk_size = 8 * 1024 * 1024 - with open(data_file, "rb") as f: + with open(file_path, "rb") as f: while f.read(chunk_size): pass elapsed = time.monotonic() - t0 - logger.info(f"LMDB cache warm complete in {elapsed:.1f}s") + logger.info(f"Page cache warm complete in {elapsed:.1f}s") + + +def warm_lmdb_cache(lmdb_directory: Path) -> None: + """Sequentially read the LMDB data file to warm the OS page cache.""" + warm_file_cache(lmdb_directory / "data.mdb") diff --git a/openfold3/core/data/io/dataset_cache.py b/openfold3/core/data/io/dataset_cache.py index f19a1fceb..011be4a3c 100644 --- a/openfold3/core/data/io/dataset_cache.py +++ b/openfold3/core/data/io/dataset_cache.py @@ -192,6 +192,9 @@ def read_datacache( with lmdb_env.begin() as txn: dataset_cache_type = json.loads(txn.get(type_key).decode(str_encoding)) + # Only one connection can be open at a time, close before creating LMDBDict + lmdb_env.close() + if not dataset_cache_type: raise ValueError("No type found for this directory.") diff --git a/openfold3/core/data/primitives/caches/format.py b/openfold3/core/data/primitives/caches/format.py index 88e55d089..3d400d740 100755 --- a/openfold3/core/data/primitives/caches/format.py +++ b/openfold3/core/data/primitives/caches/format.py @@ -24,7 +24,7 @@ import lmdb -from openfold3.core.data.primitives.caches.lmdb import LMDBDict +from openfold3.core.data.primitives.caches.lmdb import LMDBDict, LMDBEnv from openfold3.core.data.resources.residues import MoleculeType K = TypeVar("K") @@ -482,6 +482,15 @@ def _parse_ref_mol_data_json(cls, data: dict) -> dict: ref_mol_data[ref_mol_id] = per_ref_mol_data_fmt return ref_mol_data + def release_connections(self) -> None: + """ + Close any open backend connections so fork inherits clean state. + Each backend reopens lazily on next access. No-op for plain dicts. + """ + for attr in (self.structure_data, self.reference_molecule_data): + if hasattr(attr, "close"): + attr.close() + def to_json(self, file: Path) -> None: """Write the dataset cache to a JSON file. @@ -526,32 +535,27 @@ def from_lmdb( DatasetCache: The constructed datacache. """ - - lmdb_env = lmdb.open( - str(lmdb_directory), readonly=True, lock=False, subdir=True - ) - - with lmdb_env.begin() as transaction: + lmdb_env = LMDBEnv(str(lmdb_directory)) + with lmdb_env.get().begin() as transaction: _ = cls._parse_type_lmdb(transaction, str_encoding) name = cls._parse_name_lmdb(transaction, str_encoding) - structure_data = cls._parse_structure_data_lmdb( - lmdb_env, - str_encoding, - structure_data_encoding, - lmdb_path=lmdb_directory, - ) - reference_molecule_data = cls._parse_ref_mol_data_lmdb( - lmdb_env, - str_encoding, - reference_molecule_data_encoding, - lmdb_path=lmdb_directory, - ) - return cls( - name=name, - structure_data=structure_data, - reference_molecule_data=reference_molecule_data, - ) + structure_data = cls._parse_structure_data_lmdb( + lmdb_env=lmdb_env, + str_encoding=str_encoding, + structure_data_encoding=structure_data_encoding, + ) + reference_molecule_data = cls._parse_ref_mol_data_lmdb( + lmdb_env=lmdb_env, + str_encoding=str_encoding, + reference_molecule_data_encoding=reference_molecule_data_encoding, + ) + + return cls( + name=name, + structure_data=structure_data, + reference_molecule_data=reference_molecule_data, + ) @staticmethod def _parse_type_lmdb( @@ -579,40 +583,28 @@ def _parse_name_lmdb( @staticmethod def _parse_structure_data_lmdb( - lmdb_env: lmdb.Environment, + lmdb_env: LMDBEnv, str_encoding: Literal["utf-8", "pkl"], structure_data_encoding: Literal["utf-8", "pkl"], - lmdb_path: Path | None = None, ) -> LMDBDict: - from openfold3.core.data.primitives.caches.lmdb import ( - LMDBDict, - ) - return LMDBDict( lmdb_env=lmdb_env, prefix="structure_data", key_encoding=str_encoding, value_encoding=structure_data_encoding, - lmdb_path=lmdb_path, ) @staticmethod def _parse_ref_mol_data_lmdb( - lmdb_env: lmdb.Environment, + lmdb_env: LMDBEnv, str_encoding: Literal["utf-8", "pkl"], reference_molecule_data_encoding: Literal["utf-8", "pkl"], - lmdb_path: Path | None = None, ) -> LMDBDict: - from openfold3.core.data.primitives.caches.lmdb import ( - LMDBDict, - ) - return LMDBDict( lmdb_env=lmdb_env, prefix="reference_molecule_data", key_encoding=str_encoding, value_encoding=reference_molecule_data_encoding, - lmdb_path=lmdb_path, ) def to_lmdb( diff --git a/openfold3/core/data/primitives/caches/lmdb.py b/openfold3/core/data/primitives/caches/lmdb.py index 97870c0df..f360c2f78 100644 --- a/openfold3/core/data/primitives/caches/lmdb.py +++ b/openfold3/core/data/primitives/caches/lmdb.py @@ -28,6 +28,30 @@ V = TypeVar("V") +class LMDBEnv: + """Lazy-opened LMDB environment shared between LMDBDict instances""" + + def __init__(self, path: str) -> None: + self._path = path + self._env: lmdb.Environment | None = None + + def get(self) -> lmdb.Environment: + if self._env is None: + self._env = lmdb.open(self._path, readonly=True, lock=False, subdir=True) + return self._env + + def close(self) -> None: + if self._env is not None: + self._env.close() + self._env = None + + def __getstate__(self) -> dict: + return {"_path": self._path, "_env": None} + + def __setstate__(self, state: dict) -> None: + self.__dict__.update(state) + + def convert_datacache_to_lmdb( dataset_cache_file_or_obj: Union[Path, "DatasetCache"], lmdb_directory: Path, @@ -129,54 +153,42 @@ def convert_datacache_to_lmdb( class LMDBDict(Mapping[K, V], Generic[K, V]): def __init__( self, - lmdb_env: lmdb.Environment, + lmdb_env: LMDBEnv, prefix: str, separator: str = ":", key_encoding: Literal["utf-8", "pkl"] = "utf-8", value_encoding: Literal["utf-8", "pkl"] = "pkl", - lmdb_path: str | Path | None = None, ): """A dict-like class with an LMDB backend for lazy loading of datacache entries. + Takes a shared LMDBEnv instance. Multiple LMDBDict objects for the same + file should share a single LMDBEnv so only one lmdb.Environment is opened + per file per process. Because pickle deduplicates shared references, this + sharing is preserved across fork/forkserver/spawn. + Args: - lmdb_env (lmdb.Environment): - The LMDB environment object. + lmdb_env (LMDBEnv): + Shared lazy env for this LMDB directory. prefix (str): header for fields used to construct keys in lmdb separator (str): Single separator character used to construct key key_encoding (Literal["utf-8", "pkl"]): Encoding of keys. Defaults to "utf-8". value_encoding (Literal["utf-8", "pkl"]): Encoding of values. Defaults to "pkl". - lmdb_path (str | Path | None): - Path to the LMDB directory. Used to reopen env after pickle - (e.g. for forkserver/spawn multiprocessing). Raises: KeyError: If a non-existent key is requested. """ self._lmdb_env = lmdb_env - self._lmdb_path = str(lmdb_path) if lmdb_path is not None else None self._prefix = prefix + separator self._key_encoding = key_encoding self._value_encoding = value_encoding self._n_keys = None # Computed on first __len__ call - def __getstate__(self): - state = self.__dict__.copy() - del state["_lmdb_env"] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - if self._lmdb_path is not None: - self._lmdb_env = lmdb.open( - self._lmdb_path, readonly=True, lock=False, subdir=True - ) - else: - raise RuntimeError( - "Cannot unpickle LMDBDict: no lmdb_path was provided at creation" - ) + def close(self) -> None: + """Close the underlying env. Reopens lazily on next access.""" + self._lmdb_env.close() def _decode_key(self, key): encoded_prefix = self._prefix.encode(self._key_encoding) @@ -185,7 +197,7 @@ def _decode_key(self, key): def __iter__(self): "Use an iterative method to not have to store all keys in memory." encoded_prefix = self._prefix.encode(self._key_encoding) - with self._lmdb_env.begin() as txn, txn.cursor() as cursor: + with self._lmdb_env.get().begin() as txn, txn.cursor() as cursor: # Seek to the first key >= prefix if cursor.set_range(encoded_prefix): while True: @@ -203,7 +215,7 @@ def _count_keys(self): """Count keys matching the prefix.""" encoded_prefix = self._prefix.encode(self._key_encoding) count = 0 - with self._lmdb_env.begin() as txn, txn.cursor() as cursor: + with self._lmdb_env.get().begin() as txn, txn.cursor() as cursor: # Use set_range to jump to the first prefix occurrence # and avoid scanning the entire LMDB. if cursor.set_range(encoded_prefix): @@ -221,7 +233,7 @@ def __len__(self): return self._n_keys def __getitem__(self, key): - with self._lmdb_env.begin() as transaction: + with self._lmdb_env.get().begin() as transaction: key_bytes = f"{self._prefix}{key}".encode(self._key_encoding) value_bytes = transaction.get(key_bytes) if value_bytes is None: From 918ea308a26e496de7511b55a5606cf63f5e0666 Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Wed, 25 Mar 2026 16:44:53 -0400 Subject: [PATCH 3/7] Move release connection call --- openfold3/core/data/framework/single_datasets/base_of3.py | 4 ---- openfold3/core/data/framework/single_datasets/monomer.py | 3 +++ openfold3/core/data/framework/single_datasets/pdb.py | 3 +++ openfold3/core/data/framework/single_datasets/validation.py | 3 +++ 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/openfold3/core/data/framework/single_datasets/base_of3.py b/openfold3/core/data/framework/single_datasets/base_of3.py index d27649510..e9e118372 100644 --- a/openfold3/core/data/framework/single_datasets/base_of3.py +++ b/openfold3/core/data/framework/single_datasets/base_of3.py @@ -180,10 +180,6 @@ def __init__(self, dataset_config) -> None: self.single_moltype = None self.debug_mode = dataset_config.debug_mode - # Release backend connections so worker processes inherit clean state. - # Each backend reopens lazily on first access in the worker. - self.dataset_cache.release_connections() - def warm_cache(self) -> None: """Warm the OS page cache for LMDB. No-op for JSON.""" if self._dataset_cache_file.is_dir(): diff --git a/openfold3/core/data/framework/single_datasets/monomer.py b/openfold3/core/data/framework/single_datasets/monomer.py index db41cedbc..2180d4ece 100644 --- a/openfold3/core/data/framework/single_datasets/monomer.py +++ b/openfold3/core/data/framework/single_datasets/monomer.py @@ -50,6 +50,9 @@ def __init__(self, dataset_config: dict) -> None: # Datapoint cache self.create_datapoint_cache() + # Release so fork/forkserver inherits clean state for the first epoch + self.dataset_cache.release_connections() + def create_datapoint_cache(self): """Creates the datapoint_cache for uniform sampling. diff --git a/openfold3/core/data/framework/single_datasets/pdb.py b/openfold3/core/data/framework/single_datasets/pdb.py index 505f0eebc..c85d37609 100644 --- a/openfold3/core/data/framework/single_datasets/pdb.py +++ b/openfold3/core/data/framework/single_datasets/pdb.py @@ -189,6 +189,9 @@ def __init__(self, dataset_config: dict) -> None: # Datapoint cache self.create_datapoint_cache() + # Release so fork/forkserver inherits clean state for the first epoch + self.dataset_cache.release_connections() + def create_datapoint_cache(self) -> None: """Creates the datapoint_cache with chain/interface probabilities. diff --git a/openfold3/core/data/framework/single_datasets/validation.py b/openfold3/core/data/framework/single_datasets/validation.py index dc7b78abb..e7e6b6b2d 100644 --- a/openfold3/core/data/framework/single_datasets/validation.py +++ b/openfold3/core/data/framework/single_datasets/validation.py @@ -63,6 +63,9 @@ def __init__(self, dataset_config: dict, world_size: int | None = None) -> None: # Dataset/datapoint cache self.create_datapoint_cache() + # Release so fork/forkserver inherits clean state for the first epoch + self.dataset_cache.release_connections() + # Cropping should be disabled for validation datasets if self.crop["token_crop"]["enabled"]: logger.warning( From 79a038c1265064a9c14ed78e66727e076b69196e Mon Sep 17 00:00:00 2001 From: Jan Domanski Date: Sat, 11 Apr 2026 03:19:49 -0700 Subject: [PATCH 4/7] add test_lmdb.py --- openfold3/tests/test_lmdb.py | 129 ++++++++++++++++++++++++++++++++++- 1 file changed, 127 insertions(+), 2 deletions(-) diff --git a/openfold3/tests/test_lmdb.py b/openfold3/tests/test_lmdb.py index 29f367b63..23a0c6709 100644 --- a/openfold3/tests/test_lmdb.py +++ b/openfold3/tests/test_lmdb.py @@ -12,14 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Helper functions to test the lmdb dict""" +"""Tests for LMDB dict and multiprocessing safety.""" import json +import pickle +import sys -import pytest # noqa: F401 - used for pytest tmp fixture +import lmdb +import pytest +import torch +from torch.utils.data import DataLoader, Dataset from openfold3.core.data.io.dataset_cache import read_datacache from openfold3.core.data.primitives.caches.lmdb import ( + LMDBDict, + LMDBEnv, convert_datacache_to_lmdb, ) @@ -57,6 +64,44 @@ } +def create_test_lmdb(lmdb_dir, num_items=10): + """Create a small LMDB with ``item:0`` … ``item:N-1`` keys.""" + env = lmdb.open(str(lmdb_dir), map_size=1024 * 1024, subdir=True) + with env.begin(write=True) as txn: + for i in range(num_items): + key = f"item:{i}".encode() + value = json.dumps({"index": i}).encode("utf-8") + txn.put(key, value) + env.close() + + +class LMDBDataset(Dataset): + """Minimal Dataset backed by LMDBEnv + LMDBDict. + + Defined at module level so spawn/forkserver workers can import it. + """ + + def __init__(self, lmdb_dir: str, num_items: int): + self._lmdb_env = LMDBEnv(lmdb_dir) + self._dict = LMDBDict( + lmdb_env=self._lmdb_env, + prefix="item", + key_encoding="utf-8", + value_encoding="utf-8", + ) + self._n = num_items + + def __len__(self): + return self._n + + def __getitem__(self, idx): + item = self._dict[str(idx)] + return torch.tensor(item["index"]) + + def release_connections(self): + self._lmdb_env.close() + + class TestLMDBDict: def test_lmdb_roundtrip(self, tmp_path): # Save dummy json @@ -75,3 +120,83 @@ def test_lmdb_roundtrip(self, tmp_path): expected_cache = read_datacache(test_config_json) assert lmdb_cache == expected_cache + + +class TestLMDBEnvPickle: + def test_raw_lmdb_env_not_pickleable(self, tmp_path): + """Raw lmdb.Environment cannot be pickled — this is the root cause + of spawn/forkserver failures without the LMDBEnv wrapper.""" + lmdb_dir = tmp_path / "raw" + create_test_lmdb(lmdb_dir) + env = lmdb.open(str(lmdb_dir), readonly=True, lock=False, subdir=True) + with pytest.raises(TypeError, match="cannot pickle"): + pickle.dumps(env) + env.close() + + def test_lmdb_env_pickle_roundtrip(self, tmp_path): + """LMDBEnv can be pickled and reads correctly after unpickling.""" + lmdb_dir = tmp_path / "env_pkl" + create_test_lmdb(lmdb_dir, num_items=3) + + env = LMDBEnv(str(lmdb_dir)) + _ = env.get() # force open + + data = pickle.dumps(env) + env.close() # close original — LMDB forbids two open envs for same path + + env2 = pickle.loads(data) + assert env2._env is None # connection stripped by __getstate__ + with env2.get().begin() as txn: + assert txn.get(b"item:0") is not None + env2.close() + + def test_lmdb_dict_pickle_roundtrip(self, tmp_path): + """LMDBDict survives pickle roundtrip and reads correctly.""" + lmdb_dir = tmp_path / "dict_pkl" + num_items = 5 + create_test_lmdb(lmdb_dir, num_items=num_items) + + env = LMDBEnv(str(lmdb_dir)) + d = LMDBDict( + lmdb_env=env, + prefix="item", + key_encoding="utf-8", + value_encoding="utf-8", + ) + original = d["0"] + + data = pickle.dumps(d) + env.close() # close original — LMDB forbids two open envs for same path + + d2 = pickle.loads(data) + assert d2["0"] == original + assert len(d2) == num_items + + +class TestLMDBMultiprocessingDataLoader: + @pytest.mark.parametrize("mp_context", ["fork", "forkserver", "spawn"]) + def test_dataloader_reads_all_items(self, tmp_path, mp_context): + """DataLoader with num_workers>0 reads all LMDB items correctly + across fork, forkserver, and spawn multiprocessing contexts.""" + if mp_context == "fork" and sys.platform == "darwin": + pytest.skip("fork is unsafe on macOS with Python >= 3.8") + + num_items = 20 + lmdb_dir = tmp_path / "mp_lmdb" + create_test_lmdb(lmdb_dir, num_items=num_items) + + dataset = LMDBDataset(str(lmdb_dir), num_items=num_items) + dataset.release_connections() # mimic real codebase: clean state before fork + + loader = DataLoader( + dataset, + batch_size=1, + num_workers=2, + multiprocessing_context=mp_context, + ) + + results = [] + for batch in loader: + results.extend(batch.tolist()) + + assert sorted(results) == list(range(num_items)) From b2c5a31ad390c29b1490d97067415177939b924f Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Thu, 7 May 2026 11:39:25 +0200 Subject: [PATCH 5/7] Add prefetch factor fix and fix cache loading key error --- openfold3/core/data/framework/data_module.py | 1 + openfold3/core/data/io/dataset_cache.py | 24 ++++++++------------ 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/openfold3/core/data/framework/data_module.py b/openfold3/core/data/framework/data_module.py index 950a8d3d6..29abbfe91 100644 --- a/openfold3/core/data/framework/data_module.py +++ b/openfold3/core/data/framework/data_module.py @@ -429,6 +429,7 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None) prefetch_factor = self.prefetch_factor persistent_workers = self.persistent_workers and num_workers > 0 + prefetch_factor = prefetch_factor if num_workers > 0 else None multiprocessing_context = ( self.multiprocessing_context if num_workers > 0 else None ) diff --git a/openfold3/core/data/io/dataset_cache.py b/openfold3/core/data/io/dataset_cache.py index 011be4a3c..44a1d9c05 100644 --- a/openfold3/core/data/io/dataset_cache.py +++ b/openfold3/core/data/io/dataset_cache.py @@ -139,13 +139,10 @@ def _read_datacache_file(datacache_path: Path) -> "DataCacheType": else: raise ValueError("Could not determine the type of the dataset cache.") - try: - # Infer which class to build - dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type) - except KeyError as exc: - raise ValueError( - f"Unknown dataset cache type: {dataset_cache_type}" - ) from exc + # Infer which class to build + dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type) + if dataset_cache_class is None: + raise ValueError(f"Unknown dataset cache type: {dataset_cache_type}") return dataset_cache_class.from_json(datacache_path) @@ -192,19 +189,16 @@ def read_datacache( with lmdb_env.begin() as txn: dataset_cache_type = json.loads(txn.get(type_key).decode(str_encoding)) - # Only one connection can be open at a time, close before creating LMDBDict + # Close temporary handle before from_lmdb creates its own LMDBEnv lmdb_env.close() if not dataset_cache_type: raise ValueError("No type found for this directory.") - try: - # Infer which class to build - dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type) - except KeyError as exc: - raise ValueError( - f"Unknown dataset cache type: {dataset_cache_type}" - ) from exc + # Infer which class to build + dataset_cache_class = DATASET_CACHE_CLASS_REGISTRY.get(dataset_cache_type) + if dataset_cache_class is None: + raise ValueError(f"Unknown dataset cache type: {dataset_cache_type}") dataset_cache = dataset_cache_class.from_lmdb( datacache_path, From ba7b90f7a0ed49f5069cf784fcd1e0874f90e930 Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Thu, 7 May 2026 11:49:50 +0200 Subject: [PATCH 6/7] Remove warm up, mostly unnecessary and placement isn't great. Add as optimization later if needed. --- .../framework/single_datasets/base_of3.py | 7 ------- .../single_datasets/dataset_utils.py | 20 ------------------- 2 files changed, 27 deletions(-) diff --git a/openfold3/core/data/framework/single_datasets/base_of3.py b/openfold3/core/data/framework/single_datasets/base_of3.py index e9e118372..7b6b432f1 100644 --- a/openfold3/core/data/framework/single_datasets/base_of3.py +++ b/openfold3/core/data/framework/single_datasets/base_of3.py @@ -27,7 +27,6 @@ SingleDataset, register_dataset, ) -from openfold3.core.data.framework.single_datasets.dataset_utils import warm_lmdb_cache from openfold3.core.data.io.dataset_cache import read_datacache from openfold3.core.data.pipelines.featurization.conformer import ( featurize_reference_conformers_of3, @@ -156,7 +155,6 @@ def __init__(self, dataset_config) -> None: # TODO: potentially expose the LMDB database encoding types self._dataset_cache_file = dataset_config.dataset_paths.dataset_cache_file self.dataset_cache = read_datacache(self._dataset_cache_file) - self.warm_cache() self.datapoint_cache = {} @@ -180,11 +178,6 @@ def __init__(self, dataset_config) -> None: self.single_moltype = None self.debug_mode = dataset_config.debug_mode - def warm_cache(self) -> None: - """Warm the OS page cache for LMDB. No-op for JSON.""" - if self._dataset_cache_file.is_dir(): - warm_lmdb_cache(self._dataset_cache_file) - @property def ccd(self): if self._ccd is None and self._ccd_file is not None: diff --git a/openfold3/core/data/framework/single_datasets/dataset_utils.py b/openfold3/core/data/framework/single_datasets/dataset_utils.py index 70e80d72e..ebee151a2 100644 --- a/openfold3/core/data/framework/single_datasets/dataset_utils.py +++ b/openfold3/core/data/framework/single_datasets/dataset_utils.py @@ -15,9 +15,7 @@ import copy import logging import os -import time from itertools import cycle, islice -from pathlib import Path import pandas as pd import torch @@ -156,21 +154,3 @@ def getitem_debug_log(dataset_name: str = "") -> None: f"pid={os.getpid()} worker_id={worker_id} wi.seed={wi_seed} " f"wi.base_seed={wi_base_seed} torch.initial_seed={torch_seed}", ) - - -def warm_file_cache(file_path: Path) -> None: - """Sequentially read a file to warm the OS page cache.""" - file_size_gb = file_path.stat().st_size / (1024**3) - logger.info(f"Warming page cache for {file_path} ({file_size_gb:.1f} GB)...") - t0 = time.monotonic() - chunk_size = 8 * 1024 * 1024 - with open(file_path, "rb") as f: - while f.read(chunk_size): - pass - elapsed = time.monotonic() - t0 - logger.info(f"Page cache warm complete in {elapsed:.1f}s") - - -def warm_lmdb_cache(lmdb_directory: Path) -> None: - """Sequentially read the LMDB data file to warm the OS page cache.""" - warm_file_cache(lmdb_directory / "data.mdb") From 975afdd859bffb048252cb262f0c66f7990216f8 Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Fri, 8 May 2026 18:21:47 +0200 Subject: [PATCH 7/7] Import correct conftest for lmdb test --- openfold3/tests/core/data/primitives/caches/test_lmdb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfold3/tests/core/data/primitives/caches/test_lmdb.py b/openfold3/tests/core/data/primitives/caches/test_lmdb.py index 26cf5c0e8..6fd14dbec 100644 --- a/openfold3/tests/core/data/primitives/caches/test_lmdb.py +++ b/openfold3/tests/core/data/primitives/caches/test_lmdb.py @@ -21,7 +21,6 @@ import lmdb import pytest import torch -from conftest import TEST_DATASET_CONFIG from torch.utils.data import DataLoader, Dataset from openfold3.core.data.io.dataset_cache import read_datacache @@ -30,6 +29,7 @@ LMDBEnv, convert_datacache_to_lmdb, ) +from openfold3.tests.core.data.primitives.caches.conftest import TEST_DATASET_CONFIG def create_test_lmdb(lmdb_dir, num_items=10):