Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 18 additions & 27 deletions scripts/eval/rscc_grid_search_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@
import torch

# Import local modules for density calculation
from atomworks.io.parser import parse
from biotite.structure import AtomArrayStack
from biotite.structure import AtomArray
from loguru import logger
from sampleworks.core.forward_models.xray.real_space_density_deps.qfit.volume import XMap
from sampleworks.eval.constants import DEFAULT_SELECTION_PADDING
from sampleworks.eval.grid_search_eval_utils import parse_eval_args, setup_evaluation_parameters
from sampleworks.eval.metrics import rscc
from sampleworks.eval.structure_utils import (
get_asym_unit_from_structure,
get_reference_structure_coords,
)
from sampleworks.utils.atom_array_utils import (
filter_to_common_atoms,
load_structure_with_altlocs,
remove_atoms_with_any_nan_coords,
)
from sampleworks.utils.density_utils import compute_density_from_atomarray
Expand Down Expand Up @@ -71,7 +70,7 @@ def main(args: argparse.Namespace):

results = []
base_map_cache: dict[tuple[str, tuple[tuple[str, float], ...], str], tuple[XMap, XMap]] = {}
ref_full_structure_cache: dict[tuple[str, tuple[tuple[str, float], ...]], AtomArrayStack] = {}
ref_full_structure_cache: dict[tuple[str, tuple[tuple[str, float], ...]], AtomArray] = {}
# TODO parallelize this loop? It uses GPU, so be careful.
for i, trial in enumerate(all_trials):
prev_result_count = len(results)
Expand Down Expand Up @@ -107,8 +106,6 @@ def main(args: argparse.Namespace):
trial_result["selection"] = selection
trial_result["error"] = None
try:
# TODO: this needs to be better unified with what's in generate_synthetic_density
#
# Load base map for canonical unit cell,
# don't overwrite the base map with selection map--we'll use the full map later too.
if (protein, trial.occ_key, selection) not in base_map_cache:
Expand Down Expand Up @@ -136,19 +133,11 @@ def main(args: argparse.Namespace):
if extracted_base is None or extracted_base.shape[0] == 0:
raise ValueError(f"Extracted base map from {base_map_path} is empty")

# Load refined structure
structure = parse(trial.refined_cif_path, ccd_mirror_path=None)

# Compute density from refined structure
atom_array = get_asym_unit_from_structure(structure)
if not hasattr(atom_array, "coord") or atom_array.coord is None:
raise AttributeError("AtomArray | AtomArrayStack is missing coordinates")

if not hasattr(atom_array, "b_factor"):
logger.warning(
f"No b-factor array found in {trial.refined_cif_path}, setting to 20."
)
atom_array.set_annotation("b_factor", np.full(atom_array.coord.shape[-2], 20.0))
# Load refined structure twice:
# - all altlocs (occupancy-weighted) for density computation
# - highest-occupancy altloc for alignment (no duplicate atoms)
atom_array_all_altlocs = load_structure_with_altlocs(trial.refined_cif_path)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW these shouldn't have altlocs (the current structure predictors produce structures with no altlocs, although I suppose this could change.) I'm curious why you chose this path?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this could be a problem, because load_structure_with_altlocs currently returns only the first model in the CIF/PDB file. Later on, we need to handle multiple models (as each of our output CIF files contains the entire ensemble, but they are represented as models, not with altlocs.

As part of the PR, could you include a test that makes sure this code works correctly? I can point you to files if you need example input. But make sure that the new version produces the same output as the old version.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I thought the predicted ensemble would be represented as altlocs. Can you point me to files that currently represent the ensemble predictions?

atom_array = load_structure_with_altlocs(trial.refined_cif_path, altloc="occupancy")

# Lines ~183-245 are to align the refined structure to the reference structure.
# so that the calculated maps are also aligned, for a correct RSCC calculation
Expand All @@ -163,9 +152,9 @@ def main(args: argparse.Namespace):
f"occupancy {trial.altloc_occupancies}"
)

