diff --git a/analysis/compile_cas9_fidelity.py b/analysis/compile_cas9_fidelity.py index e5edb4b..08ff5ec 100644 --- a/analysis/compile_cas9_fidelity.py +++ b/analysis/compile_cas9_fidelity.py @@ -1,18 +1,24 @@ +from tqdm import tqdm import os from Bio.Align import PairwiseAligner, substitution_matrices from sequence_models.utils import parse_fasta -from tqdm import tqdm -from dayhoff.analysis_utils import get_all_paths, results_to_pandas -base_path = "/home/kevyan/generations/cas9/" +base_path = "/home/kevyan/generations/cas9-no-order/" -models = ["short_cas9s_1.0_minp0.00_new"] -for m in models: - pdb_paths, mpnn_paths = get_all_paths(os.path.join(base_path, "%s_structures/pdb/esmfold/" %m), os.path.join(base_path, "%s_structures/esmfoldmpnn_iftemp_1" %m)) - fold_df, mpnn_df, df = results_to_pandas(pdb_paths, mpnn_paths, name="") - df['model'] = m +model = "short_cas9s_1.0_minp0.00_new" +folding_df = pd.read_csv(os.path.join(base_path, 'esmfold_proteinmpnn_merge_data.csv')) +seqs, names = parse_fasta(os.path.join(base_path, "%s.fasta" % model), return_names=True) +df = folding_df[folding_df['if_temp'] == 1.0] +name_df = pd.DataFrame() +name_df['sequence'] = seqs +name_df['file'] = names +df = pd.merge(name_df, df, how='left', on='file') +# for m in models: +# pdb_paths, mpnn_paths = get_all_paths(os.path.join(base_path, "%s_structures/pdb/esmfold/" %m), os.path.join(base_path, "%s_structures/esmfoldmpnn_iftemp_1" %m)) +# fold_df, mpnn_df, df = results_to_pandas(pdb_paths, mpnn_paths, name="") +# df['model'] = m aligner = PairwiseAligner() aligner.substitution_matrix = substitution_matrices.load("BLOSUM62") @@ -22,41 +28,43 @@ aligner.query_end_gap_score = 0.0 with tqdm(total=len(df)) as pbar: homologs, homolog_names = parse_fasta(os.path.join('/home/kevyan/data/characterized_cas9s', "naturals.fasta"), return_names=True) + for idx, row in df.iterrows(): + s = row['sequence'] + s = s.replace("-", "") + s = s.replace("", "") + s = s.replace("", "") + s = s.replace("", "") + s = s.replace("", "") + best_matches = -1 + best_homolog_sequence = None + best_homolog_name = None + best_cterm_gaps = None + for hs, hn in zip(homologs, homolog_names): + alignment = aligner.align(s, hs) + if alignment.score > best_matches: + best_matches = alignment.score + best_homolog_sequence = hs + best_homolog_name = hn + best_cterm_gaps = len(hs) - alignment[0].aligned[1, -1, 1] + df.loc[idx, 'gen_length'] = len(s) + df.loc[idx, 'best_matches'] = best_matches + df.loc[idx, 'match_length'] = len(best_homolog_sequence) + df.loc[idx, 'homolog_name'] = best_homolog_name + df.loc[idx, 'homolog_sequence'] = best_homolog_sequence + df.loc[idx, 'cterm_gaps'] = best_cterm_gaps + pbar.update(1) - for model in models: - seqs, names = parse_fasta(os.path.join(base_path, "%s.fasta" %model), return_names=True) - for s, n in zip(seqs, names): - s = s.replace("-", "") - s = s.replace("", "") - s = s.replace("", "") - s = s.replace("", "") - s = s.replace("", "") - best_matches = -1 - best_homolog_sequence = None - best_homolog_name = None - best_cterm_gaps = None - for hs, hn in zip(homologs, homolog_names): - alignment = aligner.align(s, hs) - if alignment.score > best_matches: - best_matches = alignment.score - best_homolog_sequence = hs - best_homolog_name = hn - best_cterm_gaps = len(hs) - alignment[0].aligned[1, -1, 1] - - idx = df[(df['model'] == model) & (df['file'] == n)].index[0] - df.loc[idx, 'sequence'] = s - df.loc[idx, 'gen_length'] = len(s) - df.loc[idx, 'best_matches'] = best_matches - df.loc[idx, 'match_length'] = len(best_homolog_sequence) - df.loc[idx, 'homolog_name'] = best_homolog_name - df.loc[idx, 'homolog_sequence'] = best_homolog_sequence - df.loc[idx, 'cterm_gaps'] = best_cterm_gaps - pbar.update(1) - +df['plddt'] = df['esmfoldplddt'] +df['scperplexity'] = df['proteinmpnnperplexity'] df['seq_id'] = df['best_matches'] / df['gen_length'] df = df.sort_values(['cterm_gaps', 'plddt'], ascending=[True, False]) df['name'] = [f.split('_')[-1] for f in df['file']] -df.to_csv(os.path.join(base_path, "%s_fidelity.csv" %models[0]), index=False) +df.to_csv(os.path.join(base_path, "%s_fidelity.csv" %model), index=False) + +df = pd.read_csv(os.path.join(base_path, "%s_fidelity.csv" %model)) df[df['plddt'] > .70].head(10)[['name', 'match_length', 'gen_length', 'plddt', 'cterm_gaps', 'best_matches']] -# 52, 8, and 50 have the most domain hits \ No newline at end of file +# 52, 8, and 50 have the most domain hits +df[df['plddt'] > 0.7].shape +df.loc[[0, 1, 2, 18, 19, 21], ['name', 'sequence']].values +df.loc[[0, 1, 2, 18, 19, 21], ['name', 'homolog_name', 'homolog_sequence']].values \ No newline at end of file diff --git a/analysis/compile_msa_fidelity.py b/analysis/compile_msa_fidelity.py index 7ae1116..c2dc6f6 100644 --- a/analysis/compile_msa_fidelity.py +++ b/analysis/compile_msa_fidelity.py @@ -1,5 +1,6 @@ import os from collections import Counter +from tqdm import tqdm import matplotlib.pyplot as plt import numpy as np @@ -7,9 +8,10 @@ import seaborn as sns from Bio.Align import PairwiseAligner from scipy.stats import pearsonr -from sequence_models.utils import parse_fasta -from tqdm import tqdm +from scipy.stats import ttest_rel +from dayhoff.analysis_utils import results_to_pandas, get_all_paths, run_tmscore +from sequence_models.utils import parse_fasta from dayhoff.analysis_utils import get_all_paths, results_to_pandas, run_tmscore sns.set_theme(font_scale=1.2) @@ -18,9 +20,9 @@ base_path = "/home/kevyan/generations/queries_from_homologs" models = ["natural", "xlstm", "evodiff", "evodiff_nom", - "gap_1.0_0.01", "gap_1.2_0.01", "gap_1.0_0.00", "gap_1.1_0.05", "gap_1.0_0.00_nom", "gap_1.0_0.05_nom", - "ccmgen", "ccmgen_short", - "indel_1.0_0.00", "indel_1.2_0.01", "indel_1.0_0.01", "indel_1.1_0.05", "indel_1.0_0.00_nom", "indel_1.0_0.05_nom"] + "gap_1.0_0.05_nom", + "ccmgen", + "indel_1.0_0.05_nom"] raw_dfs = [] for m in models: pdb_paths, mpnn_paths = get_all_paths(os.path.join(base_path, "%s_structures/pdb/esmfold/" %m), os.path.join(base_path, "%s_structures/esmfoldmpnn_iftemp_1" %m)) @@ -28,7 +30,6 @@ merged_df['model'] = m raw_dfs.append(merged_df) df = pd.concat(raw_dfs, ignore_index=True) - model_to_name = {m: m for m in models} model_to_name['natural'] = 'queries' model_to_name['gap_1.0_0.01'] = '3b-cooled_25000_gap_t1.0_0.01' @@ -60,11 +61,12 @@ gen_length = len(s) n_homologs = len(homologs) best_id = -1 + if model == 'natural': + homologs = homologs[1:] for i, homolog in enumerate(homologs): if i == 0: query_length = len(homolog) - if model == 'natural': - continue + alignment = aligner.align(s, homolog) if alignment.score > best_matches: best_matches = alignment.score @@ -100,48 +102,67 @@ len(set(df['file'])) df.to_csv(os.path.join(base_path, "compiled_fidelities.csv"), index=False) df = pd.read_csv(os.path.join(base_path, "compiled_fidelities.csv")) +models_to_plot = { + "natural": "Natural", + "ccmgen": "CCMgen", + "evodiff_nom": "EvoDiff-MSA", + "xlstm": 'Prot-xLSTM', + "gap_1.0_0.05_nom": "Aligned", + "indel_1.0_0.05_nom": "Unaligned" +} - - -print("model R(n_homologs, plddt)") -for model in models: - df_lim = df[df['model'] == model] - print(model, pearsonr(df_lim['n_homologs'], df_lim['plddt']).statistic) - -print("model R(n_homologs, seq_id)") -for model in models: - df_lim = df[df['model'] == model] - print(model, pearsonr(df_lim['n_homologs'], df_lim['seq_id']).statistic) - -print("model R(n_homologs, tmscore)") -for model in models[1:]: - df_lim = df[df['model'] == model] - print(model, pearsonr(df_lim['n_homologs'], df_lim['tmscore']).statistic) - -print("model R(seq_id, tmscore)") -for model in models[1:]: - df_lim = df[df['model'] == model] - print(model, pearsonr(df_lim['seq_id'], df_lim['tmscore']).statistic) - - -grouped = df.groupby('model') +model1 = "indel_1.0_0.05_nom" +model2 = "gap_1.0_0.05_nom" +df.columns +all_names = list(set(df['file'])) +metrics = { + "plddt": [], + "perplexity": [], + "tmscore": [], + "seq_id": [] +} +m_dict = { + "plddt": "pLDDT", + "perplexity": "scPerplexity", + "tmscore": "TM score", + "seq_id": "Sequence identity" +} +for n in all_names: + for m in metrics: + first = df[(df['model'] == model1) & (df['file'] == n)][m].values[0] + second = df[(df['model'] == model2) & (df['file'] == n)][m].values[0] + metrics[m].append(first - second) +fig, axs = plt.subplots(1, 4, figsize=(14.4, 4.8)) +for i, (m, ax) in enumerate(zip(metrics, axs)): + print(np.array(metrics[m]).mean()) + _ = sns.stripplot(y=metrics[m], ax=ax, alpha=0.7) + if i == 0: + _ = ax.set_ylabel(models_to_plot[model1] + " - " + models_to_plot[model2]) + _ = ax.set_xlabel(m_dict[m]) + _ = ax.axhline(y=0, color='gray') + _ = ax.set_xticks(ax.get_xticks()) + _ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') +_ = fig.savefig(os.path.join(base_path, "ulal_diff.pdf"), dpi=300, bbox_inches="tight") + + +grouped = df[['model', 'plddt', 'perplexity', 'tmscore', 'seq_id']].groupby('model') +final_models = ["gap_1.0_0.05_nom", "indel_1.0_0.05_nom", "ccmgen", "xlstm", "evodiff_nom", "natural"] +metrics = grouped.agg(['mean', 'std']).loc[final_models].reset_index() +for i, row in metrics.iterrows(): + print_me = [models_to_plot[row['model'].values[0]]] + for m in ['plddt', 'perplexity', 'tmscore', 'seq_id']: + print_me.append("$%.2f \\pm %.2f$" % (row[m]['mean'], row[m]['std'])) + print_me = " & ".join(print_me) + "\\\\" + print(print_me) grouped.seq_id.agg(['mean', 'std']) grouped.tmscore.agg(['mean', 'std']) grouped.plddt.agg(['mean', 'std']) grouped.perplexity.agg(['mean', 'std']) -models_to_plot = { - "natural": "Natural", - "ccmgen_short": "CCMgen", - "evodiff_nom": "EvoDiff-MSA", - "xlstm": 'Prot-xLSTM', - "gap_1.0_0.05_nom": "Alignment conditioning", - "indel_1.0_0.05_nom": "Homolog conditioning" -} pal = sns.color_palette() model_to_hue = { "natural": "gray", - "ccmgen_short": pal[-4], + "ccmgen": pal[-4], "xlstm": pal[-1], "evodiff_nom": pal[-2], "gap_1.0_0.05_nom": pal[1], @@ -149,7 +170,7 @@ } cdfs = { "seq_id": (plt.subplots(), "Sequence Identity"), - "tmscore": (plt.subplots(), "TM score"), + "tmscore": (plt.subplots(), "TM-score"), "plddt": (plt.subplots(), "pLDDT"), "perplexity": (plt.subplots(), "scPerplexity"), } @@ -202,3 +223,21 @@ # 6202062 gap plddt 0.896246 # A0A174Z1L0 indel plddt 0.941450 longer than query # 76841376 indel plddt 0.938367 + +df.head() +df.columns +pivoted = {} +out_models = { +"gap_1.0_0.05_nom": "aligned", "indel_1.0_0.05_nom": "unaligned", "ccmgen": "CCMgen", "xlstm": "prot-xLSTM", "evodiff_nom": "EvoDiff-MSA", "natural": "natural" +} +ttest_outfile = os.path.join(base_path, "ttest_results.csv") +with open(ttest_outfile, 'w') as f: + f.write("metric,model1,model2,t,p\n") + for m in ['plddt', 'perplexity', 'tmscore', 'seq_id']: + df2 = df.pivot(index='file', columns='model', values=m) + for i, model1 in enumerate(final_models): + for model2 in final_models[i + 1:]: + r = ttest_rel(df2[model1], df2[model2]) + f.write(','.join([m, out_models[model1], out_models[model2], str(r.statistic), str(r.pvalue)]) + "\n") + if ttest_rel(df2[model1], df2[model2]).pvalue > 0.05: + print(m, model1, model2) \ No newline at end of file diff --git a/analysis/fidelity.py b/analysis/fidelity.py index 4b026f8..328c9df 100644 --- a/analysis/fidelity.py +++ b/analysis/fidelity.py @@ -15,7 +15,6 @@ PATH_TO_PROTEINMPNN = "ProteinMPNN/" CWD = os.getcwd() - def get_bfactor(filename, chain="A"): parser = PDBParser(PERMISSIVE=1) protein = parser.get_structure(chain, filename) @@ -52,7 +51,6 @@ def run_omegafold(input_fasta, output_dir, subbatch_size=1024): ],check=True ) - def convert_outputs_to_pdb(outputs): final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} @@ -101,7 +99,7 @@ def get_all_paths(pdb_path, mpnn_path): all_files.append((os.path.join(pdb_path, i), i)) for j in os.listdir(mpnn_path): all_mpnn_files.append((j, os.path.join(mpnn_path))) - + print(f"PDB Files: {len(all_files)}, MPNN Files: {len(all_mpnn_files)}") return all_files, all_mpnn_files @@ -114,13 +112,13 @@ def results_to_pandas(all_files, all_mpnn_files, fold_method="omegafold", if_met fold_files = [] mpnn_files = [] - for i, f in all_files: + for i, f in all_files: if os.path.exists(i): plddts.append(get_bfactor(i)) fold_files.append(f.split('.pdb')[0]) fold_full_path.append(i) - - for f, mpnn_output_paths in all_mpnn_files: + + for f, mpnn_output_paths in all_mpnn_files: subdir_files = os.listdir(os.path.join(mpnn_output_paths, f, 'scores/')) for mfile in subdir_files: file = os.path.join(mpnn_output_paths, f, 'scores/', mfile) @@ -128,44 +126,44 @@ def results_to_pandas(all_files, all_mpnn_files, fold_method="omegafold", if_met perp = get_mpnn_perp(file) mpnn_files.append(mfile.split('/')[-1].split('.npz')[0]) perps.append(perp) - + fold_dict = { "full_path": fold_full_path, f"{fold_method}plddt": plddts, "file": fold_files, } - + mpnn_dict = { f"{if_method}perplexity": perps, "file": mpnn_files, } - + fold_df = pd.DataFrame(fold_dict) mpnn_df = pd.DataFrame(mpnn_dict) - merged_df = pd.merge(fold_df, mpnn_df, on='file', how='inner') # merge on file name + merged_df = pd.merge(fold_df, mpnn_df, on='file', how='inner') # merge on file name return fold_df, mpnn_df, merged_df - + def parse_csv(csv_path, return_names=False): """ Parse a CSV file and return the 'sequence' column """ import csv - + sequences = [] headers = [] - + with open(csv_path, 'r') as csvfile: csv_reader = csv.reader(csvfile) - + # Get headers from the first row headers = next(csv_reader) - + # Find the index of the 'sequence' column try: seq_index = headers.index('sequence') except ValueError: raise ValueError("CSV file does not contain a 'sequence' column") - + # Extract sequences from the appropriate column for row in csv_reader: if len(row) > seq_index: @@ -181,7 +179,7 @@ def run_esmfold(input_fasta: str, esm_chunk_size: int = 64, short_or_long: str = 'short', missing_caret: bool = False): - + # raise FileNotFoundError(f"Input fasta file {input_fasta} not found.") # Parse fasta if missing_caret: @@ -201,7 +199,7 @@ def run_esmfold(input_fasta: str, select_array = [len(s) >= 800 for s in seqs] filtered_seqs = [s for i, s in enumerate(seqs) if select_array[i]] filtered_seq_ids = [s for i, s in enumerate(seq_ids) if select_array[i]] - elif short_or_long == "crop": + elif short_or_long == "crop": filtered_seqs = [s[:1750] for s in seqs] filtered_seq_ids = seq_ids else: # dont filter if anything else @@ -331,10 +329,10 @@ def main(): parser.add_argument("--short-or-long", type=str, default='all') # short < 800, long >= 800 for running on <40GB gpu, `all` dont filter parser.add_argument("--skip-folding", action="store_true") # TODO clean up later parser.add_argument("--skip-if", action="store_true") # bypass running if/folding - + args = parser.parse_args() - + pdb_path = os.path.join(args.output_path, "pdb", args.fold_method) os.makedirs(pdb_path, exist_ok=True) @@ -348,22 +346,22 @@ def main(): # Parse the CSV file to get sequences seqs = parse_csv(args.input_fasta) seq_names = [f"seq_{i+1}" for i in range(len(seqs))] - + # Create a temporary FASTA file in the same location as the input CSV - temp_fasta_path = os.path.join(os.path.dirname(args.input_fasta), + temp_fasta_path = os.path.join(os.path.dirname(args.input_fasta), f"temp_{os.path.basename(args.input_fasta)}.fasta") - + # Write sequences to the temporary FASTA file with open(temp_fasta_path, 'w') as f: for seq_name, seq in zip(seq_names, seqs): f.write(f">{seq_name}\n{seq}\n") - + # Use the temporary file as input for the rest of the pipeline input_fasta_path = temp_fasta_path else: # Use the original input FASTA path input_fasta_path = args.input_fasta - + # Run the folding model with the appropriate input path if args.fold_method == "esmfold": run_esmfold(input_fasta=input_fasta_path, @@ -380,14 +378,13 @@ def main(): print("PDBs in omegafold directory") else: print("Only omegafold and esmfold methods are supported") - # Run inverse_fold if_temps = [1, 0.5, 0.1] pdb_indices = {} mpnn_output_paths = {} - if_method = 'mpnn' # TODO Might break with esmfold - have not tested + if_method = 'mpnn' # TODO Might break with esmfold - have not tested for t in if_temps: output_folder = os.path.join(args.output_path, args.fold_method + if_method + '_iftemp_' + str(t) + "/") pdb_files = os.listdir(pdb_path) @@ -398,7 +395,7 @@ def main(): run_inversefold(pdb_path, output_folder, pdb_files, method=if_method, temperature=t) - # Compile results + # Compile results all_results = [] for t in if_temps: all_files, all_mpnn_files = get_all_paths(pdb_path, mpnn_output_paths[t]) @@ -414,7 +411,8 @@ def main(): csv_path = os.path.join(args.output_path, args.fold_method + "_" + args.if_method + "_merge_data.csv") final_df.to_csv(csv_path, index=False) print(f"All results saved to {csv_path}") - - + + if __name__ == "__main__": main() + diff --git a/analysis/gigaref.py b/analysis/gigaref.py index 977cb5a..ae95ce5 100644 --- a/analysis/gigaref.py +++ b/analysis/gigaref.py @@ -67,9 +67,9 @@ db = "SMAG" elif db == "metaeuk": db = "MetaEuk" - count_df.loc[len(count_df), ["database", "count", "fraction", "ggr"]] = (db, singleton_sums[i], singleton_sums[i] / singleton_sums[-1], "singletons") + count_df.loc[len(count_df), ["database", "count", "fraction", "ggr"]] = (db, singleton_sums[i], singleton_sums[i] / singleton_sums[-1], "GigaRef-singletons") count_df.loc[len(count_df), ["database", "count", "fraction", "ggr"]] = (db, database_sizes[i] - singleton_sums[i], - (database_sizes[i] - singleton_sums[i]) / (database_sizes[-1] - singleton_sums[-1]), "clusters") + (database_sizes[i] - singleton_sums[i]) / (database_sizes[-1] - singleton_sums[-1]), "GigaRef-clusters") big_cluster_compositions = cluster_compositions[ns_ids] ur100id = list(columns).index("UniRef100") big_cluster_count = len(big_cluster_compositions) @@ -82,27 +82,74 @@ print("no ur100\tonly ur100\tmixed") print(no_ur100_count, only_ur100_count, big_cluster_count - no_ur100_count - only_ur100_count) pal = sns.color_palette() -skip = 50000 +skip = 1000 +plot_me = [mix_compositions[:, -1], mix_compositions[:, ur100id]] +plot_me = np.stack(plot_me) +plot_me = pd.DataFrame(plot_me.T, columns=["x", "y"]) +plot_me = plot_me.drop_duplicates() +plot_me = plot_me.sort_values(by=["x", "y"]) fig, ax = plt.subplots(1, 1) -_ = ax.plot(mix_compositions[::skip, -1], mix_compositions[::skip, ur100id], '.', color='gray', alpha=0.6, label="Metagenomic samples only") -_ = ax.plot(no_ur100_compositions[::skip, -1], no_ur100_compositions[::skip, ur100id], '.', color=pal[4], alpha=0.6, label="UR100 + metagenomic samples") +_ = ax.plot(plot_me.iloc[::skip, 0], plot_me.iloc[::skip, 1], '.', color='gray', ms=3, alpha=0.6, label="UR100 + metagenomic samples") +plot_me = [no_ur100_compositions[:, -1], no_ur100_compositions[:, ur100id]] +plot_me = np.stack(plot_me) +plot_me = pd.DataFrame(plot_me.T, columns=["x", "y"]) +plot_me = plot_me.drop_duplicates() +plot_me = plot_me.sort_values(by=["x", "y"]) +_ = ax.plot(plot_me.iloc[::5]['x'], plot_me.iloc[::5]['y'], '.', ms=3, color=pal[4], alpha=0.9, label="Metagenomic samples only") _ = ax.set_xlabel('Cluster size') _ = ax.set_ylabel('# UR100 members') _ = ax.legend(loc='best') _ = ax.set_xscale('log') _ = fig.savefig(os.path.join(out_dir, "gigaref_compositions.pdf"), dpi=300, bbox_inches='tight') +skip = 100 +fig, ax = plt.subplots(1, 1) +plot_me = [mix_compositions[:, -1], mix_compositions[:, -1] - mix_compositions[:, ur100id]] +plot_me = np.stack(plot_me) +plot_me = pd.DataFrame(plot_me.T, columns=["x", "y"]) +plot_me = plot_me.drop_duplicates() +plot_me = plot_me.sort_values(by=["x", "y"]) +_ = ax.plot(plot_me.iloc[::skip, 0], plot_me.iloc[::skip, 1], '.', color='gray', ms=3, alpha=0.6, label="UR100 + metagenomic") +plot_me = [no_ur100_compositions[:, -1], no_ur100_compositions[:, ur100id]] +plot_me = np.stack(plot_me) +plot_me = pd.DataFrame(plot_me.T, columns=["x", "y"]) +plot_me = plot_me.drop_duplicates() +plot_me = plot_me.sort_values(by=["x", "y"]) +plot_me['y'] = plot_me['x'] - plot_me['y'] +_ = ax.plot(plot_me.iloc[:]['x'], plot_me.iloc[:]['y'], '.', ms=3, color=pal[4], alpha=0.7, label="Metagenomic only") +_ = ax.set_xlabel('Cluster size') +_ = ax.set_ylabel('# Metagenomic members') +_ = ax.legend(loc='upper left') +_ = ax.set_xscale('log') +_ = ax.set_yscale('log') +_ = fig.savefig(os.path.join(out_dir, "gigaref_compositions_inverted.pdf"), dpi=300, bbox_inches='tight') + +def plot_cdf(x, color=sns.color_palette()[0], label=None, ax=None, **kwargs, ): + p = x.values + p.sort() + _ = ax.plot(p, np.linspace(0, 1, len(x)), '-', color=color, label=label, **kwargs) +mix_compositions.shape + +fig, ax = plt.subplots(1, 1) +_ = ax.plot(mix_compositions[::skip, -1], mix_compositions[::skip, ur100id], '.', color=pal[4], alpha=0.6, label="UR100 + metagenomic samples") +_ = ax.plot(no_ur100_compositions[::skip, -1], no_ur100_compositions[::skip, ur100id], '.', color='gray', alpha=0.6, label="Metagenomic samples only") +_ = ax.set_xlabel('Cluster size') +_ = ax.set_ylabel('# UR100 members') +_ = ax.legend(loc='best') +_ = ax.set_xscale('log') +_ = fig.savefig(os.path.join(out_dir, "gigaref_compositions_ecdf.pdf"), dpi=300, bbox_inches='tight') fig, ax = plt.subplots(1, 1) _ = sns.barplot(data=count_df[count_df['database'] != "total"], x="database", y="count", hue="ggr", ax=ax, palette=[pal[4], pal[6]], order=['SRC', 'MGY', 'MERC', 'UniRef100', 'SMAG', 'TOPAZ', 'MetaEuk', - 'MGV', 'GPD'], legend=True, hue_order=['clusters', 'singletons']) + 'MGV', 'GPD'], legend=True, hue_order=['GigaRef-clusters', 'GigaRef-singletons']) # _ = sns.barplot(data=count_df[count_df['database'] != "total"], x="database", y="fraction", hue="ggr", # ax=axs[1], palette=[pal[4], sns.color_palette("pastel")[4]], # order=['SRC', 'MGY', 'MERC', 'UniRef100', 'smag', 'TOPAZ', 'metaeuk', # 'MGV', 'GPD'], hue_order=['clusters', 'singletons']) _ = ax.semilogy() _ = ax.set_xlabel("") +_ = ax.set_ylabel("Count") _ = ax.set_xticks(ax.get_xticks()) _ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') # _ = ax.tick_params(axis="x", labelrotation=45, horizontalalignment="right") @@ -154,7 +201,7 @@ # get all the FPDs -fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.8)) +fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) _ = sns.barplot(plot_me, x='metric', y='value', hue='dataset', ax=ax, hue_order=[model_dict[m]['name'] for m in model_order], palette=model_palette) _ = ax.set_xlabel("") @@ -166,87 +213,58 @@ -## GIGAREF PLDDTs and scPERPs -pldddts = { - "Datasets": ["Reps", "Singletons", "UniRef50"], - "Mean": [61.824312446845234, 56.09045572735549, 65.20655783405739], - "StdDev": [19.86909229559785, 20.122267214806396, 19.72222931341456] +df_fid = pd.read_csv(os.path.join(out_dir, "ggr_plddt_mpnn.csv")) +df_fid = df_fid.drop_duplicates(subset='file') +model_name_dict = { + 'uniref50_': "UniRef50", + 'rep': "GigaRef-clusters", + 'singletons': "GigaRef-singletons" } -pldddts_df = pd.DataFrame(pldddts) +dataset_order = [ + 'UniRef50', + "GigaRef-clusters", + 'GigaRef-singletons', +] +pal = sns.color_palette() +hue_order = ['gray', pal[4], pal[6]] + +for i, row in df_fid.iterrows(): + df_fid.loc[i, 'dataset'] = model_name_dict[row['model']] + df_fid.loc[i, 'model_sort'] = dataset_order.index(df_fid.loc[i, 'dataset']) +df_fid['pLDDT'] = df_fid['esmfoldplddt'] +df_fid['scPerplexity'] = df_fid['mpnnperplexity'] +df_fid = df_fid.sort_values('model_sort') +fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) +_ = sns.scatterplot(df_fid, x='pLDDT', y='scPerplexity', hue='dataset', ax=ax, + hue_order=dataset_order, palette=hue_order, alpha=0.7, s=10) +legend = ax.legend( + loc='upper right', + # bbox_to_anchor=(0.9, 0.9), + # bbox_transform=fig.transFigure, + title=None +) +_ = fig.savefig(os.path.join(out_dir, "gigaref_fidelities.pdf"), bbox_inches='tight', dpi=300) -perps = { - "Datasets": ["Reps", "Singletons", "UniRef50"], - "Mean": [9.673995, 10.068662, 9.483173], - "StdDev": [2.866755, 2.866755, 2.8937333] +cdfs = { + "pLDDT": (plt.subplots(1, 1, figsize=(6.4, 4.8)), "pLDDT"), + "scPerplexity": (plt.subplots(1, 1, figsize=(6.4, 4.8)), "scPerplexity"), } -perps_df = pd.DataFrame(perps) - - -# basic bar plot -def bar_fpd(data, title, filename): - """ - Create and save a bar plot for FPD values. - """ - plt.figure(figsize=(6, 6)) - sns.barplot(x="Datasets", y="FPD", data=data, hue="Datasets", - palette=["grey"] * len(data["Datasets"])) # type: ignore - plt.ylim(0, 0.5) - plt.xlabel("Datasets") - plt.ylabel("FPD") - plt.title(title) - plt.savefig(filename, format="pdf") - plt.close() - - -def plot_bar(data, title, filename): - """ - Create and save a bar plot with error bars for mean and standard deviation values. - """ - plt.figure(figsize=(4, 6)) - sns.barplot(x="Datasets", y="Mean", data=data, palette="muted", ci=None) - plt.errorbar(x="Datasets", y="Mean", yerr="StdDev", data=data, fmt='none', c='black', capsize=5) - plt.xlabel("Datasets") - plt.ylabel(title) - plt.title(title) - plt.savefig(filename, format="pdf") - plt.close() - - -# Function to create and save pie charts -def gigaref_pie_chart(datasets, titles): - """ - Create and save a pie chart for the given data. - """ - fig, axs = plt.subplots(1, 2, figsize=(12, 5)) - for data, ax, title in zip(datasets, axs, titles): - labels = data.keys() - colors = sns.color_palette("pastel")[0:len(labels)] - values = data.values() - values = [v for v in values if v > 0] - labels = [la for v, la in zip(values, labels) if v > 0] - - wedges, texts, autotexts = ax.pie( - values, labels=labels, autopct='%1.1f%%', colors=colors, startangle=140, wedgeprops={'edgecolor': 'black'} - ) - ax.set_title(title) - # legend_labels = [f'{label}: {value}%' for label, value in - # zip(labels, [round(value / sum(values) * 100, 2) for value in values])] - axs[0].legend(loc="upper right") - for ax in axs: - ax.axis('equal') - fig.savefig(os.path.join(out_dir, "composition_pies.pdf"), dpi=300, bbox_inches='tight') - -gigaref_pie_chart((n_clusters, n_singletons), ("Clusters", "Singletons")) - -# plot taxonomy pie charts -# gigaref_pie_chart(n_clusters, "Composition of GigaRef Clusters", "cluster_composition.pdf") -# gigaref_pie_chart(n_singletons, "Composition of GigaRef Singletons", "singleton_composition.pdf") - -# plot fpd bar charts -# bar_fpd(fpds_df, "Distributional Distances for Dayhoff Datasets", "gigaref_fpd.pdf") -plot_bar(pldddts_df, "pLDDT", "gigaref_plddt.pdf") -plot_bar(perps_df, "scPerplexity", "gigaref_scperplexity.pdf") - - - - +model_to_hue = {d: h for d, h in zip(dataset_order, hue_order)} +for dataset in dataset_order: + for cdf in cdfs: + df_lim = df_fid[df_fid['dataset'] == dataset] + v = df_lim[cdf].values + v.sort() + item = cdfs[cdf] + _ = item[0][1].plot(v, np.linspace(0, 1, len(v)), '-', color=model_to_hue[dataset], label=dataset) + +for cdf in cdfs: + fig, ax = cdfs[cdf][0] + if cdf == 'pLDDT': + _ = ax.legend(loc='best') + _ = ax.set_xlabel(cdfs[cdf][1]) + _ = ax.set_ylabel('Percentile') + _ = fig.savefig(os.path.join(out_dir, "%s.pdf" %cdf), bbox_inches='tight', dpi=300) + + +print(df_fid.groupby('dataset').agg({"pLDDT": ["mean", "std"], "scPerplexity": ["mean", "std"]})) \ No newline at end of file diff --git a/analysis/pfam.py b/analysis/pfam.py index d014c9f..279b522 100644 --- a/analysis/pfam.py +++ b/analysis/pfam.py @@ -203,21 +203,37 @@ un_vector = np.array([un_counter[d] for d in all_domains]) input_dict = { "UniRef50": ur_vector, - "GGR": r_vector, - "GGR-singles": s_vector, - "BBR-sc": sc_vector, - "BBR-n": scn_vector, - "BBR-u": un_vector, + "GR": r_vector, + "GR-singles": s_vector, + "BRq": sc_vector, + "BRn": scn_vector, + "BRu": un_vector, "DayhoffRef": dr_vector, } +np.savez_compressed(os.path.join(in_dir, "pfam_annotations", "counts.npz"), **input_dict) + +d = np.load(os.path.join(in_dir, "pfam_annotations", "counts.npz"), allow_pickle=True) + +# fixed = { +# "UniRef50": d['UniRef50'], +# "GR": d['GGR'], +# "GR-singles": d['GGR-singles'], +# "BRq": d['BBR-sc'], +# "BRn": d['BBR-n'], +# "BRu": d['BBR-u'], +# "DayhoffRef": d['DayhoffRef'], +# } + +input_dict = {k: d[k] for k in d.keys()} keys = list(input_dict.keys()) for i, k1 in enumerate(keys): v1 = input_dict[k1] + sort_idx = v1.argsort() print(k1, v1.sum()) for k2 in keys[i + 1:]: v2 = input_dict[k2] fig, ax = plt.subplots(1, 1) - _ = ax.plot(v1, v2, '.', alpha=0.3, color='gray') + _ = ax.plot(v1[sort_idx][::10], v2[sort_idx][::10], '.', alpha=0.3, color='gray') _ = ax.set_xlabel("Occurences in %s" %k1) _ = ax.set_ylabel("Occurences in %s" %k2) ax.xaxis.set_major_locator(ticker.MultipleLocator(10000)) diff --git a/analysis/plot_metrics.py b/analysis/plot_metrics.py index 78e6806..c7305c9 100644 --- a/analysis/plot_metrics.py +++ b/analysis/plot_metrics.py @@ -3,7 +3,6 @@ import matplotlib.pyplot as plt import pandas as pd import seaborn as sns -from sklearn import metrics model_order = [ '170m-uniref50', @@ -15,6 +14,7 @@ '3b-uniref', '3b-msa-gigaclust', '3b-msa-uniref90-cooldown', + 'evodiff' ] pal3b = sns.color_palette() @@ -27,7 +27,101 @@ model_dict = { '170m-gigaclust': { - "name": "170m-GGR", + "name": "170m-GR", + "hue": pal170m[4], + }, + '170m-uniref50': { + "name": "170m-UR50", + "hue": pal170m[7], + }, + '170m-uniref90': { + "name": "170m-UR90", + "hue": pal170m[3], + }, + '170m-nofilter': { + "name": "170m-UR50-BRu", + "hue": pal170m[0], + }, + '170m-rmsd': { + "name": "170m-UR50-BRq", + "hue": pal170m[0], + }, + '170m-bothfilter': { + "name": "170m-UR50-BRn", + "hue": pal170m[0], + }, + '3b-uniref': { + "name": "3b-UR90", + "hue": sns.color_palette("deep")[3], + }, + '3b-msa-gigaclust': { + "name": "3b-GR-HM", + "hue": pal3b[1], + }, + '3b-msa-uniref90-cooldown': { + "name": "3b-GR-HM-c", + "hue": sns.color_palette("pastel")[1], + }, + "evodiff": { + "name": "EvoDiff-seq", + "hue": sns.color_palette("pastel")[7], + } +} +model_palette = { + d['name']: d['hue'] for d in model_dict.values() +} + +sns.set_theme(font_scale=1.2) +sns.set_style('white') +out_fpath = "/home/kevyan/generations/" +df = pd.read_csv(os.path.join(out_fpath, "scaffold_summary.csv")) +df['model'] = [model_dict[m]['name'] for m in df['model'].values] +df = df.melt(id_vars=['model']) +fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) +_ = sns.barplot(df, x='variable', y='value', hue='model', ax=ax, legend=True, + hue_order=[model_dict[m]['name'] for m in model_order], palette=model_palette) +_ = ax.set_xlabel("") +_ = ax.set_ylabel("Score") +# _ = ax.set_xticklabels(["to UR50", "to GGR"]) +hatch_me = [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19] +for i, bar in enumerate(ax.patches): + if i in hatch_me: + bar.set_hatch('.') +legend = ax.legend(fontsize=13, labelspacing=0.15) +# legend = ax.legend( +# loc='upper right', +# bbox_to_anchor=(0.9, 0.98 , 0.32, -0.102), +# mode='expand', +# ncol=1, +# bbox_transform=fig.transFigure, +# ) +_ = fig.savefig("/home/kevyan/generations/scaffold_summary.pdf", bbox_inches='tight', dpi=300) + + + +model_order = [ + '170m-uniref50', + '170m-uniref90', + '170m-gigaclust', + # '170m-nofilter', + # '170m-rmsd', + '170m-bothfilter', + '3b-uniref', + '3b-msa-gigaclust', + '3b-msa-uniref90-cooldown', +] + +pal3b = sns.color_palette() +pal170m = sns.color_palette("deep") +# UR50/90.OpenProteinSet: gray +# GGR: purple +# BBR: blue +# Alignment: orange +# Homologs: green + +model_dict = { + '170m-gigaclust': { + "name": "170m-GR", "hue": pal170m[4], "step": 76000, "UR50 perplexity": 13.67, @@ -48,21 +142,21 @@ "GGR perplexity": 11.85, }, '170m-nofilter': { - "name": "170m-UR50-BBR-u", + "name": "170m-UR50-BRu", "hue": pal170m[0], "step": 76000, "UR50 perplexity": 11.66, "GGR perplexity": 11.87, }, '170m-rmsd': { - "name": "170m-UR50-BBR-sc", + "name": "170m-UR50-BRq", "hue": pal170m[0], "step": 76000, "UR50 perplexity": 11.67, "GGR perplexity": 11.91, }, '170m-bothfilter': { - "name": "170m-UR50-BBR-n", + "name": "170m-UR50-BRn", "hue": pal170m[0], "step": 76000, "UR50 perplexity": 11.78, @@ -76,19 +170,23 @@ "GGR perplexity": 9.64, }, '3b-msa-gigaclust': { - "name": "3b-GGR-MSA", + "name": "3b-GR-HM", "hue": pal3b[1], "step": 52000, "UR50 perplexity": 11.95, "GGR perplexity": 6.68, }, '3b-msa-uniref90-cooldown': { - "name": "3b-cooled", + "name": "3b-GR-HM-c", "hue": sns.color_palette("pastel")[1], "step": 25000, "UR50 perplexity": 10.11, "GGR perplexity": 9.21, }, + "evodiff": { + "name": "EvoDiff-seq", + "hue": sns.color_palette("pastel")[3], + } } model_palette = { d['name']: d['hue'] for d in model_dict.values() @@ -110,17 +208,16 @@ models += [model_dict[model]['name']] * 4 values += [row['protbert_fd_to_uniref'], row['protbert_fd_to_gigaref']] values += [model_dict[model]['UR50 perplexity'], model_dict[model]['GGR perplexity']] - metrics += ["FPD to UR50", "FPD to GGR", "UR50 perplexity", "GGR perplexity"] + metrics += ["FPD to UR50", "FPD to GR", "UR50 perplexity", "GR perplexity"] df['model'] = models df['value'] = values df['metric'] = metrics - sns.set_theme(font_scale=1.2) sns.set_style('white') fig, axs = plt.subplots(1, 2, figsize=(12, 5)) ax = axs[1] -plot_me = df[df['metric'].isin(["FPD to UR50", "FPD to GGR"])] +plot_me = df[df['metric'].isin(["FPD to UR50", "FPD to GR"])] _ = sns.barplot(plot_me, x='metric', y='value', hue='model', ax=ax, hue_order=[model_dict[m]['name'] for m in model_order], palette=model_palette) _ = ax.set_xlabel("") @@ -140,7 +237,7 @@ # _ = fig.savefig("/home/kevyan/generations/model_fpd.pdf", bbox_inches='tight', dpi=300) ax = axs[0] -plot_me = df[df['metric'].isin(["UR50 perplexity", "GGR perplexity"])] +plot_me = df[df['metric'].isin(["UR50 perplexity", "GR perplexity"])] _ = sns.barplot(plot_me, x='metric', y='value', hue='model', ax=ax, hue_order=[model_dict[m]['name'] for m in model_order], palette=model_palette, legend=False) _ = ax.set_xlabel("") @@ -164,9 +261,9 @@ df = pd.read_csv(os.path.join(out_fpath, "valid_by_conditioning_%s_%d.csv" %(task, rank))) df['msa_id'] = current_id + df['msa_id'] if task == 'gap': - df['task'] = 'Alignment conditioning' + df['task'] = 'Aligned homologs' else: - df['task'] = "Homolog conditioning" + df['task'] = "Unaligned homologs" current_id = max(df['msa_id']) dfs.append(df) @@ -174,7 +271,7 @@ df0 = df[df['n_conditioning'] == 0] df0.groupby('task')['perplexity'].mean() df.tail() -pal = sns.color_palette() +pal = sns.color_palette("deep") fig1, ax1 = plt.subplots(1, 1, figsize=(7, 5)) # ax2 = ax1.twinx() _ = sns.lineplot(data=df[df['n_conditioning'] < 63], x='n_conditioning', y='perplexity', hue='task', ax=ax1, palette=[pal[1], pal[2]]) @@ -182,4 +279,353 @@ _ = ax1.set_ylabel("Average perplexity") ax1.legend().set_title(None) _ = fig1.savefig(os.path.join(out_fpath, "dayhoff-3b-cooled" + "_long_msas" + "_" + direction + "_conditioning64.pdf"), - dpi=300, bbox_inches="tight") \ No newline at end of file + dpi=300, bbox_inches="tight") + + +# expression data +data_path = '/home/kevyan/generations/expression' +df = pd.read_csv(os.path.join(data_path, 'ginkgo_merged_all_data.csv')) +model_order = [ + 'uniref90-170M', + "gigaclust-170M", + "3BCOOLED", + "gigaclust-3B", + "10mbothfilter" +] +model_dict = { + 'gigaclust-170M': { + "name": "170m-GR", + "hue": pal170m[4], + "step": 76000, + "UR50 perplexity": 13.67, + "GGR perplexity": 9.36, + }, + 'uniref90-170M': { + "name": "170m-UR90", + "hue": pal170m[3], + "step": 76000, + "UR50 perplexity": 11.52, + "GGR perplexity": 11.85, + }, + '10mbothfilter': { + "name": "170m-UR50-BRn", + "hue": pal170m[0], + "step": 76000, + "UR50 perplexity": 11.78, + "GGR perplexity": 12.03, + }, + 'gigaclust-3B': { + "name": "3b-GR-HM", + "hue": pal3b[1], + "step": 52000, + "UR50 perplexity": 11.95, + "GGR perplexity": 6.68, + }, + '3BCOOLED': { + "name": "3b-GR-HM-c", + "hue": sns.color_palette("pastel")[1], + "step": 25000, + "UR50 perplexity": 10.11, + "GGR perplexity": 9.21, + }, +} +model_palette = { + d['name']: d['hue'] for d in model_dict.values() +} +grouped = df.groupby('model_name') +dfe = grouped.agg({"Express in any system": np.mean}) +dfe = dfe.reset_index() +dfe['model'] = [model_dict[m]['name'] for m in dfe['model_name'].values] +dfe['Fraction expressed'] = dfe['Express in any system'] +dfe['UR50 perplexity'] = [model_dict[m]['UR50 perplexity'] for m in dfe['model_name'].values] +dfe['GR perplexity'] = [model_dict[m]['GGR perplexity'] for m in dfe['model_name'].values] +order = [model_dict[m]['name'] for m in model_order] +sns.set_theme(font_scale=1.2) +sns.set_style('white') +# Bar plot of expression by model +fig, ax = plt.subplots() +_ = sns.barplot(dfe, x='model', y='Fraction expressed', hue='model', ax=ax, + hue_order=order, order=order, palette=model_palette, legend=False) +# _ = ax.set_xticklabels(["to UR50", "to GGR"]) +hatch_me = [0, 1, 4, 5, 6, 9] +for i, bar in enumerate(ax.patches): + if i in hatch_me: + bar.set_hatch('.') +_ = ax.set_xticks(ax.get_xticks()) +_ = ax.set_xlabel("") +_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') +_ = fig.savefig(os.path.join(data_path, "expression_bars.pdf"), bbox_inches='tight', dpi=300) + + +# scPerp / plddt x expression +df['pLDDT'] = df['plddt'] / 100 +df['scPerplexity'] = df['perp'] +df['Expressed'] = df['Express in any system'] == 1 + + +df = df.sort_values('Express in any system', ascending=True) + +blast = pd.read_csv(os.path.join(data_path, "sow2_blast.csv"), header=None) +blast.columns = ['names_clean', 'hit', 'identity', 'align_len', 'mismatches', 'gap_opens', + 'qstart', 'qend', 'sstart', 'ssend', 'evalue', 'bitscore', 'positives'] +blast = blast.drop_duplicates(subset='names_clean', keep='first') +blast['hit_length'] = blast['qend'] - blast['qstart'] + 1 +for i, row in df.iterrows(): + if row['names_clean'] in blast['names_clean'].values: + blast_row = blast[blast['names_clean'] == row['names_clean']] + df.loc[i, 'homology'] = blast_row['hit_length'].values[0] * blast_row['identity'].values[0] / len(row['Sequence']) / 100 + else: + df.loc[i, 'homology'] = 0 + +# fig, ax = plt.subplots() +# _ = sns.scatterplot(data=dfe, x='pLDDT', y='scPerplexity', hue='Expressed', ax=ax, +# palette=['gray', sns.color_palette()[0]], hue_order=[False, True], alpha=0.7) +# _ = fig.savefig(os.path.join(data_path, "fidelity_expression.pdf"), bbox_inches='tight', dpi=300) + +# stripplots for plddt, scperplexity, and homology +# unroll it +value_vars = ['homology', 'pLDDT', 'scPerplexity'] +melted = df.melt(id_vars=['names_clean', 'Expressed'], value_vars=value_vars) +for i, row in melted.iterrows(): + if row['Expressed']: + melted.loc[i, 'expressed_str'] = "Expressed" + else: + melted.loc[i, 'expressed_str'] = "Not expressed" +for i, row in df.iterrows(): + if row['Expressed']: + df.loc[i, 'expressed_str'] = "Expressed" + else: + df.loc[i, 'expressed_str'] = "Not expressed" +for v in value_vars: + fig, ax = plt.subplots(figsize=(3.6, 4.8)) + _ = sns.stripplot(df, x='expressed_str', y=v, + hue='expressed_str', ax=ax, palette=[sns.color_palette()[0], "gray"], + hue_order=["Expressed", "Not expressed"], alpha=0.7) + # _ = sns.stripplot(melted[melted['variable'] == v], x='variable', y='value', + # hue='expressed_str', ax=ax, palette=[sns.color_palette()[0], "gray"], + # hue_order=["Expressed", "Not expressed"], dodge=True, alpha=0.7) + # _ = ax.legend(title=False) + if v != "scPerplexity": + _ = ax.set_ylim(0, 1) + _ = ax.set_ylabel(v) + _ = ax.set_xlabel("") + _ = fig.savefig(os.path.join(data_path, "expression_%s_strips.pdf" %v), bbox_inches='tight', dpi=300) + +def plot_cdf(x, color=sns.color_palette()[0], label=None, ax=None, **kwargs, ): + p = x.values + p.sort() + _ = ax.plot(p, np.linspace(0, 1, len(x)), '-', color=color, label=label, **kwargs) + +for metric in ['pLDDT', 'scPerplexity', 'homology']: + fig, ax = plt.subplots(figsize=(4.8, 4.8)) + for expression in [False, True]: + if expression: + e = "Expressed" + hue = sns.color_palette()[0] + else: + e = "Not expressed" + hue = "gray" + plot_cdf(df[df['Expressed'] == expression][metric], ax=ax, label=e, color=hue) + ax.set_ylabel("Percentile") + ax.set_xlabel(metric) + if metric == "pLDDT": + ax.legend() + _ = fig.savefig(os.path.join(data_path, "cdf_expression_%s.pdf" %metric), bbox_inches='tight', dpi=300) + + + +# perplexity and fpd by model +order_to_order = { + '170m-uniref90': 'uniref90-170M', + '170m-gigaclust': "gigaclust-170M", + '3b-msa-uniref90-cooldown': "3BCOOLED", + '3b-msa-gigaclust': "gigaclust-3B", + '170m-bothfilter': "10mbothfilter" +} +melted = dfe.melt(id_vars=['model_name', 'Fraction expressed'], value_vars=['UR50 perplexity', 'GR perplexity']) +fig, ax = plt.subplots(1, 1) +pal = sns.color_palette() +_ = sns.scatterplot(data=melted, x='value', y='Fraction expressed', hue='variable', style='variable', + ax=ax, palette=[pal[4], pal[7]], s=100) +_ = ax.set_xlabel("Perplexity") +handles, labels = ax.get_legend_handles_labels() +ax.legend(handles=handles[:], labels=labels[:]) # This gets rid of the title +_ = fig.savefig('/home/kevyan/generations/expression/perplexity_combined.pdf', bbox_inches='tight', dpi=300) +dfe.columns +fpd_df = pd.read_csv("/home/kevyan/generations/fpd.csv") +fpd_df = fpd_df[(fpd_df['direction'] == 'fwd') & (fpd_df['temperature'] == 1)] +for i, row in fpd_df.iterrows(): + if row['name'] not in order_to_order: + continue + model = model_dict[order_to_order[row['name']]]['name'] + idx = dfe[dfe['model'] == model].index + dfe.loc[idx, 'FPD to UR50'] = row['protbert_fd_to_uniref'] + dfe.loc[idx, 'FPD to GR'] = row['protbert_fd_to_gigaref'] + +melted = dfe.melt(id_vars=['model_name', 'Fraction expressed'], value_vars=['FPD to UR50', 'FPD to GR']) + +fig, ax = plt.subplots(1, 1) +pal = sns.color_palette() +_ = sns.scatterplot(data=melted, x='value', y='Fraction expressed', hue='variable', style='variable', + ax=ax, palette=[pal[4], pal[7]], s=100) +_ = ax.set_xlabel("FPD") +handles, labels = ax.get_legend_handles_labels() +ax.legend(handles=handles[:], labels=labels[:]) # This gets rid of the title +_ = fig.savefig('/home/kevyan/generations/expression/fpd_combined.pdf', bbox_inches='tight', dpi=300) + + +# mets = ['UR50 perplexity', 'GGR perplexity', 'FPD to UR50', 'FPD to GR'] +# for m in mets: +# fig, ax = plt.subplots(figsize=(4.8, 4.8)) +# _ = sns.scatterplot(dfe, x=m, y='Fraction expressed', hue='model', legend='UR50' in m, ax=ax, +# hue_order=order, palette=model_palette) +# if 'UR50' in m: +# legend = ax.legend(title=False) +# +# +# _ = fig.savefig(os.path.join(data_path, "model_%s.pdf" %m), bbox_inches='tight', dpi=300) + +for model in model_order: + df_ = df[df['model_name'] == model] + print(model) + print(metrics.roc_auc_score(df_['Expressed'], df_['pLDDT'])) + print(metrics.roc_auc_score(df_['Expressed'], df_['scPerplexity'])) + print(metrics.roc_auc_score(df_['Expressed'], df_['homology'])) + print(metrics.roc_auc_score(df_['Expressed'], df_['pLDDT'] / df['scPerplexity'])) + +with open(os.path.join(data_path, "sow2.fasta"), "w") as f: + for i, row in df.iterrows(): + f.write(">{}\n".format(row['names_clean'])) + f.write("{}\n".format(row['Sequence'])) + +stats.pearsonr(df['pLDDT'], df['homology']) +metrics.roc_auc_score(df['Expressed'], df['pLDDT']) +metrics.roc_auc_score(df['Expressed'], df['scPerplexity']) +metrics.roc_auc_score(df['Expressed'], df['pLDDT'] / df['scPerplexity']) +metrics.roc_auc_score(df['Expressed'], df['homology']) + +data_path = "/home/kevyan/generations/" +df = pd.read_parquet(os.path.join(data_path, "mmd_results.parquet")) +df.head() +df.iloc[:, :3] + + +names = ['3BCOOLEDSEQUENCE11', ' 3BCOOLEDSEQUENCE86 ', ' 10mbothfilterSEQUENCE123 ', + ' 10mbothfilterSEQUENCE133 ', ' gigaclust3BSEQUENCE10 ', 'uniref90170MSEQUENCE49'] +df[df['Microsoft sequence name'].isin(names)][['Microsoft sequence name', 'homology', 'pLDDT', 'scPerplexity']] +df[df['Microsoft sequence name'] == '3BCOOLEDSEQUENCE86'][['Microsoft sequence name', 'homology', 'pLDDT', 'scPerplexity']] + +df['Microsoft sequence name'] + +# unconditional fidelities + +df = pd.read_csv('/home/kevyan/generations/folding_t1_allmodels.csv') +df['scPerplexity'] = df['proteinmpnnperplexity'] +df_natural = pd.read_csv("/home/kevyan/generations/gigaref_analysis/ggr_plddt_mpnn.csv") +df_natural['scPerplexity'] = df_natural['mpnnperplexity'] +df = pd.concat([df, df_natural]) +df['pLDDT'] = df['esmfoldplddt'] +grouped = df.groupby(['model']) +grouped = grouped.agg({'pLDDT': ['mean', 'std'], 'scPerplexity': ['mean', 'std']}) +model_order = [ + 'jamba-170m-seq-36w_76000', + 'jamba-170m-seqsam-36w_76000', + 'jamba-170m-gigaclust-36w_76000', + 'jamba-170m-10mnofilter-36w_76000', + 'jamba-170m-10mrmsd-36w_76000', + 'jamba-170m-10mbothfilter-36w_76000', + 'jamba-3b-seq-sam-biar-fsdp-tok90k_43300', + 'jamba-3b-indel-gigaclust-120k-2_52000', + 'jamba-3b-cooldown7_25000', + 'uniref50_', + 'rep', + 'singletons' +] +grouped = grouped.loc[model_order] +grouped = grouped.reset_index() +model_dict = { + 'uniref50_': {'name':"UniRef50"}, + 'rep': {'name':"GigaRef-clusters"}, + 'singletons': {'name': "GigaRef-singletons"}, + 'jamba-170m-gigaclust-36w_76000': { + "name": "170m-GGR", + "hue": pal170m[4], + "step": 76000, + "UR50 perplexity": 13.67, + "GGR perplexity": 9.36, + }, + 'jamba-170m-seq-36w_76000': { + "name": "170m-UR50", + "hue": pal170m[7], + "step": 76000, + "UR50 perplexity": 11.62, + "GGR perplexity": 11.88, + }, + 'jamba-170m-seqsam-36w_76000': { + "name": "170m-UR90", + "hue": pal170m[3], + "step": 76000, + "UR50 perplexity": 11.52, + "GGR perplexity": 11.85, + }, + 'jamba-170m-10mnofilter-36w_76000': { + "name": "170m-UR50-BBR-u", + "hue": pal170m[0], + "step": 76000, + "UR50 perplexity": 11.66, + "GGR perplexity": 11.87, + }, + 'jamba-170m-10mrmsd-36w_76000': { + "name": "170m-UR50-BBR-sc", + "hue": pal170m[0], + "step": 76000, + "UR50 perplexity": 11.67, + "GGR perplexity": 11.91, + }, + 'jamba-170m-10mbothfilter-36w_76000': { + "name": "170m-UR50-BBR-n", + "hue": pal170m[0], + "step": 76000, + "UR50 perplexity": 11.78, + "GGR perplexity": 12.03, + }, + 'jamba-3b-seq-sam-biar-fsdp-tok90k_43300': { + "name": "3b-UR90", + "hue": sns.color_palette("deep")[3], + "step": 43300, + "UR50 perplexity": 8.95, + "GGR perplexity": 9.64, + }, + 'jamba-3b-indel-gigaclust-120k-2_52000': { + "name": "3b-GGR-MSA", + "hue": pal3b[1], + "step": 52000, + "UR50 perplexity": 11.95, + "GGR perplexity": 6.68, + }, + 'jamba-3b-cooldown7_25000': { + "name": "3b-cooled", + "hue": sns.color_palette("pastel")[1], + "step": 25000, + "UR50 perplexity": 10.11, + "GGR perplexity": 9.21, + }, +} +for i, row in grouped.iterrows(): + print(' & '.join([model_dict[row['model'].values[0]]['name'], + "$%.3f \\pm %.3f$" %(row['pLDDT']['mean'], row['pLDDT']['std']), + "$%.2f \\pm %.2f$\\\\" % (row['scPerplexity']['mean'], row['scPerplexity']['std']) + ])) + + +for model in model_order[:-3]: + print(model) + print(stats.ttest_ind(df[(df['model'] == model) & (df['direction'] == "fwd")]['pLDDT'], + df[(df['model'] == model) & (df['direction'] == "rev")]['pLDDT'])) + print(stats.ttest_ind(df[(df['model'] == model) & (df['direction'] == "fwd")]['scPerplexity'], + df[(df['model'] == model) & (df['direction'] == "rev")]['scPerplexity'])) +model_palette = { + d['name']: d['hue'] for d in model_dict.values() +} + +set(df['model']) \ No newline at end of file diff --git a/analysis/plot_scaffolding.py b/analysis/plot_scaffolding.py new file mode 100644 index 0000000..31cc258 --- /dev/null +++ b/analysis/plot_scaffolding.py @@ -0,0 +1,159 @@ +import os + +import pandas as pd + +import matplotlib.pyplot as plt +import seaborn as sns + + + +base_path = '/home/kevyan/generations/scaffolding_results/' +benchmarks = [ + 'rfdiff', 'motifbench' +] + +models = ['dayhoff-170m', 'dayhoff-3b', 'evodiff'] +pal3b = sns.color_palette() +pal170m = sns.color_palette("deep") + +model_dict = { + '170m': { + "name": "170m-GGR", + "hue": pal170m[4], + "UR50 perplexity": 13.67, + "GGR perplexity": 9.36, + }, + '170m-uniref50': { + "name": "170m-UR50", + "hue": pal170m[7], + "UR50 perplexity": 11.62, + "GGR perplexity": 11.88, + }, + '170m-uniref90': { + "name": "170m-UR90", + "hue": pal170m[3], + "UR50 perplexity": 11.52, + "GGR perplexity": 11.85, + }, + '170m-nofilter': { + "name": "170m-UR50-BBR-u", + "hue": pal170m[0], + "UR50 perplexity": 11.66, + "GGR perplexity": 11.87, + }, + '170m-rmsd': { + "name": "170m-UR50-BBR-sc", + "hue": pal170m[0], + "UR50 perplexity": 11.67, + "GGR perplexity": 11.91, + }, + 'dayhoff-170m': { + "name": "170m-UR50-BBR-n", + "hue": pal170m[0], + "UR50 perplexity": 11.78, + "GGR perplexity": 12.03, + }, + '3b-uniref': { + "name": "3b-UR90", + "hue": sns.color_palette("deep")[3], + "UR50 perplexity": 8.95, + "GGR perplexity": 9.64, + }, + '3b-msa-gigaclust': { + "name": "3b-GGR-MSA", + "hue": pal3b[1], + "UR50 perplexity": 11.95, + "GGR perplexity": 6.68, + }, + 'dayhoff-3b': { + "name": "3b-cooled", + "hue": sns.color_palette("pastel")[1], + "UR50 perplexity": 10.11, + "GGR perplexity": 9.21, + }, + 'evodiff': { + 'name': 'EvoDiff-Seq', + 'hue': pal3b[5] + } +} + +model_palette = { + d['name']: d['hue'] for d in model_dict.values() +} +model_order = [model_dict[model]['name'] for model in models] +dfs = [] +for model in models: + for benchmark in benchmarks: + files = os.listdir(os.path.join(base_path, benchmark, model)) + for file in files: + if file == 'successes.csv': + continue + df = pd.read_csv(os.path.join(base_path, benchmark, model, file)) + df['problem'] = '_'.join(file.split('_')[:2]) + df['model'] = model_dict[model]['name'] + df['benchmark'] = benchmark + dfs.append(df) +df = pd.concat(dfs) +df = df.reset_index() +df['pLDDT'] = df['plddt'] +df['motif RMSD'] = df['scrmsd'] + +cutoff = {'pLDDT': 0.7, 'motif RMSD': 1.0} + +sns.set_theme(font_scale=1.2) +sns.set_style('white') +for met in ['pLDDT', 'motif RMSD']: + for benchmark in benchmarks: + problem_order = sorted(set(df[df['benchmark'] == benchmark]['problem'])) + fig, ax = plt.subplots(figsize=(16, 4)) + legend = met == 'motif RMSD' and benchmark == 'rfdiff' + _ = sns.stripplot(df[df['benchmark'] == benchmark], x='problem', y=met, hue='model', palette=model_palette, hue_order=model_order, ax=ax, + legend=legend, dodge=True, s=4, alpha=0.7, order=problem_order) + if legend: + ax.legend(title=None) + _ = ax.axhline(cutoff[met], color='gray', linestyle='-') + _ = ax.set_xticks(ax.get_xticks()) + _ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') + _ = ax.tick_params(axis='x', which='both', direction='inout') + _ = fig.savefig(os.path.join(base_path, benchmark + '_' + met + '.pdf'), dpi=300, bbox_inches='tight') + +grouped = df.groupby(['benchmark', 'problem', 'model']) +df_s = grouped.agg({'success': 'sum'}) +df_s = df_s.reset_index() +for i, row in df_s.iterrows(): + if row['benchmark'] == 'rfdiff': + new_problem = 'RFdiffusion ' + else: + new_problem = "MotifBench " + df_s.loc[i, 'problem'] = new_problem + row['problem'] + +grouped = df_s.groupby('problem') +df_p = grouped.agg({'success': 'sum'}) +df_p = df_p.reset_index() +keep = df_p[df_p['success'] > 0][['problem']].values[:, 0] +df_s = df_s[df_s['problem'].isin(keep)] +problem_order = sorted(set(df_s['problem'])) + +fig, ax = plt.subplots(figsize=(12, 4.8)) +_ = sns.barplot(df_s, x='problem', y='success', hue='model', palette=model_palette, + hue_order=model_order, ax=ax, + legend=True, order=problem_order) +hatch_me = [i for i in range(16) ] + [48] +# for i, bar in enumerate(ax.patches): +# print(i, bar) +for i, bar in enumerate(ax.patches): + if i in hatch_me: + bar.set_hatch('.') +ax.legend(title=None) +_ = ax.set_ylabel("Successes (100 attempts)") +_ = ax.set_xticks(ax.get_xticks()) +_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') +_ = fig.savefig(os.path.join(base_path, 'successes.pdf'), dpi=300, bbox_inches='tight') + +df_s['s'] = df_s['success'] > 0 +grouped = df_s.groupby(['benchmark', 'model']) +summary = grouped.agg({'s': 'sum', 'success': 'sum'}).reset_index() +summary = summary.pivot(index='model', columns='benchmark').reset_index() +for i, row in summary.iterrows(): + print(' & '.join([row['model'].values[0], str(row['s']['rfdiff']), str(row['success']['rfdiff']), + str(row['s']['motifbench']), str(row['success']['motifbench'])]) + '\\\\') \ No newline at end of file diff --git a/analysis/plot_valid.py b/analysis/plot_valid.py index 37d2981..ea2a17a 100644 --- a/analysis/plot_valid.py +++ b/analysis/plot_valid.py @@ -207,8 +207,8 @@ df = pd.DataFrame() current_row = 0 current_msa_id = 0 -for task in ["indel"]: - for rank in [2, 3, 4, 5, 6, 7]: +for task in ["gap", "indel"]: + for rank in range(8): # for rank in range(7): out_file = os.path.join(out_fpath, "valid_long_" + model + '_' + str( checkpoint) + "_" + task + "_" + direction + "_%d.pt" % rank) @@ -288,6 +288,8 @@ tasks = ['indel', 'gap'] for task in tasks: for rank in range(world_size): + if task == "gap" and rank == 7: + continue df = pd.read_csv(os.path.join(out_fpath, "valid_by_conditioning_%s_%d.csv" %(task, rank))) df['msa_id'] = current_id + df['msa_id'] current_id = max(df['msa_id']) diff --git a/analysis/plot_zs.py b/analysis/plot_zs.py index 6a78a36..92c75ee 100644 --- a/analysis/plot_zs.py +++ b/analysis/plot_zs.py @@ -87,7 +87,7 @@ for model in model_names: for dms in dmss: for rank in range(world_size): - df_path = os.path.join(out_fpath, dms, model + '_{}.csv'.format(rank)) + df_path = os.path.join(out_dir, dms, model + '_{}.csv'.format(rank)) if os.path.exists(df_path): df = pd.read_csv(df_path) if 'seq_spearman' in df: