Skip to content

Commit 2cd5ea9

Browse files
committed
support dynamic dispatching for NullState.
1 parent 72a44a8 commit 2cd5ea9

File tree

1 file changed

+199
-71
lines changed
  • datafusion/functions-aggregate-common/src/aggregate/groups_accumulator

1 file changed

+199
-71
lines changed

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

+199-71
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,6 @@ pub struct NullState<V: SeenValues, O: GroupIndexOperations> {
7474
}
7575

7676
impl<V: SeenValues, O: GroupIndexOperations> NullState<V, O> {
77-
pub fn new() -> Self {
78-
Self {
79-
seen_values: V::default(),
80-
_phantom: PhantomData {},
81-
}
82-
}
83-
8477
/// return the size of all buffers allocated by this null state, not including self
8578
pub fn size(&self) -> usize {
8679
// capacity is in bits, so convert to bytes
@@ -237,66 +230,10 @@ impl<V: SeenValues, O: GroupIndexOperations> NullState<V, O> {
237230
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
238231
self.seen_values.emit(emit_to)
239232
}
240-
241-
/// Clone and build a single [`BooleanBuffer`] from `seen_values`,
242-
/// only used for testing.
243-
#[cfg(test)]
244-
fn build_cloned_seen_values(&self) -> BooleanBuffer {
245-
if let Some(seen_values) =
246-
self.seen_values.as_any().downcast_ref::<FlatSeenValues>()
247-
{
248-
seen_values.builder.finish_cloned()
249-
} else if let Some(seen_values) = self
250-
.seen_values
251-
.as_any()
252-
.downcast_ref::<BlockedSeenValues>()
253-
{
254-
let mut return_builder = BooleanBufferBuilder::new(0);
255-
for builder in &seen_values.blocked_builders {
256-
for idx in 0..builder.len() {
257-
return_builder.append(builder.get_bit(idx));
258-
}
259-
}
260-
return_builder.finish()
261-
} else {
262-
unreachable!("unknown impl of SeenValues")
263-
}
264-
}
265-
266-
/// Emit a single [`NullBuffer`], only used for testing.
267-
#[cfg(test)]
268-
fn emit_all_in_once(&mut self, total_num_groups: usize) -> NullBuffer {
269-
if let Some(seen_values) =
270-
self.seen_values.as_any().downcast_ref::<FlatSeenValues>()
271-
{
272-
seen_values.emit(EmitTo::All)
273-
} else if let Some(seen_values) = self
274-
.seen_values
275-
.as_any()
276-
.downcast_ref::<BlockedSeenValues>()
277-
{
278-
let mut return_builder = BooleanBufferBuilder::new(0);
279-
let num_blocks = seen_values.blocked_builders.len();
280-
for _ in 0..num_blocks {
281-
let blocked_nulls = seen_values.emit(EmitTo::NextBlock(true));
282-
for bit in blocked_nulls.inner().iter() {
283-
return_builder.append(bit);
284-
}
285-
}
286-
287-
NullBuffer::new(return_builder.finish())
288-
} else {
289-
unreachable!("unknown impl of SeenValues")
290-
}
291-
}
292233
}
293234

294235
/// Structure marking if accumulating groups are seen at least one
295236
pub trait SeenValues: Default + Debug + Send {
296-
fn as_any(&self) -> &dyn std::any::Any {
297-
self
298-
}
299-
300237
fn resize(&mut self, total_num_groups: usize, default_value: bool);
301238

302239
fn set_bit(&mut self, block_id: u32, block_offset: u64, value: bool);
@@ -401,6 +338,15 @@ pub struct BlockedSeenValues {
401338
block_size: usize,
402339
}
403340

341+
impl BlockedSeenValues {
342+
pub fn new(block_size: usize) -> Self {
343+
Self {
344+
blocked_builders: VecDeque::new(),
345+
block_size,
346+
}
347+
}
348+
}
349+
404350
impl SeenValues for BlockedSeenValues {
405351
fn resize(&mut self, total_num_groups: usize, default_value: bool) {
406352
let block_size = self.block_size;
@@ -471,7 +417,10 @@ impl SeenValues for BlockedSeenValues {
471417
fn emit(&mut self, emit_to: EmitTo) -> NullBuffer {
472418
assert!(matches!(emit_to, EmitTo::NextBlock(_)));
473419

474-
let mut block = self.blocked_builders.pop_front().expect("");
420+
let mut block = self
421+
.blocked_builders
422+
.pop_front()
423+
.expect("should not try to emit empty blocks");
475424
let nulls = block.finish();
476425

477426
NullBuffer::new(nulls)
@@ -485,9 +434,148 @@ impl SeenValues for BlockedSeenValues {
485434
}
486435
}
487436

437+
/// Adapter for supporting dynamic dispatching of [`FlatNullState`] and [`BlockedNullState`].
438+
/// For performance, the cost of batch-level dynamic dispatching is acceptable.
439+
pub enum NullStateAdapter {
440+
Flat(FlatNullState),
441+
Blocked(BlockedNullState),
442+
}
443+
444+
impl NullStateAdapter {
445+
pub fn new(block_size: Option<usize>) -> Self {
446+
if let Some(blk_size) = block_size {
447+
Self::Blocked(BlockedNullState::new(blk_size))
448+
} else {
449+
Self::Flat(FlatNullState::new())
450+
}
451+
}
452+
453+
pub fn accumulate<T, F>(
454+
&mut self,
455+
group_indices: &[usize],
456+
values: &PrimitiveArray<T>,
457+
opt_filter: Option<&BooleanArray>,
458+
total_num_groups: usize,
459+
value_fn: F,
460+
) where
461+
T: ArrowPrimitiveType + Send,
462+
F: FnMut(u32, u64, T::Native) + Send,
463+
{
464+
match self {
465+
NullStateAdapter::Flat(null_state) => null_state.accumulate(
466+
group_indices,
467+
values,
468+
opt_filter,
469+
total_num_groups,
470+
value_fn,
471+
),
472+
NullStateAdapter::Blocked(null_state) => null_state.accumulate(
473+
group_indices,
474+
values,
475+
opt_filter,
476+
total_num_groups,
477+
value_fn,
478+
),
479+
}
480+
}
481+
482+
pub fn accumulate_boolean<F>(
483+
&mut self,
484+
group_indices: &[usize],
485+
values: &BooleanArray,
486+
opt_filter: Option<&BooleanArray>,
487+
total_num_groups: usize,
488+
value_fn: F,
489+
) where
490+
F: FnMut(u32, u64, bool) + Send,
491+
{
492+
match self {
493+
NullStateAdapter::Flat(null_state) => null_state.accumulate_boolean(
494+
group_indices,
495+
values,
496+
opt_filter,
497+
total_num_groups,
498+
value_fn,
499+
),
500+
NullStateAdapter::Blocked(null_state) => null_state.accumulate_boolean(
501+
group_indices,
502+
values,
503+
opt_filter,
504+
total_num_groups,
505+
value_fn,
506+
),
507+
}
508+
}
509+
510+
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
511+
match self {
512+
NullStateAdapter::Flat(null_state) => null_state.build(emit_to),
513+
NullStateAdapter::Blocked(null_state) => null_state.build(emit_to),
514+
}
515+
}
516+
517+
/// Clone and build a single [`BooleanBuffer`] from `seen_values`,
518+
/// only used for testing.
519+
#[cfg(test)]
520+
fn build_cloned_seen_values(&self) -> BooleanBuffer {
521+
match self {
522+
NullStateAdapter::Flat(null_state) => {
523+
null_state.seen_values.builder.finish_cloned()
524+
}
525+
NullStateAdapter::Blocked(null_state) => {
526+
let mut return_builder = BooleanBufferBuilder::new(0);
527+
for builder in &null_state.seen_values.blocked_builders {
528+
for idx in 0..builder.len() {
529+
return_builder.append(builder.get_bit(idx));
530+
}
531+
}
532+
return_builder.finish()
533+
}
534+
}
535+
}
536+
537+
#[cfg(test)]
538+
fn build_all_in_once(&mut self) -> NullBuffer {
539+
match self {
540+
NullStateAdapter::Flat(null_state) => null_state.build(EmitTo::All),
541+
NullStateAdapter::Blocked(null_state) => {
542+
let mut return_builder = BooleanBufferBuilder::new(0);
543+
let num_blocks = null_state.seen_values.blocked_builders.len();
544+
for _ in 0..num_blocks {
545+
let blocked_nulls = null_state.build(EmitTo::NextBlock(true));
546+
for bit in blocked_nulls.inner().iter() {
547+
return_builder.append(bit);
548+
}
549+
}
550+
551+
NullBuffer::new(return_builder.finish())
552+
}
553+
}
554+
}
555+
}
556+
488557
pub type FlatNullState = NullState<FlatSeenValues, FlatGroupIndexOperations>;
558+
559+
impl FlatNullState {
560+
pub fn new() -> Self {
561+
Self {
562+
seen_values: FlatSeenValues::default(),
563+
_phantom: PhantomData {},
564+
}
565+
}
566+
}
567+
489568
pub type BlockedNullState = NullState<BlockedSeenValues, BlockedGroupIndexOperations>;
490569

570+
impl BlockedNullState {
571+
pub fn new(block_size: usize) -> Self {
572+
Self {
573+
seen_values: BlockedSeenValues::new(block_size),
574+
_phantom: PhantomData {},
575+
}
576+
}
577+
}
578+
491579
/// Invokes `value_fn(group_index, value)` for each non null, non
492580
/// filtered value of `value`,
493581
///
@@ -873,6 +961,7 @@ mod test {
873961
values,
874962
values_with_nulls,
875963
filter,
964+
block_size: None,
876965
}
877966
.run()
878967
}
@@ -953,6 +1042,7 @@ mod test {
9531042
values,
9541043
values_with_nulls,
9551044
filter,
1045+
block_size: None,
9561046
}
9571047
}
9581048

@@ -977,14 +1067,21 @@ mod test {
9771067
let filter = &self.filter;
9781068

9791069
// no null, no filters
980-
Self::accumulate_test(group_indices, &values_array, None, total_num_groups);
1070+
Self::accumulate_test(
1071+
group_indices,
1072+
&values_array,
1073+
None,
1074+
total_num_groups,
1075+
self.block_size,
1076+
);
9811077

9821078
// nulls, no filters
9831079
Self::accumulate_test(
9841080
group_indices,
9851081
&values_with_nulls_array,
9861082
None,
9871083
total_num_groups,
1084+
self.block_size,
9881085
);
9891086

9901087
// no nulls, filters
@@ -993,6 +1090,7 @@ mod test {
9931090
&values_array,
9941091
Some(filter),
9951092
total_num_groups,
1093+
self.block_size,
9961094
);
9971095

9981096
// nulls, filters
@@ -1001,6 +1099,7 @@ mod test {
10011099
&values_with_nulls_array,
10021100
Some(filter),
10031101
total_num_groups,
1102+
self.block_size,
10041103
);
10051104
}
10061105

@@ -1012,12 +1111,14 @@ mod test {
10121111
values: &UInt32Array,
10131112
opt_filter: Option<&BooleanArray>,
10141113
total_num_groups: usize,
1114+
block_size: Option<usize>,
10151115
) {
10161116
Self::accumulate_values_test(
10171117
group_indices,
10181118
values,
10191119
opt_filter,
10201120
total_num_groups,
1121+
block_size,
10211122
);
10221123
Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter);
10231124

@@ -1041,17 +1142,44 @@ mod test {
10411142
values: &UInt32Array,
10421143
opt_filter: Option<&BooleanArray>,
10431144
total_num_groups: usize,
1145+
block_size: Option<usize>,
10441146
) {
10451147
let mut accumulated_values = vec![];
1046-
let mut null_state = FlatNullState::new();
1148+
let (mut null_state, block_size, acc_group_indices) = if let Some(blk_size) =
1149+
block_size
1150+
{
1151+
let acc_group_indices = group_indices
1152+
.iter()
1153+
.copied()
1154+
.map(|index| {
1155+
let block_id = (index / blk_size) as u32;
1156+
let block_offset = (index % blk_size) as u64;
1157+
BlockedGroupIndexOperations::pack_index(block_id, block_offset)
1158+
as usize
1159+
})
1160+
.collect::<Vec<_>>();
1161+
(
1162+
NullStateAdapter::new(Some(blk_size)),
1163+
blk_size,
1164+
acc_group_indices,
1165+
)
1166+
} else {
1167+
(
1168+
NullStateAdapter::new(None),
1169+
0,
1170+
group_indices.iter().copied().collect(),
1171+
)
1172+
};
10471173

10481174
null_state.accumulate(
1049-
group_indices,
1175+
&acc_group_indices,
10501176
values,
10511177
opt_filter,
10521178
total_num_groups,
1053-
|_, group_index, value| {
1054-
accumulated_values.push((group_index as usize, value));
1179+
|block_id, block_offset, value| {
1180+
let flatten_index =
1181+
((block_id as u64 * block_size as u64) + block_offset) as usize;
1182+
accumulated_values.push((flatten_index as usize, value));
10551183
},
10561184
);
10571185

@@ -1087,13 +1215,13 @@ mod test {
10871215

10881216
assert_eq!(accumulated_values, expected_values,
10891217
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}");
1090-
let seen_values = null_state.seen_values.builder.finish_cloned();
1218+
let seen_values = null_state.build_cloned_seen_values();
10911219
mock.validate_seen_values(&seen_values);
10921220

10931221
// Validate the final buffer (one value per group)
10941222
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
10951223

1096-
let null_buffer = null_state.build(EmitTo::All);
1224+
let null_buffer = null_state.build_all_in_once();
10971225

10981226
assert_eq!(null_buffer, expected_null_buffer);
10991227
}

0 commit comments

Comments
 (0)