1717import numpy as np
1818import reciprocalspaceship as rs
1919import torch
20- from atomworks .io .transforms .atom_array import remove_waters
2120from biotite .structure import AtomArray
2221from loguru import logger
23- from sampleworks .utils .atom_array_utils import (
24- assign_occupancies ,
25- BLANK_ALTLOC_IDS ,
26- detect_altlocs ,
27- keep_amino_acids ,
28- keep_polymer ,
29- load_structure_with_altlocs ,
30- remove_hydrogens ,
31- )
22+ from sampleworks .eval .synthetic_utils import load_structure_for_synthetic_reward
23+ from sampleworks .utils .atom_array_utils import BLANK_ALTLOC_IDS
3224from sampleworks .utils .torch_utils import try_gpu
3325from SFC_Torch import SFcalculator
3426from SFC_Torch .io import array2hier , PDBParser
@@ -49,6 +41,8 @@ class BatchRow:
4941 space_group
5042 Optional space group (in Hermann-Mauguin string format) to override the
5143 one in the structure file.
44+ selection
45+ Optional atom selection string in pyMOL-like syntax (e.g., 'chain A and resi 10-50')
5246 occ_values
5347 Custom list of occupancy values for altlocs, must be in range [0.0, 1.0]
5448 """
@@ -59,6 +53,7 @@ class BatchRow:
5953 mtzfile : str | None = None
6054 unit_cell : gemmi .UnitCell | None = None
6155 space_group : str | None = None
56+ selection : str | None = None
6257 occ_values : list [float ] = field (default_factory = list )
6358
6459 def __post_init__ (self ) -> None :
@@ -108,6 +103,7 @@ def from_dict(cls, row: dict[str, Any]) -> "BatchRow":
108103 mtzfile = row .get ("mtzfile" ) or None ,
109104 unit_cell = unit_cell ,
110105 space_group = space_group ,
106+ selection = row .get ("selection" ) or None ,
111107 occ_values = occ_values ,
112108 )
113109
@@ -193,8 +189,9 @@ def write_amplitudes_to_mtz(
193189 is False, which uses Phenix convention (1 = test, 0 = working).
194190 sigf_scale: float
195191 Scale factor to make a fake SIGFP column from FP values so that
196- SFcalculator can load the output MTZ file without errors. Default
197- is 0.2. The actual SIGFP values only matter when computing R-factor.
192+ SFcalculator can load the output MTZ file as synthetic reward
193+ without errors. The actual SIGFP values only matter when
194+ computing R-factor.
198195 """
199196 output_path .parent .mkdir (parents = True , exist_ok = True )
200197 dataset = sfc .prepare_dataset (hkl_attr , f_attr )
@@ -226,6 +223,7 @@ def _process_single_row(
226223 strip_waters : bool = False ,
227224 strip_ligands : bool = False ,
228225 simulate_solvent_and_scale : bool = False ,
226+ save_structure : bool = False ,
229227) -> None :
230228 """Compute synthetic protein structure factors for a single structure.
231229 Assume no anomalous scattering.
@@ -261,49 +259,24 @@ def _process_single_row(
261259 simulate_solvent_and_scale
262260 If True, compute bulk solvent and scale factors for Ftotal instead of Fprotein.
263261 Default is False.
262+ save_structure
263+ If True, save the processed structure (after selection and occupancy assignment)
264+ as mmCIF to output_dir. Unit cell and space group are preserved. Default is False.
264265 """
265266 structure_path = base_dir / row .filename
266- if not structure_path .exists ():
267- logger .error (f"Structure not found: { structure_path } " )
267+ atom_array = load_structure_for_synthetic_reward (
268+ structure_path ,
269+ occ_mode = occ_mode ,
270+ occ_values = row .occ_values ,
271+ strip_hydrogens = strip_hydrogens ,
272+ strip_waters = strip_waters ,
273+ strip_ligands = strip_ligands ,
274+ selection = row .selection ,
275+ )
276+ if atom_array is None :
268277 return
269278
270- # Load structure and strip off unwanted atoms
271- try :
272- atom_array = load_structure_with_altlocs (structure_path )
273- except Exception as e :
274- logger .error (
275- f"Failed to load { row .filename } ({ type (e ).__name__ } ): { e } \n "
276- f"{ '' .join (traceback .format_tb (e .__traceback__ ))} "
277- )
278- return
279- atom_array = remove_hydrogens (atom_array ) if strip_hydrogens else atom_array
280- atom_array = remove_waters (atom_array ) if strip_waters else atom_array
281- atom_array = keep_polymer (keep_amino_acids (atom_array )) if strip_ligands else atom_array
282-
283- # Altloc detection and occupancy assignment (reused from density script)
284- altloc_info = detect_altlocs (atom_array ) # ty: ignore[invalid-argument-type]
285- if row .occ_values :
286- if occ_mode != "custom" :
287- logger .warning (
288- f"Custom occupancy values provided for { row .filename } , "
289- f"but occ_mode is '{ occ_mode } '. Using 'custom' mode."
290- )
291- try :
292- atom_array = assign_occupancies (atom_array , altloc_info , "custom" , row .occ_values )
293- except ValueError as e :
294- logger .error (f"Occupancy assignment error for { row .filename } : { e } " )
295- raise
296- elif occ_mode in {"uniform" , "default" }:
297- try :
298- atom_array = assign_occupancies (atom_array , altloc_info , occ_mode )
299- except ValueError as e :
300- logger .error (f"Occupancy assignment error for { row .filename } : { e } " )
301- raise
302- else :
303- logger .error (f"Invalid occupancy mode '{ occ_mode } ' for { row .filename } " )
304- raise ValueError (f"Invalid occupancy mode '{ occ_mode } '" )
305-
306- # Convert to gemmi and initialize SFcalculator
279+ # Convert to gemmi for SFcalculator
307280 try :
308281 unit_cell = row .unit_cell
309282 space_group = row .space_group
@@ -321,6 +294,18 @@ def _process_single_row(
321294 )
322295 return
323296
297+ if save_structure :
298+ structure_output_path = output_dir / f"{ structure_path .stem } _sf_input.cif"
299+ structure_output_path .parent .mkdir (parents = True , exist_ok = True )
300+ try :
301+ gemmi_structure .make_mmcif_document ().write_file (str (structure_output_path ))
302+ logger .info (f"Saved processed structure to { structure_output_path } " )
303+ except Exception as e :
304+ logger .error (
305+ f"Failed to save structure for { row .filename } ({ type (e ).__name__ } ): { e } \n "
306+ f"{ '' .join (traceback .format_tb (e .__traceback__ ))} "
307+ )
308+
324309 # Compute structure factors
325310 try :
326311 sfc = SFcalculator (
@@ -375,7 +360,7 @@ def load_batch_csv(csv_path: Path) -> list[BatchRow]:
375360 ----------
376361 csv_path
377362 Path to CSV file with columns: filename (required), mtzfile, unit_cell,
378- space_group, occ_values (all optional)
363+ space_group, selection, occ_values (all optional)
379364
380365 Returns
381366 -------
@@ -412,6 +397,7 @@ def process_batch(
412397 strip_waters : bool = False ,
413398 strip_ligands : bool = False ,
414399 simulate_solvent_and_scale : bool = False ,
400+ save_structure : bool = False ,
415401) -> None :
416402 """Process multiple structures from a CSV file in batch mode.
417403
@@ -445,6 +431,8 @@ def process_batch(
445431 If True, keep only polymer amino-acid atoms (removes ligands and waters).
446432 simulate_solvent_and_scale
447433 If True, compute bulk solvent and scale factors in addition to F_protein.
434+ save_structure
435+ If True, save each processed structure as mmCIF to output_dir.
448436 """
449437 from joblib import delayed , Parallel
450438
@@ -466,6 +454,7 @@ def process_batch(
466454 strip_waters = strip_waters ,
467455 strip_ligands = strip_ligands ,
468456 simulate_solvent_and_scale = simulate_solvent_and_scale ,
457+ save_structure = save_structure ,
469458 )
470459 for row in rows
471460 )
@@ -488,6 +477,13 @@ def parse_args() -> argparse.Namespace:
488477 help = "Base directory for relative paths in CSV, not used in single-structure mode" ,
489478 )
490479
480+ selection_group = parser .add_argument_group ("Selection Options" )
481+ selection_group .add_argument (
482+ "--selection" ,
483+ type = str ,
484+ help = "Atom selection (e.g., 'chain A and resi 10-50' or 'chain A and resi 10')" ,
485+ )
486+
491487 occ_group = parser .add_argument_group ("Occupancy Options" )
492488 occ_group .add_argument (
493489 "--occ-mode" ,
@@ -559,6 +555,11 @@ def parse_args() -> argparse.Namespace:
559555 )
560556
561557 output_group = parser .add_argument_group ("Output Options" )
558+ output_group .add_argument (
559+ "--save-structure" ,
560+ action = "store_true" ,
561+ help = "Save the processed structure (after selection, occupancy assignment) to CIF" ,
562+ )
562563 output_group .add_argument ("--output" , "-o" , type = Path , help = "Output MTZ file path" )
563564 output_group .add_argument (
564565 "--output-dir" , type = Path , default = Path ("." ), help = "Output directory for batch mode"
@@ -595,6 +596,7 @@ def main() -> None:
595596 strip_waters = args .remove_waters ,
596597 strip_ligands = args .remove_ligands ,
597598 simulate_solvent_and_scale = args .simulate_solvent_and_scale ,
599+ save_structure = args .save_structure ,
598600 )
599601 elif args .structure :
600602 row = BatchRow .from_dict (
@@ -603,6 +605,7 @@ def main() -> None:
603605 "mtzfile" : args .output .name if args .output else None ,
604606 "unit_cell" : args .unit_cell ,
605607 "space_group" : args .space_group ,
608+ "selection" : args .selection ,
606609 "occ_values" : args .occ_values ,
607610 }
608611 )
@@ -620,6 +623,7 @@ def main() -> None:
620623 strip_waters = args .remove_waters ,
621624 strip_ligands = args .remove_ligands ,
622625 simulate_solvent_and_scale = args .simulate_solvent_and_scale ,
626+ save_structure = args .save_structure ,
623627 )
624628 else :
625629 logger .error ("Please specify --structure or --batch-csv" )
0 commit comments