Skip to content

Commit 42daa94

Browse files
committed
Support convert_to_state for AVG accumulator
1 parent 6aad19f commit 42daa94

File tree

4 files changed

+145
-1
lines changed

4 files changed

+145
-1
lines changed

datafusion/functions-aggregate/src/average.rs

+29-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
2020
use arrow::array::{
2121
self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
22-
AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
22+
AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
2323
};
24+
2425
use arrow::compute::sum;
2526
use arrow::datatypes::{
2627
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
@@ -35,6 +36,9 @@ use datafusion_expr::{
3536
Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
3637
};
3738
use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState;
39+
use datafusion_physical_expr_common::aggregate::groups_accumulator::nulls::{
40+
filtered_null_mask, set_nulls,
41+
};
3842
use datafusion_physical_expr_common::aggregate::utils::DecimalAverager;
3943
use log::debug;
4044
use std::any::Any;
@@ -554,6 +558,30 @@ where
554558
Ok(())
555559
}
556560

561+
fn convert_to_state(
562+
&self,
563+
values: &[ArrayRef],
564+
opt_filter: Option<&BooleanArray>,
565+
) -> Result<Vec<ArrayRef>> {
566+
let sums = values[0]
567+
.as_primitive::<T>()
568+
.clone()
569+
.with_data_type(self.sum_data_type.clone());
570+
let counts = UInt64Array::from_value(1, sums.len());
571+
572+
let nulls = filtered_null_mask(opt_filter, &sums);
573+
574+
// set nulls on the arrays
575+
let counts = set_nulls(counts, nulls.clone());
576+
let sums = set_nulls(sums, nulls);
577+
578+
Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
579+
}
580+
581+
fn supports_convert_to_state(&self) -> bool {
582+
true
583+
}
584+
557585
fn size(&self) -> usize {
558586
self.counts.capacity() * std::mem::size_of::<u64>()
559587
+ self.sums.capacity() * std::mem::size_of::<T>()

datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
2020
pub mod accumulate;
2121
pub mod bool_op;
22+
pub mod nulls;
2223
pub mod prim_op;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! XX utlities for working with nulls
19+
20+
use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray};
21+
use arrow::buffer::NullBuffer;
22+
23+
/// Sets the validity mask for a `PrimitiveArray` to `nulls`
24+
/// replacing any existing null mask
25+
pub fn set_nulls<T: ArrowNumericType + Send>(
26+
array: PrimitiveArray<T>,
27+
nulls: Option<NullBuffer>,
28+
) -> PrimitiveArray<T> {
29+
let (dt, values, _old_nulls) = array.into_parts();
30+
PrimitiveArray::<T>::new(values, nulls).with_data_type(dt)
31+
}
32+
33+
/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer.
34+
///
35+
/// The `NullBuffer` is
36+
/// * `true` (representing valid) for values that were `true` in filter
37+
/// * `false` (representing null) for values that were `false` or `null` in filter
38+
fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
39+
let (filter_bools, filter_nulls) = filter.clone().into_parts();
40+
let filter_bools = NullBuffer::from(filter_bools);
41+
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
42+
}
43+
44+
/// Compute an output validity mask for an array that has been filtered
45+
///
46+
/// This can be used to compute nulls for the output of
47+
/// [`GroupsAccumulator::emit_to_state`], which quickly applies an optional
48+
/// filter to the input rows by setting any filtered rows to NULL in the output.
49+
/// Subsequent applications of aggregate functions that ignore NULLs (most of
50+
/// them) will thus ignore the filtered rows as well.
51+
///
52+
/// # Output element is `true`
53+
/// * A `true` in the output represents non null output for all values that were both:
54+
///
55+
/// * `true` in any `opt_filter` (aka values that passed the filter)
56+
///
57+
/// * `non null` in `input`
58+
///
59+
/// # Output element is `false`
60+
/// * is false (null) for all values that were false in the filter or null in the input
61+
///
62+
/// # Example
63+
///
64+
/// ```text
65+
/// ┌─────┐ ┌─────┐ ┌─────┐
66+
/// │true │ │NULL │ │NULL │
67+
/// │true │ │ │true │ │true │
68+
/// │true │ ───┼─── │false│ ────────▶ │false│ filtered_nulls
69+
/// │false│ │ │NULL │ │NULL │
70+
/// │false│ │true │ │true │
71+
/// └─────┘ └─────┘ └─────┘
72+
/// array opt_filter output nulls
73+
/// .nulls()
74+
///
75+
/// false = NULL true = pass false = NULL Meanings
76+
/// true = valid false = filter true = valid
77+
/// NULL = filter
78+
/// ```
79+
///
80+
/// [`GroupsAccumulator::emit_to_state`]: datafusion_expr::groups_accumulator::GroupsAccumulator::emit_to_state
81+
pub fn filtered_null_mask(
82+
opt_filter: Option<&BooleanArray>,
83+
input: &dyn Array,
84+
) -> Option<NullBuffer> {
85+
let opt_filter = opt_filter.and_then(filter_to_nulls);
86+
NullBuffer::union(opt_filter.as_ref(), input.nulls())
87+
}

datafusion/sqllogictest/test_files/aggregate_skip_partial.slt

+28
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,20 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
209209
4 29 9.531112968922
210210
5 -194 7.074412226677
211211

212+
# Test avg for bigint / float
213+
query IRR
214+
SELECT
215+
c2,
216+
avg(c10),
217+
avg(c11)
218+
FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
219+
----
220+
1 9803675241365398000 0.552420626987
221+
2 6843194947657418000 0.435355657881
222+
3 10700987547561746000 0.504783755855
223+
4 7199224282513318000 0.41439621604
224+
5 9295051061697067000 0.505315159048
225+
212226
# Enabling PG dialect for filtered aggregates tests
213227
statement ok
214228
set datafusion.sql_parser.dialect = 'Postgres';
@@ -267,6 +281,20 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
267281
4 11 14
268282
5 8 7
269283

284+
# Test avg for bigint / float with filter
285+
query IRR
286+
SELECT
287+
c2,
288+
avg(c10) FILTER (WHERE c2 != 'e'),
289+
avg(c11) FILTER (WHERE c2 != 'e')
290+
FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
291+
----
292+
1 9803675241365398000 0.552420626987
293+
2 6843194947657418000 0.435355657881
294+
3 10700987547561746000 0.504783755855
295+
4 7199224282513318000 0.41439621604
296+
5 9295051061697067000 0.505315159048
297+
270298
# Test count with nullable fields and nullable filter
271299
query III
272300
SELECT c2,

0 commit comments

Comments
 (0)