diff --git a/docs/source/Inference.md b/docs/source/Inference.md index 0aa607ced..d70eccbb2 100644 --- a/docs/source/Inference.md +++ b/docs/source/Inference.md @@ -24,12 +24,12 @@ Supported: - Template-based prediction - using ColabFold template alignments - using pre-computed template alignments + - using direct CIF template files (no alignments required) - Non-canonical residues Coming soon: - Covalently modified residues and other cross-chain covalent bonds -- User-specified template structures (as opposed to top 4) ### 1.2 DNA @@ -301,6 +301,61 @@ model_update: --- +(inference-cif-direct-templates)= +#### 🧬 CIF Direct Template Mode + +OpenFold3 supports providing template structures directly as CIF files without requiring pre-computed template alignments. In this mode, the system automatically: +1. Parses each provided CIF file +2. Extracts all chains and their sequences +3. Aligns each chain to your query sequence +4. Selects the best matching chain based on sequence identity × coverage score + +This is particularly useful for stateless inference environments or when you have specific template structures but no alignment files. + +**Usage:** + +In your query JSON, specify `template_cif_paths` instead of `template_alignment_file_path`: + +```json +{ + "queries": { + "my_query": { + "chains": [ + { + "molecule_type": "protein", + "chain_ids": ["A"], + "sequence": "MKLLVVDDAGQKFT...", + "template_cif_paths": [ + "path/to/template1.cif", + "path/to/template2.cif", + "path/to/template3.cif" + ], + "template_cif_chain_ids": ["A", null, "B"] + } + ] + } + } +} +``` + +Optionally, use `template_cif_chain_ids` to specify which chain to use from each CIF file. Use `null` to let the system automatically select the best-matching chain. + +**Configuration:** + +You can adjust the minimum score threshold for chain selection in your `runner.yml`: + +```yaml +template_preprocessor_settings: + cif_direct_min_score: 0.1 # Default: 0.1 (seq_identity × coverage) +``` + +**Notes:** +- For multi-chain CIF files, only the best matching chain per file is used as a template +- The `template_cif_paths` field cannot be used together with `template_alignment_file_path` +- This mode is currently supported for protein chains only + +--- + ### 3.4 Customized ColabFold MSA Server Settings Using `runner.yml` All settings for the ColabFold server and outputs can be set under [`msa_computation_settings`](https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/core/data/tools/colabfold_msa_server.py#L904) @@ -478,9 +533,13 @@ This file representing the full input query in a validated internal format defin - `template_alignment_file_path`: Path to the preprocessed template cache entry `.npz` file used for template featurization. By default, template cache entries are automatically created in a short preprocessing step using the raw template alignment files provided under this same field and the template structures identified in the alignment. + - `template_cif_paths`: List of paths to CIF template files when using {ref}`CIF direct template mode `. This field is mutually exclusive with `template_alignment_file_path`. + + - `template_cif_chain_ids`: List of chain IDs to use from each corresponding CIF file in `template_cif_paths`. Use `null` for entries where automatic chain selection is desired. Must have the same length as `template_cif_paths` if provided. + - `template_entry_chain_ids`: List of template chains, identified by their entry (typically PDB) IDs and chain IDs, used for featurization. By default, up to the first 4 of these chains are used. - Note: Refer to the {doc}`Template How-To Documentation ` for how to specify these fields if you want to use precomputed template alignments instead of Colabfold alignments for template inputs. + Note: Refer to the {doc}`Template How-To Documentation ` for how to specify these fields if you want to use precomputed template alignments instead of Colabfold alignments for template inputs, or see {ref}`CIF Direct Template Mode ` for using template structures directly without alignments. Note: If MSA and template files are persisted between runs, the same `inference_query_set.json` file can be used to resubmit the query without needing to rerun the template and MSA pipelines. To do so: diff --git a/docs/source/input_format.md b/docs/source/input_format.md index acb02525c..32393f36c 100644 --- a/docs/source/input_format.md +++ b/docs/source/input_format.md @@ -64,6 +64,8 @@ All chains must define a unique ```chain_ids``` field and appropriate sequence o "paired_msa_file_paths": "/absolute/path/to/paired_msas", "template_alignment_file_path": "/absolute/path/to/template_msa", "template_entry_chain_ids": ["entry1_A", "entry2_B", "entry3_A"], + "template_cif_paths": ["/path/to/template1.cif", "/path/to/template2.cif"], + "template_cif_chain_ids": ["A", null], } ``` @@ -119,6 +121,19 @@ All chains must define a unique ```chain_ids``` field and appropriate sequence o - Use this field only when running inference with **precomputed alignments**. See the {doc}`Running with Templates Documentation ` for details. - If using the ColabFold MSA server, this field is automatically populated and will **override any user-provided path**. + - `template_cif_paths` *(list[str], optional, default = null)* + - List of paths to CIF files to use as templates for this chain. + - Enables **CIF-direct template mode**, which parses templates directly from CIF files without requiring pre-computed alignments. + - Alignments are computed on-the-fly using Kalign. + - This is useful when you have known template structures but no pre-computed MSA/template alignments. + - Example: `["/path/to/template1.cif", "/path/to/template2.cif"]` + + - `template_cif_chain_ids` *(list[str | null], optional, default = null)* + - List of chain IDs to use from each corresponding CIF file in `template_cif_paths`. + - Must have the same length as `template_cif_paths` if provided. + - Use `null` for a specific entry to let the parser automatically select the best-matching chain. + - Example: `["A", null, "B"]` - uses chain A from the first CIF, auto-selects from the second, and uses chain B from the third. + ### 3.2. RNA Chains ``` diff --git a/docs/source/template_how_to.md b/docs/source/template_how_to.md index 7d9dd427a..7a352e920 100644 --- a/docs/source/template_how_to.md +++ b/docs/source/template_how_to.md @@ -1,6 +1,13 @@ # Running OpenFold3 Inference with Templates -This document contains instructions on how to use template information for OF3 predictions. Here, we assume that you already generated all of your template alignments or intend to fetch them from Colabfold on-the-fly. If you do not have any precomputed template alignments and do not want to use Colabfold, refer to our {doc}`MSA Generation Guide ` before consulting this document. If you need further clarifications on how some of the template components of our inference pipeline work, refer to {doc}`this explanatory document `. +This document contains instructions on how to use template information for OF3 predictions. OpenFold3 supports two template modes: + +1. **Alignment-based templates** (traditional): Requires template alignments and template structures +2. **CIF direct templates** (simplified): Requires only template CIF files, no alignments needed + +For alignment-based templates, we assume you already generated all of your template alignments or intend to fetch them from Colabfold on-the-fly. If you do not have any precomputed template alignments and do not want to use Colabfold, refer to our {doc}`MSA Generation Guide ` before consulting this document. + +If you need further clarifications on how some of the template components of our inference pipeline work, refer to {doc}`this explanatory document `. The template pipeline currently supports monomeric templates and has been tested for protein chains only. @@ -12,10 +19,18 @@ The main steps detailed in this guide are: (1-template-files)= ## 1. Template Files -Template featurization requires query-to-template **alignments** and template **structures**. +OpenFold3 supports two modes for providing template information: + +### Alignment-Based Mode (Traditional) +Requires query-to-template **alignments** and template **structures**. Sections 1.1 and 1.2 below describe the required file formats. + +### CIF Direct Mode (Simplified) +Requires only template **CIF files**. The system automatically aligns template chains to your query sequence and selects the best matching chain. See {ref}`Section 2.3 <23-cif-direct-templates>` for usage details. + +--- (11-template-aligment-file-format)= -### 1.1. Template Aligment File Format +### 1.1. Template Alignment File Format (Alignment-Based Mode) Template alignments can be provided in either `sto`, `a3m` or `m8` format. Template alignments from the Colabfold server are in `m8` format. @@ -73,16 +88,18 @@ query_A template_C 71.4 14 4 0 5 18 75 88 2e-03 22.3 Note that since `m8` files do not provide actual alignments, we only use them to identify which structure files to get templates from, retrieve sequences from these structure files and always realign them to the query sequence using Kalign. More on this in the [template processing explanatory document](template_explanation.md). -### 1.2. Template Structure File Format +### 1.2. Template Structure File Format (Alignment-Based Mode) + +For alignment-based templates, template structures currently can only be provided in `cif` format. An upcoming release will add support for parsing templates from `pdb` files. -Template structures currently can only be provided in `cif` format. An upcoming release will add support for parsing templates from `pdb` files. +**Note:** For {ref}`CIF direct mode <23-cif-direct-templates>`, template CIF files are specified directly in the query JSON without separate structure directories. (2-specifying-template-information-in-the-inference-query-file)= ## 2. Specifying Template Information in the Inference Query File -### 2.1. Specifying Alignments +### 2.1. Specifying Alignments (Alignment-Based Mode) -The data pipeline needs to know which template alignment to use for which chain. This information is provided by specifying the {ref}`paths to the alignments <31-protein-chains>` for each chain's `template_alignment_file_path` field in the inference query json file. +For alignment-based templates, the data pipeline needs to know which template alignment to use for which chain. This information is provided by specifying the {ref}`paths to the alignments <31-protein-chains>` for each chain's `template_alignment_file_path` field in the inference query json file. Note that when fetching alignments from the Colabfold server, `template_alignment_file_path` fields are automatically populated. @@ -118,9 +135,9 @@ Note that when fetching alignments from the Colabfold server, `template_alignmen -### 2.2. Using Specific Templates +### 2.2. Using Specific Templates (Alignment-Based Mode) -By default, the template pipeline automatically populates the `template_entry_chain_ids` field with [n templates](https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/core/data/pipelines/preprocessing/template.py#L1535) from the alignment, which is then further subset to the [top k templates](https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/projects/of3_all_atom/config/dataset_config_components.py#L116) during featurization for inference. +By default, for alignment-based templates, the template pipeline automatically populates the `template_entry_chain_ids` field with [n templates](https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/core/data/pipelines/preprocessing/template.py#L1535) from the alignment, which is then further subset to the [top k templates](https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/projects/of3_all_atom/config/dataset_config_components.py#L116) during featurization for inference. In an **upcoming release**, we will add support for specifying *specific templates* for the data pipeline to use for featurization. This will be possible through the `template_entry_chain_ids` field: @@ -156,10 +173,77 @@ entry3_A MK----DDARGQGKFT // ``` +(23-cif-direct-templates)= +### 2.3. CIF Direct Templates (No Alignments Required) + +OpenFold3 supports providing template structures directly as CIF files without requiring pre-computed template alignments. This is particularly useful for: +- Stateless inference environments (e.g., NVIDIA Inference Microservices) +- Quick predictions when you have specific template structures +- Simplified workflows without external alignment tools + +#### How It Works + +In CIF direct mode, the system automatically: +1. Parses each provided CIF file to extract all chains and their sequences +2. Aligns each chain sequence to your query sequence using sequence alignment +3. Scores each chain by `sequence_identity × coverage` +4. Selects the best matching chain as the template (if score ≥ minimum threshold) + +For multi-chain CIF files, only the best matching chain per file is used. + +#### Usage Example + +Specify `template_cif_paths` instead of `template_alignment_file_path` in your query JSON: + +```json +{ + "queries": { + "my_protein": { + "chains": [ + { + "molecule_type": "protein", + "chain_ids": ["A", "B"], + "sequence": "XRMKQLEDKVEELLSKNYHLENEVARLKKLVGER", + "template_cif_paths": [ + "templates/1dgc.cif", + "templates/1ysa.cif", + "templates/1zta.cif" + ] + } + ] + } + } +} +``` + +**Example query files:** +- [Homomer with direct CIF templates](https://github.com/aqlaboratory/openfold-3/blob/main/examples/example_inference_inputs/query_homomer_with_direct_cif_templates.json) +- [Multimer with direct CIF templates](https://github.com/aqlaboratory/openfold-3/blob/main/examples/example_inference_inputs/query_multimer_with_direct_cif_templates.json) + +#### Configuration + +Adjust the minimum score threshold for chain selection in your `runner.yml`: + +```yaml +template_preprocessor_settings: + cif_direct_min_score: 0.1 # Default: 0.1 (seq_identity × coverage) +``` + +Only chains with a score (sequence identity × coverage) above this threshold will be considered as valid templates. + +#### Important Notes + +- The `template_cif_paths` field is **mutually exclusive** with `template_alignment_file_path` - you must use one or the other, not both +- Template structures must be in CIF format +- Currently supported for protein chains only +- For multi-chain CIF files, the system automatically selects the best matching chain per file + (3-optimizations-for-high-throughput-workflows)= ## 3. Optimizations for High-Throughput Workflows -For high-throughput use cases, where a large number of structures are to be predicted, template processing can take a significant amount of time even with the built-in {doc}`deduplication utility ` we have for template alignment and structure processing. To avoid having to spend GPU compute on data transformations, we provide separate template preprocessing scripts to generate the necessary inputs from which template featurization can run efficiently in a subsequent job without being a bottleneck to the model forward pass. +**Note:** The optimizations described in this section apply to **alignment-based templates**. If you're using {ref}`CIF direct templates <23-cif-direct-templates>`, the workflow is already simplified and these preprocessing steps are not necessary. + +For high-throughput use cases with alignment-based templates, where a large number of structures are to be predicted, template processing can take a significant amount of time even with the built-in {doc}`deduplication utility ` we have for template alignment and structure processing. To avoid having to spend GPU compute on data transformations, we provide separate template preprocessing scripts to generate the necessary inputs from which template featurization can run efficiently in a subsequent job without being a bottleneck to the model forward pass. ### 3.1. Template Alignment Preprocessing diff --git a/openfold3/core/data/framework/data_module.py b/openfold3/core/data/framework/data_module.py index 8c3a7e67d..d40a5a0da 100644 --- a/openfold3/core/data/framework/data_module.py +++ b/openfold3/core/data/framework/data_module.py @@ -41,6 +41,7 @@ import dataclasses import enum +import logging import random import warnings from typing import Any @@ -74,6 +75,7 @@ from openfold3.core.utils.tensor_utils import dict_multimap _NUMPY_AVAILABLE = RequirementCache("numpy") +logger = logging.getLogger(__name__) class DatasetMode(enum.Enum): @@ -516,8 +518,15 @@ def __init__( self.inference_config = _configs.configs[0] def prepare_data(self) -> None: + logger.info("=" * 60) + logger.info( + f"Prepare data: use_msa_server={self.use_msa_server}, use_templates={self.use_templates}" + ) + logger.info("=" * 60) + # Colabfold msa preparation if self.use_msa_server: + logger.info("Running ColabFold MSA server...") self.inference_config.query_set = preprocess_colabfold_msas( inference_query_set=self.inference_config.query_set, compute_settings=self.msa_computation_settings, @@ -529,11 +538,13 @@ def prepare_data(self) -> None: ) if self.use_templates: + logger.info("Running template preprocessing...") template_preprocessor = TemplatePreprocessor( input_set=self.inference_config.query_set, config=self.inference_config.template_preprocessor_settings, ) template_preprocessor() + logger.info("Template preprocessing complete!") def setup(self, stage=None): """Broadcast updated query set to all ranks if multiple GPUs are used.""" diff --git a/openfold3/core/data/io/sequence/template.py b/openfold3/core/data/io/sequence/template.py index 5cea16682..1dc0db44c 100644 --- a/openfold3/core/data/io/sequence/template.py +++ b/openfold3/core/data/io/sequence/template.py @@ -16,6 +16,7 @@ Parsers for template alignments. """ +import logging import re from abc import ABC, abstractmethod from collections.abc import Iterable @@ -29,6 +30,8 @@ from openfold3.core.data.resources.residues import MoleculeType from openfold3.core.data.tools.kalign import run_kalign +logger = logging.getLogger(__name__) + """ Updates compared to the old OpenFold version: @@ -370,6 +373,8 @@ class TemplateData(NamedTuple): Coverage of the full query sequence in the template alignment. template_sequence (str | None): The ungapped template sequence. + cif_path (Path | None): + Original CIF file path for CIF-direct mode. """ index: int @@ -380,6 +385,7 @@ class TemplateData(NamedTuple): seq_id: float q_cov: float | None seq: str | None + cif_path: Path | None = None def calculate_ids_hit( @@ -807,7 +813,212 @@ def __call__( return templates -TEMPLATE_PARSER_REGISTRY = {".a3m": A3mParser, ".sto": StoParser, ".m8": M8Parser} +class CifDirectParser(TemplateParser): + """Parses templates directly from CIF files without pre-computed alignments. + + For multi-chain CIF files, automatically selects the chain with the best + alignment to the query sequence (highest seq_identity × coverage). + + This parser is designed for stateless inference environments (e.g., NIM) + where users provide CIF files directly without external alignment servers. + """ + + def __init__( + self, + max_sequences: int, + min_score_threshold: float = 0.1, + ): + """Initialize CifDirectParser. + + Args: + max_sequences: Maximum number of templates to parse (not used for + single CIF parsing, but kept for interface compatibility) + min_score_threshold: Minimum seq_id × coverage score to consider + a chain as a valid template candidate + """ + super().__init__(max_sequences) + self.min_score_threshold = min_score_threshold + logger.info(f"CifDirectParser initialized with min_score_threshold={min_score_threshold}") + + def _process_chain( + self, + chain_id: str, + chain_seq: str, + query_seq_str: str, + entry_id: str, + ) -> dict | None: + if not chain_seq or len(chain_seq) == 0: + return None + + fasta_input = ( + f">query\n{query_seq_str}\n" + f">{entry_id}_{chain_id}\n{chain_seq}\n" + ) + + try: + aligned_str = run_kalign(fasta_input) + + if not aligned_str or not aligned_str.strip(): + logger.debug( + f"Chain {chain_id} from {entry_id}: Kalign returned empty result" + ) + return None + + alignments, _ = parse_fasta(aligned_str) + + if not alignments: + logger.debug( + f"Chain {chain_id} from {entry_id}: parse_fasta returned empty list. " + f"Kalign output:\n{aligned_str[:200]}" + ) + return None + + if len(alignments) < 2: + logger.debug( + f"Chain {chain_id} from {entry_id}: only {len(alignments)} sequences in alignment. " + f"Kalign output:\n{aligned_str[:200]}" + ) + return None + + query_aln_seq = alignments[0] + template_aln_seq = alignments[1] + + if not query_aln_seq or not template_aln_seq: + logger.debug( + f"Chain {chain_id} from {entry_id}: empty alignment sequences" + ) + return None + + query_aln_arr = np.fromiter( + query_aln_seq, dtype=' dict[int, TemplateData]: + """Parse a CIF file and select the best matching chain. + + Args: + cif_file_path: Path to the CIF file + query_seq_str: Query sequence to align against + chain_id_seq_map: Mapping of chain IDs to their sequences + (extracted from CIF metadata) + entry_id: Optional entry ID (defaults to filename stem) + specified_chain_id: Optional specific chain ID to use. + If provided, only this chain will be aligned and used. + If None, performs automatic chain selection from all chains. + + Returns: + Dictionary with single entry (index 0) containing TemplateData + for the best matching chain, or empty dict if no valid chains found + """ + if entry_id is None: + entry_id = cif_file_path.stem + + candidate_chains = [] + + if specified_chain_id is not None: + logger.info( + f"CIF Direct: Processing {entry_id}, using specified chain " + f"{specified_chain_id} (skipping other chains)" + ) + if specified_chain_id in chain_id_seq_map: + chain_result = self._process_chain( + specified_chain_id, + chain_id_seq_map[specified_chain_id], + query_seq_str, + entry_id, + ) + if chain_result is not None: + candidate_chains.append(chain_result) + else: + logger.info( + f"CIF Direct: Processing {entry_id} with {len(chain_id_seq_map)} chains " + "(automatic selection)" + ) + for chain_id, chain_seq in chain_id_seq_map.items(): + chain_result = self._process_chain( + chain_id, + chain_seq, + query_seq_str, + entry_id, + ) + if chain_result is not None: + candidate_chains.append(chain_result) + + if not candidate_chains: + if specified_chain_id is not None: + logger.warning( + f"CIF Direct: Specified chain {specified_chain_id} not found or " + f"failed alignment in {entry_id}" + ) + else: + logger.info(f"CIF Direct: No valid chains found in {entry_id}") + return {} + + best_chain = max(candidate_chains, key=lambda x: x['score']) + + logger.info(f"CIF Direct: Selected chain {best_chain['chain_id']} from {entry_id} (seq_id={best_chain['seq_id']:.3f}, q_cov={best_chain['q_cov']:.3f})") + + template_data = TemplateData( + index=0, + entry_id=entry_id, + chain_id=best_chain['chain_id'], + query_aln_pos=best_chain['query_aln_pos'], + aln_pos=best_chain['template_aln_pos'], + seq_id=best_chain['seq_id'], + q_cov=best_chain['q_cov'], + seq=best_chain['template_seq'], + ) + + return {0: template_data} + + +TEMPLATE_PARSER_REGISTRY = {".a3m": A3mParser, ".sto": StoParser, ".m8": M8Parser, ".cif": CifDirectParser} def parse_template_alignment( diff --git a/openfold3/core/data/pipelines/preprocessing/template.py b/openfold3/core/data/pipelines/preprocessing/template.py index 4221c08b2..d9628cce4 100644 --- a/openfold3/core/data/pipelines/preprocessing/template.py +++ b/openfold3/core/data/pipelines/preprocessing/template.py @@ -85,6 +85,7 @@ InferenceQuerySet, ) +logger = logging.getLogger(__name__) # TODO: rename variables to be PDB-agnostic # --- Template alignment preprocessing --- @@ -1451,9 +1452,33 @@ class TemplatePreprocessorInputTrain(BaseModel): class TemplatePreprocessorInputInference(BaseModel): - aln_path: Path + aln_path: Path | None = None query_seq_str: str template_entry_chain_ids: list[str] | None = None + template_cif_paths: list[Path] | None = None + template_cif_chain_ids: list[str | None] | None = None + + @model_validator(mode='after') + def validate_inputs(self) -> "TemplatePreprocessorInputInference": + """Validate template input consistency.""" + if self.aln_path is not None and self.template_cif_paths is not None: + raise ValueError( + "Cannot provide both 'aln_path' and 'template_cif_paths'. " + "Choose one mode: alignment-based or CIF-direct." + ) + + if self.template_cif_chain_ids is not None: + if self.template_cif_paths is None: + raise ValueError( + "'template_cif_chain_ids' requires 'template_cif_paths'" + ) + if len(self.template_cif_chain_ids) != len(self.template_cif_paths): + raise ValueError( + f"Length mismatch: {len(self.template_cif_paths)} CIF files " + f"but {len(self.template_cif_chain_ids)} chain IDs" + ) + + return self class TemplatePreprocessorSettings(BaseModel): @@ -1533,6 +1558,8 @@ class TemplatePreprocessorSettings(BaseModel): max_release_date: datetime | None = None min_release_date_diff: int | None = None max_templates: int = 20 + + cif_direct_min_score: float = 0.1 fetch_missing_structures: bool = True create_precache: bool = False @@ -1623,6 +1650,8 @@ def __init__( self.max_release_date = config.max_release_date self.min_release_date_diff = config.min_release_date_diff self.max_templates = config.max_templates + + self.cif_direct_min_score = config.cif_direct_min_score self.fetch_missing_structures = config.fetch_missing_structures self.create_precache = config.create_precache @@ -1667,24 +1696,45 @@ def _parse_inference_query_set(self) -> None: if chain.molecule_type not in self.moltypes: continue - if chain.template_alignment_file_path is None: + # CASE 1: Alignment file mode + if chain.template_alignment_file_path is not None: + template_alignment_path = Path(chain.template_alignment_file_path) + if template_alignment_path not in paths_seen: + paths_seen.add(template_alignment_path) + inputs.append( + TemplatePreprocessorInputInference( + aln_path=template_alignment_path, + query_seq_str=chain.sequence, + template_entry_chain_ids=chain.template_entry_chain_ids, + template_cif_paths=None, + ) + ) + + # CASE 2: CIF-direct mode + elif chain.template_cif_paths is not None: + cif_paths_tuple = tuple(sorted(str(p) for p in chain.template_cif_paths)) + chain_ids_tuple = tuple(chain.template_cif_chain_ids) if chain.template_cif_chain_ids else () + cache_key = (cif_paths_tuple, chain_ids_tuple) + + if cache_key not in paths_seen: + paths_seen.add(cache_key) + inputs.append( + TemplatePreprocessorInputInference( + aln_path=None, + query_seq_str=chain.sequence, + template_entry_chain_ids=None, + template_cif_paths=chain.template_cif_paths, + template_cif_chain_ids=chain.template_cif_chain_ids, + ) + ) + + else: print( - f"Warning: No template alignment file path provided for chain " + f"Warning: No template data provided for chain " f"{chain.chain_ids} of query {query_name}, skipping..." ) continue - - template_alignment_path = Path(chain.template_alignment_file_path) - - if template_alignment_path not in paths_seen: - paths_seen.add(template_alignment_path) - inputs.append( - TemplatePreprocessorInputInference( - aln_path=template_alignment_path, - query_seq_str=chain.sequence, - template_entry_chain_ids=chain.template_entry_chain_ids, - ) - ) + self.inputs = inputs def _update_dataset_cache(self) -> None: @@ -1747,6 +1797,116 @@ def __call__(self) -> None: elif isinstance(self.input_set, InferenceQuerySet): self._update_inference_query_set() + def _parse_templates_from_cif_files( + self, + input_data: TemplatePreprocessorInputInference, + ) -> dict[int, TemplateData]: + """Parse templates from CIF files in CIF-direct mode. + + Args: + input_data: Input data containing CIF file paths + + Returns: + Dictionary mapping indices to TemplateData objects + """ + from openfold3.core.data.io.sequence.template import CifDirectParser + from openfold3.core.data.io.structure.cif import _load_ciffile + from openfold3.core.data.primitives.structure.metadata import ( + get_asym_id_to_canonical_seq_dict, + ) + if self.create_logs: + worker_logger = logging.getLogger(f"template_preprocess_{os.getpid()}") + worker_logger.info(f"CIF-direct mode: Processing {len(input_data.template_cif_paths)} CIF files") + + templates = {} + parser = CifDirectParser( + max_sequences=None, + min_score_threshold=self.cif_direct_min_score, + ) + + if self.create_logs: + worker_logger = logging.getLogger(f"template_preprocess_{os.getpid()}") + worker_logger.info( + f"Processing {len(input_data.template_cif_paths)} CIF files " + "in CIF-direct mode..." + ) + + for idx, cif_path in enumerate(input_data.template_cif_paths): + try: + if not cif_path.exists(): + if self.create_logs: + worker_logger.warning(f"CIF file not found: {cif_path}") + continue + + entry_id = cif_path.stem + + specified_chain_id = None + if input_data.template_cif_chain_ids is not None: + specified_chain_id = input_data.template_cif_chain_ids[idx] + + if self.precache_directory is not None: + precache_file = self.precache_directory / f"{entry_id}.npz" + if precache_file.exists(): + precache_entry = np.load(precache_file, allow_pickle=True) + chain_id_seq_map = precache_entry["chain_id_seq_map"].item() + else: + cif_file = _load_ciffile(cif_path) + chain_id_seq_map = get_asym_id_to_canonical_seq_dict(cif_file) + else: + cif_file = _load_ciffile(cif_path) + chain_id_seq_map = get_asym_id_to_canonical_seq_dict(cif_file) + + cif_templates = parser( + cif_file_path=cif_path, + query_seq_str=input_data.query_seq_str, + chain_id_seq_map=chain_id_seq_map, + entry_id=entry_id, + specified_chain_id=specified_chain_id, + ) + + if cif_templates: + template = cif_templates[0] + template = template._replace(index=idx, cif_path=cif_path) + templates[idx] = template + + if self.create_logs: + worker_logger.info( + f"Selected chain {template.chain_id} from {entry_id} " + f"(seq_id={template.seq_id:.3f}, q_cov={template.q_cov:.3f})" + ) + + except Exception as e: + if self.create_logs: + worker_logger.warning( + f"Failed to process CIF file {cif_path}: {e}" + ) + continue + + if templates: + templates_list = list(templates.values()) + templates_list.sort( + key=lambda t: t.seq_id * (t.q_cov if t.q_cov is not None else 0.0), + reverse=True + ) + templates = {i: t._replace(index=i) for i, t in enumerate(templates_list)} + + if self.create_logs: + worker_logger.info(f"Parsed {len(templates)} templates from CIF files") + for t in templates_list[:5]: + if self.create_logs: + worker_logger.info(f" - {t.entry_id}_{t.chain_id}: seq_id={t.seq_id:.3f}, q_cov={t.q_cov:.3f}") + + if self.create_logs: + worker_logger.info( + f"Parsed {len(templates)} templates from CIF files, " + f"sorted by alignment quality" + ) + else: + if self.create_logs: + worker_logger.info("No templates parsed from CIF files") + + return templates + def _preprocess_templates_for_query( self, input_data: TemplatePreprocessorInputTrain | TemplatePreprocessorInputInference, @@ -1763,9 +1923,17 @@ def _preprocess_templates_for_query( worker_logger.propagate = False # preprocess templates for a single chain - # 1. Parse template alignment file - if self.create_logs: - worker_logger.info(f"Parsing template alignment {input_data.aln_path}...") + # 1. Parse template alignment file or CIF files + if isinstance(input_data, TemplatePreprocessorInputInference): + if input_data.aln_path is not None: + if self.create_logs: + worker_logger.info(f"Parsing template alignment {input_data.aln_path}...") + elif input_data.template_cif_paths is not None: + if self.create_logs: + worker_logger.info( + f"Processing {len(input_data.template_cif_paths)} CIF files " + "in CIF-direct mode..." + ) # TODO: 2. if template entry and chain ID list provided in the IQS - skip some # below: @@ -1794,9 +1962,21 @@ def _preprocess_templates_for_query( worker_logger.info( f"Creating new cache entry {template_cache_entry_file}." ) - templates = parse_template_alignment( - input_data.aln_path, input_data.query_seq_str, self.max_sequences_parse - ) + + # Parse templates based on mode + if isinstance(input_data, TemplatePreprocessorInputInference): + if input_data.aln_path is not None: + templates = parse_template_alignment( + input_data.aln_path, input_data.query_seq_str, self.max_sequences_parse + ) + elif input_data.template_cif_paths is not None: + templates = self._parse_templates_from_cif_files(input_data) + else: + return + else: + templates = parse_template_alignment( + input_data.aln_path, input_data.query_seq_str, self.max_sequences_parse + ) if self.create_logs: worker_logger.info(f"Parsed {len(templates)} templates...") template_cache_entry = {} @@ -1983,7 +2163,7 @@ def _preprocess_templates_for_query( continue # H. Add to cache entry - template_cache_entry[f"{template.entry_id}_{chain_id_matched}"] = { + cache_entry_data = { "index": template.index, "release_date": release_date, "idx_map": np.concatenate( @@ -1994,6 +2174,10 @@ def _preprocess_templates_for_query( axis=1, ), } + # Store CIF path for CIF-direct mode + if template.cif_path is not None: + cache_entry_data["cif_path"] = str(template.cif_path) + template_cache_entry[f"{template.entry_id}_{chain_id_matched}"] = cache_entry_data template_ids.append(f"{template.entry_id}_{chain_id_matched}") if self.create_logs: worker_logger.info( diff --git a/openfold3/core/data/pipelines/sample_processing/template.py b/openfold3/core/data/pipelines/sample_processing/template.py index 37f974a26..f228fd738 100644 --- a/openfold3/core/data/pipelines/sample_processing/template.py +++ b/openfold3/core/data/pipelines/sample_processing/template.py @@ -44,6 +44,7 @@ def process_template_structures_of3( template_file_format: str, ccd: CIFFile | None, use_roda_monomer_format: bool = False, + cif_assembly_cache: dict[str, tuple] | None = None, ) -> TemplateSliceCollection: """Processes template structures for all chains of a given target structure. @@ -81,6 +82,10 @@ def process_template_structures_of3( use_roda_monomer_format (bool): Whether template cache filepath is expected to be in the s3 RODA monomer format: //template.npz + cif_assembly_cache (dict[str, tuple] | None): + Optional cache dictionary mapping PDB IDs to (cif_file, atom_array_assembly) + tuples. Used to avoid re-parsing the same CIF file when multiple chains from + the same structure are needed. If None, a new cache will be created. Returns: TemplateSliceCollection: The sliced template atomarrays for each chain in the crop. @@ -95,6 +100,10 @@ def process_template_structures_of3( ): return TemplateSliceCollection(template_slices={}) + # Initialize CIF assembly cache if not provided (to avoid re-parsing same PDB files) + if cif_assembly_cache is None: + cif_assembly_cache = {} + # Iterate over protein chains in the atom array # TODO: currently, this re-processes templates identical chains, add redundancy # logic if becomes a bottleneck @@ -120,6 +129,7 @@ def process_template_structures_of3( template_file_format=template_file_format, ccd=ccd, atom_array_query_chain=atom_array[atom_array.chain_id == chain_id], + cif_assembly_cache=cif_assembly_cache, ) return TemplateSliceCollection(template_slices=template_slices) diff --git a/openfold3/core/data/primitives/sequence/msa.py b/openfold3/core/data/primitives/sequence/msa.py index 16aa01af4..e7e3afd0f 100644 --- a/openfold3/core/data/primitives/sequence/msa.py +++ b/openfold3/core/data/primitives/sequence/msa.py @@ -637,6 +637,9 @@ def extract_alignments_to_pair( if m is not None: msa_arrays_to_pair_i.append(m) + if not msa_arrays_to_pair_i: + continue + msa_arrays_to_pair_cat = MsaArray.multi_concatenate( msa_arrays=msa_arrays_to_pair_i, axis=0, diff --git a/openfold3/core/data/primitives/structure/template.py b/openfold3/core/data/primitives/structure/template.py index ad9df44e7..280a8cb2a 100644 --- a/openfold3/core/data/primitives/structure/template.py +++ b/openfold3/core/data/primitives/structure/template.py @@ -57,11 +57,14 @@ class TemplateCacheEntry: The release date of the template structure. idx_map (np.ndarray[int]): Dictionary mapping tokens that fall into the crop to corresponding residue - indices in the matching alignment.""" + indices in the matching alignment. + cif_path (Path | None): + Original CIF file path for CIF-direct mode.""" index: int release_date: str idx_map: np.ndarray[int] + cif_path: Path | None = None @dataclasses.dataclass(frozen=False) @@ -294,6 +297,8 @@ def parse_template_structure( template_pdb_chain_id: str, template_file_format: str, ccd: CIFFile | None, + cif_assembly_cache: dict[str, tuple] | None = None, + cif_path: Path | None = None, ) -> AtomArray: """Parses the template structure for the given chain. @@ -309,11 +314,17 @@ def parse_template_structure( The format of the template structures. ccd (CIFFile | None): Parsed CCD file. + cif_assembly_cache (dict[str, tuple] | None): + Optional cache dictionary mapping PDB IDs to (cif_file, atom_array_assembly) + tuples. Used to avoid re-parsing the same CIF file when multiple chains from + the same structure are needed. + cif_path (Path | None): + Direct path to the CIF file for CIF-direct mode. Raises: ValueError: If neither template_structure_array_directory nor - template_structures_directory is provided. + template_structures_directory nor cif_path is provided. Returns: AtomArray: @@ -321,6 +332,10 @@ def parse_template_structure( """ # Parse template IDs pdb_id, chain_id = template_pdb_chain_id.split("_") + + # Initialize cache if not provided + if cif_assembly_cache is None: + cif_assembly_cache = {} # Parse the pre-parsed template structure array if template_structure_array_directory is not None: @@ -342,15 +357,40 @@ def parse_template_structure( "Only pickle or npz formats are supported." ) - # Parse and clean the raw template structure file - elif template_structures_directory is not None: - # Parse the full template assembly and subset assembly to template chain - result = parse_mmcif( - template_structures_directory / Path(f"{pdb_id}.{template_file_format}") + # CIF-direct mode: use the provided CIF path directly + elif cif_path is not None: + cache_key = str(cif_path) + if cache_key in cif_assembly_cache: + cif_file, atom_array_template_assembly = cif_assembly_cache[cache_key] + logger.info(f"[CACHE HIT] Using cached CIF-direct assembly for {cif_path}") + else: + result = parse_mmcif(cif_path) + if isinstance(result, SkippedStructure): + return None + cif_file, atom_array_template_assembly = result + cif_assembly_cache[cache_key] = (cif_file, atom_array_template_assembly) + logger.info(f"[CIF-DIRECT] Parsed {cif_path}") + + atom_array_template_chain = clean_template_atom_array( + atom_array_template_assembly, cif_file, chain_id, ccd ) - if isinstance(result, SkippedStructure): - return None - cif_file, atom_array_template_assembly = result + + # Parse and clean the raw template structure file from directory + elif template_structures_directory is not None: + # Check cache first to avoid re-parsing the same CIF file for different chains + if pdb_id in cif_assembly_cache: + cif_file, atom_array_template_assembly = cif_assembly_cache[pdb_id] + logger.info(f"[CACHE HIT] Using cached assembly for {pdb_id}") + else: + result = parse_mmcif( + template_structures_directory / Path(f"{pdb_id}.{template_file_format}") + ) + if isinstance(result, SkippedStructure): + return None + cif_file, atom_array_template_assembly = result + cif_assembly_cache[pdb_id] = (cif_file, atom_array_template_assembly) + logger.info(f"[CACHE MISS] Parsed and cached assembly for {pdb_id}") + # Clean up the template atom array and subset to the chosen template chain atom_array_template_chain = clean_template_atom_array( atom_array_template_assembly, cif_file, chain_id, ccd @@ -358,7 +398,7 @@ def parse_template_structure( else: raise ValueError( "Either template_structure_array_directory or " - "template_structures_directory must be provided." + "template_structures_directory or cif_path must be provided." ) return atom_array_template_chain @@ -493,6 +533,7 @@ def align_template_to_query( template_file_format: str, ccd: CIFFile | None, atom_array_query_chain: AtomArray, + cif_assembly_cache: dict[str, tuple] | None = None, ) -> list[AtomArray]: """Identifies the subset of atoms in the template that align to the query. @@ -510,6 +551,10 @@ def align_template_to_query( Parsed CCD file. atom_array_query_chain (AtomArray): The cropped atom array containing atoms of the current protein chain.). + cif_assembly_cache (dict[str, tuple] | None): + Optional cache dictionary mapping PDB IDs to (cif_file, atom_array_assembly) + tuples. Used to avoid re-parsing the same CIF file when multiple chains from + the same structure are needed. Returns: list[AtomArray]: @@ -535,6 +580,8 @@ def align_template_to_query( template_pdb_chain_id, template_file_format, ccd, + cif_assembly_cache=cif_assembly_cache, + cif_path=template_cache_entry.cif_path, ) if not atom_array_template_chain: continue diff --git a/openfold3/core/data/tools/colabfold_msa_server.py b/openfold3/core/data/tools/colabfold_msa_server.py index be95746fc..849dc47d1 100644 --- a/openfold3/core/data/tools/colabfold_msa_server.py +++ b/openfold3/core/data/tools/colabfold_msa_server.py @@ -896,7 +896,15 @@ def add_msa_paths_to_iqs( stacklevel=2, ) chain.paired_msa_file_paths = [paired_msa_file_paths] - + + if chain.template_cif_paths is not None: + warnings.warn( + f"Query {query_name} chain {chain} already has " + "template_cif_paths set. These are not overwritten with " + "path(s) to the template CIF files from the ColabFold MSA server.", + stacklevel=2, + ) + continue # Add template alignment file paths template_alignment_file_path = ( output_directory diff --git a/openfold3/projects/of3_all_atom/config/inference_query_format.py b/openfold3/projects/of3_all_atom/config/inference_query_format.py index 0fa7d2096..f0fc86e26 100644 --- a/openfold3/projects/of3_all_atom/config/inference_query_format.py +++ b/openfold3/projects/of3_all_atom/config/inference_query_format.py @@ -20,6 +20,7 @@ DirectoryPath, FilePath, field_serializer, + model_validator, ) from openfold3.core.config.config_utils import ( @@ -66,12 +67,43 @@ class Chain(BaseModel): template_entry_chain_ids: ( Annotated[list[str], BeforeValidator(_ensure_list)] | None ) = None + template_cif_paths: ( + Annotated[list[FilePath], BeforeValidator(_ensure_list)] | None + ) = None + template_cif_chain_ids: ( + Annotated[list[str | None], BeforeValidator(_ensure_list)] | None + ) = None sdf_file_path: FilePath | None = None @field_serializer("molecule_type", return_type=str) def serialize_enum_name(self, v: MoleculeType, _info): return v.name + @model_validator(mode='after') + def validate_template_inputs(self) -> 'Chain': + """Validate template input consistency.""" + if (self.template_alignment_file_path is not None and + self.template_cif_paths is not None): + raise ValueError( + f"Chain {self.chain_ids}: Cannot specify both " + "'template_alignment_file_path' and 'template_cif_paths'" + ) + + if self.template_cif_chain_ids is not None: + if self.template_cif_paths is None: + raise ValueError( + f"Chain {self.chain_ids}: 'template_cif_chain_ids' can only " + "be specified when 'template_cif_paths' is provided" + ) + if len(self.template_cif_chain_ids) != len(self.template_cif_paths): + raise ValueError( + f"Chain {self.chain_ids}: Length mismatch - " + f"{len(self.template_cif_paths)} CIF files but " + f"{len(self.template_cif_chain_ids)} chain IDs specified" + ) + + return self + # TODO(jennifer): Add validations to this class # - if molecule type is protein / dna / rna - must specify sequence # - if molecule type is ligand - either ccd or smiles needs to be specifified