Skip to content

Commit

Permalink
yapr lint google format on tdc/benchmark_group, chem_utils, generatio…
Browse files Browse the repository at this point in the history
…n, multi_pred
  • Loading branch information
amva13 committed Mar 5, 2024
1 parent 6f7f4f9 commit 6713f6f
Show file tree
Hide file tree
Showing 32 changed files with 527 additions and 466 deletions.
46 changes: 26 additions & 20 deletions tdc/benchmark_group/base_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@


class BenchmarkGroup:

"""Boilerplate of benchmark group class. It downloads, processes, and loads a set of benchmark classes along with their splits. It also provides evaluators and train/valid splitters."""

def __init__(self, name, path="./data", file_format="csv"):
Expand Down Expand Up @@ -118,15 +117,19 @@ def get_train_valid_split(self, seed, benchmark, split_type="default"):
frac = [0.875, 0.125, 0.0]

if split_method == "scaffold":
out = create_scaffold_split(train_val, seed, frac=frac, entity="Drug")
out = create_scaffold_split(train_val,
seed,
frac=frac,
entity="Drug")
elif split_method == "random":
out = create_fold(train_val, seed, frac=frac)
elif split_method == "combination":
out = create_combination_split(train_val, seed, frac=frac)
elif split_method == "group":
out = create_group_split(
train_val, seed, holdout_frac=0.2, group_column="Year"
)
out = create_group_split(train_val,
seed,
holdout_frac=0.2,
group_column="Year")
else:
raise NotImplementedError
return out["train"], out["valid"]
Expand Down Expand Up @@ -178,8 +181,11 @@ def evaluate(self, pred, testing=True, benchmark=None, save_dict=True):
elif self.file_format == "pkl":
test = pd.read_pickle(os.path.join(data_path, "test.pkl"))
y = test.Y.values
evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + "')")
out[data_name] = {metric_dict[data_name]: round(evaluator(y, pred_), 3)}
evaluator = eval("Evaluator(name = '" + metric_dict[data_name] +
"')")
out[data_name] = {
metric_dict[data_name]: round(evaluator(y, pred_), 3)
}

# If reporting accuracy across target classes
if "target_class" in test.columns:
Expand All @@ -190,13 +196,11 @@ def evaluate(self, pred, testing=True, benchmark=None, save_dict=True):
y_subset = test_subset.Y.values
pred_subset = test_subset.pred.values

evaluator = eval(
"Evaluator(name = '" + metric_dict[data_name_subset] + "')"
)
evaluator = eval("Evaluator(name = '" +
metric_dict[data_name_subset] + "')")
out[data_name_subset] = {
metric_dict[data_name_subset]: round(
evaluator(y_subset, pred_subset), 3
)
metric_dict[data_name_subset]:
round(evaluator(y_subset, pred_subset), 3)
}
return out
else:
Expand All @@ -207,10 +211,14 @@ def evaluate(self, pred, testing=True, benchmark=None, save_dict=True):
)
data_name = fuzzy_search(benchmark, self.dataset_names)
metric_dict = bm_metric_names[self.name]
evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + "')")
evaluator = eval("Evaluator(name = '" + metric_dict[data_name] +
"')")
return {metric_dict[data_name]: round(evaluator(true, pred), 3)}

def evaluate_many(self, preds, save_file_name=None, results_individual=None):
def evaluate_many(self,
preds,
save_file_name=None,
results_individual=None):
"""
This function returns the data in a format needed to submit to the Leaderboard
Expand All @@ -225,11 +233,9 @@ def evaluate_many(self, preds, save_file_name=None, results_individual=None):
min_requirement = 5

if len(preds) < min_requirement:
return ValueError(
"Must have predictions from at least "
+ str(min_requirement)
+ " runs for leaderboard submission"
)
return ValueError("Must have predictions from at least " +
str(min_requirement) +
" runs for leaderboard submission")
if results_individual is None:
individual_results = []
for pred in preds:
Expand Down
55 changes: 32 additions & 23 deletions tdc/benchmark_group/docking_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ class docking_group(BenchmarkGroup):
"""

def __init__(
self, path="./data", num_workers=None, num_cpus=None, num_max_call=5000
):
def __init__(self,
path="./data",
num_workers=None,
num_cpus=None,
num_max_call=5000):
"""Create a docking group benchmark loader.
Raises:
Expand Down Expand Up @@ -157,7 +159,12 @@ def get(self, benchmark, num_max_call=5000):
data = pd.read_csv(os.path.join(self.path, "zinc.tab"), sep="\t")
return {"oracle": oracle, "data": data, "name": dataset}

def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True):
def evaluate(self,
pred,
true=None,
benchmark=None,
m1_api=None,
save_dict=True):
"""Summary
Args:
Expand Down Expand Up @@ -227,7 +234,9 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True)

