Skip to content

Commit 65821eb

Browse files
authored
Specialize Median Accumulator (#7376)
* Specialize Median Accumulator * Tweak memory limit test
1 parent ae4b52a commit 65821eb

File tree

2 files changed

+66
-131
lines changed

2 files changed

+66
-131
lines changed

datafusion/core/tests/memory_limit.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ async fn oom_sort() {
6868
#[tokio::test]
6969
async fn group_by_none() {
7070
TestCase::new()
71-
.with_query("select median(image) from t")
71+
.with_query("select median(request_bytes) from t")
7272
.with_expected_errors(vec![
7373
"Resources exhausted: Failed to allocate additional",
7474
"AggregateStream",
7575
])
76-
.with_memory_limit(20_000)
76+
.with_memory_limit(2_000)
7777
.run()
7878
.await
7979
}

datafusion/physical-expr/src/aggregate/median.rs

+64-129
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
use crate::aggregate::utils::down_cast_any_ref;
2121
use crate::expressions::format_state_name;
2222
use crate::{AggregateExpr, PhysicalExpr};
23-
use arrow::array::{Array, ArrayRef, UInt32Array};
24-
use arrow::compute::sort_to_indices;
23+
use arrow::array::{Array, ArrayRef};
2524
use arrow::datatypes::{DataType, Field};
26-
use datafusion_common::internal_err;
25+
use arrow_array::cast::AsArray;
26+
use arrow_array::{downcast_integer, ArrowNativeTypeOp, ArrowNumericType};
27+
use arrow_buffer::ArrowNativeType;
2728
use datafusion_common::{DataFusionError, Result, ScalarValue};
2829
use datafusion_expr::Accumulator;
2930
use std::any::Any;
31+
use std::fmt::Formatter;
3032
use std::sync::Arc;
3133

3234
/// MEDIAN aggregate expression. This uses a lot of memory because all values need to be
@@ -65,11 +67,29 @@ impl AggregateExpr for Median {
6567
}
6668

6769
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
68-
Ok(Box::new(MedianAccumulator {
69-
data_type: self.data_type.clone(),
70-
arrays: vec![],
71-
all_values: vec![],
72-
}))
70+
use arrow_array::types::*;
71+
macro_rules! helper {
72+
($t:ty, $dt:expr) => {
73+
Ok(Box::new(MedianAccumulator::<$t> {
74+
data_type: $dt.clone(),
75+
all_values: vec![],
76+
}))
77+
};
78+
}
79+
let dt = &self.data_type;
80+
downcast_integer! {
81+
dt => (helper, dt),
82+
DataType::Float16 => helper!(Float16Type, dt),
83+
DataType::Float32 => helper!(Float32Type, dt),
84+
DataType::Float64 => helper!(Float64Type, dt),
85+
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
86+
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
87+
_ => Err(DataFusionError::NotImplemented(format!(
88+
"MedianAccumulator not supported for {} with {}",
89+
self.name(),
90+
self.data_type
91+
))),
92+
}
7393
}
7494

7595
fn state_fields(&self) -> Result<Vec<Field>> {
@@ -106,159 +126,75 @@ impl PartialEq<dyn Any> for Median {
106126
}
107127
}
108128

109-
#[derive(Debug)]
110129
/// The median accumulator accumulates the raw input values
111130
/// as `ScalarValue`s
112131
///
113132
/// The intermediate state is represented as a List of scalar values updated by
114133
/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
115134
/// in the final evaluation step so that we avoid expensive conversions and
116135
/// allocations during `update_batch`.
117-
struct MedianAccumulator {
136+
struct MedianAccumulator<T: ArrowNumericType> {
118137
data_type: DataType,
119-
arrays: Vec<ArrayRef>,
120-
all_values: Vec<ScalarValue>,
138+
all_values: Vec<T::Native>,
121139
}
122140

123-
impl Accumulator for MedianAccumulator {
141+
impl<T: ArrowNumericType> std::fmt::Debug for MedianAccumulator<T> {
142+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
143+
write!(f, "MedianAccumulator({})", self.data_type)
144+
}
145+
}
146+
147+
impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
124148
fn state(&self) -> Result<Vec<ScalarValue>> {
125-
let all_values = to_scalar_values(&self.arrays)?;
149+
let all_values = self
150+
.all_values
151+
.iter()
152+
.map(|x| ScalarValue::new_primitive::<T>(Some(*x), &self.data_type))
153+
.collect();
126154
let state = ScalarValue::new_list(Some(all_values), self.data_type.clone());
127155

128156
Ok(vec![state])
129157
}
130158

131159
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
132-
assert_eq!(values.len(), 1);
133-
let array = &values[0];
134-
135-
// Defer conversions to scalar values to final evaluation.
136-
assert_eq!(array.data_type(), &self.data_type);
137-
self.arrays.push(array.clone());
138-
160+
let values = values[0].as_primitive::<T>();
161+
self.all_values.reserve(values.len() - values.null_count());
162+
self.all_values.extend(values.iter().flatten());
139163
Ok(())
140164
}
141165

142166
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
143-
assert_eq!(states.len(), 1);
144-
145-
let array = &states[0];
146-
assert!(matches!(array.data_type(), DataType::List(_)));
147-
for index in 0..array.len() {
148-
match ScalarValue::try_from_array(array, index)? {
149-
ScalarValue::List(Some(mut values), _) => {
150-
self.all_values.append(&mut values);
151-
}
152-
ScalarValue::List(None, _) => {} // skip empty state
153-
v => {
154-
return internal_err!(
155-
"unexpected state in median. Expected DataType::List, got {v:?}"
156-
)
157-
}
158-
}
167+
let array = states[0].as_list::<i32>();
168+
for v in array.iter().flatten() {
169+
self.update_batch(&[v])?
159170
}
160171
Ok(())
161172
}
162173

163174
fn evaluate(&self) -> Result<ScalarValue> {
164-
let batch_values = to_scalar_values(&self.arrays)?;
165-
166-
if !self
167-
.all_values
168-
.iter()
169-
.chain(batch_values.iter())
170-
.any(|v| !v.is_null())
171-
{
172-
return ScalarValue::try_from(&self.data_type);
173-
}
174-
175-
// Create an array of all the non null values and find the
176-
// sorted indexes
177-
let array = ScalarValue::iter_to_array(
178-
self.all_values
179-
.iter()
180-
.chain(batch_values.iter())
181-
// ignore null values
182-
.filter(|v| !v.is_null())
183-
.cloned(),
184-
)?;
185-
186-
// find the mid point
187-
let len = array.len();
188-
let mid = len / 2;
189-
190-
// only sort up to the top size/2 elements
191-
let limit = Some(mid + 1);
192-
let options = None;
193-
let indices = sort_to_indices(&array, options, limit)?;
194-
195-
// pick the relevant indices in the original arrays
196-
let result = if len >= 2 && len % 2 == 0 {
197-
// even number of values, average the two mid points
198-
let s1 = scalar_at_index(&array, &indices, mid - 1)?;
199-
let s2 = scalar_at_index(&array, &indices, mid)?;
200-
match s1.add(s2)? {
201-
ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(v / 2)),
202-
ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(v / 2)),
203-
ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(v / 2)),
204-
ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(v / 2)),
205-
ScalarValue::UInt8(Some(v)) => ScalarValue::UInt8(Some(v / 2)),
206-
ScalarValue::UInt16(Some(v)) => ScalarValue::UInt16(Some(v / 2)),
207-
ScalarValue::UInt32(Some(v)) => ScalarValue::UInt32(Some(v / 2)),
208-
ScalarValue::UInt64(Some(v)) => ScalarValue::UInt64(Some(v / 2)),
209-
ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(v / 2.0)),
210-
ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(v / 2.0)),
211-
ScalarValue::Decimal128(Some(v), p, s) => {
212-
ScalarValue::Decimal128(Some(v / 2), p, s)
213-
}
214-
v => {
215-
return internal_err!("Unsupported type in MedianAccumulator: {v:?}")
216-
}
217-
}
175+
// TODO: evaluate could pass &mut self
176+
let mut d = self.all_values.clone();
177+
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
178+
179+
let len = d.len();
180+
let median = if len == 0 {
181+
None
182+
} else if len % 2 == 0 {
183+
let (low, high, _) = d.select_nth_unstable_by(len / 2, cmp);
184+
let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp);
185+
let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2));
186+
Some(median)
218187
} else {
219-
// odd number of values, pick that one
220-
scalar_at_index(&array, &indices, mid)?
188+
let (_, median, _) = d.select_nth_unstable_by(len / 2, cmp);
189+
Some(*median)
221190
};
222-
223-
Ok(result)
191+
Ok(ScalarValue::new_primitive::<T>(median, &self.data_type))
224192
}
225193

