Skip to content

Commit dc94961

Browse files
committed
extract EmitBlocksContext to common, and prevent all new updates during emitting.
1 parent da4c590 commit dc94961

File tree

2 files changed

+86
-8
lines changed

2 files changed

+86
-8
lines changed

datafusion/expr-common/src/groups_accumulator.rs

+58
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,64 @@ impl EmitTo {
6464
}
6565
}
6666

67+
/// Emitting context used in blocked management
68+
#[derive(Debug, Default)]
69+
pub struct EmitBlocksContext {
70+
/// Mark if it is during blocks emitting, if so states can't
71+
/// be updated until all blocks are emitted
72+
pub emitting: bool,
73+
74+
/// Idx of next emitted block
75+
pub next_emit_block: usize,
76+
77+
/// Number of blocks needed to emit
78+
pub num_blocks: usize,
79+
}
80+
81+
impl EmitBlocksContext {
82+
#[inline]
83+
pub fn new() -> Self {
84+
Self::default()
85+
}
86+
87+
#[inline]
88+
pub fn start_emit(&mut self, num_blocks: usize) {
89+
self.emitting = true;
90+
self.num_blocks = num_blocks;
91+
}
92+
93+
#[inline]
94+
pub fn emitting(&self) -> bool {
95+
self.emitting
96+
}
97+
98+
#[inline]
99+
pub fn all_emitted(&self) -> bool {
100+
self.next_emit_block == self.num_blocks
101+
}
102+
103+
#[inline]
104+
pub fn cur_emit_block(&self) -> usize {
105+
assert!(self.emitting, "must start emit first");
106+
self.next_emit_block
107+
}
108+
109+
#[inline]
110+
pub fn advance_emit_block(&mut self) {
111+
assert!(self.emitting, "must start emit first");
112+
if self.next_emit_block < self.num_blocks {
113+
self.next_emit_block += 1;
114+
}
115+
}
116+
117+
#[inline]
118+
pub fn reset(&mut self) {
119+
self.emitting = false;
120+
self.next_emit_block = 0;
121+
self.num_blocks = 0;
122+
}
123+
}
124+
67125
/// `GroupsAccumulator` implements a single aggregate (e.g. AVG) and
68126
/// stores the state for *all* groups internally.
69127
///

datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs

+28-8
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use arrow::datatypes::{i256, DataType};
2626
use arrow::record_batch::RecordBatch;
2727
use datafusion_common::Result;
2828
use datafusion_execution::memory_pool::proxy::VecAllocExt;
29+
use datafusion_expr::groups_accumulator::EmitBlocksContext;
2930
use datafusion_expr::EmitTo;
3031
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::group_index_operations::{
3132
BlockedGroupIndexOperations, FlatGroupIndexOperations, GroupIndexOperations,
@@ -101,8 +102,6 @@ pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
101102
/// The values for each group index
102103
values: Vec<Vec<T::Native>>,
103104

104-
next_emit_block_id: usize,
105-
106105
/// The random state used to generate hashes
107106
random_state: RandomState,
108107

@@ -116,7 +115,14 @@ pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
116115
///
117116
block_size: Option<usize>,
118117

118+
/// Number of current storing groups
119+
///
120+
/// We maintain it to avoid the expansive dynamic computation in
121+
/// `blocked approach`.
119122
num_groups: usize,
123+
124+
/// Context used in emitting in `blocked approach`
125+
emit_blocks_ctx: EmitBlocksContext,
120126
}
121127

122128
impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
@@ -132,11 +138,11 @@ impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
132138
data_type,
133139
map: HashTable::with_capacity(128),
134140
values,
135-
next_emit_block_id: 0,
136141
null_group: None,
137142
random_state: Default::default(),
138143
block_size: None,
139144
num_groups: 0,
145+
emit_blocks_ctx: EmitBlocksContext::new(),
140146
}
141147
}
142148
}
@@ -184,7 +190,6 @@ where
184190

185191
fn len(&self) -> usize {
186192
self.num_groups
187-
// self.values.iter().map(|block| block.len()).sum::<usize>()
188193
}
189194

190195
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
@@ -263,14 +268,27 @@ where
263268
.block_size
264269
.expect("only support EmitTo::Next in blocked group values");
265270

271+
// To mark the emitting has already started, and prevent new updates
272+
if !self.emit_blocks_ctx.emitting() {
273+
let num_blocks = self.num_groups.div_ceil(block_size);
274+
self.emit_blocks_ctx.start_emit(num_blocks);
275+
}
276+
266277
// Similar as `EmitTo:All`, we will clear the old index infos both
267278
// in `map` and `null_group`
268279
self.map.clear();
269280

270-
// Get current emit block id firstly
271-
let emit_block_id = self.next_emit_block_id;
281+
// Get current emit block idx firstly
282+
let emit_block_id = self.emit_blocks_ctx.cur_emit_block();
272283
let emit_blk = std::mem::take(&mut self.values[emit_block_id]);
273-
self.next_emit_block_id += 1;
284+
// And then we advance the block idx
285+
self.emit_blocks_ctx.advance_emit_block();
286+
// Finally we check if all blocks emitted, if so, we reset the
287+
// emit context to allow new updates
288+
if self.emit_blocks_ctx.all_emitted() {
289+
self.emit_blocks_ctx.reset();
290+
self.values.clear();
291+
}
274292

275293
// Check if `null` is in current block
276294
let null_block_pair_opt = self.null_group.map(|group_index| {
@@ -430,8 +448,9 @@ mod tests {
430448

431449
use crate::aggregates::group_values::single_group_by::primitive::GroupValuesPrimitive;
432450
use crate::aggregates::group_values::GroupValues;
433-
use arrow::array::{AsArray, UInt32Array};
451+
use arrow::array::{AsArray, RecordBatch, UInt32Array};
434452
use arrow::datatypes::{DataType, UInt32Type};
453+
use arrow_schema::Schema;
435454
use datafusion_expr::EmitTo;
436455
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::group_index_operations::{
437456
BlockedGroupIndexOperations, GroupIndexOperations,
@@ -567,6 +586,7 @@ mod tests {
567586
assert_eq!(actual, expected);
568587

569588
// Insert case 1.1~1.2 + Emit case 2.2
589+
group_values.clear_shrink(&RecordBatch::new_empty(Arc::new(Schema::empty())));
570590
group_values
571591
.intern(&[Arc::clone(&data2) as _], &mut group_indices)
572592
.unwrap();

0 commit comments

Comments
 (0)