diff --git a/src/sampleworks/core/rewards/protocol.py b/src/sampleworks/core/rewards/protocol.py index 93208996..6b5fd72d 100644 --- a/src/sampleworks/core/rewards/protocol.py +++ b/src/sampleworks/core/rewards/protocol.py @@ -23,7 +23,7 @@ class RewardInputs: once and pass them to scale() methods without redundant extraction. The atom array passed to :meth:`from_atom_array` must already be clean: - all coordinates finite and all occupancies positive. Wrappers are + all coordinates finite and all occupancies non-negative. Wrappers are responsible for ensuring this (e.g. replacing NaN coordinates with noise and setting occupancy to 1.0 for model-operated atoms). """ @@ -44,7 +44,7 @@ def from_atom_array( """Construct RewardInputs from a Biotite AtomArray. The atom array must contain only valid atoms (finite coordinates, - positive occupancy). Callers are responsible for filtering + non-negative occupancy). Callers are responsible for filtering beforehand; no masking is applied here. Parameters @@ -69,9 +69,14 @@ def from_atom_array( raise ValueError("Atom array must have 'element' annotation.") if not hasattr(atom_array, "b_factor"): raise ValueError("Atom array must have 'b_factor' annotation.") + if np.any(np.isnan(atom_array.b_factor)): + raise ValueError( + "Atom array contains NaN B-factors. Wrappers must replace NaN " + "B-factors before constructing RewardInputs (e.g., with a default of 20.0)." + ) if np.any(np.isnan(atom_array.coord)): raise ValueError("Atom array contains NaN coordinates.") - if np.any((atom_array.occupancy <= 0) | (atom_array.occupancy > 1)): + if np.any((atom_array.occupancy < 0) | (atom_array.occupancy > 1)): raise ValueError("Atom array contains invalid occupancy values.") elements_list = elements_to_scattering_indices(atom_array.element) @@ -97,6 +102,7 @@ def from_atom_array( p=num_particles, e=ensemble_size, ) + # TODO: eventually this should be configurable occupancies = torch.ones_like(b_factors) / ensemble_size input_coords = einx.rearrange( "... -> b ...", diff --git a/src/sampleworks/models/protenix/wrapper.py b/src/sampleworks/models/protenix/wrapper.py index bbe4501a..48a0328f 100644 --- a/src/sampleworks/models/protenix/wrapper.py +++ b/src/sampleworks/models/protenix/wrapper.py @@ -358,7 +358,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtenixConditionin updated_json_path = json_path.with_name(f"{json_path.stem}-add-msa.json") if not updated_json_path.exists(): # get all the required sequences from the json_dict - sequence_data = { + sequence_data: dict[str | int, str] = { n: seq_data["proteinChain"]["sequence"] for n, seq_data in enumerate(json_dict["sequences"]) if "proteinChain" in seq_data @@ -476,6 +476,16 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtenixConditionin model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32)) if not hasattr(model_aa, "b_factor") or model_aa.b_factor is None: model_aa.set_annotation("b_factor", np.full(len(model_aa), 20.0, dtype=np.float32)) + else: + nan_b_mask = np.isnan(model_aa.b_factor) + if nan_b_mask.any(): + b_factors = model_aa.b_factor.copy() + b_factors[nan_b_mask] = 20.0 + model_aa.set_annotation("b_factor", b_factors) + logger.info( + f"Replaced {int(nan_b_mask.sum())} NaN B-factors with default 20.0 " + f"(from unresolved atoms added by add_missing_atoms)" + ) num_atoms_protenix = len(atom_array_protenix) conditioning = ProtenixConditioning( diff --git a/src/sampleworks/models/rf3/wrapper.py b/src/sampleworks/models/rf3/wrapper.py index bcb4b609..ab50f646 100644 --- a/src/sampleworks/models/rf3/wrapper.py +++ b/src/sampleworks/models/rf3/wrapper.py @@ -344,6 +344,16 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]: model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32)) if not hasattr(model_aa, "b_factor") or model_aa.b_factor is None: model_aa.set_annotation("b_factor", np.full(len(model_aa), 20.0, dtype=np.float32)) + else: + nan_b_mask = np.isnan(model_aa.b_factor) + if nan_b_mask.any(): + b_factors = model_aa.b_factor.copy() + b_factors[nan_b_mask] = 20.0 + model_aa.set_annotation("b_factor", b_factors) + logger.info( + f"Replaced {int(nan_b_mask.sum())} NaN B-factors with default 20.0 " + f"(from unresolved atoms added by add_missing_atoms)" + ) conditioning = RF3Conditioning( s_inputs=pairformer_out["s_inputs"], diff --git a/tests/conftest.py b/tests/conftest.py index d1e13119..22219067 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ if str(_project_root) not in sys.path: sys.path.insert(0, str(_project_root)) -from atomworks.io.parser import parse +from atomworks.io.parser import parse, parse_atom_array from atomworks.io.utils.io_utils import load_any from biotite.structure import AtomArray, AtomArrayStack, stack from sampleworks.core.samplers.edm import AF3EDMSampler, EDMSamplerConfig @@ -420,6 +420,23 @@ def structure_1vme(resources_dir: Path) -> dict: return parse(resources_dir / "1vme" / "1vme_final.cif", ccd_mirror_path=None) +@pytest.fixture(scope="session") +def atom_array_1vme_with_missing_atoms(structure_1vme) -> AtomArray: + """1VME atom array after parse_atom_array with add_missing_atoms=True. + + This reproduces what RF3's InferenceInput.from_atom_array does internally. + """ + parsed = parse_atom_array( + structure_1vme["asym_unit"], + add_missing_atoms=True, + hydrogen_policy="keep", + ) + aa = parsed["asym_unit"] + if isinstance(aa, AtomArrayStack): + aa = aa[0] + return aa + + @pytest.fixture(scope="session") def structure_6b8x(resources_dir: Path) -> dict: return parse(resources_dir / "6b8x" / "6b8x_final.pdb", ccd_mirror_path=None) @@ -1068,4 +1085,4 @@ def perturbed_coords( torch.manual_seed(42) base = converging_mock_wrapper.target perturbation = torch.randn_like(base) * 0.1 # ty: ignore[invalid-argument-type] - return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator] \ No newline at end of file + return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator] diff --git a/tests/rewards/test_reward_inputs.py b/tests/rewards/test_reward_inputs.py new file mode 100644 index 00000000..467e3485 --- /dev/null +++ b/tests/rewards/test_reward_inputs.py @@ -0,0 +1,64 @@ +"""Tests for RewardInputs validation. + +Verifies that RewardInputs.from_atom_array rejects atom arrays with NaN +B-factors, NaN coordinates, and invalid occupancies. +""" + +import numpy as np +import pytest +from sampleworks.core.rewards.protocol import RewardInputs + + +class TestRewardInputsFromAtomArray: + """Validate that RewardInputs.from_atom_array rejects invalid atom arrays.""" + + def test_nan_b_factors_from_missing_atoms_rejected(self, atom_array_1vme_with_missing_atoms): + """ + RewardInputs rejects the raw atom array from add_missing_atoms. + """ + with pytest.raises(ValueError, match="NaN B-factors"): + RewardInputs.from_atom_array(atom_array_1vme_with_missing_atoms, ensemble_size=1) + + def test_cleaned_missing_atoms_accepted(self, atom_array_1vme_with_missing_atoms): + """After applying the same fixes as the RF3 wrapper, the atom array passes.""" + aa = atom_array_1vme_with_missing_atoms.copy() + + # Fix NaN coordinates + nan_coord_mask = np.any(np.isnan(aa.coord), axis=-1) + if nan_coord_mask.any(): + resolved_coords = aa.coord[~nan_coord_mask] + if len(resolved_coords) > 0: + centroid = resolved_coords.mean(axis=0) + else: + centroid = np.zeros(3) + n_nan = int(nan_coord_mask.sum()) + noise = np.random.normal(loc=0.0, scale=1.0, size=(n_nan, 3)).astype(np.float32) + new_coords = aa.coord.copy() + new_coords[nan_coord_mask] = centroid + noise + aa.coord = new_coords + + # Fix occupancy + aa.set_annotation("occupancy", np.ones(len(aa), dtype=np.float32)) + + # Fix NaN b_factors + nan_b_mask = np.isnan(aa.b_factor) + if nan_b_mask.any(): + b_factors = aa.b_factor.copy() + b_factors[nan_b_mask] = 20.0 + aa.set_annotation("b_factor", b_factors) + + reward_inputs = RewardInputs.from_atom_array(aa, ensemble_size=1) + assert reward_inputs.b_factors.shape[-1] == len(aa) + + def test_nan_coordinates_rejected(self, structure_1vme): + """NaN coordinates must be caught before constructing reward tensors.""" + atom_array = structure_1vme["asym_unit"].copy() + # Fix any pre-existing NaN b_factors so the coordinate check passes + b_factors = np.nan_to_num(atom_array.b_factor, nan=20.0) + atom_array.set_annotation("b_factor", b_factors) + coords = atom_array.coord.copy() + coords[..., 3, :] = np.nan + atom_array.coord = coords + + with pytest.raises(ValueError, match="NaN coordinates"): + RewardInputs.from_atom_array(atom_array, ensemble_size=1)