Skip to content

Commit e5d7693

Browse files
committed
Rate variation update
Avoid many small allocations of memory and perform one large one instead.
1 parent 9e5f9ab commit e5d7693

File tree

2 files changed

+40
-52
lines changed

2 files changed

+40
-52
lines changed

model/model_dna_rate_variation.cpp

Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,11 @@ void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) {
129129

130130
void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
131131
std::cout << "Estimating mutation rate per site..." << std::endl;
132-
RealNumType** waitingTimes = new RealNumType*[num_states_];
133-
for(int j = 0; j < num_states_; j++) {
134-
waitingTimes[j] = new RealNumType[genomeSize];
135-
}
132+
RealNumType* waitingTimes = new RealNumType[num_states_ * genomeSize];
136133
RealNumType* numSubstitutions = new RealNumType[genomeSize];
137134
for(int i = 0; i < genomeSize; i++) {
138135
for(int j = 0; j < num_states_; j++) {
139-
waitingTimes[j][i] = 0;
136+
waitingTimes[i * num_states_ + j] = 0;
140137
}
141138
numSubstitutions[i] = 0;
142139
}
@@ -183,12 +180,12 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
183180
// both states are type REF
184181
for(int i = pos; i <= end_pos; i++) {
185182
int state = tree->aln->ref_seq[static_cast<std::vector<cmaple::StateType>::size_type>(i)];
186-
waitingTimes[state][i] += blength;
183+
waitingTimes[i * num_states_ + state] += blength;
187184
}
188185
} else if(seq1_region->type == seq2_region->type && seq1_region->type < TYPE_R) {
189186
// both states are equal but not of type REF
190187
for(int i = pos; i <= end_pos; i++) {
191-
waitingTimes[seq1_region->type][i] += blength;
188+
waitingTimes[i * num_states_ + seq1_region->type] += blength;
192189
}
193190
} else if(seq1_region->type <= TYPE_R && seq2_region->type <= TYPE_R) {
194191
// both states are not equal
@@ -207,7 +204,7 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
207204
} else {
208205
RealNumType expectedRateNoSubstitution = 0;
209206
for(int j = 0; j < num_states_; j++) {
210-
RealNumType summand = waitingTimes[j][i] * abs(diagonal_mut_mat[j]);
207+
RealNumType summand = waitingTimes[i * num_states_ + j] * abs(diagonal_mut_mat[j]);
211208
expectedRateNoSubstitution += summand;
212209
}
213210
if(expectedRateNoSubstitution <= 0.01) {
@@ -247,24 +244,19 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
247244
}
248245
}
249246

250-
for(int j = 0; j < num_states_; j++) {
251-
delete[] waitingTimes[j];
252-
}
253247
delete[] waitingTimes;
254248
delete[] numSubstitutions;
255249
}
256250

257251
void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
258252

