From 37425c487da9b1299d5df590727621b5eff05524 Mon Sep 17 00:00:00 2001 From: Karson Chrispens Date: Sun, 15 Mar 2026 04:09:35 +0000 Subject: [PATCH 1/4] fix(rf3): add_missing_atoms also, along with giving elements that weren't in the base structure, adds nan coords, b-factors, and occupancies of 0 --- src/sampleworks/core/rewards/protocol.py | 8 ++- src/sampleworks/models/protenix/wrapper.py | 12 +++- src/sampleworks/models/rf3/wrapper.py | 10 +++ tests/conftest.py | 21 +++++- tests/rewards/test_reward_inputs.py | 77 ++++++++++++++++++++++ 5 files changed, 124 insertions(+), 4 deletions(-) create mode 100644 tests/rewards/test_reward_inputs.py diff --git a/src/sampleworks/core/rewards/protocol.py b/src/sampleworks/core/rewards/protocol.py index 93208996..0167f907 100644 --- a/src/sampleworks/core/rewards/protocol.py +++ b/src/sampleworks/core/rewards/protocol.py @@ -69,9 +69,14 @@ def from_atom_array( raise ValueError("Atom array must have 'element' annotation.") if not hasattr(atom_array, "b_factor"): raise ValueError("Atom array must have 'b_factor' annotation.") + if np.any(np.isnan(atom_array.b_factor)): + raise ValueError( + "Atom array contains NaN B-factors. Wrappers must replace NaN " + "B-factors before constructing RewardInputs (e.g., with a default of 20.0)." + ) if np.any(np.isnan(atom_array.coord)): raise ValueError("Atom array contains NaN coordinates.") - if np.any((atom_array.occupancy <= 0) | (atom_array.occupancy > 1)): + if np.any((atom_array.occupancy < 0) | (atom_array.occupancy > 1)): raise ValueError("Atom array contains invalid occupancy values.") elements_list = elements_to_scattering_indices(atom_array.element) @@ -97,6 +102,7 @@ def from_atom_array( p=num_particles, e=ensemble_size, ) + # TODO: eventually this should be configurable occupancies = torch.ones_like(b_factors) / ensemble_size input_coords = einx.rearrange( "... -> b ...", diff --git a/src/sampleworks/models/protenix/wrapper.py b/src/sampleworks/models/protenix/wrapper.py index bbe4501a..48a0328f 100644 --- a/src/sampleworks/models/protenix/wrapper.py +++ b/src/sampleworks/models/protenix/wrapper.py @@ -358,7 +358,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtenixConditionin updated_json_path = json_path.with_name(f"{json_path.stem}-add-msa.json") if not updated_json_path.exists(): # get all the required sequences from the json_dict - sequence_data = { + sequence_data: dict[str | int, str] = { n: seq_data["proteinChain"]["sequence"] for n, seq_data in enumerate(json_dict["sequences"]) if "proteinChain" in seq_data @@ -476,6 +476,16 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtenixConditionin model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32)) if not hasattr(model_aa, "b_factor") or model_aa.b_factor is None: model_aa.set_annotation("b_factor", np.full(len(model_aa), 20.0, dtype=np.float32)) + else: + nan_b_mask = np.isnan(model_aa.b_factor) + if nan_b_mask.any(): + b_factors = model_aa.b_factor.copy() + b_factors[nan_b_mask] = 20.0 + model_aa.set_annotation("b_factor", b_factors) + logger.info( + f"Replaced {int(nan_b_mask.sum())} NaN B-factors with default 20.0 " + f"(from unresolved atoms added by add_missing_atoms)" + ) num_atoms_protenix = len(atom_array_protenix) conditioning = ProtenixConditioning( diff --git a/src/sampleworks/models/rf3/wrapper.py b/src/sampleworks/models/rf3/wrapper.py index bcb4b609..ab50f646 100644 --- a/src/sampleworks/models/rf3/wrapper.py +++ b/src/sampleworks/models/rf3/wrapper.py @@ -344,6 +344,16 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]: model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32)) if not hasattr(model_aa, "b_factor") or model_aa.b_factor is None: model_aa.set_annotation("b_factor", np.full(len(model_aa), 20.0, dtype=np.float32)) + else: + nan_b_mask = np.isnan(model_aa.b_factor) + if nan_b_mask.any(): + b_factors = model_aa.b_factor.copy() + b_factors[nan_b_mask] = 20.0 + model_aa.set_annotation("b_factor", b_factors) + logger.info( + f"Replaced {int(nan_b_mask.sum())} NaN B-factors with default 20.0 " + f"(from unresolved atoms added by add_missing_atoms)" + ) conditioning = RF3Conditioning( s_inputs=pairformer_out["s_inputs"], diff --git a/tests/conftest.py b/tests/conftest.py index d1e13119..22219067 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ if str(_project_root) not in sys.path: sys.path.insert(0, str(_project_root)) -from atomworks.io.parser import parse +from atomworks.io.parser import parse, parse_atom_array from atomworks.io.utils.io_utils import load_any from biotite.structure import AtomArray, AtomArrayStack, stack from sampleworks.core.samplers.edm import AF3EDMSampler, EDMSamplerConfig @@ -420,6 +420,23 @@ def structure_1vme(resources_dir: Path) -> dict: return parse(resources_dir / "1vme" / "1vme_final.cif", ccd_mirror_path=None) +@pytest.fixture(scope="session") +def atom_array_1vme_with_missing_atoms(structure_1vme) -> AtomArray: + """1VME atom array after parse_atom_array with add_missing_atoms=True. + + This reproduces what RF3's InferenceInput.from_atom_array does internally. + """ + parsed = parse_atom_array( + structure_1vme["asym_unit"], + add_missing_atoms=True, + hydrogen_policy="keep", + ) + aa = parsed["asym_unit"] + if isinstance(aa, AtomArrayStack): + aa = aa[0] + return aa + + @pytest.fixture(scope="session") def structure_6b8x(resources_dir: Path) -> dict: return parse(resources_dir / "6b8x" / "6b8x_final.pdb", ccd_mirror_path=None) @@ -1068,4 +1085,4 @@ def perturbed_coords( torch.manual_seed(42) base = converging_mock_wrapper.target perturbation = torch.randn_like(base) * 0.1 # ty: ignore[invalid-argument-type] - return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator] \ No newline at end of file + return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator] diff --git a/tests/rewards/test_reward_inputs.py b/tests/rewards/test_reward_inputs.py new file mode 100644 index 00000000..575a1b02 --- /dev/null +++ b/tests/rewards/test_reward_inputs.py @@ -0,0 +1,77 @@ +"""Tests for RewardInputs validation. + +Verifies that RewardInputs.from_atom_array rejects atom arrays with NaN +B-factors, NaN coordinates, and invalid occupancies. +""" + +import numpy as np +import pytest +from sampleworks.core.rewards.protocol import RewardInputs + + +class TestRewardInputsFromAtomArray: + """Validate that RewardInputs.from_atom_array rejects invalid atom arrays.""" + + def test_nan_b_factors_from_missing_atoms_rejected(self, atom_array_1vme_with_missing_atoms): + """ + RewardInputs rejects the raw atom array from add_missing_atoms. + """ + with pytest.raises(ValueError, match="NaN B-factors"): + RewardInputs.from_atom_array(atom_array_1vme_with_missing_atoms, ensemble_size=1) + + def test_cleaned_missing_atoms_accepted(self, atom_array_1vme_with_missing_atoms): + """After applying the same fixes as the RF3 wrapper, the atom array passes.""" + aa = atom_array_1vme_with_missing_atoms.copy() + + # Fix NaN coordinates + nan_coord_mask = np.any(np.isnan(aa.coord), axis=-1) + if nan_coord_mask.any(): + resolved_coords = aa.coord[~nan_coord_mask] + centroid = resolved_coords.mean(axis=0) + n_nan = int(nan_coord_mask.sum()) + noise = np.random.normal(loc=0.0, scale=1.0, size=(n_nan, 3)).astype(np.float32) + new_coords = aa.coord.copy() + new_coords[nan_coord_mask] = centroid + noise + aa.coord = new_coords + + # Fix occupancy + aa.set_annotation("occupancy", np.ones(len(aa), dtype=np.float32)) + + # Fix NaN b_factors + nan_b_mask = np.isnan(aa.b_factor) + if nan_b_mask.any(): + b_factors = aa.b_factor.copy() + b_factors[nan_b_mask] = 20.0 + aa.set_annotation("b_factor", b_factors) + + reward_inputs = RewardInputs.from_atom_array(aa, ensemble_size=1) + assert reward_inputs.b_factors.shape[-1] == len(aa) + + def test_nan_coordinates_rejected(self, structure_1vme): + """NaN coordinates must be caught before constructing reward tensors.""" + atom_array = structure_1vme["asym_unit"].copy() + # Fix any pre-existing NaN b_factors so the coordinate check passes + b_factors = np.nan_to_num(atom_array.b_factor, nan=20.0) + atom_array.set_annotation("b_factor", b_factors) + coords = atom_array.coord.copy() + coords[..., 3, :] = np.nan + atom_array.coord = coords + + with pytest.raises(ValueError, match="NaN coordinates"): + RewardInputs.from_atom_array(atom_array, ensemble_size=1) + + def test_zero_occupancy_rejected(self, structure_1vme): + """Zero occupancy (from unresolved atoms) must be caught.""" + atom_array = structure_1vme["asym_unit"].copy() + # Fix any pre-existing NaN b_factors so the occupancy check passes + b_factors = np.nan_to_num(atom_array.b_factor, nan=20.0) + atom_array.set_annotation("b_factor", b_factors) + # Fix any pre-existing NaN coords so the occupancy check passes + coords = np.nan_to_num(atom_array.coord, nan=0.0) + atom_array.coord = coords + occupancies = atom_array.occupancy.copy() + occupancies[0:3] = 0.0 + atom_array.set_annotation("occupancy", occupancies) + + with pytest.raises(ValueError, match="invalid occupancy"): + RewardInputs.from_atom_array(atom_array, ensemble_size=1) From 040715f7a5eaf83681650a26f90a1be9f4aa3caf Mon Sep 17 00:00:00 2001 From: Karson Chrispens Date: Sun, 15 Mar 2026 04:32:13 +0000 Subject: [PATCH 2/4] fix(tests): remove zero occupancy test, we can allow that --- tests/rewards/test_reward_inputs.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tests/rewards/test_reward_inputs.py b/tests/rewards/test_reward_inputs.py index 575a1b02..a0ac9b41 100644 --- a/tests/rewards/test_reward_inputs.py +++ b/tests/rewards/test_reward_inputs.py @@ -59,19 +59,3 @@ def test_nan_coordinates_rejected(self, structure_1vme): with pytest.raises(ValueError, match="NaN coordinates"): RewardInputs.from_atom_array(atom_array, ensemble_size=1) - - def test_zero_occupancy_rejected(self, structure_1vme): - """Zero occupancy (from unresolved atoms) must be caught.""" - atom_array = structure_1vme["asym_unit"].copy() - # Fix any pre-existing NaN b_factors so the occupancy check passes - b_factors = np.nan_to_num(atom_array.b_factor, nan=20.0) - atom_array.set_annotation("b_factor", b_factors) - # Fix any pre-existing NaN coords so the occupancy check passes - coords = np.nan_to_num(atom_array.coord, nan=0.0) - atom_array.coord = coords - occupancies = atom_array.occupancy.copy() - occupancies[0:3] = 0.0 - atom_array.set_annotation("occupancy", occupancies) - - with pytest.raises(ValueError, match="invalid occupancy"): - RewardInputs.from_atom_array(atom_array, ensemble_size=1) From e813218efa130dbb932743e5360660600490a0ce Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Tue, 17 Mar 2026 15:28:35 -0700 Subject: [PATCH 3/4] Apply suggestions from code review--handle case where all coordinates are NaN in tests. Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- tests/rewards/test_reward_inputs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/rewards/test_reward_inputs.py b/tests/rewards/test_reward_inputs.py index a0ac9b41..467e3485 100644 --- a/tests/rewards/test_reward_inputs.py +++ b/tests/rewards/test_reward_inputs.py @@ -27,7 +27,10 @@ def test_cleaned_missing_atoms_accepted(self, atom_array_1vme_with_missing_atoms nan_coord_mask = np.any(np.isnan(aa.coord), axis=-1) if nan_coord_mask.any(): resolved_coords = aa.coord[~nan_coord_mask] - centroid = resolved_coords.mean(axis=0) + if len(resolved_coords) > 0: + centroid = resolved_coords.mean(axis=0) + else: + centroid = np.zeros(3) n_nan = int(nan_coord_mask.sum()) noise = np.random.normal(loc=0.0, scale=1.0, size=(n_nan, 3)).astype(np.float32) new_coords = aa.coord.copy() From 37d182eeda7f3cf420d837621bb9e5581bf606fe Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Tue, 17 Mar 2026 15:32:30 -0700 Subject: [PATCH 4/4] Update RewardInputs documentation to note that it accepts all non-negative occupancies, not just positive --- src/sampleworks/core/rewards/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sampleworks/core/rewards/protocol.py b/src/sampleworks/core/rewards/protocol.py index 0167f907..6b5fd72d 100644 --- a/src/sampleworks/core/rewards/protocol.py +++ b/src/sampleworks/core/rewards/protocol.py @@ -23,7 +23,7 @@ class RewardInputs: once and pass them to scale() methods without redundant extraction. The atom array passed to :meth:`from_atom_array` must already be clean: - all coordinates finite and all occupancies positive. Wrappers are + all coordinates finite and all occupancies non-negative. Wrappers are responsible for ensuring this (e.g. replacing NaN coordinates with noise and setting occupancy to 1.0 for model-operated atoms). """ @@ -44,7 +44,7 @@ def from_atom_array( """Construct RewardInputs from a Biotite AtomArray. The atom array must contain only valid atoms (finite coordinates, - positive occupancy). Callers are responsible for filtering + non-negative occupancy). Callers are responsible for filtering beforehand; no masking is applied here. Parameters