@@ -26,6 +26,7 @@ use arrow::datatypes::{i256, DataType};
26
26
use arrow:: record_batch:: RecordBatch ;
27
27
use datafusion_common:: Result ;
28
28
use datafusion_execution:: memory_pool:: proxy:: VecAllocExt ;
29
+ use datafusion_expr:: groups_accumulator:: EmitBlocksContext ;
29
30
use datafusion_expr:: EmitTo ;
30
31
use datafusion_functions_aggregate_common:: aggregate:: groups_accumulator:: group_index_operations:: {
31
32
BlockedGroupIndexOperations , FlatGroupIndexOperations , GroupIndexOperations ,
@@ -101,8 +102,6 @@ pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
101
102
/// The values for each group index
102
103
values : Vec < Vec < T :: Native > > ,
103
104
104
- next_emit_block_id : usize ,
105
-
106
105
/// The random state used to generate hashes
107
106
random_state : RandomState ,
108
107
@@ -116,7 +115,14 @@ pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
116
115
///
117
116
block_size : Option < usize > ,
118
117
118
+ /// Number of current storing groups
119
+ ///
120
+ /// We maintain it to avoid the expansive dynamic computation in
121
+ /// `blocked approach`.
119
122
num_groups : usize ,
123
+
124
+ /// Context used in emitting in `blocked approach`
125
+ emit_blocks_ctx : EmitBlocksContext ,
120
126
}
121
127
122
128
impl < T : ArrowPrimitiveType > GroupValuesPrimitive < T > {
@@ -132,11 +138,11 @@ impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
132
138
data_type,
133
139
map : HashTable :: with_capacity ( 128 ) ,
134
140
values,
135
- next_emit_block_id : 0 ,
136
141
null_group : None ,
137
142
random_state : Default :: default ( ) ,
138
143
block_size : None ,
139
144
num_groups : 0 ,
145
+ emit_blocks_ctx : EmitBlocksContext :: new ( ) ,
140
146
}
141
147
}
142
148
}
@@ -184,7 +190,6 @@ where
184
190
185
191
fn len ( & self ) -> usize {
186
192
self . num_groups
187
- // self.values.iter().map(|block| block.len()).sum::<usize>()
188
193
}
189
194
190
195
fn emit ( & mut self , emit_to : EmitTo ) -> Result < Vec < ArrayRef > > {
@@ -263,14 +268,27 @@ where
263
268
. block_size
264
269
. expect ( "only support EmitTo::Next in blocked group values" ) ;
265
270
271
+ // To mark the emitting has already started, and prevent new updates
272
+ if !self . emit_blocks_ctx . emitting ( ) {
273
+ let num_blocks = self . num_groups . div_ceil ( block_size) ;
274
+ self . emit_blocks_ctx . start_emit ( num_blocks) ;
275
+ }
276
+
266
277
// Similar as `EmitTo:All`, we will clear the old index infos both
267
278
// in `map` and `null_group`
268
279
self . map . clear ( ) ;
269
280
270
- // Get current emit block id firstly
271
- let emit_block_id = self . next_emit_block_id ;
281
+ // Get current emit block idx firstly
282
+ let emit_block_id = self . emit_blocks_ctx . cur_emit_block ( ) ;
272
283
let emit_blk = std:: mem:: take ( & mut self . values [ emit_block_id] ) ;
273
- self . next_emit_block_id += 1 ;
284
+ // And then we advance the block idx
285
+ self . emit_blocks_ctx . advance_emit_block ( ) ;
286
+ // Finally we check if all blocks emitted, if so, we reset the
287
+ // emit context to allow new updates
288
+ if self . emit_blocks_ctx . all_emitted ( ) {
289
+ self . emit_blocks_ctx . reset ( ) ;
290
+ self . values . clear ( ) ;
291
+ }
274
292
275
293
// Check if `null` is in current block
276
294
let null_block_pair_opt = self . null_group . map ( |group_index| {
@@ -430,8 +448,9 @@ mod tests {
430
448
431
449
use crate :: aggregates:: group_values:: single_group_by:: primitive:: GroupValuesPrimitive ;
432
450
use crate :: aggregates:: group_values:: GroupValues ;
433
- use arrow:: array:: { AsArray , UInt32Array } ;
451
+ use arrow:: array:: { AsArray , RecordBatch , UInt32Array } ;
434
452
use arrow:: datatypes:: { DataType , UInt32Type } ;
453
+ use arrow_schema:: Schema ;
435
454
use datafusion_expr:: EmitTo ;
436
455
use datafusion_functions_aggregate_common:: aggregate:: groups_accumulator:: group_index_operations:: {
437
456
BlockedGroupIndexOperations , GroupIndexOperations ,
@@ -567,6 +586,7 @@ mod tests {
567
586
assert_eq ! ( actual, expected) ;
568
587
569
588
// Insert case 1.1~1.2 + Emit case 2.2
589
+ group_values. clear_shrink ( & RecordBatch :: new_empty ( Arc :: new ( Schema :: empty ( ) ) ) ) ;
570
590
group_values
571
591
. intern ( & [ Arc :: clone ( & data2) as _ ] , & mut group_indices)
572
592
. unwrap ( ) ;
0 commit comments