docking_scores = oracle(pred_)
print_sys("---- Calculating average docking scores ----")
if len(np.where(np.array(list(docking_scores.values())) > 0)[0]) > 0.7:
if len(
np.where(np.array(list(docking_scores.values())) > 0)
[0]) > 0.7:
## check if the scores are all positive.. if so, make them all negative
docking_scores = {j: -k for j, k in docking_scores.items()}
if save_dict:
Expand Down Expand Up @@ -275,7 +284,8 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True)
if save_dict:
results["pass_list"] = pred_filter
results["%pass"] = float(len(pred_filter)) / 100
results["top1_%pass"] = min([docking_scores[i] for i in pred_filter])
results["top1_%pass"] = min(
[docking_scores[i] for i in pred_filter])
print_sys("---- Calculating diversity ----")
from ..evaluator import Evaluator

Expand All @@ -284,19 +294,23 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True)
results["diversity"] = score
print_sys("---- Calculating novelty ----")
evaluator = Evaluator(name="Novelty")
training = pd.read_csv(os.path.join(self.path, "zinc.tab"), sep="\t")
training = pd.read_csv(os.path.join(self.path, "zinc.tab"),
sep="\t")
score = evaluator(pred_, training.smiles.values)
results["novelty"] = score
results["top smiles"] = [
i[0] for i in sorted(docking_scores.items(), key=lambda x: x[1])
i[0]
for i in sorted(docking_scores.items(), key=lambda x: x[1])
]
results_max_call[num_max_call] = results
results_all[data_name] = results_max_call
return results_all

def evaluate_many(
self, preds, save_file_name=None, m1_api=None, results_individual=None
):
def evaluate_many(self,
preds,
save_file_name=None,
m1_api=None,
results_individual=None):
"""evaluate many runs together and output submission ready pkl file.
Args:
Expand All @@ -310,11 +324,9 @@ def evaluate_many(
"""
min_requirement = 3
if len(preds) < min_requirement:
return ValueError(
"Must have predictions from at least "
+ str(min_requirement)
+ " runs for leaderboard submission"
)
return ValueError("Must have predictions from at least " +
str(min_requirement) +
" runs for leaderboard submission")
if results_individual is None:
individual_results = []
for pred in preds:
Expand Down Expand Up @@ -345,13 +357,10 @@ def evaluate_many(
for metric in metrics:
if metric == "top smiles":
results_agg_target_call[metric] = np.unique(
np.array(
[
individual_results[fold][target][num_calls][metric]
for fold in range(num_folds)
]
).reshape(-1)
).tolist()
np.array([
individual_results[fold][target][num_calls]
[metric] for fold in range(num_folds)
]).reshape(-1)).tolist()
else:
res = [
individual_results[fold][target][num_calls][metric]
Expand Down
8 changes: 4 additions & 4 deletions tdc/benchmark_group/drugcombo_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class drugcombo_group(BenchmarkGroup):
def __init__(self, path="./data"):
"""create a drug combination benchmark group"""
super().__init__(name="DrugCombo_Group", path=path, file_format="pkl")



def get_cell_line_meta_data(self):
import os
from ..utils.load import download_wrapper
from ..utils import load_dict
name = download_wrapper('drug_comb_meta_data', self.path, ['drug_comb_meta_data'])
return load_dict(os.path.join(self.path, name + '.pkl'))
name = download_wrapper('drug_comb_meta_data', self.path,
['drug_comb_meta_data'])
return load_dict(os.path.join(self.path, name + '.pkl'))
58 changes: 36 additions & 22 deletions tdc/chem_utils/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

rdBase.DisableLog("rdApp.error")
except:
raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ")
raise ImportError(
"Please install rdkit by 'conda install -c conda-forge rdkit'! ")


def single_molecule_validity(smiles):
Expand Down Expand Up @@ -57,7 +58,8 @@ def canonicalize(smiles):

def unique_lst_of_smiles(list_of_smiles):
canonical_smiles_lst = list(map(canonicalize, list_of_smiles))
canonical_smiles_lst = list(filter(lambda x: x is not None, canonical_smiles_lst))
canonical_smiles_lst = list(
filter(lambda x: x is not None, canonical_smiles_lst))
canonical_smiles_lst = list(set(canonical_smiles_lst))
return canonical_smiles_lst

Expand Down Expand Up @@ -88,11 +90,9 @@ def novelty(generated_smiles_lst, training_smiles_lst):
"""
generated_smiles_lst = unique_lst_of_smiles(generated_smiles_lst)
training_smiles_lst = unique_lst_of_smiles(training_smiles_lst)
novel_ratio = (
sum([1 if i in training_smiles_lst else 0 for i in generated_smiles_lst])
* 1.0
/ len(generated_smiles_lst)
)
novel_ratio = (sum(
[1 if i in training_smiles_lst else 0 for i in generated_smiles_lst]) *
1.0 / len(generated_smiles_lst))
return 1 - novel_ratio


Expand All @@ -107,14 +107,19 @@ def diversity(list_of_smiles):
div: float
"""
list_of_unique_smiles = unique_lst_of_smiles(list_of_smiles)
list_of_mol = [Chem.MolFromSmiles(smiles) for smiles in list_of_unique_smiles]
list_of_mol = [
Chem.MolFromSmiles(smiles) for smiles in list_of_unique_smiles
]
list_of_fp = [
AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048, useChirality=False)
AllChem.GetMorganFingerprintAsBitVect(mol,
2,
nBits=2048,
useChirality=False)
for mol in list_of_mol
]
avg_lst = []
for idx, fp in enumerate(list_of_fp):
for fp2 in list_of_fp[idx + 1 :]:
for fp2 in list_of_fp[idx + 1:]:
sim = DataStructs.TanimotoSimilarity(fp, fp2)
### option I
distance = 1 - sim
Expand Down Expand Up @@ -235,7 +240,9 @@ def get_fingerprints(mols, radius=2, length=4096):
Returns: a list of fingerprints
"""
return [AllChem.GetMorganFingerprintAsBitVect(m, radius, length) for m in mols]
return [
AllChem.GetMorganFingerprintAsBitVect(m, radius, length) for m in mols
]


def get_mols(smiles_list):
Expand Down Expand Up @@ -267,10 +274,8 @@ def calculate_internal_pairwise_similarities(smiles_list):
Symmetric matrix of pairwise similarities. Diagonal is set to zero.
"""
if len(smiles_list) > 10000:
logger.warning(
f"Calculating internal similarity on large set of "
f"SMILES strings ({len(smiles_list)})"
)
logger.warning(f"Calculating internal similarity on large set of "
f"SMILES strings ({len(smiles_list)})")

