Skip to content

Commit 5bb6b35

Browse files
authored
fix: use total ordering in the min & max accumulator for floats (#10627)
* fix: use total ordering in the min & max accumulator for floats to match the ordering used by arrow kernels * change unit test to expect min to be nan * changed behavior again since the partial_cmp approach doesn't handle nulls correctly * Revert change to describe test. It was not originating from a nan/finite discrepency but from a null/defined discrepency and we don't want that behavior to change * Update the test to check the min function and also verify the result
1 parent cb9068c commit 5bb6b35

File tree

1 file changed

+56
-4
lines changed

1 file changed

+56
-4
lines changed

datafusion/physical-expr/src/aggregate/min_max.rs

+56-4
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,20 @@ macro_rules! typed_min_max {
488488
}};
489489
}
490490

491+
macro_rules! typed_min_max_float {
492+
($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
493+
ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
494+
(None, None) => None,
495+
(Some(a), None) => Some(*a),
496+
(None, Some(b)) => Some(*b),
497+
(Some(a), Some(b)) => match a.total_cmp(b) {
498+
choose_min_max!($OP) => Some(*b),
499+
_ => Some(*a),
500+
},
501+
})
502+
}};
503+
}
504+
491505
// min/max of two scalar string values.
492506
macro_rules! typed_min_max_string {
493507
($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
@@ -500,7 +514,7 @@ macro_rules! typed_min_max_string {
500514
}};
501515
}
502516

503-
macro_rules! interval_choose_min_max {
517+
macro_rules! choose_min_max {
504518
(min) => {
505519
std::cmp::Ordering::Greater
506520
};
@@ -512,7 +526,7 @@ macro_rules! interval_choose_min_max {
512526
macro_rules! interval_min_max {
513527
($OP:tt, $LHS:expr, $RHS:expr) => {{
514528
match $LHS.partial_cmp(&$RHS) {
515-
Some(interval_choose_min_max!($OP)) => $RHS.clone(),
529+
Some(choose_min_max!($OP)) => $RHS.clone(),
516530
Some(_) => $LHS.clone(),
517531
None => {
518532
return internal_err!("Comparison error while computing interval min/max")
@@ -555,10 +569,10 @@ macro_rules! min_max {
555569
typed_min_max!(lhs, rhs, Boolean, $OP)
556570
}
557571
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
558-
typed_min_max!(lhs, rhs, Float64, $OP)
572+
typed_min_max_float!(lhs, rhs, Float64, $OP)
559573
}
560574
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
561-
typed_min_max!(lhs, rhs, Float32, $OP)
575+
typed_min_max_float!(lhs, rhs, Float32, $OP)
562576
}
563577
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
564578
typed_min_max!(lhs, rhs, UInt64, $OP)
@@ -1103,3 +1117,41 @@ impl Accumulator for SlidingMinAccumulator {
11031117
std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size()
11041118
}
11051119
}
1120+
1121+
#[cfg(test)]
1122+
mod tests {
1123+
use super::*;
1124+
1125+
#[test]
1126+
fn float_min_max_with_nans() {
1127+
let pos_nan = f32::NAN;
1128+
let zero = 0_f32;
1129+
let neg_inf = f32::NEG_INFINITY;
1130+
1131+
let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| {
1132+
for batch in values.iter() {
1133+
let batch =
1134+
Arc::new(Float32Array::from_iter_values(batch.iter().copied()));
1135+
acc.update_batch(&[batch]).unwrap();
1136+
}
1137+
let result = acc.evaluate().unwrap();
1138+
assert_eq!(result, ScalarValue::Float32(Some(expected)));
1139+
};
1140+
1141+
// This test checks both comparison between batches (which uses the min_max macro
1142+
// defined above) and within a batch (which uses the arrow min/max compute function
1143+
// and verifies both respect the total order comparison for floats)
1144+
1145+
let min = || MinAccumulator::try_new(&DataType::Float32).unwrap();
1146+
let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap();
1147+
1148+
check(&mut min(), &[&[zero], &[pos_nan]], zero);
1149+
check(&mut min(), &[&[zero, pos_nan]], zero);
1150+
check(&mut min(), &[&[zero], &[neg_inf]], neg_inf);
1151+
check(&mut min(), &[&[zero, neg_inf]], neg_inf);
1152+
check(&mut max(), &[&[zero], &[pos_nan]], pos_nan);
1153+
check(&mut max(), &[&[zero, pos_nan]], pos_nan);
1154+
check(&mut max(), &[&[zero], &[neg_inf]], zero);
1155+
check(&mut max(), &[&[zero, neg_inf]], zero);
1156+
}
1157+
}

0 commit comments

Comments
 (0)