@@ -597,6 +597,26 @@ void cmaple::Tree::doInferenceTemplate(
597597 cout.rdbuf (src_cout);
598598}
599599
600+ /* *
601+ * \internal
602+ * Helper class used for storing placements found for samples.
603+ */
604+ class SamplePlacement {
605+ public:
606+
607+ // the sample id
608+ size_t _id;
609+
610+ // the lower regions representing that sample
611+ std::unique_ptr<SeqRegions> _lower_regions;
612+
613+ // the placement
614+ cmaple::Index _selected_node_index;
615+
616+ // the estimated likelihood contribution of that placement
617+ RealNumType _placement_lh;
618+ };
619+
600620template <const StateType num_states>
601621void cmaple::Tree::doPlacementTemplate (const int num_threads, std::ostream& out_stream) {
602622 assert (cumulative_rate);
@@ -675,8 +695,7 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
675695 chunk_size = params->num_samples_per_thread * num_actual_threads;
676696#endif
677697 // Vectors store the placements found
678- std::vector<Index> selected_node_index_vec (chunk_size);
679- std::vector<std::unique_ptr<SeqRegions>> lower_regions_vec (chunk_size);
698+ std::vector<SamplePlacement> sample_placement_vec (chunk_size);
680699
681700 // iteratively place other samples (sequences)
682701 for (; i < num_seqs; ++i, ++sequence) {
@@ -692,6 +711,7 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
692711 if (num_seqs - i < chunk_size)
693712 {
694713 chunk_size = num_seqs - i;
714+ sample_placement_vec.resize (chunk_size);
695715 }
696716
697717 #pragma omp parallel for
@@ -700,16 +720,19 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
700720 // get the actual index of sequence
701721 size_t index = i + j;
702722
723+ // dummy variables
724+ Index selected_node_index = Index (root_vector_index, UNDEFINED);
725+ RealNumType best_lh_diff = MIN_NEGATIVE;
726+ std::unique_ptr<SeqRegions> lower_regions = nullptr ;
727+
703728 // only seek a placement for a sequence that was NOT added in the input tree
704729 if (!(from_input_tree && sequence_added[index]))
705730 {
706731 // get the lower likelihood vector of the current sequence
707- std::unique_ptr<SeqRegions> lower_regions =
732+ lower_regions =
708733 sequence[j].getLowerLhVector (seq_length, num_states, aln->getSeqType ());
709734
710735 // seek a position for new sample placement
711- Index selected_node_index;
712- RealNumType best_lh_diff = MIN_NEGATIVE;
713736 bool is_mid_branch = false ;
714737 RealNumType best_up_lh_diff = MIN_NEGATIVE;
715738 RealNumType best_down_lh_diff = MIN_NEGATIVE;
@@ -739,10 +762,20 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
739762 break ;
740763 }
741764 }
742- selected_node_index_vec[j] = selected_node_index;
743- lower_regions_vec[j] = std::move (lower_regions);
744765 }
766+
767+ // record the placement
768+ sample_placement_vec[j]._id = index;
769+ sample_placement_vec[j]._lower_regions = std::move (lower_regions);
770+ sample_placement_vec[j]._selected_node_index = selected_node_index;
771+ sample_placement_vec[j]._placement_lh = best_lh_diff;
745772 }
773+
774+ // sort the sample placements
775+ std::sort (sample_placement_vec.begin (), sample_placement_vec.end (),
776+ [](const SamplePlacement& a, const SamplePlacement& b) {
777+ return a._placement_lh > b._placement_lh ;
778+ });
746779 }
747780
748781 // sequentially seek placement (again from the found placement if found or from the root) and place the sample
@@ -762,6 +795,9 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
762795 ++sequence;
763796 }
764797
798+ // get the actual index of sequence
799+ size_t index = parallel_search ? sample_placement_vec[j]._id : i;
800+
765801 // check to perform topology optimization
766802 if (params->num_samples_spr_during_inital_tree
767803 && i % (params->num_samples_spr_during_inital_tree ) == 0
@@ -778,13 +814,13 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
778814 }
779815
780816 // don't add sequence that was already added in the input tree
781- if (from_input_tree && sequence_added[i ]) {
817+ if (from_input_tree && sequence_added[index ]) {
782818 --num_new_sequences;
783819 continue ;
784820 }
785821 // otherwise, mark the current sequence as added
786822 else {
787- sequence_added[i ] = true ;
823+ sequence_added[index ] = true ;
788824 }
789825
790826 // update the mutation matrix from empirical number of mutations observed
@@ -801,8 +837,8 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
801837 std::unique_ptr<SeqRegions> lower_regions = nullptr ;
802838 if (parallel_search)
803839 {
804- found_placement_index = selected_node_index_vec [j];
805- lower_regions = std::move (lower_regions_vec [j]);
840+ found_placement_index = sample_placement_vec [j]. _selected_node_index ;
841+ lower_regions = std::move (sample_placement_vec [j]. _lower_regions );
806842 }
807843 // otherise, start from the root
808844 else
@@ -838,7 +874,7 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
838874 RealNumType best_down_lh_diff = MIN_NEGATIVE;
839875 Index best_child_index;
840876 seekSamplePlacement<num_states>(
841- Index (found_placement_index.getVectorIndex (), TOP), static_cast <NumSeqsType>(i ),
877+ Index (found_placement_index.getVectorIndex (), TOP), static_cast <NumSeqsType>(index ),
842878 lower_regions, selected_node_index, best_lh_diff, is_mid_branch,
843879 best_up_lh_diff, best_down_lh_diff, best_child_index);
844880
@@ -848,12 +884,12 @@ void cmaple::Tree::doPlacementTemplate(const int num_threads, std::ostream& out_
848884 // place new sample as a descendant of a mid-branch point
849885 if (is_mid_branch) {
850886 placeNewSampleMidBranch<num_states>(selected_node_index, lower_regions,
851- static_cast <NumSeqsType>(i ), best_lh_diff);
887+ static_cast <NumSeqsType>(index ), best_lh_diff);
852888 // otherwise, best lk so far is for appending directly to existing
853889 // node
854890 } else {
855891 placeNewSampleAtNode<num_states>(selected_node_index, lower_regions,
856- static_cast <NumSeqsType>(i ),
892+ static_cast <NumSeqsType>(index ),
857893 best_lh_diff, best_up_lh_diff,
858894 best_down_lh_diff, best_child_index);
859895 }
0 commit comments