@@ -21,17 +21,13 @@ use arrow::array::{AsArray, PrimitiveBuilder};
21
21
use log:: debug;
22
22
23
23
use std:: any:: Any ;
24
- use std:: convert:: TryFrom ;
25
24
use std:: sync:: Arc ;
26
25
27
26
use crate :: aggregate:: groups_accumulator:: accumulate:: NullState ;
28
- use crate :: aggregate:: sum;
29
- use crate :: aggregate:: sum:: sum_batch;
30
- use crate :: aggregate:: utils:: calculate_result_decimal_for_avg;
31
27
use crate :: aggregate:: utils:: down_cast_any_ref;
32
28
use crate :: expressions:: format_state_name;
33
29
use crate :: { AggregateExpr , GroupsAccumulator , PhysicalExpr } ;
34
- use arrow:: compute;
30
+ use arrow:: compute:: sum ;
35
31
use arrow:: datatypes:: { DataType , Decimal128Type , Float64Type , UInt64Type } ;
36
32
use arrow:: {
37
33
array:: { ArrayRef , UInt64Array } ,
@@ -40,9 +36,7 @@ use arrow::{
40
36
use arrow_array:: {
41
37
Array , ArrowNativeTypeOp , ArrowNumericType , ArrowPrimitiveType , PrimitiveArray ,
42
38
} ;
43
- use datafusion_common:: {
44
- downcast_value, internal_err, not_impl_err, DataFusionError , Result , ScalarValue ,
45
- } ;
39
+ use datafusion_common:: { not_impl_err, DataFusionError , Result , ScalarValue } ;
46
40
use datafusion_expr:: type_coercion:: aggregates:: avg_return_type;
47
41
use datafusion_expr:: Accumulator ;
48
42
@@ -87,11 +81,27 @@ impl AggregateExpr for Avg {
87
81
}
88
82
89
83
fn create_accumulator ( & self ) -> Result < Box < dyn Accumulator > > {
90
- Ok ( Box :: new ( AvgAccumulator :: try_new (
91
- // avg is f64 or decimal
92
- & self . input_data_type ,
93
- & self . result_data_type ,
94
- ) ?) )
84
+ use DataType :: * ;
85
+ // instantiate specialized accumulator based for the type
86
+ match ( & self . input_data_type , & self . result_data_type ) {
87
+ ( Float64 , Float64 ) => Ok ( Box :: < AvgAccumulator > :: default ( ) ) ,
88
+ (
89
+ Decimal128 ( sum_precision, sum_scale) ,
90
+ Decimal128 ( target_precision, target_scale) ,
91
+ ) => Ok ( Box :: new ( DecimalAvgAccumulator {
92
+ sum : None ,
93
+ count : 0 ,
94
+ sum_scale : * sum_scale,
95
+ sum_precision : * sum_precision,
96
+ target_precision : * target_precision,
97
+ target_scale : * target_scale,
98
+ } ) ) ,
99
+ _ => not_impl_err ! (
100
+ "AvgAccumulator for ({} --> {})" ,
101
+ self . input_data_type,
102
+ self . result_data_type
103
+ ) ,
104
+ }
95
105
}
96
106
97
107
fn state_fields ( & self ) -> Result < Vec < Field > > {
@@ -122,10 +132,7 @@ impl AggregateExpr for Avg {
122
132
}
123
133
124
134
fn create_sliding_accumulator ( & self ) -> Result < Box < dyn Accumulator > > {
125
- Ok ( Box :: new ( AvgAccumulator :: try_new (
126
- & self . input_data_type ,
127
- & self . result_data_type ,
128
- ) ?) )
135
+ self . create_accumulator ( )
129
136
}
130
137
131
138
fn groups_accumulator_supported ( & self ) -> bool {
@@ -189,91 +196,141 @@ impl PartialEq<dyn Any> for Avg {
189
196
}
190
197
191
198
/// An accumulator to compute the average
192
- #[ derive( Debug ) ]
199
+ #[ derive( Debug , Default ) ]
193
200
pub struct AvgAccumulator {
194
- // sum is used for null
195
- sum : ScalarValue ,
196
- return_data_type : DataType ,
201
+ sum : Option < f64 > ,
197
202
count : u64 ,
198
203
}
199
204
200
- impl AvgAccumulator {
201
- /// Creates a new `AvgAccumulator`
202
- pub fn try_new ( datatype : & DataType , return_data_type : & DataType ) -> Result < Self > {
203
- Ok ( Self {
204
- sum : ScalarValue :: try_from ( datatype) ?,
205
- return_data_type : return_data_type. clone ( ) ,
206
- count : 0 ,
207
- } )
205
+ impl Accumulator for AvgAccumulator {
206
+ fn state ( & self ) -> Result < Vec < ScalarValue > > {
207
+ Ok ( vec ! [
208
+ ScalarValue :: from( self . count) ,
209
+ ScalarValue :: Float64 ( self . sum) ,
210
+ ] )
211
+ }
212
+
213
+ fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
214
+ let values = values[ 0 ] . as_primitive :: < Float64Type > ( ) ;
215
+ self . count += ( values. len ( ) - values. null_count ( ) ) as u64 ;
216
+ if let Some ( x) = sum ( values) {
217
+ let v = self . sum . get_or_insert ( 0. ) ;
218
+ * v += x;
219
+ }
220
+ Ok ( ( ) )
221
+ }
222
+
223
+ fn retract_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
224
+ let values = values[ 0 ] . as_primitive :: < Float64Type > ( ) ;
225
+ self . count -= ( values. len ( ) - values. null_count ( ) ) as u64 ;
226
+ if let Some ( x) = sum ( values) {
227
+ self . sum = Some ( self . sum . unwrap ( ) - x) ;
228
+ }
229
+ Ok ( ( ) )
230
+ }
231
+
232
+ fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
233
+ // counts are summed
234
+ self . count += sum ( states[ 0 ] . as_primitive :: < UInt64Type > ( ) ) . unwrap_or_default ( ) ;
235
+
236
+ // sums are summed
237
+ if let Some ( x) = sum ( states[ 1 ] . as_primitive :: < Float64Type > ( ) ) {
238
+ let v = self . sum . get_or_insert ( 0. ) ;
239
+ * v += x;
240
+ }
241
+ Ok ( ( ) )
242
+ }
243
+
244
+ fn evaluate ( & self ) -> Result < ScalarValue > {
245
+ Ok ( ScalarValue :: Float64 (
246
+ self . sum . map ( |f| f / self . count as f64 ) ,
247
+ ) )
248
+ }
249
+ fn supports_retract_batch ( & self ) -> bool {
250
+ true
251
+ }
252
+
253
+ fn size ( & self ) -> usize {
254
+ std:: mem:: size_of_val ( self )
208
255
}
209
256
}
210
257
211
- impl Accumulator for AvgAccumulator {
258
+ /// An accumulator to compute the average for decimals
259
+ #[ derive( Debug ) ]
260
+ struct DecimalAvgAccumulator {
261
+ sum : Option < i128 > ,
262
+ count : u64 ,
263
+ sum_scale : i8 ,
264
+ sum_precision : u8 ,
265
+ target_precision : u8 ,
266
+ target_scale : i8 ,
267
+ }
268
+
269
+ impl Accumulator for DecimalAvgAccumulator {
212
270
fn state ( & self ) -> Result < Vec < ScalarValue > > {
213
- Ok ( vec ! [ ScalarValue :: from( self . count) , self . sum. clone( ) ] )
271
+ Ok ( vec ! [
272
+ ScalarValue :: from( self . count) ,
273
+ ScalarValue :: Decimal128 ( self . sum, self . sum_precision, self . sum_scale) ,
274
+ ] )
214
275
}
215
276
216
277
fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
217
- let values = & values[ 0 ] ;
278
+ let values = values[ 0 ] . as_primitive :: < Decimal128Type > ( ) ;
218
279
219
280
self . count += ( values. len ( ) - values. null_count ( ) ) as u64 ;
220
- self . sum = self . sum . add ( & sum:: sum_batch ( values) ?) ?;
281
+ if let Some ( x) = sum ( values) {
282
+ let v = self . sum . get_or_insert ( 0 ) ;
283
+ * v += x;
284
+ }
221
285
Ok ( ( ) )
222
286
}
223
287
224
288
fn retract_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
225
- let values = & values[ 0 ] ;
289
+ let values = values[ 0 ] . as_primitive :: < Decimal128Type > ( ) ;
226
290
self . count -= ( values. len ( ) - values. null_count ( ) ) as u64 ;
227
- let delta = sum_batch ( values) ?;
228
- self . sum = self . sum . sub ( & delta) ?;
291
+ if let Some ( x) = sum ( values) {
292
+ self . sum = Some ( self . sum . unwrap ( ) - x) ;
293
+ }
229
294
Ok ( ( ) )
230
295
}
231
296
232
297
fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
233
- let counts = downcast_value ! ( states[ 0 ] , UInt64Array ) ;
234
298
// counts are summed
235
- self . count += compute :: sum ( counts ) . unwrap_or ( 0 ) ;
299
+ self . count += sum ( states [ 0 ] . as_primitive :: < UInt64Type > ( ) ) . unwrap_or_default ( ) ;
236
300
237
301
// sums are summed
238
- self . sum = self . sum . add ( & sum:: sum_batch ( & states[ 1 ] ) ?) ?;
302
+ if let Some ( x) = sum ( states[ 1 ] . as_primitive :: < Decimal128Type > ( ) ) {
303
+ let v = self . sum . get_or_insert ( 0 ) ;
304
+ * v += x;
305
+ }
239
306
Ok ( ( ) )
240
307
}
241
308
242
309
fn evaluate ( & self ) -> Result < ScalarValue > {
243
- match self . sum {
244
- ScalarValue :: Float64 ( e) => {
245
- Ok ( ScalarValue :: Float64 ( e. map ( |f| f / self . count as f64 ) ) )
246
- }
247
- ScalarValue :: Decimal128 ( value, _, scale) => {
248
- match value {
249
- None => match & self . return_data_type {
250
- DataType :: Decimal128 ( p, s) => {
251
- Ok ( ScalarValue :: Decimal128 ( None , * p, * s) )
252
- }
253
- other => internal_err ! (
254
- "Error returned data type in AvgAccumulator {other:?}"
255
- ) ,
256
- } ,
257
- Some ( value) => {
258
- // now the sum_type and return type is not the same, need to convert the sum type to return type
259
- calculate_result_decimal_for_avg (
260
- value,
261
- self . count as i128 ,
262
- scale,
263
- & self . return_data_type ,
264
- )
265
- }
266
- }
267
- }
268
- _ => internal_err ! ( "Sum should be f64 or decimal128 on average" ) ,
269
- }
310
+ let v = self
311
+ . sum
312
+ . map ( |v| {
313
+ Decimal128Averager :: try_new (
314
+ self . sum_scale ,
315
+ self . target_precision ,
316
+ self . target_scale ,
317
+ ) ?
318
+ . avg ( v, self . count as _ )
319
+ } )
320
+ . transpose ( ) ?;
321
+
322
+ Ok ( ScalarValue :: Decimal128 (
323
+ v,
324
+ self . target_precision ,
325
+ self . target_scale ,
326
+ ) )
270
327
}
271
328
fn supports_retract_batch ( & self ) -> bool {
272
329
true
273
330
}
274
331
275
332
fn size ( & self ) -> usize {
276
- std:: mem:: size_of_val ( self ) - std :: mem :: size_of_val ( & self . sum ) + self . sum . size ( )
333
+ std:: mem:: size_of_val ( self )
277
334
}
278
335
}
279
336
@@ -484,6 +541,7 @@ mod tests {
484
541
assert_aggregate (
485
542
array,
486
543
AggregateFunction :: Avg ,
544
+ false ,
487
545
ScalarValue :: Decimal128 ( Some ( 35000 ) , 14 , 4 ) ,
488
546
) ;
489
547
}
@@ -500,6 +558,7 @@ mod tests {
500
558
assert_aggregate (
501
559
array,
502
560
AggregateFunction :: Avg ,
561
+ false ,
503
562
ScalarValue :: Decimal128 ( Some ( 32500 ) , 14 , 4 ) ,
504
563
) ;
505
564
}
@@ -517,14 +576,15 @@ mod tests {
517
576
assert_aggregate (
518
577
array,
519
578
AggregateFunction :: Avg ,
579
+ false ,
520
580
ScalarValue :: Decimal128 ( None , 14 , 4 ) ,
521
581
) ;
522
582
}
523
583
524
584
#[ test]
525
585
fn avg_i32 ( ) {
526
586
let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
527
- assert_aggregate ( a, AggregateFunction :: Avg , ScalarValue :: from ( 3_f64 ) ) ;
587
+ assert_aggregate ( a, AggregateFunction :: Avg , false , ScalarValue :: from ( 3_f64 ) ) ;
528
588
}
529
589
530
590
#[ test]
@@ -536,33 +596,33 @@ mod tests {
536
596
Some ( 4 ) ,
537
597
Some ( 5 ) ,
538
598
] ) ) ;
539
- assert_aggregate ( a, AggregateFunction :: Avg , ScalarValue :: from ( 3.25f64 ) ) ;
599
+ assert_aggregate ( a, AggregateFunction :: Avg , false , ScalarValue :: from ( 3.25f64 ) ) ;
540
600
}
541
601
542
602
#[ test]
543
603
fn avg_i32_all_nulls ( ) {
544
604
let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ None , None ] ) ) ;
545
- assert_aggregate ( a, AggregateFunction :: Avg , ScalarValue :: Float64 ( None ) ) ;
605
+ assert_aggregate ( a, AggregateFunction :: Avg , false , ScalarValue :: Float64 ( None ) ) ;
546
606
}
547
607
548
608
#[ test]
549
609
fn avg_u32 ( ) {
550
610
let a: ArrayRef =
551
611
Arc :: new ( UInt32Array :: from ( vec ! [ 1_u32 , 2_u32 , 3_u32 , 4_u32 , 5_u32 ] ) ) ;
552
- assert_aggregate ( a, AggregateFunction :: Avg , ScalarValue :: from ( 3.0f64 ) ) ;
612
+ assert_aggregate ( a, AggregateFunction :: Avg , false , ScalarValue :: from ( 3.0f64 ) ) ;
553
613
}
554
614
555
615
#[ test]
556
616
fn avg_f32 ( ) {
557
617
let a: ArrayRef =
558
618
Arc :: new ( Float32Array :: from ( vec ! [ 1_f32 , 2_f32 , 3_f32 , 4_f32 , 5_f32 ] ) ) ;
559
- assert_aggregate ( a, AggregateFunction :: Avg , ScalarValue :: from ( 3_f64 ) ) ;
619
+ assert_aggregate ( a, AggregateFunction :: Avg , false , ScalarValue :: from ( 3_f64 ) ) ;
560
620
}
561
621
562
622
#[ test]
563
623
fn avg_f64 ( ) {
564
624
let a: ArrayRef =
565
625
Arc :: new ( Float64Array :: from ( vec ! [ 1_f64 , 2_f64 , 3_f64 , 4_f64 , 5_f64 ] ) ) ;
566
- assert_aggregate ( a, AggregateFunction :: Avg , ScalarValue :: from ( 3_f64 ) ) ;
626
+ assert_aggregate ( a, AggregateFunction :: Avg , false , ScalarValue :: from ( 3_f64 ) ) ;
567
627
}
568
628
}
0 commit comments