Skip to content

Commit 7f529b9

Browse files
committed
use EmitBlocksContext to refactor BlockedNullState.
1 parent 9d0b73b commit 7f529b9

File tree

1 file changed

+115
-38
lines changed
  • datafusion/functions-aggregate-common/src/aggregate/groups_accumulator

1 file changed

+115
-38
lines changed

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

+115-38
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
2727
use arrow::buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
2828
use arrow::datatypes::ArrowPrimitiveType;
2929

30-
use datafusion_expr_common::groups_accumulator::EmitTo;
30+
use datafusion_expr_common::groups_accumulator::{EmitBlocksContext, EmitTo};
3131

3232
use crate::aggregate::groups_accumulator::blocks::{Block, Blocks};
3333
use crate::aggregate::groups_accumulator::group_index_operations::{
@@ -72,20 +72,10 @@ pub struct NullState<O: GroupIndexOperations> {
7272

7373
block_size: Option<usize>,
7474

75-
emit_context: Option<EmitBlocksContext>,
76-
7775
/// phantom data for required type `<O>`
7876
_phantom: PhantomData<O>,
7977
}
8078

81-
#[derive(Debug)]
82-
struct EmitBlocksContext {
83-
next_emit_block_id: usize,
84-
last_block_len: usize,
85-
num_blocks: usize,
86-
buffer: BooleanBuffer,
87-
}
88-
8979
impl<O: GroupIndexOperations> NullState<O> {
9080
/// return the size of all buffers allocated by this null state, not including self
9181
pub fn size(&self) -> usize {
@@ -186,7 +176,8 @@ impl<O: GroupIndexOperations> NullState<O> {
186176
.for_each(|((&group_index, new_value), is_valid)| {
187177
if is_valid {
188178
let block_id = O::get_block_id(group_index, block_size);
189-
let block_offset = O::get_block_offset(group_index, block_size);
179+
let block_offset =
180+
O::get_block_offset(group_index, block_size);
190181
seen_values.set_bit(group_index, false);
191182
value_fn(block_id, block_offset, new_value);
192183
}
@@ -203,7 +194,8 @@ impl<O: GroupIndexOperations> NullState<O> {
203194
.for_each(|((&group_index, new_value), filter_value)| {
204195
if let Some(true) = filter_value {
205196
let block_id = O::get_block_id(group_index, block_size);
206-
let block_offset = O::get_block_offset(group_index, block_size);
197+
let block_offset =
198+
O::get_block_offset(group_index, block_size);
207199
seen_values.set_bit(group_index, false);
208200
value_fn(block_id, block_offset, new_value);
209201
}
@@ -220,7 +212,8 @@ impl<O: GroupIndexOperations> NullState<O> {
220212
if let Some(true) = filter_value {
221213
if let Some(new_value) = new_value {
222214
let block_id = O::get_block_id(group_index, block_size);
223-
let block_offset = O::get_block_offset(group_index, block_size);
215+
let block_offset =
216+
O::get_block_offset(group_index, block_size);
224217
seen_values.set_bit(group_index, false);
225218
value_fn(block_id, block_offset, new_value);
226219
}
@@ -264,6 +257,7 @@ impl NullStateAdapter {
264257
}
265258
}
266259

260+
#[inline]
267261
pub fn accumulate<T, F>(
268262
&mut self,
269263
group_indices: &[usize],
@@ -293,6 +287,7 @@ impl NullStateAdapter {
293287
}
294288
}
295289

290+
#[inline]
296291
pub fn accumulate_boolean<F>(
297292
&mut self,
298293
group_indices: &[usize],
@@ -321,13 +316,15 @@ impl NullStateAdapter {
321316
}
322317
}
323318

319+
#[inline]
324320
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
325321
match self {
326322
NullStateAdapter::Flat(null_state) => null_state.build(emit_to),
327323
NullStateAdapter::Blocked(null_state) => null_state.build(),
328324
}
329325
}
330326

327+
#[inline]
331328
pub fn size(&self) -> usize {
332329
match self {
333330
NullStateAdapter::Flat(null_state) => null_state.size(),
@@ -411,7 +408,6 @@ impl Default for FlatNullState {
411408
Self {
412409
seen_values: BooleanBufferBuilder::new(0),
413410
block_size: None,
414-
emit_context: None,
415411
_phantom: PhantomData,
416412
}
417413
}
@@ -460,56 +456,137 @@ impl FlatNullState {
460456
///
461457
/// [`GroupsAccumulator::supports_blocked_groups`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator::supports_blocked_groups
462458
///
463-
pub type BlockedNullState = NullState<BlockedGroupIndexOperations>;
459+
#[derive(Debug)]
460+
pub struct BlockedNullState {
461+
inner: NullState<BlockedGroupIndexOperations>,
462+
emit_ctx: NullsEmitContext,
463+
}
464+
465+
#[derive(Debug, Default)]
466+
struct NullsEmitContext {
467+
base_ctx: EmitBlocksContext,
468+
last_block_len: usize,
469+
buffer: Option<BooleanBuffer>,
470+
}
471+
472+
impl NullsEmitContext {
473+
fn new() -> Self {
474+
Self::default()
475+
}
476+
}
464477

465478
impl BlockedNullState {
466479
pub fn new(block_size: usize) -> Self {
467-
Self {
480+
let inner = NullState {
468481
seen_values: BooleanBufferBuilder::new(0),
469482
block_size: Some(block_size),
470-
emit_context: None,
471483
_phantom: PhantomData {},
472-
}
484+
};
485+
486+
let emit_ctx = NullsEmitContext::new();
487+
488+
Self { inner, emit_ctx }
489+
}
490+
491+
#[inline]
492+
pub fn accumulate<T, F>(
493+
&mut self,
494+
group_indices: &[usize],
495+
values: &PrimitiveArray<T>,
496+
opt_filter: Option<&BooleanArray>,
497+
total_num_groups: usize,
498+
value_fn: F,
499+
) where
500+
T: ArrowPrimitiveType + Send,
501+
F: FnMut(usize, usize, T::Native) + Send,
502+
{
503+
assert!(!self.emit_ctx.base_ctx.emitting());
504+
self.inner.accumulate(
505+
group_indices,
506+
values,
507+
opt_filter,
508+
total_num_groups,
509+
value_fn,
510+
);
511+
}
512+
513+
#[inline]
514+
pub fn accumulate_boolean<F>(
515+
&mut self,
516+
group_indices: &[usize],
517+
values: &BooleanArray,
518+
opt_filter: Option<&BooleanArray>,
519+
total_num_groups: usize,
520+
value_fn: F,
521+
) where
522+
F: FnMut(usize, usize, bool) + Send,
523+
{
524+
assert!(!self.emit_ctx.base_ctx.emitting());
525+
self.inner.accumulate_boolean(
526+
group_indices,
527+
values,
528+
opt_filter,
529+
total_num_groups,
530+
value_fn,
531+
);
473532
}
474-
}
475533

476-
impl BlockedNullState {
477534
pub fn build(&mut self) -> NullBuffer {
478-
let block_size = self.block_size.unwrap();
535+
let block_size = self.inner.block_size.unwrap();
479536

480-
if self.emit_context.is_none() {
481-
let buffer = self.seen_values.finish();
537+
if !self.emit_ctx.base_ctx.emitting() {
538+
// Init needed contexts
539+
let buffer = self.inner.seen_values.finish();
482540
let num_blocks = buffer.len().div_ceil(block_size);
483541
let mut last_block_len = buffer.len() % block_size;
484542
last_block_len = if last_block_len > 0 {
485543
last_block_len
486544
} else {
487545
usize::MAX
488546
};
547+
self.emit_ctx.buffer = Some(buffer);
548+
self.emit_ctx.last_block_len = last_block_len;
489549

490-
self.emit_context = Some(EmitBlocksContext {
491-
next_emit_block_id: 0,
492-
last_block_len,
493-
num_blocks,
494-
buffer,
495-
});
550+
// Start emit
551+
self.emit_ctx.base_ctx.start_emit(num_blocks);
496552
}
497553

498-
let emit_context = self.emit_context.as_mut().unwrap();
499-
let cur_emit_block_id = emit_context.next_emit_block_id;
500-
emit_context.next_emit_block_id += 1;
554+
// Get current emit block idx
555+
let emit_block_id = self.emit_ctx.base_ctx.cur_emit_block();
556+
// And then we advance the block idx
557+
self.emit_ctx.base_ctx.advance_emit_block();
501558

502-
assert!(cur_emit_block_id < emit_context.num_blocks);
503-
let slice_offset = cur_emit_block_id * block_size;
504-
let slice_len = if cur_emit_block_id == emit_context.num_blocks - 1 {
505-
cmp::min(emit_context.last_block_len, block_size)
559+
// Process and generate the emit block
560+
let buffer = self.emit_ctx.buffer.as_ref().unwrap();
561+
let slice_offset = emit_block_id * block_size;
562+
let slice_len = if self.emit_ctx.base_ctx.all_emitted() {
563+
cmp::min(self.emit_ctx.last_block_len, block_size)
506564
} else {
507565
block_size
508566
};
567+
let emit_block = buffer.slice(slice_offset, slice_len);
568+
569+
// Finally we check if all blocks emitted, if so, we reset the
570+
// emit context to allow new updates
571+
if self.emit_ctx.base_ctx.all_emitted() {
572+
self.emit_ctx.base_ctx.reset();
573+
self.emit_ctx.buffer = None;
574+
self.emit_ctx.last_block_len = 0;
575+
}
509576

510-
let emit_block = emit_context.buffer.slice(slice_offset, slice_len);
511577
NullBuffer::new(emit_block)
512578
}
579+
580+
fn size(&self) -> usize {
581+
self.inner.size()
582+
+ size_of::<NullsEmitContext>()
583+
+ self
584+
.emit_ctx
585+
.buffer
586+
.as_ref()
587+
.map(|b| b.len() / 8)
588+
.unwrap_or_default()
589+
}
513590
}
514591

515592
/// Invokes `value_fn(group_index, value)` for each non null, non

0 commit comments

Comments
 (0)