Skip to content

Commit d1b474c

Browse files
committed
support_ansi_avg
1 parent 07c6acd commit d1b474c

File tree

7 files changed

+82
-229
lines changed

7 files changed

+82
-229
lines changed

docs/source/user-guide/latest/compatibility.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ The following cast operations are generally compatible with Spark except for the
8989
<!-- WARNING! DO NOT MANUALLY MODIFY CONTENT BETWEEN THE BEGIN AND END TAGS -->
9090

9191
<!--BEGIN:COMPAT_CAST_TABLE-->
92-
<!-- prettier-ignore-start -->
9392
| From Type | To Type | Notes |
9493
|-|-|-|
9594
| boolean | byte | |
@@ -166,7 +165,6 @@ The following cast operations are generally compatible with Spark except for the
166165
| timestamp | long | |
167166
| timestamp | string | |
168167
| timestamp | date | |
169-
<!-- prettier-ignore-end -->
170168
<!--END:COMPAT_CAST_TABLE-->
171169

172170
### Incompatible Casts
@@ -176,7 +174,6 @@ The following cast operations are not compatible with Spark for all inputs and a
176174
<!-- WARNING! DO NOT MANUALLY MODIFY CONTENT BETWEEN THE BEGIN AND END TAGS -->
177175

178176
<!--BEGIN:INCOMPAT_CAST_TABLE-->
179-
<!-- prettier-ignore-start -->
180177
| From Type | To Type | Notes |
181178
|-|-|-|
182179
| float | decimal | There can be rounding differences |
@@ -185,7 +182,6 @@ The following cast operations are not compatible with Spark for all inputs and a
185182
| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
186183
| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits |
187184
| string | timestamp | Not all valid formats are supported |
188-
<!-- prettier-ignore-end -->
189185
<!--END:INCOMPAT_CAST_TABLE-->
190186

191187
### Unsupported Casts

docs/source/user-guide/latest/configs.md

Lines changed: 0 additions & 29 deletions
Large diffs are not rendered by default.

native/core/src/execution/planner.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ use datafusion_comet_proto::{
115115
};
116116
use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId;
117117
use datafusion_comet_spark_expr::{
118-
ArrayInsert, Avg, AvgDecimal, AvgInt, Cast, CheckOverflow, Correlation, Covariance,
119-
CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract,
120-
NormalizeNaNAndZero, RLike, RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr,
121-
SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance,
118+
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
119+
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike,
120+
RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr,
121+
ToJson, UnboundColumn, Variance,
122122
};
123123
use itertools::Itertools;
124124
use jni::objects::GlobalRef;
@@ -1893,28 +1893,24 @@ impl PhysicalPlanner {
18931893
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
18941894
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
18951895
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
1896+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
1897+
18961898
let builder = match datatype {
1897-
DataType::Int8
1898-
| DataType::UInt8
1899-
| DataType::Int16
1900-
| DataType::UInt16
1901-
| DataType::Int32 => {
1902-
let func =
1903-
AggregateUDF::new_from_impl(AvgInt::new(datatype, input_datatype));
1904-
AggregateExprBuilder::new(Arc::new(func), vec![child])
1905-
}
19061899
DataType::Decimal128(_, _) => {
19071900
let func =
19081901
AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));
19091902
AggregateExprBuilder::new(Arc::new(func), vec![child])
19101903
}
19111904
_ => {
1912-
// cast to the result data type of AVG if the result data type is different
1913-
// from the input type, e.g. AVG(Int32). We should not expect a cast
1914-
// failure since it should have already been checked at Spark side.
1905+
// For all other numeric types (Int8/16/32/64, Float32/64):
1906+
// Cast to Float64 for accumulation
19151907
let child: Arc<dyn PhysicalExpr> =
1916-
Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None));
1917-
let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype));
1908+
Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None));
1909+
let func = AggregateUDF::new_from_impl(Avg::new(
1910+
"avg",
1911+
DataType::Float64,
1912+
eval_mode,
1913+
));
19181914
AggregateExprBuilder::new(Arc::new(func), vec![child])
19191915
}
19201916
};

