Skip to content

Commit 33dc325

Browse files
committed
feat(ciffiles): Write Sampleworks metadata into CIF files produced by grid search and patching script; resolves #208
1 parent 98997f8 commit 33dc325

6 files changed

Lines changed: 66 additions & 18 deletions

File tree

scripts/patch_output_cif_files.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# utility script to put all header information from original PDB entry into our CIF files
22
import fnmatch
3+
import json
34
import re
45
from argparse import ArgumentParser
56
from pathlib import Path
@@ -12,6 +13,7 @@
1213
from biotite.structure.io.pdbx import CIFColumn, CIFFile, set_structure
1314
from loguru import logger
1415
from sampleworks.utils.atom_array_utils import remove_atoms_with_any_nan_coords
16+
from sampleworks.utils.cif_utils import add_category_to_cif
1517

1618

1719
SAMPLEWORKS_CACHE = Path("~/.sampleworks/rcsb").expanduser()
@@ -73,7 +75,11 @@ def parse_args():
7375
)
7476
parser.add_argument("--grid-search-input-dir", required=True)
7577
parser.add_argument(
76-
"--input-pdb-pattern", default="{pdb_id}/{pdb_id}_single_001_density_input.cif"
78+
"--input-pdb-pattern",
79+
default="{pdb_id}/{pdb_id}_single_001_density_input.cif",
80+
help="Pattern used by fnmatch/glob for input pdb files. The complete path of the input "
81+
"pdb must match f'{grid-search-input-dir}/{input-pdb-pattern}'. Defaults to "
82+
"'{pdb_id}/{pdb_id}_single_001_density_input.cif'",
7783
)
7884
args = parser.parse_args()
7985
return args
@@ -159,9 +165,23 @@ def patch_individual_cif_file(
159165
atom_keys = list(zip(asym_unit.chain_id.tolist(), asym_unit.res_id.tolist()))
160166
asym_unit.res_id = np.array([mapping[k] for k in atom_keys], dtype=asym_unit.res_id.dtype)
161167

162-
# load the actual PDB, we'll copy the new coordinates to it.
168+
# load the actual PDB, we'll copy the new coordinates and metadata into it.
163169
template = CIFFile.read(rcsb_path)
164170

171+
# Write sampleworks trial metadata to the CIF file, if we can find it
172+
cif_data = CIFFile.read(cif_path)
173+
if "sampleworks" in cif_data.block:
174+
template.block["sampleworks"] = cif_data.block["sampleworks"]
175+
elif (metadata_path := cif_path.parent / "job_metadata.json").exists():
176+
with open(metadata_path, "r") as fp:
177+
metadata = json.load(fp)
178+
if metadata is not None:
179+
add_category_to_cif(template, metadata, "sampleworks")
180+
else:
181+
logger.warning(f"Sampleworks metadata file at {metadata_path} is empty")
182+
else:
183+
logger.warning(f"No sampleworks metadata found for {cif_path}")
184+
165185
# remove any atoms with nan coordinates--these seem to come in because we sometimes use parse
166186
# (from AtomWorks) which creates them. Still, we'll do this here just in case.
167187
asym_unit = remove_atoms_with_any_nan_coords(asym_unit)

src/sampleworks/utils/cif_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,7 @@ def add_category_to_cif(
297297
raise ValueError("CIFFile has no blocks. Cannot add category.")
298298
elif len(blocks) > 1:
299299
raise ValueError(
300-
f"CIFFile has multiple blocks: {blocks}. "
301-
"Please specify block_name parameter."
300+
f"CIFFile has multiple blocks: {blocks}. Please specify block_name parameter."
302301
)
303302
block = ciffile[blocks[0]]
304303
else:

src/sampleworks/utils/guidance_script_utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
NoiseSpaceDPSScaler,
3131
NoScalingScaler,
3232
)
33-
from sampleworks.utils.cif_utils import resolve_mixed_hetatm_atom_altlocs
33+
from sampleworks.utils.cif_utils import add_category_to_cif, resolve_mixed_hetatm_atom_altlocs
3434
from sampleworks.utils.guidance_constants import (
3535
GuidanceType,
3636
StructurePredictor,
@@ -265,7 +265,7 @@ def get_reward_function_and_structure(
265265

266266

267267
def save_everything(
268-
output_dir: str | Path,
268+
args: GuidanceConfig | argparse.Namespace,
269269
losses: list[Any],
270270
refined_structure: dict,
271271
traj_denoised: list[Any],
@@ -283,8 +283,10 @@ def save_everything(
283283
284284
Parameters
285285
----------
286-
output_dir : str | Path
287-
Directory to write all output files into. Created if it doesn't exist.
286+
args : GuidanceConfig | argparse.Namespace
287+
The arguments for the guidance run. This method directly uses args.output_dir,
288+
and creates that directory if it does not exist. The result of args.as_dict() is
289+
written to a JSON file in the same directory, and inserted into the output CIF file.
288290
losses : list[Any]
289291
Per-step loss values (may contain ``None`` entries for unguided steps).
290292
refined_structure : dict
@@ -304,7 +306,7 @@ def save_everything(
304306
Optional model-space atom template. When provided (mismatch runs),
305307
this template is used for final structure and trajectory saving.
306308
"""
307-
output_dir = Path(output_dir)
309+
output_dir = Path(args.output_dir)
308310
output_dir.mkdir(parents=True, exist_ok=True)
309311

310312
logger.info("Saving results")
@@ -327,10 +329,20 @@ def save_everything(
327329
else:
328330
atom_array = base_atom_array
329331

332+
metadata = args.as_dict()
333+
330334
final_structure = CIFFile()
331335
set_structure(final_structure, atom_array)
336+
add_category_to_cif(final_structure, metadata, category_name="sampleworks")
332337
final_structure.write(str(output_dir / "refined.cif"))
333338

339+
# write out the job parameters to a JSON file in the same directory as the refined.cif file
340+
# Even though this is technically duplicated, keep it around as a backup in case metadata
341+
# is lost in some CIF transform.
342+
with open(Path(output_dir) / "job_metadata.json", "w") as fp:
343+
# use the GuidanceConfig's as_dict() method to avoid serializing PosixPath objects
344+
json.dump(metadata, fp)
345+
334346
# Two calls to save_trajectory, very similar, but saving different trajectories!
335347
save_trajectory(
336348
scaler_type,
@@ -565,7 +577,7 @@ def _run_guidance(
565577
model_atom_array = result.metadata.get("model_atom_array") if result.metadata else None
566578

567579
save_everything(
568-
args.output_dir,
580+
args,
569581
losses,
570582
refined_structure,
571583
traj_denoised,
@@ -641,10 +653,6 @@ def run_guidance_job_queue(job_queue_path: str) -> list[JobResult]:
641653
logger.info(f"Running job {i + 1}/{len(job_queue)}: {job}")
642654

643655
job_result = run_guidance(job, job.guidance_type, model_wrapper, device)
644-
# write out the job parameters to a JSON file in the same directory as the refined.cif file
645-
with open(Path(job_result.output_dir) / "job_metadata.json", "w") as fp:
646-
# use the GuidanceConfig's as_dict() method to avoid serializing PosixPath objects
647-
json.dump(job.as_dict(), fp)
648656

649657
job_results.append(job_result)
650658
torch.cuda.empty_cache() # just in case

tests/integration/test_mismatch_integration.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sampleworks.utils.atom_array_utils import make_normalized_atom_id
2222
from sampleworks.utils.atom_reconciler import AtomReconciler
2323
from sampleworks.utils.frame_transforms import apply_forward_transform
24+
from sampleworks.utils.guidance_script_arguments import GuidanceConfig
2425
from sampleworks.utils.guidance_script_utils import save_everything
2526

2627
from tests.mocks import MismatchCase, MismatchCaseWrapper
@@ -1011,8 +1012,18 @@ def test_save_with_model_template(self, tmp_path: Path):
10111012
refined = {"asym_unit": build_test_atom_array(n_atoms=n_struct)}
10121013
model_atom_array = build_test_atom_array(n_atoms=n_model, with_occupancy=False)
10131014

1015+
args = GuidanceConfig(
1016+
protein="1l63",
1017+
structure=Path("dummy"),
1018+
density=Path("dummy"),
1019+
model="boltz2",
1020+
guidance_type="pure_guidance",
1021+
log_path="dummy",
1022+
output_dir=str(tmp_path),
1023+
)
1024+
10141025
save_everything(
1015-
output_dir=tmp_path,
1026+
args,
10161027
losses=[0.5, 0.3],
10171028
refined_structure=refined,
10181029
traj_denoised=[],

tests/utils/test_cif_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
"""Tests for cif_utils module."""
22

33
import logging
4-
import tempfile
54
from pathlib import Path
65

76
import numpy as np
87
import pytest
98
from atomworks.io.utils.io_utils import load_any
109
from biotite.structure import array, Atom, AtomArray, AtomArrayStack
11-
from biotite.structure.io.pdbx.cif import CIFFile, CIFColumn
10+
from biotite.structure.io.pdbx.cif import CIFColumn, CIFFile
1211
from sampleworks.utils.atom_array_utils import save_structure_to_cif
1312
from sampleworks.utils.cif_utils import add_category_to_cif, resolve_mixed_hetatm_atom_altlocs
1413

tests/utils/test_guidance_script_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44

55
import torch
6+
from sampleworks.utils.guidance_script_arguments import GuidanceConfig
67
from sampleworks.utils.guidance_script_utils import save_everything
78

89
from tests.utils.atom_array_builders import build_test_atom_array
@@ -15,8 +16,18 @@ def test_save_everything_uses_model_atom_array_for_mismatch(tmp_path: Path):
1516

1617
final_state = torch.zeros((1, 5, 3), dtype=torch.float32)
1718

19+
args = GuidanceConfig(
20+
protein="1l63",
21+
structure=Path("dummy"),
22+
density=Path("dummy"),
23+
model="boltz2",
24+
guidance_type="pure_guidance",
25+
log_path="dummy",
26+
output_dir=str(tmp_path),
27+
)
28+
1829
save_everything(
19-
output_dir=tmp_path,
30+
args,
2031
losses=[],
2132
refined_structure=refined_structure,
2233
traj_denoised=[],

0 commit comments

Comments
 (0)