diff --git a/asapdiscovery-cli/asapdiscovery/cli/cli.py b/asapdiscovery-cli/asapdiscovery/cli/cli.py index a792e854..7f25430f 100644 --- a/asapdiscovery-cli/asapdiscovery/cli/cli.py +++ b/asapdiscovery-cli/asapdiscovery/cli/cli.py @@ -4,7 +4,6 @@ @click.group() def cli(help="Command-line interface for asapdiscovery"): ... - from asapdiscovery.workflows.docking_workflows.cli import ( # noqa: F401, E402, F811 docking, ) diff --git a/asapdiscovery-docking/asapdiscovery/docking/docking.py b/asapdiscovery-docking/asapdiscovery/docking/docking.py index 86477ff9..9240fd89 100644 --- a/asapdiscovery-docking/asapdiscovery/docking/docking.py +++ b/asapdiscovery-docking/asapdiscovery/docking/docking.py @@ -361,11 +361,13 @@ def write_results_to_multi_sdf( raise ValueError("Must provide reconstruct_cls if using disk backend") for res in results: - if backend == BackendType.IN_MEMORY: + if backend == BackendType.DISK: + res = reconstruct_cls.from_json_file(res) + if backend in [BackendType.DISK, BackendType.IN_MEMORY]: lig = res.posed_ligand - elif backend == BackendType.DISK: - lig = reconstruct_cls.from_json_file(res).posed_ligand else: raise ValueError(f"Unknown backend type {backend}") + lig.set_SD_data({"ReferenceStructureName": res.input_pair.complex.target.target_name, + "ReferenceLigandName": res.input_pair.complex.ligand.compound_name}) lig.to_sdf(sdf_file, allow_append=True) diff --git a/asapdiscovery-docking/asapdiscovery/docking/openeye.py b/asapdiscovery-docking/asapdiscovery/docking/openeye.py index d59a413c..c2c35ddc 100644 --- a/asapdiscovery-docking/asapdiscovery/docking/openeye.py +++ b/asapdiscovery-docking/asapdiscovery/docking/openeye.py @@ -225,6 +225,7 @@ def _dock( docking_results = [] + logger.info(f"Running docking on {len(inputs)} sets") for set in inputs: try: # make sure its a path @@ -253,9 +254,11 @@ def _dock( ) # run docking if output does not exist else: + logger.info("Making design units") dus = set.to_design_units() lig_oemol = oechem.OEMol(set.ligand.to_oemol()) if self.use_omega: + logger.info("Running omega on ligand") if self.omega_dense: omegaOpts = oeomega.OEOmegaOptions( oeomega.OEOmegaSampling_Dense @@ -276,12 +279,13 @@ def _dock( raise ValueError( f"Unknown error handling option {failure_mode}" ) - + logger.info("Prepping docking options") opts = oedocking.OEPositOptions() opts.SetIgnoreNitrogenStereo(True) opts.SetPositMethods(self.posit_method.value) opts.SetPoseRelaxMode(self.relax_mode.value) + logger.info("Run docking") pose_res = oedocking.OEPositResults() pose_res, retcode = self.run_oe_posit_docking( opts, pose_res, dus, lig_oemol, self.num_poses @@ -336,9 +340,11 @@ def _dock( ) if retcode == oedocking.OEDockingReturnCode_Success: + logger.info("Docking Success!") input_pairs = [] posed_ligands = [] num_poses = pose_res.GetNumPoses() + logger.info(f"Prepping results for {num_poses} poses") for i, result in enumerate(pose_res.GetSinglePoseResults()): posed_mol = result.GetPose() prob = result.GetProbability() @@ -370,6 +376,7 @@ def _dock( input_pairs.append(set) # Create Docking Results Objects + logger.info("Creating POSITDockingResults objects") docking_results_objects = [] for input_pair, posed_ligand in zip(input_pairs, posed_ligands): docking_results_objects.append( diff --git a/asapdiscovery-docking/asapdiscovery/docking/scorer.py b/asapdiscovery-docking/asapdiscovery/docking/scorer.py index dc60b2dc..42cc3ef0 100644 --- a/asapdiscovery-docking/asapdiscovery/docking/scorer.py +++ b/asapdiscovery-docking/asapdiscovery/docking/scorer.py @@ -1,20 +1,15 @@ import abc import logging -import warnings from enum import Enum -from io import StringIO from pathlib import Path from typing import Any, ClassVar, Optional, Union -import MDAnalysis as mda import numpy as np import pandas as pd -from asapdiscovery.data.backend.openeye import oedocking, oemol_to_pdb_string -from asapdiscovery.dataviz.plip import compute_fint_score +from asapdiscovery.data.backend.openeye import oedocking from asapdiscovery.data.schema.complex import Complex from asapdiscovery.data.schema.ligand import Ligand, LigandIdentifiers from asapdiscovery.data.schema.target import TargetIdentifiers -from asapdiscovery.data.services.postera.manifold_data_validation import TargetTags from asapdiscovery.data.util.dask_utils import ( BackendType, FailureMode, @@ -23,12 +18,8 @@ ) from asapdiscovery.docking.docking import DockingResult from asapdiscovery.docking.docking_data_validation import DockingResultCols -from asapdiscovery.ml.inference import InferenceBase, get_inference_cls_from_model_type -from asapdiscovery.ml.models import MLModelSpecBase -from asapdiscovery.spectrum.fitness import target_has_fitness_data -from mtenn.config import ModelType from multimethod import multimethod -from pydantic.v1 import BaseModel, Field, validator +from pydantic.v1 import BaseModel, Field logger = logging.getLogger(__name__) @@ -60,39 +51,6 @@ class ScoreUnits(str, Enum): # TODO: this is a massive kludge, need to refactor -def endpoint_and_model_type_to_score_type(endpoint: str, model_type: str) -> ScoreType: - """ - Convert an endpoint to a score type. - - Parameters - ---------- - endpoint : str - Endpoint to convert - - Returns - ------- - ScoreType - Score type - """ - if model_type == ModelType.GAT: - if endpoint == "pIC50": # TODO: make this an enum - return ScoreType.GAT_pIC50 - elif endpoint == "LogD": - return ScoreType.GAT_LogD - else: - raise ValueError(f"Endpoint {endpoint} not recognized, for GAT") - elif model_type == ModelType.schnet: - if endpoint == "pIC50": - return ScoreType.schnet_pIC50 - else: - raise ValueError(f"Endpoint {endpoint} not recognized, for Schnet") - elif model_type == ModelType.e3nn: - if endpoint == "pIC50": - return ScoreType.e3nn_pIC50 - else: - raise ValueError(f"Endpoint {endpoint} not recognized for E3NN") - else: - raise ValueError(f"Model type {model_type} not recognized") _SCORE_MANIFOLD_ALIAS = { @@ -261,7 +219,7 @@ class ScorerBase(BaseModel): score_units: ClassVar[ScoreUnits.INVALID] = ScoreUnits.INVALID @abc.abstractmethod - def _score() -> list[DockingResult]: ... + def _score(self) -> list[DockingResult]: ... def score( self, @@ -355,15 +313,6 @@ def scores_to_df(scores: list[Score]) -> pd.DataFrame: return df -def _get_disk_path_from_docking_result(docking_result: DockingResult) -> Path: - if docking_result.provenance is None: - raise ValueError("DockingResult does not have provenance") - disk_path = docking_result.provenance.get("on_disk_location", None) - if not disk_path: - raise ValueError("DockingResult provenance does not have on_disk_location") - return disk_path - - class ChemGauss4Scorer(ScorerBase): """ Scoring using ChemGauss. @@ -452,405 +401,6 @@ def _dispatch(self, inputs: list[Path], **kwargs) -> list[Score]: return self._dispatch(complexes) -class FINTScorer(ScorerBase): - """ - Score using Fitness Interaction Score - - Overloaded to accept DockingResults, Complexes, or Paths to PDB files. - """ - - score_type: ScoreType = Field(ScoreType.FINT, description="Type of score") - units: ClassVar[ScoreUnits.arbitrary] = ScoreUnits.arbitrary - target: TargetTags = Field(..., description="Which target to use for scoring") - - @validator("target") - @classmethod - def validate_target(cls, v): - if not target_has_fitness_data(v): - raise ValueError( - "target does not have fitness data so cannot use FINTScorer" - ) - return v - - @dask_vmap(["inputs"]) - @backend_wrapper("inputs") - def _score( - self, - inputs: Union[list[DockingResult], list[Complex], list[Path]], - return_for_disk_backend: bool = False, - **kwargs, - ) -> list[Score]: - """ - Score the inputs, dispatching based on type. - """ - return self._dispatch( - inputs, return_for_disk_backend=return_for_disk_backend, **kwargs - ) - - @multimethod - def _dispatch( - self, - inputs: list[DockingResult], - return_for_disk_backend: bool = False, - **kwargs, - ) -> list[Score]: - """ - Dispatch for DockingResults - """ - results = [] - for inp in inputs: - _, fint_score = compute_fint_score( - inp.to_protein(), inp.posed_ligand.to_oemol(), self.target - ) - - sc = Score.from_score_and_docking_result( - fint_score, self.score_type, self.units, inp - ) - # overwrite the input with the path to the file - if return_for_disk_backend: - sc.input = _get_disk_path_from_docking_result(inp) - - results.append(sc) - - return results - - @_dispatch.register - def _dispatch(self, inputs: list[Complex], **kwargs): - """ - Dispatch for Complexes - """ - results = [] - for inp in inputs: - _, fint_score = compute_fint_score( - inp.target.to_oemol(), inp.ligand.to_oemol(), self.target - ) - results.append( - Score.from_score_and_complex( - fint_score, self.score_type, self.units, inp - ) - ) - return results - - @_dispatch.register - def _dispatch(self, inputs: list[Path], **kwargs): - """ - Dispatch for PDB files from disk - """ - # assuming reading PDB files from disk - complexes = [ - Complex.from_pdb( - p, - ligand_kwargs={"compound_name": f"{p.stem}_ligand"}, - target_kwargs={"target_name": f"{p.stem}_target"}, - ) - for p in inputs - ] - - return self._dispatch(complexes, **kwargs) - - -# keep track of all the ml scorers -_ml_scorer_classes_meta = [] - - -# decorator to register all the ml scorers -def register_ml_scorer(cls): - _ml_scorer_classes_meta.append(cls) - return cls - - -class MLModelScorer(ScorerBase): - """ - Baseclass to score from some kind of ML model, including 2D or 3D models - """ - - model_type: ClassVar[ModelType.INVALID] = ModelType.INVALID - score_type: ScoreType = Field(..., description="Type of score") - endpoint: Optional[str] = Field(None, description="Endpoint biological property") - units: ClassVar[ScoreUnits.INVALID] = ScoreUnits.INVALID - - targets: Any = Field( - ..., - description="Which targets can this model do predictions for", # FIXME: Optional[set[TargetTags]] - ) - model_name: str = Field(..., description="String indicating which model to use") - inference_cls: InferenceBase = Field(..., description="Inference class") - - @classmethod - def from_latest_by_target(cls, target: TargetTags): - if cls.model_type == ModelType.INVALID: - raise Exception("trying to instantiate some kind a baseclass") - inference_cls = get_inference_cls_from_model_type(cls.model_type) - inference_instance = inference_cls.from_latest_by_target(target) - if inference_instance is None: - logger.warn( - f"no ML model of type {cls.model_type} found for target: {target}, skipping" - ) - return None - else: - try: - instance = cls( - targets=inference_instance.targets, - model_name=inference_instance.model_name, - inference_cls=inference_instance, - endpoint=inference_instance.model_spec.endpoint, - score_type=endpoint_and_model_type_to_score_type( - inference_instance.model_spec.endpoint, cls.model_type - ), - ) - return instance - except Exception as e: - logger.error(f"error instantiating MLModelScorer: {e}") - return None - - @staticmethod - def from_latest_by_target_and_type(target: TargetTags, type: ModelType): - """ - Get the latest ML Scorer by target and type. - - Parameters - ---------- - target : TargetTags - Target to get the scorer for - type : ModelType - Type of model to get the scorer for - """ - if type == ModelType.INVALID: - raise Exception("trying to instantiate some kind a baseclass") - scorer_class = get_ml_scorer_cls_from_model_type(type) - return scorer_class.from_latest_by_target(target) - - @classmethod - def from_model_name(cls, model_name: str): - if cls.model_type == ModelType.INVALID: - raise Exception("trying to instantiate some kind a baseclass") - inference_cls = get_inference_cls_from_model_type(cls.model_type) - inference_instance = inference_cls.from_model_name(model_name) - if inference_instance is None: - logger.warn( - f"no ML model of type {cls.model_type} found for model_name: {model_name}, skipping" - ) - return None - else: - try: - instance = cls( - targets=inference_instance.targets, - model_name=inference_instance.model_name, - inference_cls=inference_instance, - endpoint=inference_instance.model_spec.endpoint, - score_type=endpoint_and_model_type_to_score_type( - inference_instance.model_spec.endpoint, cls.model_type - ), - ) - return instance - except Exception as e: - logger.error(f"error instantiating MLModelScorer: {e}") - return None - - @staticmethod - def load_model_specs( - models: list[MLModelSpecBase], - ) -> list["MLModelScorer"]: # noqa: F821 - """ - Load a list of models into scorers. - - Parameters - ---------- - models : list[MLModelSpecBase] - List of models to load - """ - scorers = [] - for model in models: - scorer_class = get_ml_scorer_cls_from_model_type(model.type) - scorer = scorer_class.from_model_name(model.name) - if scorer is not None: - scorers.append(scorer) - return scorers - - -@register_ml_scorer -class GATScorer(MLModelScorer): - """ - Scoring using GAT ML Model - """ - - model_type: ClassVar[ModelType.GAT] = ModelType.GAT - units: ClassVar[ScoreUnits.pIC50] = ScoreUnits.pIC50 - - @dask_vmap(["inputs"]) - @backend_wrapper("inputs") - def _score( - self, - inputs: Union[list[DockingResult], list[str], list[Ligand]], - return_for_disk_backend: bool = False, - **kwargs, - ) -> list[Score]: - """ - Score the inputs, dispatching based on type. - """ - return self._dispatch( - inputs, return_for_disk_backend=return_for_disk_backend, **kwargs - ) - - @multimethod - def _dispatch( - self, - inputs: list[DockingResult], - return_for_disk_backend: bool = False, - **kwargs, - ) -> list[Score]: - """ - Dispatch for DockingResults - """ - results = [] - for inp in inputs: - gat_score = self.inference_cls.predict_from_smiles(inp.posed_ligand.smiles) - sc = Score.from_score_and_docking_result( - gat_score, - self.score_type, - self.units, - inp, - ) - # overwrite the input with the path to the file - if return_for_disk_backend: - sc.input = _get_disk_path_from_docking_result(inp) - results.append(sc) - return results - - @_dispatch.register - def _dispatch(self, inputs: list[str], **kwargs) -> list[Score]: - """ - Dispatch for SMILES strings - """ - results = [] - for inp in inputs: - gat_score = self.inference_cls.predict_from_smiles(inp) - results.append( - Score.from_score_and_smiles( - gat_score, - inp, - self.score_type, - self.units, - ) - ) - return results - - @_dispatch.register - def _dispatch(self, inputs: list[Ligand], **kwargs) -> list[Score]: - """ - Dispatch for Ligands - """ - results = [] - for inp in inputs: - gat_score = self.inference_cls.predict_from_smiles(inp.smiles) - results.append( - Score.from_score_and_ligand( - gat_score, - inp, - self.score_type, - self.units, - ) - ) - return results - - -class E3MLModelScorer(MLModelScorer): - """ - Scoring using ML Models that operate over 3D structures - These all share an interface so we can use multimethods to dispatch - for the different input types for all subclasses. - """ - - model_type: ClassVar[ModelType.INVALID] = ModelType.INVALID - units: ClassVar[ScoreUnits.INVALID] = ScoreUnits.INVALID - - @dask_vmap(["inputs"]) - @backend_wrapper("inputs") - def _score( - self, - inputs: Union[list[DockingResult], list[Complex], list[Path]], - return_for_disk_backend: bool = False, - **kwargs, - ) -> list[Score]: - return self._dispatch( - inputs, return_for_disk_backend=return_for_disk_backend, **kwargs - ) - - @multimethod - def _dispatch( - self, - inputs: list[DockingResult], - return_for_disk_backend: bool = False, - **kwargs, - ) -> list[Score]: - results = [] - for inp in inputs: - score = self.inference_cls.predict_from_oemol(inp.to_posed_oemol()) - - sc = Score.from_score_and_docking_result( - score, self.score_type, self.units, inp - ) - # overwrite the input with the path to the file - if return_for_disk_backend: - sc.input = _get_disk_path_from_docking_result(inp) - results.append(sc) - - return results - - @_dispatch.register - def _dispatch(self, inputs: list[Complex], **kwargs) -> list[Score]: - results = [] - for inp in inputs: - score = self.inference_cls.predict_from_oemol(inp.to_combined_oemol()) - results.append( - Score.from_score_and_complex(score, self.score_type, self.units, inp) - ) - return results - - @_dispatch.register - def _dispatch(self, inputs: list[Path], **kwargs) -> list[Score]: - # assuming reading PDB files from disk - complexes = [ - Complex.from_pdb( - p, - ligand_kwargs={"compound_name": f"{p.stem}_ligand"}, - target_kwargs={"target_name": f"{p.stem}_target"}, - ) - for i, p in enumerate(inputs) - ] - return self._dispatch(complexes, **kwargs) - - -@register_ml_scorer -class SchnetScorer(E3MLModelScorer): - """ - Scoring using Schnet ML Model - """ - - model_type: ClassVar[ModelType.schnet] = ModelType.schnet - units: ClassVar[ScoreUnits.pIC50] = ScoreUnits.pIC50 - - -@register_ml_scorer -class E3NNScorer(E3MLModelScorer): - """ - Scoring using e3nn ML Model - """ - - model_type: ClassVar[ModelType.e3nn] = ModelType.e3nn - units: ClassVar[ScoreUnits.pIC50] = ScoreUnits.pIC50 - - -def get_ml_scorer_cls_from_model_type(model_type: ModelType): - instantiable_classes = [ - m for m in _ml_scorer_classes_meta if m.model_type != ModelType.INVALID - ] - scorer_class = [m for m in instantiable_classes if m.model_type == model_type] - if len(scorer_class) != 1: - raise Exception("Somehow got multiple scorers") - return scorer_class[0] - - class MetaScorer(BaseModel): """ Score from a combination of other scorers, the scorers must share an input type, @@ -893,98 +443,10 @@ def score( return np.ravel(results).tolist() -class SymClashScorer(ScorerBase): - """ - Scoring, checking for clashes between ligand and target - in neighboring unit cells. - """ - - score_type: ScoreType = Field(ScoreType.sym_clash, description="Type of score") - units: ClassVar[ScoreUnits.arbitrary] = ScoreUnits.arbitrary - - count_clashing_pairs: bool = Field( - False, - description="Whether to count clashing distance pairs, rather than unique clashing ligand atoms", - ) - - vdw_radii_fudge_factor: float = Field( - 1.0, - description="fudge factor multiplier for vdw radii, lower to decrease clash sensitivity, higher to increase", - ) - - @dask_vmap(["inputs"]) - @backend_wrapper("inputs") - def _score(self, inputs, **kwargs) -> list[Score]: - """ - Score the inputs, dispatching based on type. - """ - return self._dispatch(inputs, **kwargs) - - @multimethod - def _dispatch(self, inputs: list[Complex], **kwargs) -> list[Score]: - """ - Dispatch for Complex - """ - results = [] - warnings.warn( - "SymClashScorer relies on expanded protein units having chain X as constructed by SymmetryExpander" - ) - for inp in inputs: - # load into MDA universe - u = mda.Universe( - mda.lib.util.NamedStream( - StringIO(oemol_to_pdb_string(inp.to_combined_oemol())), - "complex.pdb", - ) - ) - lig = u.select_atoms("not protein") - symmetry_expanded_prot = u.select_atoms("protein and chainID X") - # hacky but expand to real space with mega box - # multiply first 3 dimensions by 20 - expanded_box = u.dimensions - expanded_box[:3] *= 20 - pair_indices, pair_distances = mda.lib.distances.capped_distance( - lig, - symmetry_expanded_prot, - 4, - box=expanded_box, # large cutoff to loop in a good amount of distances up to 8Å - ) - # check if distance for an atom pair is less than summed vdw radii - num_clashes = 0 - clashing_lig_at = set() - clashing_prot_at = set() - - for k, [i, j] in enumerate(pair_indices): - lig_atom = lig[i] - prot_atom = symmetry_expanded_prot[j] - distance = pair_distances[k] - if ( - ( - distance - < ( - ( - mda.topology.tables.vdwradii[lig_atom.element.upper()] - * self.vdw_radii_fudge_factor - ) - + ( - mda.topology.tables.vdwradii[prot_atom.element.upper()] - * self.vdw_radii_fudge_factor - ) - ) - ) - and lig_atom.element != "H" - and prot_atom.element != "H" - ): - num_clashes += 1 - clashing_lig_at.add(i) - clashing_prot_at.add(j) - - if self.count_clashing_pairs: - val = num_clashes - else: - val = len(clashing_lig_at) # seems ok as metric for now - - results.append( - Score.from_score_and_complex(val, self.score_type, self.units, inp) - ) - return results +def _get_disk_path_from_docking_result(docking_result: DockingResult) -> Path: + if docking_result.provenance is None: + raise ValueError("DockingResult does not have provenance") + disk_path = docking_result.provenance.get("on_disk_location", None) + if not disk_path: + raise ValueError("DockingResult provenance does not have on_disk_location") + return disk_path diff --git a/asapdiscovery-docking/asapdiscovery/docking/tests/test_docking.py b/asapdiscovery-docking/asapdiscovery/docking/tests/test_docking.py index 8d1bb1c1..2dead0cf 100644 --- a/asapdiscovery-docking/asapdiscovery/docking/tests/test_docking.py +++ b/asapdiscovery-docking/asapdiscovery/docking/tests/test_docking.py @@ -116,6 +116,36 @@ def test_multipose_docking_with_cache_and_writing( ) assert len(list(tmp_path.glob("docking_results/*/*.sdf"))) == num_poses_expected + @pytest.mark.parametrize("num_poses", [1, 5, 10, 20, 50]) + @pytest.mark.parametrize("use_omega", [True, False]) + def test_multipose_docking_speed( + self, docking_input_pair, tmp_path, num_poses, use_omega + ): + """Test how docking time scales with number of poses requested.""" + import time + + # Initialize docker with specific number of poses + docker = POSITDocker(use_omega=use_omega, num_poses=num_poses) + + # Time the docking + start_time = time.time() + results = docker.dock( + [docking_input_pair], output_dir=tmp_path / f"docking_results_{num_poses}" + ) + end_time = time.time() + + # Basic assertions + assert len(results) > 0 + assert results[0].probability > 0.0 + + # Store timing info in the test report + timing = end_time - start_time + poses_generated = len(results) + print( + f"Docking with {num_poses} poses took {timing:.2f}s " + f"and generated {poses_generated} poses" + ) + def test_results_to_df(self, results_simple): df = results_simple[0].to_df() assert DockingResultCols.SMILES in df.columns diff --git a/asapdiscovery-docking/asapdiscovery/docking/tests/test_scorers.py b/asapdiscovery-docking/asapdiscovery/docking/tests/test_scorers.py index 3de820fa..56726ca8 100644 --- a/asapdiscovery-docking/asapdiscovery/docking/tests/test_scorers.py +++ b/asapdiscovery-docking/asapdiscovery/docking/tests/test_scorers.py @@ -1,12 +1,9 @@ import pytest from asapdiscovery.docking.scorer import ( ChemGauss4Scorer, - E3NNScorer, - FINTScorer, - GATScorer, MetaScorer, - SchnetScorer, ) +from asapdiscovery.workflows.scorers import GATScorer, SchnetScorer, E3NNScorer, FINTScorer # parametrize over fixtures diff --git a/asapdiscovery-docking/asapdiscovery/docking/workflows/__init__.py b/asapdiscovery-docking/asapdiscovery/docking/workflows/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/asapdiscovery-docking/asapdiscovery/docking/workflows/cli.py b/asapdiscovery-docking/asapdiscovery/docking/workflows/cli.py new file mode 100644 index 00000000..0a95d5c3 --- /dev/null +++ b/asapdiscovery-docking/asapdiscovery/docking/workflows/cli.py @@ -0,0 +1,140 @@ +import click +import logging +from typing import Optional, Union +from asapdiscovery.docking.workflows.cross_docking import CrossDockingWorkflowInputs, cross_docking_workflow +from asapdiscovery.cli.cli_args import target, ligands, pdb_file, fragalysis_dir, structure_dir, save_to_cache, cache_dir, use_only_cache, dask_args, output_dir, overwrite, input_json, loglevel +from asapdiscovery.docking.openeye import POSIT_METHOD, POSIT_RELAX_MODE +from asapdiscovery.docking.selectors.selector_list import StructureSelector +from asapdiscovery.data.services.postera.manifold_data_validation import TargetTags +from asapdiscovery.data.util.dask_utils import DaskType, FailureMode + +@click.group() +def cli(help="Command-line interface for asapdiscovery-docking"): ... + +@cli.command() +@target +@click.option( + "--use-omega", + is_flag=True, + default=False, + help="Whether to use OEOmega conformer enumeration before docking (slower, more accurate)", +) +@click.option( + "--omega-dense", + is_flag=True, + default=False, + help="Whether to use dense conformer enumeration with OEOmega (slower, more accurate)", +) +@click.option( + "--posit-method", + type=click.Choice(POSIT_METHOD.get_names(), case_sensitive=False), + default="all", + help="The set of methods POSIT can use. Defaults to all.", +) +@click.option( + "--relax-mode", + type=click.Choice(POSIT_RELAX_MODE.get_names(), case_sensitive=False), + default="none", + help="When to check for relaxation either, 'clash', 'all', 'none'", +) +@click.option( + "--allow-retries", + is_flag=True, + default=False, + help="Whether to allow POSIT to retry with relaxed parameters if docking fails (slower, more likely to succeed)", +) +@click.option( + "--allow-final-clash", + is_flag=True, + default=False, + help="Allow clashing poses in last stage of docking", +) +@click.option( + "--multi-reference", + is_flag=True, + default=False, + help="Whether to pass multiple references to the docker for each ligand instead of just one at a time", +) +@click.option( + "--structure-selector", + type=click.Choice(StructureSelector.get_values(), case_sensitive=False), + default=StructureSelector.LEAVE_SIMILAR_OUT.value, + help="The type of structure selector to use.", +) +@click.option("--num-poses", type=int, default=1, help="Number of poses to generate") +@ligands +@pdb_file +@fragalysis_dir +@structure_dir +@save_to_cache +@cache_dir +@use_only_cache +@dask_args +@output_dir +@overwrite +@input_json +@loglevel +def cross_docking( + target: TargetTags, + multi_reference: bool = False, + structure_selector: StructureSelector = StructureSelector.LEAVE_SIMILAR_OUT, + use_omega: bool = False, + omega_dense: bool = False, + posit_method: Optional[str] = POSIT_METHOD.ALL.name, + relax_mode: Optional[str] = POSIT_RELAX_MODE.NONE.name, + num_poses: int = 1, + allow_retries: bool = False, + allow_final_clash: bool = False, + ligands: Optional[str] = None, + pdb_file: Optional[str] = None, + fragalysis_dir: Optional[str] = None, + structure_dir: Optional[str] = None, + use_only_cache: bool = False, + save_to_cache: Optional[bool] = True, + cache_dir: Optional[str] = None, + output_dir: str = "output", + overwrite: bool = True, + input_json: Optional[str] = None, + use_dask: bool = False, + dask_type: DaskType = DaskType.LOCAL, + dask_n_workers: Optional[int] = None, + failure_mode: FailureMode = FailureMode.SKIP, + loglevel: Union[int, str] = logging.INFO, +): + """ + Run cross docking on a set of ligands, against a set of targets. + """ + + if input_json is not None: + print("Loading inputs from json file... Will override all other inputs.") + inputs = CrossDockingWorkflowInputs.from_json_file(input_json) + + else: + inputs = CrossDockingWorkflowInputs( + target=target, + multi_reference=multi_reference, + structure_selector=structure_selector, + use_dask=use_dask, + dask_type=dask_type, + dask_n_workers=dask_n_workers, + failure_mode=failure_mode, + use_omega=use_omega, + omega_dense=omega_dense, + posit_method=POSIT_METHOD[posit_method], + relax_mode=POSIT_RELAX_MODE[relax_mode], + num_poses=num_poses, + allow_retries=allow_retries, + ligands=ligands, + pdb_file=pdb_file, + fragalysis_dir=fragalysis_dir, + structure_dir=structure_dir, + cache_dir=cache_dir, + use_only_cache=use_only_cache, + save_to_cache=save_to_cache, + output_dir=output_dir, + overwrite=overwrite, + allow_final_clash=allow_final_clash, + loglevel=loglevel, + ) + + cross_docking_workflow(inputs) \ No newline at end of file diff --git a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/cross_docking.py b/asapdiscovery-docking/asapdiscovery/docking/workflows/cross_docking.py similarity index 91% rename from asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/cross_docking.py rename to asapdiscovery-docking/asapdiscovery/docking/workflows/cross_docking.py index 01340fdc..bad6b451 100644 --- a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/cross_docking.py +++ b/asapdiscovery-docking/asapdiscovery/docking/workflows/cross_docking.py @@ -22,9 +22,7 @@ from asapdiscovery.docking.openeye import POSIT_METHOD, POSIT_RELAX_MODE, POSITDocker from asapdiscovery.docking.scorer import ChemGauss4Scorer, MetaScorer from asapdiscovery.modeling.protein_prep import ProteinPrepper -from asapdiscovery.workflows.docking_workflows.workflows import ( - DockingWorkflowInputsBase, -) +from asapdiscovery.docking.workflows.docking_workflows import DockingWorkflowInputsBase from pydantic.v1 import Field, PositiveInt @@ -203,11 +201,10 @@ def cross_docking_workflow(inputs: CrossDockingWorkflowInputs): ) results = docker.dock( sets, - output_dir=output_dir / "docking_results", use_dask=inputs.use_dask, dask_client=dask_client, failure_mode=inputs.failure_mode, - return_for_disk_backend=True, + return_for_disk_backend=False, ) n_results = len(results) logger.info(f"Docked {n_results} pairs successfully") @@ -227,7 +224,7 @@ def cross_docking_workflow(inputs: CrossDockingWorkflowInputs): write_results_to_multi_sdf( output_dir / "docking_results.sdf", results, - backend=BackendType.DISK, + backend=BackendType.IN_MEMORY, reconstruct_cls=docker.result_cls, ) @@ -238,23 +235,11 @@ def cross_docking_workflow(inputs: CrossDockingWorkflowInputs): dask_client=dask_client, failure_mode=inputs.failure_mode, return_df=True, - backend=BackendType.DISK, + backend=BackendType.IN_MEMORY, reconstruct_cls=docker.result_cls, ) del results - scores_df.to_csv(data_intermediates / "docking_scores_raw.csv", index=False) - - # rename columns for manifold - logger.info("Renaming columns for manifold") - result_df = rename_output_columns_for_manifold( - scores_df, - inputs.target, - [DockingResultCols], - manifold_validate=True, - drop_non_output=True, - allow=[DockingResultCols.LIGAND_ID.value], - ) - - result_df.to_csv(output_dir / "docking_results_final.csv", index=False) + scores_df.to_csv(output_dir / "docking_scores_raw.csv", index=False) + logger.info("Finished successfully!") diff --git a/asapdiscovery-docking/asapdiscovery/docking/workflows/docking_workflows.py b/asapdiscovery-docking/asapdiscovery/docking/workflows/docking_workflows.py new file mode 100644 index 00000000..7b5e5bd0 --- /dev/null +++ b/asapdiscovery-docking/asapdiscovery/docking/workflows/docking_workflows.py @@ -0,0 +1,133 @@ +import logging + +from pathlib import Path + +from typing import Optional, Union + +from pydantic.v1 import BaseModel, Field, PositiveInt, root_validator + +from asapdiscovery.data.metadata.resources import active_site_chains +from asapdiscovery.data.services.postera.manifold_data_validation import TargetTags +from asapdiscovery.data.util.dask_utils import DaskType, FailureMode + + +class DockingWorkflowInputsBase(BaseModel): + ligands: Optional[str] = Field( + None, description="Path to a molecule file containing query ligands." + ) + + pdb_file: Optional[Path] = Field( + None, description="Path to a PDB file to prep and dock to." + ) + + fragalysis_dir: Optional[Path] = Field( + None, description="Path to a directory containing a Fragalysis dump." + ) + structure_dir: Optional[Path] = Field( + None, + description="Path to a directory containing structures to dock instead of a full fragalysis database.", + ) + + cache_dir: Optional[str] = Field( + None, description="Path to a directory where a cache has been generated" + ) + + use_only_cache: bool = Field( + False, + description="Whether to only use the cached structures, otherwise try to prep uncached structures.", + ) + + save_to_cache: bool = Field( + True, + description="Generate a cache from structures prepped in this workflow run in this directory", + ) + + target: TargetTags = Field(None, description="The target to dock against.") + + write_final_sdf: bool = Field( + default=True, + description="Whether to write the final docked poses to an SDF file.", + ) + use_dask: bool = Field(True, description="Whether to use dask for parallelism.") + + dask_type: DaskType = Field( + DaskType.LOCAL, description="Dask client to use for parallelism." + ) + + dask_n_workers: Optional[PositiveInt] = Field(None, description="Number of workers") + + failure_mode: FailureMode = Field( + FailureMode.SKIP, description="Dask failure mode." + ) + + n_select: PositiveInt = Field( + 5, description="Number of targets to dock each ligand against." + ) + logname: str = Field( + "", description="Name of the log file." + ) # use root logger for proper forwarding of logs from dask + + loglevel: Union[int, str] = Field(logging.INFO, description="Logging level") + + output_dir: Path = Field(Path("output"), description="Output directory") + + overwrite: bool = Field( + False, description="Whether to overwrite existing output directory." + ) + ref_chain: Optional[str] = Field( + None, + description="Chain ID to align to in reference structure containing the active site", + ) + active_site_chain: Optional[str] = Field( + None, + description="Active site chain ID to align to ref_chain in reference structure", + ) + + class Config: + arbitrary_types_allowed = True + + @classmethod + def from_json_file(cls, file: str | Path): + return cls.parse_file(str(file)) + + def to_json_file(self, file: str | Path): + with open(file, "w") as f: + f.write(self.json(indent=2)) + + @root_validator + @classmethod + def check_inputs(cls, values): + """ + Validate inputs + """ + ligands = values.get("ligands") + fragalysis_dir = values.get("fragalysis_dir") + structure_dir = values.get("structure_dir") + postera = values.get("postera") + pdb_file = values.get("pdb_file") + + if postera and ligands: + raise ValueError("Cannot specify both ligands and postera.") + + if not postera and not ligands: + raise ValueError("Must specify either ligands or postera.") + + # can only specify one of fragalysis dir, structure dir and PDB file + if sum([bool(fragalysis_dir), bool(structure_dir), bool(pdb_file)]) != 1: + raise ValueError( + "Must specify exactly one of fragalysis_dir, structure_dir or pdb_file" + ) + + return values + + @root_validator(pre=True) + def check_and_set_chains(cls, values): + active_site_chain = values.get("active_site_chain") + ref_chain = values.get("ref_chain") + target = values.get("target") + if not active_site_chain: + values["active_site_chain"] = active_site_chains[target] + # set same chain for active site if not specified + if not ref_chain: + values["ref_chain"] = active_site_chains[target] + return values diff --git a/asapdiscovery-docking/pyproject.toml b/asapdiscovery-docking/pyproject.toml index a12f9713..09c54f1e 100644 --- a/asapdiscovery-docking/pyproject.toml +++ b/asapdiscovery-docking/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [] [project.scripts] make-docked-complexes = "asapdiscovery.docking.scripts.make_docked_complexes_schema_v2:main" +asap-docking = "asapdiscovery.docking.workflows.cli:cli" [tool.setuptools.packages.find] where = ["."] diff --git a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/cli.py b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/cli.py index 3955df2c..f156c150 100644 --- a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/cli.py +++ b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/cli.py @@ -22,15 +22,10 @@ target, use_only_cache, ) -from asapdiscovery.docking.selectors.selector_list import StructureSelector + from asapdiscovery.data.services.postera.manifold_data_validation import TargetTags from asapdiscovery.data.util.dask_utils import DaskType, FailureMode -from asapdiscovery.docking.openeye import POSIT_METHOD, POSIT_RELAX_MODE from asapdiscovery.simulation.simulate import OpenMMPlatform -from asapdiscovery.workflows.docking_workflows.cross_docking import ( - CrossDockingWorkflowInputs, - cross_docking_workflow, -) from asapdiscovery.workflows.docking_workflows.large_scale_docking import ( LargeScaleDockingInputs, large_scale_docking_workflow, @@ -170,135 +165,6 @@ def large_scale( large_scale_docking_workflow(inputs) -@docking.command() -@target -@click.option( - "--use-omega", - is_flag=True, - default=False, - help="Whether to use OEOmega conformer enumeration before docking (slower, more accurate)", -) -@click.option( - "--omega-dense", - is_flag=True, - default=False, - help="Whether to use dense conformer enumeration with OEOmega (slower, more accurate)", -) -@click.option( - "--posit-method", - type=click.Choice(POSIT_METHOD.get_names(), case_sensitive=False), - default="all", - help="The set of methods POSIT can use. Defaults to all.", -) -@click.option( - "--relax-mode", - type=click.Choice(POSIT_RELAX_MODE.get_names(), case_sensitive=False), - default="none", - help="When to check for relaxation either, 'clash', 'all', 'none'", -) -@click.option( - "--allow-retries", - is_flag=True, - default=False, - help="Whether to allow POSIT to retry with relaxed parameters if docking fails (slower, more likely to succeed)", -) -@click.option( - "--allow-final-clash", - is_flag=True, - default=False, - help="Allow clashing poses in last stage of docking", -) -@click.option( - "--multi-reference", - is_flag=True, - default=False, - help="Whether to pass multiple references to the docker for each ligand instead of just one at a time", -) -@click.option( - "--structure-selector", - type=click.Choice(StructureSelector.get_values(), case_sensitive=False), - default=StructureSelector.LEAVE_SIMILAR_OUT.value, - help="The type of structure selector to use.", -) -@click.option("--num-poses", type=int, default=1, help="Number of poses to generate") -@ligands -@pdb_file -@fragalysis_dir -@structure_dir -@save_to_cache -@cache_dir -@use_only_cache -@dask_args -@output_dir -@overwrite -@input_json -@loglevel -def cross_docking( - target: TargetTags, - multi_reference: bool = False, - structure_selector: StructureSelector = StructureSelector.LEAVE_SIMILAR_OUT, - use_omega: bool = False, - omega_dense: bool = False, - posit_method: Optional[str] = POSIT_METHOD.ALL.name, - relax_mode: Optional[str] = POSIT_RELAX_MODE.NONE.name, - num_poses: int = 1, - allow_retries: bool = False, - allow_final_clash: bool = False, - ligands: Optional[str] = None, - pdb_file: Optional[str] = None, - fragalysis_dir: Optional[str] = None, - structure_dir: Optional[str] = None, - use_only_cache: bool = False, - save_to_cache: Optional[bool] = True, - cache_dir: Optional[str] = None, - output_dir: str = "output", - overwrite: bool = True, - input_json: Optional[str] = None, - use_dask: bool = False, - dask_type: DaskType = DaskType.LOCAL, - dask_n_workers: Optional[int] = None, - failure_mode: FailureMode = FailureMode.SKIP, - loglevel: Union[int, str] = logging.INFO, -): - """ - Run cross docking on a set of ligands, against a set of targets. - """ - - if input_json is not None: - print("Loading inputs from json file... Will override all other inputs.") - inputs = CrossDockingWorkflowInputs.from_json_file(input_json) - - else: - inputs = CrossDockingWorkflowInputs( - target=target, - multi_reference=multi_reference, - structure_selector=structure_selector, - use_dask=use_dask, - dask_type=dask_type, - dask_n_workers=dask_n_workers, - failure_mode=failure_mode, - use_omega=use_omega, - omega_dense=omega_dense, - posit_method=POSIT_METHOD[posit_method], - relax_mode=POSIT_RELAX_MODE[relax_mode], - num_poses=num_poses, - allow_retries=allow_retries, - ligands=ligands, - pdb_file=pdb_file, - fragalysis_dir=fragalysis_dir, - structure_dir=structure_dir, - cache_dir=cache_dir, - use_only_cache=use_only_cache, - save_to_cache=save_to_cache, - output_dir=output_dir, - overwrite=overwrite, - allow_final_clash=allow_final_clash, - loglevel=loglevel, - ) - - cross_docking_workflow(inputs) - - @docking.command() @target @click.option( diff --git a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/large_scale_docking.py b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/large_scale_docking.py index 1d91053f..ada286e7 100644 --- a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/large_scale_docking.py +++ b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/large_scale_docking.py @@ -33,10 +33,9 @@ from asapdiscovery.docking.openeye import POSITDocker from asapdiscovery.docking.scorer import ( ChemGauss4Scorer, - FINTScorer, MetaScorer, - MLModelScorer, ) +from asapdiscovery.workflows.scorers import MLModelScorer, FINTScorer from asapdiscovery.ml.models import ASAPMLModelRegistry from asapdiscovery.modeling.protein_prep import ProteinPrepper from asapdiscovery.spectrum.fitness import target_has_fitness_data diff --git a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/ligand_transfer_docking.py b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/ligand_transfer_docking.py index 73d83c0b..af3031fb 100644 --- a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/ligand_transfer_docking.py +++ b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/ligand_transfer_docking.py @@ -28,13 +28,12 @@ from asapdiscovery.docking.docking import write_results_to_multi_sdf from asapdiscovery.docking.docking_data_validation import DockingResultCols from asapdiscovery.docking.openeye import POSIT_METHOD, POSIT_RELAX_MODE, POSITDocker -from asapdiscovery.docking.scorer import ChemGauss4Scorer, MetaScorer, MLModelScorer +from asapdiscovery.docking.scorer import ChemGauss4Scorer, MetaScorer +from asapdiscovery.workflows.scorers import MLModelScorer from asapdiscovery.ml.models import ASAPMLModelRegistry from asapdiscovery.modeling.protein_prep import LigandTransferProteinPrepper from asapdiscovery.simulation.simulate import OpenMMPlatform, VanillaMDSimulator -from asapdiscovery.workflows.docking_workflows.workflows import ( - DockingWorkflowInputsBase, -) +from asapdiscovery.docking.workflows.docking_workflows import DockingWorkflowInputsBase from pydantic.v1 import Field, PositiveInt, root_validator diff --git a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/small_scale_docking.py b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/small_scale_docking.py index d9d080b6..461fc80e 100644 --- a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/small_scale_docking.py +++ b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/small_scale_docking.py @@ -39,10 +39,9 @@ from asapdiscovery.docking.openeye import POSITDocker from asapdiscovery.docking.scorer import ( ChemGauss4Scorer, - FINTScorer, MetaScorer, - MLModelScorer, ) +from asapdiscovery.workflows.scorers import MLModelScorer, FINTScorer from asapdiscovery.ml.models import ASAPMLModelRegistry from asapdiscovery.modeling.protein_prep import ProteinPrepper from asapdiscovery.simulation.simulate import OpenMMPlatform, VanillaMDSimulator diff --git a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/symexp_crystal_packing.py b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/symexp_crystal_packing.py index 43e0c92a..cd3d01b7 100644 --- a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/symexp_crystal_packing.py +++ b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/symexp_crystal_packing.py @@ -29,7 +29,8 @@ from asapdiscovery.docking.docking import write_results_to_multi_sdf from asapdiscovery.docking.docking_data_validation import DockingResultCols from asapdiscovery.docking.openeye import POSITDocker -from asapdiscovery.docking.scorer import ChemGauss4Scorer, SymClashScorer +from asapdiscovery.docking.scorer import ChemGauss4Scorer +from asapdiscovery.workflows.scorers import SymClashScorer from asapdiscovery.modeling.protein_prep import ProteinPrepper from asapdiscovery.workflows.docking_workflows.workflows import ( PosteraDockingWorkflowInputs, diff --git a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/workflows.py b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/workflows.py index b6ee00e0..43de3002 100644 --- a/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/workflows.py +++ b/asapdiscovery-workflows/asapdiscovery/workflows/docking_workflows/workflows.py @@ -2,136 +2,11 @@ Schema for workflows base classes """ -import logging -from pathlib import Path -from typing import Optional, Union +from typing import Optional -from asapdiscovery.data.metadata.resources import active_site_chains -from asapdiscovery.data.services.postera.manifold_data_validation import TargetTags -from asapdiscovery.data.util.dask_utils import DaskType, FailureMode -from pydantic.v1 import BaseModel, Field, PositiveInt, root_validator +from pydantic.v1 import Field - -class DockingWorkflowInputsBase(BaseModel): - ligands: Optional[str] = Field( - None, description="Path to a molecule file containing query ligands." - ) - - pdb_file: Optional[Path] = Field( - None, description="Path to a PDB file to prep and dock to." - ) - - fragalysis_dir: Optional[Path] = Field( - None, description="Path to a directory containing a Fragalysis dump." - ) - structure_dir: Optional[Path] = Field( - None, - description="Path to a directory containing structures to dock instead of a full fragalysis database.", - ) - - cache_dir: Optional[str] = Field( - None, description="Path to a directory where a cache has been generated" - ) - - use_only_cache: bool = Field( - False, - description="Whether to only use the cached structures, otherwise try to prep uncached structures.", - ) - - save_to_cache: bool = Field( - True, - description="Generate a cache from structures prepped in this workflow run in this directory", - ) - - target: TargetTags = Field(None, description="The target to dock against.") - - write_final_sdf: bool = Field( - default=True, - description="Whether to write the final docked poses to an SDF file.", - ) - use_dask: bool = Field(True, description="Whether to use dask for parallelism.") - - dask_type: DaskType = Field( - DaskType.LOCAL, description="Dask client to use for parallelism." - ) - - dask_n_workers: Optional[PositiveInt] = Field(None, description="Number of workers") - - failure_mode: FailureMode = Field( - FailureMode.SKIP, description="Dask failure mode." - ) - - n_select: PositiveInt = Field( - 5, description="Number of targets to dock each ligand against." - ) - logname: str = Field( - "", description="Name of the log file." - ) # use root logger for proper forwarding of logs from dask - - loglevel: Union[int, str] = Field(logging.INFO, description="Logging level") - - output_dir: Path = Field(Path("output"), description="Output directory") - - overwrite: bool = Field( - False, description="Whether to overwrite existing output directory." - ) - ref_chain: Optional[str] = Field( - None, - description="Chain ID to align to in reference structure containing the active site", - ) - active_site_chain: Optional[str] = Field( - None, - description="Active site chain ID to align to ref_chain in reference structure", - ) - - class Config: - arbitrary_types_allowed = True - - @classmethod - def from_json_file(cls, file: str | Path): - return cls.parse_file(str(file)) - - def to_json_file(self, file: str | Path): - with open(file, "w") as f: - f.write(self.json(indent=2)) - - @root_validator - @classmethod - def check_inputs(cls, values): - """ - Validate inputs - """ - ligands = values.get("ligands") - fragalysis_dir = values.get("fragalysis_dir") - structure_dir = values.get("structure_dir") - postera = values.get("postera") - pdb_file = values.get("pdb_file") - - if postera and ligands: - raise ValueError("Cannot specify both ligands and postera.") - - if not postera and not ligands: - raise ValueError("Must specify either ligands or postera.") - - # can only specify one of fragalysis dir, structure dir and PDB file - if sum([bool(fragalysis_dir), bool(structure_dir), bool(pdb_file)]) != 1: - raise ValueError( - "Must specify exactly one of fragalysis_dir, structure_dir or pdb_file" - ) - - return values - - @root_validator(pre=True) - def check_and_set_chains(cls, values): - active_site_chain = values.get("active_site_chain") - ref_chain = values.get("ref_chain") - target = values.get("target") - if not active_site_chain: - values["active_site_chain"] = active_site_chains[target] - # set same chain for active site if not specified - if not ref_chain: - values["ref_chain"] = active_site_chains[target] - return values +from asapdiscovery.docking.workflows.docking_workflows import DockingWorkflowInputsBase class PosteraDockingWorkflowInputs(DockingWorkflowInputsBase): diff --git a/asapdiscovery-workflows/asapdiscovery/workflows/scorers.py b/asapdiscovery-workflows/asapdiscovery/workflows/scorers.py new file mode 100644 index 00000000..8e850925 --- /dev/null +++ b/asapdiscovery-workflows/asapdiscovery/workflows/scorers.py @@ -0,0 +1,556 @@ +from io import StringIO + +import MDAnalysis as mda +import warnings + +from multimethod import multimethod +from pydantic.v1 import Field, validator +from typing import ClassVar, Optional, Any, Union + +from pathlib import Path + +from mtenn.config import ModelType + +from asapdiscovery.data.backend.openeye import oemol_to_pdb_string +from asapdiscovery.data.schema.complex import Complex +from asapdiscovery.data.schema.ligand import Ligand +from asapdiscovery.data.services.postera.manifold_data_validation import TargetTags +from asapdiscovery.data.util.dask_utils import dask_vmap, backend_wrapper +from asapdiscovery.dataviz.plip import compute_fint_score +from asapdiscovery.docking.docking import DockingResult +from asapdiscovery.docking.scorer import ScoreType, ScorerBase, ScoreUnits, logger, Score, \ + _get_disk_path_from_docking_result +from asapdiscovery.ml.inference import InferenceBase, get_inference_cls_from_model_type +from asapdiscovery.ml.models import MLModelSpecBase +from asapdiscovery.spectrum.fitness import target_has_fitness_data + + +def endpoint_and_model_type_to_score_type(endpoint: str, model_type: str) -> ScoreType: + """ + Convert an endpoint to a score type. + + Parameters + ---------- + endpoint : str + Endpoint to convert + + Returns + ------- + ScoreType + Score type + """ + if model_type == ModelType.GAT: + if endpoint == "pIC50": # TODO: make this an enum + return ScoreType.GAT_pIC50 + elif endpoint == "LogD": + return ScoreType.GAT_LogD + else: + raise ValueError(f"Endpoint {endpoint} not recognized, for GAT") + elif model_type == ModelType.schnet: + if endpoint == "pIC50": + return ScoreType.schnet_pIC50 + else: + raise ValueError(f"Endpoint {endpoint} not recognized, for Schnet") + elif model_type == ModelType.e3nn: + if endpoint == "pIC50": + return ScoreType.e3nn_pIC50 + else: + raise ValueError(f"Endpoint {endpoint} not recognized for E3NN") + else: + raise ValueError(f"Model type {model_type} not recognized") + + +# keep track of all the ml scorers +_ml_scorer_classes_meta = [] + + +# decorator to register all the ml scorers +def register_ml_scorer(cls): + _ml_scorer_classes_meta.append(cls) + return cls + + +class MLModelScorer(ScorerBase): + """ + Baseclass to score from some kind of ML model, including 2D or 3D models + """ + + model_type: ClassVar[ModelType.INVALID] = ModelType.INVALID + score_type: ScoreType = Field(..., description="Type of score") + endpoint: Optional[str] = Field(None, description="Endpoint biological property") + units: ClassVar[ScoreUnits.INVALID] = ScoreUnits.INVALID + + targets: Any = Field( + ..., + description="Which targets can this model do predictions for", # FIXME: Optional[set[TargetTags]] + ) + model_name: str = Field(..., description="String indicating which model to use") + inference_cls: InferenceBase = Field(..., description="Inference class") + + @classmethod + def from_latest_by_target(cls, target: TargetTags): + if cls.model_type == ModelType.INVALID: + raise Exception("trying to instantiate some kind a baseclass") + inference_cls = get_inference_cls_from_model_type(cls.model_type) + inference_instance = inference_cls.from_latest_by_target(target) + if inference_instance is None: + logger.warn( + f"no ML model of type {cls.model_type} found for target: {target}, skipping" + ) + return None + else: + try: + instance = cls( + targets=inference_instance.targets, + model_name=inference_instance.model_name, + inference_cls=inference_instance, + endpoint=inference_instance.model_spec.endpoint, + score_type=endpoint_and_model_type_to_score_type( + inference_instance.model_spec.endpoint, cls.model_type + ), + ) + return instance + except Exception as e: + logger.error(f"error instantiating MLModelScorer: {e}") + return None + + @staticmethod + def from_latest_by_target_and_type(target: TargetTags, type: ModelType): + """ + Get the latest ML Scorer by target and type. + + Parameters + ---------- + target : TargetTags + Target to get the scorer for + type : ModelType + Type of model to get the scorer for + """ + if type == ModelType.INVALID: + raise Exception("trying to instantiate some kind a baseclass") + scorer_class = get_ml_scorer_cls_from_model_type(type) + return scorer_class.from_latest_by_target(target) + + @classmethod + def from_model_name(cls, model_name: str): + if cls.model_type == ModelType.INVALID: + raise Exception("trying to instantiate some kind a baseclass") + inference_cls = get_inference_cls_from_model_type(cls.model_type) + inference_instance = inference_cls.from_model_name(model_name) + if inference_instance is None: + logger.warn( + f"no ML model of type {cls.model_type} found for model_name: {model_name}, skipping" + ) + return None + else: + try: + instance = cls( + targets=inference_instance.targets, + model_name=inference_instance.model_name, + inference_cls=inference_instance, + endpoint=inference_instance.model_spec.endpoint, + score_type=endpoint_and_model_type_to_score_type( + inference_instance.model_spec.endpoint, cls.model_type + ), + ) + return instance + except Exception as e: + logger.error(f"error instantiating MLModelScorer: {e}") + return None + + @staticmethod + def load_model_specs( + models: list[MLModelSpecBase], + ) -> list["MLModelScorer"]: # noqa: F821 + """ + Load a list of models into scorers. + + Parameters + ---------- + models : list[MLModelSpecBase] + List of models to load + """ + scorers = [] + for model in models: + scorer_class = get_ml_scorer_cls_from_model_type(model.type) + scorer = scorer_class.from_model_name(model.name) + if scorer is not None: + scorers.append(scorer) + return scorers + + +@register_ml_scorer +class GATScorer(MLModelScorer): + """ + Scoring using GAT ML Model + """ + + model_type: ClassVar[ModelType.GAT] = ModelType.GAT + units: ClassVar[ScoreUnits.pIC50] = ScoreUnits.pIC50 + + @dask_vmap(["inputs"]) + @backend_wrapper("inputs") + def _score( + self, + inputs: Union[list[DockingResult], list[str], list[Ligand]], + return_for_disk_backend: bool = False, + **kwargs, + ) -> list[Score]: + """ + Score the inputs, dispatching based on type. + """ + return self._dispatch( + inputs, return_for_disk_backend=return_for_disk_backend, **kwargs + ) + + @multimethod + def _dispatch( + self, + inputs: list[DockingResult], + return_for_disk_backend: bool = False, + **kwargs, + ) -> list[Score]: + """ + Dispatch for DockingResults + """ + results = [] + for inp in inputs: + gat_score = self.inference_cls.predict_from_smiles(inp.posed_ligand.smiles) + sc = Score.from_score_and_docking_result( + gat_score, + self.score_type, + self.units, + inp, + ) + # overwrite the input with the path to the file + if return_for_disk_backend: + sc.input = _get_disk_path_from_docking_result(inp) + results.append(sc) + return results + + @_dispatch.register + def _dispatch(self, inputs: list[str], **kwargs) -> list[Score]: + """ + Dispatch for SMILES strings + """ + results = [] + for inp in inputs: + gat_score = self.inference_cls.predict_from_smiles(inp) + results.append( + Score.from_score_and_smiles( + gat_score, + inp, + self.score_type, + self.units, + ) + ) + return results + + @_dispatch.register + def _dispatch(self, inputs: list[Ligand], **kwargs) -> list[Score]: + """ + Dispatch for Ligands + """ + results = [] + for inp in inputs: + gat_score = self.inference_cls.predict_from_smiles(inp.smiles) + results.append( + Score.from_score_and_ligand( + gat_score, + inp, + self.score_type, + self.units, + ) + ) + return results + + +class E3MLModelScorer(MLModelScorer): + """ + Scoring using ML Models that operate over 3D structures + These all share an interface so we can use multimethods to dispatch + for the different input types for all subclasses. + """ + + model_type: ClassVar[ModelType.INVALID] = ModelType.INVALID + units: ClassVar[ScoreUnits.INVALID] = ScoreUnits.INVALID + + @dask_vmap(["inputs"]) + @backend_wrapper("inputs") + def _score( + self, + inputs: Union[list[DockingResult], list[Complex], list[Path]], + return_for_disk_backend: bool = False, + **kwargs, + ) -> list[Score]: + return self._dispatch( + inputs, return_for_disk_backend=return_for_disk_backend, **kwargs + ) + + @multimethod + def _dispatch( + self, + inputs: list[DockingResult], + return_for_disk_backend: bool = False, + **kwargs, + ) -> list[Score]: + results = [] + for inp in inputs: + score = self.inference_cls.predict_from_oemol(inp.to_posed_oemol()) + + sc = Score.from_score_and_docking_result( + score, self.score_type, self.units, inp + ) + # overwrite the input with the path to the file + if return_for_disk_backend: + sc.input = _get_disk_path_from_docking_result(inp) + results.append(sc) + + return results + + @_dispatch.register + def _dispatch(self, inputs: list[Complex], **kwargs) -> list[Score]: + results = [] + for inp in inputs: + score = self.inference_cls.predict_from_oemol(inp.to_combined_oemol()) + results.append( + Score.from_score_and_complex(score, self.score_type, self.units, inp) + ) + return results + + @_dispatch.register + def _dispatch(self, inputs: list[Path], **kwargs) -> list[Score]: + # assuming reading PDB files from disk + complexes = [ + Complex.from_pdb( + p, + ligand_kwargs={"compound_name": f"{p.stem}_ligand"}, + target_kwargs={"target_name": f"{p.stem}_target"}, + ) + for i, p in enumerate(inputs) + ] + return self._dispatch(complexes, **kwargs) + + +@register_ml_scorer +class SchnetScorer(E3MLModelScorer): + """ + Scoring using Schnet ML Model + """ + + model_type: ClassVar[ModelType.schnet] = ModelType.schnet + units: ClassVar[ScoreUnits.pIC50] = ScoreUnits.pIC50 + + +@register_ml_scorer +class E3NNScorer(E3MLModelScorer): + """ + Scoring using e3nn ML Model + """ + + model_type: ClassVar[ModelType.e3nn] = ModelType.e3nn + units: ClassVar[ScoreUnits.pIC50] = ScoreUnits.pIC50 + + +def get_ml_scorer_cls_from_model_type(model_type: ModelType): + instantiable_classes = [ + m for m in _ml_scorer_classes_meta if m.model_type != ModelType.INVALID + ] + scorer_class = [m for m in instantiable_classes if m.model_type == model_type] + if len(scorer_class) != 1: + raise Exception("Somehow got multiple scorers") + return scorer_class[0] + + +class SymClashScorer(ScorerBase): + """ + Scoring, checking for clashes between ligand and target + in neighboring unit cells. + """ + + score_type: ScoreType = Field(ScoreType.sym_clash, description="Type of score") + units: ClassVar[ScoreUnits.arbitrary] = ScoreUnits.arbitrary + + count_clashing_pairs: bool = Field( + False, + description="Whether to count clashing distance pairs, rather than unique clashing ligand atoms", + ) + + vdw_radii_fudge_factor: float = Field( + 1.0, + description="fudge factor multiplier for vdw radii, lower to decrease clash sensitivity, higher to increase", + ) + + @dask_vmap(["inputs"]) + @backend_wrapper("inputs") + def _score(self, inputs, **kwargs) -> list[Score]: + """ + Score the inputs, dispatching based on type. + """ + return self._dispatch(inputs, **kwargs) + + @multimethod + def _dispatch(self, inputs: list[Complex], **kwargs) -> list[Score]: + """ + Dispatch for Complex + """ + results = [] + warnings.warn( + "SymClashScorer relies on expanded protein units having chain X as constructed by SymmetryExpander" + ) + for inp in inputs: + # load into MDA universe + u = mda.Universe( + mda.lib.util.NamedStream( + StringIO(oemol_to_pdb_string(inp.to_combined_oemol())), + "complex.pdb", + ) + ) + lig = u.select_atoms("not protein") + symmetry_expanded_prot = u.select_atoms("protein and chainID X") + # hacky but expand to real space with mega box + # multiply first 3 dimensions by 20 + expanded_box = u.dimensions + expanded_box[:3] *= 20 + pair_indices, pair_distances = mda.lib.distances.capped_distance( + lig, + symmetry_expanded_prot, + 4, + box=expanded_box, # large cutoff to loop in a good amount of distances up to 8Å + ) + # check if distance for an atom pair is less than summed vdw radii + num_clashes = 0 + clashing_lig_at = set() + clashing_prot_at = set() + + for k, [i, j] in enumerate(pair_indices): + lig_atom = lig[i] + prot_atom = symmetry_expanded_prot[j] + distance = pair_distances[k] + if ( + ( + distance + < ( + ( + mda.topology.tables.vdwradii[lig_atom.element.upper()] + * self.vdw_radii_fudge_factor + ) + + ( + mda.topology.tables.vdwradii[prot_atom.element.upper()] + * self.vdw_radii_fudge_factor + ) + ) + ) + and lig_atom.element != "H" + and prot_atom.element != "H" + ): + num_clashes += 1 + clashing_lig_at.add(i) + clashing_prot_at.add(j) + + if self.count_clashing_pairs: + val = num_clashes + else: + val = len(clashing_lig_at) # seems ok as metric for now + + results.append( + Score.from_score_and_complex(val, self.score_type, self.units, inp) + ) + return results + + +class FINTScorer(ScorerBase): + """ + Score using Fitness Interaction Score + + Overloaded to accept DockingResults, Complexes, or Paths to PDB files. + """ + + score_type: ScoreType = Field(ScoreType.FINT, description="Type of score") + units: ClassVar[ScoreUnits.arbitrary] = ScoreUnits.arbitrary + target: TargetTags = Field(..., description="Which target to use for scoring") + + @validator("target") + @classmethod + def validate_target(cls, v): + if not target_has_fitness_data(v): + raise ValueError( + "target does not have fitness data so cannot use FINTScorer" + ) + return v + + @dask_vmap(["inputs"]) + @backend_wrapper("inputs") + def _score( + self, + inputs: Union[list[DockingResult], list[Complex], list[Path]], + return_for_disk_backend: bool = False, + **kwargs, + ) -> list[Score]: + """ + Score the inputs, dispatching based on type. + """ + return self._dispatch( + inputs, return_for_disk_backend=return_for_disk_backend, **kwargs + ) + + @multimethod + def _dispatch( + self, + inputs: list[DockingResult], + return_for_disk_backend: bool = False, + **kwargs, + ) -> list[Score]: + """ + Dispatch for DockingResults + """ + results = [] + for inp in inputs: + _, fint_score = compute_fint_score( + inp.to_protein(), inp.posed_ligand.to_oemol(), self.target + ) + + sc = Score.from_score_and_docking_result( + fint_score, self.score_type, self.units, inp + ) + # overwrite the input with the path to the file + if return_for_disk_backend: + sc.input = _get_disk_path_from_docking_result(inp) + + results.append(sc) + + return results + + @_dispatch.register + def _dispatch(self, inputs: list[Complex], **kwargs): + """ + Dispatch for Complexes + """ + results = [] + for inp in inputs: + _, fint_score = compute_fint_score( + inp.target.to_oemol(), inp.ligand.to_oemol(), self.target + ) + results.append( + Score.from_score_and_complex( + fint_score, self.score_type, self.units, inp + ) + ) + return results + + @_dispatch.register + def _dispatch(self, inputs: list[Path], **kwargs): + """ + Dispatch for PDB files from disk + """ + # assuming reading PDB files from disk + complexes = [ + Complex.from_pdb( + p, + ligand_kwargs={"compound_name": f"{p.stem}_ligand"}, + target_kwargs={"target_name": f"{p.stem}_target"}, + ) + for p in inputs + ] + + return self._dispatch(complexes, **kwargs)