native/spark-expr/src/agg_funcs/avg.rs

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::EvalMode;
1819
use arrow::array::{
1920
builder::PrimitiveBuilder,
2021
cast::AsArray,
2122
types::{Float64Type, Int64Type},
22-
Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray,
23+
Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray,
2324
};
2425
use arrow::compute::sum;
2526
use arrow::datatypes::{DataType, Field, FieldRef};
@@ -31,45 +32,43 @@ use datafusion::logical_expr::{
3132
use datafusion::physical_expr::expressions::format_state_name;
3233
use std::{any::Any, sync::Arc};
3334

34-
use arrow::array::ArrowNativeTypeOp;
3535
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
3636
use datafusion::logical_expr::Volatility::Immutable;
3737
use DataType::*;
3838

39-
/// AVG aggregate expression
4039
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4140
pub struct Avg {
4241
name: String,
4342
signature: Signature,
44-
// expr: Arc<dyn PhysicalExpr>,
4543
input_data_type: DataType,
4644
result_data_type: DataType,
45+
eval_mode: EvalMode,
4746
}
4847

4948
impl Avg {
5049
/// Create a new AVG aggregate function
51-
pub fn new(name: impl Into<String>, data_type: DataType) -> Self {
50+
pub fn new(name: impl Into<String>, data_type: DataType, eval_mode: EvalMode) -> Self {
5251
let result_data_type = avg_return_type("avg", &data_type).unwrap();
5352

5453
Self {
5554
name: name.into(),
5655
signature: Signature::user_defined(Immutable),
5756
input_data_type: data_type,
5857
result_data_type,
58+
eval_mode,
5959
}
6060
}
6161
}
6262

6363
impl AggregateUDFImpl for Avg {
64-
/// Return a reference to Any that can be used for downcasting
6564
fn as_any(&self) -> &dyn Any {
6665
self
6766
}
6867

6968
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
70-
// instantiate specialized accumulator based for the type
69+
// All numeric types use Float64 accumulation after casting
7170
match (&self.input_data_type, &self.result_data_type) {
72-
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
71+
(Float64, Float64) => Ok(Box::new(AvgAccumulator::new(self.eval_mode))),
7372
_ => not_impl_err!(
7473
"AvgAccumulator for ({} --> {})",
7574
self.input_data_type,
@@ -109,10 +108,10 @@ impl AggregateUDFImpl for Avg {
109108
&self,
110109
_args: AccumulatorArgs,
111110
) -> Result<Box<dyn GroupsAccumulator>> {
112-
// instantiate specialized accumulator based for the type
113111
match (&self.input_data_type, &self.result_data_type) {
114112
(Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
115113
&self.input_data_type,
114+
self.eval_mode,
116115
|sum: f64, count: i64| Ok(sum / count as f64),
117116
))),
118117

@@ -137,11 +136,22 @@ impl AggregateUDFImpl for Avg {
137136
}
138137
}
139138

140-
/// An accumulator to compute the average
141-
#[derive(Debug, Default)]
139+
#[derive(Debug)]
142140
pub struct AvgAccumulator {
143141
sum: Option<f64>,
144142
count: i64,
143+
#[allow(dead_code)]
144+
eval_mode: EvalMode,
145+
}
146+
147+
impl AvgAccumulator {
148+
pub fn new(eval_mode: EvalMode) -> Self {
149+
Self {
150+
sum: None,
151+
count: 0,
152+
eval_mode,
153+
}
154+
}
145155
}
146156

147157
impl Accumulator for AvgAccumulator {
@@ -166,7 +176,7 @@ impl Accumulator for AvgAccumulator {
166176
// counts are summed
167177
self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
168178

169-
// sums are summed
179+
// sums are summed - no overflow checking
170180
if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
171181
let v = self.sum.get_or_insert(0.);
172182
*v += x;
@@ -176,8 +186,6 @@ impl Accumulator for AvgAccumulator {
176186

177187
fn evaluate(&mut self) -> Result<ScalarValue> {
178188
if self.count == 0 {
179-
// If all input are nulls, count will be 0 and we will get null after the division.
180-
// This is consistent with Spark Average implementation.
181189
Ok(ScalarValue::Float64(None))
182190
} else {
183191
Ok(ScalarValue::Float64(
@@ -192,7 +200,7 @@ impl Accumulator for AvgAccumulator {
192200
}
193201

194202
/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
195-
/// Stores values as native types, and does overflow checking
203+
/// Stores values as native types.
196204
///
197205
/// F: Function that calculates the average value from a sum of
198206
/// T::Native and a total count
@@ -211,6 +219,10 @@ where
211219
/// Sums per group, stored as the native type
212220
sums: Vec<T::Native>,
213221

222+
/// Evaluation mode (stored but not used for Float64)
223+
#[allow(dead_code)]
224+
eval_mode: EvalMode,
225+
214226
/// Function that computes the final average (value / count)
215227
avg_fn: F,
216228
}
@@ -220,11 +232,12 @@ where
220232
T: ArrowNumericType + Send,
221233
F: Fn(T::Native, i64) -> Result<T::Native> + Send,
222234
{
223-
pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
235+
pub fn new(return_data_type: &DataType, eval_mode: EvalMode, avg_fn: F) -> Self {
224236
Self {
225237
return_data_type: return_data_type.clone(),
226238
counts: vec![],
227239
sums: vec![],
240+
eval_mode,
228241
avg_fn,
229242
}
230243
}
@@ -254,6 +267,7 @@ where
254267
if values.null_count() == 0 {
255268
for (&group_index, &value) in iter {
256269
let sum = &mut self.sums[group_index];
270+
// No overflow checking - INFINITY is a valid result
257271
*sum = (*sum).add_wrapping(value);
258272
self.counts[group_index] += 1;
259273
}
@@ -264,7 +278,6 @@ where
264278
}
265279
let sum = &mut self.sums[group_index];
266280
*sum = (*sum).add_wrapping(value);
267-
268281
self.counts[group_index] += 1;
269282
}
270283
}
@@ -280,17 +293,17 @@ where
280293
total_num_groups: usize,
281294
) -> Result<()> {
282295
assert_eq!(values.len(), 2, "two arguments to merge_batch");
283-
// first batch is partial sums, second is counts
284296
let partial_sums = values[0].as_primitive::<T>();
285297
let partial_counts = values[1].as_primitive::<Int64Type>();
298+
286299
// update counts with partial counts
287300
self.counts.resize(total_num_groups, 0);
288301
let iter1 = group_indices.iter().zip(partial_counts.values().iter());
289302
for (&group_index, &partial_count) in iter1 {
290303
self.counts[group_index] += partial_count;
291304
}
292305

293-
// update sums
306+
// update sums - no overflow checking
294307
self.sums.resize(total_num_groups, T::default_value());
295308
let iter2 = group_indices.iter().zip(partial_sums.values().iter());
296309
for (&group_index, &new_value) in iter2 {
@@ -319,7 +332,6 @@ where
319332
Ok(Arc::new(array))
320333
}
321334

322-
// return arrays for sums and counts
323335
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
324336
let counts = emit_to.take_needed(&mut self.counts);
325337
let counts = Int64Array::new(counts.into(), None);

0 commit comments

Comments
 (0)