Skip to content

Commit

Permalink
Exposed paccmann based BiModal MCA affinity predictor (#36)
Browse files Browse the repository at this point in the history
* exposed paccmann predictor - included a regression unit test

* fixed typo in doc

* feat: minor fixes in PaccMannn affinity predictor.

Co-authored-by: Yoel Shoshan <[email protected]>
Co-authored-by: Matteo Manica <[email protected]>
  • Loading branch information
3 people authored Feb 22, 2022
1 parent 4438b5f commit df48a56
Show file tree
Hide file tree
Showing 6 changed files with 402 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/gt4sd/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
HuggingFaceXLNetGenerator,
)
from .generation.polymer_blocks.core import PolymerBlocksGenerator # noqa: F401
from .prediction.paccmann.core import PaccMann # noqa: F401
from .prediction.topics_zero_shot.core import TopicsPredictor # noqa: F401

# extras requirements
Expand Down
11 changes: 0 additions & 11 deletions src/gt4sd/algorithms/prediction/core.py

This file was deleted.

File renamed without changes.
111 changes: 111 additions & 0 deletions src/gt4sd/algorithms/prediction/paccmann/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Prediction algorithms based on PaccMann"""

import logging
from dataclasses import field
from typing import Any, ClassVar, List, Optional, TypeVar

from ...core import AlgorithmConfiguration, GeneratorAlgorithm, Untargeted
from ...registry import ApplicationsRegistry
from .implementation import BimodalMCAAffinityPredictor, MCAPredictor

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

T = TypeVar("T", bound=Any)
S = TypeVar("S", bound=Any)


class PaccMann(GeneratorAlgorithm[S, T]):
"""PaccMann predictor."""

def __init__(
self,
configuration: AlgorithmConfiguration[S, T],
target: Optional[T] = None,
):
"""Instantiate PaccMann for prediction.
Args:
configuration: domain and application
specification defining parameters, types and validations.
target: a target for which to generate items.
Example:
An example for predicting affinity for a given ligand and target protein pair::
config = AffinityPredictor()
algorithm = TopicsZeroShot(configuration=config, target="This is a text I want to understand better")
items = list(algorithm.sample(1))
print(items)
"""

configuration = self.validate_configuration(configuration)
# TODO there might also be a validation/check on the target input

super().__init__(
configuration=configuration, # type:ignore
target=target, # type:ignore
)

def get_generator(
self,
configuration: AlgorithmConfiguration[S, T],
target: Optional[T],
) -> Untargeted:
"""Get the function to perform the prediction via PaccMann's generator.
Args:
configuration: helps to set up specific application of PaccMann.
target: context or condition for the generation.
Returns:
callable with target predicting properties using PaccMann.
"""
logger.info("ensure artifacts for the application are present.")
self.local_artifacts = configuration.ensure_artifacts()
implementation: MCAPredictor = configuration.get_conditional_generator( # type: ignore
self.local_artifacts
)
return implementation.predict_values


@ApplicationsRegistry.register_algorithm_application(PaccMann)
class AffinityPredictor(AlgorithmConfiguration[str, str]):
"""Configuration to predict affinity for a given ligand/protrin target pair."""

algorithm_type: ClassVar[str] = "prediction"
domain: ClassVar[str] = "materials"
algorithm_version: str = "v0"

protein_targets: List[str] = field(
default_factory=list,
metadata=dict(description="List of protein targets as AA sequences."),
)
ligands: List[str] = field(
default_factory=list,
metadata=dict(description="List of ligands in SMILES format."),
)
confidence: bool = field(
default=False,
metadata=dict(
description="Whether the confidence for the prediction should be returned."
),
)

def get_conditional_generator(
self, resources_path: str
) -> BimodalMCAAffinityPredictor:
"""Instantiate the actual predictor implementation.
Args:
resources_path: local path to model files.
Returns:
instance with :meth:`gt4sd.algorithms.prediction.affinity._predicto.implementation.BimodalMCAAffinityPredictor.predict` method for predicting affinity.
"""
return BimodalMCAAffinityPredictor(
resources_path=resources_path,
protein_targets=self.protein_targets,
ligands=self.ligands,
confidence=self.confidence,
)
161 changes: 161 additions & 0 deletions src/gt4sd/algorithms/prediction/paccmann/implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Implementation of the zero-shot classifier."""

