Skip to content

Commit 6f475a1

Browse files
committed
make a local reference at a node
1 parent 1e73b88 commit 6f475a1

File tree

1 file changed

+153
-21
lines changed

1 file changed

+153
-21
lines changed

tree/tree.cpp

Lines changed: 153 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5096,22 +5096,16 @@ void cmaple::Tree::connectNewSample2Branch(
50965096
->integrateMutations<num_states>(sibling_node_mutations, aln, true)));
50975097
}
50985098

5099-
// lower lh vec at the new internal node
5100-
internal.setPartialLh(TOP, std::move(internal.getPartialLh(TOP)
5101-
->integrateMutations<num_states>(sibling_node_mutations, aln, true)));
5099+
// all lh vectors/regions at the internal node
5100+
internal.integrateMutAllRegions<num_states>(sibling_node_mutations, aln, true);
51025101

5103-
// upper left/right lh vec at the internal node
5104-
internal.setPartialLh(LEFT, std::move(internal.getPartialLh(LEFT)
5105-
->integrateMutations<num_states>(sibling_node_mutations, aln, true)));
5106-
internal.setPartialLh(RIGHT, std::move(internal.getPartialLh(RIGHT)
5107-
->integrateMutations<num_states>(sibling_node_mutations, aln, true)));
5102+
// NHAN added: total lh at the new internal node
5103+
// recompute the total lh vec
5104+
internal.computeTotalLhAtNode<num_states>(internal.getTotalLh(), parent_node, aln,
5105+
model, threshold_prob, root_vector_index == internal_vec_index);
51085106

5109-
// mid-branch lh vec at the internal node
5110-
if (internal.getMidBranchLh())
5111-
{
5112-
internal.setMidBranchLh(std::move(internal.getMidBranchLh()
5113-
->integrateMutations<num_states>(sibling_node_mutations, aln, true)));
5114-
}
5107+
// don't need to de-integrate the mutations since it's has been
5108+
// recomputed on the correct lh vectors
51155109

51165110
// NHAN added: total lh vec must be computed after others
51175111
// total lh vec at the new sample
@@ -5124,13 +5118,45 @@ void cmaple::Tree::connectNewSample2Branch(
51245118
// don't need to de-integrate the mutations since it's has been
51255119
// recomputed on the correct lh vectors
51265120
}
5127-
// total lh vec at the new internal node
5128-
// recompute the total lh vec
5129-
internal.computeTotalLhAtNode<num_states>(
5130-
internal.getTotalLh(), parent_node, aln, model, threshold_prob,
5131-
root_vector_index == internal_vec_index);
5132-
// don't need to de-integrate the mutations since it's has been
5133-
// recomputed on the correct lh vectors
5121+
}
5122+
5123+
// traverse upward to update the number of descendants
5124+
assert(num_new_descendant >= 0);
5125+
if (num_new_descendant)
5126+
{
5127+
Index traverse_parent_index = internal.getNeighborIndex(TOP);
5128+
NumSeqsType traverse_parent_vec_index = traverse_parent_index.getVectorIndex();
5129+
5130+
// traverse upward until going beyond the root
5131+
while (traverse_parent_index.getMiniIndex() != UNDEFINED)
5132+
{
5133+
// we still process the root and the nearest local reference here
5134+
// update the number of descendants
5135+
corrected_num_descendants[traverse_parent_vec_index] += num_new_descendant;
5136+
5137+
// stop if reaching the nearest local reference
5138+
if (node_mutations[traverse_parent_vec_index])
5139+
break;
5140+
5141+
PhyloNode& traverse_node = nodes[traverse_parent_vec_index];
5142+
5143+
// make the internal node a new new local ref node, if it meets the requirements
5144+
if (corrected_num_descendants[traverse_parent_vec_index] >= params->max_desc_ref
5145+
&& traverse_node.getPartialLh(TOP)->containAtLeastNMuts<num_states>(params->min_mut_ref))
5146+
{
5147+
// make the internal node a new new local ref node
5148+
5149+
// stop traversing further
5150+
break;
5151+
}
5152+
5153+
5154+
// move upward
5155+
traverse_parent_index = traverse_node.getNeighborIndex(TOP);
5156+
traverse_parent_vec_index = traverse_parent_index.getVectorIndex();
5157+
5158+
5159+
}
51345160
}
51355161

