11// This check is new and seems buggy (possibly with PyO3 interaction)
22#![ allow( clippy:: borrow_deref_ref) ]
33
4- use std:: collections:: HashSet ;
4+ use std:: collections:: { BTreeMap , BTreeSet , HashSet } ;
5+ use std:: iter:: successors;
56use std:: num:: NonZeroU64 ;
67use std:: thread;
78
@@ -15,7 +16,17 @@ use rustc_hash::FxHashMap as HashMap;
1516
1617type Rank = u32 ;
1718
19+ const LARGE_ENCODER_CHARACTER_LIMIT : usize = 500 ;
20+
1821fn _byte_pair_merge ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < ( usize , Rank ) > {
22+ if piece. len ( ) < LARGE_ENCODER_CHARACTER_LIMIT {
23+ _byte_pair_merge_small ( ranks, piece) // Quadratic, but lightweight
24+ } else {
25+ _byte_pair_merge_large ( ranks, piece) // Linearithmic, but heavy
26+ }
27+ }
28+
29+ fn _byte_pair_merge_small ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < ( usize , Rank ) > {
1930 // This is a vector of (start, rank).
2031 // The rank is of the pair starting at position start.
2132 let mut parts = Vec :: with_capacity ( piece. len ( ) + 1 ) ;
@@ -73,6 +84,78 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
7384 parts
7485}
7586
87+ fn _byte_pair_merge_large ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < ( usize , Rank ) > {
88+ let mut rank_indexes = BTreeMap :: < Rank , BTreeSet < usize > > :: new ( ) ;
89+ let mut index_rank = vec ! [ Rank :: MAX ; piece. len( ) + 1 ] ;
90+ let mut index_prev = vec ! [ usize :: MAX ; piece. len( ) + 1 ] ;
91+ let mut index_next = vec ! [ usize :: MAX ; piece. len( ) + 1 ] ;
92+
93+ let get_rank = |start_idx : usize , end_idx : usize | -> Rank {
94+ * piece. get ( start_idx..end_idx)
95+ . and_then ( |p| ranks. get ( p) )
96+ . unwrap_or ( & Rank :: MAX )
97+ } ;
98+
99+ let mut prev_node = None ;
100+ for i in 0 ..=piece. len ( ) {
101+ let rank = get_rank ( i, i + 2 ) ;
102+ index_rank[ i] = rank;
103+ if let Some ( prev) = prev_node {
104+ index_prev[ i] = prev;
105+ index_next[ prev] = i;
106+ }
107+ prev_node = Some ( i) ;
108+
109+ rank_indexes. entry ( rank) . or_default ( ) . insert ( i) ;
110+ }
111+
112+ while rank_indexes. len ( ) > 1 {
113+ let mut skip_next = false ;
114+ if let Some ( ( _, nodes) ) = rank_indexes. pop_first ( ) {
115+ for & min_node in & nodes {
116+ if skip_next {
117+ skip_next = false ;
118+ continue ;
119+ }
120+
121+ let min_rank = index_rank[ min_node] ;
122+
123+ let prev_node = index_prev[ min_node] ;
124+ let next_node = index_next[ min_node] ;
125+ let next_next_node = index_next[ next_node] ;
126+ let next_next_next_node = index_next[ next_next_node] ;
127+
128+ if prev_node != usize:: MAX {
129+ let new_rank = get_rank ( prev_node, next_next_node) ;
130+ if index_rank[ prev_node] != new_rank {
131+ rank_indexes. get_mut ( & index_rank[ prev_node] ) . unwrap ( ) . remove ( & prev_node) ;
132+ index_rank[ prev_node] = new_rank;
133+ rank_indexes. entry ( new_rank) . or_default ( ) . insert ( prev_node) ;
134+ }
135+ }
136+
137+ let new_rank = get_rank ( min_node, next_next_next_node) ;
138+ index_rank[ min_node] = new_rank;
139+ rank_indexes. entry ( new_rank) . or_default ( ) . insert ( min_node) ;
140+
141+ index_next[ min_node] = next_next_node;
142+ index_prev[ next_next_node] = min_node;
143+
144+ let next_node_rank = index_rank[ next_node] ;
145+ if next_node_rank == min_rank {
146+ skip_next = true ;
147+ } else if next_node_rank != Rank :: MAX {
148+ rank_indexes. get_mut ( & next_node_rank) . unwrap ( ) . remove ( & next_node) ;
149+ }
150+ }
151+ }
152+ }
153+
154+ successors ( Some ( 0 ) , |& n| index_next. get ( n) . filter ( |& & x| x != usize:: MAX ) . copied ( ) )
155+ . map ( |n| ( n, Rank :: MAX ) )
156+ . collect ( )
157+ }
158+
76159pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
77160 assert ! ( piece. len( ) > 1 ) ;
78161 _byte_pair_merge ( & ranks, & piece)
0 commit comments