import json
import logging
import os
from typing import Any, List, Optional, Union

import torch
from paccmann_predictor.models import MODEL_FACTORY
from pytoda.proteins.protein_language import ProteinLanguage
from pytoda.smiles.smiles_language import SMILESLanguage
from pytoda.transforms import LeftPadding, ToTensor

from ....frameworks.torch import device_claim

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


class MCAPredictor:
"""Base implementation of an MCAPredictor."""

def predict(self) -> Any:
"""Get prediction.
Returns:
predicted affinity
"""
raise NotImplementedError("No prediction implemented for base MCAPredictor")

def predict_values(self) -> Any:
"""Get prediction for algorithm sample method.
Returns:
predicted values as list.
"""
raise NotImplementedError(
"No values prediction implemented for base MCAPredictor"
)


class BimodalMCAAffinityPredictor(MCAPredictor):
"""Bimodal MCA (Multiscale Convolutional Attention) affinity prediction model.
For details see: https://pubs.acs.org/doi/10.1021/acs.molpharmaceut.9b00520
and https://iopscience.iop.org/article/10.1088/2632-2153/abe808.
"""

def __init__(
self,
resources_path: str,
protein_targets: List[str],
ligands: List[str],
confidence: bool,
device: Optional[Union[torch.device, str]] = None,
):
"""Initialize BimodalMCAAffinityPredictor.
Args:
resources_path: path where to load model weights and cofiguration.
protein_targets: list of protein targets as AA sequences.
ligands: list of ligands in SMILES format.
confidence: whether the confidence for the prediction should be returned.
device: device where the inference
is running either as a dedicated class or a string. If not provided is inferred.
"""
self.device = device_claim(device)
self.resources_path = resources_path
self.protein_targets = protein_targets
self.ligands = ligands
self.confidence = confidence

# setting affinity predictor parameters
with open(os.path.join(resources_path, "mca_model_params.json")) as f:
self.predictor_params = json.load(f)
self.affinity_predictor = MODEL_FACTORY["bimodal_mca"](self.predictor_params)
self.affinity_predictor.load(
os.path.join(resources_path, "mca_weights.pt"),
map_location=self.device,
)
affinity_protein_language = ProteinLanguage.load(
os.path.join(resources_path, "protein_language.pkl")
)
affinity_smiles_language = SMILESLanguage.load(
os.path.join(resources_path, "smiles_language.pkl")
)
self.affinity_predictor._associate_language(affinity_smiles_language)
self.affinity_predictor._associate_language(affinity_protein_language)
self.affinity_predictor.eval()

self.pad_smiles_predictor = LeftPadding(
self.affinity_predictor.smiles_padding_length,
self.affinity_predictor.smiles_language.padding_index,
)

self.pad_protein_predictor = LeftPadding(
self.affinity_predictor.protein_padding_length,
self.affinity_predictor.protein_language.padding_index,
)

self.to_tensor = ToTensor()

def predict(self) -> Any:
"""Get predicted affinity.
Returns:
predicted affinity.
"""
# prepare ligand representation
ligand_tensor = torch.cat(
[
torch.unsqueeze(
self.to_tensor(
self.pad_smiles_predictor(
self.affinity_predictor.smiles_language.smiles_to_token_indexes(
ligand_smiles
)
)
),
0,
)
for ligand_smiles in self.ligands
],
dim=0,
)

# prepare target protein representation
target_tensor = torch.cat(
[
torch.unsqueeze(
self.to_tensor(
self.pad_protein_predictor(
self.affinity_predictor.protein_language.sequence_to_token_indexes(
protein_target
)
)
),
0,
)
for protein_target in self.protein_targets
],
dim=0,
)

with torch.no_grad():
predictions, predictions_dict = self.affinity_predictor(
ligand_tensor,
target_tensor,
confidence=self.confidence,
)

return predictions, predictions_dict

def predict_values(self) -> List[float]:
"""Get prediction for algorithm sample method.
Returns:
predicted values as list.
"""
predictions, _ = self.predict()
return list(predictions[:, 0])
Loading

0 comments on commit df48a56

Please sign in to comment.