Skip to content

Commit 6c785d1

Browse files
authored
Specialize Avg and Sum accumulators (#6842) (#7358)
* Specialize SUM and AVG (#6842) * Specialize Distinct Sum * Review feedback * Update sqllogictest
1 parent 65821eb commit 6c785d1

File tree

9 files changed

+356
-412
lines changed

9 files changed

+356
-412
lines changed

datafusion/core/src/execution/context.rs

+1-6
Original file line numberDiff line numberDiff line change
@@ -2452,12 +2452,7 @@ mod tests {
24522452
vec![DataType::Float64],
24532453
Arc::new(DataType::Float64),
24542454
Volatility::Immutable,
2455-
Arc::new(|_| {
2456-
Ok(Box::new(AvgAccumulator::try_new(
2457-
&DataType::Float64,
2458-
&DataType::Float64,
2459-
)?))
2460-
}),
2455+
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
24612456
Arc::new(vec![DataType::UInt64, DataType::Float64]),
24622457
);
24632458

datafusion/core/tests/sql/udf.rs

+1-6
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,7 @@ async fn simple_udaf() -> Result<()> {
237237
vec![DataType::Float64],
238238
Arc::new(DataType::Float64),
239239
Volatility::Immutable,
240-
Arc::new(|_| {
241-
Ok(Box::new(AvgAccumulator::try_new(
242-
&DataType::Float64,
243-
&DataType::Float64,
244-
)?))
245-
}),
240+
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
246241
Arc::new(vec![DataType::UInt64, DataType::Float64]),
247242
);
248243

datafusion/optimizer/src/analyzer/type_coercion.rs

+3-12
Original file line numberDiff line numberDiff line change
@@ -906,12 +906,7 @@ mod test {
906906
vec![DataType::Float64],
907907
Arc::new(DataType::Float64),
908908
Volatility::Immutable,
909-
Arc::new(|_| {
910-
Ok(Box::new(AvgAccumulator::try_new(
911-
&DataType::Float64,
912-
&DataType::Float64,
913-
)?))
914-
}),
909+
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
915910
Arc::new(vec![DataType::UInt64, DataType::Float64]),
916911
);
917912
let udaf = Expr::AggregateUDF(expr::AggregateUDF::new(
@@ -932,12 +927,8 @@ mod test {
932927
Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
933928
let state_type: StateTypeFunction =
934929
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
935-
let accumulator: AccumulatorFactoryFunction = Arc::new(|_| {
936-
Ok(Box::new(AvgAccumulator::try_new(
937-
&DataType::Float64,
938-
&DataType::Float64,
939-
)?))
940-
});
930+
let accumulator: AccumulatorFactoryFunction =
931+
Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
941932
let my_avg = AggregateUDF::new(
942933
"MY_AVG",
943934
&Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),

datafusion/physical-expr/src/aggregate/average.rs

+133-73
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,13 @@ use arrow::array::{AsArray, PrimitiveBuilder};
2121
use log::debug;
2222

2323
use std::any::Any;
24-
use std::convert::TryFrom;
2524
use std::sync::Arc;
2625

2726
use crate::aggregate::groups_accumulator::accumulate::NullState;
28-
use crate::aggregate::sum;
29-
use crate::aggregate::sum::sum_batch;
30-
use crate::aggregate::utils::calculate_result_decimal_for_avg;
3127
use crate::aggregate::utils::down_cast_any_ref;
3228
use crate::expressions::format_state_name;
3329
use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
34-
use arrow::compute;
30+
use arrow::compute::sum;
3531
use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type};
3632
use arrow::{
3733
array::{ArrayRef, UInt64Array},
@@ -40,9 +36,7 @@ use arrow::{
4036
use arrow_array::{
4137
Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray,
4238
};
43-
use datafusion_common::{
44-
downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue,
45-
};
39+
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
4640
use datafusion_expr::type_coercion::aggregates::avg_return_type;
4741
use datafusion_expr::Accumulator;
4842

@@ -87,11 +81,27 @@ impl AggregateExpr for Avg {
8781
}
8882

8983
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
90-
Ok(Box::new(AvgAccumulator::try_new(
91-
// avg is f64 or decimal
92-
&self.input_data_type,
93-
&self.result_data_type,
94-
)?))
84+
use DataType::*;
85+
// instantiate specialized accumulator based for the type
86+
match (&self.input_data_type, &self.result_data_type) {
87+
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
88+
(
89+
Decimal128(sum_precision, sum_scale),
90+
Decimal128(target_precision, target_scale),
91+
) => Ok(Box::new(DecimalAvgAccumulator {
92+
sum: None,
93+
count: 0,
94+
sum_scale: *sum_scale,
95+
sum_precision: *sum_precision,
96+
target_precision: *target_precision,
97+
target_scale: *target_scale,
98+
})),
99+
_ => not_impl_err!(
100+
"AvgAccumulator for ({} --> {})",
101+
self.input_data_type,
102+
self.result_data_type
103+
),
104+
}
95105
}
96106

97107
fn state_fields(&self) -> Result<Vec<Field>> {
@@ -122,10 +132,7 @@ impl AggregateExpr for Avg {
122132
}
123133

124134
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
125-
Ok(Box::new(AvgAccumulator::try_new(
126-
&self.input_data_type,
127-
&self.result_data_type,
128-
)?))
135+
self.create_accumulator()
129136
}
130137

131138
fn groups_accumulator_supported(&self) -> bool {
@@ -189,91 +196,141 @@ impl PartialEq<dyn Any> for Avg {
189196
}
190197

191198
/// An accumulator to compute the average
192-
#[derive(Debug)]
199+
#[derive(Debug, Default)]
193200
pub struct AvgAccumulator {
194-
// sum is used for null
195-
sum: ScalarValue,
196-
return_data_type: DataType,
201+
sum: Option<f64>,
197202
count: u64,
198203
}
199204

200-
impl AvgAccumulator {
201-
/// Creates a new `AvgAccumulator`
202-
pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> Result<Self> {
203-
Ok(Self {
204-
sum: ScalarValue::try_from(datatype)?,
205-
return_data_type: return_data_type.clone(),
206-
count: 0,
207-
})
205+
impl Accumulator for AvgAccumulator {
206+
fn state(&self) -> Result<Vec<ScalarValue>> {
207+
Ok(vec![
208+
ScalarValue::from(self.count),
209+
ScalarValue::Float64(self.sum),
210+
])
211+
}
212+
213+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
214+
let values = values[0].as_primitive::<Float64Type>();
215+
self.count += (values.len() - values.null_count()) as u64;
216+
if let Some(x) = sum(values) {
217+
let v = self.sum.get_or_insert(0.);
218+
*v += x;
219+
}
220+
Ok(())
221+
}
222+
223+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
224+
let values = values[0].as_primitive::<Float64Type>();
225+
self.count -= (values.len() - values.null_count()) as u64;
226+
if let Some(x) = sum(values) {
227+
self.sum = Some(self.sum.unwrap() - x);
228+
}
229+
Ok(())
230+
}
231+
232+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
233+
// counts are summed
234+
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
235+
236+
// sums are summed
237+
if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
238+
let v = self.sum.get_or_insert(0.);
239+
*v += x;
240+
}
241+
Ok(())
242+
}
243+
244+
fn evaluate(&self) -> Result<ScalarValue> {
245+
Ok(ScalarValue::Float64(
246+
self.sum.map(|f| f / self.count as f64),
247+
))
248+
}
249+
fn supports_retract_batch(&self) -> bool {
250+
true
251+
}
252+
253+
fn size(&self) -> usize {
254+
std::mem::size_of_val(self)
208255
}
209256
}
210257

211-
impl Accumulator for AvgAccumulator {
258+
/// An accumulator to compute the average for decimals
259+
#[derive(Debug)]
260+
struct DecimalAvgAccumulator {
261+
sum: Option<i128>,
262+
count: u64,
263+
sum_scale: i8,
264+
sum_precision: u8,
265+
target_precision: u8,
266+
target_scale: i8,
267+
}
268+
269+
impl Accumulator for DecimalAvgAccumulator {
212270
fn state(&self) -> Result<Vec<ScalarValue>> {
213-
Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
271+
Ok(vec![
272+
ScalarValue::from(self.count),
273+
ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale),
274+
])
214275
}
215276

216277
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
217-
let values = &values[0];
278+
let values = values[0].as_primitive::<Decimal128Type>();
218279

219280
self.count += (values.len() - values.null_count()) as u64;
220-
self.sum = self.sum.add(&sum::sum_batch(values)?)?;
281+
if let Some(x) = sum(values) {
282+
let v = self.sum.get_or_insert(0);
283+
*v += x;
284+
}
221285
Ok(())
222286
}
223287

224288
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
225-
let values = &values[0];
289+
let values = values[0].as_primitive::<Decimal128Type>();
226290
self.count -= (values.len() - values.null_count()) as u64;
227-
let delta = sum_batch(values)?;
228-
self.sum = self.sum.sub(&delta)?;
291+
if let Some(x) = sum(values) {
292+
self.sum = Some(self.sum.unwrap() - x);
293+
}
229294
Ok(())
230295
}
231296

232297
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
233-
let counts = downcast_value!(states[0], UInt64Array);
234298
// counts are summed
235-
self.count += compute::sum(counts).unwrap_or(0);
299+
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
236300

237301
// sums are summed
238-
self.sum = self.sum.add(&sum::sum_batch(&states[1])?)?;
302+
if let Some(x) = sum(states[1].as_primitive::<Decimal128Type>()) {
303+
let v = self.sum.get_or_insert(0);
304+
*v += x;
305+
}
239306
Ok(())
240307
}
241308

242309
fn evaluate(&self) -> Result<ScalarValue> {
243-
match self.sum {
244-
ScalarValue::Float64(e) => {
245-
Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64)))
246-
}
247-
ScalarValue::Decimal128(value, _, scale) => {
248-
match value {
249-
None => match &self.return_data_type {
250-
DataType::Decimal128(p, s) => {
251-
Ok(ScalarValue::Decimal128(None, *p, *s))
252-
}
253-
other => internal_err!(
254-
"Error returned data type in AvgAccumulator {other:?}"
255-
),
256-
},
257-
Some(value) => {
258-
// now the sum_type and return type is not the same, need to convert the sum type to return type
259-
calculate_result_decimal_for_avg(
260-
value,
261-
self.count as i128,
262-
scale,
263-
&self.return_data_type,
264-
)
265-
}
266-
}
267-
}
268-
_ => internal_err!("Sum should be f64 or decimal128 on average"),
269-
}
310+
let v = self
311+
.sum
312+
.map(|v| {
313+
Decimal128Averager::try_new(
314+
self.sum_scale,
315+
self.target_precision,
316+
self.target_scale,
317+
)?
318+
.avg(v, self.count as _)
319+
})
320+
.transpose()?;
321+
322+
Ok(ScalarValue::Decimal128(
323+
v,
324+
self.target_precision,
325+
self.target_scale,
326+
))
270327
}
271328
fn supports_retract_batch(&self) -> bool {
272329
true
273330
}
274331