# 2. Load the reference structure with parse() to get only the first altloc
ref_structure = parse(ref_path, ccd_mirror_path=None)
ref_atom_array = get_asym_unit_from_structure(ref_structure)
# 2. Load the reference structure with highest-occupancy altloc
# (used only for alignment, so no need for all altlocs)
ref_atom_array = load_structure_with_altlocs(ref_path, altloc="occupancy")
logger.info(
f"Caching reference structure for {protein} "
f"altloc_occupancies={trial.altloc_occupancies}"
Expand All @@ -175,9 +164,11 @@ def main(args: argparse.Namespace):
ref_atom_array = ref_full_structure_cache[(protein, trial.occ_key)]

# 3. Find the common atoms with non-nan coords between the reference
# and the refined structure
# and the refined structure for alignment. Both use the highest
# occupancy altloc for alignment (no duplicate atoms).
ref_atom_array = remove_atoms_with_any_nan_coords(ref_atom_array)
atom_array = remove_atoms_with_any_nan_coords(atom_array)
atom_array_all_altlocs = remove_atoms_with_any_nan_coords(atom_array_all_altlocs)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
ref_common, pred_common = filter_to_common_atoms(ref_atom_array, atom_array)

# 4. Align the refined structure to the reference
Expand Down Expand Up @@ -214,16 +205,16 @@ def main(args: argparse.Namespace):
allow_gradients=False,
)

# 5. Apply the transform to the entire refined structure (atom_array)
atom_array_coords_torch = torch.from_numpy(atom_array.coord)
# 5. Apply the transform to the entire refined structure (atom_array_all_altlocs)
atom_array_coords_torch = torch.from_numpy(atom_array_all_altlocs.coord)
aligned_coords_torch = apply_forward_transform(
atom_array_coords_torch, transform, rotation_only=False
)
atom_array.coord = aligned_coords_torch.numpy()
atom_array_all_altlocs.coord = aligned_coords_torch.numpy()

# Compute density from the aligned refined structure
computed_density, _ = compute_density_from_atomarray(
atom_array, xmap=base_xmap, em_mode=False, device=device
atom_array_all_altlocs, xmap=base_xmap, em_mode=False, device=device
)

# Create an XMap from the computed density by copying the base xmap
Expand Down
17 changes: 13 additions & 4 deletions src/sampleworks/utils/atom_array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,32 @@ class AltlocInfo:
atom_masks: dict[str, np.ndarray[Any, np.dtype[np.bool_]]]


def load_structure_with_altlocs(path: Path) -> AtomArray:
"""Load a structure file with alternate conformations and occupancy data.
def load_structure_with_altlocs(
path: Path, altloc: Literal["all", "occupancy", "first"] = "all"
) -> AtomArray:
"""Load a structure file with occupancy and B-factor data.

Takes the first model if multiple models are present.

Parameters
----------
path
Path to the structure file (PDB, mmCIF, etc.)
altloc
How to handle alternate conformations (passed directly to biotite):

- ``"all"`` — keep every altloc as a separate atom (default); use for
density computation where all conformers should contribute.
- ``"occupancy"`` — keep the highest-occupancy altloc per residue; use
for tasks requiring a single unambiguous conformation (e.g. alignment).
- ``"first"`` — keep the first altloc ID appearing in each residue.

Returns
-------
AtomArray
Loaded structure with occupancy and B-factor data
"""
# Currently, we need to specify extra_fields=["occupancy"] to load altlocs properly
atom_array = load_any(path, altloc="all", extra_fields=["occupancy", "b_factor"])
atom_array = load_any(path, altloc=altloc, extra_fields=["occupancy", "b_factor"])
if isinstance(atom_array, AtomArrayStack):
atom_array = cast(AtomArray, atom_array[0])
return atom_array
Expand Down
Loading