Skip to content

Commit 69279b1

Browse files
authored
Merge pull request #208 from beiko-lab/fix-approx-rspr
Updated rspr_approx.py to remove double counting
2 parents c5a4004 + a45a8ee commit 69279b1

File tree

1 file changed

+126
-34
lines changed

1 file changed

+126
-34
lines changed

bin/rspr_approx.py

Lines changed: 126 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from pathlib import Path
99
import argparse
1010
import subprocess
11+
import shutil
1112
from ete3 import Tree
1213
import pandas as pd
1314
from collections import defaultdict
1415
from matplotlib import pyplot as plt
1516
from matplotlib.colors import LogNorm
1617
import seaborn as sns
1718
import tempfile
19+
import logging
1820

1921

2022
#####################################################################
@@ -98,8 +100,9 @@ def read_tree(input_path):
98100
tree_string = f.read()
99101
formatted = re.sub(r";[^:]+:", ":", tree_string)
100102
is_duplicated = check_formatted_tree(formatted)
103+
is_small = formatted.count(",") < 3
101104

102-
return Tree(formatted), is_duplicated
105+
return Tree(formatted), is_duplicated, is_small
103106

104107

105108
#####################################################################
@@ -111,33 +114,38 @@ def read_tree(input_path):
111114
#####################################################################
112115

113116

114-
def root_tree(input_path, basename, output_path):
115-
tre,is_duplicated = read_tree(input_path)
117+
def root_one_tree(input_path, basename, output_path):
118+
tre,is_duplicated,is_small = read_tree(input_path)
116119
midpoint = tre.get_midpoint_outgroup()
117120
tre.set_outgroup(midpoint)
118121
if is_duplicated:
119122
outdir = Path(output_path) / "multiple"
120123
Path(outdir).mkdir(exist_ok=True, parents=True)
121124
output_path = outdir / basename
122125
output_path = str(output_path).replace(".tre", ".tre.multiple")
126+
elif is_small:
127+
outdir = Path(output_path) / "small"
128+
Path(outdir).mkdir(exist_ok=True, parents=True)
129+
output_path = outdir / basename
130+
output_path = str(output_path).replace(".tre", ".tre.small")
123131
else:
124132
outdir = Path(output_path) / "unique"
125133
Path(outdir).mkdir(exist_ok=True, parents=True)
126134
output_path = outdir / basename
127135

128136
tre.write(outfile=output_path)
129-
return tre.write(), len(tre.get_leaves()), output_path, is_duplicated
137+
return tre.write(), len(tre.get_leaves()), output_path, is_duplicated, is_small
130138

131139
def root_reference_tree(input_path, output_path):
132-
tre, _ = read_tree(input_path)
140+
tre, _, _ = read_tree(input_path)
133141
midpoint = tre.get_midpoint_outgroup()
134142
tre.set_outgroup(midpoint)
135143
tre.write(outfile=output_path)
136144
return tre.write(), len(tre.get_leaves())
137145

138146

139147
#####################################################################
140-
### FUNCTION ROOT_TREE
148+
### FUNCTION ROOT_ALL_TREES
141149
### Root all the unrooted input trees in directory
142150
### core_tree: path of the core tree
143151
### gene_trees: path of the csv file containing all the gene tree paths
@@ -148,8 +156,7 @@ def root_reference_tree(input_path, output_path):
148156
#####################################################################
149157

150158

151-
def root_trees(core_tree, gene_trees_path, output_dir, results, merge_pair=False):
152-
print("Rooting trees")
159+
def root_all_trees(core_tree, gene_trees_path, output_dir, results, merge_pair=False):
153160
#'''
154161
reference_tree = core_tree
155162

