Skip to content

Commit 632e867

Browse files
committed
use standard bitrank instead of virtualized version
add quaternary parallel implementation (~50ns vs. ~80ns for binary).
1 parent 56dddba commit 632e867

File tree

4 files changed

+430
-170
lines changed

4 files changed

+430
-170
lines changed

crates/quaternary_trie/src/lib.rs

+5-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use virtual_bitrank::{VirtualBitRank, Word, WORD_BITS};
22

33
pub mod parallel;
4+
mod parallel4;
45
mod virtual_bitrank;
56

67
const MAX_LEVEL: usize = 14;
@@ -200,22 +201,13 @@ impl QuarternaryTrie {
200201
s.fill_bit_rank(&mut consumed, MAX_LEVEL - 1);
201202
}
202203
s.data.build();
203-
s.reset_stats();
204204
println!(
205205
"encoded size: {}",
206206
4.0 * s.level_idx[0] as f32 / values.len() as f32
207207
);
208208
s
209209
}
210210

211-
fn reset_stats(&mut self) {
212-
self.data.reset_stats();
213-
}
214-
215-
fn page_count(&self) -> (usize, usize) {
216-
self.data.page_count()
217-
}
218-
219211
fn recurse(&self, node: usize, level: usize, value: u32, results: &mut Vec<u32>) {
220212
if level == 1 {
221213
self.recurse2(node, value, results);
@@ -567,10 +559,10 @@ mod tests {
567559

568560
#[test]
569561
fn test_large() {
570-
let mut values: Vec<_> = (0..10000000)
562+
let mut values: Vec<_> = (0..100_000)
571563
.map(|_| thread_rng().gen_range(0..100000000))
572564
.collect();
573-
// let mut values: Vec<_> = (0..100).map(|_| thread_rng().gen_range(0..10000)).collect();
565+
// let mut values: Vec<_> = (0..10).map(|_| thread_rng().gen_range(0..100)).collect();
574566
values.sort();
575567
values.dedup();
576568

@@ -587,7 +579,7 @@ mod tests {
587579
let start = Instant::now();
588580
let result: Vec<_> = iter.collect();
589581
println!("iteration {:?}", start.elapsed() / values.len() as u32);
590-
// assert_eq!(result, values);
582+
assert_eq!(result, values);
591583
}
592584

593585
#[test]
@@ -670,17 +662,12 @@ mod tests {
670662
));
671663
let start = Instant::now();
672664
let result: Vec<_> = iter.collect();
673-
let count = trie.page_count();
674-
let count2 = trie2.page_count();
675-
page_counts[i] += count.0 + count2.0;
676665
println!(
677-
"trie intersection {:?} {}",
666+
"trie intersection {:?}",
678667
start.elapsed() / values.len() as u32,
679-
(count.0 + count2.0) as f32 / (count.1 + count2.1) as f32
680668
);
681669
assert_eq!(result, intersection);
682670
}
683-
println!("{page_counts:?}");
684671
}
685672
}
686673

crates/quaternary_trie/src/parallel.rs

+47-24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::arch::x86_64::{_pdep_u64, _pext_u64};
1+
use std::arch::x86_64::_pdep_u64;
22

33
use crate::virtual_bitrank::VirtualBitRank;
44

@@ -21,10 +21,10 @@ impl ParallelTrie {
2121
// !("fill_bit_rank {prefix} {mask:064b} {level}");
2222
for t in [0, 64 << level] {
2323
let mut sub_mask = 0;
24-
for i in 0..64 {
25-
if (1 << i) & mask == 0 {
26-
continue;
27-
}
24+
let mut mask = mask;
25+
while mask != 0 {
26+
let i = mask.trailing_zeros() as usize;
27+
mask &= mask - 1;
2828
if let Some(&value) = slices[i].get(0) {
2929
if (value ^ prefix) >> (level + 7) == 0 && value & (64 << level) == t {
3030
if WRITE {
@@ -134,19 +134,46 @@ impl ParallelTrie {
134134
}
135135
} else {
136136
let required_bits = word.count_ones();
137-
if required_bits == 0 {
138-
return;
137+
let mut new_rank = self.data.rank(rank) as usize + self.root_ones;
138+
139+
if required_bits == 1 {
140+
// TODO: simply switch to single bit recursion here instead of checking on every level again.
141+
// NOTE: we cannot easily read here a nibble, since the rank is only a multiple of 2, but not necessarily of 4.
142+
let mut w = self.data.get_word(rank) & 3;
143+
while w != 0 {
144+
let zeros = w.trailing_zeros();
145+
w &= w - 1;
146+
self.recurse(pos * 2 + zeros as usize, word, new_rank * 2, level - 1, v);
147+
new_rank += 1;
148+
}
149+
} else if required_bits <= 32 {
150+
let w = self.data.get_word(rank);
151+
let new_word = unsafe { _pdep_u64(w, word) };
152+
if new_word != 0 {
153+
self.recurse(pos * 2, new_word, new_rank * 2, level - 1, v);
154+
new_rank += new_word.count_ones() as usize;
155+
}
156+
157+
let w = w >> required_bits;
158+
let new_word = unsafe { _pdep_u64(w, word) };
159+
if new_word != 0 {
160+
self.recurse(pos * 2 + 1, new_word, new_rank * 2, level - 1, v);
161+
}
162+
} else {
163+
let w = self.data.get_word(rank);
164+
let new_word = unsafe { _pdep_u64(w, word) };
165+
if new_word != 0 {
166+
self.recurse(pos * 2, new_word, new_rank * 2, level - 1, v);
167+
new_rank += new_word.count_ones() as usize;
168+
}
169+
170+
let rank = rank + required_bits as usize;
171+
let w = self.data.get_word(rank);
172+
let new_word = unsafe { _pdep_u64(w, word) };
173+
if new_word != 0 {
174+
self.recurse(pos * 2 + 1, new_word, new_rank * 2, level - 1, v);
175+
}
139176
}
140-
let w = self.data.get_word(rank);
141-
let new_word = unsafe { _pdep_u64(w, word) };
142-
let new_rank = self.data.rank(rank) as usize + self.root_ones;
143-
self.recurse(pos * 2, new_word, new_rank * 2, level - 1, v);
144-
145-
let rank = rank + required_bits as usize;
146-
let w = self.data.get_word(rank);
147-
let new_word = unsafe { _pdep_u64(w, word) };
148-
let new_rank = self.data.rank(rank) as usize + self.root_ones;
149-
self.recurse(pos * 2 + 1, new_word, new_rank * 2, level - 1, v);
150177
}
151178
}
152179
}
@@ -155,19 +182,15 @@ impl ParallelTrie {
155182
mod tests {
156183
use std::time::Instant;
157184

158-
use itertools::{kmerge, Itertools};
159185
use rand::{thread_rng, Rng};
160186

161-
use crate::{
162-
parallel::ParallelTrie, Intersection, Layout, QuarternaryTrie, TrieIterator, TrieTraversal,
163-
Union,
164-
};
187+
use crate::parallel::ParallelTrie;
165188

166189
#[test]
167-
fn test_parallel() {
190+
fn test_parallel_large() {
168191
// let values = vec![3, 6, 7, 10, 90, 91, 120, 128, 129, 130, 231, 321, 999];
169192
// let values = vec![3, 6, 7, 321, 999];
170-
let mut values: Vec<_> = (0..10_000_000)
193+
let mut values: Vec<_> = (0..100_000)
171194
.map(|_| thread_rng().gen_range(0..100_000_000))
172195
.collect();
173196
values.sort();

0 commit comments

Comments
 (0)