Skip to content

Commit df25440

Browse files
committed
parallel building an initial tree: sort samples according to their estimated placement likelihood contributions
1 parent a04fd98 commit df25440

File tree

1 file changed

+50
-14
lines changed

1 file changed

+50
-14
lines changed

tree/tree.cpp

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,26 @@ void cmaple::Tree::doInferenceTemplate(
597597
cout.rdbuf(src_cout);
598598
}
599599

600+
/**
601+
* \internal
602+
* Helper class used for storing placements found for samples.
603+
*/
604+
class SamplePlacement {
605+
public:
606+
607+
// the sample id
608+
size_t _id;
609+
610+
// the lower regions representing that sample
611+
std::unique_ptr<SeqRegions> _lower_regions;
612+
613+
// the placement
614+
cmaple::Index _selected_node_index;
615+
616+
// the estimated likelihood contribution of that placement
617+
RealNumType _placement_lh;
618+
};
619+
600620
template <const StateType num_states>
601621
void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_stream) {
602622
assert(cumulative_rate);
@@ -675,8 +695,7 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
675695
chunk_size = params->num_samples_per_thread * num_actual_threads;
676696
#endif
677697
// Vectors store the placements found
678-
std::vector<Index> selected_node_index_vec(chunk_size);
679-
std::vector<std::unique_ptr<SeqRegions>> lower_regions_vec(chunk_size);
698+
std::vector<SamplePlacement> sample_placement_vec(chunk_size);
680699

681700
// iteratively place other samples (sequences)
682701
for (; i < num_seqs; ++i, ++sequence) {
@@ -692,6 +711,7 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
692711
if (num_seqs - i < chunk_size)
693712
{
694713
chunk_size = num_seqs - i;
714+
sample_placement_vec.resize(chunk_size);
695715
}
696716

697717
#pragma omp parallel for
@@ -700,16 +720,19 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
700720
// get the actual index of sequence
701721
size_t index = i + j;
702722

723+
// dummy variables
724+
Index selected_node_index = Index(root_vector_index, UNDEFINED);
725+
RealNumType best_lh_diff = MIN_NEGATIVE;
726+
std::unique_ptr<SeqRegions> lower_regions = nullptr;
727+
703728
// only seek a placement for a sequence that was NOT added in the input tree
704729
if (!(from_input_tree && sequence_added[index]))
705730
{
706731
// get the lower likelihood vector of the current sequence
707-
std::unique_ptr<SeqRegions> lower_regions =
732+
lower_regions =
708733
sequence[j].getLowerLhVector(seq_length, num_states, aln->getSeqType());
709734

710735
// seek a position for new sample placement
711-
Index selected_node_index;
712-
RealNumType best_lh_diff = MIN_NEGATIVE;
713736
bool is_mid_branch = false;
714737
RealNumType best_up_lh_diff = MIN_NEGATIVE;
715738
RealNumType best_down_lh_diff = MIN_NEGATIVE;
@@ -739,10 +762,20 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
739762
break;
740763
}
741764
}
742-
selected_node_index_vec[j] = selected_node_index;
743-
lower_regions_vec[j] = std::move(lower_regions);
744765
}
766+
767+
// record the placement
768+
sample_placement_vec[j]._id = index;
769+
sample_placement_vec[j]._lower_regions = std::move(lower_regions);
770+
sample_placement_vec[j]._selected_node_index = selected_node_index;
771+
sample_placement_vec[j]._placement_lh = best_lh_diff;
745772
}
773+
774+
// sort the sample placements
775+
std::sort(sample_placement_vec.begin(), sample_placement_vec.end(),
776+
[](const SamplePlacement& a, const SamplePlacement& b) {
777+
return a._placement_lh > b._placement_lh;
778+
});
746779
}
747780

748781
// sequentially seek placement (again from the found placement if found or from the root) and place the sample
@@ -762,6 +795,9 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
762795
++sequence;
763796
}
764797

798+
// get the actual index of sequence
799+
size_t index = parallel_search ? sample_placement_vec[j]._id : i;
800+
765801
// check to perform topology optimization
766802
if (params->num_samples_spr_during_inital_tree
767803
&& i % (params->num_samples_spr_during_inital_tree) == 0
@@ -778,13 +814,13 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
778814
}
779815

780816
// don't add sequence that was already added in the input tree
781-
if (from_input_tree && sequence_added[i]) {
817+
if (from_input_tree && sequence_added[index]) {
782818
--num_new_sequences;
783819
continue;
784820
}
785821
// otherwise, mark the current sequence as added
786822
else {
787-
sequence_added[i] = true;
823+
sequence_added[index] = true;
788824
}
789825

790826
// update the mutation matrix from empirical number of mutations observed
@@ -801,8 +837,8 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
801837
std::unique_ptr<SeqRegions> lower_regions = nullptr;
802838
if (parallel_search)
803839
{
804-
found_placement_index = selected_node_index_vec[j];
805-
lower_regions = std::move(lower_regions_vec[j]);
840+
found_placement_index = sample_placement_vec[j]._selected_node_index;
841+
lower_regions = std::move(sample_placement_vec[j]._lower_regions);
806842
}
807843
// otherise, start from the root
808844
else
@@ -838,7 +874,7 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
838874
RealNumType best_down_lh_diff = MIN_NEGATIVE;
839875
Index best_child_index;
840876
seekSamplePlacement<num_states>(
841-
Index(found_placement_index.getVectorIndex(), TOP), static_cast<NumSeqsType>(i),
877+
Index(found_placement_index.getVectorIndex(), TOP), static_cast<NumSeqsType>(index),
842878
lower_regions, selected_node_index, best_lh_diff, is_mid_branch,
843879
best_up_lh_diff, best_down_lh_diff, best_child_index);
844880

@@ -848,12 +884,12 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
848884
// place new sample as a descendant of a mid-branch point
849885
if (is_mid_branch) {
850886
placeNewSampleMidBranch<num_states>(selected_node_index, lower_regions,
851-
static_cast<NumSeqsType>(i), best_lh_diff);
887+
static_cast<NumSeqsType>(index), best_lh_diff);
852888
// otherwise, best lk so far is for appending directly to existing
853889
// node
854890
} else {
855891
placeNewSampleAtNode<num_states>(selected_node_index, lower_regions,
856-
static_cast<NumSeqsType>(i),
892+
static_cast<NumSeqsType>(index),
857893
best_lh_diff, best_up_lh_diff,
858894
best_down_lh_diff, best_child_index);
859895
}

0 commit comments

Comments
 (0)