From dcb90397685bfa0e21879f0c9147a7c15c88c0a3 Mon Sep 17 00:00:00 2001 From: Karson Chrispens <33336327+k-chrispens@users.noreply.github.com> Date: Wed, 15 Apr 2026 21:56:50 +0000 Subject: [PATCH 1/5] feat: script for classifying conformational changes --- .gitignore | 5 +- scripts/eval/classify_altloc_regions.py | 434 ++++++++++++++++++++++++ 2 files changed, 438 insertions(+), 1 deletion(-) create mode 100644 scripts/eval/classify_altloc_regions.py diff --git a/.gitignore b/.gitignore index 827be6ce..67a0addd 100644 --- a/.gitignore +++ b/.gitignore @@ -224,4 +224,7 @@ checkpoints/ # Large results from runs grid_search_results/ outputs/ -initial_dataset_40/ +initial_dataset_40*/ +*.tar.gz +*.tgz +*.csv \ No newline at end of file diff --git a/scripts/eval/classify_altloc_regions.py b/scripts/eval/classify_altloc_regions.py new file mode 100644 index 00000000..8dc25977 --- /dev/null +++ b/scripts/eval/classify_altloc_regions.py @@ -0,0 +1,434 @@ +"""Classify altloc regions. + +This script consumes the output of ``scripts/eval/find_altloc_selections.py`` +and classifies each contiguous altloc span into one of four bins: + +1. ``side_chain_only`` : altloc atoms exist in the residue, but none of its + backbone atoms have altlocs. +2. ``small_loop`` : a contiguous backbone altloc span whose mean per-residue + backbone lDDT score (defined below) between altlocs is above + ``--loop-lddt-threshold`` (default 0.75). +3. ``large_loop`` : a contiguous backbone-altloc span whose mean per-residue + backbone lDDT score between altlocs is below ``--loop-lddt-threshold``. +4. ``domain_shift`` : a single contiguous backbone-altloc span longer than + ``--domain-shift-min-span`` residues (default 50). Classified before the + loop lDDT test. + +Score definition (important, slightly different from canonical lDDT): + + For a given pair of altlocs, the score is the **equal-weighted arithmetic + mean** of per residue backbone lDDT scores across the span: + + score = (1 / N_span_residues) * sum_k score_k + + Each ``score_k`` is the standard per-residue local lDDT from + :class:`sampleworks.metrics.lddt.AllAtomLDDT`, which is the fraction of residue + k's neighbor distances (within 15 Å) that are preserved between altlocs across + the four lDDT thresholds (0.5, 1, 2, 4 Å). + + The canonical atom pair weighted lDDT would instead aggregate as + ``sum_k(score_k * n_pairs_k) / sum_k(n_pairs_k)``. This script's + equal residue mean is equivalent with that only when every span residue + has the same neighbor count. The 0.75 default is calibrated for this specific + calculation. + +Altloc pairing: when > 2 altlocs are present, the score above is +computed for every combination of altloc pairs and the span is +classified by the *minimum* score over pair combinations. + +Use ``find_altloc_selections.py --min-span 1`` to ensure single-residue side only +selections. +""" + +import argparse +import json +import sys +from pathlib import Path + +import numpy as np +import pandas as pd +from loguru import logger +from sampleworks.metrics.lddt import AllAtomLDDT +from sampleworks.utils.atom_array_utils import ( + BACKBONE_ATOM_TYPES, + BLANK_ALTLOC_IDS, + detect_altlocs, + filter_to_common_atoms, + load_structure_with_altlocs, + select_altloc, +) + + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from lddt_evaluation_script import translate_selection + + +# np.isin requires a sequence; BLANK_ALTLOC_IDS is a set. Cache the list form. +_BLANK_ALTLOC_ID_LIST = list(BLANK_ALTLOC_IDS) + + +OUTPUT_COLUMNS = [ + "protein", + "selection", + "chain", + "start_res", + "end_res", + "span_length", + "classification", + "worst_pair_mean_backbone_lddt", + "n_backbone_altloc_residues", + "n_altlocs", + "pair_lddts", +] + + +def _resolve_cif_path(row: pd.Series, cif_root: Path | None) -> Path: + """Resolve a CIF path from a row, preferring ``structure`` then ``structure_pattern``. + + When resolving ``structure_pattern`` against ``cif_root``, this tries both + ``{cif_root}/{pattern}`` (flat layout) and ``{cif_root}/{protein}/{pattern}`` + (per-protein subdirectory layout, as used by the initial_dataset processed dir). + """ + if "structure" in row and isinstance(row["structure"], str) and row["structure"]: + p = Path(row["structure"]) + if p.is_absolute() or p.exists(): + return p + if cif_root is not None: + return cif_root / p + return p + + if "structure_pattern" not in row or not row["structure_pattern"]: + raise ValueError(f"Row has neither 'structure' nor 'structure_pattern': {row.to_dict()}") + + pattern = Path(row["structure_pattern"]) + if pattern.is_absolute(): + return pattern + if cif_root is None: + return pattern + + flat = cif_root / pattern + if flat.exists(): + return flat + + protein = row.get("protein", "") + if isinstance(protein, str) and protein: + for candidate in (cif_root / protein / pattern, cif_root / protein.upper() / pattern): + if candidate.exists(): + return candidate + + return flat # fall back to flat so caller's existence check emits the right error + + +def _max_contiguous_run(sorted_res_ids: list[int]) -> int: + """Return the length of the longest contiguous run of integers in a sorted list.""" + if not sorted_res_ids: + return 0 + best = cur = 1 + for prev, r in zip(sorted_res_ids, sorted_res_ids[1:]): + cur = cur + 1 if r == prev + 1 else 1 + if cur > best: + best = cur + return best + + +def _build_pairwise_altloc_arrays( + atom_array, altloc_ids: list[str] +) -> dict[tuple[str, str], tuple[object, object]]: + """Return ``{(id_i, id_j): (array_i, array_j)}`` pre-filtered to common atoms. + + For each unordered altloc pair we build the two per-altloc AtomArrays + (via ``select_altloc(return_full_array=True)``, which includes blank-altloc + atoms as shared context) and then run ``filter_to_common_atoms`` so the two + inputs have identical atom order and count. + + We build per-pair so residues whose altloc set is a subset of those in the whole structure + (e.g. 2YL0 res 60–64 carry only altlocs A and B, not C) still get scored for the pairs where + they exist. + """ + pairs: dict[tuple[str, str], tuple[object, object]] = {} + for i in range(len(altloc_ids)): + for j in range(i + 1, len(altloc_ids)): + a_i = select_altloc(atom_array, altloc_ids[i], return_full_array=True) + a_j = select_altloc(atom_array, altloc_ids[j], return_full_array=True) + try: + f_i, f_j = filter_to_common_atoms(a_i, a_j) + except RuntimeError as e: + logger.warning( + f"could not match atoms between altlocs " + f"{altloc_ids[i]} and {altloc_ids[j]}: {e}" + ) + continue + pairs[(altloc_ids[i], altloc_ids[j])] = (f_i, f_j) + return pairs + + +def _mean_residue_lddt_for_pair( + gt_array, + pred_array, + chain: str, + residues: list[int], +) -> float: + """ + Equal weighted arithmetic mean of per residue backbone lDDT across the span. + """ + if len(residues) < 2 or gt_array is None or pred_array is None: + return float("nan") + + res_clause = " or ".join(f"res_id == {r}" for r in residues) + selection = f"chain_id == '{chain}' and ({res_clause}) and atom_name in ['C','CA','N','O']" + try: + result = AllAtomLDDT().compute( + predicted_atom_array_stack=pred_array, + ground_truth_atom_array_stack=gt_array, + selection=selection, + ) + except Exception as e: + logger.warning(f"lDDT compute failed for chain {chain} residues {residues}: {e}") + return float("nan") + + residue_scores = result.get("residue_lddt_scores", {}) + if not residue_scores: + return float("nan") + + flat = [ + v[0] if isinstance(v, (list, tuple, np.ndarray)) else v for v in residue_scores.values() + ] + return float(np.mean(flat)) + + +def _classify_selection( + atom_array, + pair_arrays: dict[tuple[str, str], tuple[object, object]], + altloc_ids: list[str], + selection_str: str, + protein: str, + structure_altloc_mask: np.ndarray, + structure_backbone_mask: np.ndarray, + domain_shift_min_span: int, + loop_lddt_threshold: float, +) -> tuple[dict, set[tuple[str, int]]] | None: + """Classify one contiguous altloc selection. + + Returns ``(row_dict, covered_altloc_residues)`` on success or ``None`` if + the selection could not be applied. ``covered_altloc_residues`` is the set + of (chain_id, res_id) pairs inside the selection that carry any altloc, + used for the caller's residue-coverage invariant check. + """ + aw_selection = translate_selection(selection_str) + + try: + sel_mask = atom_array.mask(aw_selection) + except Exception as e: + logger.error(f"[{protein}] failed to apply selection '{selection_str}': {e}") + return None + + if not sel_mask.any(): + logger.warning(f"[{protein}] selection matched no atoms: {selection_str}") + return None + + sel_res_ids = np.unique(atom_array.res_id[sel_mask]) + sel_chain_ids = np.unique(atom_array.chain_id[sel_mask]) + if len(sel_chain_ids) != 1: + logger.warning( + f"[{protein}] selection spans multiple chains {sel_chain_ids}, using first: " + f"{selection_str}" + ) + chain = str(sel_chain_ids[0]) + + sel_altloc_mask = sel_mask & structure_altloc_mask + covered_altloc_residues: set[tuple[str, int]] = { + (str(c), int(r)) + for c, r in zip(atom_array.chain_id[sel_altloc_mask], atom_array.res_id[sel_altloc_mask]) + } + + backbone_altloc_mask = sel_altloc_mask & structure_backbone_mask + backbone_altloc_res_ids = sorted( + int(r) for r in np.unique(atom_array.res_id[backbone_altloc_mask]) + ) + n_backbone = len(backbone_altloc_res_ids) + + row = { + "protein": protein, + "selection": selection_str, + "chain": chain, + "start_res": int(sel_res_ids.min()), + "end_res": int(sel_res_ids.max()), + "span_length": int(len(sel_res_ids)), + "n_backbone_altloc_residues": n_backbone, + "n_altlocs": len(altloc_ids), + "pair_lddts": json.dumps({}), + "worst_pair_mean_backbone_lddt": float("nan"), + "classification": "", + } + + # Side chain only: no backbone altlocs anywhere in the span. + if n_backbone == 0: + row["classification"] = "side_chain_only" + return row, covered_altloc_residues + + # Domain shift: contiguous backbone-altloc run exceeds threshold (default 50). + if _max_contiguous_run(backbone_altloc_res_ids) > domain_shift_min_span: + row["classification"] = "domain_shift" + return row, covered_altloc_residues + + # Single residue backbone altlocs cannot yield a meaningful lDDT, since you need at least two + # residues for inter-residue distances. A lone backbone-altloc residue is + # a minor local perturbation by definition, which we label as small_loop. + if n_backbone < 2: + row["classification"] = "small_loop" + return row, covered_altloc_residues + + # 3. Loop classification via pairwise lDDT across all altloc pairs + pair_lddts: dict[str, float] = {} + for i in range(len(altloc_ids)): + for j in range(i + 1, len(altloc_ids)): + pair = pair_arrays.get((altloc_ids[i], altloc_ids[j])) + gt, pred = pair if pair is not None else (None, None) + pair_lddts[f"{altloc_ids[i]}-{altloc_ids[j]}"] = _mean_residue_lddt_for_pair( + gt, pred, chain, backbone_altloc_res_ids + ) + row["pair_lddts"] = json.dumps(pair_lddts) + + finite_vals = [v for v in pair_lddts.values() if np.isfinite(v)] + if not finite_vals: + raise RuntimeError( + f"[{protein}] could not compute lDDT for any altloc pair in span " + f"'{selection_str}' (backbone-altloc residues: {backbone_altloc_res_ids}). " + "Refusing to emit an indeterminate classification." + ) + + worst = float(min(finite_vals)) + row["worst_pair_mean_backbone_lddt"] = worst + row["classification"] = "small_loop" if worst > loop_lddt_threshold else "large_loop" + return row, covered_altloc_residues + + +def _process_structure( + row: pd.Series, + cif_root: Path | None, + domain_shift_min_span: int, + loop_lddt_threshold: float, +) -> list[dict]: + protein = str(row["protein"]) + cif_path = _resolve_cif_path(row, cif_root) + if not cif_path.exists(): + logger.error(f"[{protein}] CIF file not found: {cif_path}") + return [] + + selection_field = row.get("selection", "") + if not isinstance(selection_field, str) or not selection_field.strip(): + logger.warning(f"[{protein}] no selections in CSV row for {cif_path}") + return [] + + logger.info(f"[{protein}] loading {cif_path}") + atom_array = load_structure_with_altlocs(cif_path) + altloc_info = detect_altlocs(atom_array) + if len(altloc_info.altloc_ids) < 2: + logger.warning( + f"[{protein}] structure has <2 altloc IDs ({altloc_info.altloc_ids}); skipping" + ) + return [] + + pair_arrays = _build_pairwise_altloc_arrays(atom_array, altloc_info.altloc_ids) + + structure_altloc_mask = ~np.isin(atom_array.altloc_id, _BLANK_ALTLOC_ID_LIST) + structure_backbone_mask = np.isin(atom_array.atom_name, BACKBONE_ATOM_TYPES) + + rows: list[dict] = [] + classified_res_ids: set[tuple[str, int]] = set() + for selection_str in [s.strip() for s in selection_field.split(";") if s.strip()]: + # find_altloc_selections.py appends a combined all altloc selection + # (atomworks-style with " or " clauses) at the end of each row. That one is + # a union over every span we already processed individually, so skip it. + if " or " in selection_str: + continue + out = _classify_selection( + atom_array=atom_array, + pair_arrays=pair_arrays, + altloc_ids=altloc_info.altloc_ids, + selection_str=selection_str, + protein=protein, + structure_altloc_mask=structure_altloc_mask, + structure_backbone_mask=structure_backbone_mask, + domain_shift_min_span=domain_shift_min_span, + loop_lddt_threshold=loop_lddt_threshold, + ) + if out is None: + continue + row, covered = out + rows.append(row) + classified_res_ids.update(covered) + + # residues across all classified spans should equal total unique + # (chain, res_id) pairs that carry any altloc in the structure. + all_altloc_res_ids: set[tuple[str, int]] = { + (str(c), int(r)) + for c, r in zip( + atom_array.chain_id[structure_altloc_mask], + atom_array.res_id[structure_altloc_mask], + ) + } + if classified_res_ids != all_altloc_res_ids: + missing = all_altloc_res_ids - classified_res_ids + extra = classified_res_ids - all_altloc_res_ids + logger.warning( + f"[{protein}] residue coverage invariant not satisfied: " + f"{len(missing)} altloc residues missing from classification, " + f"{len(extra)} classified residues not in full altloc set. " + "This typically means --min-span > 1 was used upstream." + ) + + return rows + + +def main(args: argparse.Namespace) -> None: + input_df = pd.read_csv(args.input_csv) + required = {"protein", "selection"} + missing = required - set(input_df.columns) + if missing: + raise ValueError(f"Input CSV missing required columns: {missing}") + + all_rows: list[dict] = [] + for _, row in input_df.iterrows(): + all_rows.extend( + _process_structure( + row=row, + cif_root=args.cif_root, + domain_shift_min_span=args.domain_shift_min_span, + loop_lddt_threshold=args.loop_lddt_threshold, + ) + ) + + out_df = pd.DataFrame(all_rows, columns=OUTPUT_COLUMNS) + args.output_file.parent.mkdir(parents=True, exist_ok=True) + out_df.to_csv(args.output_file, index=False) + logger.info(f"Wrote {len(out_df)} classified spans to {args.output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=( + "Classify altloc regions into side_chain_only / small_loop / " + "large_loop / domain_shift bins. Consumes the CSV produced by " + "find_altloc_selections.py (run with --min-span 1 to include " + "side-chain-only regions)." + ) + ) + parser.add_argument( + "--input-csv", + type=Path, + required=True, + help="Output CSV from find_altloc_selections.py (must contain 'protein' " + "and 'selection'. May contain 'structure' or 'structure_pattern').", + ) + parser.add_argument( + "--cif-root", + type=Path, + default=None, + help="Optional root directory to resolve 'structure_pattern' entries against.", + ) + parser.add_argument("--output-file", type=Path, required=True) + parser.add_argument("--domain-shift-min-span", type=int, default=50) + parser.add_argument("--loop-lddt-threshold", type=float, default=0.75) + args = parser.parse_args() + main(args) From 815ae32393d9f4f9f2836dcde6755998685822bc Mon Sep 17 00:00:00 2001 From: Karson Chrispens <33336327+k-chrispens@users.noreply.github.com> Date: Fri, 17 Apr 2026 21:03:58 +0000 Subject: [PATCH 2/5] refactor: move resolve_cif_path to grid_search_eval_utils.py --- scripts/eval/classify_altloc_regions.py | 40 +------------------ .../eval/grid_search_eval_utils.py | 38 ++++++++++++++++++ 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/scripts/eval/classify_altloc_regions.py b/scripts/eval/classify_altloc_regions.py index 8dc25977..120cb71e 100644 --- a/scripts/eval/classify_altloc_regions.py +++ b/scripts/eval/classify_altloc_regions.py @@ -48,6 +48,7 @@ import numpy as np import pandas as pd from loguru import logger +from sampleworks.eval.grid_search_eval_utils import resolve_cif_path from sampleworks.metrics.lddt import AllAtomLDDT from sampleworks.utils.atom_array_utils import ( BACKBONE_ATOM_TYPES, @@ -82,43 +83,6 @@ ] -def _resolve_cif_path(row: pd.Series, cif_root: Path | None) -> Path: - """Resolve a CIF path from a row, preferring ``structure`` then ``structure_pattern``. - - When resolving ``structure_pattern`` against ``cif_root``, this tries both - ``{cif_root}/{pattern}`` (flat layout) and ``{cif_root}/{protein}/{pattern}`` - (per-protein subdirectory layout, as used by the initial_dataset processed dir). - """ - if "structure" in row and isinstance(row["structure"], str) and row["structure"]: - p = Path(row["structure"]) - if p.is_absolute() or p.exists(): - return p - if cif_root is not None: - return cif_root / p - return p - - if "structure_pattern" not in row or not row["structure_pattern"]: - raise ValueError(f"Row has neither 'structure' nor 'structure_pattern': {row.to_dict()}") - - pattern = Path(row["structure_pattern"]) - if pattern.is_absolute(): - return pattern - if cif_root is None: - return pattern - - flat = cif_root / pattern - if flat.exists(): - return flat - - protein = row.get("protein", "") - if isinstance(protein, str) and protein: - for candidate in (cif_root / protein / pattern, cif_root / protein.upper() / pattern): - if candidate.exists(): - return candidate - - return flat # fall back to flat so caller's existence check emits the right error - - def _max_contiguous_run(sorted_res_ids: list[int]) -> int: """Return the length of the longest contiguous run of integers in a sorted list.""" if not sorted_res_ids: @@ -310,7 +274,7 @@ def _process_structure( loop_lddt_threshold: float, ) -> list[dict]: protein = str(row["protein"]) - cif_path = _resolve_cif_path(row, cif_root) + cif_path = resolve_cif_path(row, cif_root) if not cif_path.exists(): logger.error(f"[{protein}] CIF file not found: {cif_path}") return [] diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index aa294e50..11df145a 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -9,6 +9,7 @@ from importlib.resources import files from pathlib import Path +import pandas as pd from loguru import logger from sampleworks.eval.constants import OCCUPANCY_LEVELS from sampleworks.eval.eval_dataclasses import ProteinConfig, Trial, TrialList @@ -16,6 +17,43 @@ from sampleworks.utils.guidance_constants import StructurePredictor +def resolve_cif_path(row: pd.Series, cif_root: Path | None) -> Path: + """Resolve a CIF path from a row, preferring ``structure`` then ``structure_pattern``. + + When resolving ``structure_pattern`` against ``cif_root``, this tries both + ``{cif_root}/{pattern}`` (flat layout) and ``{cif_root}/{protein}/{pattern}`` + (per-protein subdirectory layout, as used by the initial_dataset processed dir). + """ + if "structure" in row and isinstance(row["structure"], str) and row["structure"]: + p = Path(row["structure"]) + if p.is_absolute() or p.exists(): + return p + if cif_root is not None: + return cif_root / p + return p + + if "structure_pattern" not in row or not row["structure_pattern"]: + raise ValueError(f"Row has neither 'structure' nor 'structure_pattern': {row.to_dict()}") + + pattern = Path(row["structure_pattern"]) + if pattern.is_absolute(): + return pattern + if cif_root is None: + return pattern + + flat = cif_root / pattern + if flat.exists(): + return flat + + protein = row.get("protein", "") + if isinstance(protein, str) and protein: + for candidate in (cif_root / protein / pattern, cif_root / protein.upper() / pattern): + if candidate.exists(): + return candidate + + return flat # fall back to flat so caller's existence check emits the right error + + # TODO: this either (both) needs tests or (and) there needs to be a clearer "API" # for how the folder names are generated. # https://github.com/diff-use/sampleworks/issues/121 From 4ecbde32f1a037c4e664cc3c24dcfd30e81924cb Mon Sep 17 00:00:00 2001 From: Karson Chrispens <33336327+k-chrispens@users.noreply.github.com> Date: Fri, 17 Apr 2026 21:16:09 +0000 Subject: [PATCH 3/5] review: addressing bot review fb --- .gitignore | 5 ++++- scripts/eval/classify_altloc_regions.py | 27 +++++++++++++++---------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 67a0addd..2469ff2a 100644 --- a/.gitignore +++ b/.gitignore @@ -227,4 +227,7 @@ outputs/ initial_dataset_40*/ *.tar.gz *.tgz -*.csv \ No newline at end of file +*.csv + +# Reinclude +!src/sampleworks/data/protein_configs.csv \ No newline at end of file diff --git a/scripts/eval/classify_altloc_regions.py b/scripts/eval/classify_altloc_regions.py index 120cb71e..3dca863e 100644 --- a/scripts/eval/classify_altloc_regions.py +++ b/scripts/eval/classify_altloc_regions.py @@ -42,13 +42,16 @@ import argparse import json -import sys from pathlib import Path import numpy as np import pandas as pd from loguru import logger from sampleworks.eval.grid_search_eval_utils import resolve_cif_path +from sampleworks.eval.structure_utils import ( + ATOMWORKS_COMPARISON_OPS, + get_mask_from_old_selection_string, +) from sampleworks.metrics.lddt import AllAtomLDDT from sampleworks.utils.atom_array_utils import ( BACKBONE_ATOM_TYPES, @@ -60,10 +63,6 @@ ) -sys.path.insert(0, str(Path(__file__).resolve().parent)) -from lddt_evaluation_script import translate_selection - - # np.isin requires a sequence; BLANK_ALTLOC_IDS is a set. Cache the list form. _BLANK_ALTLOC_ID_LIST = list(BLANK_ALTLOC_IDS) @@ -178,10 +177,11 @@ def _classify_selection( of (chain_id, res_id) pairs inside the selection that carry any altloc, used for the caller's residue-coverage invariant check. """ - aw_selection = translate_selection(selection_str) - try: - sel_mask = atom_array.mask(aw_selection) + if not any(op in selection_str for op in ATOMWORKS_COMPARISON_OPS): + sel_mask = get_mask_from_old_selection_string(atom_array, selection_str) + else: + sel_mask = atom_array.mask(selection_str) except Exception as e: logger.error(f"[{protein}] failed to apply selection '{selection_str}': {e}") return None @@ -193,9 +193,14 @@ def _classify_selection( sel_res_ids = np.unique(atom_array.res_id[sel_mask]) sel_chain_ids = np.unique(atom_array.chain_id[sel_mask]) if len(sel_chain_ids) != 1: - logger.warning( - f"[{protein}] selection spans multiple chains {sel_chain_ids}, using first: " - f"{selection_str}" + # res_ids are per chain, so mixing chains would put residues from + # distinct chains into one list, corrupting backbone_altloc_res_ids, + # _max_contiguous_run, and the lDDT residue selection. + # find_altloc_selections.py emits per chain spans, so callers must split + # multi chain inputs upstream rather than have us guess. + raise RuntimeError( + f"[{protein}] selection '{selection_str}' spans multiple chains " + f"{sel_chain_ids.tolist()}. Split it per-chain upstream." ) chain = str(sel_chain_ids[0]) From f9be48e8ec7f460f04270aa0dec32d6d78a53155 Mon Sep 17 00:00:00 2001 From: Karson Chrispens <33336327+k-chrispens@users.noreply.github.com> Date: Tue, 21 Apr 2026 16:33:15 +0000 Subject: [PATCH 4/5] review: addressing @marcuscollins fb and bot fb --- scripts/eval/classify_altloc_regions.py | 157 ++++++++++++------ .../eval/grid_search_eval_utils.py | 6 +- 2 files changed, 107 insertions(+), 56 deletions(-) diff --git a/scripts/eval/classify_altloc_regions.py b/scripts/eval/classify_altloc_regions.py index 3dca863e..814d8513 100644 --- a/scripts/eval/classify_altloc_regions.py +++ b/scripts/eval/classify_altloc_regions.py @@ -42,15 +42,18 @@ import argparse import json +import re from pathlib import Path import numpy as np import pandas as pd +from biotite.structure import AtomArray from loguru import logger from sampleworks.eval.grid_search_eval_utils import resolve_cif_path from sampleworks.eval.structure_utils import ( ATOMWORKS_COMPARISON_OPS, get_mask_from_old_selection_string, + parse_selection_string, ) from sampleworks.metrics.lddt import AllAtomLDDT from sampleworks.utils.atom_array_utils import ( @@ -63,8 +66,7 @@ ) -# np.isin requires a sequence; BLANK_ALTLOC_IDS is a set. Cache the list form. -_BLANK_ALTLOC_ID_LIST = list(BLANK_ALTLOC_IDS) +_ATOMWORKS_CHAIN_RE = re.compile(r"chain_id\s*==\s*['\"]([^'\"]+)['\"]") OUTPUT_COLUMNS = [ @@ -82,21 +84,36 @@ ] -def _max_contiguous_run(sorted_res_ids: list[int]) -> int: +def _max_contiguous_run(sorted_res_ids: np.ndarray | list[int]) -> int: """Return the length of the longest contiguous run of integers in a sorted list.""" - if not sorted_res_ids: + arr = np.asarray(sorted_res_ids, dtype=int) + if arr.size == 0: return 0 - best = cur = 1 - for prev, r in zip(sorted_res_ids, sorted_res_ids[1:]): - cur = cur + 1 if r == prev + 1 else 1 - if cur > best: - best = cur - return best + breaks = np.concatenate(([0], np.nonzero(np.diff(arr) != 1)[0] + 1, [arr.size])) + return int(np.diff(breaks).max()) + + +def _chain_from_selection(selection: str) -> str | None: + """Extract the chain_id named by a selection string, or None if absent. + + Handles atomworks-style (``chain_id == 'A'``) and the legacy ``chain A`` + syntax accepted by ``parse_selection_string``. + + TODO: deprecate when we move all to atomworks style selections. + """ + m = _ATOMWORKS_CHAIN_RE.search(selection) + if m is not None: + return m.group(1) + if any(op in selection for op in ATOMWORKS_COMPARISON_OPS): + # Atomworks style selection without a chain_id + return None + chain_id, _, _ = parse_selection_string(selection) + return chain_id def _build_pairwise_altloc_arrays( atom_array, altloc_ids: list[str] -) -> dict[tuple[str, str], tuple[object, object]]: +) -> dict[tuple[str, str], tuple[AtomArray, AtomArray]]: """Return ``{(id_i, id_j): (array_i, array_j)}`` pre-filtered to common atoms. For each unordered altloc pair we build the two per-altloc AtomArrays @@ -104,11 +121,18 @@ def _build_pairwise_altloc_arrays( atoms as shared context) and then run ``filter_to_common_atoms`` so the two inputs have identical atom order and count. - We build per-pair so residues whose altloc set is a subset of those in the whole structure - (e.g. 2YL0 res 60–64 carry only altlocs A and B, not C) still get scored for the pairs where - they exist. + We build per-pair rather than using ``map_altlocs_to_stack`` so residues whose + altloc set is a subset of those in the whole structure (e.g. 2YL0 res 60–64 + carry only altlocs A and B, not C) still get scored for the pairs where they + exist. A stack level ``filter_to_common_atoms`` would drop them entirely. + + TODO: this helper hits the broader issue in how we + handle structures with >2 altlocs. + Fixing that upstream would let us replace this helper + with a direct ``map_altlocs_to_stack`` call and remove a source of + duplication. """ - pairs: dict[tuple[str, str], tuple[object, object]] = {} + pairs: dict[tuple[str, str], tuple[AtomArray, AtomArray]] = {} for i in range(len(altloc_ids)): for j in range(i + 1, len(altloc_ids)): a_i = select_altloc(atom_array, altloc_ids[i], return_full_array=True) @@ -126,15 +150,13 @@ def _build_pairwise_altloc_arrays( def _mean_residue_lddt_for_pair( - gt_array, - pred_array, + gt_array: AtomArray, + pred_array: AtomArray, chain: str, residues: list[int], ) -> float: - """ - Equal weighted arithmetic mean of per residue backbone lDDT across the span. - """ - if len(residues) < 2 or gt_array is None or pred_array is None: + """Equal weighted arithmetic mean of per residue lDDT across the span.""" + if gt_array is None or pred_array is None or not residues: return float("nan") res_clause = " or ".join(f"res_id == {r}" for r in residues) @@ -150,18 +172,22 @@ def _mean_residue_lddt_for_pair( return float("nan") residue_scores = result.get("residue_lddt_scores", {}) - if not residue_scores: + keys = [f"{chain}{r}" for r in residues] + missing = [k for k in keys if k not in residue_scores] + if missing: + logger.warning( + f"lDDT result missing residues {missing} for chain {chain}. This means the result" + f"averaged only over the {len(keys) - len(missing)} residues it returned" + ) + flat = [residue_scores[k][0] for k in keys if k in residue_scores] + if not flat: return float("nan") - - flat = [ - v[0] if isinstance(v, (list, tuple, np.ndarray)) else v for v in residue_scores.values() - ] return float(np.mean(flat)) def _classify_selection( - atom_array, - pair_arrays: dict[tuple[str, str], tuple[object, object]], + atom_array: AtomArray, + pair_arrays: dict[tuple[str, str], tuple[AtomArray, AtomArray]], altloc_ids: list[str], selection_str: str, protein: str, @@ -170,19 +196,36 @@ def _classify_selection( domain_shift_min_span: int, loop_lddt_threshold: float, ) -> tuple[dict, set[tuple[str, int]]] | None: - """Classify one contiguous altloc selection. - - Returns ``(row_dict, covered_altloc_residues)`` on success or ``None`` if - the selection could not be applied. ``covered_altloc_residues`` is the set - of (chain_id, res_id) pairs inside the selection that carry any altloc, - used for the caller's residue-coverage invariant check. + """Classify one contiguous altloc selection into a conformational type. + + 1. If the span has no backbone altlocs anywhere, it is classified as ``side_chain_only``. + 2. Else if the longest contiguous backbone altloc run exceeds + ``domain_shift_min_span``, it is classified as ``domain_shift``. + 3. Else compute the per residue backbone lDDT for every altloc pair over + the backbone altloc residues in the span and take the minimum + pair mean. Compare against ``loop_lddt_threshold``, if it is above is is classified as + ``small_loop``. If it is below, it is classified as ``large_loop``. + + Returns ``(row_dict, covered_altloc_residues)`` on success or ``None`` if the + selection could not be applied. + + ``row_dict`` has the keys: + ``protein``, ``selection``, ``chain``, ``start_res``, ``end_res``, + ``span_length``, ``classification``, ``worst_pair_mean_backbone_lddt``, + ``n_backbone_altloc_residues``, ``n_altlocs``, and ``pair_lddts`` (a + JSON encoded ``{pair_label: mean_lddt}`` map so the dict can be loaded + through the CSV intact via ``json.loads``). + + ``covered_altloc_residues`` is the set of ``(chain_id, res_id)`` pairs in the + span that carry any altloc, used for the caller's residue-coverage invariant + check. """ try: if not any(op in selection_str for op in ATOMWORKS_COMPARISON_OPS): sel_mask = get_mask_from_old_selection_string(atom_array, selection_str) else: sel_mask = atom_array.mask(selection_str) - except Exception as e: + except (ValueError, SyntaxError) as e: logger.error(f"[{protein}] failed to apply selection '{selection_str}': {e}") return None @@ -192,17 +235,27 @@ def _classify_selection( sel_res_ids = np.unique(atom_array.res_id[sel_mask]) sel_chain_ids = np.unique(atom_array.chain_id[sel_mask]) - if len(sel_chain_ids) != 1: - # res_ids are per chain, so mixing chains would put residues from - # distinct chains into one list, corrupting backbone_altloc_res_ids, - # _max_contiguous_run, and the lDDT residue selection. - # find_altloc_selections.py emits per chain spans, so callers must split - # multi chain inputs upstream rather than have us guess. - raise RuntimeError( - f"[{protein}] selection '{selection_str}' spans multiple chains " - f"{sel_chain_ids.tolist()}. Split it per-chain upstream." - ) - chain = str(sel_chain_ids[0]) + + # Chain is taken from the selection string. Fall back to the + # mask-matched atoms when the selection has no chain clause. + chain_from_sel = _chain_from_selection(selection_str) + if chain_from_sel is None: + if len(sel_chain_ids) != 1: + logger.warning( + f"{protein} selection '{selection_str}' did not specify a chain and " + f"matched atoms that exist in these chains {sel_chain_ids.tolist()}, skipping" + ) + return None + chain = str(sel_chain_ids[0]) + else: + if not (len(sel_chain_ids) == 1 and str(sel_chain_ids[0]) == chain_from_sel): + logger.warning( + f"{protein} selection '{selection_str}' has chain " + f"'{chain_from_sel}' but mask matched atoms exist in chains " + f"{sel_chain_ids.tolist()} skipping" + ) + return None + chain = chain_from_sel sel_altloc_mask = sel_mask & structure_altloc_mask covered_altloc_residues: set[tuple[str, int]] = { @@ -225,6 +278,7 @@ def _classify_selection( "span_length": int(len(sel_res_ids)), "n_backbone_altloc_residues": n_backbone, "n_altlocs": len(altloc_ids), + # JSON encoded so the pair calculation can be loaded back through the CSV "pair_lddts": json.dumps({}), "worst_pair_mean_backbone_lddt": float("nan"), "classification": "", @@ -240,14 +294,7 @@ def _classify_selection( row["classification"] = "domain_shift" return row, covered_altloc_residues - # Single residue backbone altlocs cannot yield a meaningful lDDT, since you need at least two - # residues for inter-residue distances. A lone backbone-altloc residue is - # a minor local perturbation by definition, which we label as small_loop. - if n_backbone < 2: - row["classification"] = "small_loop" - return row, covered_altloc_residues - - # 3. Loop classification via pairwise lDDT across all altloc pairs + # Loop classification via pairwise lDDT across all altloc pairs pair_lddts: dict[str, float] = {} for i in range(len(altloc_ids)): for j in range(i + 1, len(altloc_ids)): @@ -300,7 +347,7 @@ def _process_structure( pair_arrays = _build_pairwise_altloc_arrays(atom_array, altloc_info.altloc_ids) - structure_altloc_mask = ~np.isin(atom_array.altloc_id, _BLANK_ALTLOC_ID_LIST) + structure_altloc_mask = ~np.isin(atom_array.altloc_id, list(BLANK_ALTLOC_IDS)) structure_backbone_mask = np.isin(atom_array.atom_name, BACKBONE_ATOM_TYPES) rows: list[dict] = [] diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index 11df145a..e234a9ef 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -32,7 +32,11 @@ def resolve_cif_path(row: pd.Series, cif_root: Path | None) -> Path: return cif_root / p return p - if "structure_pattern" not in row or not row["structure_pattern"]: + if ( + "structure_pattern" not in row + or pd.isna(row["structure_pattern"]) + or not row["structure_pattern"] + ): raise ValueError(f"Row has neither 'structure' nor 'structure_pattern': {row.to_dict()}") pattern = Path(row["structure_pattern"]) From da69e98a21215d9b0cc1900aee42b40be3cb301e Mon Sep 17 00:00:00 2001 From: Karson Chrispens <33336327+k-chrispens@users.noreply.github.com> Date: Tue, 21 Apr 2026 22:06:26 +0000 Subject: [PATCH 5/5] docs: add grid_search_eval_utils.py resolve_cif_path NumPy style docstring --- .../eval/grid_search_eval_utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index e234a9ef..1bd32fd8 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -20,6 +20,25 @@ def resolve_cif_path(row: pd.Series, cif_root: Path | None) -> Path: """Resolve a CIF path from a row, preferring ``structure`` then ``structure_pattern``. + Parameters + ---------- + row : pd.Series + Row containing a ``structure`` and/or ``structure_pattern`` field. + cif_root : Path | None + Root directory used to resolve relative paths. + + Returns + ------- + Path + The resolved CIF path. + + Raises + ------ + ValueError + If the row has neither ``structure`` nor ``structure_pattern``. + + Notes + ----- When resolving ``structure_pattern`` against ``cif_root``, this tries both ``{cif_root}/{pattern}`` (flat layout) and ``{cif_root}/{protein}/{pattern}`` (per-protein subdirectory layout, as used by the initial_dataset processed dir).