20
20
use crate :: aggregate:: utils:: down_cast_any_ref;
21
21
use crate :: expressions:: format_state_name;
22
22
use crate :: { AggregateExpr , PhysicalExpr } ;
23
- use arrow:: array:: { Array , ArrayRef , UInt32Array } ;
24
- use arrow:: compute:: sort_to_indices;
23
+ use arrow:: array:: { Array , ArrayRef } ;
25
24
use arrow:: datatypes:: { DataType , Field } ;
26
- use datafusion_common:: internal_err;
25
+ use arrow_array:: cast:: AsArray ;
26
+ use arrow_array:: { downcast_integer, ArrowNativeTypeOp , ArrowNumericType } ;
27
+ use arrow_buffer:: ArrowNativeType ;
27
28
use datafusion_common:: { DataFusionError , Result , ScalarValue } ;
28
29
use datafusion_expr:: Accumulator ;
29
30
use std:: any:: Any ;
31
+ use std:: fmt:: Formatter ;
30
32
use std:: sync:: Arc ;
31
33
32
34
/// MEDIAN aggregate expression. This uses a lot of memory because all values need to be
@@ -65,11 +67,29 @@ impl AggregateExpr for Median {
65
67
}
66
68
67
69
fn create_accumulator ( & self ) -> Result < Box < dyn Accumulator > > {
68
- Ok ( Box :: new ( MedianAccumulator {
69
- data_type : self . data_type . clone ( ) ,
70
- arrays : vec ! [ ] ,
71
- all_values : vec ! [ ] ,
72
- } ) )
70
+ use arrow_array:: types:: * ;
71
+ macro_rules! helper {
72
+ ( $t: ty, $dt: expr) => {
73
+ Ok ( Box :: new( MedianAccumulator :: <$t> {
74
+ data_type: $dt. clone( ) ,
75
+ all_values: vec![ ] ,
76
+ } ) )
77
+ } ;
78
+ }
79
+ let dt = & self . data_type ;
80
+ downcast_integer ! {
81
+ dt => ( helper, dt) ,
82
+ DataType :: Float16 => helper!( Float16Type , dt) ,
83
+ DataType :: Float32 => helper!( Float32Type , dt) ,
84
+ DataType :: Float64 => helper!( Float64Type , dt) ,
85
+ DataType :: Decimal128 ( _, _) => helper!( Decimal128Type , dt) ,
86
+ DataType :: Decimal256 ( _, _) => helper!( Decimal256Type , dt) ,
87
+ _ => Err ( DataFusionError :: NotImplemented ( format!(
88
+ "MedianAccumulator not supported for {} with {}" ,
89
+ self . name( ) ,
90
+ self . data_type
91
+ ) ) ) ,
92
+ }
73
93
}
74
94
75
95
fn state_fields ( & self ) -> Result < Vec < Field > > {
@@ -106,159 +126,75 @@ impl PartialEq<dyn Any> for Median {
106
126
}
107
127
}
108
128
109
- #[ derive( Debug ) ]
110
129
/// The median accumulator accumulates the raw input values
111
130
/// as `ScalarValue`s
112
131
///
113
132
/// The intermediate state is represented as a List of scalar values updated by
114
133
/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
115
134
/// in the final evaluation step so that we avoid expensive conversions and
116
135
/// allocations during `update_batch`.
117
- struct MedianAccumulator {
136
+ struct MedianAccumulator < T : ArrowNumericType > {
118
137
data_type : DataType ,
119
- arrays : Vec < ArrayRef > ,
120
- all_values : Vec < ScalarValue > ,
138
+ all_values : Vec < T :: Native > ,
121
139
}
122
140
123
- impl Accumulator for MedianAccumulator {
141
+ impl < T : ArrowNumericType > std:: fmt:: Debug for MedianAccumulator < T > {
142
+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
143
+ write ! ( f, "MedianAccumulator({})" , self . data_type)
144
+ }
145
+ }
146
+
147
+ impl < T : ArrowNumericType > Accumulator for MedianAccumulator < T > {
124
148
fn state ( & self ) -> Result < Vec < ScalarValue > > {
125
- let all_values = to_scalar_values ( & self . arrays ) ?;
149
+ let all_values = self
150
+ . all_values
151
+ . iter ( )
152
+ . map ( |x| ScalarValue :: new_primitive :: < T > ( Some ( * x) , & self . data_type ) )
153
+ . collect ( ) ;
126
154
let state = ScalarValue :: new_list ( Some ( all_values) , self . data_type . clone ( ) ) ;
127
155
128
156
Ok ( vec ! [ state] )
129
157
}
130
158
131
159
fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
132
- assert_eq ! ( values. len( ) , 1 ) ;
133
- let array = & values[ 0 ] ;
134
-
135
- // Defer conversions to scalar values to final evaluation.
136
- assert_eq ! ( array. data_type( ) , & self . data_type) ;
137
- self . arrays . push ( array. clone ( ) ) ;
138
-
160
+ let values = values[ 0 ] . as_primitive :: < T > ( ) ;
161
+ self . all_values . reserve ( values. len ( ) - values. null_count ( ) ) ;
162
+ self . all_values . extend ( values. iter ( ) . flatten ( ) ) ;
139
163
Ok ( ( ) )
140
164
}
141
165
142
166
fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
143
- assert_eq ! ( states. len( ) , 1 ) ;
144
-
145
- let array = & states[ 0 ] ;
146
- assert ! ( matches!( array. data_type( ) , DataType :: List ( _) ) ) ;
147
- for index in 0 ..array. len ( ) {
148
- match ScalarValue :: try_from_array ( array, index) ? {
149
- ScalarValue :: List ( Some ( mut values) , _) => {
150
- self . all_values . append ( & mut values) ;
151
- }
152
- ScalarValue :: List ( None , _) => { } // skip empty state
153
- v => {
154
- return internal_err ! (
155
- "unexpected state in median. Expected DataType::List, got {v:?}"
156
- )
157
- }
158
- }
167
+ let array = states[ 0 ] . as_list :: < i32 > ( ) ;
168
+ for v in array. iter ( ) . flatten ( ) {
169
+ self . update_batch ( & [ v] ) ?
159
170
}
160
171
Ok ( ( ) )
161
172
}
162
173
163
174
fn evaluate ( & self ) -> Result < ScalarValue > {
164
- let batch_values = to_scalar_values ( & self . arrays ) ?;
165
-
166
- if !self
167
- . all_values
168
- . iter ( )
169
- . chain ( batch_values. iter ( ) )
170
- . any ( |v| !v. is_null ( ) )
171
- {
172
- return ScalarValue :: try_from ( & self . data_type ) ;
173
- }
174
-
175
- // Create an array of all the non null values and find the
176
- // sorted indexes
177
- let array = ScalarValue :: iter_to_array (
178
- self . all_values
179
- . iter ( )
180
- . chain ( batch_values. iter ( ) )
181
- // ignore null values
182
- . filter ( |v| !v. is_null ( ) )
183
- . cloned ( ) ,
184
- ) ?;
185
-
186
- // find the mid point
187
- let len = array. len ( ) ;
188
- let mid = len / 2 ;
189
-
190
- // only sort up to the top size/2 elements
191
- let limit = Some ( mid + 1 ) ;
192
- let options = None ;
193
- let indices = sort_to_indices ( & array, options, limit) ?;
194
-
195
- // pick the relevant indices in the original arrays
196
- let result = if len >= 2 && len % 2 == 0 {
197
- // even number of values, average the two mid points
198
- let s1 = scalar_at_index ( & array, & indices, mid - 1 ) ?;
199
- let s2 = scalar_at_index ( & array, & indices, mid) ?;
200
- match s1. add ( s2) ? {
201
- ScalarValue :: Int8 ( Some ( v) ) => ScalarValue :: Int8 ( Some ( v / 2 ) ) ,
202
- ScalarValue :: Int16 ( Some ( v) ) => ScalarValue :: Int16 ( Some ( v / 2 ) ) ,
203
- ScalarValue :: Int32 ( Some ( v) ) => ScalarValue :: Int32 ( Some ( v / 2 ) ) ,
204
- ScalarValue :: Int64 ( Some ( v) ) => ScalarValue :: Int64 ( Some ( v / 2 ) ) ,
205
- ScalarValue :: UInt8 ( Some ( v) ) => ScalarValue :: UInt8 ( Some ( v / 2 ) ) ,
206
- ScalarValue :: UInt16 ( Some ( v) ) => ScalarValue :: UInt16 ( Some ( v / 2 ) ) ,
207
- ScalarValue :: UInt32 ( Some ( v) ) => ScalarValue :: UInt32 ( Some ( v / 2 ) ) ,
208
- ScalarValue :: UInt64 ( Some ( v) ) => ScalarValue :: UInt64 ( Some ( v / 2 ) ) ,
209
- ScalarValue :: Float32 ( Some ( v) ) => ScalarValue :: Float32 ( Some ( v / 2.0 ) ) ,
210
- ScalarValue :: Float64 ( Some ( v) ) => ScalarValue :: Float64 ( Some ( v / 2.0 ) ) ,
211
- ScalarValue :: Decimal128 ( Some ( v) , p, s) => {
212
- ScalarValue :: Decimal128 ( Some ( v / 2 ) , p, s)
213
- }
214
- v => {
215
- return internal_err ! ( "Unsupported type in MedianAccumulator: {v:?}" )
216
- }
217
- }
175
+ // TODO: evaluate could pass &mut self
176
+ let mut d = self . all_values . clone ( ) ;
177
+ let cmp = |x : & T :: Native , y : & T :: Native | x. compare ( * y) ;
178
+
179
+ let len = d. len ( ) ;
180
+ let median = if len == 0 {
181
+ None
182
+ } else if len % 2 == 0 {
183
+ let ( low, high, _) = d. select_nth_unstable_by ( len / 2 , cmp) ;
184
+ let ( _, low, _) = low. select_nth_unstable_by ( low. len ( ) - 1 , cmp) ;
185
+ let median = low. add_wrapping ( * high) . div_wrapping ( T :: Native :: usize_as ( 2 ) ) ;
186
+ Some ( median)
218
187
} else {
219
- // odd number of values, pick that one
220
- scalar_at_index ( & array , & indices , mid ) ?
188
+ let ( _ , median , _ ) = d . select_nth_unstable_by ( len / 2 , cmp ) ;
189
+ Some ( * median )
221
190
} ;
222
-
223
- Ok ( result)
191
+ Ok ( ScalarValue :: new_primitive :: < T > ( median, & self . data_type ) )
224
192
}
225
193
226
194
fn size ( & self ) -> usize {
227
- let arrays_size: usize = self . arrays . iter ( ) . map ( |a| a. len ( ) ) . sum ( ) ;
228
-
229
195
std:: mem:: size_of_val ( self )
230
- + ScalarValue :: size_of_vec ( & self . all_values )
231
- + arrays_size
232
- - std:: mem:: size_of_val ( & self . all_values )
233
- + self . data_type . size ( )
234
- - std:: mem:: size_of_val ( & self . data_type )
235
- }
236
- }
237
-
238
- fn to_scalar_values ( arrays : & [ ArrayRef ] ) -> Result < Vec < ScalarValue > > {
239
- let num_values: usize = arrays. iter ( ) . map ( |a| a. len ( ) ) . sum ( ) ;
240
- let mut all_values = Vec :: with_capacity ( num_values) ;
241
-
242
- for array in arrays {
243
- for index in 0 ..array. len ( ) {
244
- all_values. push ( ScalarValue :: try_from_array ( & array, index) ?) ;
245
- }
196
+ + self . all_values . capacity ( ) * std:: mem:: size_of :: < T :: Native > ( )
246
197
}
247
-
248
- Ok ( all_values)
249
- }
250
-
251
- /// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue`
252
- fn scalar_at_index (
253
- array : & dyn Array ,
254
- indices : & UInt32Array ,
255
- indicies_index : usize ,
256
- ) -> Result < ScalarValue > {
257
- let array_index = indices
258
- . value ( indicies_index)
259
- . try_into ( )
260
- . expect ( "Convert uint32 to usize" ) ;
261
- ScalarValue :: try_from_array ( array, array_index)
262
198
}
263
199
264
200
#[ cfg( test) ]
@@ -269,7 +205,6 @@ mod tests {
269
205
use crate :: generic_test_op;
270
206
use arrow:: record_batch:: RecordBatch ;
271
207
use arrow:: { array:: * , datatypes:: * } ;
272
- use datafusion_common:: Result ;
273
208
274
209
#[ test]
275
210
fn median_decimal ( ) -> Result < ( ) > {
0 commit comments