Skip to content

Commit f07cca8

Browse files
committed
refactor: extract shared structure loading in generate synthetic reward into synthetic_utils
1 parent 561ffa7 commit f07cca8

5 files changed

Lines changed: 250 additions & 183 deletions

File tree

src/sampleworks/eval/generate_synthetic_density.py

Lines changed: 12 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,11 @@
77
from typing import ClassVar
88

99
import torch
10-
from atomworks.io.transforms.atom_array import remove_waters
1110
from joblib import delayed, Parallel
1211
from loguru import logger
1312
from sampleworks.core.forward_models.xray.real_space_density import XMap_torch
14-
from sampleworks.eval.structure_utils import apply_selection
15-
from sampleworks.utils.atom_array_utils import (
16-
assign_occupancies,
17-
detect_altlocs,
18-
keep_amino_acids,
19-
keep_polymer,
20-
load_structure_with_altlocs,
21-
remove_hydrogens,
22-
save_structure_to_cif,
23-
)
13+
from sampleworks.eval.synthetic_utils import load_structure_for_synthetic_reward
14+
from sampleworks.utils.atom_array_utils import save_structure_to_cif
2415
from sampleworks.utils.density_utils import compute_density_from_atomarray
2516
from sampleworks.utils.torch_utils import try_gpu
2617

@@ -187,54 +178,18 @@ def _process_single_row(
187178
If True, save the processed structure to a CIF file in the input directory. Default is True.
188179
"""
189180
structure_path = base_dir / row.filename
190-
if not structure_path.exists():
191-
logger.error(f"Structure not found: {structure_path}")
192-
return
193-
194-
try:
195-
atom_array = load_structure_with_altlocs(structure_path)
196-
except Exception as e:
197-
logger.error(
198-
f"Failed to load {row.filename} ({type(e).__name__}): {e}\n"
199-
f"{''.join(traceback.format_tb(e.__traceback__))}"
200-
)
201-
return
202-
203-
try:
204-
atom_array = apply_selection(atom_array, row.selection)
205-
except ValueError as e:
206-
logger.error(f"Selection error for {row.filename}: {e}")
181+
atom_array = load_structure_for_synthetic_reward(
182+
structure_path,
183+
occ_mode=occ_mode,
184+
occ_values=row.occ_values,
185+
strip_hydrogens=strip_hydrogens,
186+
strip_waters=strip_waters,
187+
strip_ligands=strip_ligands,
188+
selection=row.selection,
189+
)
190+
if atom_array is None:
207191
return
208192

209-
atom_array = remove_hydrogens(atom_array) if strip_hydrogens else atom_array
210-
atom_array = remove_waters(atom_array) if strip_waters else atom_array
211-
# This is currently a sort of hacky way to remove ligands by keeping only polymer atoms
212-
# TODO: there's probably a more robust way to do this
213-
atom_array = keep_polymer(keep_amino_acids(atom_array)) if strip_ligands else atom_array
214-
215-
altloc_info = detect_altlocs(atom_array) # ty: ignore[invalid-argument-type]
216-
if row.occ_values:
217-
if occ_mode != "custom":
218-
logger.warning(
219-
f"Custom occupancy values provided for {row.filename}, "
220-
f"but occ_mode is '{occ_mode}'. Using 'custom' mode."
221-
)
222-
occ_mode = "custom"
223-
try:
224-
atom_array = assign_occupancies(atom_array, altloc_info, "custom", row.occ_values)
225-
except ValueError as e:
226-
logger.error(f"Occupancy assignment error for {row.filename}: {e}")
227-
raise
228-
elif occ_mode in {"uniform", "default"}:
229-
try:
230-
atom_array = assign_occupancies(atom_array, altloc_info, occ_mode)
231-
except ValueError as e:
232-
logger.error(f"Occupancy assignment error for {row.filename}: {e}")
233-
raise
234-
else:
235-
logger.error(f"Invalid occupancy mode '{occ_mode}' for {row.filename}")
236-
raise ValueError(f"Invalid occupancy mode '{occ_mode}'")
237-
238193
try:
239194
density, xmap_torch = compute_density_from_atomarray(
240195
atom_array, resolution=resolution, em_mode=em_mode, device=device

src/sampleworks/eval/generate_synthetic_sf.py

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,10 @@
1717
import numpy as np
1818
import reciprocalspaceship as rs
1919
import torch
20-
from atomworks.io.transforms.atom_array import remove_waters
2120
from biotite.structure import AtomArray
2221
from loguru import logger
23-
from sampleworks.utils.atom_array_utils import (
24-
assign_occupancies,
25-
BLANK_ALTLOC_IDS,
26-
detect_altlocs,
27-
keep_amino_acids,
28-
keep_polymer,
29-
load_structure_with_altlocs,
30-
remove_hydrogens,
31-
)
22+
from sampleworks.eval.synthetic_utils import load_structure_for_synthetic_reward
23+
from sampleworks.utils.atom_array_utils import BLANK_ALTLOC_IDS
3224
from sampleworks.utils.torch_utils import try_gpu
3325
from SFC_Torch import SFcalculator
3426
from SFC_Torch.io import array2hier, PDBParser
@@ -49,6 +41,8 @@ class BatchRow:
4941
space_group
5042
Optional space group (in Hermann-Mauguin string format) to override the
5143
one in the structure file.
44+
selection
45+
Optional atom selection string in pyMOL-like syntax (e.g., 'chain A and resi 10-50')
5246
occ_values
5347
Custom list of occupancy values for altlocs, must be in range [0.0, 1.0]
5448
"""
@@ -59,6 +53,7 @@ class BatchRow:
5953
mtzfile: str | None = None
6054
unit_cell: gemmi.UnitCell | None = None
6155
space_group: str | None = None
56+
selection: str | None = None
6257
occ_values: list[float] = field(default_factory=list)
6358

6459
def __post_init__(self) -> None:
@@ -108,6 +103,7 @@ def from_dict(cls, row: dict[str, Any]) -> "BatchRow":
108103
mtzfile=row.get("mtzfile") or None,
109104
unit_cell=unit_cell,
110105
space_group=space_group,
106+
selection=row.get("selection") or None,
111107
occ_values=occ_values,
112108
)
113109

