Skip to content

Commit b4069a6

Browse files
authored
Remove AggregateFunctionDefinition (#11803)
* Remove �[200~if udf.name() == count => { * Apply review suggestions
1 parent c8e5996 commit b4069a6

File tree

13 files changed

+144
-199
lines changed

13 files changed

+144
-199
lines changed

datafusion/core/src/physical_planner.rs

+32-37
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ use datafusion_common::{
7474
};
7575
use datafusion_expr::dml::CopyTo;
7676
use datafusion_expr::expr::{
77-
self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr,
78-
Cast, GroupingSet, InList, Like, TryCast, WindowFunction,
77+
self, AggregateFunction, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like,
78+
TryCast, WindowFunction,
7979
};
8080
use datafusion_expr::expr_rewriter::unnormalize_cols;
8181
use datafusion_expr::expr_vec_fmt;
@@ -223,18 +223,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
223223
create_function_physical_name(&fun.to_string(), false, args, Some(order_by))
224224
}
225225
Expr::AggregateFunction(AggregateFunction {
226-
func_def,
226+
func,
227227
distinct,
228228
args,
229229
filter: _,
230230
order_by,
231231
null_treatment: _,
232-
}) => create_function_physical_name(
233-
func_def.name(),
234-
*distinct,
235-
args,
236-
order_by.as_ref(),
237-
),
232+
}) => {
233+
create_function_physical_name(func.name(), *distinct, args, order_by.as_ref())
234+
}
238235
Expr::GroupingSet(grouping_set) => match grouping_set {
239236
GroupingSet::Rollup(exprs) => Ok(format!(
240237
"ROLLUP ({})",
@@ -1817,7 +1814,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18171814
) -> Result<AggregateExprWithOptionalArgs> {
18181815
match e {
18191816
Expr::AggregateFunction(AggregateFunction {
1820-
func_def,
1817+
func,
18211818
distinct,
18221819
args,
18231820
filter,
@@ -1839,36 +1836,34 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18391836
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
18401837
== NullTreatment::IgnoreNulls;
18411838

1842-
let (agg_expr, filter, order_by) = match func_def {
1843-
AggregateFunctionDefinition::UDF(fun) => {
1844-
let sort_exprs = order_by.clone().unwrap_or(vec![]);
1845-
let physical_sort_exprs = match order_by {
1846-
Some(exprs) => Some(create_physical_sort_exprs(
1847-
exprs,
1848-
logical_input_schema,
1849-
execution_props,
1850-
)?),
1851-
None => None,
1852-
};
1839+
let (agg_expr, filter, order_by) = {
1840+
let sort_exprs = order_by.clone().unwrap_or(vec![]);
1841+
let physical_sort_exprs = match order_by {
1842+
Some(exprs) => Some(create_physical_sort_exprs(
1843+
exprs,
1844+
logical_input_schema,
1845+
execution_props,
1846+
)?),
1847+
None => None,
1848+
};
18531849

1854-
let ordering_reqs: Vec<PhysicalSortExpr> =
1855-
physical_sort_exprs.clone().unwrap_or(vec![]);
1850+
let ordering_reqs: Vec<PhysicalSortExpr> =
1851+
physical_sort_exprs.clone().unwrap_or(vec![]);
18561852

1857-
let agg_expr = udaf::create_aggregate_expr_with_dfschema(
1858-
fun,
1859-
&physical_args,
1860-
args,
1861-
&sort_exprs,
1862-
&ordering_reqs,
1863-
logical_input_schema,
1864-
name,
1865-
ignore_nulls,
1866-
*distinct,
1867-
false,
1868-
)?;
1853+
let agg_expr = udaf::create_aggregate_expr_with_dfschema(
1854+
func,
1855+
&physical_args,
1856+
args,
1857+
&sort_exprs,
1858+
&ordering_reqs,
1859+
logical_input_schema,
1860+
name,
1861+
ignore_nulls,
1862+
*distinct,
1863+
false,
1864+
)?;
18691865

1870-
(agg_expr, filter, physical_sort_exprs)
1871-
}
1866+
(agg_expr, filter, physical_sort_exprs)
18721867
};
18731868

18741869
Ok((agg_expr, filter, order_by))

datafusion/expr/src/expr.rs

+9-25
Original file line numberDiff line numberDiff line change
@@ -627,22 +627,6 @@ impl Sort {
627627
}
628628
}
629629

630-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
631-
/// Defines which implementation of an aggregate function DataFusion should call.
632-
pub enum AggregateFunctionDefinition {
633-
/// Resolved to a user defined aggregate function
634-
UDF(Arc<crate::AggregateUDF>),
635-
}
636-
637-
impl AggregateFunctionDefinition {
638-
/// Function's name for display
639-
pub fn name(&self) -> &str {
640-
match self {
641-
AggregateFunctionDefinition::UDF(udf) => udf.name(),
642-
}
643-
}
644-
}
645-
646630
/// Aggregate function
647631
///
648632
/// See also [`ExprFunctionExt`] to set these fields on `Expr`
@@ -651,7 +635,7 @@ impl AggregateFunctionDefinition {
651635
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
652636
pub struct AggregateFunction {
653637
/// Name of the function
654-
pub func_def: AggregateFunctionDefinition,
638+
pub func: Arc<crate::AggregateUDF>,
655639
/// List of expressions to feed to the functions as arguments
656640
pub args: Vec<Expr>,
657641
/// Whether this is a DISTINCT aggregation or not
@@ -666,15 +650,15 @@ pub struct AggregateFunction {
666650
impl AggregateFunction {
667651
/// Create a new AggregateFunction expression with a user-defined function (UDF)
668652
pub fn new_udf(
669-
udf: Arc<crate::AggregateUDF>,
653+
func: Arc<crate::AggregateUDF>,
670654
args: Vec<Expr>,
671655
distinct: bool,
672656
filter: Option<Box<Expr>>,
673657
order_by: Option<Vec<Expr>>,
674658
null_treatment: Option<NullTreatment>,
675659
) -> Self {
676660
Self {
677-
func_def: AggregateFunctionDefinition::UDF(udf),
661+
func,
678662
args,
679663
distinct,
680664
filter,
@@ -1666,14 +1650,14 @@ impl Expr {
16661650
func.hash(hasher);
16671651
}
16681652
Expr::AggregateFunction(AggregateFunction {
1669-
func_def,
1653+
func,
16701654
args: _args,
16711655
distinct,
16721656
filter: _filter,
16731657
order_by: _order_by,
16741658
null_treatment,
16751659
}) => {
1676-
func_def.hash(hasher);
1660+
func.hash(hasher);
16771661
distinct.hash(hasher);
16781662
null_treatment.hash(hasher);
16791663
}
@@ -1870,15 +1854,15 @@ impl fmt::Display for Expr {
18701854
Ok(())
18711855
}
18721856
Expr::AggregateFunction(AggregateFunction {
1873-
func_def,
1857+
func,
18741858
distinct,
18751859
ref args,
18761860
filter,
18771861
order_by,
18781862
null_treatment,
18791863
..
18801864
}) => {
1881-
fmt_function(f, func_def.name(), *distinct, args, true)?;
1865+
fmt_function(f, func.name(), *distinct, args, true)?;
18821866
if let Some(nt) = null_treatment {
18831867
write!(f, " {}", nt)?;
18841868
}
@@ -2190,14 +2174,14 @@ fn write_name<W: Write>(w: &mut W, e: &Expr) -> Result<()> {
21902174
write!(w, "{window_frame}")?;
21912175
}
21922176
Expr::AggregateFunction(AggregateFunction {
2193-
func_def,
2177+
func,
21942178
distinct,
21952179
args,
21962180
filter,
21972181
order_by,
21982182
null_treatment,
21992183
}) => {
2200-
write_function_name(w, func_def.name(), *distinct, args)?;
2184+
write_function_name(w, func.name(), *distinct, args)?;
22012185
if let Some(fe) = filter {
22022186
write!(w, " FILTER (WHERE {fe})")?;
22032187
};

datafusion/expr/src/expr_schema.rs

+21-26
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
use super::{Between, Expr, Like};
1919
use crate::expr::{
20-
AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, InList,
21-
InSubquery, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
20+
AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder,
21+
ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
2222
};
2323
use crate::type_coercion::binary::get_result_type;
2424
use crate::type_coercion::functions::{
@@ -193,28 +193,24 @@ impl ExprSchemable for Expr {
193193
_ => fun.return_type(&data_types, &nullability),
194194
}
195195
}
196-
Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => {
196+
Expr::AggregateFunction(AggregateFunction { func, args, .. }) => {
197197
let data_types = args
198198
.iter()
199199
.map(|e| e.get_type(schema))
200200
.collect::<Result<Vec<_>>>()?;
201-
match func_def {
202-
AggregateFunctionDefinition::UDF(fun) => {
203-
let new_types = data_types_with_aggregate_udf(&data_types, fun)
204-
.map_err(|err| {
205-
plan_datafusion_err!(
206-
"{} {}",
207-
err,
208-
utils::generate_signature_error_msg(
209-
fun.name(),
210-
fun.signature().clone(),
211-
&data_types
212-
)
201+
let new_types = data_types_with_aggregate_udf(&data_types, func)
202+
.map_err(|err| {
203+
plan_datafusion_err!(
204+
"{} {}",
205+
err,
206+
utils::generate_signature_error_msg(
207+
func.name(),
208+
func.signature().clone(),
209+
&data_types
213210
)
214-
})?;
215-
Ok(fun.return_type(&new_types)?)
216-
}
217-
}
211+
)
212+
})?;
213+
Ok(func.return_type(&new_types)?)
218214
}
219215
Expr::Not(_)
220216
| Expr::IsNull(_)
@@ -329,13 +325,12 @@ impl ExprSchemable for Expr {
329325
}
330326
}
331327
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
332-
Expr::AggregateFunction(AggregateFunction { func_def, .. }) => {
333-
match func_def {
334-
// TODO: UDF should be able to customize nullability
335-
AggregateFunctionDefinition::UDF(udf) if udf.name() == "count" => {
336-
Ok(false)
337-
}
338-
AggregateFunctionDefinition::UDF(_) => Ok(true),
328+
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
329+
// TODO: UDF should be able to customize nullability
330+
if func.name() == "count" {
331+
Ok(false)
332+
} else {
333+
Ok(true)
339334
}
340335
}
341336
Expr::ScalarVariable(_, _)

datafusion/expr/src/tree_node.rs

+13-18
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
//! Tree node implementation for logical expr
1919
2020
use crate::expr::{
21-
AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case,
22-
Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort,
23-
TryCast, Unnest, WindowFunction,
21+
AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList,
22+
InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
2423
};
2524
use crate::{Expr, ExprFunctionExt};
2625

@@ -304,7 +303,7 @@ impl TreeNode for Expr {
304303
}),
305304
Expr::AggregateFunction(AggregateFunction {
306305
args,
307-
func_def,
306+
func,
308307
distinct,
309308
filter,
310309
order_by,
@@ -316,20 +315,16 @@ impl TreeNode for Expr {
316315
order_by,
317316
transform_option_vec(order_by, &mut f)
318317
)?
319-
.map_data(
320-
|(new_args, new_filter, new_order_by)| match func_def {
321-
AggregateFunctionDefinition::UDF(fun) => {
322-
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
323-
fun,
324-
new_args,
325-
distinct,
326-
new_filter,
327-
new_order_by,
328-
null_treatment,
329-
)))
330-
}
331-
},
332-
)?,
318+
.map_data(|(new_args, new_filter, new_order_by)| {
319+
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
320+
func,
321+
new_args,
322+
distinct,
323+
new_filter,
324+
new_order_by,
325+
null_treatment,
326+
)))
327+
})?,
333328
Expr::GroupingSet(grouping_set) => match grouping_set {
334329
GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
335330
.update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),

datafusion/functions-nested/src/planner.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
use datafusion_common::{exec_err, utils::list_ndims, DFSchema, Result};
2121
use datafusion_expr::expr::ScalarFunction;
2222
use datafusion_expr::{
23-
expr::AggregateFunctionDefinition,
2423
planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
2524
sqlparser, Expr, ExprSchemable, GetFieldAccess,
2625
};
@@ -171,6 +170,5 @@ impl ExprPlanner for FieldAccessPlanner {
171170
}
172171

173172
fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool {
174-
let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def;
175-
return udf.name() == "array_agg";
173+
return agg_func.func.name() == "array_agg";
176174
}

datafusion/optimizer/src/analyzer/count_wildcard_rule.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ use crate::utils::NamePreserver;
2121
use datafusion_common::config::ConfigOptions;
2222
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
2323
use datafusion_common::Result;
24-
use datafusion_expr::expr::{
25-
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
26-
};
24+
use datafusion_expr::expr::{AggregateFunction, WindowFunction};
2725
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
2826
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
2927

@@ -56,10 +54,10 @@ fn is_wildcard(expr: &Expr) -> bool {
5654
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
5755
matches!(aggregate_function,
5856
AggregateFunction {
59-
func_def: AggregateFunctionDefinition::UDF(udf),
57+
func,
6058
args,
6159
..
62-
} if udf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
60+
} if func.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
6361
}
6462

6563
fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {

0 commit comments

Comments
 (0)