@@ -13,15 +13,18 @@ use std::{
1313 hash:: Hash ,
1414 pin:: Pin ,
1515 task:: { Context , Poll } ,
16+ thread:: available_parallelism,
1617} ;
1718
1819use arrow:: {
1920 array:: {
20- as_primitive_array, as_string_array, ArrayIter , ArrayRef , BooleanArray ,
21- BooleanBufferBuilder , PrimitiveArray , StringArray ,
21+ as_primitive_array, as_string_array, ArrayRef , BooleanArray , BooleanBufferBuilder ,
22+ PrimitiveArray , StringArray ,
23+ } ,
24+ compute:: {
25+ and, filter, filter_record_batch,
26+ kernels:: cmp:: { distinct, eq} ,
2227 } ,
23- compute:: kernels:: cmp:: eq,
24- compute:: { and, filter_record_batch} ,
2528 datatypes:: { ArrowPrimitiveType , DataType , Int32Type , Int64Type } ,
2629 error:: ArrowError ,
2730 record_batch:: RecordBatch ,
@@ -33,7 +36,7 @@ use futures::{
3336use itertools:: { iproduct, Itertools } ;
3437
3538use iceberg_rust_spec:: { partition:: BoundPartitionField , spec:: values:: Value } ;
36- use once_map :: OnceMap ;
39+ use lru :: LruCache ;
3740use pin_project_lite:: pin_project;
3841
3942use crate :: error:: Error ;
@@ -58,7 +61,7 @@ pin_project! {
5861 #[ pin]
5962 record_batches: Pin <Box <dyn Stream <Item = Result <RecordBatch , ArrowError >> + Send >>,
6063 partition_fields: & ' a [ BoundPartitionField <' a>] ,
61- partition_streams: OnceMap <Vec <Value >, RecordBatchSender >,
64+ partition_streams: LruCache <Vec <Value >, RecordBatchSender >,
6265 queue: VecDeque <Result <( Vec <Value >, RecordBatchReceiver ) , Error >>,
6366 sends: Vec <( RecordBatchSender , RecordBatch ) >,
6467 }
@@ -72,7 +75,7 @@ impl<'a> PartitionStream<'a> {
7275 Self {
7376 record_batches,
7477 partition_fields,
75- partition_streams : OnceMap :: new ( ) ,
78+ partition_streams : LruCache :: unbounded ( ) ,
7679 queue : VecDeque :: new ( ) ,
7780 sends : Vec :: new ( ) ,
7881 }
@@ -109,13 +112,22 @@ impl Stream for PartitionStream<'_> {
109112 }
110113 }
111114
115+ // Limit the number of open partition_streams by available parallelism
116+ if this. partition_streams . len ( ) > available_parallelism ( ) . unwrap ( ) . get ( ) {
117+ if let Some ( ( _, mut sender) ) = this. partition_streams . pop_lru ( ) {
118+ sender. close_channel ( ) ;
119+ }
120+ }
121+
112122 match this. record_batches . as_mut ( ) . poll_next ( cx) {
113123 Poll :: Pending => {
114124 break Poll :: Pending ;
115125 }
116126 Poll :: Ready ( None ) => {
117- for sender in this. partition_streams . read_only_view ( ) . values ( ) {
118- sender. clone ( ) . close_channel ( ) ;
127+ while let Some ( ( _, sender) ) = this. partition_streams . pop_lru ( ) {
128+ if !sender. is_closed ( ) {
129+ sender. clone ( ) . close_channel ( ) ;
130+ }
119131 }
120132 break Poll :: Ready ( None ) ;
121133 }
@@ -127,16 +139,16 @@ impl Stream for PartitionStream<'_> {
127139 let ( partition_values, batch) = result?;
128140
129141 let sender = if let Some ( sender) =
130- this. partition_streams . get_cloned ( & partition_values)
142+ this. partition_streams . get ( & partition_values) . cloned ( )
131143 {
132144 sender
133145 } else {
146+ let ( sender, reciever) = channel ( 1 ) ;
147+ this. queue
148+ . push_back ( Ok ( ( partition_values. clone ( ) , reciever) ) ) ;
134149 this. partition_streams
135- . insert_cloned ( partition_values, |key| {
136- let ( sender, reciever) = channel ( 1 ) ;
137- this. queue . push_back ( Ok ( ( key. clone ( ) , reciever) ) ) ;
138- sender
139- } )
150+ . push ( partition_values, sender. clone ( ) ) ;
151+ sender
140152 } ;
141153
142154 this. sends . push ( ( sender, batch) ) ;
@@ -250,12 +262,12 @@ fn distinct_values(array: ArrayRef) -> Result<DistinctValues, ArrowError> {
250262 DataType :: Int32 => Ok ( DistinctValues :: Int ( distinct_values_primitive :: <
251263 i32 ,
252264 Int32Type ,
253- > ( array) ) ) ,
265+ > ( array) ? ) ) ,
254266 DataType :: Int64 => Ok ( DistinctValues :: Long ( distinct_values_primitive :: <
255267 i64 ,
256268 Int64Type ,
257- > ( array) ) ) ,
258- DataType :: Utf8 => Ok ( DistinctValues :: String ( distinct_values_string ( array) ) ) ,
269+ > ( array) ? ) ) ,
270+ DataType :: Utf8 => Ok ( DistinctValues :: String ( distinct_values_string ( array) ? ) ) ,
259271 _ => Err ( ArrowError :: ComputeError (
260272 "Datatype not supported for transform." . to_string ( ) ,
261273 ) ) ,
@@ -275,15 +287,36 @@ fn distinct_values(array: ArrayRef) -> Result<DistinctValues, ArrowError> {
275287/// A HashSet containing all unique values from the array
276288fn distinct_values_primitive < T : Eq + Hash , P : ArrowPrimitiveType < Native = T > > (
277289 array : ArrayRef ,
278- ) -> HashSet < P :: Native > {
279- let mut set = HashSet :: new ( ) ;
290+ ) -> Result < HashSet < P :: Native > , ArrowError > {
280291 let array = as_primitive_array :: < P > ( & array) ;
281- for value in ArrayIter :: new ( array) . flatten ( ) {
282- if !set. contains ( & value) {
283- set. insert ( value) ;
284- }
292+
293+ let first = array. value ( 0 ) ;
294+
295+ let slice_len = array. len ( ) - 1 ;
296+
297+ if slice_len == 0 {
298+ return Ok ( HashSet :: from_iter ( [ first] ) ) ;
285299 }
286- set
300+
301+ let v1 = array. slice ( 0 , slice_len) ;
302+ let v2 = array. slice ( 1 , slice_len) ;
303+
304+ // Which consecutive entries are different
305+ let mask = distinct ( & v1, & v2) ?;
306+
307+ let unique = filter ( & v2, & mask) ?;
308+
309+ let unique = as_primitive_array :: < P > ( & unique) ;
310+
311+ let set = unique
312+ . iter ( )
313+ . fold ( HashSet :: from_iter ( [ first] ) , |mut acc, x| {
314+ if let Some ( x) = x {
315+ acc. insert ( x) ;
316+ }
317+ acc
318+ } ) ;
319+ Ok ( set)
287320}
288321
289322/// Extracts distinct string values from an Arrow array into a HashSet
@@ -293,15 +326,36 @@ fn distinct_values_primitive<T: Eq + Hash, P: ArrowPrimitiveType<Native = T>>(
293326///
294327/// # Returns
295328/// A HashSet containing all unique string values from the array
296- fn distinct_values_string ( array : ArrayRef ) -> HashSet < String > {
297- let mut set = HashSet :: new ( ) ;
329+ fn distinct_values_string ( array : ArrayRef ) -> Result < HashSet < String > , ArrowError > {
330+ let slice_len = array. len ( ) - 1 ;
331+
298332 let array = as_string_array ( & array) ;
299- for value in ArrayIter :: new ( array) . flatten ( ) {
300- if !set. contains ( value) {
301- set. insert ( value. to_owned ( ) ) ;
302- }
333+
334+ let first = array. value ( 0 ) . to_owned ( ) ;
335+
336+ if slice_len == 0 {
337+ return Ok ( HashSet :: from_iter ( [ first] ) ) ;
303338 }
304- set
339+
340+ let v1 = array. slice ( 0 , slice_len) ;
341+ let v2 = array. slice ( 1 , slice_len) ;
342+
343+ // Which consecutive entries are different
344+ let mask = distinct ( & v1, & v2) ?;
345+
346+ let unique = filter ( & v2, & mask) ?;
347+
348+ let unique = as_string_array ( & unique) ;
349+
350+ let set = unique
351+ . iter ( )
352+ . fold ( HashSet :: from_iter ( [ first] ) , |mut acc, x| {
353+ if let Some ( x) = x {
354+ acc. insert ( x. to_owned ( ) ) ;
355+ }
356+ acc
357+ } ) ;
358+ Ok ( set)
305359}
306360
307361/// Represents distinct values found in Arrow arrays during partitioning
0 commit comments