@@ -193,8 +189,9 @@ def write_amplitudes_to_mtz(
193189
is False, which uses Phenix convention (1 = test, 0 = working).
194190
sigf_scale: float
195191
Scale factor to make a fake SIGFP column from FP values so that
196-
SFcalculator can load the output MTZ file without errors. Default
197-
is 0.2. The actual SIGFP values only matter when computing R-factor.
192+
SFcalculator can load the output MTZ file as synthetic reward
193+
without errors. The actual SIGFP values only matter when
194+
computing R-factor.
198195
"""
199196
output_path.parent.mkdir(parents=True, exist_ok=True)
200197
dataset = sfc.prepare_dataset(hkl_attr, f_attr)
@@ -226,6 +223,7 @@ def _process_single_row(
226223
strip_waters: bool = False,
227224
strip_ligands: bool = False,
228225
simulate_solvent_and_scale: bool = False,
226+
save_structure: bool = False,
229227
) -> None:
230228
"""Compute synthetic protein structure factors for a single structure.
231229
Assume no anomalous scattering.
@@ -261,49 +259,24 @@ def _process_single_row(
261259
simulate_solvent_and_scale
262260
If True, compute bulk solvent and scale factors for Ftotal instead of Fprotein.
263261
Default is False.
262+
save_structure
263+
If True, save the processed structure (after selection and occupancy assignment)
264+
as mmCIF to output_dir. Unit cell and space group are preserved. Default is False.
264265
"""
265266
structure_path = base_dir / row.filename
266-
if not structure_path.exists():
267-
logger.error(f"Structure not found: {structure_path}")
267+
atom_array = load_structure_for_synthetic_reward(
268+
structure_path,
269+
occ_mode=occ_mode,
270+
occ_values=row.occ_values,
271+
strip_hydrogens=strip_hydrogens,
272+
strip_waters=strip_waters,
273+
strip_ligands=strip_ligands,
274+
selection=row.selection,
275+
)
276+
if atom_array is None:
268277
return
269278

270-
# Load structure and strip off unwanted atoms
271-
try:
272-
atom_array = load_structure_with_altlocs(structure_path)
273-
except Exception as e:
274-
logger.error(
275-
f"Failed to load {row.filename} ({type(e).__name__}): {e}\n"
276-
f"{''.join(traceback.format_tb(e.__traceback__))}"
277-
)
278-
return
279-
atom_array = remove_hydrogens(atom_array) if strip_hydrogens else atom_array
280-
atom_array = remove_waters(atom_array) if strip_waters else atom_array
281-
atom_array = keep_polymer(keep_amino_acids(atom_array)) if strip_ligands else atom_array
282-
283-
# Altloc detection and occupancy assignment (reused from density script)
284-
altloc_info = detect_altlocs(atom_array) # ty: ignore[invalid-argument-type]
285-
if row.occ_values:
286-
if occ_mode != "custom":
287-
logger.warning(
288-
f"Custom occupancy values provided for {row.filename}, "
289-
f"but occ_mode is '{occ_mode}'. Using 'custom' mode."
290-
)
291-
try:
292-
atom_array = assign_occupancies(atom_array, altloc_info, "custom", row.occ_values)
293-
except ValueError as e:
294-
logger.error(f"Occupancy assignment error for {row.filename}: {e}")
295-
raise
296-
elif occ_mode in {"uniform", "default"}:
297-
try:
298-
atom_array = assign_occupancies(atom_array, altloc_info, occ_mode)
299-
except ValueError as e:
300-
logger.error(f"Occupancy assignment error for {row.filename}: {e}")
301-
raise
302-
else:
303-
logger.error(f"Invalid occupancy mode '{occ_mode}' for {row.filename}")
304-
raise ValueError(f"Invalid occupancy mode '{occ_mode}'")
305-
306-
# Convert to gemmi and initialize SFcalculator
279+
# Convert to gemmi for SFcalculator
307280
try:
308281
unit_cell = row.unit_cell
309282
space_group = row.space_group
@@ -321,6 +294,18 @@ def _process_single_row(
321294
)
322295
return
323296

297+
if save_structure:
298+
structure_output_path = output_dir / f"{structure_path.stem}_sf_input.cif"
299+
structure_output_path.parent.mkdir(parents=True, exist_ok=True)
300+
try:
301+
gemmi_structure.make_mmcif_document().write_file(str(structure_output_path))
302+
logger.info(f"Saved processed structure to {structure_output_path}")
303+
except Exception as e:
304+
logger.error(
305+
f"Failed to save structure for {row.filename} ({type(e).__name__}): {e}\n"
306+
f"{''.join(traceback.format_tb(e.__traceback__))}"
307+
)
308+
324309
# Compute structure factors
325310
try:
326311
sfc = SFcalculator(
@@ -375,7 +360,7 @@ def load_batch_csv(csv_path: Path) -> list[BatchRow]:
375360
----------
376361
csv_path
377362
Path to CSV file with columns: filename (required), mtzfile, unit_cell,
378-
space_group, occ_values (all optional)
363+
space_group, selection, occ_values (all optional)
379364
380365
Returns
381366
-------
@@ -412,6 +397,7 @@ def process_batch(
412397
strip_waters: bool = False,
413398
strip_ligands: bool = False,
414399
simulate_solvent_and_scale: bool = False,
400+
save_structure: bool = False,
415401
) -> None:
416402
"""Process multiple structures from a CSV file in batch mode.
417403
@@ -445,6 +431,8 @@ def process_batch(
445431
If True, keep only polymer amino-acid atoms (removes ligands and waters).
446432
simulate_solvent_and_scale
447433
If True, compute bulk solvent and scale factors in addition to F_protein.
434+
save_structure
435+
If True, save each processed structure as mmCIF to output_dir.
448436
"""
449437
from joblib import delayed, Parallel
450438

@@ -466,6 +454,7 @@ def process_batch(
466454
strip_waters=strip_waters,
467455
strip_ligands=strip_ligands,
468456
simulate_solvent_and_scale=simulate_solvent_and_scale,
457+
save_structure=save_structure,
469458
)
470459
for row in rows
471460
)
@@ -488,6 +477,13 @@ def parse_args() -> argparse.Namespace:
488477
help="Base directory for relative paths in CSV, not used in single-structure mode",
489478
)
490479

480+
selection_group = parser.add_argument_group("Selection Options")
481+
selection_group.add_argument(
482+
"--selection",
483+
type=str,
484+
help="Atom selection (e.g., 'chain A and resi 10-50' or 'chain A and resi 10')",
485+
)
486+
491487
occ_group = parser.add_argument_group("Occupancy Options")
492488
occ_group.add_argument(
493489
"--occ-mode",
@@ -559,6 +555,11 @@ def parse_args() -> argparse.Namespace:
559555
)
560556

561557
output_group = parser.add_argument_group("Output Options")
558+
output_group.add_argument(
559+
"--save-structure",
560+
action="store_true",
561+
help="Save the processed structure (after selection, occupancy assignment) to CIF",
562+
)
562563
output_group.add_argument("--output", "-o", type=Path, help="Output MTZ file path")
563564
output_group.add_argument(
564565
"--output-dir", type=Path, default=Path("."), help="Output directory for batch mode"
@@ -595,6 +596,7 @@ def main() -> None:
595596
strip_waters=args.remove_waters,
596597
strip_ligands=args.remove_ligands,
597598
simulate_solvent_and_scale=args.simulate_solvent_and_scale,
599+
save_structure=args.save_structure,
598600
)
599601
elif args.structure:
600602
row = BatchRow.from_dict(
@@ -603,6 +605,7 @@ def main() -> None:
603605
"mtzfile": args.output.name if args.output else None,
604606
"unit_cell": args.unit_cell,
605607
"space_group": args.space_group,
608+
"selection": args.selection,
606609
"occ_values": args.occ_values,
607610
}
608611
)
@@ -620,6 +623,7 @@ def main() -> None:
620623
strip_waters=args.remove_waters,
621624
strip_ligands=args.remove_ligands,
622625
simulate_solvent_and_scale=args.simulate_solvent_and_scale,
626+
save_structure=args.save_structure,
623627
)
624628
else:
625629
logger.error("Please specify --structure or --batch-csv")

0 commit comments

Comments
 (0)