-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Exposed paccmann based BiModal MCA affinity predictor (#36)
* exposed paccmann predictor - included a regression unit test * fixed typo in doc * feat: minor fixes in PaccMannn affinity predictor. Co-authored-by: Yoel Shoshan <[email protected]> Co-authored-by: Matteo Manica <[email protected]>
- Loading branch information
1 parent
4438b5f
commit df48a56
Showing
6 changed files
with
402 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
"""Prediction algorithms based on PaccMann""" | ||
|
||
import logging | ||
from dataclasses import field | ||
from typing import Any, ClassVar, List, Optional, TypeVar | ||
|
||
from ...core import AlgorithmConfiguration, GeneratorAlgorithm, Untargeted | ||
from ...registry import ApplicationsRegistry | ||
from .implementation import BimodalMCAAffinityPredictor, MCAPredictor | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.addHandler(logging.NullHandler()) | ||
|
||
T = TypeVar("T", bound=Any) | ||
S = TypeVar("S", bound=Any) | ||
|
||
|
||
class PaccMann(GeneratorAlgorithm[S, T]): | ||
"""PaccMann predictor.""" | ||
|
||
def __init__( | ||
self, | ||
configuration: AlgorithmConfiguration[S, T], | ||
target: Optional[T] = None, | ||
): | ||
"""Instantiate PaccMann for prediction. | ||
Args: | ||
configuration: domain and application | ||
specification defining parameters, types and validations. | ||
target: a target for which to generate items. | ||
Example: | ||
An example for predicting affinity for a given ligand and target protein pair:: | ||
config = AffinityPredictor() | ||
algorithm = TopicsZeroShot(configuration=config, target="This is a text I want to understand better") | ||
items = list(algorithm.sample(1)) | ||
print(items) | ||
""" | ||
|
||
configuration = self.validate_configuration(configuration) | ||
# TODO there might also be a validation/check on the target input | ||
|
||
super().__init__( | ||
configuration=configuration, # type:ignore | ||
target=target, # type:ignore | ||
) | ||
|
||
def get_generator( | ||
self, | ||
configuration: AlgorithmConfiguration[S, T], | ||
target: Optional[T], | ||
) -> Untargeted: | ||
"""Get the function to perform the prediction via PaccMann's generator. | ||
Args: | ||
configuration: helps to set up specific application of PaccMann. | ||
target: context or condition for the generation. | ||
Returns: | ||
callable with target predicting properties using PaccMann. | ||
""" | ||
logger.info("ensure artifacts for the application are present.") | ||
self.local_artifacts = configuration.ensure_artifacts() | ||
implementation: MCAPredictor = configuration.get_conditional_generator( # type: ignore | ||
self.local_artifacts | ||
) | ||
return implementation.predict_values | ||
|
||
|
||
@ApplicationsRegistry.register_algorithm_application(PaccMann) | ||
class AffinityPredictor(AlgorithmConfiguration[str, str]): | ||
"""Configuration to predict affinity for a given ligand/protrin target pair.""" | ||
|
||
algorithm_type: ClassVar[str] = "prediction" | ||
domain: ClassVar[str] = "materials" | ||
algorithm_version: str = "v0" | ||
|
||
protein_targets: List[str] = field( | ||
default_factory=list, | ||
metadata=dict(description="List of protein targets as AA sequences."), | ||
) | ||
ligands: List[str] = field( | ||
default_factory=list, | ||
metadata=dict(description="List of ligands in SMILES format."), | ||
) | ||
confidence: bool = field( | ||
default=False, | ||
metadata=dict( | ||
description="Whether the confidence for the prediction should be returned." | ||
), | ||
) | ||
|
||
def get_conditional_generator( | ||
self, resources_path: str | ||
) -> BimodalMCAAffinityPredictor: | ||
"""Instantiate the actual predictor implementation. | ||
Args: | ||
resources_path: local path to model files. | ||
Returns: | ||
instance with :meth:`gt4sd.algorithms.prediction.affinity._predicto.implementation.BimodalMCAAffinityPredictor.predict` method for predicting affinity. | ||
""" | ||
return BimodalMCAAffinityPredictor( | ||
resources_path=resources_path, | ||
protein_targets=self.protein_targets, | ||
ligands=self.ligands, | ||
confidence=self.confidence, | ||
) |
161 changes: 161 additions & 0 deletions
161
src/gt4sd/algorithms/prediction/paccmann/implementation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
"""Implementation of the zero-shot classifier.""" | ||
|
||
import json | ||
import logging | ||
import os | ||
from typing import Any, List, Optional, Union | ||
|
||
import torch | ||
from paccmann_predictor.models import MODEL_FACTORY | ||
from pytoda.proteins.protein_language import ProteinLanguage | ||
from pytoda.smiles.smiles_language import SMILESLanguage | ||
from pytoda.transforms import LeftPadding, ToTensor | ||
|
||
from ....frameworks.torch import device_claim | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.addHandler(logging.NullHandler()) | ||
|
||
|
||
class MCAPredictor: | ||
"""Base implementation of an MCAPredictor.""" | ||
|
||
def predict(self) -> Any: | ||
"""Get prediction. | ||
Returns: | ||
predicted affinity | ||
""" | ||
raise NotImplementedError("No prediction implemented for base MCAPredictor") | ||
|
||
def predict_values(self) -> Any: | ||
"""Get prediction for algorithm sample method. | ||
Returns: | ||
predicted values as list. | ||
""" | ||
raise NotImplementedError( | ||
"No values prediction implemented for base MCAPredictor" | ||
) | ||
|
||
|
||
class BimodalMCAAffinityPredictor(MCAPredictor): | ||
"""Bimodal MCA (Multiscale Convolutional Attention) affinity prediction model. | ||
For details see: https://pubs.acs.org/doi/10.1021/acs.molpharmaceut.9b00520 | ||
and https://iopscience.iop.org/article/10.1088/2632-2153/abe808. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
resources_path: str, | ||
protein_targets: List[str], | ||
ligands: List[str], | ||
confidence: bool, | ||
device: Optional[Union[torch.device, str]] = None, | ||
): | ||
"""Initialize BimodalMCAAffinityPredictor. | ||
Args: | ||
resources_path: path where to load model weights and cofiguration. | ||
protein_targets: list of protein targets as AA sequences. | ||
ligands: list of ligands in SMILES format. | ||
confidence: whether the confidence for the prediction should be returned. | ||
device: device where the inference | ||
is running either as a dedicated class or a string. If not provided is inferred. | ||
""" | ||
self.device = device_claim(device) | ||
self.resources_path = resources_path | ||
self.protein_targets = protein_targets | ||
self.ligands = ligands | ||
self.confidence = confidence | ||
|
||
# setting affinity predictor parameters | ||
with open(os.path.join(resources_path, "mca_model_params.json")) as f: | ||
self.predictor_params = json.load(f) | ||
self.affinity_predictor = MODEL_FACTORY["bimodal_mca"](self.predictor_params) | ||
self.affinity_predictor.load( | ||
os.path.join(resources_path, "mca_weights.pt"), | ||
map_location=self.device, | ||
) | ||
affinity_protein_language = ProteinLanguage.load( | ||
os.path.join(resources_path, "protein_language.pkl") | ||
) | ||
affinity_smiles_language = SMILESLanguage.load( | ||
os.path.join(resources_path, "smiles_language.pkl") | ||
) | ||
self.affinity_predictor._associate_language(affinity_smiles_language) | ||
self.affinity_predictor._associate_language(affinity_protein_language) | ||
self.affinity_predictor.eval() | ||
|
||
self.pad_smiles_predictor = LeftPadding( | ||
self.affinity_predictor.smiles_padding_length, | ||
self.affinity_predictor.smiles_language.padding_index, | ||
) | ||
|
||
self.pad_protein_predictor = LeftPadding( | ||
self.affinity_predictor.protein_padding_length, | ||
self.affinity_predictor.protein_language.padding_index, | ||
) | ||
|
||
self.to_tensor = ToTensor() | ||
|
||
def predict(self) -> Any: | ||
"""Get predicted affinity. | ||
Returns: | ||
predicted affinity. | ||
""" | ||
# prepare ligand representation | ||
ligand_tensor = torch.cat( | ||
[ | ||
torch.unsqueeze( | ||
self.to_tensor( | ||
self.pad_smiles_predictor( | ||
self.affinity_predictor.smiles_language.smiles_to_token_indexes( | ||
ligand_smiles | ||
) | ||
) | ||
), | ||
0, | ||
) | ||
for ligand_smiles in self.ligands | ||
], | ||
dim=0, | ||
) | ||
|
||
# prepare target protein representation | ||
target_tensor = torch.cat( | ||
[ | ||
torch.unsqueeze( | ||
self.to_tensor( | ||
self.pad_protein_predictor( | ||
self.affinity_predictor.protein_language.sequence_to_token_indexes( | ||
protein_target | ||
) | ||
) | ||
), | ||
0, | ||
) | ||
for protein_target in self.protein_targets | ||
], | ||
dim=0, | ||
) | ||
|
||
with torch.no_grad(): | ||
predictions, predictions_dict = self.affinity_predictor( | ||
ligand_tensor, | ||
target_tensor, | ||
confidence=self.confidence, | ||
) | ||
|
||
return predictions, predictions_dict | ||
|
||
def predict_values(self) -> List[float]: | ||
"""Get prediction for algorithm sample method. | ||
Returns: | ||
predicted values as list. | ||
""" | ||
predictions, _ = self.predict() | ||
return list(predictions[:, 0]) |
Oops, something went wrong.