Skip to content

Commit 37f57fe

Browse files
lutengdaroseboy-liu
authored andcommitted
add "can_be_pushed_down" in AggregateFunction
1 parent 21b2303 commit 37f57fe

File tree

17 files changed

+46
-4
lines changed

17 files changed

+46
-4
lines changed

datafusion/core/src/physical_planner.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
16501650
args,
16511651
filter,
16521652
order_by,
1653+
..
16531654
}) => {
16541655
let args = args
16551656
.iter()

datafusion/core/tests/provider_aggregation_pushdown.rs

+2
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ impl TableProvider for CustomAggregationProvider {
249249
distinct,
250250
filter,
251251
order_by,
252+
can_be_pushed_down,
252253
}) => {
253254
let support_agg_func = match fun {
254255
aggregate_function::AggregateFunction::Count => true,
@@ -263,6 +264,7 @@ impl TableProvider for CustomAggregationProvider {
263264
&& !distinct
264265
&& filter.is_none()
265266
&& order_by.is_none()
267+
&& *can_be_pushed_down
266268
}
267269
_ => false,
268270
}

datafusion/expr/src/expr.rs

+5
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,8 @@ pub struct AggregateFunction {
425425
pub filter: Option<Box<Expr>>,
426426
/// Optional ordering
427427
pub order_by: Option<Vec<Expr>>,
428+
/// Whether it can be pushed down
429+
pub can_be_pushed_down: bool,
428430
}
429431

430432
impl AggregateFunction {
@@ -434,13 +436,15 @@ impl AggregateFunction {
434436
distinct: bool,
435437
filter: Option<Box<Expr>>,
436438
order_by: Option<Vec<Expr>>,
439+
can_be_pushed_down: bool,
437440
) -> Self {
438441
Self {
439442
fun,
440443
args,
441444
distinct,
442445
filter,
443446
order_by,
447+
can_be_pushed_down,
444448
}
445449
}
446450
}
@@ -1364,6 +1368,7 @@ fn create_name(e: &Expr) -> Result<String> {
13641368
args,
13651369
filter,
13661370
order_by,
1371+
..
13671372
}) => {
13681373
let mut name = create_function_name(&fun.to_string(), *distinct, args)?;
13691374
if let Some(fe) = filter {

datafusion/expr/src/expr_fn.rs

+12
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ pub fn min(expr: Expr) -> Expr {
111111
false,
112112
None,
113113
None,
114+
false,
114115
))
115116
}
116117

@@ -122,6 +123,7 @@ pub fn max(expr: Expr) -> Expr {
122123
false,
123124
None,
124125
None,
126+
false,
125127
))
126128
}
127129

@@ -133,6 +135,7 @@ pub fn sum(expr: Expr) -> Expr {
133135
false,
134136
None,
135137
None,
138+
false,
136139
))
137140
}
138141

@@ -144,6 +147,7 @@ pub fn avg(expr: Expr) -> Expr {
144147
false,
145148
None,
146149
None,
150+
false,
147151
))
148152
}
149153

@@ -155,6 +159,7 @@ pub fn count(expr: Expr) -> Expr {
155159
false,
156160
None,
157161
None,
162+
true,
158163
))
159164
}
160165

@@ -211,6 +216,7 @@ pub fn count_distinct(expr: Expr) -> Expr {
211216
true,
212217
None,
213218
None,
219+
false,
214220
))
215221
}
216222

@@ -263,6 +269,7 @@ pub fn approx_distinct(expr: Expr) -> Expr {
263269
false,
264270
None,
265271
None,
272+
false,
266273
))
267274
}
268275

@@ -274,6 +281,7 @@ pub fn median(expr: Expr) -> Expr {
274281
false,
275282
None,
276283
None,
284+
false,
277285
))
278286
}
279287

@@ -285,6 +293,7 @@ pub fn approx_median(expr: Expr) -> Expr {
285293
false,
286294
None,
287295
None,
296+
false,
288297
))
289298
}
290299

