From cca24a77d644c98a7249f682fae0746104a7ac9e Mon Sep 17 00:00:00 2001 From: Doris Mai Date: Sat, 28 Mar 2026 23:17:41 +0000 Subject: [PATCH] refactor(eval): unify rscc script structure loading with generate_synthetic_density.py --- pixi.lock | 4 +- scripts/eval/rscc_grid_search_script.py | 45 +++++++++-------------- src/sampleworks/utils/atom_array_utils.py | 17 +++++++-- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/pixi.lock b/pixi.lock index 12ed24ce..601f87f6 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9719,8 +9719,8 @@ packages: timestamp: 1753407970803 - pypi: ./ name: sampleworks - version: 0.4.0 - sha256: 5db03aab50df2b70618c97837dfeb5cda94af7877f98bbc922f7438fafc86e77 + version: 0.4.1 + sha256: 71a8650ab9088e0cd195579bf80abdf5ed4fbc90c35af2bace27c123238380eb requires_dist: - atomworks[ml]==2.1.1 - python-dotenv diff --git a/scripts/eval/rscc_grid_search_script.py b/scripts/eval/rscc_grid_search_script.py index 78bbca61..bdfc751c 100644 --- a/scripts/eval/rscc_grid_search_script.py +++ b/scripts/eval/rscc_grid_search_script.py @@ -23,19 +23,18 @@ import torch # Import local modules for density calculation -from atomworks.io.parser import parse -from biotite.structure import AtomArrayStack +from biotite.structure import AtomArray from loguru import logger from sampleworks.core.forward_models.xray.real_space_density_deps.qfit.volume import XMap from sampleworks.eval.constants import DEFAULT_SELECTION_PADDING from sampleworks.eval.grid_search_eval_utils import parse_eval_args, setup_evaluation_parameters from sampleworks.eval.metrics import rscc from sampleworks.eval.structure_utils import ( - get_asym_unit_from_structure, get_reference_structure_coords, ) from sampleworks.utils.atom_array_utils import ( filter_to_common_atoms, + load_structure_with_altlocs, remove_atoms_with_any_nan_coords, ) from sampleworks.utils.density_utils import compute_density_from_atomarray @@ -71,7 +70,7 @@ def main(args: argparse.Namespace): results = [] base_map_cache: dict[tuple[str, tuple[tuple[str, float], ...], str], tuple[XMap, XMap]] = {} - ref_full_structure_cache: dict[tuple[str, tuple[tuple[str, float], ...]], AtomArrayStack] = {} + ref_full_structure_cache: dict[tuple[str, tuple[tuple[str, float], ...]], AtomArray] = {} # TODO parallelize this loop? It uses GPU, so be careful. for i, trial in enumerate(all_trials): prev_result_count = len(results) @@ -107,8 +106,6 @@ def main(args: argparse.Namespace): trial_result["selection"] = selection trial_result["error"] = None try: - # TODO: this needs to be better unified with what's in generate_synthetic_density - # # Load base map for canonical unit cell, # don't overwrite the base map with selection map--we'll use the full map later too. if (protein, trial.occ_key, selection) not in base_map_cache: @@ -136,19 +133,11 @@ def main(args: argparse.Namespace): if extracted_base is None or extracted_base.shape[0] == 0: raise ValueError(f"Extracted base map from {base_map_path} is empty") - # Load refined structure - structure = parse(trial.refined_cif_path, ccd_mirror_path=None) - - # Compute density from refined structure - atom_array = get_asym_unit_from_structure(structure) - if not hasattr(atom_array, "coord") or atom_array.coord is None: - raise AttributeError("AtomArray | AtomArrayStack is missing coordinates") - - if not hasattr(atom_array, "b_factor"): - logger.warning( - f"No b-factor array found in {trial.refined_cif_path}, setting to 20." - ) - atom_array.set_annotation("b_factor", np.full(atom_array.coord.shape[-2], 20.0)) + # Load refined structure twice: + # - all altlocs (occupancy-weighted) for density computation + # - highest-occupancy altloc for alignment (no duplicate atoms) + atom_array_all_altlocs = load_structure_with_altlocs(trial.refined_cif_path) + atom_array = load_structure_with_altlocs(trial.refined_cif_path, altloc="occupancy") # Lines ~183-245 are to align the refined structure to the reference structure. # so that the calculated maps are also aligned, for a correct RSCC calculation @@ -163,9 +152,9 @@ def main(args: argparse.Namespace): f"occupancy {trial.altloc_occupancies}" ) - # 2. Load the reference structure with parse() to get only the first altloc - ref_structure = parse(ref_path, ccd_mirror_path=None) - ref_atom_array = get_asym_unit_from_structure(ref_structure) + # 2. Load the reference structure with highest-occupancy altloc + # (used only for alignment, so no need for all altlocs) + ref_atom_array = load_structure_with_altlocs(ref_path, altloc="occupancy") logger.info( f"Caching reference structure for {protein} " f"altloc_occupancies={trial.altloc_occupancies}" @@ -175,9 +164,11 @@ def main(args: argparse.Namespace): ref_atom_array = ref_full_structure_cache[(protein, trial.occ_key)] # 3. Find the common atoms with non-nan coords between the reference - # and the refined structure + # and the refined structure for alignment. Both use the highest + # occupancy altloc for alignment (no duplicate atoms). ref_atom_array = remove_atoms_with_any_nan_coords(ref_atom_array) atom_array = remove_atoms_with_any_nan_coords(atom_array) + atom_array_all_altlocs = remove_atoms_with_any_nan_coords(atom_array_all_altlocs) ref_common, pred_common = filter_to_common_atoms(ref_atom_array, atom_array) # 4. Align the refined structure to the reference @@ -214,16 +205,16 @@ def main(args: argparse.Namespace): allow_gradients=False, ) - # 5. Apply the transform to the entire refined structure (atom_array) - atom_array_coords_torch = torch.from_numpy(atom_array.coord) + # 5. Apply the transform to the entire refined structure (atom_array_all_altlocs) + atom_array_coords_torch = torch.from_numpy(atom_array_all_altlocs.coord) aligned_coords_torch = apply_forward_transform( atom_array_coords_torch, transform, rotation_only=False ) - atom_array.coord = aligned_coords_torch.numpy() + atom_array_all_altlocs.coord = aligned_coords_torch.numpy() # Compute density from the aligned refined structure computed_density, _ = compute_density_from_atomarray( - atom_array, xmap=base_xmap, em_mode=False, device=device + atom_array_all_altlocs, xmap=base_xmap, em_mode=False, device=device ) # Create an XMap from the computed density by copying the base xmap diff --git a/src/sampleworks/utils/atom_array_utils.py b/src/sampleworks/utils/atom_array_utils.py index 7d87cb94..af3afc55 100644 --- a/src/sampleworks/utils/atom_array_utils.py +++ b/src/sampleworks/utils/atom_array_utils.py @@ -33,8 +33,10 @@ class AltlocInfo: atom_masks: dict[str, np.ndarray[Any, np.dtype[np.bool_]]] -def load_structure_with_altlocs(path: Path) -> AtomArray: - """Load a structure file with alternate conformations and occupancy data. +def load_structure_with_altlocs( + path: Path, altloc: Literal["all", "occupancy", "first"] = "all" +) -> AtomArray: + """Load a structure file with occupancy and B-factor data. Takes the first model if multiple models are present. @@ -42,14 +44,21 @@ def load_structure_with_altlocs(path: Path) -> AtomArray: ---------- path Path to the structure file (PDB, mmCIF, etc.) + altloc + How to handle alternate conformations (passed directly to biotite): + + - ``"all"`` — keep every altloc as a separate atom (default); use for + density computation where all conformers should contribute. + - ``"occupancy"`` — keep the highest-occupancy altloc per residue; use + for tasks requiring a single unambiguous conformation (e.g. alignment). + - ``"first"`` — keep the first altloc ID appearing in each residue. Returns ------- AtomArray Loaded structure with occupancy and B-factor data """ - # Currently, we need to specify extra_fields=["occupancy"] to load altlocs properly - atom_array = load_any(path, altloc="all", extra_fields=["occupancy", "b_factor"]) + atom_array = load_any(path, altloc=altloc, extra_fields=["occupancy", "b_factor"]) if isinstance(atom_array, AtomArrayStack): atom_array = cast(AtomArray, atom_array[0]) return atom_array