Skip to content

Commit

Permalink
Generalize finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
michellemli committed Dec 30, 2023
1 parent 9f757d6 commit 1651e4b
Show file tree
Hide file tree
Showing 7 changed files with 641 additions and 886 deletions.
603 changes: 47 additions & 556 deletions finetune_pinnacle/data_prep.py

Large diffs are not rendered by default.

393 changes: 393 additions & 0 deletions finetune_pinnacle/extract_txdata_utils.py

Large diffs are not rendered by default.

88 changes: 20 additions & 68 deletions finetune_pinnacle/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,13 @@ def save_results(output_results_path: str, ap_scores: Dict[str, Dict[str, float]
"""
Save results in the form of dictionary to a json file.
"""
res = {'ap':ap_scores, 'auroc':auroc_scores}
res = {'ap': ap_scores, 'auroc': auroc_scores}
with open(output_results_path, 'w') as f:
json.dump(res, f)
print(f"\nResults output to -> {output_results_path}")


def save_plots(output_figs_path: str, positive_proportion_train: Dict[str, Dict[str, float]], positive_proportion_test: Dict[str, Dict[str, float]], ap_scores: Dict[str, Dict[str, float]], auroc_scores: Dict[str, Dict[str, float]], disease, wandb):
"""
Render and save/log plots.
"""
for eval_results, eval_type in zip([ap_scores, auroc_scores], ['ap_scores', 'auroc_scores']):
i = 0
fig = plt.figure(figsize=(10 * len(eval_results), 6 * len(eval_results)))
for disease, ct_res in eval_results.items():
all_scores = []
xlabels = []
i += 1
for celltype, score in ct_res.items(): # No repetition of experiments, so it's score but not scores for each cell type
all_scores = all_scores + [score]
xlabels = xlabels + [celltype]
ax = plt.subplot(len(eval_results), 1, i)
plot_data = pd.DataFrame({'y': xlabels, 'x': all_scores})
sns.barplot(x='x', y='y', data=plot_data, capsize=.2)
if eval_type == 'ap_scores':
try:
plt.scatter(y = xlabels,
x = [list(positive_proportion_train[disease].values())[0]] * (len(positive_proportion_test[disease])-1) + [list(positive_proportion_train[disease].values())[1]],
c = 'grey', zorder = 100, label = 'train', s = 10) # np.arange(len(positive_proportion_test[disease]))
except:
plt.scatter(y = xlabels,
x = [list(positive_proportion_train[disease].values())[0]] * (len(positive_proportion_test[disease])),
c = 'grey', zorder = 100, label='train', s = 10) # np.arange(len(positive_proportion_test[disease]))
plt.scatter(y = xlabels,
x = positive_proportion_test[disease].values(),
c = 'black', zorder = 100, label = 'test', s = 10)
ax.set_ylabel("")
ax.set_xlabel(disease + '-' + eval_type)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
plt.tight_layout()
plt.savefig(output_figs_path + eval_type + '.png')
wandb.log({f'test {eval_type} bar':fig})
plt.close(fig)
return


def save_torch_train_val_preds(best_train_y, best_train_preds, best_train_groups, best_train_cts, best_val_y, best_val_preds, best_val_groups, best_val_cts, groups_map_train, groups_map_val, cts_map_train, cts_map_val, models_output_dir, embed_name, disease, mod, wandb):
def save_torch_train_val_preds(best_train_y, best_train_preds, best_train_groups, best_train_cts, best_val_y, best_val_preds, best_val_groups, best_val_cts, groups_map_train, groups_map_val, cts_map_train, cts_map_val, models_output_dir, embed_name, wandb):
train_ranks = {}
val_ranks = {}
for ct in np.unique(best_val_cts):
Expand All @@ -73,51 +34,46 @@ def save_torch_train_val_preds(best_train_y, best_train_preds, best_train_groups
ct_name = cts_map_val[ct]

if len(np.unique(val_y_ct)) < 2:
auroc_score, ap_score, ct_recall_5, ct_precision_5, ct_ap_5, ct_recall_10, ct_precision_10, ct_ap_10, ct_recall_20, ct_precision_20, ct_ap_20, sorted_val_y_ct, sorted_val_preds_ct, sorted_val_groups_ct, positive_proportion_val = -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, np.array([-1] * len(val_y_ct)), np.array([-1] * len(val_y_ct)), np.array([-1] * len(val_y_ct)), -1
auroc_score, ap_score, ct_recall_5, ct_precision_5, ct_ap_5, ct_recall_10, ct_precision_10, ct_ap_10, sorted_val_y_ct, sorted_val_preds_ct, sorted_val_groups_ct, positive_proportion_val = -1, -1, -1, -1, -1, -1, -1, -1, np.array([-1] * len(val_y_ct)), np.array([-1] * len(val_y_ct)), np.array([-1] * len(val_y_ct)), -1
else:
auroc_score, ap_score, ct_recall_5, ct_precision_5, ct_ap_5, ct_recall_10, ct_precision_10, ct_ap_10, ct_recall_20, ct_precision_20, ct_ap_20, sorted_val_y_ct, sorted_val_preds_ct, sorted_val_groups_ct, positive_proportion_val = get_metrics(val_y_ct, val_preds_ct, val_groups_ct, "training")
auroc_score, ap_score, ct_recall_5, ct_precision_5, ct_ap_5, ct_recall_10, ct_precision_10, ct_ap_10, sorted_val_y_ct, sorted_val_preds_ct, sorted_val_groups_ct, positive_proportion_val = get_metrics(val_y_ct, val_preds_ct, val_groups_ct, "training")
if len(sorted_val_y_ct) > 0: sorted_val_y_ct = sorted_val_y_ct.squeeze(-1)
if len(sorted_val_preds_ct) > 0: sorted_val_preds_ct = sorted_val_preds_ct.squeeze(-1)

temp = pd.DataFrame({'y':sorted_val_y_ct, 'preds':sorted_val_preds_ct, 'name':[groups_map_val[prot_ind] for prot_ind in sorted_val_groups_ct]})
temp['type'] = ['val'] * len(temp)
val_ranks[ct_name] = temp
temp.to_csv(f'{models_output_dir}/{embed_name}_{disease}_{mod}_val_preds_{ct_name}.csv', index=False) # Save the validation predictions
temp.to_csv(f'{models_output_dir}/{embed_name}_val_preds_{ct_name}.csv', index=False) # Save the validation predictions

wandb.log({f'val AUPRC cell types {ct_name}': ap_score,
f'val AUROC cell types {ct_name}': auroc_score,
f'val recall@5 cell types {ct_name}': ct_recall_5,
f'val precision@5 cell types {ct_name}': ct_precision_5,
f'val AP@5 cell types {ct_name}': ct_ap_5,
f'val recall@10 cell types {ct_name}': ct_recall_10,
f'val precision@10 cell types {ct_name}': ct_precision_10,
f'val AP@10 cell types {ct_name}': ct_ap_10,
f'val recall@20 cell types {ct_name}': ct_recall_20,
f'val precision@20 cell types {ct_name}': ct_precision_20,
f'val AP@20 cell types {ct_name}': ct_ap_20,
f'val positive proportion {ct_name}': positive_proportion_val})
f'val AUROC cell types {ct_name}': auroc_score,
f'val recall@5 cell types {ct_name}': ct_recall_5,
f'val precision@5 cell types {ct_name}': ct_precision_5,
f'val AP@5 cell types {ct_name}': ct_ap_5,
f'val recall@10 cell types {ct_name}': ct_recall_10,
f'val precision@10 cell types {ct_name}': ct_precision_10,
f'val AP@10 cell types {ct_name}': ct_ap_10,
f'val positive proportion {ct_name}': positive_proportion_val})

for ct in np.unique(best_train_cts): # We don't want to mess up train & val, so better separate
hits = np.where(best_train_cts==ct)[0]
for ct in np.unique(best_train_cts):
hits = np.where(best_train_cts == ct)[0]
train_y_ct = best_train_y[hits]
train_preds_ct = best_train_preds[hits]
train_groups_ct = best_train_groups[hits]
# ct_recall_5, ct_precision_5, ct_ap_5, _ = precision_recall_at_k(train_y_ct, train_preds_ct, k=5)
# ct_recall_10, ct_precision_10, ct_ap_10, _ = precision_recall_at_k(train_y_ct, train_preds_ct, k=10)
#_, _, _, (sorted_train_y_ct, sorted_train_preds_ct, sorted_train_groups_ct) = precision_recall_at_k(train_y_ct, train_preds_ct, k=20, prots=train_groups_ct)
_, _, _, (sorted_train_y_ct, sorted_train_preds_ct, sorted_train_groups_ct) = precision_recall_at_k(train_y_ct, train_preds_ct, k=10, prots=train_groups_ct)

ct_name = cts_map_train[ct]
temp = pd.DataFrame({'y': sorted_train_y_ct.squeeze(-1), 'preds':sorted_train_preds_ct.squeeze(-1), 'name':[groups_map_train[prot_ind] for prot_ind in sorted_train_groups_ct]})
temp['type'] = ['train'] * len(temp)
train_ranks[ct_name] = temp
temp.to_csv(f'{models_output_dir}/{embed_name}_{disease}_{mod}_train_preds_{ct_name}.csv', index=False) # Save the validation predictions
temp.to_csv(f'{models_output_dir}/{embed_name}_train_preds_{ct_name}.csv', index=False) # Save the validation predictions

return train_ranks, val_ranks


def precision_recall_at_k(y: np.ndarray, preds: np.ndarray, k: int = 10, prots: np.ndarray = None):
""" Calculate recall@k, precision@k, and AP@k for binary classification.
"""
Calculate recall@k, precision@k, and AP@k for binary classification.
"""
assert preds.shape[0] == y.shape[0]
assert k > 0
Expand All @@ -140,7 +96,6 @@ def precision_recall_at_k(y: np.ndarray, preds: np.ndarray, k: int = 10, prots:
precision_k = np.sum(topk_y) / k

# Calculate the AP@k
# print(topk_y, topk_preds)
ap_k = average_precision_score(topk_y, topk_preds)

return recall_k, precision_k, ap_k, (sorted_y, sorted_preds, sorted_prots)
Expand All @@ -151,15 +106,12 @@ def get_metrics(y, y_pred, groups, celltype):
y = {celltype: y}
groups = {celltype: groups}

auroc_score = roc_auc_score(y[celltype], y_pred) # s[disease][celltype]
auroc_score = roc_auc_score(y[celltype], y_pred)
ap_score = average_precision_score(y[celltype], y_pred)
recall_5, precision_5, ap_5, _ = precision_recall_at_k(y[celltype], y_pred, k=5)
#recall_10, precision_10, ap_10, _ = precision_recall_at_k(y[celltype], y_pred, k=10)
recall_10, precision_10, ap_10, (sorted_y, sorted_preds, sorted_groups) = precision_recall_at_k(y[celltype], y_pred, k=10, prots=np.array(groups[celltype]))
#recall_20, precision_20, ap_20, (sorted_y, sorted_preds, sorted_groups) = precision_recall_at_k(y[celltype], y_pred, k=20, prots=np.array(groups[celltype]))
recall_20, precision_20, ap_20 = -1, -1, -1

# Calculate positive label proportions for each cell type, i.e. baseline for AP metric
positive_proportion = sum(y[celltype]) / len(y[celltype])

return auroc_score, ap_score, recall_5, precision_5, ap_5, recall_10, precision_10, ap_10, recall_20, precision_20, ap_20, sorted_y, sorted_preds, sorted_groups, positive_proportion
return auroc_score, ap_score, recall_5, precision_5, ap_5, recall_10, precision_10, ap_10, sorted_y, sorted_preds, sorted_groups, positive_proportion
11 changes: 9 additions & 2 deletions finetune_pinnacle/run_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ conda activate pinnacle

# Rheumatoid Arthritis (EFO_0000685)
python train.py \
--disease=EFO_0000685 \
--task_name=EFO_0000685 \
--embeddings_dir=../data/pinnacle_embeds/ \
--positive_proteins_prefix ../data/therapeutic_target_task/positive_proteins_EFO_0000685 \
--negative_proteins_prefix ../data/therapeutic_target_task/negative_proteins_EFO_0000685 \
--data_split_path ../data/therapeutic_target_task/data_split_EFO_0000685 \
--actn=relu \
--dropout=0.2 \
--hidden_dim_1=32 \
Expand All @@ -19,8 +22,11 @@ python train.py \

# Inflammatory bowel disease (EFO_0003767)
python train.py \
--disease=EFO_0003767 \
--task_name=EFO_0003767 \
--embeddings_dir=../data/pinnacle_embeds/ \
--positive_proteins_prefix ../data/therapeutic_target_task/positive_proteins_EFO_0003767 \
--negative_proteins_prefix ../data/therapeutic_target_task/negative_proteins_EFO_0003767 \
--data_split_path ../data/therapeutic_target_task/data_split_EFO_0003767 \
--actn=relu \
--dropout=0.4 \
--hidden_dim_1=32 \
Expand All @@ -31,3 +37,4 @@ python train.py \
--wd=0.0001 \
--random_state 1 \
--num_epoch=2000

84 changes: 84 additions & 0 deletions finetune_pinnacle/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Set up model environment, parameters, and data

import os
import pandas as pd
import argparse
import json
import torch

from read_data import read_labels_from_evidence


def create_parser():
parser = argparse.ArgumentParser()

parser.add_argument("--model", type=str, default="torch_mlp")

parser.add_argument("--hidden_dim_1", type=int, default=64, help="1st hidden dim size")
parser.add_argument("--hidden_dim_2", type=int, default=32, help="2nd hidden dim size, discard if 0")
parser.add_argument("--hidden_dim_3", type=int, default=0, help="3rd hidden dim size, discard if 0")
parser.add_argument("--dropout", type=float, default=0, help="dropout rate")
parser.add_argument("--norm", type=str, default=None, help="normalization layer")
parser.add_argument("--actn", type=str, default="relu", help="activation type")
parser.add_argument("--order", type=str, default="nd", help="order of normalization and dropout")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
parser.add_argument("--num_epoch", type=int, default=1, help="epoch num")
parser.add_argument("--batch_size", type=int, help="batch size")

# Input data for finetuning task
parser.add_argument("--task_name", type=str)
parser.add_argument("--data_split_path", type=str, default="targets/data_split")
parser.add_argument('--positive_proteins_prefix', type=str, default="../data/therapeutic_target_task/positive_proteins")
parser.add_argument('--negative_proteins_prefix', type=str, default="../data/therapeutic_target_task/negative_proteins")

# Input PINNACLE representations
parser.add_argument("--embeddings_dir", type=str)
parser.add_argument("--embed", type=str, default="pinnacle")

# Output directories
parser.add_argument("--metrics_output_dir", type=str, default="./tmp_evaluation_results/")
parser.add_argument("--models_output_dir", type=str, default="./tmp_model_outputs/")
parser.add_argument("--random_state", type=int, default=1)
parser.add_argument("--random", action="store_true", help="random runs without fixed seeds")
parser.add_argument("--overwrite", action="store_true", help="whether to overwrite the label data or not")
parser.add_argument("--train_size", type=float, default=0.6)
parser.add_argument("--val_size", type=float, default=0.2)
parser.add_argument("--weigh_sample", action="store_true", help="whether to weigh samples or not") # default = False
parser.add_argument("--weigh_loss", action="store_true", help="whether to weigh losses or not") # default = False

args = parser.parse_args()
return args


def get_hparams(args):

hparams = {
"lr": args.lr,
"wd": args.wd,
"hidden_dim_1": args.hidden_dim_1,
"hidden_dim_2": args.hidden_dim_2,
"hidden_dim_3": args.hidden_dim_3,
"dropout": args.dropout,
"actn": args.actn,
"order": args.order,
"norm": args.norm,
"task_name": args.task_name
}
return hparams


def setup_paths(args):
random_state = args.random_state if args.random_state >= 0 else None
if random_state == None:
models_output_dir = args.models_output_dir + args.embed + "/"
metrics_output_dir = args.metrics_output_dir + args.embed + "/"
else:
models_output_dir = args.models_output_dir + args.embed + ("_seed=%s" % str(random_state)) + "/"
metrics_output_dir = args.metrics_output_dir + args.embed + ("_seed=%s" % str(random_state)) + "/"
if not os.path.exists(models_output_dir): os.makedirs(models_output_dir)
if not os.path.exists(metrics_output_dir): os.makedirs(metrics_output_dir)

embed_path = args.embeddings_dir + args.embed + "_ppi_embed.pth" #"_protein_embed.pth"
labels_path = args.embeddings_dir + args.embed + "_labels_dict.txt"
return models_output_dir, metrics_output_dir, random_state, embed_path, labels_path
Loading

0 comments on commit 1651e4b

Please sign in to comment.