@@ -296,6 +305,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
296305
false,
297306
None,
298307
None,
308+
false,
299309
))
300310
}
301311

@@ -311,6 +321,7 @@ pub fn approx_percentile_cont_with_weight(
311321
false,
312322
None,
313323
None,
324+
false,
314325
))
315326
}
316327

@@ -381,6 +392,7 @@ pub fn stddev(expr: Expr) -> Expr {
381392
false,
382393
None,
383394
None,
395+
false,
384396
))
385397
}
386398

datafusion/expr/src/tree_node/expr.rs

+2
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,14 @@ impl TreeNode for Expr {
299299
distinct,
300300
filter,
301301
order_by,
302+
can_be_pushed_down,
302303
}) => Expr::AggregateFunction(AggregateFunction::new(
303304
fun,
304305
transform_vec(args, &mut transform)?,
305306
distinct,
306307
transform_option_box(filter, &mut transform)?,
307308
transform_option_vec(order_by, &mut transform)?,
309+
can_be_pushed_down,
308310
)),
309311
Expr::GroupingSet(grouping_set) => match grouping_set {
310312
GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup(

datafusion/optimizer/src/analyzer/count_wildcard_rule.rs

+2
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,15 @@ impl TreeNodeRewriter for CountWildcardRewriter {
176176
distinct,
177177
filter,
178178
order_by,
179+
can_be_pushed_down,
179180
}) if args.len() == 1 => match args[0] {
180181
Expr::Wildcard => Expr::AggregateFunction(AggregateFunction {
181182
fun: aggregate_function::AggregateFunction::Count,
182183
args: vec![lit(COUNT_STAR_EXPANSION)],
183184
distinct,
184185
filter,
185186
order_by,
187+
can_be_pushed_down,
186188
}),
187189
_ => old_expr,
188190
},

datafusion/optimizer/src/analyzer/type_coercion.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
409409
distinct,
410410
filter,
411411
order_by,
412+
can_be_pushed_down,
412413
}) => {
413414
let new_expr = coerce_agg_exprs_for_signature(
414415
&fun,
@@ -417,7 +418,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
417418
&aggregate_function::signature(&fun),
418419
)?;
419420
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
420-
fun, new_expr, distinct, filter, order_by,
421+
fun, new_expr, distinct, filter, order_by, can_be_pushed_down,
421422
));
422423
Ok(expr)
423424
}
@@ -993,6 +994,7 @@ mod test {
993994
false,
994995
None,
995996
None,
997+
false,
996998
));
997999
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
9981000
let expected = "Projection: AVG(Int64(12))\n EmptyRelation";
@@ -1006,6 +1008,7 @@ mod test {
10061008
false,
10071009
None,
10081010
None,
1011+
false,
10091012
));
10101013
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
10111014
let expected = "Projection: AVG(a)\n EmptyRelation";
@@ -1023,6 +1026,7 @@ mod test {
10231026
false,
10241027
None,
10251028
None,
1029+
false,
10261030
));
10271031
let err = Projection::try_new(vec![agg_expr], empty).err().unwrap();
10281032
assert_eq!(

datafusion/optimizer/src/push_down_projection.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,7 @@ mod tests {
10641064
false,
10651065
Some(Box::new(col("c").gt(lit(42)))),
10661066
None,
1067+
false,
10671068
));
10681069

10691070
let plan = LogicalPlanBuilder::from(table_scan)

datafusion/optimizer/src/single_distinct_to_groupby.rs

+3
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
132132
args,
133133
filter,
134134
order_by,
135+
can_be_pushed_down,
135136
..
136137
}) => {
137138
// is_single_distinct_agg ensure args.len=1
@@ -146,6 +147,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
146147
false, // intentional to remove distinct here
147148
filter.clone(),
148149
order_by.clone(),
150+
can_be_pushed_down.clone(),
149151
)))
150152
}
151153
_ => Ok(aggr_expr.clone()),
@@ -402,6 +404,7 @@ mod tests {
402404
true,
403405
None,
404406
None,
407+
false,
405408
)),
406409
],
407410
)?

