@@ -488,6 +488,20 @@ macro_rules! typed_min_max {
488
488
} } ;
489
489
}
490
490
491
+ macro_rules! typed_min_max_float {
492
+ ( $VALUE: expr, $DELTA: expr, $SCALAR: ident, $OP: ident) => { {
493
+ ScalarValue :: $SCALAR( match ( $VALUE, $DELTA) {
494
+ ( None , None ) => None ,
495
+ ( Some ( a) , None ) => Some ( * a) ,
496
+ ( None , Some ( b) ) => Some ( * b) ,
497
+ ( Some ( a) , Some ( b) ) => match a. total_cmp( b) {
498
+ choose_min_max!( $OP) => Some ( * b) ,
499
+ _ => Some ( * a) ,
500
+ } ,
501
+ } )
502
+ } } ;
503
+ }
504
+
491
505
// min/max of two scalar string values.
492
506
macro_rules! typed_min_max_string {
493
507
( $VALUE: expr, $DELTA: expr, $SCALAR: ident, $OP: ident) => { {
@@ -500,7 +514,7 @@ macro_rules! typed_min_max_string {
500
514
} } ;
501
515
}
502
516
503
- macro_rules! interval_choose_min_max {
517
+ macro_rules! choose_min_max {
504
518
( min) => {
505
519
std:: cmp:: Ordering :: Greater
506
520
} ;
@@ -512,7 +526,7 @@ macro_rules! interval_choose_min_max {
512
526
macro_rules! interval_min_max {
513
527
( $OP: tt, $LHS: expr, $RHS: expr) => { {
514
528
match $LHS. partial_cmp( & $RHS) {
515
- Some ( interval_choose_min_max !( $OP) ) => $RHS. clone( ) ,
529
+ Some ( choose_min_max !( $OP) ) => $RHS. clone( ) ,
516
530
Some ( _) => $LHS. clone( ) ,
517
531
None => {
518
532
return internal_err!( "Comparison error while computing interval min/max" )
@@ -555,10 +569,10 @@ macro_rules! min_max {
555
569
typed_min_max!( lhs, rhs, Boolean , $OP)
556
570
}
557
571
( ScalarValue :: Float64 ( lhs) , ScalarValue :: Float64 ( rhs) ) => {
558
- typed_min_max !( lhs, rhs, Float64 , $OP)
572
+ typed_min_max_float !( lhs, rhs, Float64 , $OP)
559
573
}
560
574
( ScalarValue :: Float32 ( lhs) , ScalarValue :: Float32 ( rhs) ) => {
561
- typed_min_max !( lhs, rhs, Float32 , $OP)
575
+ typed_min_max_float !( lhs, rhs, Float32 , $OP)
562
576
}
563
577
( ScalarValue :: UInt64 ( lhs) , ScalarValue :: UInt64 ( rhs) ) => {
564
578
typed_min_max!( lhs, rhs, UInt64 , $OP)
@@ -1103,3 +1117,41 @@ impl Accumulator for SlidingMinAccumulator {
1103
1117
std:: mem:: size_of_val ( self ) - std:: mem:: size_of_val ( & self . min ) + self . min . size ( )
1104
1118
}
1105
1119
}
1120
+
1121
+ #[ cfg( test) ]
1122
+ mod tests {
1123
+ use super :: * ;
1124
+
1125
+ #[ test]
1126
+ fn float_min_max_with_nans ( ) {
1127
+ let pos_nan = f32:: NAN ;
1128
+ let zero = 0_f32 ;
1129
+ let neg_inf = f32:: NEG_INFINITY ;
1130
+
1131
+ let check = |acc : & mut dyn Accumulator , values : & [ & [ f32 ] ] , expected : f32 | {
1132
+ for batch in values. iter ( ) {
1133
+ let batch =
1134
+ Arc :: new ( Float32Array :: from_iter_values ( batch. iter ( ) . copied ( ) ) ) ;
1135
+ acc. update_batch ( & [ batch] ) . unwrap ( ) ;
1136
+ }
1137
+ let result = acc. evaluate ( ) . unwrap ( ) ;
1138
+ assert_eq ! ( result, ScalarValue :: Float32 ( Some ( expected) ) ) ;
1139
+ } ;
1140
+
1141
+ // This test checks both comparison between batches (which uses the min_max macro
1142
+ // defined above) and within a batch (which uses the arrow min/max compute function
1143
+ // and verifies both respect the total order comparison for floats)
1144
+
1145
+ let min = || MinAccumulator :: try_new ( & DataType :: Float32 ) . unwrap ( ) ;
1146
+ let max = || MaxAccumulator :: try_new ( & DataType :: Float32 ) . unwrap ( ) ;
1147
+
1148
+ check ( & mut min ( ) , & [ & [ zero] , & [ pos_nan] ] , zero) ;
1149
+ check ( & mut min ( ) , & [ & [ zero, pos_nan] ] , zero) ;
1150
+ check ( & mut min ( ) , & [ & [ zero] , & [ neg_inf] ] , neg_inf) ;
1151
+ check ( & mut min ( ) , & [ & [ zero, neg_inf] ] , neg_inf) ;
1152
+ check ( & mut max ( ) , & [ & [ zero] , & [ pos_nan] ] , pos_nan) ;
1153
+ check ( & mut max ( ) , & [ & [ zero, pos_nan] ] , pos_nan) ;
1154
+ check ( & mut max ( ) , & [ & [ zero] , & [ neg_inf] ] , zero) ;
1155
+ check ( & mut max ( ) , & [ & [ zero, neg_inf] ] , zero) ;
1156
+ }
1157
+ }
0 commit comments