@@ -165,11 +172,11 @@ def root_trees(core_tree, gene_trees_path, output_dir, results, merge_pair=False
165172
rooted_gene_trees_path = os.path.join(output_dir, "rooted_gene_trees")
166173
for filename in df_gene_trees["path"]:
167174
basename = Path(filename).name
168-
gene_content, gene_tree_size, gene_tree_path, is_duplicated = root_tree(
175+
gene_content, gene_tree_size, gene_tree_path, is_duplicated, is_small = root_one_tree(
169176
filename,
170177
basename,
171178
rooted_gene_trees_path)
172-
if not is_duplicated:
179+
if not (is_duplicated or is_small):
173180
results.loc[basename, "tree_size"] = gene_tree_size
174181
if merge_pair:
175182
with open(gene_tree_path, "w") as f2:
@@ -205,6 +212,9 @@ def extract_approx_distance(text):
205212

206213
def run_approx_rspr(results, input_file, lst_filename, rspr_path):
207214
input_file.seek(0)
215+
216+
command_exists = shutil.which(rspr_path[0])
217+
208218
result = subprocess.run(
209219
rspr_path, stdin=input_file, capture_output=True, text=True
210220
)
@@ -231,7 +241,6 @@ def run_approx_rspr(results, input_file, lst_filename, rspr_path):
231241
def approx_rspr(
232242
rooted_gene_trees_path, results, min_branch_len=0, max_support_threshold=0.7
233243
):
234-
print("Calculating approx distance")
235244
rspr_path = [
236245
"rspr",
237246
"-approx",
@@ -245,20 +254,73 @@ def approx_rspr(
245254
lst_filename = []
246255
with tempfile.TemporaryFile(mode='w+') as temp_file:
247256
for filename in os.listdir(rooted_gene_trees_path):
248-
if cur_count == group_size:
249-
run_approx_rspr(results, temp_file, lst_filename, rspr_path)
250-
temp_file.seek(0)
251-
temp_file.truncate()
252-
lst_filename.clear()
253-
cur_count = 0
254-
255-
gene_tree_path = os.path.join(rooted_gene_trees_path, filename)
256-
with open(gene_tree_path, "r") as infile:
257-
temp_file.write(infile.read() + "\n")
258-
lst_filename.append(filename)
259-
cur_count += 1
260-
if cur_count > 0:
261-
run_approx_rspr(results, temp_file, lst_filename, rspr_path)
257+
if str(filename) in results.index:
258+
print("Found " + str(filename))
259+
if cur_count == group_size:
260+
run_approx_rspr(results, temp_file, lst_filename, rspr_path)
261+
temp_file.seek(0)
262+
temp_file.truncate()
263+
lst_filename.clear()
264+
cur_count = 0
265+
266+
gene_tree_path = os.path.join(rooted_gene_trees_path, filename)
267+
with open(gene_tree_path, "r") as infile:
268+
lines = infile.readlines()
269+
if len(lines) < 2:
270+
print(f"File {filename} does not have enough lines.")
271+
continue
272+
tree = Tree(lines[1].strip())
273+
# Calculate N: number of nodes at or above the support threshold
274+
# num_resolved = sum(1 for node in tree.traverse() if node.support >= max_support_threshold and not node.is_leaf())
275+
num_resolved = -1
276+
for node in tree.traverse():
277+
if node.support is not None and node.support >= max_support_threshold and not node.is_leaf():
278+
num_resolved += 1
279+
280+
tree_size = len(tree.get_leaves())
281+
results.loc[filename, "Num resolved"] = num_resolved
282+
results.loc[filename, "N/tree_size"] = num_resolved / tree_size if tree_size > 0 else 0
283+
lst_filename.append(filename)
284+
temp_file.write(lines[0].strip() + "\n" + lines[1].strip() + "\n")
285+
cur_count += 1
286+
if cur_count > 0:
287+
run_approx_rspr(results, temp_file, lst_filename, rspr_path)
288+
289+
# Add the approx_drSPR/N column
290+
results["approx_drSPR/N"] = results.apply(lambda row: float(row["approx_drSPR"]) / row["Num resolved"] if row["Num resolved"] > 0 else 0, axis=1)
291+
print("CBA " + str(results))
292+
293+
#def approx_rspr_old(
294+
# rooted_gene_trees_path, results, min_branch_len=0, max_support_threshold=0.7
295+
#):
296+
# print("Calculating approx distance")
297+
# rspr_path = [
298+
# "rspr",
299+
# "-approx",
300+
# "-multifurcating",
301+
# "-length " + str(min_branch_len),
302+
# "-support " + str(max_support_threshold),
303+
# ]
304+
#
305+
# group_size = 10000
306+
# cur_count = 0
307+
# lst_filename = []
308+
# with tempfile.TemporaryFile(mode='w+') as temp_file:
309+
# for filename in os.listdir(rooted_gene_trees_path):
310+
# if cur_count == group_size:
311+
# run_approx_rspr(results, temp_file, lst_filename, rspr_path)
312+
# temp_file.seek(0)
313+
# temp_file.truncate()
314+
# lst_filename.clear()
315+
# cur_count = 0
316+
#
317+
# gene_tree_path = os.path.join(rooted_gene_trees_path, filename)
318+
# with open(gene_tree_path, "r") as infile:
319+
# temp_file.write(infile.read() + "\n")
320+
# lst_filename.append(filename)
321+
# cur_count += 1
322+
# if cur_count > 0:
323+
# run_approx_rspr(results, temp_file, lst_filename, rspr_path)
262324

263325

264326
#####################################################################
@@ -289,7 +351,6 @@ def generate_heatmap(freq_table, output_path, log_scale=False):
289351
#####################################################################
290352

291353
def make_heatmap(results, output_path, min_distance, max_distance):
292-
print("Generating heatmap")
293354

294355
# create sub dataframe
295356
sub_results = results[(results["approx_drSPR"] >= min_distance)]
@@ -306,7 +367,6 @@ def make_heatmap(results, output_path, min_distance, max_distance):
306367

307368

308369
def make_heatmap_from_tsv(input_path, output_path, min_distance, max_distance):
309-
print("Generating heatmap from CSV")
310370
results = pd.read_table(input_path)
311371
make_heatmap(results, output_path, min_distance, max_distance)
312372

@@ -339,7 +399,6 @@ def get_heatmap_group_size(all_values, max_groups=15):
339399
#####################################################################
340400

341401
def make_group_heatmap(results, output_path, min_distance, max_distance):
342-
print("Generating group heatmap")
343402

344403
# create sub dataframe
345404
sub_results = results[(results["approx_drSPR"] >= min_distance)]
@@ -383,7 +442,7 @@ def make_group_heatmap(results, output_path, min_distance, max_distance):
383442
### RETURN groups of trees
384443
#####################################################################
385444

386-
def generate_group_sizes(target_sum, max_groups=500):
445+
def generate_group_sizes(target_sum, max_groups=1000):
387446
degree = 1
388447
current_sum = 0
389448
group_sizes = []
@@ -410,7 +469,6 @@ def generate_group_sizes(target_sum, max_groups=500):
410469
#####################################################################
411470

412471
def make_groups_v1(results, min_limit=10):
413-
print("Generating groups")
414472
min_group = results[results["approx_drSPR"] <= min_limit]["file_name"].tolist()
415473
groups = defaultdict()
416474
first_group = "group_0"
@@ -438,7 +496,6 @@ def make_groups_v1(results, min_limit=10):
438496
#####################################################################
439497

440498
def make_groups(results, min_limit=10):
441-
print("Generating groups")
442499
min_group = results[results["approx_drSPR"] <= min_limit]["file_name"].tolist()
443500
groups = defaultdict()
444501
first_group = "group_0"
@@ -463,7 +520,6 @@ def make_groups(results, min_limit=10):
463520

464521

465522
def make_groups_from_csv(input_df, min_limit):
466-
print("Generating groups from CSV")
467523
groups = make_groups_v1(input_df, min_limit)
468524
tidy_data = [
469525
(key, val)
@@ -476,6 +532,24 @@ def make_groups_from_csv(input_df, min_limit):
476532
return merged
477533

478534

535+
# def join_annotation_data(df, annotation_data):
536+
# ann_df = pd.read_table(annotation_data, dtype={"genome_id": "str"})
537+
# ann_df.columns = map(str.lower, ann_df.columns)
538+
# ann_df.columns = ann_df.columns.str.replace(" ", "_")
539+
# ann_subset = ann_df[["gene", "product"]]
540+
#
541+
# df["tree_name"] = [f.split(".")[0] for f in df["file_name"]]
542+
#
543+
# merged = df.merge(ann_subset, how="left", left_on="tree_name", right_on="gene")
544+
#
545+
# if merged["gene"].isnull().all():
546+
# ann_subset = ann_df[["locus_tag", "gene", "product"]]
547+
# merged = df.merge(
548+
# ann_subset, how="left", left_on="tree_name", right_on="locus_tag"
549+
# )
550+
#
551+
# return merged.fillna(value="NULL").drop("tree_name", axis=1).drop_duplicates()
552+
479553
def join_annotation_data(df, annotation_data):
480554
ann_df = pd.read_table(annotation_data, dtype={"genome_id": "str"})
481555
ann_df.columns = map(str.lower, ann_df.columns)
@@ -492,8 +566,23 @@ def join_annotation_data(df, annotation_data):
492566
ann_subset, how="left", left_on="tree_name", right_on="locus_tag"
493567
)
494568

495-
return merged.fillna(value="NULL").drop("tree_name", axis=1).drop_duplicates()
569+
merged = merged.fillna("NULL").drop("tree_name", axis=1)
496570

571+
# Group by all columns except 'product' and aggregate 'product'
572+
grouped = (
573+
merged.groupby(list(merged.columns.difference(['product'])))
574+
.agg({'product': lambda x: '||'.join(sorted(set(x)))})
575+
.reset_index()
576+
)
577+
578+
# Reorder columns
579+
desired_order = [
580+
"file_name", "gene", "tree_size", "product", "N/tree_size",
581+
"Num resolved", "approx_drSPR", "approx_drSPR/N"
582+
]
583+
grouped = grouped[desired_order]
584+
585+
return grouped.drop_duplicates()
497586

498587
def main(args=None):
499588
args = parse_args(args)
@@ -502,7 +591,7 @@ def main(args=None):
502591
#'''
503592
results = pd.DataFrame(columns=["file_name", "tree_size", "approx_drSPR"])
504593
results.set_index("file_name", inplace=True)
505-
rooted_paths = root_trees(
594+
rooted_paths = root_all_trees(
506595
args.CORE_TREE, args.GENE_TREES, args.OUTPUT_DIR, results, True
507596
)
508597
approx_rspr(
@@ -512,7 +601,10 @@ def main(args=None):
512601
args.MAX_SUPPORT_THRESHOLD,
513602
)
514603

604+
#exit(11)
605+
515606
# Generate standard heatmap
607+
# results["approx_drSPR"] = pd.to_numeric(results["approx_drSPR"]).fillna(1000000)
516608
results["approx_drSPR"] = pd.to_numeric(results["approx_drSPR"])
517609
fig_path = os.path.join(args.OUTPUT_DIR, "output.png")
518610
make_heatmap(

0 commit comments

Comments
 (0)