datafusion/proto/src/logical_plan/from_proto.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,7 @@ pub fn parse_expr(
10111011
expr.distinct,
10121012
parse_optional_expr(expr.filter.as_deref(), registry)?.map(Box::new),
10131013
parse_vec_expr(&expr.order_by, registry)?,
1014+
false,
10141015
)))
10151016
}
10161017
ExprType::Alias(alias) => Ok(Expr::Alias(

datafusion/proto/src/logical_plan/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -2570,6 +2570,7 @@ mod roundtrip_tests {
25702570
false,
25712571
None,
25722572
None,
2573+
false,
25732574
));
25742575
let ctx = SessionContext::new();
25752576
roundtrip_expr_test(test_expr, ctx);
@@ -2583,6 +2584,7 @@ mod roundtrip_tests {
25832584
true,
25842585
None,
25852586
None,
2587+
false,
25862588
));
25872589
let ctx = SessionContext::new();
25882590
roundtrip_expr_test(test_expr, ctx);
@@ -2596,6 +2598,7 @@ mod roundtrip_tests {
25962598
false,
25972599
None,
25982600
None,
2601+
false,
25992602
));
26002603

26012604
let ctx = SessionContext::new();

datafusion/proto/src/logical_plan/to_proto.rs

+1
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
632632
ref distinct,
633633
ref filter,
634634
ref order_by,
635+
..
635636
}) => {
636637
let aggr_function = match fun {
637638
AggregateFunction::ApproxDistinct => {

datafusion/sql/src/expr/function.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
135135
self.function_args_to_expr(function.args, schema, planner_context)?;
136136

137137
return Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
138-
fun, args, distinct, None, order_by,
138+
fun, args, distinct, None, order_by, true,
139139
)));
140140
};
141141

datafusion/sql/src/expr/mod.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
365365
// next, aggregate built-ins
366366
let fun = AggregateFunction::ArrayAgg;
367367
Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
368-
fun, args, distinct, None, order_by,
368+
fun, args, distinct, None, order_by, false,
369369
)))
370370
}
371371

@@ -500,6 +500,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
500500
args,
501501
distinct,
502502
order_by,
503+
can_be_pushed_down,
503504
..
504505
}) => Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
505506
fun,
@@ -511,6 +512,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
511512
planner_context,
512513
)?)),
513514
order_by,
515+
can_be_pushed_down,
514516
))),
515517
_ => Err(DataFusionError::Plan(
516518
"AggregateExpressionWithFilter expression was not an AggregateFunction"

datafusion/sql/src/utils.rs

+2
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ where
167167
distinct,
168168
filter,
169169
order_by,
170+
can_be_pushed_down,
170171
}) => Ok(Expr::AggregateFunction(AggregateFunction::new(
171172
fun.clone(),
172173
args.iter()
@@ -175,6 +176,7 @@ where
175176
*distinct,
176177
filter.clone(),
177178
order_by.clone(),
179+
can_be_pushed_down.clone(),
178180
))),
179181
Expr::WindowFunction(WindowFunction {
180182
fun,

datafusion/substrait/src/logical_plan/consumer.rs

+1
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ pub async fn from_substrait_agg_func(
665665
distinct,
666666
filter,
667667
order_by,
668+
can_be_pushed_down: false,
668669
})))
669670
}
670671

datafusion/substrait/src/logical_plan/producer.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ pub fn to_substrait_agg_measure(
472472
),
473473
) -> Result<Measure> {
474474
match expr {
475-
Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by }) => {
475+
Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by, .. }) => {
476476
let sorts = if let Some(order_by) = order_by {
477477
order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
478478
} else {

0 commit comments

Comments
 (0)