Skip to content

Commit b261ed1

Browse files
committed
try to make parallel iterator work
1 parent 632e867 commit b261ed1

8 files changed

+1188
-66
lines changed

.cargo/config.toml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[target.x86_64-unknown-linux-gnu]
2+
# SSE3 is requred by simd-varint. tokio_unstable is for tokio-console.
3+
rustflags = ["-C", "target-feature=+ssse3,+avx2,+avx512,+popcnt,+bmi2", "--cfg", "tokio_unstable"]
4+
5+
[target.x86_64-apple-darwin]
6+
# SSE3 is requred by simd-varint. tokio_unstable is for tokio-console.
7+
rustflags = ["-C", "target-feature=+ssse3,+avx2,+bmi2", "--cfg", "tokio_unstable"]

crates/bpe/src/binary_prefix_tree.rs

+355
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
use std::{fmt::Debug, time::Instant};
2+
3+
use itertools::Itertools;
4+
use rand::{thread_rng, Rng};
5+
6+
struct BitField {
7+
data: Vec<u64>,
8+
bits: usize,
9+
rank: Vec<u32>,
10+
}
11+
12+
impl BitField {
13+
fn new() -> Self {
14+
Self {
15+
data: vec![],
16+
bits: 0,
17+
rank: vec![],
18+
}
19+
}
20+
21+
fn push(&mut self, value: bool) {
22+
if self.data.len() * 64 <= self.bits {
23+
self.rank.push(
24+
self.rank.last().copied().unwrap_or_default()
25+
+ self.data.last().copied().unwrap_or_default().count_ones(),
26+
);
27+
self.data.push(0);
28+
}
29+
if value {
30+
self.data[self.bits / 64] |= 1 << (self.bits & 63);
31+
}
32+
self.bits += 1;
33+
}
34+
35+
fn get(&self, index: usize) -> bool {
36+
self.data[index / 64] & (1 << (index & 63)) != 0
37+
}
38+
39+
fn get2(&self, index: usize) -> u32 {
40+
(self.data[index / 32] >> ((2 * index) & 63)) as u32 & 3
41+
}
42+
43+
fn rank(&self, index: usize) -> usize {
44+
let r = self.rank[index / 64];
45+
(r + (self.data[index / 64] & !(u64::MAX << (index & 63))).count_ones()) as usize
46+
}
47+
}
48+
49+
impl Debug for BitField {
50+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51+
f.debug_struct("BitField")
52+
.field("data", &self.data)
53+
.field("bits", &self.bits)
54+
.finish()
55+
}
56+
}
57+
58+
struct BinaryPrefixTree {
59+
data: BitField,
60+
}
61+
62+
const MAX_LEVEL: usize = 30;
63+
64+
impl BinaryPrefixTree {
65+
fn new(values: &[u32]) -> Self {
66+
let mut data = BitField::new();
67+
for level in (0..MAX_LEVEL).rev() {
68+
let mut previous_prefix = None;
69+
let mut occurrences = 0;
70+
for value in values {
71+
let prefix = *value >> level;
72+
if let Some(prev) = previous_prefix {
73+
if prefix >> 1 != prev {
74+
data.push(occurrences & 1 != 0);
75+
data.push(occurrences & 2 != 0);
76+
occurrences = 0;
77+
}
78+
}
79+
if prefix & 1 == 1 {
80+
occurrences |= 2;
81+
} else {
82+
occurrences |= 1;
83+
}
84+
previous_prefix = Some(prefix >> 1);
85+
}
86+
data.push(occurrences & 1 != 0);
87+
data.push(occurrences & 2 != 0);
88+
}
89+
println!("encoded size: {}", data.bits as f32 / values.len() as f32);
90+
Self { data }
91+
}
92+
93+
fn check(&self, value: u32) -> bool {
94+
let mut i = 0;
95+
for level in (0..MAX_LEVEL).rev() {
96+
if value & (1 << level) != 0 {
97+
i += 1;
98+
}
99+
if !self.data.get(i) {
100+
return false;
101+
}
102+
i = self.data.rank(i + 1) * 2;
103+
}
104+
true
105+
}
106+
107+
fn as_iter(&self) -> BinaryPrefixTreeIterator<'_> {
108+
BinaryPrefixTreeIterator::new(self)
109+
}
110+
111+
#[inline(always)]
112+
fn recurse(&self, node: usize, level: usize, value: u32, results: &mut Vec<u32>) {
113+
if level == 1 {
114+
self.recurse1(node, value, results);
115+
} else {
116+
let n = self.data.get2(node);
117+
let value = value * 2;
118+
let mut r = self.data.rank(node * 2);
119+
if n & 1 != 0 {
120+
r += 1;
121+
self.recurse(r, level - 1, value, results);
122+
}
123+
if n & 2 != 0 {
124+
r += 1;
125+
self.recurse(r, level - 1, value + 1, results);
126+
}
127+
}
128+
}
129+
130+
#[inline(always)]
131+
fn recurse1(&self, node: usize, value: u32, results: &mut Vec<u32>) {
132+
let n = self.data.get2(node);
133+
let value = value * 2;
134+
let mut r = self.data.rank(node * 2);
135+
if n & 1 != 0 {
136+
r += 1;
137+
self.recurse0(r, value, results);
138+
}
139+
if n & 2 != 0 {
140+
r += 1;
141+
self.recurse0(r, value + 1, results);
142+
}
143+
}
144+
145+
#[inline(always)]
146+
fn recurse0(&self, node: usize, value: u32, results: &mut Vec<u32>) {
147+
let n = self.data.get2(node);
148+
let value = value * 2;
149+
if n & 1 != 0 {
150+
results.push(value);
151+
}
152+
if n & 2 != 0 {
153+
results.push(value + 1);
154+
}
155+
}
156+
}
157+
158+
trait TreeIterator {
159+
fn get(&self, level: usize) -> u32;
160+
fn down(&mut self, level: usize, value: bool);
161+
}
162+
163+
struct BinaryPrefixTreeIterator<'a> {
164+
bpt: &'a BinaryPrefixTree,
165+
pos: [u32; MAX_LEVEL],
166+
}
167+
168+
impl<'a> BinaryPrefixTreeIterator<'a> {
169+
fn new(bpt: &'a BinaryPrefixTree) -> Self {
170+
Self {
171+
bpt,
172+
pos: [0; MAX_LEVEL],
173+
}
174+
}
175+
176+
fn into_iter(self) -> RustIterator<'a> {
177+
RustIterator::new(self)
178+
}
179+
180+
fn into_vec(mut self) -> Vec<u32> {
181+
let mut res = vec![];
182+
// init positions
183+
for level in (1..MAX_LEVEL).rev() {
184+
let index = self.pos[level];
185+
let new_index = self.bpt.data.rank(index as usize) as u32 + 1;
186+
self.pos[level - 1] = new_index * 2;
187+
}
188+
let mut level = MAX_LEVEL - 1;
189+
let mut value = 0u32;
190+
while level < MAX_LEVEL {
191+
let v = self.bpt.data.get(self.pos[level] as usize);
192+
// println!("{value} {level} {} {v}", self.pos[level]);
193+
self.pos[level] += 1;
194+
if v {
195+
if level == 0 {
196+
res.push(value);
197+
value += 1 << level;
198+
level = value.trailing_zeros() as usize;
199+
} else {
200+
level -= 1;
201+
}
202+
} else {
203+
value += 1 << level;
204+
level = value.trailing_zeros() as usize;
205+
}
206+
}
207+
res
208+
}
209+
}
210+
211+
impl<'a> TreeIterator for BinaryPrefixTreeIterator<'a> {
212+
fn get(&self, level: usize) -> u32 {
213+
self.bpt.data.get2(self.pos[level] as usize)
214+
}
215+
216+
fn down(&mut self, level: usize, value: bool) {
217+
let index = self.pos[level] * 2 + value as u32;
218+
let new_index = self.bpt.data.rank(index as usize + 1);
219+
self.pos[level - 1] = new_index as u32;
220+
}
221+
}
222+
223+
struct RustIterator<'a> {
224+
inner: BinaryPrefixTreeIterator<'a>,
225+
level: usize,
226+
item: u32,
227+
}
228+
229+
impl<'a> RustIterator<'a> {
230+
fn new(inner: BinaryPrefixTreeIterator<'a>) -> Self {
231+
Self {
232+
inner,
233+
level: MAX_LEVEL - 1,
234+
item: 0,
235+
}
236+
}
237+
}
238+
239+
impl<'a> Iterator for RustIterator<'a> {
240+
type Item = u32;
241+
242+
fn next(&mut self) -> Option<u32> {
243+
while self.level < MAX_LEVEL {
244+
let bitmask = self.inner.get(self.level);
245+
let curr = self.item & (1 << self.level);
246+
if curr == 0 && bitmask & 1 != 0 {
247+
if self.level == 0 {
248+
let res = self.item;
249+
self.item += 1;
250+
return Some(res);
251+
}
252+
self.inner.down(self.level, false);
253+
self.level -= 1;
254+
} else if bitmask & 2 != 0 {
255+
if curr == 0 {
256+
self.item += 1 << self.level;
257+
}
258+
if self.level == 0 {
259+
let res = self.item;
260+
self.item += 1;
261+
self.level = self.item.trailing_zeros() as usize + 1;
262+
return Some(res);
263+
}
264+
self.inner.down(self.level, true);
265+
self.level -= 1;
266+
} else {
267+
self.item += 1 << self.level;
268+
self.level = self.item.trailing_zeros() as usize;
269+
}
270+
}
271+
None
272+
}
273+
}
274+
275+
#[test]
276+
fn test_bpt() {
277+
let values = &[3, 6, 7, 10];
278+
let bpt = BinaryPrefixTree::new(values);
279+
println!("{:x?} {:?}", bpt.data.data, bpt.data.rank);
280+
assert!(!bpt.check(0));
281+
assert!(!bpt.check(1));
282+
assert!(!bpt.check(2));
283+
assert!(bpt.check(3));
284+
assert!(!bpt.check(4));
285+
assert!(!bpt.check(5));
286+
assert!(bpt.check(6));
287+
assert!(bpt.check(7));
288+
assert!(!bpt.check(8));
289+
assert!(!bpt.check(9));
290+
assert!(bpt.check(10));
291+
292+
/* let mut iter = bpt.as_iter();
293+
assert_eq!(iter.get(7), 1);
294+
iter.down(7, false);
295+
assert_eq!(iter.get(6), 1);
296+
iter.down(6, false);
297+
assert_eq!(iter.get(5), 1);
298+
iter.down(5, false);
299+
assert_eq!(iter.get(4), 1);
300+
iter.down(4, false);
301+
assert_eq!(iter.get(3), 3);
302+
iter.down(3, false);
303+
assert_eq!(iter.get(2), 3);
304+
iter.down(2, false);
305+
assert_eq!(iter.get(1), 2);
306+
iter.down(1, true);
307+
assert_eq!(iter.get(0), 2);*/
308+
309+
let iter = bpt.as_iter().into_iter();
310+
assert_eq!(iter.collect_vec(), values.iter().copied().collect_vec());
311+
312+
assert_eq!(
313+
bpt.as_iter().into_vec(),
314+
values.iter().copied().collect_vec()
315+
);
316+
}
317+
318+
#[test]
319+
fn test_intersection() {
320+
let start = Instant::now();
321+
// let mut values = (0..1)
322+
let mut values = (0..1000000)
323+
.map(|_| thread_rng().gen_range(0..100000000))
324+
.collect_vec();
325+
values.sort();
326+
values.dedup();
327+
println!("generation of values {:?}", start.elapsed());
328+
329+
let start = Instant::now();
330+
let bpt = BinaryPrefixTree::new(&values);
331+
println!("construction of tree {:?}", start.elapsed());
332+
333+
let start = Instant::now();
334+
assert_eq!(bpt.as_iter().into_iter().collect_vec(), values);
335+
println!("iteration {:?}", start.elapsed());
336+
337+
let start = Instant::now();
338+
let mut v = Vec::with_capacity(values.len());
339+
bpt.recurse(0, MAX_LEVEL - 1, 0, &mut v);
340+
assert_eq!(v, values);
341+
println!("recursive collect {:?}", start.elapsed());
342+
343+
let start = Instant::now();
344+
let v = bpt.as_iter().into_vec();
345+
println!("iteration {:?}", start.elapsed());
346+
assert_eq!(bpt.as_iter().into_vec(), values);
347+
348+
let start = Instant::now();
349+
let mut v = Vec::with_capacity(values.len());
350+
for i in values.iter() {
351+
v.push(*i);
352+
}
353+
println!("iteration {:?}", start.elapsed());
354+
assert_eq!(v, values);
355+
}

0 commit comments

Comments
 (0)