mols = get_mols(smiles_list)
fps = get_fingerprints(mols)
Expand Down Expand Up @@ -313,7 +318,8 @@ def kl_divergence(generated_smiles_lst, training_smiles_lst):
def canonical(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is not None:
return Chem.MolToSmiles(mol, isomericSmiles=True) ### todo double check
return Chem.MolToSmiles(mol,
isomericSmiles=True) ### todo double check
else:
return None

Expand All @@ -323,17 +329,20 @@ def canonical(smiles):
generated_lst_mol = list(filter(filter_out_func, generated_lst_mol))
training_lst_mol = list(filter(filter_out_func, training_lst_mol))

d_sampled = calculate_pc_descriptors(generated_lst_mol, pc_descriptor_subset)
d_sampled = calculate_pc_descriptors(generated_lst_mol,
pc_descriptor_subset)
d_chembl = calculate_pc_descriptors(training_lst_mol, pc_descriptor_subset)

kldivs = {}
for i in range(4):
kldiv = continuous_kldiv(X_baseline=d_chembl[:, i], X_sampled=d_sampled[:, i])
kldiv = continuous_kldiv(X_baseline=d_chembl[:, i],
X_sampled=d_sampled[:, i])
kldivs[pc_descriptor_subset[i]] = kldiv

# ... and for the int valued ones.
for i in range(4, 9):
kldiv = discrete_kldiv(X_baseline=d_chembl[:, i], X_sampled=d_sampled[:, i])
kldiv = discrete_kldiv(X_baseline=d_chembl[:, i],
X_sampled=d_sampled[:, i])
kldivs[pc_descriptor_subset[i]] = kldiv

# pairwise similarity
Expand All @@ -344,7 +353,8 @@ def canonical(smiles):
sampled_sim = calculate_internal_pairwise_similarities(generated_lst_mol)
sampled_sim = sampled_sim.max(axis=1)

kldiv_int_int = continuous_kldiv(X_baseline=chembl_sim, X_sampled=sampled_sim)
kldiv_int_int = continuous_kldiv(X_baseline=chembl_sim,
X_sampled=sampled_sim)
kldivs["internal_similarity"] = kldiv_int_int
"""
# for some reason, this runs into problems when both sets are identical.
Expand Down Expand Up @@ -395,10 +405,14 @@ def _calculate_distribution_statistics(chemnet, molecules):
cov = np.cov(gen_mol_act.T)
return mu, cov

mu_ref, cov_ref = _calculate_distribution_statistics(chemnet, training_smiles_lst)
mu_ref, cov_ref = _calculate_distribution_statistics(
chemnet, training_smiles_lst)
mu, cov = _calculate_distribution_statistics(chemnet, generated_smiles_lst)

FCD = fcd.calculate_frechet_distance(mu1=mu_ref, mu2=mu, sigma1=cov_ref, sigma2=cov)
FCD = fcd.calculate_frechet_distance(mu1=mu_ref,
mu2=mu,
sigma1=cov_ref,
sigma2=cov)
fcd_distance = np.exp(-0.2 * FCD)
return fcd_distance

Expand Down
Loading

0 comments on commit 6713f6f

Please sign in to comment.