Skip to content

Commit 8ec258e

Browse files
authored
perf: optimize em matrix build performance and memory usage
1 parent 0f4551d commit 8ec258e

File tree

2 files changed

+24
-68
lines changed

2 files changed

+24
-68
lines changed

src/matrix.rs

Lines changed: 22 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::sam::{extract_alignment_score, SamReader};
22
use crate::{MultiMappingReads, PathoscopeError, UniqueReads};
33
use log::info;
4-
use rustc_hash::FxHashMap;
4+
use rustc_hash::{FxHashMap, FxHashSet};
55

66
/// A matrix containing alignment data and metadata.
77
///
@@ -34,13 +34,12 @@ impl PathoscopeMatrix {
3434
fn rescale_scores(&mut self, u_temp: MultiMappingReads) {
3535
let (u, nu) = rescale_samscore(
3636
u_temp,
37-
self.multi_mapping_reads.clone(),
37+
&mut self.multi_mapping_reads,
3838
self.max_score,
3939
self.min_score,
4040
);
4141
self.multi_mapping_reads = nu;
4242

43-
// Convert u to unique_reads format and store for build_unique_map
4443
self.unique_reads = FxHashMap::default();
4544
for (read_idx, (ref_indices, scores, _, _)) in u {
4645
if let (Some(first_ref), Some(first_score)) =
@@ -52,15 +51,14 @@ impl PathoscopeMatrix {
5251
}
5352
}
5453

55-
/// Build the final unique reads map (placeholder as rescale_scores now handles this)
56-
fn build_unique_map(&mut self) {
57-
// This is now handled in rescale_scores method
58-
}
59-
6054
/// Normalize multi-mapping read scores so they sum to 1.0
6155
fn normalize_multi_mapping(&mut self) {
62-
let nu_keys: Vec<i32> = self.multi_mapping_reads.keys().cloned().collect();
63-
for k in nu_keys {
56+
for k in self
57+
.multi_mapping_reads
58+
.keys()
59+
.cloned()
60+
.collect::<Vec<i32>>()
61+
{
6462
if let Some(nu_entry) = self.multi_mapping_reads.get(&k) {
6563
let p_score_sum = nu_entry.1.iter().sum::<f64>();
6664
if let Some(nu_entry_mut) = self.multi_mapping_reads.get_mut(&k) {
@@ -86,15 +84,12 @@ impl PathoscopeMatrix {
8684
read_index: i32,
8785
ref_index: i32,
8886
score: f64,
89-
read_alignments: &mut FxHashMap<i32, Vec<(i32, f64)>>,
87+
read_alignments: &mut FxHashMap<i32, (FxHashSet<i32>, Vec<(i32, f64)>)>,
9088
) {
91-
// Update score range
9289
self.max_score = self.max_score.max(score);
9390
self.min_score = self.min_score.min(score);
94-
95-
// Add to read_alignments map (skip duplicates)
96-
let alignments = read_alignments.entry(read_index).or_default();
97-
if !alignments.iter().any(|(ref_idx, _)| *ref_idx == ref_index) {
91+
let (seen_refs, alignments) = read_alignments.entry(read_index).or_default();
92+
if seen_refs.insert(ref_index) {
9893
alignments.push((ref_index, score));
9994
}
10095
}
@@ -106,21 +101,21 @@ impl PathoscopeMatrix {
106101
///
107102
/// # Arguments
108103
/// * `read_alignments` - Map of read indices to their alignment data
109-
pub fn finalize(&mut self, read_alignments: FxHashMap<i32, Vec<(i32, f64)>>) {
110-
// Classify reads as unique or multi-mapping
104+
pub fn finalize(
105+
&mut self,
106+
read_alignments: FxHashMap<i32, (FxHashSet<i32>, Vec<(i32, f64)>)>,
107+
) {
111108
let mut u_temp: MultiMappingReads = FxHashMap::default();
112109
let mut nu: MultiMappingReads = FxHashMap::default();
113110

114-
for (read_index, read_alignments) in read_alignments {
111+
for (read_index, (_, read_alignments)) in read_alignments {
115112
if read_alignments.len() == 1 {
116-
// Unique read: maps to exactly one reference
117113
let (ref_index, score) = read_alignments[0];
118114
u_temp.insert(
119115
read_index,
120116
(vec![ref_index], vec![score], vec![score], score),
121117
);
122118
} else {
123-
// Non-unique read: maps to multiple references
124119
let ref_indices: Vec<i32> = read_alignments
125120
.iter()
126121
.map(|(ref_idx, _)| *ref_idx)
@@ -136,9 +131,7 @@ impl PathoscopeMatrix {
136131

137132
self.multi_mapping_reads = nu;
138133

139-
// Rescale scores and build final data structures
140134
self.rescale_scores(u_temp);
141-
self.build_unique_map();
142135
self.normalize_multi_mapping();
143136

144137
let unique_count = self.unique_reads.len();
@@ -169,67 +162,46 @@ pub fn build_matrix(
169162
let mut reader = SamReader::new(alignment_path)?;
170163
let header = reader.header().clone();
171164

172-
// Initialize matrix for incremental building
173165
let mut matrix = PathoscopeMatrix::new();
174-
175-
// Tracking variables
176166
let mut h_read_id: FxHashMap<String, i32> = FxHashMap::default();
177167
let mut h_ref_id: FxHashMap<String, i32> = FxHashMap::default();
178168
let mut ref_count: i32 = 0;
179169
let mut read_count: i32 = 0;
180-
let mut read_alignments: FxHashMap<i32, Vec<(i32, f64)>> = FxHashMap::default();
170+
let mut read_alignments: FxHashMap<i32, (FxHashSet<i32>, Vec<(i32, f64)>)> =
171+
FxHashMap::default();
181172

182-
// Stream through BAM file and build matrix incrementally
183173
reader.stream_chunks(|chunk| {
184-
// Process this chunk
185174
for record in chunk {
186-
// Skip unmapped reads
187175
if record.is_unmapped() {
188176
continue;
189177
}
190-
191-
// Get read ID (qname)
192178
let read_id = match std::str::from_utf8(record.qname()) {
193179
Ok(id) => id.to_string(),
194180
Err(_) => continue,
195181
};
196-
197-
// Get reference name
198182
let name_bytes = header.tid2name(record.tid() as u32);
199183
let ref_id = std::str::from_utf8(name_bytes).unwrap_or("*").to_string();
200-
201-
// Get alignment score using shared function
202184
let total_score = match extract_alignment_score(record) {
203185
Some(score) => score,
204186
None => continue,
205187
};
206-
207-
// Apply score cutoff
208188
if total_score <= p_score_cutoff {
209189
continue;
210190
}
211-
212-
// Track score range
213191
matrix.min_score = total_score.min(matrix.min_score);
214192
matrix.max_score = total_score.max(matrix.max_score);
215-
216-
// Get or create reference index
217193
let ref_index = *h_ref_id.entry(ref_id.clone()).or_insert_with(|| {
218194
let idx = ref_count;
219195
matrix.refs.push(ref_id);
220196
ref_count += 1;
221197
idx
222198
});
223-
224-
// Get or create read index
225199
let read_index = *h_read_id.entry(read_id.clone()).or_insert_with(|| {
226200
let idx = read_count;
227201
matrix.reads.push(read_id);
228202
read_count += 1;
229203
idx
230204
});
231-
232-
// Add alignment to matrix incrementally
233205
matrix.add_alignment(
234206
read_index,
235207
ref_index,
@@ -240,7 +212,6 @@ pub fn build_matrix(
240212
Ok(())
241213
})?;
242214

243-
// Finalize matrix after all alignments are processed
244215
matrix.finalize(read_alignments);
245216

246217
Ok(matrix)
@@ -249,7 +220,7 @@ pub fn build_matrix(
249220
/// modifies the scores of u and nu with respect to max_score and min_score
250221
fn rescale_samscore(
251222
mut u: MultiMappingReads,
252-
mut nu: MultiMappingReads,
223+
nu: &mut MultiMappingReads,
253224
max_score: f64,
254225
min_score: f64,
255226
) -> (MultiMappingReads, MultiMappingReads) {
@@ -259,8 +230,7 @@ fn rescale_samscore(
259230
100.0 / max_score
260231
};
261232

262-
let u_keys: Vec<i32> = u.keys().cloned().collect();
263-
for k in u_keys {
233+
for k in u.keys().cloned().collect::<Vec<i32>>() {
264234
if let Some(entry) = u.get_mut(&k) {
265235
if min_score < 0.0 {
266236
entry.1[0] -= min_score;
@@ -270,8 +240,7 @@ fn rescale_samscore(
270240
}
271241
}
272242

273-
let nu_keys: Vec<i32> = nu.keys().cloned().collect();
274-
for k in nu_keys {
243+
for k in nu.keys().cloned().collect::<Vec<i32>>() {
275244
if let Some(entry) = nu.get_mut(&k) {
276245
entry.3 = 0.0;
277246

@@ -288,22 +257,12 @@ fn rescale_samscore(
288257
}
289258
}
290259
}
291-
(u, nu)
260+
(u, std::mem::take(nu))
292261
}
293262

294263
#[cfg(test)]
295264
mod tests {
296-
#![allow(unused)]
297-
298265
use crate::matrix::*;
299-
use crate::*;
300-
use std::fs::File;
301-
use std::io::BufRead;
302-
use std::io::BufReader;
303-
use std::io::Read;
304-
305-
extern crate yaml_rust;
306-
use yaml_rust::{YamlEmitter, YamlLoader};
307266

308267
#[test]
309268
fn test_build_matrix() {

src/subtraction.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,7 @@ pub fn process_isolate_file(
131131
let mut writer = bam::Writer::from_path(output_path, &header, Format::Bam)
132132
.map_err(PathoscopeError::Htslib)?;
133133

134-
writer
135-
.set_threads(proc)
136-
.map_err(PathoscopeError::Htslib)?;
134+
writer.set_threads(proc).map_err(PathoscopeError::Htslib)?;
137135

138136
let mut all_subtracted_read_ids = HashSet::new();
139137
let mut write_buffer: Vec<bam::Record> = Vec::with_capacity(CHUNK_SIZE);
@@ -147,8 +145,7 @@ pub fn process_isolate_file(
147145
continue;
148146
}
149147

150-
let read_id_str =
151-
unsafe { std::str::from_utf8_unchecked(record.qname()) };
148+
let read_id_str = unsafe { std::str::from_utf8_unchecked(record.qname()) };
152149

153150
if record.tid() < 0 {
154151
continue;

0 commit comments

Comments
 (0)