@@ -76,7 +76,7 @@ pub struct NullState<O: GroupIndexOperations> {
76
76
77
77
row_offset_buffer : Vec < usize > ,
78
78
79
- values_buffer : MutableBuffer ,
79
+ row_idxs_buffer : Vec < usize > ,
80
80
81
81
/// phantom data for required type `<O>`
82
82
_phantom : PhantomData < O > ,
@@ -139,7 +139,7 @@ impl<O: GroupIndexOperations> NullState<O> {
139
139
mut value_fn : F ,
140
140
) where
141
141
T : ArrowPrimitiveType + Send ,
142
- F : FnMut ( & [ u32 ] , & [ u64 ] , & [ usize ] , & [ T :: Native ] ) + Send ,
142
+ F : FnMut ( & [ u32 ] , & [ u64 ] , & [ usize ] , & [ usize ] ) + Send ,
143
143
{
144
144
// ensure the seen_values is big enough (start everything at
145
145
// "not seen" valid)
@@ -149,14 +149,14 @@ impl<O: GroupIndexOperations> NullState<O> {
149
149
let block_ids_buffer = & mut self . block_ids_buffer ;
150
150
let block_offsets_buffer = & mut self . block_offsets_buffer ;
151
151
let row_offset_buffer = & mut self . row_offset_buffer ;
152
- let values_buffer = & mut self . values_buffer ;
152
+ let row_idxs_buffer = & mut self . row_idxs_buffer ;
153
153
154
154
block_ids_buffer. clear ( ) ;
155
155
block_offsets_buffer. clear ( ) ;
156
156
row_offset_buffer. clear ( ) ;
157
- values_buffer . clear ( ) ;
157
+ row_idxs_buffer . clear ( ) ;
158
158
159
- accumulate ( group_indices, values, opt_filter, |packed_index, value | {
159
+ accumulate_blocks ( group_indices, values, opt_filter, |packed_index, row_idx | {
160
160
let packed_index = packed_index as u64 ;
161
161
let block_id = O :: get_block_id ( packed_index) ;
162
162
let block_offset = O :: get_block_offset ( packed_index) ;
@@ -168,7 +168,7 @@ impl<O: GroupIndexOperations> NullState<O> {
168
168
row_offset_buffer. push ( block_offsets_buffer. len ( ) ) ;
169
169
}
170
170
171
- values_buffer . push ( value ) ;
171
+ row_idxs_buffer . push ( row_idx ) ;
172
172
block_offsets_buffer. push ( block_offset) ;
173
173
} ) ;
174
174
row_offset_buffer. push ( block_offsets_buffer. len ( ) ) ;
@@ -182,7 +182,7 @@ impl<O: GroupIndexOperations> NullState<O> {
182
182
& block_ids_buffer,
183
183
& block_offsets_buffer,
184
184
& row_offset_buffer,
185
- values_buffer . typed_data ( ) ,
185
+ & row_idxs_buffer ,
186
186
) ;
187
187
}
188
188
@@ -351,7 +351,7 @@ impl NullStateAdapter {
351
351
value_fn : F ,
352
352
) where
353
353
T : ArrowPrimitiveType + Send ,
354
- F : FnMut ( & [ u32 ] , & [ u64 ] , & [ usize ] , & [ T :: Native ] ) + Send ,
354
+ F : FnMut ( & [ u32 ] , & [ u64 ] , & [ usize ] , & [ usize ] ) + Send ,
355
355
{
356
356
match self {
357
357
NullStateAdapter :: Flat ( null_state) => null_state. accumulate_blocks (
@@ -493,7 +493,7 @@ impl Default for FlatNullState {
493
493
block_ids_buffer : Vec :: new ( ) ,
494
494
block_offsets_buffer : Vec :: new ( ) ,
495
495
row_offset_buffer : Vec :: new ( ) ,
496
- values_buffer : MutableBuffer :: new ( 0 ) ,
496
+ row_idxs_buffer : Vec :: new ( ) ,
497
497
_phantom : PhantomData { } ,
498
498
}
499
499
}
@@ -529,7 +529,7 @@ impl BlockedNullState {
529
529
block_ids_buffer : Vec :: new ( ) ,
530
530
block_offsets_buffer : Vec :: new ( ) ,
531
531
row_offset_buffer : Vec :: new ( ) ,
532
- values_buffer : MutableBuffer :: new ( 0 ) ,
532
+ row_idxs_buffer : Vec :: new ( ) ,
533
533
_phantom : PhantomData { } ,
534
534
}
535
535
}
@@ -767,6 +767,103 @@ pub fn accumulate<T, F>(
767
767
}
768
768
}
769
769
770
+ pub fn accumulate_blocks < T , F > (
771
+ group_indices : & [ usize ] ,
772
+ values : & PrimitiveArray < T > ,
773
+ opt_filter : Option < & BooleanArray > ,
774
+ mut value_fn : F ,
775
+ ) where
776
+ T : ArrowPrimitiveType + Send ,
777
+ F : FnMut ( usize , usize ) + Send ,
778
+ {
779
+ let data: & [ T :: Native ] = values. values ( ) ;
780
+ assert_eq ! ( data. len( ) , group_indices. len( ) ) ;
781
+
782
+ match ( values. null_count ( ) > 0 , opt_filter) {
783
+ // no nulls, no filter,
784
+ ( false , None ) => {
785
+ for ( row_idx, & group_index) in group_indices. iter ( ) . enumerate ( ) {
786
+ value_fn ( group_index, row_idx) ;
787
+ }
788
+ }
789
+ // nulls, no filter
790
+ ( true , None ) => {
791
+ let nulls = values. nulls ( ) . unwrap ( ) ;
792
+ // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum
793
+ // iterate over in chunks of 64 bits for more efficient null checking
794
+ let group_indices_chunks = group_indices. chunks_exact ( 64 ) ;
795
+ let bit_chunks = nulls. inner ( ) . bit_chunks ( ) ;
796
+ let group_indices_remainder = group_indices_chunks. remainder ( ) ;
797
+
798
+ let mut row_idx = 0 ;
799
+ group_indices_chunks. zip ( bit_chunks. iter ( ) ) . for_each (
800
+ |( group_index_chunk, mask) | {
801
+ // index_mask has value 1 << i in the loop
802
+ let mut index_mask = 1 ;
803
+ group_index_chunk. iter ( ) . for_each ( |& group_index| {
804
+ // valid bit was set, real value
805
+ let is_valid = ( mask & index_mask) != 0 ;
806
+ if is_valid {
807
+ value_fn ( group_index, row_idx) ;
808
+ }
809
+ index_mask <<= 1 ;
810
+ row_idx += 1 ;
811
+ } )
812
+ } ,
813
+ ) ;
814
+
815
+ // handle any remaining bits (after the initial 64)
816
+ let remainder_bits = bit_chunks. remainder_bits ( ) ;
817
+ group_indices_remainder
818
+ . iter ( )
819
+ . enumerate ( )
820
+ . for_each ( |( i, & group_index) | {
821
+ let is_valid = remainder_bits & ( 1 << i) != 0 ;
822
+ if is_valid {
823
+ value_fn ( group_index, row_idx) ;
824
+ }
825
+ row_idx += 1 ;
826
+ } ) ;
827
+ }
828
+
829
+ // no nulls, but a filter
830
+ ( false , Some ( filter) ) => {
831
+ assert_eq ! ( filter. len( ) , group_indices. len( ) ) ;
832
+ // The performance with a filter could be improved by
833
+ // iterating over the filter in chunks, rather than a single
834
+ // iterator. TODO file a ticket
835
+ group_indices
836
+ . iter ( )
837
+ . zip ( filter. iter ( ) )
838
+ . enumerate ( )
839
+ . for_each ( |( row_idx, ( & group_index, filter_value) ) | {
840
+ if let Some ( true ) = filter_value {
841
+ value_fn ( group_index, row_idx) ;
842
+ }
843
+ } )
844
+ }
845
+
846
+ // both null values and filters
847
+ ( true , Some ( filter) ) => {
848
+ assert_eq ! ( filter. len( ) , group_indices. len( ) ) ;
849
+ // The performance with a filter could be improved by
850
+ // iterating over the filter in chunks, rather than using
851
+ // iterators. TODO file a ticket
852
+ filter
853
+ . iter ( )
854
+ . zip ( group_indices. iter ( ) )
855
+ . enumerate ( )
856
+ . for_each ( |( row_idx, ( filter_value, & group_index) ) | {
857
+ if let Some ( true ) = filter_value {
858
+ if values. is_null ( row_idx) {
859
+ value_fn ( group_index, row_idx)
860
+ }
861
+ }
862
+ } )
863
+ }
864
+ }
865
+ }
866
+
770
867
/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
771
868
///
772
869
/// This method assumes that for any input record index, if any of the value column
0 commit comments