275332
fn size(&self) -> usize {
276-
std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size()
333+
std::mem::size_of_val(self)
277334
}
278335
}
279336

@@ -484,6 +541,7 @@ mod tests {
484541
assert_aggregate(
485542
array,
486543
AggregateFunction::Avg,
544+
false,
487545
ScalarValue::Decimal128(Some(35000), 14, 4),
488546
);
489547
}
@@ -500,6 +558,7 @@ mod tests {
500558
assert_aggregate(
501559
array,
502560
AggregateFunction::Avg,
561+
false,
503562
ScalarValue::Decimal128(Some(32500), 14, 4),
504563
);
505564
}
@@ -517,14 +576,15 @@ mod tests {
517576
assert_aggregate(
518577
array,
519578
AggregateFunction::Avg,
579+
false,
520580
ScalarValue::Decimal128(None, 14, 4),
521581
);
522582
}
523583

524584
#[test]
525585
fn avg_i32() {
526586
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
527-
assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3_f64));
587+
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64));
528588
}
529589

530590
#[test]
@@ -536,33 +596,33 @@ mod tests {
536596
Some(4),
537597
Some(5),
538598
]));
539-
assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3.25f64));
599+
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.25f64));
540600
}
541601

542602
#[test]
543603
fn avg_i32_all_nulls() {
544604
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
545-
assert_aggregate(a, AggregateFunction::Avg, ScalarValue::Float64(None));
605+
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::Float64(None));
546606
}
547607

548608
#[test]
549609
fn avg_u32() {
550610
let a: ArrayRef =
551611
Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
552-
assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3.0f64));
612+
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.0f64));
553613
}
554614

555615
#[test]
556616
fn avg_f32() {
557617
let a: ArrayRef =
558618
Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
559-
assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3_f64));
619+
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64));
560620
}
561621

562622
#[test]
563623
fn avg_f64() {
564624
let a: ArrayRef =
565625
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
566-
assert_aggregate(a, AggregateFunction::Avg, ScalarValue::from(3_f64));
626+
assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64));
567627
}
568628
}

0 commit comments

Comments
 (0)