Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/sampleworks/core/rewards/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RewardInputs:
once and pass them to scale() methods without redundant extraction.

The atom array passed to :meth:`from_atom_array` must already be clean:
all coordinates finite and all occupancies positive. Wrappers are
all coordinates finite and all occupancies non-negative. Wrappers are
responsible for ensuring this (e.g. replacing NaN coordinates with
noise and setting occupancy to 1.0 for model-operated atoms).
"""
Expand All @@ -44,7 +44,7 @@ def from_atom_array(
"""Construct RewardInputs from a Biotite AtomArray.

The atom array must contain only valid atoms (finite coordinates,
positive occupancy). Callers are responsible for filtering
non-negative occupancy). Callers are responsible for filtering
beforehand; no masking is applied here.

Parameters
Expand All @@ -69,9 +69,14 @@ def from_atom_array(
raise ValueError("Atom array must have 'element' annotation.")
if not hasattr(atom_array, "b_factor"):
raise ValueError("Atom array must have 'b_factor' annotation.")
if np.any(np.isnan(atom_array.b_factor)):
raise ValueError(
"Atom array contains NaN B-factors. Wrappers must replace NaN "
"B-factors before constructing RewardInputs (e.g., with a default of 20.0)."
)
if np.any(np.isnan(atom_array.coord)):
raise ValueError("Atom array contains NaN coordinates.")
if np.any((atom_array.occupancy <= 0) | (atom_array.occupancy > 1)):
if np.any((atom_array.occupancy < 0) | (atom_array.occupancy > 1)):
Comment thread
marcuscollins marked this conversation as resolved.
raise ValueError("Atom array contains invalid occupancy values.")
Comment thread
coderabbitai[bot] marked this conversation as resolved.

elements_list = elements_to_scattering_indices(atom_array.element)
Expand All @@ -97,6 +102,7 @@ def from_atom_array(
p=num_particles,
e=ensemble_size,
)
# TODO: eventually this should be configurable
occupancies = torch.ones_like(b_factors) / ensemble_size
input_coords = einx.rearrange(
"... -> b ...",
Expand Down
12 changes: 11 additions & 1 deletion src/sampleworks/models/protenix/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtenixConditionin
updated_json_path = json_path.with_name(f"{json_path.stem}-add-msa.json")
if not updated_json_path.exists():
# get all the required sequences from the json_dict
sequence_data = {
sequence_data: dict[str | int, str] = {
n: seq_data["proteinChain"]["sequence"]
for n, seq_data in enumerate(json_dict["sequences"])
if "proteinChain" in seq_data
Expand Down Expand Up @@ -476,6 +476,16 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtenixConditionin
model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32))
if not hasattr(model_aa, "b_factor") or model_aa.b_factor is None:
model_aa.set_annotation("b_factor", np.full(len(model_aa), 20.0, dtype=np.float32))
else:
nan_b_mask = np.isnan(model_aa.b_factor)
if nan_b_mask.any():
b_factors = model_aa.b_factor.copy()
b_factors[nan_b_mask] = 20.0
model_aa.set_annotation("b_factor", b_factors)
logger.info(
f"Replaced {int(nan_b_mask.sum())} NaN B-factors with default 20.0 "
f"(from unresolved atoms added by add_missing_atoms)"
)

num_atoms_protenix = len(atom_array_protenix)
conditioning = ProtenixConditioning(
Expand Down
10 changes: 10 additions & 0 deletions src/sampleworks/models/rf3/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,16 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]:
model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32))
if not hasattr(model_aa, "b_factor") or model_aa.b_factor is None:
model_aa.set_annotation("b_factor", np.full(len(model_aa), 20.0, dtype=np.float32))
else:
nan_b_mask = np.isnan(model_aa.b_factor)
if nan_b_mask.any():
b_factors = model_aa.b_factor.copy()
b_factors[nan_b_mask] = 20.0
model_aa.set_annotation("b_factor", b_factors)
logger.info(
f"Replaced {int(nan_b_mask.sum())} NaN B-factors with default 20.0 "
f"(from unresolved atoms added by add_missing_atoms)"
)

conditioning = RF3Conditioning(
s_inputs=pairformer_out["s_inputs"],
Expand Down
21 changes: 19 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))

