-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Michelle M. Li
committed
Jul 19, 2023
1 parent
2fd2882
commit 890ed7c
Showing
14 changed files
with
2,932 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
from typing import Dict | ||
import numpy as np | ||
import pandas as pd | ||
|
||
import json, matplotlib, os | ||
|
||
import torch | ||
from sklearn.metrics import average_precision_score, roc_auc_score | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
matplotlib.use('Agg') | ||
|
||
from data_prep import process_and_split_data | ||
|
||
|
||
def save_results(output_results_path: str, ap_scores: Dict[str, Dict[str, float]], auroc_scores: Dict[str, Dict[str, float]]): | ||
""" | ||
Save results in the form of dictionary to a json file. | ||
""" | ||
res = {'ap':ap_scores, 'auroc':auroc_scores} | ||
with open(output_results_path, 'w') as f: | ||
json.dump(res, f) | ||
print(f"\nResults output to -> {output_results_path}") | ||
|
||
|
||
def save_plots(output_figs_path: str, positive_proportion_train: Dict[str, Dict[str, float]], positive_proportion_test: Dict[str, Dict[str, float]], ap_scores: Dict[str, Dict[str, float]], auroc_scores: Dict[str, Dict[str, float]], disease, wandb): | ||
""" | ||
Render and save/log plots. | ||
""" | ||
for eval_results, eval_type in zip([ap_scores, auroc_scores], ['ap_scores', 'auroc_scores']): | ||
i = 0 | ||
fig = plt.figure(figsize=(10 * len(eval_results), 6 * len(eval_results))) | ||
for disease, ct_res in eval_results.items(): | ||
all_scores = [] | ||
xlabels = [] | ||
i += 1 | ||
for celltype, score in ct_res.items(): # No repetition of experiments, so it's score but not scores for each cell type | ||
all_scores = all_scores + [score] | ||
xlabels = xlabels + [celltype] | ||
ax = plt.subplot(len(eval_results), 1, i) | ||
plot_data = pd.DataFrame({'y': xlabels, 'x': all_scores}) | ||
sns.barplot(x='x', y='y', data=plot_data, capsize=.2) | ||
if eval_type == 'ap_scores': | ||
try: | ||
plt.scatter(y = xlabels, | ||
x = [list(positive_proportion_train[disease].values())[0]] * (len(positive_proportion_test[disease])-1) + [list(positive_proportion_train[disease].values())[1]], | ||
c = 'grey', zorder = 100, label = 'train', s = 10) # np.arange(len(positive_proportion_test[disease])) | ||
except: | ||
plt.scatter(y = xlabels, | ||
x = [list(positive_proportion_train[disease].values())[0]] * (len(positive_proportion_test[disease])), | ||
c = 'grey', zorder = 100, label='train', s = 10) # np.arange(len(positive_proportion_test[disease])) | ||
plt.scatter(y = xlabels, | ||
x = positive_proportion_test[disease].values(), | ||
c = 'black', zorder = 100, label = 'test', s = 10) | ||
ax.set_ylabel("") | ||
ax.set_xlabel(disease + '-' + eval_type) | ||
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0) | ||
plt.tight_layout() | ||
plt.savefig(output_figs_path + eval_type + '.png') | ||
wandb.log({f'test {eval_type} bar':fig}) | ||
plt.close(fig) | ||
return | ||
|
||
|
||
def save_torch_train_val_preds(best_train_y, best_train_preds, best_train_groups, best_train_cts, best_val_y, best_val_preds, best_val_groups, best_val_cts, groups_map_train, groups_map_val, cts_map_train, cts_map_val, models_output_dir, embed_name, disease, mod, is_global, wandb): | ||
if not is_global: | ||
train_ranks = {} | ||
val_ranks = {} | ||
for ct in np.unique(best_val_cts): | ||
hits = np.where(best_val_cts==ct)[0] | ||
val_y_ct = best_val_y[hits] | ||
val_preds_ct = best_val_preds[hits] | ||
val_groups_ct = best_val_groups[hits] | ||
ct_name = cts_map_val[ct] | ||
|
||
if len(np.unique(val_y_ct)) < 2: | ||
auroc_score, ap_score, ct_recall_5, ct_precision_5, ct_ap_5, ct_recall_10, ct_precision_10, ct_ap_10, ct_recall_20, ct_precision_20, ct_ap_20, sorted_val_y_ct, sorted_val_preds_ct, sorted_val_groups_ct, positive_proportion_val = -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, np.array([-1] * len(val_y_ct)), np.array([-1] * len(val_y_ct)), np.array([-1] * len(val_y_ct)), -1 | ||
else: | ||
auroc_score, ap_score, ct_recall_5, ct_precision_5, ct_ap_5, ct_recall_10, ct_precision_10, ct_ap_10, ct_recall_20, ct_precision_20, ct_ap_20, sorted_val_y_ct, sorted_val_preds_ct, sorted_val_groups_ct, positive_proportion_val = get_metrics(val_y_ct, val_preds_ct, val_groups_ct, "training") | ||
if len(sorted_val_y_ct) > 0: sorted_val_y_ct = sorted_val_y_ct.squeeze(-1) | ||
if len(sorted_val_preds_ct) > 0: sorted_val_preds_ct = sorted_val_preds_ct.squeeze(-1) | ||
|
||
temp = pd.DataFrame({'y':sorted_val_y_ct, 'preds':sorted_val_preds_ct, 'name':[groups_map_val[prot_ind] for prot_ind in sorted_val_groups_ct]}) | ||
temp['type'] = ['val'] * len(temp) | ||
val_ranks[ct_name] = temp | ||
temp.to_csv(f'{models_output_dir}/{embed_name}_{disease}_{mod}_val_preds_{ct_name}.csv', index=False) # Save the validation predictions | ||
|
||
wandb.log({f'val AUPRC cell types {ct_name}': ap_score, | ||
f'val AUROC cell types {ct_name}': auroc_score, | ||
f'val recall@5 cell types {ct_name}': ct_recall_5, | ||
f'val precision@5 cell types {ct_name}': ct_precision_5, | ||
f'val AP@5 cell types {ct_name}': ct_ap_5, | ||
f'val recall@10 cell types {ct_name}': ct_recall_10, | ||
f'val precision@10 cell types {ct_name}': ct_precision_10, | ||
f'val AP@10 cell types {ct_name}': ct_ap_10, | ||
f'val recall@20 cell types {ct_name}': ct_recall_20, | ||
f'val precision@20 cell types {ct_name}': ct_precision_20, | ||
f'val AP@20 cell types {ct_name}': ct_ap_20, | ||
f'val positive proportion {ct_name}': positive_proportion_val}) | ||
|
||
for ct in np.unique(best_train_cts): # We don't want to mess up train & val, so better separate | ||
hits = np.where(best_train_cts==ct)[0] | ||
train_y_ct = best_train_y[hits] | ||
train_preds_ct = best_train_preds[hits] | ||
train_groups_ct = best_train_groups[hits] | ||
# ct_recall_5, ct_precision_5, ct_ap_5, _ = precision_recall_at_k(train_y_ct, train_preds_ct, k=5) | ||
# ct_recall_10, ct_precision_10, ct_ap_10, _ = precision_recall_at_k(train_y_ct, train_preds_ct, k=10) | ||
#_, _, _, (sorted_train_y_ct, sorted_train_preds_ct, sorted_train_groups_ct) = precision_recall_at_k(train_y_ct, train_preds_ct, k=20, prots=train_groups_ct) | ||
_, _, _, (sorted_train_y_ct, sorted_train_preds_ct, sorted_train_groups_ct) = precision_recall_at_k(train_y_ct, train_preds_ct, k=10, prots=train_groups_ct) | ||
|
||
ct_name = cts_map_train[ct] | ||
temp = pd.DataFrame({'y': sorted_train_y_ct.squeeze(-1), 'preds':sorted_train_preds_ct.squeeze(-1), 'name':[groups_map_train[prot_ind] for prot_ind in sorted_train_groups_ct]}) | ||
temp['type'] = ['train'] * len(temp) | ||
train_ranks[ct_name] = temp | ||
temp.to_csv(f'{models_output_dir}/{embed_name}_{disease}_{mod}_train_preds_{ct_name}.csv', index=False) # Save the validation predictions | ||
|
||
else: # Global | ||
|
||
# val | ||
_, _, _, (sorted_val_y_global, sorted_val_preds_global, sorted_val_groups_global) = precision_recall_at_k(best_val_y, best_val_preds, k=5, prots=np.array(best_val_groups)) | ||
positive_proportion_val = sum(best_val_y) / len(best_val_y) | ||
|
||
wandb.log({'val positive proportion global':positive_proportion_val}) | ||
|
||
val_ranks = pd.DataFrame({'y':sorted_val_y_global.squeeze(-1), 'preds':sorted_val_preds_global.squeeze(-1), 'name':[groups_map_val[prot_ind] for prot_ind in sorted_val_groups_global]}) | ||
val_ranks['type'] = ['val'] * len(val_ranks) | ||
val_ranks.to_csv(f'{models_output_dir}/{embed_name}_{disease}_{mod}_val_preds_global.csv', index=False) # Save the validation predictions | ||
|
||
# train | ||
_, _, _, (sorted_train_y_global, sorted_train_preds_global, sorted_train_groups_global) = precision_recall_at_k(best_train_y, best_train_preds, k=20, prots=best_train_groups) | ||
train_ranks = pd.DataFrame({'y':sorted_train_y_global.squeeze(-1), 'preds':sorted_train_preds_global.squeeze(-1), 'name':[groups_map_train[prot_ind] for prot_ind in sorted_train_groups_global]}) | ||
train_ranks['type'] = ['train'] * len(train_ranks) | ||
|
||
train_ranks.to_csv(f'{models_output_dir}/{embed_name}_{disease}_{mod}_train_preds_global.csv', index=False) # Save the validation predictions | ||
|
||
return train_ranks, val_ranks | ||
|
||
|
||
def precision_recall_at_k(y: np.ndarray, preds: np.ndarray, k: int = 10, prots: np.ndarray = None): | ||
""" Calculate recall@k, precision@k, and AP@k for binary classification. | ||
""" | ||
assert preds.shape[0] == y.shape[0] | ||
assert k > 0 | ||
if k > preds.shape[0]: return -1, -1, -1, ([], [], []) | ||
|
||
# Sort the scores and the labels by the scores | ||
sorted_indices = np.argsort(preds.flatten())[::-1] | ||
sorted_preds = preds[sorted_indices] | ||
sorted_y = y[sorted_indices] | ||
if prots is not None: | ||
sorted_prots = prots[sorted_indices] | ||
else: sorted_prots = None | ||
|
||
# Get the scores of the k highest predictions | ||
topk_preds = sorted_preds[:k] | ||
topk_y = sorted_y[:k] | ||
|
||
# Calculate the recall@k and precision@k | ||
recall_k = np.sum(topk_y) / np.sum(y) | ||
precision_k = np.sum(topk_y) / k | ||
|
||
# Calculate the AP@k | ||
# print(topk_y, topk_preds) | ||
ap_k = average_precision_score(topk_y, topk_preds) | ||
|
||
return recall_k, precision_k, ap_k, (sorted_y, sorted_preds, sorted_prots) | ||
|
||
|
||
def get_metrics(y, y_pred, groups, celltype): | ||
if celltype in ["global", "esm", "training"]: # keep the data structure consistent between celltype and global | ||
y = {celltype: y} | ||
groups = {celltype: groups} | ||
|
||
auroc_score = roc_auc_score(y[celltype], y_pred) # s[disease][celltype] | ||
ap_score = average_precision_score(y[celltype], y_pred) | ||
recall_5, precision_5, ap_5, _ = precision_recall_at_k(y[celltype], y_pred, k=5) | ||
#recall_10, precision_10, ap_10, _ = precision_recall_at_k(y[celltype], y_pred, k=10) | ||
recall_10, precision_10, ap_10, (sorted_y, sorted_preds, sorted_groups) = precision_recall_at_k(y[celltype], y_pred, k=10, prots=np.array(groups[celltype])) | ||
#recall_20, precision_20, ap_20, (sorted_y, sorted_preds, sorted_groups) = precision_recall_at_k(y[celltype], y_pred, k=20, prots=np.array(groups[celltype])) | ||
recall_20, precision_20, ap_20 = -1, -1, -1 | ||
|
||
# Calculate positive label proportions for each cell type, i.e. baseline for AP metric | ||
positive_proportion = sum(y[celltype]) / len(y[celltype]) | ||
|
||
return auroc_score, ap_score, recall_5, precision_5, ap_5, recall_10, precision_10, ap_10, recall_20, precision_20, ap_20, sorted_y, sorted_preds, sorted_groups, positive_proportion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, in_dim: int, hidden_dims: list, p: float, norm: str, actn: str, order: str = 'nd'): | ||
super(MLP, self).__init__() | ||
|
||
self.n_layer = len(hidden_dims) - 1 | ||
self.in_dim = in_dim | ||
|
||
actn2actfunc = {'relu': nn.ReLU(), 'leakyrelu': nn.LeakyReLU(), 'tanh': nn.Tanh(), 'sigmoid': nn.Sigmoid(), 'selu': nn.SELU(), 'elu': nn.ELU(), 'softplus': nn.Softplus()} | ||
try: | ||
actn = actn2actfunc[actn] | ||
except: | ||
print(actn) | ||
raise NotImplementedError | ||
|
||
# Input layer | ||
layers = [nn.Linear(self.in_dim, hidden_dims[0]), actn] | ||
|
||
# Hidden layers | ||
for i in range(self.n_layer): | ||
layers += self.compose_layer(in_dim=hidden_dims[i], out_dim=hidden_dims[i+1], norm=norm, actn=actn, p=p, order=order) | ||
|
||
# Output layers | ||
layers.append(nn.Linear(hidden_dims[-1], 1)) | ||
|
||
self.fc = nn.Sequential(*layers) | ||
|
||
def compose_layer(self, in_dim: int, out_dim: int, norm: str, actn: nn.Module, p: float = 0.0, order: str = 'nd'): | ||
norm2normlayer = {'bn': nn.BatchNorm1d(in_dim), 'ln': nn.LayerNorm(in_dim), None: None, 'None': None} # because in_dim is only fixed here | ||
try: | ||
norm = norm2normlayer[norm] | ||
except: | ||
print(norm) | ||
raise NotImplementedError | ||
|
||
# Options: norm --> dropout or dropout --> norm | ||
if order == 'nd': | ||
layers = [norm] if norm is not None else [] | ||
if p != 0: | ||
layers.append(nn.Dropout(p)) | ||
elif order == 'dn': | ||
layers = [nn.Dropout(p)] if p != 0 else [] | ||
if norm is not None: | ||
layers.append(norm) | ||
else: | ||
print(order) | ||
raise NotImplementedError | ||
|
||
layers.append(nn.Linear(in_dim, out_dim)) | ||
if actn is not None: | ||
layers.append(actn) | ||
return layers | ||
|
||
def forward(self, x): | ||
output = self.fc(x) | ||
return output |
Oops, something went wrong.