@@ -5096,22 +5096,16 @@ void cmaple::Tree::connectNewSample2Branch(
50965096 ->integrateMutations <num_states>(sibling_node_mutations, aln, true )));
50975097 }
50985098
5099- // lower lh vec at the new internal node
5100- internal.setPartialLh (TOP, std::move (internal.getPartialLh (TOP)
5101- ->integrateMutations <num_states>(sibling_node_mutations, aln, true )));
5099+ // all lh vectors/regions at the internal node
5100+ internal.integrateMutAllRegions <num_states>(sibling_node_mutations, aln, true );
51025101
5103- // upper left/right lh vec at the internal node
5104- internal.setPartialLh (LEFT, std::move (internal.getPartialLh (LEFT)
5105- ->integrateMutations <num_states>(sibling_node_mutations, aln, true )));
5106- internal.setPartialLh (RIGHT, std::move (internal.getPartialLh (RIGHT)
5107- ->integrateMutations <num_states>(sibling_node_mutations, aln, true )));
5102+ // NHAN added: total lh at the new internal node
5103+ // recompute the total lh vec
5104+ internal.computeTotalLhAtNode <num_states>(internal.getTotalLh (), parent_node, aln,
5105+ model, threshold_prob, root_vector_index == internal_vec_index);
51085106
5109- // mid-branch lh vec at the internal node
5110- if (internal.getMidBranchLh ())
5111- {
5112- internal.setMidBranchLh (std::move (internal.getMidBranchLh ()
5113- ->integrateMutations <num_states>(sibling_node_mutations, aln, true )));
5114- }
5107+ // don't need to de-integrate the mutations since it's has been
5108+ // recomputed on the correct lh vectors
51155109
51165110 // NHAN added: total lh vec must be computed after others
51175111 // total lh vec at the new sample
@@ -5124,13 +5118,45 @@ void cmaple::Tree::connectNewSample2Branch(
51245118 // don't need to de-integrate the mutations since it's has been
51255119 // recomputed on the correct lh vectors
51265120 }
5127- // total lh vec at the new internal node
5128- // recompute the total lh vec
5129- internal.computeTotalLhAtNode <num_states>(
5130- internal.getTotalLh (), parent_node, aln, model, threshold_prob,
5131- root_vector_index == internal_vec_index);
5132- // don't need to de-integrate the mutations since it's has been
5133- // recomputed on the correct lh vectors
5121+ }
5122+
5123+ // traverse upward to update the number of descendants
5124+ assert (num_new_descendant >= 0 );
5125+ if (num_new_descendant)
5126+ {
5127+ Index traverse_parent_index = internal.getNeighborIndex (TOP);
5128+ NumSeqsType traverse_parent_vec_index = traverse_parent_index.getVectorIndex ();
5129+
5130+ // traverse upward until going beyond the root
5131+ while (traverse_parent_index.getMiniIndex () != UNDEFINED)
5132+ {
5133+ // we still process the root and the nearest local reference here
5134+ // update the number of descendants
5135+ corrected_num_descendants[traverse_parent_vec_index] += num_new_descendant;
5136+
5137+ // stop if reaching the nearest local reference
5138+ if (node_mutations[traverse_parent_vec_index])
5139+ break ;
5140+
5141+ PhyloNode& traverse_node = nodes[traverse_parent_vec_index];
5142+
5143+ // make the internal node a new new local ref node, if it meets the requirements
5144+ if (corrected_num_descendants[traverse_parent_vec_index] >= params->max_desc_ref
5145+ && traverse_node.getPartialLh (TOP)->containAtLeastNMuts <num_states>(params->min_mut_ref ))
5146+ {
5147+ // make the internal node a new new local ref node
5148+
5149+ // stop traversing further
5150+ break ;
5151+ }
5152+
5153+
5154+ // move upward
5155+ traverse_parent_index = traverse_node.getNeighborIndex (TOP);
5156+ traverse_parent_vec_index = traverse_parent_index.getVectorIndex ();
5157+
5158+
5159+ }
51345160 }
51355161
51365162 // NHANLT: LOGS FOR DEBUGGING
@@ -11380,3 +11406,109 @@ void Tree::expandVectorsAfterTreeExpansion()
1138011406 sprta_alt_branches.resize (num_nodes);
1138111407 sprta_support_list.resize (num_nodes);
1138211408}
11409+
11410+ template <const StateType num_states>
11411+ auto cmaple::Tree::makeReferenceNode (PhyloNode& node, const cmaple::Index node_index, const int old_num_desc) -> void
11412+ {
11413+ // dummy variables
11414+ const PositionType seq_length = static_cast <PositionType>(aln->ref_seq .size ());
11415+ const RealNumType threshold_prob = params->threshold_prob ;
11416+
11417+ // 1. traverse upward, reduce the number of descendants
11418+ cmaple::Index traverse_parent_index = node.getNeighborIndex (TOP);
11419+ while (traverse_parent_index.getMiniIndex () != UNDEFINED)
11420+ {
11421+ const NumSeqsType& traverse_parent_vec_index = traverse_parent_index.getVectorIndex ();
11422+ corrected_num_descendants[traverse_parent_vec_index] -= old_num_desc;
11423+
11424+ // stop if reaching a local reference
11425+ if (node_mutations[traverse_parent_vec_index])
11426+ break ;
11427+
11428+ // move upward
11429+ PhyloNode& traverse_parent_node = nodes[traverse_parent_vec_index];
11430+ traverse_parent_index = traverse_parent_node.getNeighborIndex (TOP);
11431+ }
11432+
11433+ // 2. define mutations at node
11434+ const NumSeqsType& node_vec_index = node_index.getVectorIndex ();
11435+ node_mutations[node_vec_index] = cmaple::make_unique<SeqRegions>();
11436+ std::unique_ptr<SeqRegions>& this_node_mutations = node_mutations[node_vec_index];
11437+ std::unique_ptr<SeqRegions>& lower_regions = node.getPartialLh (TOP);
11438+ // loop over the lower regions of this node
11439+ // extract mutations and at R regions if needed
11440+ // loop over the vector of regions
11441+ for (auto i = 0 ; i < lower_regions->size (); ++i)
11442+ {
11443+ const SeqRegion& seq_region = lower_regions->at (i);
11444+
11445+ // only handle mutations
11446+ if (seq_region.type < num_states)
11447+ {
11448+ // add an R region if needed
11449+ PositionType prev_region_pos = seq_region.position ;
11450+ if (prev_region_pos > 0 )
11451+ {
11452+ --prev_region_pos;
11453+
11454+ if (!this_node_mutations->size ()
11455+ || this_node_mutations->back ().position < prev_region_pos)
11456+ {
11457+ this_node_mutations->push_back (SeqRegion (TYPE_R, prev_region_pos));
11458+ }
11459+ }
11460+
11461+ // add this mutation
11462+ this_node_mutations->push_back (SeqRegion::clone (seq_region));
11463+ }
11464+ }
11465+ // add the last R region if needed
11466+ if (!this_node_mutations->size ()
11467+ || this_node_mutations->back ().position < seq_length - 1 )
11468+ {
11469+ this_node_mutations->push_back (SeqRegion (TYPE_R, seq_length - 1 ));
11470+ }
11471+
11472+ // 3. update the lh vectors of this node
11473+ node.integrateMutAllRegions <num_states>(this_node_mutations, aln);
11474+ // NHAN added: total lh at the new internal node
11475+ // recompute the total lh vec
11476+ const NumSeqsType parent_vec_index = node.getNeighborIndex (TOP).getVectorIndex ();
11477+ node.computeTotalLhAtNode <num_states>(node.getTotalLh (), nodes[parent_vec_index], aln,
11478+ model, threshold_prob, root_vector_index == node_vec_index);
11479+
11480+ // 4. Traverse downward to integrate the mutations to descendant nodes
11481+ assert (node.isInternal ());
11482+ stack<NumSeqsType> node_stack;
11483+ node_stack.push (node.getNeighborIndex (LEFT).getVectorIndex ());
11484+ node_stack.push (node.getNeighborIndex (RIGHT).getVectorIndex ());
11485+ while (!node_stack.empty ()) {
11486+ // extract the corresponding node
11487+ const NumSeqsType child_node_vec_index = node_stack.top ();
11488+ node_stack.pop ();
11489+ PhyloNode& child_node = nodes[child_node_vec_index];
11490+
11491+ // if child node is also a local reference node
11492+ if (node_mutations[child_node_vec_index])
11493+ {
11494+ // TODO merge these two lists of mutations
11495+
11496+ }
11497+ // otherwise, simply integrate the mutations from the new reference node
11498+ else
11499+ {
11500+ // update the lh vectors of this child node
11501+ child_node.integrateMutAllRegions <num_states>(this_node_mutations, aln);
11502+ // NHAN added: total lh at the new internal node
11503+ // recompute the total lh vec
11504+ const NumSeqsType parent_child_vec_index = child_node.getNeighborIndex (TOP).getVectorIndex ();
11505+ child_node.computeTotalLhAtNode <num_states>(child_node.getTotalLh (),
11506+ nodes[parent_child_vec_index], aln, model, threshold_prob,
11507+ root_vector_index == child_node_vec_index);
11508+
11509+ // keep traversing downward
11510+ node_stack.push (child_node.getNeighborIndex (LEFT).getVectorIndex ());
11511+ node_stack.push (child_node.getNeighborIndex (RIGHT).getVectorIndex ());
11512+ }
11513+ }
11514+ }
0 commit comments