diff --git a/drugforge-spectrum/drugforge/spectrum/alphafold.py b/drugforge-spectrum/drugforge/spectrum/alphafold.py new file mode 100644 index 00000000..b8cda914 --- /dev/null +++ b/drugforge-spectrum/drugforge/spectrum/alphafold.py @@ -0,0 +1,331 @@ +""" +drugforge/spectrum/alphafold.py +================================ +Pydantic models and pure functions for building AlphaFold 3 JSON input files. + +Public API +---------- +Af3ProteinChain – AF3 JSON representation of a single protein chain +Af3Input – top-level AF3 JSON input (Pydantic model) +make_msa_inputs – FASTA → list of MSA-stage Af3Input objects +make_fold_inputs – MSA output dir → list of fold-stage Af3Input objects +""" + +import json +import logging +from pathlib import Path +from typing import Optional + +from drugforge.spectrum.schema import ProteinSequence, SequenceList +from pydantic import BaseModel, Field + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + + +class Af3ProteinChain(BaseModel): + """AF3 JSON representation of a single protein chain.""" + + id: str = Field("A", description="Chain ID letter used in the AF3 JSON.") + sequence: str = Field(..., description="One-letter amino acid sequence.") + description: Optional[str] = Field(None, description="Human-readable label.") + unpairedMsa: Optional[str] = Field( + None, + description=( + "Pre-computed unpaired MSA in A3M format. " + "Set to None to let AF3 build the MSA (data pipeline). " + "Set to '' to run MSA-free." + ), + ) + pairedMsa: Optional[str] = Field( + None, + description=( + "Pre-computed paired MSA in A3M format. " + "Set to None alongside unpairedMsa=None so AF3 builds both. " + "Set to '' when providing a custom unpairedMsa for a single chain." + ), + ) + templates: Optional[list] = Field( + None, + description=( + "List of structural templates. None → AF3 searches for templates. " + "[] → run template-free." + ), + ) + + def to_af3_dict(self, version: int = 2) -> dict: + """Serialise to the AF3 JSON 'protein' sub-dict.""" + d: dict = {"id": self.id, "sequence": self.sequence} + if self.description is not None: + if version >= 4: + d["description"] = self.description + # Always write MSA fields explicitly so AF3 interprets them correctly + d["unpairedMsa"] = self.unpairedMsa + d["pairedMsa"] = self.pairedMsa + d["templates"] = self.templates + return d + + +class Af3Input(BaseModel): + """Top-level AF3 JSON input for a single folding job.""" + + name: str = Field(..., description="Job name; used to name output files.") + model_seeds: list[int] = Field( + default_factory=lambda: [1, 2, 5, 10], + description="List of integer random seeds. At least one required.", + ) + chains: list[Af3ProteinChain] = Field( + ..., description="Protein chains to include in the folding job." + ) + dialect: str = Field("alphafold3", description="Must be 'alphafold3'.") + version: int = Field(2, description="AF3 JSON format version.") + + def to_af3_dict(self) -> dict: + """Serialise to the full AF3 JSON structure.""" + return { + "name": self.name, + "modelSeeds": self.model_seeds, + "sequences": [{"protein": chain.to_af3_dict()} for chain in self.chains], + "dialect": self.dialect, + "version": self.version, + } + + def write(self, output_dir: str | Path) -> Path: + """Write this input to ``/.json`` and return the path.""" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + out_path = output_dir / f"{self.name}.json" + with open(out_path, "w") as fh: + json.dump(self.to_af3_dict(), fh, indent=2) + return out_path + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _read_msa_output(msa_output_dir: Path, name: str) -> tuple[str, list, str]: + """Extract ``(unpairedMsa, templates, sequence)`` from an AF3 MSA output. + + AF3 writes ``//_data.json`` after the data + pipeline completes. + + Parameters + ---------- + msa_output_dir: + Root directory of AF3 MSA outputs. + name: + Sequence / job name. + + Returns + ------- + tuple[str, list, str] + ``(unpairedMsa, templates, sequence)`` + """ + # resolve() follows symlinks – Nextflow stages inputs as symlinks into + # the process work directory. + data_json = ( + Path(msa_output_dir).resolve() / name.lower() / f"{name.lower()}_data.json" + ) + if not data_json.exists(): + raise FileNotFoundError( + f"MSA output not found for '{name}': expected {data_json}" + ) + with open(data_json) as fh: + data = json.load(fh) + protein = data["sequences"][0]["protein"] + return protein["unpairedMsa"], protein["templates"], protein["sequence"] + + +# --------------------------------------------------------------------------- +# Core functions +# --------------------------------------------------------------------------- + + +def make_msa_inputs( + fasta_path: str | Path, + seeds: list[int] | None = None, + description_prefix: str = "", +) -> list[Af3Input]: + """Build MSA-stage AF3 inputs from a FASTA file. + + Each returned :class:`Af3Input` has all MSA fields set to ``None`` so that + AlphaFold 3 runs its data pipeline (Jackhmmer / Nhmmer) to build MSAs. + Run the resulting JSONs with ``--norun_inference``. + + Parameters + ---------- + fasta_path: + Path to a FASTA file containing protein sequences. + seeds: + Integer model seeds. Defaults to ``[1, 2, 5, 10]``. + description_prefix: + Optional string prepended to each chain description, e.g. ``"2A protease"``. + + Returns + ------- + list[Af3Input] + One :class:`Af3Input` per sequence in the FASTA. + """ + if seeds is None: + seeds = [1, 2, 5, 10] + + seq_list = SequenceList.from_fasta(fasta_path, aligned=False) + + return [ + Af3Input( + name=seq.seq_id, + model_seeds=seeds, + chains=[ + Af3ProteinChain( + id="A", + sequence=seq.sequence, + description=( + f"{description_prefix} – {seq.seq_id}" + if description_prefix + else seq.seq_id + ), + unpairedMsa=None, + pairedMsa=None, + templates=None, + ) + ], + ) + for seq in seq_list + ] + + +def make_fold_inputs( + msa_output_dir: str | Path, + seeds: list[int] | None = None, + fasta_path: str | Path | None = None, +) -> list[Af3Input]: + """Build fold-stage AF3 inputs from pre-computed MSA outputs. + + Reads each ``/_data.json`` written by the AF3 data pipeline + and embeds the ``unpairedMsa`` and ``templates`` into a new + :class:`Af3Input` ready for GPU inference. Run with + ``--norun_data_pipeline``. + + Parameters + ---------- + msa_output_dir: + Directory containing one sub-directory per sequence, each holding the + AF3 data-pipeline output JSON. + seeds: + Integer model seeds. Defaults to ``[1, 2, 5, 10]``. + fasta_path: + Optional FASTA used to control which sequences are processed and in + what order. When omitted every sub-directory in *msa_output_dir* is + used (sorted alphabetically). + + Returns + ------- + list[Af3Input] + One :class:`Af3Input` per sequence. + """ + if seeds is None: + seeds = [1, 2, 5, 10] + + # resolve() follows symlinks – important when the directory is staged + # by Nextflow as a symlink pointing to another process's work directory. + msa_dir = Path(msa_output_dir).resolve() + + if fasta_path is not None: + seq_list = SequenceList.from_fasta(fasta_path, aligned=False) + names = [seq.seq_id for seq in seq_list] + else: + # is_dir() follows symlinks, so staged symlink dirs are included. + names = sorted(p.name for p in msa_dir.iterdir() if p.is_dir()) + + if not names: + raise ValueError(f"No sequence directories found in {msa_dir}") + + inputs: list[Af3Input] = [] + for name in names: + unpaired_msa, templates, sequence = _read_msa_output(msa_dir, name) + inputs.append( + Af3Input( + name=name, + model_seeds=seeds, + chains=[ + Af3ProteinChain( + id="A", + sequence=sequence, + description=name, + unpairedMsa=unpaired_msa, + pairedMsa="", # single chain – no inter-chain pairing + templates=templates, + ) + ], + ) + ) + return inputs + + +def select_best_af3( + af3_output_dir: str | Path, + seq_name: str, + ref_pdb: str | Path, + chain: str = "A", + final_pdb: str | Path = "aligned_protein.pdb", +) -> tuple[float, str]: + """Select the best-ranked AF3 model for a sequence and align it to a reference. + + AF3 writes one subdirectory per job named after the sequence. Inside, models + are ranked and named ``_model.cif`` (rank 0 is best). This function + picks the rank-0 model, aligns it to *ref_pdb*, and saves the aligned + structure as a PDB. + + Parameters + ---------- + af3_output_dir: + Root directory of AF3 fold outputs (one sub-directory per sequence). + seq_name: + Name of the sequence / job (must match the sub-directory name). + ref_pdb: + Path to the reference PDB to align against. + chain: + Chain ID to use for alignment, by default ``"A"``. + final_pdb: + Path where the aligned PDB will be saved. + + Returns + ------- + tuple[float, str] + ``(rmsd, path_to_aligned_pdb)`` + + Raises + ------ + FileNotFoundError + If the AF3 output directory or the sequence sub-directory is not found, + or if no CIF model files are present. + """ + # Defer import to avoid circular dependency with calculate_rmsd + from drugforge.spectrum.calculate_rmsd import rmsd_alignment + + af3_output_dir = Path(af3_output_dir) + seq_dir = af3_output_dir / seq_name.lower() + if not seq_dir.exists(): + raise FileNotFoundError( + f"AF3 output directory for '{seq_name}' not found: {seq_dir}" + ) + + # AF3 lowercases job names; rank-0 model is first when sorted by name. + candidates = sorted(seq_dir.glob(f"{seq_name.lower()}*model*.cif")) + if not candidates: + raise FileNotFoundError( + f"No AF3 CIF model files found for '{seq_name}' in {seq_dir}" + ) + + best_cif = candidates[0] + logging.info(f"Selected AF3 model for '{seq_name}': {best_cif.name}") + + rmsd, aligned_pdb = rmsd_alignment( + str(best_cif), str(ref_pdb), str(final_pdb), chain, chain + ) + logging.info(f"RMSD for '{seq_name}' vs reference: {rmsd:.3f} Å") + + return rmsd, str(aligned_pdb) diff --git a/drugforge-spectrum/drugforge/spectrum/boltz.py b/drugforge-spectrum/drugforge/spectrum/boltz.py new file mode 100644 index 00000000..63999860 --- /dev/null +++ b/drugforge-spectrum/drugforge/spectrum/boltz.py @@ -0,0 +1,124 @@ +""" +drugforge/spectrum/boltz.py +============================ +Pydantic models and pure functions for building Boltz-1 YAML input files. + +Boltz-1 takes a single YAML per folding job that describes all chains +(proteins, ligands, nucleic acids). This module covers the single-protein +and protein+ligand cases relevant to the 2A-protease panel. + +Public API +---------- +BoltzProteinChain – a protein chain entry in a Boltz YAML +BoltzLigandChain – a small-molecule ligand entry in a Boltz YAML +BoltzInput – top-level Boltz YAML input (Pydantic model) +make_boltz_inputs – FASTA → list of BoltzInput objects (one per sequence) +""" + +from pathlib import Path +from typing import Optional + +import yaml +from drugforge.spectrum.schema import SequenceList +from pydantic import BaseModel, Field + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + + +class BoltzProteinChain(BaseModel): + """A single protein chain in a Boltz YAML input.""" + + id: str = Field("A", description="Chain ID letter.") + sequence: str = Field(..., description="One-letter amino acid sequence.") + + def to_boltz_dict(self) -> dict: + return {"protein": {"id": self.id, "sequence": self.sequence}} + + +class BoltzLigandChain(BaseModel): + """A small-molecule ligand chain in a Boltz YAML input.""" + + id: str = Field("L", description="Chain ID letter for the ligand.") + smiles: str = Field(..., description="SMILES string for the ligand.") + + def to_boltz_dict(self) -> dict: + return {"ligand": {"id": self.id, "smiles": self.smiles}} + + +class BoltzInput(BaseModel): + """Top-level Boltz YAML input for a single folding job.""" + + name: str = Field(..., description="Job name; used to name the output YAML file.") + version: int = Field(1, description="Boltz input format version.") + protein_chains: list[BoltzProteinChain] = Field( + ..., description="Protein chains to include in the folding job." + ) + ligand_chains: list[BoltzLigandChain] = Field( + default_factory=list, + description="Optional small-molecule ligand chains.", + ) + + def to_boltz_dict(self) -> dict: + """Serialise to the Boltz YAML structure.""" + sequences = [chain.to_boltz_dict() for chain in self.protein_chains] + sequences += [chain.to_boltz_dict() for chain in self.ligand_chains] + return {"version": self.version, "sequences": sequences} + + def write(self, output_dir: str | Path) -> Path: + """Write this input to ``/.yaml`` and return the path.""" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + out_path = output_dir / f"{self.name}.yaml" + with open(out_path, "w") as fh: + yaml.dump( + self.to_boltz_dict(), fh, default_flow_style=False, sort_keys=False + ) + return out_path + + +# --------------------------------------------------------------------------- +# Core functions +# --------------------------------------------------------------------------- + + +def make_boltz_inputs( + fasta_path: str | Path, + ligand_smiles: Optional[str] = None, + ligand_id: str = "L", +) -> list[BoltzInput]: + """Build Boltz YAML inputs from a FASTA file. + + Creates one :class:`BoltzInput` per sequence. If *ligand_smiles* is + provided, a ligand chain is appended to every input — useful when folding + a panel of proteins all against the same small molecule. + + Parameters + ---------- + fasta_path: + Path to a FASTA file containing protein sequences. + ligand_smiles: + Optional SMILES string for a ligand to include in every input. + ligand_id: + Chain ID for the ligand, by default ``"L"``. + + Returns + ------- + list[BoltzInput] + One :class:`BoltzInput` per sequence in the FASTA. + """ + seq_list = SequenceList.from_fasta(fasta_path, aligned=False) + + ligand_chains = ( + [BoltzLigandChain(id=ligand_id, smiles=ligand_smiles)] if ligand_smiles else [] + ) + + return [ + BoltzInput( + name=seq.seq_id, + protein_chains=[BoltzProteinChain(id="A", sequence=seq.sequence)], + ligand_chains=ligand_chains, + ) + for seq in seq_list + ] diff --git a/drugforge-spectrum/drugforge/spectrum/docking.py b/drugforge-spectrum/drugforge/spectrum/docking.py new file mode 100644 index 00000000..4cc72ebf --- /dev/null +++ b/drugforge-spectrum/drugforge/spectrum/docking.py @@ -0,0 +1,283 @@ +""" +drugforge/spectrum/docking.py +============================== +Ligand transfer docking for predicted protein structures. + +Strips away the target-tag, ML scoring, MD, Dask, HTML visualisation, and +caching concerns from the full drugforge-workflows implementation so that the +workflow can be run against novel viral protease sequences without a registered +TargetTags entry. + +Public API +---------- +ligand_transfer_docking – run POSIT ligand transfer docking for a directory + of predicted structures against Fragalysis reference + complexes, returning a results DataFrame and SDF. +""" + +import logging +import warnings +from pathlib import Path +from shutil import rmtree + +import pandas as pd +from drugforge.data.readers.meta_structure_factory import MetaStructureFactory +from drugforge.data.readers.structure_dir import StructureDirFactory +from drugforge.data.util.dask_utils import BackendType, FailureMode +from drugforge.data.util.logging import FileLogger +from drugforge.docking.docking import write_results_to_multi_sdf +from drugforge.docking.docking_data_validation import DockingResultCols +from drugforge.docking.meta_scorer import MetaScorer +from drugforge.docking.openeye import POSIT_METHOD, POSIT_RELAX_MODE, POSITDocker +from drugforge.docking.scorer import ChemGauss4Scorer +from drugforge.docking.selectors.selector_list import StructureSelector +from drugforge.modeling.protein_prep import LigandTransferProteinPrepper + +logger = logging.getLogger(__name__) + + +def ligand_transfer_docking( + target_structure_dir: str | Path, + reference_fragalysis_dir: str | Path, + output_dir: str | Path, + ref_chain: str = "A", + active_site_chain: str = "A", + posit_method: POSIT_METHOD = POSIT_METHOD.ALL, + relax_mode: POSIT_RELAX_MODE = POSIT_RELAX_MODE.NONE, + use_omega: bool = False, + num_poses: int = 1, + allow_retries: bool = True, + allow_final_clash: bool = True, + posit_confidence_cutoff: float = 0.1, + overwrite: bool = True, + failure_mode: FailureMode = FailureMode.SKIP, + loglevel: int = logging.INFO, +) -> pd.DataFrame: + """Run POSIT ligand transfer docking for predicted structures. + + Loads predicted protein structures (AF3 / Boltz PDB files) from + ``target_structure_dir``, aligns each to every reference complex in + ``reference_fragalysis_dir`` using ``LigandTransferProteinPrepper``, + then runs POSIT self-docking on each transferred pose. + + No target tag, ML scoring, MD, HTML visualisation, Dask, or caching is + required – this function is intentionally minimal. + + Parameters + ---------- + target_structure_dir: + Directory of predicted PDB files (one per sequence). Globs ``*.pdb``. + reference_fragalysis_dir: + Fragalysis-format directory of reference crystal complexes + (``/.pdb`` + ``/.sdf``). + output_dir: + Directory to write results. Created if absent; overwritten if + ``overwrite=True``. + ref_chain: + Chain ID in the reference complex used for structural alignment. + active_site_chain: + Chain ID in the target structure used for structural alignment. + posit_method: + POSIT method(s) to use. Defaults to ``ALL``. + relax_mode: + When to relax clashing atoms. Defaults to ``NONE``. + use_omega: + Whether to enumerate conformers with OEOmega before docking. + num_poses: + Number of docked poses to return per pair. + allow_retries: + Whether POSIT may retry with relaxed settings on failure. + allow_final_clash: + Whether to keep poses that clash in the final docking stage. + posit_confidence_cutoff: + Minimum POSIT confidence score to keep a result. + overwrite: + Whether to wipe and recreate ``output_dir`` if it already exists. + failure_mode: + How to handle per-structure failures – ``SKIP`` or ``RAISE``. + loglevel: + Python logging level. + + Returns + ------- + pd.DataFrame + Final scored and filtered docking results, sorted by ChemGauss4 score. + Also written to ``output_dir/docking_results_final.csv``. + """ + output_dir = Path(output_dir) + if output_dir.exists() and overwrite: + rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + file_logger = FileLogger( + logname="ligand_transfer_docking", + path=str(output_dir), + logfile="ligand_transfer_docking.log", + level=loglevel, + stdout=True, + ) + log = file_logger.getLogger() + data_intermediates = output_dir / "data_intermediates" + data_intermediates.mkdir(exist_ok=True) + + # ------------------------------------------------------------------ + # Load reference complexes from Fragalysis directory + # ------------------------------------------------------------------ + log.info(f"Loading reference complexes from {reference_fragalysis_dir}") + ref_factory = MetaStructureFactory( + fragalysis_dir=reference_fragalysis_dir, + structure_dir=None, + pdb_file=None, + ) + ref_complexes = ref_factory.load(use_dask=False, failure_mode=failure_mode) + log.info(f"Loaded {len(ref_complexes)} reference complexes") + + # ------------------------------------------------------------------ + # Load predicted target structures from flat PDB directory + # ------------------------------------------------------------------ + log.info(f"Loading target structures from {target_structure_dir}") + target_factory = StructureDirFactory.from_dir(target_structure_dir) + targets = target_factory.load(use_dask=False, failure_mode=failure_mode) + log.info(f"Loaded {len(targets)} target structures") + + # ------------------------------------------------------------------ + # Prep: align each target to each reference, transfer ligand coordinates + # ------------------------------------------------------------------ + log.info("Running LigandTransferProteinPrepper (align + ligand transfer)...") + prepper = LigandTransferProteinPrepper( + reference_complexes=ref_complexes, + ref_chain=ref_chain, + active_site_chain=active_site_chain, + seqres_yaml=None, # no mutation – predicted structures are complete + loop_db=None, # no loop filling – predicted structures are complete + ) + prepped = prepper.prep( + targets, + use_dask=False, + failure_mode=failure_mode, + cache_dir=None, + use_only_cache=False, + ) + log.info(f"Prepped {len(prepped)} target-reference pairs") + + # ------------------------------------------------------------------ + # Select pairs for docking (self-docking: each ligand docked back into + # the structure it was transferred from) + # ------------------------------------------------------------------ + selector = StructureSelector.SELF_DOCKING.selector_cls() + + # De-duplicate ligands by InChIKey so pivot() doesn't fail when the same + # reference ligand appears in multiple prepped complexes. + seen: set[str] = set() + unique_ligands = [] + for pc in prepped: + ik = pc.ligand.inchikey + if ik not in seen: + seen.add(ik) + unique_ligands.append(pc.ligand) + + pairs = selector.select(unique_ligands, prepped) + log.info( + f"Selected {len(pairs)} pairs from {len(unique_ligands)} unique ligands " + f"and {len(prepped)} prepped complexes" + ) + + # ------------------------------------------------------------------ + # Dock + # ------------------------------------------------------------------ + log.info("Running POSIT docking...") + docker = POSITDocker( + relax_mode=relax_mode, + posit_method=posit_method, + use_omega=use_omega, + omega_dense=False, + num_poses=num_poses, + allow_low_posit_prob=True, + low_posit_prob_thresh=posit_confidence_cutoff, + allow_final_clash=allow_final_clash, + allow_retries=allow_retries, + last_ditch_fred=False, + ) + results = docker.dock( + pairs, + output_dir=output_dir / "docking_results", + use_dask=False, + failure_mode=failure_mode, + ) + log.info(f"Docked {len(results)} pairs successfully") + + if not results: + raise ValueError("No docking results generated – check structures and inputs.") + + # ------------------------------------------------------------------ + # Write SDF of all poses before filtering + # ------------------------------------------------------------------ + sdf_path = output_dir / "docking_results.sdf" + write_results_to_multi_sdf( + sdf_path, + results, + backend=BackendType.IN_MEMORY, + reconstruct_cls=docker.result_cls, + ) + log.info(f"Wrote all poses to {sdf_path}") + + # ------------------------------------------------------------------ + # Score with ChemGauss4 + # ------------------------------------------------------------------ + log.info("Scoring with ChemGauss4...") + scorer = MetaScorer(scorers=[ChemGauss4Scorer()]) + scores_df = scorer.score( + results, + use_dask=False, + failure_mode=failure_mode, + return_df=True, + backend=BackendType.IN_MEMORY, + reconstruct_cls=docker.result_cls, + return_for_disk_backend=True, + ) + scores_df.to_csv(data_intermediates / "docking_scores_raw.csv", index=False) + + # ------------------------------------------------------------------ + # Filter by POSIT confidence + # ------------------------------------------------------------------ + n_before = len(scores_df) + scores_df = scores_df[ + scores_df[DockingResultCols.DOCKING_CONFIDENCE_POSIT.value] + > posit_confidence_cutoff + ] + log.info( + f"POSIT confidence filter: {len(scores_df)} / {n_before} results kept " + f"(cutoff={posit_confidence_cutoff})" + ) + + if scores_df.empty: + warnings.warn( + "No docking results passed the POSIT confidence cutoff – " + "raw results written to data_intermediates/docking_scores_raw.csv" + ) + + # ------------------------------------------------------------------ + # Optionally filter clashes (ChemGauss4 > 0 → clash) + # ------------------------------------------------------------------ + if not allow_final_clash: + n_before = len(scores_df) + scores_df = scores_df[ + scores_df[DockingResultCols.DOCKING_SCORE_POSIT.value] <= 0 + ] + log.info(f"Clash filter: {len(scores_df)} / {n_before} results kept") + + # ------------------------------------------------------------------ + # Sort and write final CSV + # ------------------------------------------------------------------ + scores_df = scores_df.sort_values( + DockingResultCols.DOCKING_SCORE_POSIT.value, ascending=True + ) + scores_df.to_csv( + data_intermediates / "docking_scores_filtered_sorted.csv", index=False + ) + + final_csv = output_dir / "docking_results_final.csv" + scores_df.to_csv(final_csv, index=False) + log.info(f"Wrote final results to {final_csv}") + + return scores_df diff --git a/drugforge-spectrum/drugforge/spectrum/schema.py b/drugforge-spectrum/drugforge/spectrum/schema.py new file mode 100644 index 00000000..1d9f6c6c --- /dev/null +++ b/drugforge-spectrum/drugforge/spectrum/schema.py @@ -0,0 +1,703 @@ +import io +import subprocess +import tempfile +from pathlib import Path +from warnings import warn + +import numpy as np +import pandas as pd +from Bio import AlignIO, SeqIO +from Bio.Align import MultipleSeqAlignment +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord +from bokeh.layouts import column +from bokeh.models import ColumnDataSource, LabelSet, LinearAxis, Range1d +from bokeh.models.glyphs import Rect, Text + +# Bokeh imports +from bokeh.plotting import figure, output_file, save +from drugforge.spectrum.seq_alignment import get_colors_by_aa_group, get_colors_protein +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self + + +class ProteinSequence(BaseModel): + seq_id: str = Field(..., description="Unique identifier for the protein sequence") + aligned: bool = Field( + ..., + description="Indicates whether the sequence is aligned (True) or unaligned (False)", + ) + sequence: str = Field( + ..., + description="Amino acid sequence of the protein", + ) + + @model_validator(mode="after") + def validate_sequence(self) -> Self: + if not self.aligned: + valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") + if not all(residue in valid_amino_acids for residue in self.sequence): + raise ValueError( + "Protein sequence contains invalid characters. Only standard amino acids are allowed." + ) + return self + + def get_unaligned_sequence(self) -> Self: + if self.aligned: + return ProteinSequence( + seq_id=self.seq_id, + sequence=self.sequence.replace("-", ""), + aligned=False, + ) + else: + return self + + +class SequenceList(BaseModel): + aligned: bool = Field( + None, + description="Indicates whether all the sequences are aligned (True) or unaligned (False).", + ) + sequences: list[ProteinSequence] = Field( + ..., description="List of protein sequences" + ) + + def __iter__(self): + return iter(self.sequences) + + @model_validator(mode="after") + def validate_sequence_length(self) -> Self: + if self.aligned: + seq_lengths = {len(seq.sequence) for seq in self.sequences} + if len(seq_lengths) > 1: + raise ValueError( + "All sequences must be the same length when 'aligned' is True." + ) + return self + + @classmethod + def from_fasta(cls, input_fasta, aligned: bool): + """Load sequences from a FASTA file and return a list of ProteinSequence or AlignedSequence objects.""" + input_fasta = Path(input_fasta) + if not input_fasta.exists(): + raise ValueError(f"FASTA file does not exist: {input_fasta}") + if not input_fasta.suffix == ".fasta": + raise ValueError("Fasta file must be in FASTA format") + sequences = [] + for record in SeqIO.parse(input_fasta, "fasta"): + sequences.append( + ProteinSequence( + seq_id=record.id, sequence=str(record.seq), aligned=aligned + ) + ) + return cls(aligned=aligned, sequences=sequences) + + def to_bio_seq_records(self) -> list[SeqRecord]: + seq_recs = [ + SeqRecord(Seq(sequence.sequence), id=sequence.seq_id) + for sequence in self.sequences + ] + return seq_recs + + def to_dataframe(self) -> pd.DataFrame: + records = [seq.model_dump() for seq in self.sequences] + return pd.DataFrame.from_records(records) + + @classmethod + def from_dataframe(cls, df: pd.DataFrame) -> Self: + records = df.to_dict(orient="records", index=True) + sequences = [ProteinSequence(**record) for record in records] + aligned_list = [seq.aligned for seq in sequences] + if not all(aligned_list): + if any(aligned_list): + warn(f"Some, but not all of the sequences are aligned: {sequences}") + aligned = False + if all(aligned_list): + aligned = True + return cls(aligned=aligned, sequences=sequences) + + def to_fasta(self, output_fasta: str | Path): + """Convert sequences to a FASTA file""" + output_fasta = Path(output_fasta) + output_fasta.parent.mkdir(parents=True, exist_ok=True) + if not output_fasta.suffix == ".fasta": + raise ValueError( + "Fasta file must be in FASTA format and have the .fasta extension" + ) + seq_recs = self.to_bio_seq_records() + SeqIO.write(seq_recs, output_fasta, "fasta") + + if not output_fasta.exists(): + raise RuntimeError( + f"SeqIO failed to write these sequences to {output_fasta}:\n {seq_recs}" + ) + + return output_fasta + + def to_csv(self, output_csv: str | Path): + """Convert sequences to a CSV file""" + output_csv = Path(output_csv) + output_csv.parent.mkdir(parents=True, exist_ok=True) + if not output_csv.suffix == ".csv": + raise ValueError("CSV file must be in .csv format") + df = self.to_dataframe() + df.to_csv(output_csv, index=False) + if not output_csv.exists(): + raise RuntimeError( + f"CSV failed to write these sequences to {output_csv}:\n {df}" + ) + return output_csv + + @classmethod + def from_csv(cls, input_csv: str | Path) -> Self: + """Load sequences from a CSV file""" + input_csv = Path(input_csv) + if not input_csv.suffix == ".csv": + raise ValueError("CSV file must be in .csv format") + df = pd.read_csv(input_csv) + return cls.from_dataframe(df) + + def serialize(self, output_dir: str | Path) -> Path: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + self.to_fasta(output_dir / "sequences.fasta") + self.to_csv(output_dir / "sequences.csv") + + def to_bio_alignment_obj( + self, + ) -> MultipleSeqAlignment: + """Returns a MultipleSeqAlignment object representing the alignment""" + buf = io.StringIO() + SeqIO.write(self.to_bio_seq_records(), buf, "fasta") + buf.seek(0) + return AlignIO.read(buf, "fasta") + + +def run_multiple_sequence_alignment(sequences: SequenceList) -> SequenceList: + """Run a multiple sequence alignment using MAFFT. + + Writes the input sequences to a temporary FASTA file, runs MAFFT, + captures the aligned output in a second temporary file, and returns + the result as an aligned SequenceList. + + Parameters + ---------- + sequences : SequenceList + Unaligned sequences to be aligned. + + Returns + ------- + SequenceList + A new SequenceList with ``aligned=True`` containing the MAFFT-aligned sequences. + """ + with tempfile.NamedTemporaryFile( + suffix=".fasta", mode="w", delete=False + ) as input_file: + input_path = Path(input_file.name) + + try: + SeqIO.write(sequences.to_bio_seq_records(), input_path, "fasta") + cmd = ["mafft", str(input_path)] + result = subprocess.run(cmd, capture_output=True, check=True, text=True) + aligned_records = list(SeqIO.parse(io.StringIO(result.stdout), "fasta")) + aligned_sequences = [ + ProteinSequence(seq_id=rec.id, sequence=str(rec.seq), aligned=True) + for rec in aligned_records + ] + return SequenceList(aligned=True, sequences=aligned_sequences) + finally: + input_path.unlink(missing_ok=True) + + +def find_bsite_resids( + pdb: str | Path, + ligres: str = "LIG", + chain: str = "A", + bsite_dist: float = 4.5, + res_threshold: int = 5, +) -> np.ndarray: + """Find binding site residues in a single protein-ligand complex based on ligand proximity. + + Unlike the version in ``calculate_rmsd``, this function operates on a single + input structure and does not require a separate reference PDB or alignment + step — the ligand must already be present in ``pdb``. + + Parameters + ---------- + pdb : str or Path + Path to the PDB file containing the protein-ligand complex. + ligres : str, optional + Residue name of the ligand, by default "LIG" + chain : str, optional + Chain ID of the protein and ligand, by default "A" + bsite_dist : float, optional + Distance from the ligand in angstroms used to define the binding site, + by default 4.5 + res_threshold : int, optional + Minimum residue ID to be considered a binding site residue. Avoids + assigning terminal residues incorrectly, by default 5 + + Returns + ------- + np.ndarray + Sorted array of binding site residue IDs. + + Raises + ------ + ValueError + If no ligand atoms are found for ``ligres`` in ``chain``. + ValueError + If no binding site residues are found above ``res_threshold``. + """ + import MDAnalysis as mda + + u = mda.Universe(str(pdb)) + + lig_atoms = u.select_atoms(f"chainid {chain} and resname {ligres}") + if len(lig_atoms) == 0: + raise ValueError( + f"No ligand atoms found for resname '{ligres}' in chain '{chain}' of {pdb}" + ) + + bs_atoms = u.select_atoms( + f"protein and chainid {chain} and around {bsite_dist} resname {ligres}" + ) + + bs_resids = np.unique(bs_atoms.resids) + bs_resids = bs_resids[bs_resids >= res_threshold] + + if len(bs_resids) == 0: + raise ValueError( + f"No binding site residues found within {bsite_dist} Å of '{ligres}' " + f"with residue ID >= {res_threshold}" + ) + + return bs_resids + + +def view_alignment( + sequence_list: SequenceList, + fontsize="11pt", + plot_width=800, + file_name="alignment", + color_by_group=False, + start_idx=0, + skip=4, + max_mismatch=2, + reorder: list = None, + output_dir=None, + scores: list = None, + x_offsets: float = None, + bsite_resids: np.ndarray | list | None = None, +): + """ "Bokeh sequence alignment view + From: https://dmnfarrell.github.io/bioinformatics/bokeh-sequence-aligner + + Parameters + ---------- + sequence_list : SequenceList + fontsize : str, optional + Size of aminoacid one-letter IDs, by default "11pt" + plot_width : int, optional + width of alignment plot, by default 800 + file_name : str, optional + suffix for html file, by default "alignment" + color_by_group : bool, optional + View mode where matching aminoacids are colored, by default False + start_idx : int, optional + Index of first aminiacid of reference sequence, by default 0 + skip : int, optional + Skip for displayed indexes of reference sequence , by default 4 + max_mismatch : int, optional + How many mismatches are tolerated for highlighted group match, by default 2 + reorder : list, optional + List of indices to reorder sequences, by default None + output_dir : str or Path, optional + Directory to save the output html file, by default None (saves to cwd) + scores : list, optional + List of per-sequence scores to display as right-side labels, by default None + x_offsets : float, optional + X position of the score labels relative to the alignment width. Defaults + to ``N + 8`` (just past the right edge) when scores are provided. + bsite_resids : array-like, optional + Residue IDs of binding site residues (in the reference sequence, i.e. the + last sequence in the alignment). Continuous runs of binding site columns + are highlighted with a blue box overlay on both plots, by default None. + + Returns + ------- + (bokeh.Column, str) + Bokeh Column of layouts, path to saved html file. + """ + if not sequence_list.aligned: + raise ValueError("Sequence list must have aligned sequences.") + + # The function takes a biopython alignment object as input. + aln = sequence_list.to_bio_alignment_obj() + if reorder is not None: + aln_ref = aln[:1] # ref + aln_sorted = [aln[int(i)] for i in reorder] + aln_ref.extend(aln_sorted) + aln = aln_ref + + aln = aln[::-1] # So outputs are ordered from top to bottom + seqs = [rec.seq for rec in (aln)] # Each sequence input + text = [i for s in list(seqs) for i in s] # Al units joind on same list + + N = len(seqs[-1]) + S = len(seqs) + + # Shorten the description for display — take the part after the last ":" + # in the description field, falling back to rec.id if no ":" is present. + def matches(x): + return x.split(":")[-1] if ":" in x else x + + desc = [ + matches(rec.description) if rec.description != rec.id else rec.id for rec in aln + ] + colors_dict = {"exact": "white", "group": "orange", "none": "red"} + + # List with ALL colors + # By aminoacid group or exact match + if color_by_group: + col_colors = [] + font_colors = [] + match_keys = [] + for col in range(N): # Go through each column + # Note: AlignIO item retrieval is done through a get_item function, so this has to be done with a loop + col_string = aln[:, col] + color, font_color, match_key = get_colors_by_aa_group( + col_string, max_mismatch, colors_dict + ) + col_colors.append(color) + font_colors.append(font_color) + match_keys.append(match_key) + colors = col_colors * S + # Append each font_color list "colum-wise" + font_colors = np.array(font_colors).T.flatten() + else: + colors = get_colors_protein(seqs) + font_colors = ["black"] * len(colors) + + # Defining x indexes only for non-gap characters of ref sequence (seqs[-1]) + seq_array = np.array(list(seqs[-1])) + x_non_gap = np.full(len(seqs[-1]), " ", dtype="= 0) & (bsite_positions < len(non_gap_idx)) + ] + bsite_cols = np.sort(non_gap_idx[valid]) + + # Group consecutive alignment columns into contiguous spans + if len(bsite_cols) > 0: + span_start = bsite_cols[0] + span_end = bsite_cols[0] + for col in bsite_cols[1:]: + if col == span_end + 1: + span_end = col + else: + bsite_spans.append((span_start, span_end)) + span_start = span_end = col + bsite_spans.append((span_start, span_end)) + # creates a 2D grid of coords from the 1D arrays + xx, yy = np.meshgrid(x, y) + # flattens the arrays + gx = xx.ravel() + gy = yy.flatten() + # use recty for rect coords with an offset + recty = gy + 0.5 + # now we can create the ColumnDataSource with all the arrays + # logging.info(f"Aligning {S} sequences of lenght {N}") + # ColumnDataSource is a JSON dict that maps names to arrays of values + source = ColumnDataSource(dict(x=gx, y=gy, recty=recty, text=text, colors=colors)) + plot_height = len(seqs) * 10 + 50 + x_range = Range1d(gx[0] - 1, N + 8, bounds="auto") # (start, end) + if N > 150: + viewlen = 150 + else: + viewlen = N + # view_range is for the close up view + view_range = (gx[0] - 1, viewlen) + tools = "xpan, xwheel_zoom, reset, save" + + # Custom right-side labels — only rendered when scores are provided + if scores is not None: + _x_offsets = x_offsets if x_offsets is not None else N + 8 + right_labels1 = [f"{round(score, 1)}%" for score in scores][::-1] + source2 = ColumnDataSource( + data=dict( + x=[_x_offsets] * len(desc), + y=desc, + labels=right_labels1, + ) + ) + labels = LabelSet( + x="x", + y="y", + text="labels", + level="glyph", + x_offset=0, + y_offset=0, + source=source2, + text_align="left", + text_baseline="middle", + text_font_size=str(int(fontsize[:-2]) - 2) + "pt", + ) + else: + labels = None + + # entire sequence view (no text, with zoom) + p1 = figure( + title=None, + width=plot_width, + height=plot_height, + x_range=x_range, + y_range=desc, + tools=tools, + min_border=0, + ) + p1.toolbar_location = None + # Rect simply places rectangles of with "width" into the positions defined by x and y + rects = Rect( + x="x", + y="recty", + width=1, + height=1, + fill_color="colors", + line_color=None, + fill_alpha=0.6, + ) + # Source does mapping from keys in rects to values in ColumnDataSource definition + p1.add_glyph(source, rects) + p1.grid.visible = False + p1.xaxis.major_label_text_font_style = "bold" + p1.yaxis.major_label_text_font_size = "8pt" + p1.yaxis.minor_tick_line_width = 0 + p1.yaxis.major_tick_line_width = 0 + if labels is not None: + p1.add_layout(labels) + + def _add_bsite_boxes(plot, spans, n_seqs): + """Overlay a blue box for each continuous run of binding site columns.""" + if not spans: + return + from bokeh.models import BoxAnnotation + + for col_start, col_end in spans: + box = BoxAnnotation( + left=col_start - 0.5, + right=col_end + 0.5, + fill_color="steelblue", + fill_alpha=0.25, + line_color="steelblue", + line_alpha=0.6, + line_width=1.5, + ) + plot.add_layout(box) + + _add_bsite_boxes(p1, bsite_spans, S) + + plot_height = len(seqs) * 20 + 30 + + # sequence text view with ability to scroll along x axis + p2 = figure( + title=None, + width=plot_width, + height=plot_height, + x_range=view_range, + y_range=desc, + tools=tools, + min_border=0, + toolbar_location="below", + ) + # Text does the same thing as rectangles but placing letter (or words) instead, aligned accordingly + text_source = ColumnDataSource( + dict(x=gx, y=gy, recty=recty, text=text, colors=font_colors) + ) + glyph = Text( + x="x", + y="y", + text="text", + text_color="colors", + text_align="center", + text_font_size=fontsize, + ) + rects = Rect( + x="x", + y="recty", + width=1, + height=1, + fill_color="colors", + line_color=None, + fill_alpha=0.4, + ) + + # Blank plot to hold the position labels + p_blank = figure( + width=plot_width, + height=40, + x_range=view_range, + y_range=Range1d(0, 1), + title=None, + toolbar_location=None, + tools="", + outline_line_alpha=0, + ) + p_blank.xaxis.visible = False + p_blank.yaxis.visible = False + p_blank.grid.visible = False + label_source = ColumnDataSource(dict(x=x, y=[0.05] * len(x), text=x_non_gap)) + labels_b = Text( + x="x", + y="y", + text="text", + text_color="black", + text_align="center", + text_font_size=str(int(fontsize[:-2]) - 2) + "pt", + ) + p2.add_glyph(text_source, glyph) + p2.add_glyph(source, rects) + p_blank.add_glyph(label_source, labels_b) + if labels is not None: + p2.add_layout(labels) + _add_bsite_boxes(p2, bsite_spans, S) + _add_bsite_boxes(p_blank, bsite_spans, S) + + view_range = Range1d(gx[0] - 1, viewlen) + p2.grid.visible = True + p2.xaxis.major_label_text_font_style = "bold" + p2.yaxis.major_label_text_font_style = "bold" + p2.yaxis.minor_tick_line_width = 0 + p2.yaxis.major_tick_line_width = 0 + p2.xaxis.major_label_text_font_size = "0pt" + p2.add_layout( + LinearAxis(major_label_text_font_size="0pt", ticker=list(x_non_gap_locs)), + "above", + ) + p2.x_range = view_range + p_blank.x_range = view_range + + # --- Legend --- + # Build a row of labelled swatches explaining the colours used in the plot. + from bokeh.models import Div + + if color_by_group: + legend_items = [ + ("Exact match (all identical)", "white", "black"), + ("Group match (same amino acid group)", "orange", "black"), + ("No match", "red", "black"), + ] + legend_title = "Color key: amino acid group matching" + else: + # Per-amino-acid colour map from _AMINO_ACID_COLORS + _AA_COLORS = { + "A": "red", + "R": "blue", + "N": "green", + "D": "yellow", + "C": "orange", + "Q": "purple", + "E": "cyan", + "G": "magenta", + "H": "pink", + "I": "brown", + "L": "gray", + "K": "lime", + "M": "teal", + "F": "navy", + "P": "olive", + "S": "maroon", + "T": "silver", + "W": "gold", + "Y": "skyblue", + "V": "violet", + "-": "white", + } + _AA_NAMES = { + "A": "Ala", + "R": "Arg", + "N": "Asn", + "D": "Asp", + "C": "Cys", + "Q": "Gln", + "E": "Glu", + "G": "Gly", + "H": "His", + "I": "Ile", + "L": "Leu", + "K": "Lys", + "M": "Met", + "F": "Phe", + "P": "Pro", + "S": "Ser", + "T": "Thr", + "W": "Trp", + "Y": "Tyr", + "V": "Val", + "-": "Gap", + } + legend_items = [ + (f"{aa} ({_AA_NAMES[aa]})", color, "black") + for aa, color in _AA_COLORS.items() + ] + legend_title = "Color key: amino acid identity" + + def _swatch_html(label, fill, text_color="black", border="black"): + return ( + f'' + f'' + f'{label}' + f"" + ) + + swatches_html = "".join( + _swatch_html(label, fill, tc) for label, fill, tc in legend_items + ) + if bsite_spans: + swatches_html += _swatch_html( + "Binding site region", + fill="rgba(70,130,180,0.25)", + border="steelblue", + ) + + legend_div = Div( + text=( + f'
' + f"{legend_title}
" + f'
{swatches_html}
' + ), + width=plot_width, + ) + + p = column(p1, p_blank, p2, legend_div) + + out_dir = Path(output_dir) if output_dir is not None else Path.cwd() + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / f"{file_name}.html" + output_file(filename=str(out_path), title="Alignment result") + save(p) + + return p, str(out_path) diff --git a/drugforge-spectrum/drugforge/spectrum/tests/test_schema.py b/drugforge-spectrum/drugforge/spectrum/tests/test_schema.py new file mode 100644 index 00000000..79b9c9c5 --- /dev/null +++ b/drugforge-spectrum/drugforge/spectrum/tests/test_schema.py @@ -0,0 +1,78 @@ +import pytest +from Bio import SeqIO +from drugforge.spectrum.schema import ( + ProteinSequence, + SequenceList, + run_multiple_sequence_alignment, +) + + +def test_protein_sequence(): + seq = ProteinSequence( + seq_id="P12345", aligned=False, sequence="MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP" + ) + assert seq.seq_id == "P12345" + assert seq.sequence == "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP" + + +@pytest.fixture() +def unaligned_fasta_file(fasta_alignment_path, tmpdir): + sequences = SequenceList.from_fasta(fasta_alignment_path, aligned=True) + seq_list = SequenceList( + aligned=False, + sequences=[sequence.get_unaligned_sequence() for sequence in sequences], + ) + return seq_list.to_fasta(tmpdir / "unaligned.fasta") + + +class TestSequenceList: + def test_from_aligned_fasta(self, fasta_alignment_path): + sequences = SequenceList.from_fasta(fasta_alignment_path, aligned=True) + + with pytest.raises(ValueError): + sequences = SequenceList.from_fasta(fasta_alignment_path, aligned=False) + + def test_from_fasta_unaligned(self, unaligned_fasta_file): + sequences = SequenceList.from_fasta(unaligned_fasta_file, aligned=False) + + with pytest.raises(ValueError): + sequences = SequenceList.from_fasta(unaligned_fasta_file, aligned=True) + + def test_fasta_roundtrip_unaligned(self, unaligned_fasta_file, tmpdir): + sequences = SequenceList.from_fasta(unaligned_fasta_file, aligned=False) + created_path = sequences.to_fasta(tmpdir / "unaligned.fasta") + roundtripped = SequenceList.from_fasta(created_path, aligned=False) + assert sequences == roundtripped + + def test_csv_roundtrip_unaligned(self, unaligned_fasta_file, tmpdir): + sequences = SequenceList.from_fasta(unaligned_fasta_file, aligned=False) + created_path = sequences.to_csv(tmpdir / "unaligned.csv") + roundtripped = SequenceList.from_csv(created_path) + assert sequences == roundtripped + + def test_fasta_roundtrip_aligned(self, fasta_alignment_path, tmpdir): + sequences = SequenceList.from_fasta(fasta_alignment_path, aligned=True) + created_path = sequences.to_fasta(tmpdir / "unaligned.fasta") + roundtripped = SequenceList.from_fasta(created_path, aligned=True) + assert sequences == roundtripped + + def test_csv_roundtrip_aligned(self, fasta_alignment_path, tmpdir): + sequences = SequenceList.from_fasta(fasta_alignment_path, aligned=True) + created_path = sequences.to_csv(tmpdir / "unaligned.csv") + roundtripped = SequenceList.from_csv(created_path) + assert sequences == roundtripped + + +class TestSequenceAlignment: + def test_alignment(self, fasta_alignment_path, unaligned_fasta_file): + unaligned_sequences = SequenceList.from_fasta( + unaligned_fasta_file, aligned=False + ) + + aligned_sequences = run_multiple_sequence_alignment(unaligned_sequences) + + reference_aligned_sequences = SequenceList.from_fasta( + fasta_alignment_path, aligned=True + ) + + assert aligned_sequences == reference_aligned_sequences diff --git a/drugforge-spectrum/drugforge/spectrum/tests/test_score.py b/drugforge-spectrum/drugforge/spectrum/tests/test_score.py index e9dfff7c..6b2942e3 100644 --- a/drugforge-spectrum/drugforge/spectrum/tests/test_score.py +++ b/drugforge-spectrum/drugforge/spectrum/tests/test_score.py @@ -12,6 +12,7 @@ minimize_structure, score_autodock_vina, ) +from torch.fx.experimental.unification.multipledispatch.utils import raises def click_success(result): diff --git a/drugforge-spectrum/drugforge/spectrum/workflows/cli.py b/drugforge-spectrum/drugforge/spectrum/workflows/cli.py new file mode 100644 index 00000000..14b98710 --- /dev/null +++ b/drugforge-spectrum/drugforge/spectrum/workflows/cli.py @@ -0,0 +1,480 @@ +from pathlib import Path + +import click +from drugforge.data.util.logging import FileLogger +from drugforge.spectrum.alphafold import ( + make_fold_inputs, + make_msa_inputs, + select_best_af3, +) +from drugforge.spectrum.boltz import make_boltz_inputs +from drugforge.spectrum.calculate_rmsd import save_alignment_pymol +from drugforge.spectrum.docking import ligand_transfer_docking +from drugforge.spectrum.schema import ( + SequenceList, + find_bsite_resids, + run_multiple_sequence_alignment, + view_alignment, +) + + +def _parse_seeds(ctx, param, value: str) -> list[int]: + """Click callback – convert a comma-separated string to a list of ints.""" + try: + return [int(s.strip()) for s in value.split(",")] + except ValueError: + raise click.BadParameter( + f"Seeds must be comma-separated integers, got: {value!r}" + ) + + +@click.group("spectrum-cli") +def cli(): + pass + + +@cli.command( + "align-fasta", help="Use mafft to run multiple sequence alignment in FASTA format" +) +@click.argument("input_fasta") +@click.argument("output_dir", type=click.Path()) +def align_fasta(input_fasta, output_dir): + """""" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + logger = FileLogger( + logname="align_fasta", path=output_dir, logfile="align_fasta.log" + ).getLogger() + logger.info(f"Aligning {input_fasta}...") + seq_list = SequenceList.from_fasta(input_fasta, aligned=False) + aligned = run_multiple_sequence_alignment(seq_list) + aligned.serialize(output_dir) + logger.info(f"Successfully aligned {len(aligned.sequences)} sequences.") + + +@cli.command( + "vizualize-alignment", help="Visualize aligned sequences using a bokeh plot." +) +@click.argument("input_fasta") +@click.argument("output_dir", type=click.Path()) +@click.option( + "--color-by-group/--no-color-by-group", + default=None, + help=( + "Color sequences by amino acid group match (--color-by-group) or by amino acid " + "identity (--no-color-by-group). By default both plots are produced." + ), +) +@click.option( + "--pdb", + default=None, + type=click.Path(exists=True), + help=( + "Path to a PDB file containing a protein-ligand complex. " + "When provided, binding site residues are highlighted with a blue box overlay." + ), +) +@click.option( + "--ligres", + default="LIG", + show_default=True, + help="Residue name of the ligand in the PDB file.", +) +@click.option( + "--chain", + default="A", + show_default=True, + help="Chain ID of the protein/ligand in the PDB file.", +) +@click.option( + "--bsite-dist", + default=4.5, + show_default=True, + type=float, + help="Distance in Å from the ligand used to define the binding site.", +) +def vizualize_alignment( + input_fasta, output_dir, color_by_group, pdb, ligres, chain, bsite_dist +): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + logger = FileLogger( + logname="vizualize_alignment", + path=output_dir, + logfile="vizualize_alignment.log", + ).getLogger() + + seq_list = SequenceList.from_fasta(input_fasta, aligned=True) + + bsite_resids = None + if pdb is not None: + logger.info(f"Detecting binding site residues from {pdb}...") + bsite_resids = find_bsite_resids( + pdb, ligres=ligres, chain=chain, bsite_dist=bsite_dist + ) + logger.info( + f"Found {len(bsite_resids)} binding site residues: {bsite_resids.tolist()}" + ) + + # Determine which modes to render: both by default, one if explicitly specified. + if color_by_group is None: + modes = [(True, "colored_by_group"), (False, "colored_by_amino_acid")] + elif color_by_group: + modes = [(True, "colored_by_group")] + else: + modes = [(False, "colored_by_amino_acid")] + + for by_group, file_name in modes: + view_alignment( + seq_list, + color_by_group=by_group, + start_idx=1, + output_dir=output_dir, + file_name=file_name, + plot_width=2400, + bsite_resids=bsite_resids, + ) + logger.info(f"Saved alignment plot to {output_dir / file_name}.html") + + +@cli.command( + "msa-input", help="Generate AF3 JSON inputs for the MSA (data pipeline) step." +) +@click.argument("fasta", type=click.Path(exists=True, dir_okay=False)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "--seeds", + "-s", + default="1,2,5,10", + show_default=True, + callback=_parse_seeds, + is_eager=True, + help="Comma-separated integer model seeds.", +) +@click.option( + "--description-prefix", + default="", + show_default=False, + help="Optional string prepended to each chain description, e.g. '2A protease'.", +) +def msa_input(fasta, output_dir, seeds, description_prefix): + """Generate one AF3 JSON per sequence in FASTA for the MSA step. + + MSA fields are left null so AlphaFold 3 runs its data pipeline + (Jackhmmer / Nhmmer). Run the resulting JSONs with --norun_inference. + + FASTA: path to the input FASTA file. + + OUTPUT_DIR: directory to write per-sequence JSON files. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + logger = FileLogger( + logname="msa_input", path=output_dir, logfile="msa_input.log" + ).getLogger() + inputs = make_msa_inputs( + fasta_path=fasta, seeds=seeds, description_prefix=description_prefix + ) + for af3_input in inputs: + out_path = af3_input.write(output_dir) + logger.info(f"Wrote {out_path}") + click.echo(f" Wrote {out_path}") + click.secho(f"\nWrote {len(inputs)} MSA-input JSON(s) to {output_dir}", fg="green") + + +@cli.command("fold-input", help="Generate AF3 fold-input JSONs from MSA outputs.") +@click.argument("msa_output_dir", type=click.Path(exists=True, file_okay=False)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "--seeds", + "-s", + default="1,2,5,10", + show_default=True, + callback=_parse_seeds, + is_eager=True, + help="Comma-separated integer model seeds.", +) +@click.option( + "--fasta", + "-f", + default=None, + type=click.Path(exists=True, dir_okay=False), + help=( + "Optional FASTA to control which sequences are processed and their " + "order. If omitted, every sub-directory in MSA_OUTPUT_DIR is used." + ), +) +def fold_input(msa_output_dir, output_dir, seeds, fasta): + """Generate fold-ready AF3 JSONs from pre-computed MSA outputs. + + Reads each /_data.json written by the AF3 data pipeline and + embeds unpairedMsa + templates for GPU inference with --norun_data_pipeline. + + MSA_OUTPUT_DIR: directory containing AF3 MSA outputs (one sub-dir per sequence). + + OUTPUT_DIR: directory to write per-sequence fold-input JSON files. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + logger = FileLogger( + logname="fold_input", path=output_dir, logfile="fold_input.log" + ).getLogger() + inputs = make_fold_inputs( + msa_output_dir=msa_output_dir, seeds=seeds, fasta_path=fasta + ) + for af3_input in inputs: + out_path = af3_input.write(output_dir) + logger.info(f"Wrote {out_path}") + click.echo(f" Wrote {out_path}") + click.secho(f"\nWrote {len(inputs)} fold-input JSON(s) to {output_dir}", fg="green") + + +@cli.command("af3-struct-alignment") +@click.argument( + "struct_dir", type=click.Path(exists=True, file_okay=False, path_type=Path) +) +@click.argument("ref_pdb", type=click.Path(exists=True, dir_okay=False, path_type=Path)) +@click.argument("output_dir", type=click.Path(file_okay=False, path_type=Path)) +@click.option( + "--chain", + "-c", + default="A", + show_default=True, + help="Chain ID to use for structural alignment.", +) +@click.option( + "--pymol-save", + default="af3_aligned.pse", + show_default=True, + help="Filename for the saved PyMOL session (written inside OUTPUT_DIR).", +) +@click.option( + "--color-by-rmsd", + is_flag=True, + default=False, + help="Color aligned structures by per-residue RMSD in the PyMOL session.", +) +@click.option( + "--fasta", + "-f", + default=None, + type=click.Path(exists=True, dir_okay=False, path_type=Path), + help=( + "Optional FASTA to control which sequences are processed and their " + "order. If omitted, every sub-directory in STRUCT_DIR is used." + ), +) +def af3_struct_alignment( + struct_dir, ref_pdb, output_dir, chain, pymol_save, color_by_rmsd, fasta +): + """Align AF3 fold outputs to a reference structure and save a PyMOL session. + + Walks STRUCT_DIR (one sub-directory per sequence), picks the top-ranked + AF3 CIF model for each sequence, aligns it to REF_PDB, saves each aligned + structure as a PDB in OUTPUT_DIR, then saves a combined PyMOL session. + + STRUCT_DIR: root AF3 fold output directory (one sub-dir per sequence). + + REF_PDB: reference PDB to align all structures against. + + OUTPUT_DIR: directory to write aligned PDBs and the PyMOL session. + """ + from drugforge.spectrum.schema import SequenceList + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + logger = FileLogger( + logname="af3_struct_alignment", + path=output_dir, + logfile="af3_struct_alignment.log", + ).getLogger() + + # Determine which sequences to process and in what order + if fasta is not None: + seq_list = SequenceList.from_fasta(fasta, aligned=False) + names = [seq.seq_id for seq in seq_list] + else: + names = sorted(p.name for p in struct_dir.iterdir() if p.is_dir()) + + if not names: + raise click.ClickException(f"No sequence directories found in {struct_dir}") + + aligned_pdbs = [] + seq_labels = [] + + for name in names: + final_pdb = output_dir / f"{name}_aligned.pdb" + try: + rmsd, aligned_pdb = select_best_af3( + af3_output_dir=struct_dir, + seq_name=name, + ref_pdb=ref_pdb, + chain=chain, + final_pdb=final_pdb, + ) + aligned_pdbs.append(aligned_pdb) + seq_labels.append(name) + logger.info(f"{name}: RMSD = {rmsd:.3f} Å → {aligned_pdb}") + click.echo(f" {name}: RMSD = {rmsd:.3f} Å") + except FileNotFoundError as e: + logger.warning(str(e)) + click.secho(f" WARN: {e}", fg="yellow") + + if not aligned_pdbs: + raise click.ClickException("No structures were successfully aligned.") + + session_save = output_dir / pymol_save + save_alignment_pymol( + aligned_pdbs, seq_labels, str(ref_pdb), str(session_save), chain, color_by_rmsd + ) + logger.info(f"Saved PyMOL session to {session_save}") + click.secho( + f"\nAligned {len(aligned_pdbs)} structure(s). PyMOL session → {session_save}", + fg="green", + ) + + +@cli.command("make-boltz-input", help="Generate Boltz YAML inputs from a FASTA file.") +@click.argument("fasta", type=click.Path(exists=True, dir_okay=False)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "--ligand-smiles", + "-l", + default=None, + help="SMILES string for a ligand to include in every input YAML.", +) +@click.option( + "--ligand-id", + default="L", + show_default=True, + help="Chain ID to assign to the ligand.", +) +def make_boltz_input(fasta, output_dir, ligand_smiles, ligand_id): + """Generate one Boltz YAML per sequence in FASTA. + + Boltz uses --use_msa_server so no separate MSA step is required. + Run the resulting YAMLs with: boltz predict --use_msa_server + + FASTA: path to the input FASTA file. + + OUTPUT_DIR: directory to write per-sequence YAML files. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + logger = FileLogger( + logname="make_boltz_input", path=output_dir, logfile="make_boltz_input.log" + ).getLogger() + inputs = make_boltz_inputs( + fasta_path=fasta, + ligand_smiles=ligand_smiles, + ligand_id=ligand_id, + ) + for boltz_input in inputs: + out_path = boltz_input.write(output_dir) + logger.info(f"Wrote {out_path}") + click.echo(f" Wrote {out_path}") + click.secho(f"\nWrote {len(inputs)} Boltz YAML(s) to {output_dir}", fg="green") + + +@cli.command("ligand-transfer-docking") +@click.argument( + "target_structure_dir", + type=click.Path(exists=True, file_okay=False, path_type=Path), +) +@click.argument( + "reference_fragalysis_dir", + type=click.Path(exists=True, file_okay=False, path_type=Path), +) +@click.argument("output_dir", type=click.Path(file_okay=False, path_type=Path)) +@click.option( + "--ref-chain", + default="A", + show_default=True, + help="Chain ID in the reference complex used for structural alignment.", +) +@click.option( + "--active-site-chain", + default="A", + show_default=True, + help="Chain ID in the target structure used for structural alignment.", +) +@click.option( + "--posit-confidence-cutoff", + default=0.1, + show_default=True, + type=float, + help="Minimum POSIT confidence score to keep a docking result.", +) +@click.option( + "--num-poses", + default=1, + show_default=True, + type=int, + help="Number of docked poses to return per pair.", +) +@click.option( + "--use-omega", + is_flag=True, + default=False, + help="Enumerate conformers with OEOmega before docking (slower, more accurate).", +) +@click.option( + "--allow-final-clash/--no-allow-final-clash", + default=True, + show_default=True, + help="Keep poses that clash in the final docking stage.", +) +@click.option( + "--no-overwrite", + is_flag=True, + default=False, + help="Do not overwrite output_dir if it already exists.", +) +def ligand_transfer_docking_cmd( + target_structure_dir, + reference_fragalysis_dir, + output_dir, + ref_chain, + active_site_chain, + posit_confidence_cutoff, + num_poses, + use_omega, + allow_final_clash, + no_overwrite, +): + """Run POSIT ligand transfer docking for predicted protein structures. + + Aligns each PDB in TARGET_STRUCTURE_DIR to every reference complex in + REFERENCE_FRAGALYSIS_DIR, transfers the reference ligand coordinates, and + runs POSIT self-docking on each pair. No target tag or ML models required. + + TARGET_STRUCTURE_DIR: directory of predicted PDB files (one per sequence). + + REFERENCE_FRAGALYSIS_DIR: Fragalysis-format directory of crystal complexes + (/.pdb + /.sdf). + + OUTPUT_DIR: directory to write docking results. + """ + import logging + + df = ligand_transfer_docking( + target_structure_dir=target_structure_dir, + reference_fragalysis_dir=reference_fragalysis_dir, + output_dir=output_dir, + ref_chain=ref_chain, + active_site_chain=active_site_chain, + posit_confidence_cutoff=posit_confidence_cutoff, + num_poses=num_poses, + use_omega=use_omega, + allow_final_clash=allow_final_clash, + overwrite=not no_overwrite, + loglevel=logging.INFO, + ) + click.secho( + f"\nDocking complete: {len(df)} results → {output_dir / 'docking_results_final.csv'}", + fg="green", + ) + + +if __name__ == "__main__": + cli() diff --git a/drugforge-spectrum/pyproject.toml b/drugforge-spectrum/pyproject.toml index 7f66ba0b..bc95452e 100644 --- a/drugforge-spectrum/pyproject.toml +++ b/drugforge-spectrum/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [] "Bug Tracker" = "https://github.com/choderalab/drugforge/issues" [project.scripts] +drugforge-spectrum = "drugforge.spectrum.workflows.cli:cli" [tool.setuptools.packages.find] where = ["."] diff --git a/drugforge-workflows/drugforge/workflows/spectrum_workflows/cli.py b/drugforge-workflows/drugforge/workflows/spectrum_workflows/cli.py index 3ba832ae..957e2f39 100644 --- a/drugforge-workflows/drugforge/workflows/spectrum_workflows/cli.py +++ b/drugforge-workflows/drugforge/workflows/spectrum_workflows/cli.py @@ -668,5 +668,23 @@ def score( score_complex_workflow(inputs) +from drugforge.spectrum.workflows.cli import ( + af3_struct_alignment, + align_fasta, + fold_input, + ligand_transfer_docking_cmd, + make_boltz_input, + msa_input, + vizualize_alignment, +) + +spectrum.add_command(align_fasta, name="align-fasta") +spectrum.add_command(vizualize_alignment, name="vizualize-alignment") +spectrum.add_command(msa_input, name="msa-input") +spectrum.add_command(fold_input, name="fold-input") +spectrum.add_command(af3_struct_alignment, name="af3-struct-alignment") +spectrum.add_command(make_boltz_input, name="make-boltz-input") +spectrum.add_command(ligand_transfer_docking_cmd, name="ligand-transfer-docking") + if __name__ == "__main__": spectrum()