8
8
from pathlib import Path
9
9
import argparse
10
10
import subprocess
11
+ import shutil
11
12
from ete3 import Tree
12
13
import pandas as pd
13
14
from collections import defaultdict
14
15
from matplotlib import pyplot as plt
15
16
from matplotlib .colors import LogNorm
16
17
import seaborn as sns
17
18
import tempfile
19
+ import logging
18
20
19
21
20
22
#####################################################################
@@ -98,8 +100,9 @@ def read_tree(input_path):
98
100
tree_string = f .read ()
99
101
formatted = re .sub (r";[^:]+:" , ":" , tree_string )
100
102
is_duplicated = check_formatted_tree (formatted )
103
+ is_small = formatted .count ("," ) < 3
101
104
102
- return Tree (formatted ), is_duplicated
105
+ return Tree (formatted ), is_duplicated , is_small
103
106
104
107
105
108
#####################################################################
@@ -111,33 +114,38 @@ def read_tree(input_path):
111
114
#####################################################################
112
115
113
116
114
- def root_tree (input_path , basename , output_path ):
115
- tre ,is_duplicated = read_tree (input_path )
117
+ def root_one_tree (input_path , basename , output_path ):
118
+ tre ,is_duplicated , is_small = read_tree (input_path )
116
119
midpoint = tre .get_midpoint_outgroup ()
117
120
tre .set_outgroup (midpoint )
118
121
if is_duplicated :
119
122
outdir = Path (output_path ) / "multiple"
120
123
Path (outdir ).mkdir (exist_ok = True , parents = True )
121
124
output_path = outdir / basename
122
125
output_path = str (output_path ).replace (".tre" , ".tre.multiple" )
126
+ elif is_small :
127
+ outdir = Path (output_path ) / "small"
128
+ Path (outdir ).mkdir (exist_ok = True , parents = True )
129
+ output_path = outdir / basename
130
+ output_path = str (output_path ).replace (".tre" , ".tre.small" )
123
131
else :
124
132
outdir = Path (output_path ) / "unique"
125
133
Path (outdir ).mkdir (exist_ok = True , parents = True )
126
134
output_path = outdir / basename
127
135
128
136
tre .write (outfile = output_path )
129
- return tre .write (), len (tre .get_leaves ()), output_path , is_duplicated
137
+ return tre .write (), len (tre .get_leaves ()), output_path , is_duplicated , is_small
130
138
131
139
def root_reference_tree (input_path , output_path ):
132
- tre , _ = read_tree (input_path )
140
+ tre , _ , _ = read_tree (input_path )
133
141
midpoint = tre .get_midpoint_outgroup ()
134
142
tre .set_outgroup (midpoint )
135
143
tre .write (outfile = output_path )
136
144
return tre .write (), len (tre .get_leaves ())
137
145
138
146
139
147
#####################################################################
140
- ### FUNCTION ROOT_TREE
148
+ ### FUNCTION ROOT_ALL_TREES
141
149
### Root all the unrooted input trees in directory
142
150
### core_tree: path of the core tree
143
151
### gene_trees: path of the csv file containing all the gene tree paths
@@ -148,8 +156,7 @@ def root_reference_tree(input_path, output_path):
148
156
#####################################################################
149
157
150
158
151
- def root_trees (core_tree , gene_trees_path , output_dir , results , merge_pair = False ):
152
- print ("Rooting trees" )
159
+ def root_all_trees (core_tree , gene_trees_path , output_dir , results , merge_pair = False ):
153
160
#'''
154
161
reference_tree = core_tree
155
162
@@ -165,11 +172,11 @@ def root_trees(core_tree, gene_trees_path, output_dir, results, merge_pair=False
165
172
rooted_gene_trees_path = os .path .join (output_dir , "rooted_gene_trees" )
166
173
for filename in df_gene_trees ["path" ]:
167
174
basename = Path (filename ).name
168
- gene_content , gene_tree_size , gene_tree_path , is_duplicated = root_tree (
175
+ gene_content , gene_tree_size , gene_tree_path , is_duplicated , is_small = root_one_tree (
169
176
filename ,
170
177
basename ,
171
178
rooted_gene_trees_path )
172
- if not is_duplicated :
179
+ if not ( is_duplicated or is_small ) :
173
180
results .loc [basename , "tree_size" ] = gene_tree_size
174
181
if merge_pair :
175
182
with open (gene_tree_path , "w" ) as f2 :
@@ -205,6 +212,9 @@ def extract_approx_distance(text):
205
212
206
213
def run_approx_rspr (results , input_file , lst_filename , rspr_path ):
207
214
input_file .seek (0 )
215
+
216
+ command_exists = shutil .which (rspr_path [0 ])
217
+
208
218
result = subprocess .run (
209
219
rspr_path , stdin = input_file , capture_output = True , text = True
210
220
)
@@ -231,7 +241,6 @@ def run_approx_rspr(results, input_file, lst_filename, rspr_path):
231
241
def approx_rspr (
232
242
rooted_gene_trees_path , results , min_branch_len = 0 , max_support_threshold = 0.7
233
243
):
234
- print ("Calculating approx distance" )
235
244
rspr_path = [
236
245
"rspr" ,
237
246
"-approx" ,
@@ -245,20 +254,73 @@ def approx_rspr(
245
254
lst_filename = []
246
255
with tempfile .TemporaryFile (mode = 'w+' ) as temp_file :
247
256
for filename in os .listdir (rooted_gene_trees_path ):
248
- if cur_count == group_size :
249
- run_approx_rspr (results , temp_file , lst_filename , rspr_path )
250
- temp_file .seek (0 )
251
- temp_file .truncate ()
252
- lst_filename .clear ()
253
- cur_count = 0
254
-
255
- gene_tree_path = os .path .join (rooted_gene_trees_path , filename )
256
- with open (gene_tree_path , "r" ) as infile :
257
- temp_file .write (infile .read () + "\n " )
258
- lst_filename .append (filename )
259
- cur_count += 1
260
- if cur_count > 0 :
261
- run_approx_rspr (results , temp_file , lst_filename , rspr_path )
257
+ if str (filename ) in results .index :
258
+ print ("Found " + str (filename ))
259
+ if cur_count == group_size :
260
+ run_approx_rspr (results , temp_file , lst_filename , rspr_path )
261
+ temp_file .seek (0 )
262
+ temp_file .truncate ()
263
+ lst_filename .clear ()
264
+ cur_count = 0
265
+
266
+ gene_tree_path = os .path .join (rooted_gene_trees_path , filename )
267
+ with open (gene_tree_path , "r" ) as infile :
268
+ lines = infile .readlines ()
269
+ if len (lines ) < 2 :
270
+ print (f"File { filename } does not have enough lines." )
271
+ continue
272
+ tree = Tree (lines [1 ].strip ())
273
+ # Calculate N: number of nodes at or above the support threshold
274
+ # num_resolved = sum(1 for node in tree.traverse() if node.support >= max_support_threshold and not node.is_leaf())
275
+ num_resolved = - 1
276
+ for node in tree .traverse ():
277
+ if node .support is not None and node .support >= max_support_threshold and not node .is_leaf ():
278
+ num_resolved += 1
279
+
280
+ tree_size = len (tree .get_leaves ())
281
+ results .loc [filename , "Num resolved" ] = num_resolved
282
+ results .loc [filename , "N/tree_size" ] = num_resolved / tree_size if tree_size > 0 else 0
283
+ lst_filename .append (filename )
284
+ temp_file .write (lines [0 ].strip () + "\n " + lines [1 ].strip () + "\n " )
285
+ cur_count += 1
286
+ if cur_count > 0 :
287
+ run_approx_rspr (results , temp_file , lst_filename , rspr_path )
288
+
289
+ # Add the approx_drSPR/N column
290
+ results ["approx_drSPR/N" ] = results .apply (lambda row : float (row ["approx_drSPR" ]) / row ["Num resolved" ] if row ["Num resolved" ] > 0 else 0 , axis = 1 )
291
+ print ("CBA " + str (results ))
292
+
293
+ #def approx_rspr_old(
294
+ # rooted_gene_trees_path, results, min_branch_len=0, max_support_threshold=0.7
295
+ #):
296
+ # print("Calculating approx distance")
297
+ # rspr_path = [
298
+ # "rspr",
299
+ # "-approx",
300
+ # "-multifurcating",
301
+ # "-length " + str(min_branch_len),
302
+ # "-support " + str(max_support_threshold),
303
+ # ]
304
+ #
305
+ # group_size = 10000
306
+ # cur_count = 0
307
+ # lst_filename = []
308
+ # with tempfile.TemporaryFile(mode='w+') as temp_file:
309
+ # for filename in os.listdir(rooted_gene_trees_path):
310
+ # if cur_count == group_size:
311
+ # run_approx_rspr(results, temp_file, lst_filename, rspr_path)
312
+ # temp_file.seek(0)
313
+ # temp_file.truncate()
314
+ # lst_filename.clear()
315
+ # cur_count = 0
316
+ #
317
+ # gene_tree_path = os.path.join(rooted_gene_trees_path, filename)
318
+ # with open(gene_tree_path, "r") as infile:
319
+ # temp_file.write(infile.read() + "\n")
320
+ # lst_filename.append(filename)
321
+ # cur_count += 1
322
+ # if cur_count > 0:
323
+ # run_approx_rspr(results, temp_file, lst_filename, rspr_path)
262
324
263
325
264
326
#####################################################################
@@ -289,7 +351,6 @@ def generate_heatmap(freq_table, output_path, log_scale=False):
289
351
#####################################################################
290
352
291
353
def make_heatmap (results , output_path , min_distance , max_distance ):
292
- print ("Generating heatmap" )
293
354
294
355
# create sub dataframe
295
356
sub_results = results [(results ["approx_drSPR" ] >= min_distance )]
@@ -306,7 +367,6 @@ def make_heatmap(results, output_path, min_distance, max_distance):
306
367
307
368
308
369
def make_heatmap_from_tsv (input_path , output_path , min_distance , max_distance ):
309
- print ("Generating heatmap from CSV" )
310
370
results = pd .read_table (input_path )
311
371
make_heatmap (results , output_path , min_distance , max_distance )
312
372
@@ -339,7 +399,6 @@ def get_heatmap_group_size(all_values, max_groups=15):
339
399
#####################################################################
340
400
341
401
def make_group_heatmap (results , output_path , min_distance , max_distance ):
342
- print ("Generating group heatmap" )
343
402
344
403
# create sub dataframe
345
404
sub_results = results [(results ["approx_drSPR" ] >= min_distance )]
@@ -383,7 +442,7 @@ def make_group_heatmap(results, output_path, min_distance, max_distance):
383
442
### RETURN groups of trees
384
443
#####################################################################
385
444
386
- def generate_group_sizes (target_sum , max_groups = 500 ):
445
+ def generate_group_sizes (target_sum , max_groups = 1000 ):
387
446
degree = 1
388
447
current_sum = 0
389
448
group_sizes = []
@@ -410,7 +469,6 @@ def generate_group_sizes(target_sum, max_groups=500):
410
469
#####################################################################
411
470
412
471
def make_groups_v1 (results , min_limit = 10 ):
413
- print ("Generating groups" )
414
472
min_group = results [results ["approx_drSPR" ] <= min_limit ]["file_name" ].tolist ()
415
473
groups = defaultdict ()
416
474
first_group = "group_0"
@@ -438,7 +496,6 @@ def make_groups_v1(results, min_limit=10):
438
496
#####################################################################
439
497
440
498
def make_groups (results , min_limit = 10 ):
441
- print ("Generating groups" )
442
499
min_group = results [results ["approx_drSPR" ] <= min_limit ]["file_name" ].tolist ()
443
500
groups = defaultdict ()
444
501
first_group = "group_0"
@@ -463,7 +520,6 @@ def make_groups(results, min_limit=10):
463
520
464
521
465
522
def make_groups_from_csv (input_df , min_limit ):
466
- print ("Generating groups from CSV" )
467
523
groups = make_groups_v1 (input_df , min_limit )
468
524
tidy_data = [
469
525
(key , val )
@@ -476,6 +532,24 @@ def make_groups_from_csv(input_df, min_limit):
476
532
return merged
477
533
478
534
535
+ # def join_annotation_data(df, annotation_data):
536
+ # ann_df = pd.read_table(annotation_data, dtype={"genome_id": "str"})
537
+ # ann_df.columns = map(str.lower, ann_df.columns)
538
+ # ann_df.columns = ann_df.columns.str.replace(" ", "_")
539
+ # ann_subset = ann_df[["gene", "product"]]
540
+ #
541
+ # df["tree_name"] = [f.split(".")[0] for f in df["file_name"]]
542
+ #
543
+ # merged = df.merge(ann_subset, how="left", left_on="tree_name", right_on="gene")
544
+ #
545
+ # if merged["gene"].isnull().all():
546
+ # ann_subset = ann_df[["locus_tag", "gene", "product"]]
547
+ # merged = df.merge(
548
+ # ann_subset, how="left", left_on="tree_name", right_on="locus_tag"
549
+ # )
550
+ #
551
+ # return merged.fillna(value="NULL").drop("tree_name", axis=1).drop_duplicates()
552
+
479
553
def join_annotation_data (df , annotation_data ):
480
554
ann_df = pd .read_table (annotation_data , dtype = {"genome_id" : "str" })
481
555
ann_df .columns = map (str .lower , ann_df .columns )
@@ -492,8 +566,23 @@ def join_annotation_data(df, annotation_data):
492
566
ann_subset , how = "left" , left_on = "tree_name" , right_on = "locus_tag"
493
567
)
494
568
495
- return merged .fillna (value = "NULL" ).drop ("tree_name" , axis = 1 ). drop_duplicates ( )
569
+ merged = merged .fillna ("NULL" ).drop ("tree_name" , axis = 1 )
496
570
571
+ # Group by all columns except 'product' and aggregate 'product'
572
+ grouped = (
573
+ merged .groupby (list (merged .columns .difference (['product' ])))
574
+ .agg ({'product' : lambda x : '||' .join (sorted (set (x )))})
575
+ .reset_index ()
576
+ )
577
+
578
+ # Reorder columns
579
+ desired_order = [
580
+ "file_name" , "gene" , "tree_size" , "product" , "N/tree_size" ,
581
+ "Num resolved" , "approx_drSPR" , "approx_drSPR/N"
582
+ ]
583
+ grouped = grouped [desired_order ]
584
+
585
+ return grouped .drop_duplicates ()
497
586
498
587
def main (args = None ):
499
588
args = parse_args (args )
@@ -502,7 +591,7 @@ def main(args=None):
502
591
#'''
503
592
results = pd .DataFrame (columns = ["file_name" , "tree_size" , "approx_drSPR" ])
504
593
results .set_index ("file_name" , inplace = True )
505
- rooted_paths = root_trees (
594
+ rooted_paths = root_all_trees (
506
595
args .CORE_TREE , args .GENE_TREES , args .OUTPUT_DIR , results , True
507
596
)
508
597
approx_rspr (
@@ -512,7 +601,10 @@ def main(args=None):
512
601
args .MAX_SUPPORT_THRESHOLD ,
513
602
)
514
603
604
+ #exit(11)
605
+
515
606
# Generate standard heatmap
607
+ # results["approx_drSPR"] = pd.to_numeric(results["approx_drSPR"]).fillna(1000000)
516
608
results ["approx_drSPR" ] = pd .to_numeric (results ["approx_drSPR" ])
517
609
fig_path = os .path .join (args .OUTPUT_DIR , "output.png" )
518
610
make_heatmap (
0 commit comments