51365162
// NHANLT: LOGS FOR DEBUGGING
@@ -11380,3 +11406,109 @@ void Tree::expandVectorsAfterTreeExpansion()
1138011406
sprta_alt_branches.resize(num_nodes);
1138111407
sprta_support_list.resize(num_nodes);
1138211408
}
11409+
11410+
template <const StateType num_states>
11411+
auto cmaple::Tree::makeReferenceNode(PhyloNode& node, const cmaple::Index node_index, const int old_num_desc) -> void
11412+
{
11413+
// dummy variables
11414+
const PositionType seq_length = static_cast<PositionType>(aln->ref_seq.size());
11415+
const RealNumType threshold_prob = params->threshold_prob;
11416+
11417+
// 1. traverse upward, reduce the number of descendants
11418+
cmaple::Index traverse_parent_index = node.getNeighborIndex(TOP);
11419+
while (traverse_parent_index.getMiniIndex() != UNDEFINED)
11420+
{
11421+
const NumSeqsType& traverse_parent_vec_index = traverse_parent_index.getVectorIndex();
11422+
corrected_num_descendants[traverse_parent_vec_index] -= old_num_desc;
11423+
11424+
// stop if reaching a local reference
11425+
if (node_mutations[traverse_parent_vec_index])
11426+
break;
11427+
11428+
// move upward
11429+
PhyloNode& traverse_parent_node = nodes[traverse_parent_vec_index];
11430+
traverse_parent_index = traverse_parent_node.getNeighborIndex(TOP);
11431+
}
11432+
11433+
// 2. define mutations at node
11434+
const NumSeqsType& node_vec_index = node_index.getVectorIndex();
11435+
node_mutations[node_vec_index] = cmaple::make_unique<SeqRegions>();
11436+
std::unique_ptr<SeqRegions>& this_node_mutations = node_mutations[node_vec_index];
11437+
std::unique_ptr<SeqRegions>& lower_regions = node.getPartialLh(TOP);
11438+
// loop over the lower regions of this node
11439+
// extract mutations and at R regions if needed
11440+
// loop over the vector of regions
11441+
for (auto i = 0; i < lower_regions->size(); ++i)
11442+
{
11443+
const SeqRegion& seq_region = lower_regions->at(i);
11444+
11445+
// only handle mutations
11446+
if (seq_region.type < num_states)
11447+
{
11448+
// add an R region if needed
11449+
PositionType prev_region_pos = seq_region.position;
11450+
if (prev_region_pos > 0)
11451+
{
11452+
--prev_region_pos;
11453+
11454+
if (!this_node_mutations->size()
11455+
|| this_node_mutations->back().position < prev_region_pos)
11456+
{
11457+
this_node_mutations->push_back(SeqRegion(TYPE_R, prev_region_pos));
11458+
}
11459+
}
11460+
11461+
// add this mutation
11462+
this_node_mutations->push_back(SeqRegion::clone(seq_region));
11463+
}
11464+
}
11465+
// add the last R region if needed
11466+
if (!this_node_mutations->size()
11467+
|| this_node_mutations->back().position < seq_length - 1)
11468+
{
11469+
this_node_mutations->push_back(SeqRegion(TYPE_R, seq_length - 1));
11470+
}
11471+
11472+
// 3. update the lh vectors of this node
11473+
node.integrateMutAllRegions<num_states>(this_node_mutations, aln);
11474+
// NHAN added: total lh at the new internal node
11475+
// recompute the total lh vec
11476+
const NumSeqsType parent_vec_index = node.getNeighborIndex(TOP).getVectorIndex();
11477+
node.computeTotalLhAtNode<num_states>(node.getTotalLh(), nodes[parent_vec_index], aln,
11478+
model, threshold_prob, root_vector_index == node_vec_index);
11479+
11480+
// 4. Traverse downward to integrate the mutations to descendant nodes
11481+
assert(node.isInternal());
11482+
stack<NumSeqsType> node_stack;
11483+
node_stack.push(node.getNeighborIndex(LEFT).getVectorIndex());
11484+
node_stack.push(node.getNeighborIndex(RIGHT).getVectorIndex());
11485+
while (!node_stack.empty()) {
11486+
// extract the corresponding node
11487+
const NumSeqsType child_node_vec_index = node_stack.top();
11488+
node_stack.pop();
11489+
PhyloNode& child_node = nodes[child_node_vec_index];
11490+
11491+
// if child node is also a local reference node
11492+
if (node_mutations[child_node_vec_index])
11493+
{
11494+
// TODO merge these two lists of mutations
11495+
11496+
}
11497+
// otherwise, simply integrate the mutations from the new reference node
11498+
else
11499+
{
11500+
// update the lh vectors of this child node
11501+
child_node.integrateMutAllRegions<num_states>(this_node_mutations, aln);
11502+
// NHAN added: total lh at the new internal node
11503+
// recompute the total lh vec
11504+
const NumSeqsType parent_child_vec_index = child_node.getNeighborIndex(TOP).getVectorIndex();
11505+
child_node.computeTotalLhAtNode<num_states>(child_node.getTotalLh(),
11506+
nodes[parent_child_vec_index], aln, model, threshold_prob,
11507+
root_vector_index == child_node_vec_index);
11508+
11509+
// keep traversing downward
11510+
node_stack.push(child_node.getNeighborIndex(LEFT).getVectorIndex());
11511+
node_stack.push(child_node.getNeighborIndex(RIGHT).getVectorIndex());
11512+
}
11513+
}
11514+
}

0 commit comments

Comments
 (0)