diff --git a/scripts/patch_output_cif_files.py b/scripts/patch_output_cif_files.py index 0d38620..4ffe6a9 100644 --- a/scripts/patch_output_cif_files.py +++ b/scripts/patch_output_cif_files.py @@ -1,5 +1,6 @@ # utility script to put all header information from original PDB entry into our CIF files import fnmatch +import json import re from argparse import ArgumentParser from pathlib import Path @@ -12,6 +13,7 @@ from biotite.structure.io.pdbx import CIFColumn, CIFFile, set_structure from loguru import logger from sampleworks.utils.atom_array_utils import remove_atoms_with_any_nan_coords +from sampleworks.utils.cif_utils import add_category_to_cif SAMPLEWORKS_CACHE = Path("~/.sampleworks/rcsb").expanduser() @@ -73,7 +75,11 @@ def parse_args(): ) parser.add_argument("--grid-search-input-dir", required=True) parser.add_argument( - "--input-pdb-pattern", default="{pdb_id}/{pdb_id}_single_001_density_input.cif" + "--input-pdb-pattern", + default="{pdb_id}/{pdb_id}_single_001_density_input.cif", + help="Pattern used by fnmatch/glob for input pdb files. The complete path of the input " + "pdb must match f'{grid-search-input-dir}/{input-pdb-pattern}'. Defaults to " + "'{pdb_id}/{pdb_id}_single_001_density_input.cif'", ) args = parser.parse_args() return args @@ -159,9 +165,23 @@ def patch_individual_cif_file( atom_keys = list(zip(asym_unit.chain_id.tolist(), asym_unit.res_id.tolist())) asym_unit.res_id = np.array([mapping[k] for k in atom_keys], dtype=asym_unit.res_id.dtype) - # load the actual PDB, we'll copy the new coordinates to it. + # load the actual PDB, we'll copy the new coordinates and metadata into it. template = CIFFile.read(rcsb_path) + # Write sampleworks trial metadata to the CIF file, if we can find it + cif_data = CIFFile.read(cif_path) + if "sampleworks" in cif_data.block: + template.block["sampleworks"] = cif_data.block["sampleworks"] + elif (metadata_path := cif_path.parent / "job_metadata.json").exists(): + with open(metadata_path, "r") as fp: + metadata = json.load(fp) + if metadata is not None: + add_category_to_cif(template, metadata, "sampleworks") + else: + logger.warning(f"Sampleworks metadata file at {metadata_path} is empty") + else: + logger.warning(f"No sampleworks metadata found for {cif_path}") + # remove any atoms with nan coordinates--these seem to come in because we sometimes use parse # (from AtomWorks) which creates them. Still, we'll do this here just in case. asym_unit = remove_atoms_with_any_nan_coords(asym_unit) diff --git a/src/sampleworks/utils/cif_utils.py b/src/sampleworks/utils/cif_utils.py index 26b0109..048b5f9 100644 --- a/src/sampleworks/utils/cif_utils.py +++ b/src/sampleworks/utils/cif_utils.py @@ -3,10 +3,12 @@ from collections import OrderedDict from collections.abc import Iterable from pathlib import Path +from typing import Any import numpy as np from atomworks.io.utils.io_utils import load_any from biotite.structure import AtomArrayStack +from biotite.structure.io.pdbx.cif import CIFCategory, CIFFile from loguru import logger from sampleworks.utils.atom_array_utils import ( @@ -235,3 +237,88 @@ def resolve_mixed_hetatm_atom_altlocs(cif_path: Path | str) -> Path: save_structure_to_cif(fixed_array, tmp_path) logger.info(f"Wrote altloc-fixed CIF to temporary file: {tmp_path}") return tmp_path + + +def add_category_to_cif( + ciffile: CIFFile, + data: dict[str, Any], + category_name: str, + overwrite: bool = False, + block_name: str | None = None, +) -> None: + """Add a custom category in-place to a CIFFile. + + Parameters + ---------- + ciffile : CIFFile + The CIF file object to modify. + data : dict[str, Any] + Dictionary with column names as keys and column data as values. + category_name : str + Name of the category to add (e.g., "custom_data"). + overwrite : bool, optional + If False and the category already exists, raise RuntimeError. Default is False. + block_name : str | None, optional + Name of the block to add the category to. If None, check that there is only + one block and add to that block. Default is None. + + Raises + ------ + RuntimeError + If category already exists and overwrite is False. + ValueError + If block_name is None but the file has multiple blocks, or if the specified + block_name does not exist. + + Examples + -------- + >>> from biotite.structure.io.pdbx.cif import CIFFile + >>> ciffile = CIFFile.read("example.cif") # assuming it contains a single block + >>> data = {"id": [1, 2, 3], "value": ["a", "b", "c"]} + >>> add_category_to_cif(ciffile, data, "my_custom_data") + >>> print(ciffile.block["my_custom_data"].serialize()) + loop_ + _my_custom_data.id + _my_custom_data.value + 1 a + 2 b + 3 c + >>> data = {"sampleworks_version": "0.4.0", "pdb_id": "1L63"} + >>> add_category_to_cif(ciffile, data, "sampleworks_metadata") + >>> print(ciffile.block["sampleworks_metadata"].serialize()) + _sampleworks_metadata.sampleworks_version 0.4.0 + _sampleworks_metadata.pdb_id 1L63 + """ + # Determine which block to use + if block_name is None: + # CIFFile is a Mapping, so inherits .keys(), which ultimately iterates over blocks + blocks = list(ciffile.keys()) + if len(blocks) == 0: + raise ValueError("CIFFile has no blocks. Cannot add category.") + elif len(blocks) > 1: + raise ValueError( + f"CIFFile has multiple blocks: {blocks}. Please specify block_name parameter." + ) + block = ciffile[blocks[0]] + else: + if block_name not in ciffile: + raise ValueError(f"Block '{block_name}' not found in CIFFile.") + block = ciffile[block_name] + + # Check if a category with name category_name already exists + if category_name in block and not overwrite: + raise RuntimeError( + f"Category '{category_name}' already exists in block with value: {block[category_name]}" + ) + + # Create and add the category--remove any None values, CIF requires non-null values + category = CIFCategory( + columns={k: _normalize_nulls(v) for k, v in data.items()}, name=category_name + ) + block[category_name] = category + + +def _normalize_nulls(value: Any) -> Any: + if isinstance(value, Iterable) and not isinstance(value, str | bytes): + return ["?" if item is None else item for item in value] + return "?" if value is None else value diff --git a/src/sampleworks/utils/guidance_script_utils.py b/src/sampleworks/utils/guidance_script_utils.py index 0262b08..d8b2246 100644 --- a/src/sampleworks/utils/guidance_script_utils.py +++ b/src/sampleworks/utils/guidance_script_utils.py @@ -1,6 +1,5 @@ from __future__ import annotations -import argparse import json import os import pickle @@ -30,7 +29,7 @@ NoiseSpaceDPSScaler, NoScalingScaler, ) -from sampleworks.utils.cif_utils import resolve_mixed_hetatm_atom_altlocs +from sampleworks.utils.cif_utils import add_category_to_cif, resolve_mixed_hetatm_atom_altlocs from sampleworks.utils.guidance_constants import ( GuidanceType, StructurePredictor, @@ -265,7 +264,7 @@ def get_reward_function_and_structure( def save_everything( - output_dir: str | Path, + args: GuidanceConfig, losses: list[Any], refined_structure: dict, traj_denoised: list[Any], @@ -283,8 +282,10 @@ def save_everything( Parameters ---------- - output_dir : str | Path - Directory to write all output files into. Created if it doesn't exist. + args : GuidanceConfig + The arguments for the guidance run. This method directly uses args.output_dir, + and creates that directory if it does not exist. The result of args.as_dict() is + written to a JSON file in the same directory, and inserted into the output CIF file. losses : list[Any] Per-step loss values (may contain ``None`` entries for unguided steps). refined_structure : dict @@ -304,7 +305,7 @@ def save_everything( Optional model-space atom template. When provided (mismatch runs), this template is used for final structure and trajectory saving. """ - output_dir = Path(output_dir) + output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) logger.info("Saving results") @@ -327,10 +328,20 @@ def save_everything( else: atom_array = base_atom_array + metadata = args.as_dict() + final_structure = CIFFile() set_structure(final_structure, atom_array) + add_category_to_cif(final_structure, metadata, category_name="sampleworks") final_structure.write(str(output_dir / "refined.cif")) + # write out the job parameters to a JSON file in the same directory as the refined.cif file + # Even though this is technically duplicated, keep it around as a backup in case metadata + # is lost in some CIF transform. + with open(Path(output_dir) / "job_metadata.json", "w") as fp: + # use the GuidanceConfig's as_dict() method to avoid serializing PosixPath objects + json.dump(metadata, fp) + # Two calls to save_trajectory, very similar, but saving different trajectories! save_trajectory( scaler_type, @@ -363,14 +374,12 @@ def save_everything( # Methods for running model guidance in separate processes, avoiding reloading of the model. ##################### # These args are passed from run_grid_search.py via GuidanceConfig. -def run_guidance( - args: GuidanceConfig | argparse.Namespace, guidance_type: str, model_wrapper, device -) -> JobResult: +def run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, device) -> JobResult: """Wrapper around ``_run_guidance`` to redirect logs and generate a JobResult. Parameters ---------- - args : GuidanceConfig | argparse.Namespace + args : GuidanceConfig Configuration for the guidance run. guidance_type : str Type of guidance/scaler to apply. @@ -410,9 +419,7 @@ def run_guidance( # "guidance_type" is also called "scaler" in many places -def _run_guidance( - args: GuidanceConfig | argparse.Namespace, guidance_type: str, model_wrapper, device -): +def _run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, device): reward_function, structure = get_reward_function_and_structure( args.density, # str/path to a map file. device, # this needs to come from the global context, not the args object. @@ -565,7 +572,7 @@ def _run_guidance( model_atom_array = result.metadata.get("model_atom_array") if result.metadata else None save_everything( - args.output_dir, + args, losses, refined_structure, traj_denoised, @@ -587,7 +594,7 @@ def epoch_seconds(time_to_convert: datetime) -> float: def get_job_result( - args: GuidanceConfig | argparse.Namespace, + args: GuidanceConfig, device: torch.device, started_at: datetime, ended_at: datetime, @@ -641,10 +648,6 @@ def run_guidance_job_queue(job_queue_path: str) -> list[JobResult]: logger.info(f"Running job {i + 1}/{len(job_queue)}: {job}") job_result = run_guidance(job, job.guidance_type, model_wrapper, device) - # write out the job parameters to a JSON file in the same directory as the refined.cif file - with open(Path(job_result.output_dir) / "job_metadata.json", "w") as fp: - # use the GuidanceConfig's as_dict() method to avoid serializing PosixPath objects - json.dump(job.as_dict(), fp) job_results.append(job_result) torch.cuda.empty_cache() # just in case diff --git a/tests/integration/test_mismatch_integration.py b/tests/integration/test_mismatch_integration.py index 00bdc9f..01393d4 100644 --- a/tests/integration/test_mismatch_integration.py +++ b/tests/integration/test_mismatch_integration.py @@ -21,6 +21,7 @@ from sampleworks.utils.atom_array_utils import make_normalized_atom_id from sampleworks.utils.atom_reconciler import AtomReconciler from sampleworks.utils.frame_transforms import apply_forward_transform +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import save_everything from tests.mocks import MismatchCase, MismatchCaseWrapper @@ -1011,8 +1012,18 @@ def test_save_with_model_template(self, tmp_path: Path): refined = {"asym_unit": build_test_atom_array(n_atoms=n_struct)} model_atom_array = build_test_atom_array(n_atoms=n_model, with_occupancy=False) + args = GuidanceConfig( + protein="1l63", + structure=Path("dummy"), + density=Path("dummy"), + model="boltz2", + guidance_type="pure_guidance", + log_path="dummy", + output_dir=str(tmp_path), + ) + save_everything( - output_dir=tmp_path, + args, losses=[0.5, 0.3], refined_structure=refined, traj_denoised=[], diff --git a/tests/utils/test_cif_utils.py b/tests/utils/test_cif_utils.py index 5c17303..f52d2e5 100644 --- a/tests/utils/test_cif_utils.py +++ b/tests/utils/test_cif_utils.py @@ -7,8 +7,9 @@ import pytest from atomworks.io.utils.io_utils import load_any from biotite.structure import array, Atom, AtomArray, AtomArrayStack +from biotite.structure.io.pdbx.cif import CIFColumn, CIFFile from sampleworks.utils.atom_array_utils import save_structure_to_cif -from sampleworks.utils.cif_utils import resolve_mixed_hetatm_atom_altlocs +from sampleworks.utils.cif_utils import add_category_to_cif, resolve_mixed_hetatm_atom_altlocs # --------------------------------------------------------------------------- @@ -210,3 +211,169 @@ def test_real_6ni6_warning_mentions_residue_and_modified_name(self, resources_di resolve_mixed_hetatm_atom_altlocs(cif_path) assert "101" in caplog.text assert "CSO" in caplog.text + + +# --------------------------------------------------------------------------- +# Tests for add_category_to_cif +# --------------------------------------------------------------------------- + + +class TestAddCategoryToCif: + """Tests for add_category_to_cif function.""" + + def test_add_category_to_single_block_ciffile(self, tmp_path): + """Add a category to a CIFFile with a single block.""" + # Create a simple CIF file with structure + atoms = [_atom("A", 1, "ALA", False), _atom("A", 2, "VAL", False)] + cif_path = _write_cif(atoms, tmp_path / "test.cif") + + # Read it back + ciffile = CIFFile.read(str(cif_path)) + + # Add a custom category + data = {"id": [1, 2, 3], "value": ["a", "b", "c"], "score": [1.0, 2.0, 3.0]} + add_category_to_cif(ciffile, data, "custom_data") + + # Verify the category was added + block = ciffile[list(ciffile.keys())[0]] + assert "custom_data" in block + category = block["custom_data"] + assert category["id"] == CIFColumn([1, 2, 3]) + assert category["value"] == CIFColumn(["a", "b", "c"]) + assert category["score"] == CIFColumn([1.0, 2.0, 3.0]) + + def test_add_category_with_explicit_block_name(self, tmp_path): + """Add a category to a specific block by name.""" + atoms = [_atom("A", 1, "ALA", False)] + cif_path = _write_cif(atoms, tmp_path / "test.cif") + ciffile = CIFFile.read(str(cif_path)) + + block_name = list(ciffile.keys())[0] + data = {"id": [1]} + add_category_to_cif(ciffile, data, "custom_data", block_name=block_name) + + assert "custom_data" in ciffile[block_name] + + def test_category_already_exists_raises_error(self, tmp_path): + """Adding a category that already exists should raise RuntimeError.""" + atoms = [_atom("A", 1, "ALA", False)] + cif_path = _write_cif(atoms, tmp_path / "test.cif") + ciffile = CIFFile.read(str(cif_path)) + + data = {"id": [1]} + add_category_to_cif(ciffile, data, "custom_data") + + # Try to add the same category again + with pytest.raises(RuntimeError, match="Category 'custom_data' already exists"): + add_category_to_cif(ciffile, data, "custom_data") + + def test_overwrite_existing_category(self, tmp_path): + """Overwriting an existing category should succeed when overwrite=True.""" + atoms = [_atom("A", 1, "ALA", False)] + cif_path = _write_cif(atoms, tmp_path / "test.cif") + ciffile = CIFFile.read(str(cif_path)) + + # Add initial category + data1 = {"id": [1], "value": ["old"]} + add_category_to_cif(ciffile, data1, "custom_data") + + # Overwrite with new data + data2 = {"id": [2, 3], "value": ["new1", "new2"]} + add_category_to_cif(ciffile, data2, "custom_data", overwrite=True) + + # Verify the category was overwritten + block = ciffile[list(ciffile.keys())[0]] + category = block["custom_data"] + assert category["id"] == CIFColumn([2, 3]) + assert category["value"] == CIFColumn(["new1", "new2"]) + + def test_multiple_blocks_without_block_name_raises_error(self, tmp_path): + """If CIFFile has multiple blocks and block_name is None, should raise ValueError.""" + # Create a CIF file with two blocks manually + ciffile = CIFFile() + from biotite.structure.io.pdbx.cif import CIFBlock + + ciffile["block1"] = CIFBlock() + ciffile["block2"] = CIFBlock() + + data = {"id": [1]} + with pytest.raises(ValueError, match="multiple blocks"): + add_category_to_cif(ciffile, data, "custom_data") + + def test_nonexistent_block_name_raises_error(self, tmp_path): + """Specifying a block_name that doesn't exist should raise ValueError.""" + atoms = [_atom("A", 1, "ALA", False)] + cif_path = _write_cif(atoms, tmp_path / "test.cif") + ciffile = CIFFile.read(str(cif_path)) + + data = {"id": [1]} + with pytest.raises(ValueError, match="Block 'nonexistent' not found"): + add_category_to_cif(ciffile, data, "custom_data", block_name="nonexistent") + + def test_write_and_read_back_category(self, tmp_path): + """Demonstrate that a custom category can be written to disk and read back.""" + # Create initial CIF + atoms = [_atom("A", 1, "ALA", False), _atom("A", 2, "VAL", False)] + cif_path = _write_cif(atoms, tmp_path / "test.cif") + + # Read, add category, and write + ciffile = CIFFile.read(str(cif_path)) + data = { + "experiment_id": [1, 2, 3], + "method": ["xray", "nmr", "em"], + "resolution": [2.5, 1.8, 3.2], + } + add_category_to_cif(ciffile, data, "experiment_metadata") + + output_path = tmp_path / "test_with_metadata.cif" + ciffile.write(str(output_path)) + + # Read back and verify + reloaded = CIFFile.read(str(output_path)) + block = reloaded[list(reloaded.keys())[0]] + assert "experiment_metadata" in block + + category = block["experiment_metadata"] + assert category["experiment_id"] == CIFColumn([1, 2, 3]) + assert category["method"] == CIFColumn(["xray", "nmr", "em"]) + assert category["resolution"] == CIFColumn([2.5, 1.8, 3.2]) + + def test_empty_data_dict(self, tmp_path): + """Adding a category with empty data should work.""" + atoms = [_atom("A", 1, "ALA", False)] + cif_path = _write_cif(atoms, tmp_path / "test.cif") + ciffile = CIFFile.read(str(cif_path)) + + data = {} + add_category_to_cif(ciffile, data, "empty_category") + + block = ciffile[list(ciffile.keys())[0]] + assert "empty_category" in block + + def test_single_item_data(self, tmp_path): + """Adding a category with single items (not lists) should work.""" + atoms = [_atom("A", 1, "ALA", False)] + cif_path = _write_cif(atoms, tmp_path / "test.cif") + ciffile = CIFFile.read(str(cif_path)) + + data = {"name": "test_structure", "version": 1.0} + add_category_to_cif(ciffile, data, "metadata") + + block = ciffile[list(ciffile.keys())[0]] + category = block["metadata"] + assert "name" in category + assert "version" in category + + def test_none_values_converted(self, tmp_path): + """None values in data dict should be converted to placeholder string.""" + atoms = [_atom("A", 1, "ALA", False)] + cif_path = _write_cif(atoms, tmp_path / "test.cif") + ciffile = CIFFile.read(str(cif_path)) + + data = {"present": "value", "missing": None} + add_category_to_cif(ciffile, data, "test_category") + + block = ciffile[list(ciffile.keys())[0]] + category = block["test_category"] + # Verify None was replaced (with "none" or "?" depending on implementation) + assert "missing" in category diff --git a/tests/utils/test_guidance_script_utils.py b/tests/utils/test_guidance_script_utils.py index a0b46fc..df5e75e 100644 --- a/tests/utils/test_guidance_script_utils.py +++ b/tests/utils/test_guidance_script_utils.py @@ -3,6 +3,7 @@ from pathlib import Path import torch +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import save_everything from tests.utils.atom_array_builders import build_test_atom_array @@ -15,8 +16,18 @@ def test_save_everything_uses_model_atom_array_for_mismatch(tmp_path: Path): final_state = torch.zeros((1, 5, 3), dtype=torch.float32) + args = GuidanceConfig( + protein="1l63", + structure=Path("dummy"), + density=Path("dummy"), + model="boltz2", + guidance_type="pure_guidance", + log_path="dummy", + output_dir=str(tmp_path), + ) + save_everything( - output_dir=tmp_path, + args, losses=[], refined_structure=refined_structure, traj_denoised=[],