259-
RealNumType** C = new RealNumType*[genomeSize];
260-
RealNumType** W = new RealNumType*[genomeSize];
253+
RealNumType* C = new RealNumType[genomeSize * matSize];
254+
RealNumType* W = new RealNumType[genomeSize * num_states_];
261255
for(int i = 0; i < genomeSize; i++) {
262-
C[i] = new RealNumType[matSize];
263-
W[i] = new RealNumType[num_states_];
264256
for(int j = 0; j < num_states_; j++) {
265-
W[i][j] = 0;
257+
W[i * num_states_ + j] = 0;
266258
for(int k = 0; k < num_states_; k++) {
267-
C[i][row_index[j] + k] = 0;
259+
C[i * num_states_ + row_index[j] + k] = 0;
268260
}
269261
}
270262
}
@@ -325,12 +317,12 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
325317
// both states are type REF
326318
for(int i = pos; i <= end_pos; i++) {
327319
StateType state = tree->aln->ref_seq[static_cast<std::vector<cmaple::StateType>::size_type>(i)];
328-
W[i][state] += branchLengthToObservation;
320+
W[i * num_states_ + state] += branchLengthToObservation;
329321
}
330322
} else if(seqP_region->type == seqC_region->type && seqP_region->type < TYPE_R) {
331323
// both states are equal but not of type REF
332324
for(int i = pos; i <= end_pos; i++) {
333-
W[i][seqP_region->type] += branchLengthToObservation;
325+
W[i * num_states_ + seqP_region->type] += branchLengthToObservation;
334326
}
335327
} else if(seqP_region->type <= TYPE_R && seqC_region->type <= TYPE_R) {
336328
//states are not equal
@@ -345,9 +337,9 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
345337
// Case 1: Last observation was this side of the root node
346338
if(seqP_region->plength_observation2root <= 0) {
347339
for(int i = pos; i <= end_pos; i++) {
348-
W[i][stateA] += branchLengthToObservation/2;
349-
W[i][stateB] += branchLengthToObservation/2;
350-
C[i][stateB + row_index[stateA]] += 1;
340+
W[i * num_states_ + stateA] += branchLengthToObservation/2;
341+
W[i * num_states_ + stateB] += branchLengthToObservation/2;
342+
C[i * matSize + stateB + row_index[stateA]] += 1;
351343
}
352344
} else {
353345
// Case 2: Last observation was the other side of the root.
@@ -387,12 +379,12 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
387379
for(StateType stateB = 0; stateB < num_states_; stateB++) {
388380
RealNumType prob = weightVector[stateB];
389381
if(stateB != stateA) {
390-
C[end_pos][stateB + row_index[stateA]] += prob;
382+
C[end_pos * matSize + stateB + row_index[stateA]] += prob;
391383

392-
W[end_pos][stateA] += prob * branchLengthToObservation/2;
393-
W[end_pos][stateB] += prob * branchLengthToObservation/2;
384+
W[end_pos * num_states_ + stateA] += prob * branchLengthToObservation/2;
385+
W[end_pos * num_states_ + stateB] += prob * branchLengthToObservation/2;
394386
} else {
395-
W[end_pos][stateA] += prob * branchLengthToObservation;
387+
W[end_pos * num_states_ + stateA] += prob * branchLengthToObservation;
396388
}
397389
}
398390
} else {
@@ -434,12 +426,12 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
434426
for(StateType stateA = 0; stateA < num_states_; stateA++) {
435427
RealNumType prob = weightVector[stateA];
436428
if(stateB != stateA) {
437-
C[end_pos][stateB + row_index[stateA]] += prob;
429+
C[end_pos * matSize + stateB + row_index[stateA]] += prob;
438430

439-
W[end_pos][stateA] += prob * branchLengthToObservation/2;
440-
W[end_pos][stateB] += prob * branchLengthToObservation/2;
431+
W[end_pos * num_states_ + stateA] += prob * branchLengthToObservation/2;
432+
W[end_pos * num_states_ + stateB] += prob * branchLengthToObservation/2;
441433
} else {
442-
W[end_pos][stateA] += prob * branchLengthToObservation;
434+
W[end_pos * num_states_ + stateA] += prob * branchLengthToObservation;
443435
}
444436
}
445437
} else {
@@ -480,12 +472,12 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
480472
for(StateType stateB = 0; stateB < num_states_; stateB++) {
481473
RealNumType prob = weightVector[row_index[stateA] + stateB];
482474
if(stateB != stateA) {
483-
C[end_pos][stateB + row_index[stateA]] += prob;
475+
C[end_pos * matSize + stateB + row_index[stateA]] += prob;
484476

485-
W[end_pos][stateA] += prob * branchLengthToObservation/2;
486-
W[end_pos][stateB] += prob * branchLengthToObservation/2;
477+
W[end_pos * num_states_ + stateA] += prob * branchLengthToObservation/2;
478+
W[end_pos * num_states_ + stateB] += prob * branchLengthToObservation/2;
487479
} else {
488-
W[end_pos][stateA] += prob * branchLengthToObservation;
480+
W[end_pos * num_states_ + stateA] += prob * branchLengthToObservation;
489481
}
490482
}
491483
}
@@ -513,7 +505,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
513505
for(int i = 0; i < genomeSize; i++) {
514506
outFile << "Position: " << i << std::endl;
515507
outFile << "Count Matrix: " << std::endl;
516-
printCountsAndWaitingTimes(C[i], W[i], &outFile);
508+
printCountsAndWaitingTimes(C + (i * matSize), W + (i * num_states_), &outFile);
517509
outFile << std::endl;
518510
}
519511
outFile.close();
@@ -531,9 +523,9 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
531523

532524
for(int i = 0; i < genomeSize; i++) {
533525
for(int j = 0; j < num_states_; j++) {
534-
globalWaitingTimes[j] += W[i][j];
526+
globalWaitingTimes[j] += W[i * num_states_ + j];
535527
for(int k = 0; k < num_states_; k++) {
536-
globalCounts[row_index[j] + k] += C[i][row_index[j] + k];
528+
globalCounts[row_index[j] + k] += C[i * matSize + row_index[j] + k];
537529
}
538530
}
539531
}
@@ -549,10 +541,10 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
549541
for(int i = 0; i < genomeSize; i++) {
550542
for(int j = 0; j < num_states_; j++) {
551543
// Add pseudocounts to waitingTimes
552-
W[i][j] += waitingTimePseudoCount;
544+
W[i * num_states_ + j] += waitingTimePseudoCount;
553545
for(int k = 0; k < num_states_; k++) {
554546
// Add pseudocount of average rate across genome * waitingTime pseudocount for counts
555-
C[i][row_index[j] + k] += globalCounts[row_index[j] + k] * waitingTimePseudoCount / globalWaitingTimes[j];
547+
C[i * matSize + row_index[j] + k] += globalCounts[row_index[j] + k] * waitingTimePseudoCount / globalWaitingTimes[j];
556548
}
557549
}
558550
}
@@ -589,8 +581,8 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
589581
RealNumType totalRate = 0;
590582
// Update mutation matrices with new rate estimation
591583
for(int i = 0; i < genomeSize; i++) {
592-
RealNumType* Ci = C[i];
593-
RealNumType* Wi = W[i];
584+
RealNumType* Ci = C + (i * matSize);
585+
RealNumType* Wi = W + (i * num_states_);
594586
StateType refState = tree->aln->ref_seq[static_cast<std::vector<cmaple::StateType>::size_type>(i)];
595587

596588
for(int stateA = 0; stateA < num_states_; stateA++) {
@@ -642,35 +634,31 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
642634
}
643635

644636
// Clean-up
645-
for(int i = 0; i < genomeSize; i++) {
646-
delete[] C[i];
647-
delete[] W[i];
648-
}
649637
delete[] C;
650638
delete[] W;
651639
}
652640

653641
void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( PositionType start, PositionType end,
654642
StateType parentState, StateType childState,
655643
RealNumType distToRoot, RealNumType distToObserved,
656-
RealNumType** waitingTimes, RealNumType** counts,
644+
RealNumType* waitingTimes, RealNumType* counts,
657645
RealNumType weight)
658646
{
659647
if(parentState != childState) {
660648
for(int i = start; i <= end; i++) {
661649
RealNumType pRootIsStateParent = root_freqs[parentState] * getMutationMatrixEntry(parentState, childState, i) * distToRoot;
662650
RealNumType pRootIsStateChild = root_freqs[childState] * getMutationMatrixEntry(childState, parentState, i) * distToObserved;
663651
RealNumType relativeRootIsStateParent = pRootIsStateParent / (pRootIsStateParent + pRootIsStateChild);
664-
waitingTimes[i][parentState] += weight * relativeRootIsStateParent * distToRoot/2;
665-
waitingTimes[i][childState] += weight * relativeRootIsStateParent * distToRoot/2;
666-
counts[i][childState + row_index[parentState]] += weight * relativeRootIsStateParent;
652+
waitingTimes[i * num_states_ + parentState] += weight * relativeRootIsStateParent * distToRoot/2;
653+
waitingTimes[i * num_states_ + childState] += weight * relativeRootIsStateParent * distToRoot/2;
654+
counts[i * matSize + childState + row_index[parentState]] += weight * relativeRootIsStateParent;
667655

668656
RealNumType relativeRootIsStateChild = 1 - relativeRootIsStateParent;
669-
waitingTimes[i][childState] += weight * relativeRootIsStateChild * distToRoot;
657+
waitingTimes[i * num_states_ + childState] += weight * relativeRootIsStateChild * distToRoot;
670658
}
671659
} else {
672660
for(int i = start; i <= end; i++) {
673-
waitingTimes[i][childState] += weight * distToRoot;
661+
waitingTimes[i * num_states_ + childState] += weight * distToRoot;
674662
}
675663
}
676664
}

model/model_dna_rate_variation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class ModelDNARateVariation : public ModelDNA {
8282
void updateCountsAndWaitingTimesAcrossRoot( PositionType start, PositionType end,
8383
StateType parentState, StateType childState,
8484
RealNumType distToRoot, RealNumType distToObserved,
85-
RealNumType** waitingTimes, RealNumType** counts,
85+
RealNumType* waitingTimes, RealNumType* counts,
8686
RealNumType weight = 1.);
8787

8888
void readRatesFile();

0 commit comments

Comments
 (0)