from atomworks.io.parser import parse
from atomworks.io.parser import parse, parse_atom_array
from atomworks.io.utils.io_utils import load_any
from biotite.structure import AtomArray, AtomArrayStack, stack
from sampleworks.core.samplers.edm import AF3EDMSampler, EDMSamplerConfig
Expand Down Expand Up @@ -420,6 +420,23 @@ def structure_1vme(resources_dir: Path) -> dict:
return parse(resources_dir / "1vme" / "1vme_final.cif", ccd_mirror_path=None)


@pytest.fixture(scope="session")
def atom_array_1vme_with_missing_atoms(structure_1vme) -> AtomArray:
"""1VME atom array after parse_atom_array with add_missing_atoms=True.

This reproduces what RF3's InferenceInput.from_atom_array does internally.
"""
parsed = parse_atom_array(
structure_1vme["asym_unit"],
add_missing_atoms=True,
hydrogen_policy="keep",
)
aa = parsed["asym_unit"]
if isinstance(aa, AtomArrayStack):
aa = aa[0]
return aa


@pytest.fixture(scope="session")
def structure_6b8x(resources_dir: Path) -> dict:
return parse(resources_dir / "6b8x" / "6b8x_final.pdb", ccd_mirror_path=None)
Expand Down Expand Up @@ -1068,4 +1085,4 @@ def perturbed_coords(
torch.manual_seed(42)
base = converging_mock_wrapper.target
perturbation = torch.randn_like(base) * 0.1 # ty: ignore[invalid-argument-type]
return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator]
return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator]
64 changes: 64 additions & 0 deletions tests/rewards/test_reward_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Tests for RewardInputs validation.

Verifies that RewardInputs.from_atom_array rejects atom arrays with NaN
B-factors, NaN coordinates, and invalid occupancies.
"""

import numpy as np
import pytest
from sampleworks.core.rewards.protocol import RewardInputs


class TestRewardInputsFromAtomArray:
"""Validate that RewardInputs.from_atom_array rejects invalid atom arrays."""

def test_nan_b_factors_from_missing_atoms_rejected(self, atom_array_1vme_with_missing_atoms):
"""
RewardInputs rejects the raw atom array from add_missing_atoms.
"""
with pytest.raises(ValueError, match="NaN B-factors"):
RewardInputs.from_atom_array(atom_array_1vme_with_missing_atoms, ensemble_size=1)

def test_cleaned_missing_atoms_accepted(self, atom_array_1vme_with_missing_atoms):
"""After applying the same fixes as the RF3 wrapper, the atom array passes."""
Comment thread
k-chrispens marked this conversation as resolved.
aa = atom_array_1vme_with_missing_atoms.copy()

# Fix NaN coordinates
nan_coord_mask = np.any(np.isnan(aa.coord), axis=-1)
if nan_coord_mask.any():
resolved_coords = aa.coord[~nan_coord_mask]
if len(resolved_coords) > 0:
centroid = resolved_coords.mean(axis=0)
else:
centroid = np.zeros(3)
n_nan = int(nan_coord_mask.sum())
Comment thread
coderabbitai[bot] marked this conversation as resolved.
noise = np.random.normal(loc=0.0, scale=1.0, size=(n_nan, 3)).astype(np.float32)
Comment thread
marcuscollins marked this conversation as resolved.
new_coords = aa.coord.copy()
new_coords[nan_coord_mask] = centroid + noise
aa.coord = new_coords

# Fix occupancy
aa.set_annotation("occupancy", np.ones(len(aa), dtype=np.float32))

# Fix NaN b_factors
nan_b_mask = np.isnan(aa.b_factor)
if nan_b_mask.any():
b_factors = aa.b_factor.copy()
b_factors[nan_b_mask] = 20.0
aa.set_annotation("b_factor", b_factors)

reward_inputs = RewardInputs.from_atom_array(aa, ensemble_size=1)
assert reward_inputs.b_factors.shape[-1] == len(aa)

def test_nan_coordinates_rejected(self, structure_1vme):
"""NaN coordinates must be caught before constructing reward tensors."""
atom_array = structure_1vme["asym_unit"].copy()
# Fix any pre-existing NaN b_factors so the coordinate check passes
b_factors = np.nan_to_num(atom_array.b_factor, nan=20.0)
atom_array.set_annotation("b_factor", b_factors)
coords = atom_array.coord.copy()
coords[..., 3, :] = np.nan
atom_array.coord = coords
Comment thread
k-chrispens marked this conversation as resolved.

with pytest.raises(ValueError, match="NaN coordinates"):
RewardInputs.from_atom_array(atom_array, ensemble_size=1)
Loading