226194
fn size(&self) -> usize {
227-
let arrays_size: usize = self.arrays.iter().map(|a| a.len()).sum();
228-
229195
std::mem::size_of_val(self)
230-
+ ScalarValue::size_of_vec(&self.all_values)
231-
+ arrays_size
232-
- std::mem::size_of_val(&self.all_values)
233-
+ self.data_type.size()
234-
- std::mem::size_of_val(&self.data_type)
235-
}
236-
}
237-
238-
fn to_scalar_values(arrays: &[ArrayRef]) -> Result<Vec<ScalarValue>> {
239-
let num_values: usize = arrays.iter().map(|a| a.len()).sum();
240-
let mut all_values = Vec::with_capacity(num_values);
241-
242-
for array in arrays {
243-
for index in 0..array.len() {
244-
all_values.push(ScalarValue::try_from_array(&array, index)?);
245-
}
196+
+ self.all_values.capacity() * std::mem::size_of::<T::Native>()
246197
}
247-
248-
Ok(all_values)
249-
}
250-
251-
/// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue`
252-
fn scalar_at_index(
253-
array: &dyn Array,
254-
indices: &UInt32Array,
255-
indicies_index: usize,
256-
) -> Result<ScalarValue> {
257-
let array_index = indices
258-
.value(indicies_index)
259-
.try_into()
260-
.expect("Convert uint32 to usize");
261-
ScalarValue::try_from_array(array, array_index)
262198
}
263199

264200
#[cfg(test)]
@@ -269,7 +205,6 @@ mod tests {
269205
use crate::generic_test_op;
270206
use arrow::record_batch::RecordBatch;
271207
use arrow::{array::*, datatypes::*};
272-
use datafusion_common::Result;
273208

274209
#[test]
275210
fn median_decimal() -> Result<()> {

0 commit comments

Comments
 (0)