From 55d6d6452f7b324bd03dd47c98c2508051fb47eb Mon Sep 17 00:00:00 2001 From: feiniaofeiafei <53502832+feiniaofeiafei@users.noreply.github.com> Date: Fri, 13 Sep 2024 17:43:40 +0800 Subject: [PATCH] [Fix](nereids) set all nullable aggregate function to alwaysnullable in window expression (#40693) --- .../ExtractAndNormalizeWindowExpression.java | 16 +- .../normalize_window_nullable_agg_test.out | 293 ++++++++++++++++++ .../normalize_window_nullable_agg_test.groovy | 71 +++++ 3 files changed, 369 insertions(+), 11 deletions(-) create mode 100644 regression-test/data/nereids_rules_p0/normalize_window/normalize_window_nullable_agg_test.out create mode 100644 regression-test/suites/nereids_rules_p0/normalize_window/normalize_window_nullable_agg_test.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java index 6f067545cee0cc..e82c3f7b416b8c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java @@ -25,11 +25,7 @@ import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.WindowExpression; -import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; -import org.apache.doris.nereids.trees.expressions.functions.agg.Max; -import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; @@ -64,13 +60,11 @@ private Plan normalize(LogicalProject project) { if (output instanceof WindowExpression) { WindowExpression windowExpression = (WindowExpression) output; Expression expression = ((WindowExpression) output).getFunction(); - if (expression instanceof Sum || expression instanceof Max - || expression instanceof Min || expression instanceof Avg) { - // sum, max, min and avg in window function should be always nullable - windowExpression = ((WindowExpression) output) - .withFunction( - ((NullableAggregateFunction) expression).withAlwaysNullable(true) - ); + if (expression instanceof NullableAggregateFunction) { + // NullableAggregateFunction in window function should be always nullable + // Because there may be no data in the window frame, null values will be generated. + windowExpression = ((WindowExpression) output).withFunction( + ((NullableAggregateFunction) expression).withAlwaysNullable(true)); } ImmutableList.Builder nonLiteralPartitionKeys = diff --git a/regression-test/data/nereids_rules_p0/normalize_window/normalize_window_nullable_agg_test.out b/regression-test/data/nereids_rules_p0/normalize_window/normalize_window_nullable_agg_test.out new file mode 100644 index 00000000000000..2df25bb0d3bed2 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/normalize_window/normalize_window_nullable_agg_test.out @@ -0,0 +1,293 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !max -- +\N +2 +3 +3 +3 +5 +5 +5 +5 + +-- !min -- +\N +2 +2 +3 +2 +2 +3 +3 +3 + +-- !sum -- +\N +2 +5 +6 +5 +7 +8 +8 +8 + +-- !avg -- +\N +2.0 +2.5 +3.0 +2.5 +3.5 +4.0 +4.0 +4.0 + +-- !topn -- +\N +{"1":1} +{"2":1,"1":1} +{"3":1,"2":1} +{"4":1,"3":1} +{"5":1,"4":1} +{"6":1,"5":1} +{"7":1,"6":1} +{"8":1,"7":1} + +-- !topn_array -- +\N +[1] +[2, 1] +[3, 2] +[4, 3] +[5, 4] +[6, 5] +[7, 6] +[8, 7] + +-- !topn_weighted -- +\N +[1] +[2, 1] +[3, 2] +[4, 3] +[5, 4] +[6, 5] +[7, 6] +[8, 7] + +-- !max_by -- +\N +2 +3 +3 +2 +5 +3 +5 +3 + +-- !min_by -- +\N +2 +2 +3 +3 +2 +5 +3 +5 + +-- !avg_weighted -- +\N +2.0 +2.5 +3.0 +2.5 +3.5 +4.0 +4.0 +5.0 + +-- !variance -- +\N +0.0 +0.25 +0.0 +0.25 +2.25 +1.0 +1.0 +1.0 + +-- !variance_samp -- +\N +0.0 +0.5 +0.0 +0.5 +4.5 +2.0 +2.0 +2.0 + +-- !percentile -- +\N +2.0 +2.5 +3.0 +2.5 +3.5 +4.0 +4.0 +4.0 + +-- !percentile_approx -- +\N +2.0 +3.0 +3.0 +3.0 +5.0 +5.0 +5.0 +5.0 + +-- !stddev -- +\N +0.0 +0.5 +0.0 +0.5 +1.5 +1.0 +1.0 +1.0 + +-- !stddev_samp -- +\N +0.0 +0.7071067811865476 +0.0 +0.7071067811865476 +2.1213203435596424 +1.4142135623730951 +1.4142135623730951 +1.4142135623730951 + +-- !corr -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !covar -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !covar_samp -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !group_concat -- +\N +1 +1,1 +1,2 +2,2 +2,2 +2,2 +2,2 +2 + +-- !retention -- +\N +[1, 0] +[1, 1] +[1, 1] +[0, 0] +[0, 0] +[0, 0] +[0, 0] +[0, 0] + +-- !group_bit_and -- +\N +1 +1 +0 +2 +2 +2 +2 +2 + +-- !group_bit_or -- +\N +1 +1 +3 +2 +2 +2 +2 +2 + +-- !group_bit_xor -- +\N +1 +0 +3 +0 +0 +0 +0 +2 + +-- !group_bitmap_xor -- +\N +\N +\N +\N +\N +\N +\N +\N +\N + +-- !sum_foreach -- +\N +[1, 2] +[4, 4] +[4, 7] +[4, 7] +[8, 4] +[6, 4] +[2, 4] +[2, 25] + +-- !sequence_match -- +\N +false +false +false + diff --git a/regression-test/suites/nereids_rules_p0/normalize_window/normalize_window_nullable_agg_test.groovy b/regression-test/suites/nereids_rules_p0/normalize_window/normalize_window_nullable_agg_test.groovy new file mode 100644 index 00000000000000..915a4d025a89a1 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/normalize_window/normalize_window_nullable_agg_test.groovy @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("normalize_window_nullable_agg") { + sql "drop table if exists normalize_window_nullable_agg" + sql """create table normalize_window_nullable_agg (a int, b int,c int,d array) distributed by hash(a) + properties("replication_num"="1"); + """ + sql """insert into normalize_window_nullable_agg values(1,2,1,[1,2]),(1,3,2,[3,2]),(2,3,3,[1,5]),(2,2,4,[3,2]),(2,5,5,[5,2]) + ,(2,3,6,[1,2]),(2,5,7,[1,2]),(null,3,8,[1,23]),(null,6,9,[3,2]);""" + qt_max "select max(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_min "select min(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_sum "select sum(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_avg "select avg(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_topn "select topn(c,3) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_topn_array "select topn_array(c,3) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_topn_weighted "select topn_weighted(c,c,3) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_max_by "select max_by(b,c) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_min_by "select min_by(b,c) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_avg_weighted "select avg_weighted(b,a) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_variance "select variance(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_variance_samp "select variance_samp(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_percentile "select PERCENTILE(b,0.5) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_percentile_approx "select PERCENTILE_approx(b,0.99) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_stddev "select stddev(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_stddev_samp "select stddev_samp(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_corr "select corr(a,b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_covar "select covar(a,b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_covar_samp "select covar_samp(a,b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_group_concat "select group_concat(cast(a as varchar(10)),',') over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_retention "select retention(a=1,b>2) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_group_bit_and "select group_bit_and(a) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_group_bit_or "select group_bit_or(a) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_group_bit_xor "select group_bit_xor(a) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_group_bitmap_xor "select group_bitmap_xor(to_bitmap(a)) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + qt_sum_foreach "select sum_foreach(d) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg" + + sql "drop table if exists windowfunnel_test_normalize_window" + sql """CREATE TABLE windowfunnel_test_normalize_window ( + `xwho` varchar(50) NULL COMMENT 'xwho', + `xwhen` datetime COMMENT 'xwhen', + `xwhat` int NULL COMMENT 'xwhat' + ) + DUPLICATE KEY(xwho) + DISTRIBUTED BY HASH(xwho) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + );""" + + sql """INSERT into windowfunnel_test_normalize_window (xwho, xwhen, xwhat) values ('1', '2022-03-12 10:41:00', 1), + ('1', '2022-03-12 13:28:02', 2), + ('1', '2022-03-12 16:15:01', 3), + ('1', '2022-03-12 19:05:04', 4);""" + //这个目前会core +// qt_window_funnel """select window_funnel(3600 * 3, 'default', t.xwhen, t.xwhat = 1, t.xwhat = 2 ) over (order by xwhat rows +// between 2 preceding and 1 preceding) AS level from windowfunnel_test_normalize_window t;""" + qt_sequence_match "SELECT sequence_match('(?1)(?2)', xwhen, xwhat = 1, xwhat = 3) over (order by xwhat rows between 2 preceding and 1 preceding) FROM windowfunnel_test_normalize_window;" +} \ No newline at end of file