diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java index 28a149f7fd7..1f755d50df1 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java @@ -14,8 +14,10 @@ import com.google.common.collect.ImmutableList; import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -23,8 +25,11 @@ import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelShuttle; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalSort; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexCorrelVariable; @@ -38,6 +43,7 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.Pair; import org.apache.calcite.util.Util; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; @@ -501,13 +507,51 @@ public Void visitInputRef(RexInputRef inputRef) { return selectedColumns; } + // `RelDecorrelator` may generate a Project with duplicated fields, e.g. Project($0,$0). + // There will be problem if pushing down the pattern like `Aggregate(AGG($0),{1})-Project($0,$0)`, + // as it will lead to field-name conflict. + // We should wait and rely on `AggregateProjectMergeRule` to mitigate it by having this constraint + // Nevertheless, that rule cannot handle all cases if there is RexCall in the Project, + // e.g. Project($0, $0, +($0,1)). We cannot push down the Aggregate for this corner case. + // TODO: Simplify the Project where there is RexCall by adding a new rule. + static boolean distinctProjectList(LogicalProject project) { + // Change to Set> to resolve + // https://github.com/opensearch-project/sql/issues/4347 + Set> rexSet = new HashSet<>(); + return project.getNamedProjects().stream().allMatch(rexSet::add); + } + + static boolean containsRexOver(LogicalProject project) { + return project.getProjects().stream().anyMatch(RexOver::containsOver); + } + + /** + * The LogicalSort is a LIMIT that should be pushed down when its fetch field is not null and its + * collation is empty. For example: sort name | head 5 should not be pushed down + * because it has a field collation. + * + * @param sort The LogicalSort to check. + * @return True if the LogicalSort is a LIMIT, false otherwise. + */ + static boolean isLogicalSortLimit(LogicalSort sort) { + return sort.fetch != null; + } + + static boolean projectContainsExpr(Project project) { + return project.getProjects().stream().anyMatch(p -> p instanceof RexCall); + } + + static boolean sortByFieldsOnly(Sort sort) { + return !sort.getCollation().getFieldCollations().isEmpty() && sort.fetch == null; + } + /** * Get a string representation of the argument types expressed in ExprType for error messages. * * @param argTypes the list of argument types as {@link RelDataType} * @return a string in the format [type1,type2,...] representing the argument types */ - public static String getActualSignature(List argTypes) { + static String getActualSignature(List argTypes) { return "[" + argTypes.stream() .map(OpenSearchTypeFactory::convertRelDataTypeToExprType) diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index e1c726acbbb..4fd328c9eff 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -1018,8 +1018,7 @@ public void testExplainCountsByAgg() throws IOException { } @Test - public void testExplainSortOnMetricsNoBucketNullable() throws IOException { - // TODO enhancement later: https://github.com/opensearch-project/sql/issues/4282 + public void testExplainSortOnMetrics() throws IOException { enabledOnlyWhenPushdownIsEnabled(); String expected = loadExpectedPlan("explain_agg_sort_on_metrics1.yaml"); assertYamlEqualsIgnoreId( @@ -1027,8 +1026,34 @@ public void testExplainSortOnMetricsNoBucketNullable() throws IOException { explainQueryYaml( "source=opensearch-sql_test_index_account | stats bucket_nullable=false count() by" + " state | sort `count()`")); - expected = loadExpectedPlan("explain_agg_sort_on_metrics2.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account | stats bucket_nullable=false sum(balance)" + + " as sum by state | sort - sum")); + // TODO limit should pushdown to non-composite agg + expected = loadExpectedPlan("explain_agg_sort_on_metrics3.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + String.format( + "source=%s | stats count() as cnt by span(birthdate, 1d) | sort - cnt", + TEST_INDEX_BANK))); + expected = loadExpectedPlan("explain_agg_sort_on_metrics4.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + String.format( + "source=%s | stats bucket_nullable=false sum(balance) by span(age, 5) | sort -" + + " `sum(balance)`", + TEST_INDEX_BANK))); + } + + @Test + public void testExplainSortOnMetricsMultiTerms() throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + String expected = loadExpectedPlan("explain_agg_sort_on_metrics_multi_terms.yaml"); assertYamlEqualsIgnoreId( expected, explainQueryYaml( @@ -1036,6 +1061,75 @@ public void testExplainSortOnMetricsNoBucketNullable() throws IOException { + " gender, state | sort `count()`")); } + @Test + public void testExplainCompositeMultiBucketsAutoDateThenSortOnMetricsNotPushdown() + throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + assertYamlEqualsIgnoreId( + loadExpectedPlan("agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml"), + explainQueryYaml( + String.format( + "source=%s | bin timestamp bins=3 | stats bucket_nullable=false avg(value), count()" + + " as cnt by category, value, timestamp | sort cnt", + TEST_INDEX_TIME_DATA))); + } + + @Test + public void testExplainCompositeRangeThenSortOnMetricsNotPushdown() throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + assertYamlEqualsIgnoreId( + loadExpectedPlan("agg_composite_range_sort_agg_metric_not_push.yaml"), + explainQueryYaml( + String.format( + "source=%s | eval value_range = case(value < 7000, 'small'" + + " else 'great') | stats bucket_nullable=false avg(value), count() as cnt by" + + " value_range, category | sort cnt", + TEST_INDEX_TIME_DATA))); + } + + @Test + public void testExplainCompositeAutoDateThenSortOnMetricsNotPushdown() throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + assertYamlEqualsIgnoreId( + loadExpectedPlan("agg_composite_autodate_sort_agg_metric_not_push.yaml"), + explainQueryYaml( + String.format( + "source=%s | bin timestamp bins=3 | stats bucket_nullable=false avg(value), count()" + + " as cnt by timestamp, category | sort cnt", + TEST_INDEX_TIME_DATA))); + } + + @Test + public void testExplainCompositeRangeAutoDateThenSortOnMetricsNotPushdown() throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + assertYamlEqualsIgnoreId( + loadExpectedPlan("agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml"), + explainQueryYaml( + String.format( + "source=%s | bin timestamp bins=3 | eval value_range = case(value < 7000, 'small'" + + " else 'great') | stats bucket_nullable=false avg(value), count() as cnt by" + + " timestamp, value_range, category | sort cnt", + TEST_INDEX_TIME_DATA))); + } + + @Test + public void testExplainMultipleAggregatorsWithSortOnOneMetricNotPushDown() throws IOException { + enabledOnlyWhenPushdownIsEnabled(); + String expected = + loadExpectedPlan("explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account | stats bucket_nullable=false count() as c," + + " sum(balance) as s by state | sort c")); + expected = loadExpectedPlan("explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account | stats bucket_nullable=false count() as c," + + " sum(balance) as s by state | sort c, s")); + } + @Test public void testExplainEvalMax() throws IOException { String expected = loadExpectedPlan("explain_eval_max.json"); diff --git a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml new file mode 100644 index 00000000000..90e83946c38 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml @@ -0,0 +1,14 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$1], dir0=[ASC-nulls-first]) + LogicalProject(avg(value)=[$3], cnt=[$4], timestamp=[$0], value_range=[$1], category=[$2]) + LogicalAggregate(group=[{0, 1, 2}], avg(value)=[AVG($3)], cnt=[COUNT()]) + LogicalProject(timestamp=[$9], value_range=[$10], category=[$1], value=[$2]) + LogicalFilter(condition=[AND(IS NOT NULL($9), IS NOT NULL($1))]) + LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())], value_range=[CASE(<($2, 7000), 'small':VARCHAR, 'great':VARCHAR)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2, 3},avg(value)=AVG($1),cnt=COUNT()), PROJECT->[avg(value), cnt, timestamp, value_range, category]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null},"aggregations":{"value_range":{"range":{"field":"value","ranges":[{"key":"small","to":7000.0},{"key":"great","from":7000.0}],"keyed":true},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_sort_agg_metric_not_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_sort_agg_metric_not_push.yaml new file mode 100644 index 00000000000..e3d4d9fba4d --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_autodate_sort_agg_metric_not_push.yaml @@ -0,0 +1,14 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$1], dir0=[ASC-nulls-first]) + LogicalProject(avg(value)=[$2], cnt=[$3], timestamp=[$0], category=[$1]) + LogicalAggregate(group=[{0, 1}], avg(value)=[AVG($2)], cnt=[COUNT()]) + LogicalProject(timestamp=[$9], category=[$1], value=[$2]) + LogicalFilter(condition=[AND(IS NOT NULL($9), IS NOT NULL($1))]) + LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2},avg(value)=AVG($1),cnt=COUNT()), PROJECT->[avg(value), cnt, timestamp, category]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml new file mode 100644 index 00000000000..6995097f878 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml @@ -0,0 +1,15 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$1], dir0=[ASC-nulls-first]) + LogicalProject(avg(value)=[$3], cnt=[$4], category=[$0], value=[$1], timestamp=[$2]) + LogicalAggregate(group=[{0, 1, 2}], avg(value)=[AVG($1)], cnt=[COUNT()]) + LogicalProject(category=[$1], value=[$2], timestamp=[$9]) + LogicalFilter(condition=[AND(IS NOT NULL($1), IS NOT NULL($2), IS NOT NULL($9))]) + LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first]) + EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t1):DOUBLE], avg(value)=[$t4], cnt=[$t3], category=[$t0], value=[$t1], timestamp=[$t2]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1, 2},cnt=COUNT())], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}},{"value":{"terms":{"field":"value","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_range_sort_agg_metric_not_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_range_sort_agg_metric_not_push.yaml new file mode 100644 index 00000000000..19846e9910b --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/agg_composite_range_sort_agg_metric_not_push.yaml @@ -0,0 +1,14 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$1], dir0=[ASC-nulls-first]) + LogicalProject(avg(value)=[$2], cnt=[$3], value_range=[$0], category=[$1]) + LogicalAggregate(group=[{0, 1}], avg(value)=[AVG($2)], cnt=[COUNT()]) + LogicalProject(value_range=[$10], category=[$1], value=[$2]) + LogicalFilter(condition=[IS NOT NULL($1)]) + LogicalProject(@timestamp=[$0], category=[$1], value=[$2], timestamp=[$3], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], value_range=[CASE(<($2, 7000), 'small':VARCHAR, 'great':VARCHAR)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2},avg(value)=AVG($1),cnt=COUNT()), PROJECT->[avg(value), cnt, value_range, category]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"value_range":{"range":{"field":"value","ranges":[{"key":"small","to":7000.0},{"key":"great","from":7000.0}],"keyed":true},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.json deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml index 81082ac86e7..b837e4968d4 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics1.yaml @@ -8,6 +8,4 @@ calcite: LogicalFilter(condition=[IS NOT NULL($7)]) CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) physical: | - EnumerableLimit(fetch=[10000]) - EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first]) - CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), PROJECT->[count(), state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), SORT_AGG_METRICS->[1 ASC FIRST], PROJECT->[count(), state], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"state":{"terms":{"field":"state.keyword","size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"count()":"asc"},{"_key":"asc"}]},"aggregations":{"count()":{"value_count":{"field":"_index"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics2.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics2.yaml index 8a45ecc2f92..1808eba1f08 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics2.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics2.yaml @@ -1,13 +1,11 @@ calcite: logical: | - LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) - LogicalSort(sort0=[$0], dir0=[ASC-nulls-first]) - LogicalProject(count()=[$2], gender=[$0], state=[$1]) - LogicalAggregate(group=[{0, 1}], count()=[COUNT()]) - LogicalProject(gender=[$4], state=[$7]) - LogicalFilter(condition=[AND(IS NOT NULL($4), IS NOT NULL($7))]) + LogicalSystemLimit(sort0=[$0], dir0=[DESC-nulls-last], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[DESC-nulls-last]) + LogicalProject(sum=[$1], state=[$0]) + LogicalAggregate(group=[{0}], sum=[SUM($1)]) + LogicalProject(state=[$7], balance=[$3]) + LogicalFilter(condition=[IS NOT NULL($7)]) CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) physical: | - EnumerableLimit(fetch=[10000]) - EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first]) - CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},count()=COUNT()), PROJECT->[count(), gender, state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"gender":{"terms":{"field":"gender.keyword","missing_bucket":false,"order":"asc"}}},{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1},sum=SUM($0)), SORT_AGG_METRICS->[1 DESC LAST], PROJECT->[sum, state], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"state":{"terms":{"field":"state.keyword","size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"sum":"desc"},{"_key":"asc"}]},"aggregations":{"sum":{"sum":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics3.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics3.yaml new file mode 100644 index 00000000000..a40c5cec466 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics3.yaml @@ -0,0 +1,12 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], dir0=[DESC-nulls-last], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[DESC-nulls-last]) + LogicalProject(cnt=[$1], span(birthdate,1d)=[$0]) + LogicalAggregate(group=[{0}], cnt=[COUNT()]) + LogicalProject(span(birthdate,1d)=[SPAN($3, 1, 'd')]) + LogicalFilter(condition=[IS NOT NULL($3)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]]) + physical: | + EnumerableLimit(fetch=[10000]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[birthdate], FILTER->IS NOT NULL($0), AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},cnt=COUNT()), SORT_AGG_METRICS->[1 DESC LAST], PROJECT->[cnt, span(birthdate,1d)]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","query":{"exists":{"field":"birthdate","boost":1.0}},"_source":{"includes":["birthdate"],"excludes":[]},"aggregations":{"span(birthdate,1d)":{"date_histogram":{"field":"birthdate","fixed_interval":"1d","offset":0,"order":[{"cnt":"desc"},{"_key":"asc"}],"keyed":false,"min_doc_count":0},"aggregations":{"cnt":{"value_count":{"field":"_index"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics4.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics4.yaml new file mode 100644 index 00000000000..74ff751bcef --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics4.yaml @@ -0,0 +1,12 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], dir0=[DESC-nulls-last], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[DESC-nulls-last]) + LogicalProject(sum(balance)=[$1], span(age,5)=[$0]) + LogicalAggregate(group=[{1}], sum(balance)=[SUM($0)]) + LogicalProject(balance=[$7], span(age,5)=[SPAN($10, 5, null:NULL)]) + LogicalFilter(condition=[IS NOT NULL($10)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]]) + physical: | + EnumerableLimit(fetch=[10000]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[balance, age], FILTER->IS NOT NULL($1), AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1},sum(balance)=SUM($0)), SORT_AGG_METRICS->[1 DESC LAST], PROJECT->[sum(balance), span(age,5)]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","query":{"exists":{"field":"age","boost":1.0}},"_source":{"includes":["balance","age"],"excludes":[]},"aggregations":{"span(age,5)":{"histogram":{"field":"age","interval":5.0,"offset":0.0,"order":[{"sum(balance)":"desc"},{"_key":"asc"}],"keyed":false,"min_doc_count":0},"aggregations":{"sum(balance)":{"sum":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms.yaml new file mode 100644 index 00000000000..a7a2bbad9db --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_agg_sort_on_metrics_multi_terms.yaml @@ -0,0 +1,11 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[ASC-nulls-first]) + LogicalProject(count()=[$2], gender=[$0], state=[$1]) + LogicalAggregate(group=[{0, 1}], count()=[COUNT()]) + LogicalProject(gender=[$4], state=[$7]) + LogicalFilter(condition=[AND(IS NOT NULL($4), IS NOT NULL($7))]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},count()=COUNT()), SORT_AGG_METRICS->[2 ASC FIRST], PROJECT->[count(), gender, state], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"multi_terms_buckets":{"multi_terms":{"terms":[{"field":"gender.keyword"},{"field":"state.keyword"}],"size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"_count":"desc"},{"_key":"asc"}]},"aggregations":{"count()":{"value_count":{"field":"_index"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml new file mode 100644 index 00000000000..6a5bc8ea0f5 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml @@ -0,0 +1,13 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], dir0=[ASC-nulls-first]) + LogicalProject(c=[$1], s=[$2], state=[$0]) + LogicalAggregate(group=[{0}], c=[COUNT()], s=[SUM($1)]) + LogicalProject(state=[$7], balance=[$3]) + LogicalFilter(condition=[IS NOT NULL($7)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1},c=COUNT(),s=SUM($0)), PROJECT->[c, s, state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"s":{"sum":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml new file mode 100644 index 00000000000..d1651f464a6 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml @@ -0,0 +1,13 @@ +calcite: + logical: | + LogicalSystemLimit(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalSort(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[ASC-nulls-first]) + LogicalProject(c=[$1], s=[$2], state=[$0]) + LogicalAggregate(group=[{0}], c=[COUNT()], s=[SUM($1)]) + LogicalProject(state=[$7], balance=[$3]) + LogicalFilter(condition=[IS NOT NULL($7)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC-nulls-first], dir1=[ASC-nulls-first]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1},c=COUNT(),s=SUM($0)), PROJECT->[c, s, state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"s":{"sum":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ExpandCollationOnProjectExprRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ExpandCollationOnProjectExprRule.java index 36acc3b0dab..57db35b092a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ExpandCollationOnProjectExprRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/ExpandCollationOnProjectExprRule.java @@ -19,6 +19,7 @@ import org.apache.calcite.rel.core.Project; import org.apache.commons.lang3.tuple.Pair; import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.opensearch.util.OpenSearchRelOptUtil; /** @@ -108,7 +109,7 @@ public interface Config extends RelRule.Config { .oneInput( b1 -> b1.operand(EnumerableProject.class) - .predicate(OpenSearchIndexScanRule::projectContainsExpr) + .predicate(PlanUtils::projectContainsExpr) .predicate(p -> !p.containsOver()) .anyInputs())); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java index 079616456e1..2beef0f2ea9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java @@ -22,7 +22,9 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.immutables.value.Value; import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.expression.function.udf.binning.WidthBucketFunction; +import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; /** Planner rule that push a {@link LogicalAggregate} down to {@link CalciteLogicalIndexScan} */ @@ -105,17 +107,17 @@ public interface Config extends RelRule.Config { // 1. No RexOver and no duplicate projection // 2. Contains width_bucket function on date field referring // to bin command with parameter bins - Predicate.not(OpenSearchIndexScanRule::containsRexOver) - .and(OpenSearchIndexScanRule::distinctProjectList) + Predicate.not(PlanUtils::containsRexOver) + .and(PlanUtils::distinctProjectList) .or(Config::containsWidthBucketFuncOnDate)) .oneInput( b2 -> b2.operand(CalciteLogicalIndexScan.class) .predicate( Predicate.not( - OpenSearchIndexScanRule::isLimitPushed) + AbstractCalciteIndexScan::isLimitPushed) .and( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::noAggregatePushed)) .noInputs()))); Config COUNT_STAR = @@ -137,8 +139,8 @@ public interface Config extends RelRule.Config { b1 -> b1.operand(CalciteLogicalIndexScan.class) .predicate( - Predicate.not(OpenSearchIndexScanRule::isLimitPushed) - .and(OpenSearchIndexScanRule::noAggregatePushed)) + Predicate.not(AbstractCalciteIndexScan::isLimitPushed) + .and(AbstractCalciteIndexScan::noAggregatePushed)) .noInputs())); // TODO: No need this rule once https://github.com/opensearch-project/sql/issues/4403 is // addressed @@ -172,22 +174,18 @@ public interface Config extends RelRule.Config { // 2. Contains width_bucket function on date // field referring // to bin command with parameter bins - Predicate.not( - OpenSearchIndexScanRule - ::containsRexOver) - .and( - OpenSearchIndexScanRule - ::distinctProjectList) + Predicate.not(PlanUtils::containsRexOver) + .and(PlanUtils::distinctProjectList) .or(Config::containsWidthBucketFuncOnDate)) .oneInput( b3 -> b3.operand(CalciteLogicalIndexScan.class) .predicate( Predicate.not( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::isLimitPushed) .and( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::noAggregatePushed)) .noInputs())))); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java index 0650d7ed44a..a0917fda4eb 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchDedupPushdownRule.java @@ -22,6 +22,7 @@ import org.apache.logging.log4j.Logger; import org.immutables.value.Value; import org.opensearch.sql.calcite.utils.PlanUtils; +import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; @Value.Enclosing @@ -124,10 +125,10 @@ public interface Config extends RelRule.Config { b3.operand(CalciteLogicalIndexScan.class) .predicate( Predicate.not( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::isLimitPushed) .and( - OpenSearchIndexScanRule + AbstractCalciteIndexScan ::noAggregatePushed)) .noInputs())))); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java index 4ee5148b826..686da70b37c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java @@ -11,6 +11,7 @@ import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.logical.LogicalFilter; import org.immutables.value.Value; +import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; /** Planner rule that push a {@link LogicalFilter} down to {@link CalciteLogicalIndexScan} */ @@ -63,8 +64,8 @@ public interface Config extends RelRule.Config { // handle filter pushdown after limit. Both "limit after // filter" and "filter after limit" result in the same // limit-after-filter DSL. - Predicate.not(OpenSearchIndexScanRule::isLimitPushed) - .and(OpenSearchIndexScanRule::noAggregatePushed)) + Predicate.not(AbstractCalciteIndexScan::isLimitPushed) + .and(AbstractCalciteIndexScan::noAggregatePushed)) .noInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java index 4c16893a89c..0d8e0d0d5d9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java @@ -31,6 +31,8 @@ public class OpenSearchIndexRules { SortProjectExprTransposeRule.Config.DEFAULT.toRule(); private static final ExpandCollationOnProjectExprRule EXPAND_COLLATION_ON_PROJECT_EXPR = ExpandCollationOnProjectExprRule.Config.DEFAULT.toRule(); + private static final SortAggregationMetricsRule SORT_AGGREGATION_METRICS_RULE = + SortAggregationMetricsRule.Config.DEFAULT.toRule(); // Rule that always pushes down relevance functions regardless of pushdown settings public static final OpenSearchRelevanceFunctionPushdownRule RELEVANCE_FUNCTION_PUSHDOWN = @@ -48,6 +50,7 @@ public class OpenSearchIndexRules { // DEDUP_PUSH_DOWN, SORT_INDEX_SCAN, SORT_PROJECT_EXPR_TRANSPOSE, + SORT_AGGREGATION_METRICS_RULE, EXPAND_COLLATION_ON_PROJECT_EXPR); // prevent instantiation diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java deleted file mode 100644 index 24abb3c3bc9..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.planner.physical; - -import java.util.HashSet; -import java.util.Set; -import org.apache.calcite.plan.RelOptTable; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rel.core.Sort; -import org.apache.calcite.rel.logical.LogicalProject; -import org.apache.calcite.rel.logical.LogicalSort; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.rex.RexOver; -import org.apache.calcite.util.Pair; -import org.opensearch.sql.opensearch.storage.OpenSearchIndex; -import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; - -public interface OpenSearchIndexScanRule { - /** - * CalciteOpenSearchIndexScan doesn't allow push-down anymore (except Sort under some strict - * condition) after Aggregate push-down. - */ - static boolean noAggregatePushed(AbstractCalciteIndexScan scan) { - if (scan.getPushDownContext().isAggregatePushed()) return false; - final RelOptTable table = scan.getTable(); - return table.unwrap(OpenSearchIndex.class) != null; - } - - static boolean isLimitPushed(AbstractCalciteIndexScan scan) { - return scan.getPushDownContext().isLimitPushed(); - } - - // `RelDecorrelator` may generate a Project with duplicated fields, e.g. Project($0,$0). - // There will be problem if pushing down the pattern like `Aggregate(AGG($0),{1})-Project($0,$0)`, - // as it will lead to field-name conflict. - // We should wait and rely on `AggregateProjectMergeRule` to mitigate it by having this constraint - // Nevertheless, that rule cannot handle all cases if there is RexCall in the Project, - // e.g. Project($0, $0, +($0,1)). We cannot push down the Aggregate for this corner case. - // TODO: Simplify the Project where there is RexCall by adding a new rule. - static boolean distinctProjectList(LogicalProject project) { - // Change to Set> to resolve - // https://github.com/opensearch-project/sql/issues/4347 - Set> rexSet = new HashSet<>(); - return project.getNamedProjects().stream().allMatch(rexSet::add); - } - - static boolean containsRexOver(LogicalProject project) { - return project.getProjects().stream().anyMatch(RexOver::containsOver); - } - - /** - * The LogicalSort is a LIMIT that should be pushed down when its fetch field is not null and its - * collation is empty. For example: sort name | head 5 should not be pushed down - * because it has a field collation. - * - * @param sort The LogicalSort to check. - * @return True if the LogicalSort is a LIMIT, false otherwise. - */ - static boolean isLogicalSortLimit(LogicalSort sort) { - return sort.fetch != null; - } - - static boolean projectContainsExpr(Project project) { - return project.getProjects().stream().anyMatch(p -> p instanceof RexCall); - } - - static boolean sortByFieldsOnly(Sort sort) { - return !sort.getCollation().getFieldCollations().isEmpty() && sort.fetch == null; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchLimitIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchLimitIndexScanRule.java index a6832b06a24..31a7bf233d2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchLimitIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchLimitIndexScanRule.java @@ -13,6 +13,7 @@ import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; /** @@ -88,7 +89,7 @@ public interface Config extends RelRule.Config { .withOperandSupplier( b0 -> b0.operand(LogicalSort.class) - .predicate(OpenSearchIndexScanRule::isLogicalSortLimit) + .predicate(PlanUtils::isLogicalSortLimit) .oneInput(b1 -> b1.operand(CalciteLogicalIndexScan.class).noInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java index 47274f467fc..b519c50a1ec 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java @@ -10,6 +10,7 @@ import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.core.Sort; import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; @Value.Enclosing @@ -43,7 +44,7 @@ public interface Config extends RelRule.Config { .withOperandSupplier( b0 -> b0.operand(Sort.class) - .predicate(OpenSearchIndexScanRule::sortByFieldsOnly) + .predicate(PlanUtils::sortByFieldsOnly) .oneInput( b1 -> b1.operand(AbstractCalciteIndexScan.class) @@ -51,7 +52,11 @@ public interface Config extends RelRule.Config { // because pushing down a sort after a limit will be treated // as sort-then-limit by OpenSearch DSL. .predicate( - Predicate.not(OpenSearchIndexScanRule::isLimitPushed)) + Predicate.not(AbstractCalciteIndexScan::isLimitPushed) + .and( + Predicate.not( + AbstractCalciteIndexScan + ::isMetricsOrderPushed))) .noInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java new file mode 100644 index 00000000000..63b04e8c099 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortAggregationMetricsRule.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.physical; + +import java.util.function.Predicate; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.logical.LogicalSort; +import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; +import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan; +import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; + +@Value.Enclosing +public class SortAggregationMetricsRule extends RelRule { + + protected SortAggregationMetricsRule(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final LogicalSort sort = call.rel(0); + final CalciteLogicalIndexScan scan = call.rel(1); + CalciteLogicalIndexScan newScan = scan.pushDownSortAggregateMetrics(sort); + if (newScan != null) { + call.transformTo(newScan); + } + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + // TODO support multiple metrics, only support single metric sort + Predicate hasOneFieldCollation = + sort -> sort.getCollation().getFieldCollations().size() == 1; + SortAggregationMetricsRule.Config DEFAULT = + ImmutableSortAggregationMetricsRule.Config.builder() + .build() + .withDescription("Sort-TableScan(agg-pushed)") + .withOperandSupplier( + b0 -> + b0.operand(LogicalSort.class) + .predicate(hasOneFieldCollation.and(PlanUtils::sortByFieldsOnly)) + .oneInput( + b1 -> + b1.operand(CalciteLogicalIndexScan.class) + .predicate( + Predicate.not(AbstractCalciteIndexScan::noAggregatePushed)) + .noInputs())); + + @Override + default SortAggregationMetricsRule toRule() { + return new SortAggregationMetricsRule(this); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortProjectExprTransposeRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortProjectExprTransposeRule.java index dfec6754908..2ed94292096 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortProjectExprTransposeRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/SortProjectExprTransposeRule.java @@ -26,6 +26,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.commons.lang3.tuple.Pair; import org.immutables.value.Value; +import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.opensearch.util.OpenSearchRelOptUtil; /** @@ -132,7 +133,7 @@ public interface Config extends RelRule.Config { b1.operand(LogicalProject.class) .predicate( Predicate.not(LogicalProject::containsOver) - .and(OpenSearchIndexScanRule::projectContainsExpr)) + .and(PlanUtils::projectContainsExpr)) .anyInputs())); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java index 8e50ef572be..b9129ced8eb 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -74,7 +74,6 @@ import org.opensearch.search.aggregations.support.ValueType; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.search.sort.SortOrder; -import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.data.type.ExprCoreType; @@ -129,7 +128,7 @@ public static class ExpressionNotAnalyzableException extends Exception { private AggregateAnalyzer() {} @RequiredArgsConstructor - static class AggregateBuilderHelper { + public static class AggregateBuilderHelper { final RelDataType rowType; final Map fieldTypes; final RelOptCluster cluster; @@ -182,24 +181,12 @@ T inferValue(RexNode node, Class clazz) { public static Pair, OpenSearchAggregationResponseParser> analyze( Aggregate aggregate, Project project, - RelDataType rowType, - Map fieldTypes, List outputFields, - RelOptCluster cluster, - int bucketSize) + AggregateBuilderHelper helper) throws ExpressionNotAnalyzableException { requireNonNull(aggregate, "aggregate"); try { - boolean bucketNullable = - Boolean.parseBoolean( - aggregate.getHints().stream() - .filter(hits -> hits.hintName.equals("stats_args")) - .map(hint -> hint.kvOptions.getOrDefault(Argument.BUCKET_NULLABLE, "true")) - .findFirst() - .orElseGet(() -> "true")); List groupList = aggregate.getGroupSet().asList(); - AggregateBuilderHelper helper = - new AggregateBuilderHelper(rowType, fieldTypes, cluster, bucketNullable, bucketSize); List aggFieldNames = outputFields.subList(groupList.size(), outputFields.size()); // Process all aggregate calls Pair> builderAndParser = @@ -272,7 +259,7 @@ public static Pair, OpenSearchAggregationResponseParser + " aggregation"); } AggregationBuilder compositeBuilder = - AggregationBuilders.composite("composite_buckets", buckets).size(bucketSize); + AggregationBuilders.composite("composite_buckets", buckets).size(helper.bucketSize); if (subBuilder != null) { compositeBuilder.subAggregations(subBuilder); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java index 17b82ffd4a5..9f877e29168 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java @@ -7,10 +7,10 @@ import static java.util.Objects.requireNonNull; import static org.opensearch.sql.common.setting.Settings.Key.CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR; -import static org.opensearch.sql.opensearch.storage.scan.PushDownType.AGGREGATION; -import static org.opensearch.sql.opensearch.storage.scan.PushDownType.PROJECT; +import static org.opensearch.sql.opensearch.storage.scan.context.PushDownType.AGGREGATION; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -48,6 +48,15 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.scan.context.AbstractAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggPushDownAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggregationBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.FilterDigest; +import org.opensearch.sql.opensearch.storage.scan.context.LimitDigest; +import org.opensearch.sql.opensearch.storage.scan.context.OSRequestBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownContext; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownOperation; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownType; /** An abstract relational operator representing a scan of an OpenSearchIndex type. */ @Getter @@ -105,21 +114,24 @@ public double estimateRowCount(RelMetadataQuery mq) { .reduce( osIndex.getMaxResultWindow().doubleValue(), (rowCount, operation) -> { - switch (operation.getType()) { + switch (operation.type()) { case AGGREGATION: - return mq.getRowCount((RelNode) operation.getDigest()); + return mq.getRowCount((RelNode) operation.digest()); case PROJECT: case SORT: return rowCount; + case SORT_AGG_METRICS: + return NumberUtil.min( + rowCount, osIndex.getBucketSize().doubleValue()); case COLLAPSE: return rowCount / 10; case FILTER: case SCRIPT: return NumberUtil.multiply( rowCount, - RelMdUtil.guessSelectivity(((FilterDigest) operation.getDigest()).getCondition())); + RelMdUtil.guessSelectivity(((FilterDigest) operation.digest()).condition())); case LIMIT: - return Math.min(rowCount, ((LimitDigest) operation.getDigest()).getLimit()); + return Math.min(rowCount, ((LimitDigest) operation.digest()).limit()); default: return rowCount; } @@ -145,9 +157,9 @@ public double estimateRowCount(RelMetadataQuery mq) { public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { double dRows = osIndex.getMaxResultWindow().doubleValue(), dCpu = 0.0d; for (PushDownOperation operation : pushDownContext) { - switch (operation.getType()) { + switch (operation.type()) { case AGGREGATION: - dRows = mq.getRowCount((RelNode) operation.getDigest()); + dRows = mq.getRowCount((RelNode) operation.digest()); dCpu += dRows * getAggMultiplier(operation); break; // Ignored Project in cost accumulation, but it will affect the external cost @@ -156,6 +168,10 @@ public double estimateRowCount(RelMetadataQuery mq) { case SORT: dCpu += dRows; break; + case SORT_AGG_METRICS: + dRows = dRows * .9 / 10; + dCpu += dRows; + break; // Refer the org.apache.calcite.rel.metadata.RelMdRowCount.getRowCount(Aggregate rel,...) case COLLAPSE: dRows = dRows / 10; @@ -165,21 +181,21 @@ public double estimateRowCount(RelMetadataQuery mq) { case FILTER: dRows = NumberUtil.multiply( - dRows, RelMdUtil.guessSelectivity(((FilterDigest) operation.getDigest()).getCondition())); + dRows, RelMdUtil.guessSelectivity(((FilterDigest) operation.digest()).condition())); break; case SCRIPT: - FilterDigest filterDigest = (FilterDigest) operation.getDigest(); - dRows = NumberUtil.multiply(dRows, RelMdUtil.guessSelectivity(filterDigest.getCondition())); + FilterDigest filterDigest = (FilterDigest) operation.digest(); + dRows = NumberUtil.multiply(dRows, RelMdUtil.guessSelectivity(filterDigest.condition())); // Calculate the cost of script filter by multiplying the selectivity of the filter and // the factor amplified by script count. - dCpu += NumberUtil.multiply(dRows, Math.pow(1.1, filterDigest.getScriptCount())); + dCpu += NumberUtil.multiply(dRows, Math.pow(1.1, filterDigest.scriptCount())); break; // Ignore cost the LIMIT but it will affect the rows count. // Try to reduce the rows count by 1 to make the cost cheaper slightly than non-push down. // Because we'd like to push down LIMIT even when the fetch in LIMIT is greater than // dRows. case LIMIT: - dRows = Math.min(dRows, ((LimitDigest) operation.getDigest()).getLimit()) - 1; + dRows = Math.min(dRows, ((LimitDigest) operation.digest()).limit()) - 1; break; default: // No-op for unhandled cases @@ -203,7 +219,7 @@ public double estimateRowCount(RelMetadataQuery mq) { /** See source in {@link org.apache.calcite.rel.core.Aggregate::computeSelfCost} */ private static float getAggMultiplier(PushDownOperation operation) { // START CALCITE - List aggCalls = ((Aggregate) operation.getDigest()).getAggCallList(); + List aggCalls = ((Aggregate) operation.digest()).getAggCallList(); float multiplier = 1f + (float) aggCalls.size() * 0.125f; for (AggregateCall aggCall : aggCalls) { if (aggCall.getAggregation().getName().equals("SUM")) { @@ -216,7 +232,7 @@ private static float getAggMultiplier(PushDownOperation operation) { // For script aggregation, we need to multiply the multiplier by 1.1 to make up the cost. As we // prefer to have non-script agg push down after optimized by {@link PPLAggregateConvertRule} - multiplier *= (float) Math.pow(1.1f, ((AggPushDownAction) operation.getAction()).getScriptCount()); + multiplier *= (float) Math.pow(1.1f, ((AggPushDownAction) operation.action()).getScriptCount()); return multiplier; } @@ -229,31 +245,60 @@ protected abstract AbstractCalciteIndexScan buildScan( RelDataType schema, PushDownContext pushDownContext); - private List getCollationNames(List collations) { + protected List getCollationNames(List collations) { return collations.stream() .map(collation -> getRowType().getFieldNames().get(collation.getFieldIndex())) .collect(Collectors.toList()); } /** - * Check if the sort by collations contains any aggregators that are pushed down. E.g. In `stats - * avg(age) as avg_age by state | sort avg_age`, the sort clause has `avg_age` which is an - * aggregator. The function will return true in this case. + * Check if all sort-by collations equal aggregators that are pushed down. E.g. In `stats avg(age) + * as avg_age, sum(age) as sum_age by state | sort avg_age, sum_age`, the sort keys `avg_age`, + * `sum_age` which equal the pushed down aggregators `avg(age)`, `sum(age)`. + * + * @param collations List of collation names to check against aggregators. + * @return True if all collation names match all aggregator output, false otherwise. + */ + protected boolean isAllCollationNamesEqualAggregators(List collations) { + Stream aggregates = + pushDownContext.stream() + .filter(action -> action.type() == AGGREGATION) + .map(action -> ((LogicalAggregate) action.digest())); + return aggregates + .map(aggregate -> isAllCollationNamesEqualAggregators(aggregate, collations)) + .reduce(false, Boolean::logicalOr); + } + + private boolean isAllCollationNamesEqualAggregators( + LogicalAggregate aggregate, List collations) { + List fieldNames = aggregate.getRowType().getFieldNames(); + // The output fields of the aggregate are in the format of + // [...grouping fields, ...aggregator fields], so we set an offset to skip + // the grouping fields. + int groupOffset = aggregate.getGroupSet().cardinality(); + List fieldsWithoutGrouping = fieldNames.subList(groupOffset, fieldNames.size()); + return new HashSet<>(collations).equals(new HashSet<>(fieldsWithoutGrouping)); + } + + /** + * Check if any sort-by collations is in aggregators that are pushed down. E.g. In `stats avg(age) + * as avg_age by state | sort avg_age`, the sort clause has `avg_age` which is an aggregator. The + * function will return true in this case. * * @param collations List of collation names to check against aggregators. * @return True if any collation name matches an aggregator output, false otherwise. */ - private boolean hasAggregatorInSortBy(List collations) { + protected boolean isAnyCollationNameInAggregators(List collations) { Stream aggregates = pushDownContext.stream() - .filter(action -> action.getType() == AGGREGATION) - .map(action -> ((LogicalAggregate) action.getDigest())); + .filter(action -> action.type() == AGGREGATION) + .map(action -> ((LogicalAggregate) action.digest())); return aggregates - .map(aggregate -> isAnyCollationNameInAggregateOutput(aggregate, collations)) + .map(aggregate -> isAnyCollationNameInAggregators(aggregate, collations)) .reduce(false, Boolean::logicalOr); } - private static boolean isAnyCollationNameInAggregateOutput( + private boolean isAnyCollationNameInAggregators( LogicalAggregate aggregate, List collations) { List fieldNames = aggregate.getRowType().getFieldNames(); // The output fields of the aggregate are in the format of @@ -274,7 +319,8 @@ private static boolean isAnyCollationNameInAggregateOutput( public AbstractCalciteIndexScan pushDownSort(List collations) { try { List collationNames = getCollationNames(collations); - if (getPushDownContext().isAggregatePushed() && hasAggregatorInSortBy(collationNames)) { + if (getPushDownContext().isAggregatePushed() + && isAnyCollationNameInAggregators(collationNames)) { // If aggregation is pushed down, we cannot push down sorts where its by fields contain // aggregators. return null; @@ -357,4 +403,22 @@ public AbstractCalciteIndexScan pushDownSort(List collations) } return null; } + + /** + * CalciteOpenSearchIndexScan doesn't allow push-down anymore (except Sort under some strict + * condition) after Aggregate push-down. + */ + public boolean noAggregatePushed() { + if (this.getPushDownContext().isAggregatePushed()) return false; + final RelOptTable table = this.getTable(); + return table.unwrap(OpenSearchIndex.class) != null; + } + + public boolean isLimitPushed() { + return this.getPushDownContext().isLimitPushed(); + } + + public boolean isMetricsOrderPushed() { + return this.getPushDownContext().isMetricOrderPushed(); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java index 5fd104cc794..ca2370dfa98 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java @@ -34,6 +34,7 @@ import org.opensearch.sql.calcite.plan.Scannable; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownContext; import org.opensearch.sql.opensearch.util.OpenSearchRelOptUtil; /** The physical relational operator representing a scan of an OpenSearchIndex type. */ diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java index e74a1efcb57..eaefc921ca3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java @@ -25,6 +25,7 @@ import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.hint.RelHint; import org.apache.calcite.rel.logical.LogicalSort; import org.apache.calcite.rel.type.RelDataType; @@ -37,6 +38,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.type.ExprCoreType; @@ -50,6 +53,14 @@ import org.opensearch.sql.opensearch.request.PredicateAnalyzer.QueryExpression; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.scan.context.AbstractAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggPushDownAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggregationBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.FilterDigest; +import org.opensearch.sql.opensearch.storage.scan.context.LimitDigest; +import org.opensearch.sql.opensearch.storage.scan.context.OSRequestBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownContext; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownType; /** The logical relational operator representing a scan of an OpenSearchIndex type. */ @Getter @@ -99,6 +110,11 @@ public CalciteLogicalIndexScan copyWithNewSchema(RelDataType schema) { getCluster(), traitSet, hints, table, osIndex, schema, pushDownContext.clone()); } + public CalciteLogicalIndexScan copyWithNewTraitSet(RelTraitSet traitSet) { + return new CalciteLogicalIndexScan( + getCluster(), traitSet, hints, table, osIndex, schema, pushDownContext.clone()); + } + @Override public void register(RelOptPlanner planner) { super.register(planner); @@ -269,6 +285,37 @@ private RelTraitSet reIndexCollations(List selectedColumns) { return newTraitSet; } + public CalciteLogicalIndexScan pushDownSortAggregateMetrics(Sort sort) { + try { + if (!pushDownContext.isAggregatePushed()) return null; + List aggregationBuilders = + pushDownContext.getAggPushDownAction().getAggregationBuilder().getLeft(); + if (aggregationBuilders.size() != 1) { + return null; + } + if (!(aggregationBuilders.get(0) instanceof CompositeAggregationBuilder)) { + return null; + } + List collationNames = getCollationNames(sort.getCollation().getFieldCollations()); + if (!isAllCollationNamesEqualAggregators(collationNames)) { + return null; + } + AbstractAction newAction = + (AggregationBuilderAction) + aggAction -> + aggAction.pushDownSortAggMetrics( + sort.getCollation().getFieldCollations(), rowType.getFieldNames()); + Object digest = sort.getCollation().getFieldCollations(); + pushDownContext.add(PushDownType.SORT_AGG_METRICS, digest, newAction); + return copyWithNewTraitSet(sort.getTraitSet()); + } catch (Exception e) { + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot pushdown the sort aggregate {}", sort, e); + } + } + return null; + } + public AbstractRelNode pushDownAggregate(Aggregate aggregate, Project project) { try { CalciteLogicalIndexScan newScan = @@ -288,9 +335,18 @@ public AbstractRelNode pushDownAggregate(Aggregate aggregate, Project project) { .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); List outputFields = aggregate.getRowType().getFieldNames(); int bucketSize = osIndex.getBucketSize(); + boolean bucketNullable = + Boolean.parseBoolean( + aggregate.getHints().stream() + .filter(hits -> hits.hintName.equals("stats_args")) + .map(hint -> hint.kvOptions.getOrDefault(Argument.BUCKET_NULLABLE, "true")) + .findFirst() + .orElseGet(() -> "true")); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper( + getRowType(), fieldTypes, getCluster(), bucketNullable, bucketSize); final Pair, OpenSearchAggregationResponseParser> aggregationBuilder = - AggregateAnalyzer.analyze( - aggregate, project, getRowType(), fieldTypes, outputFields, getCluster(), bucketSize); + AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); Map extendedTypeMapping = aggregate.getRowType().getFieldList().stream() .collect( diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java deleted file mode 100644 index 0397ffad3db..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownContext.java +++ /dev/null @@ -1,381 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.opensearch.storage.scan; - -import com.google.common.collect.Iterators; -import java.util.AbstractCollection; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import lombok.Getter; -import org.apache.calcite.rel.RelFieldCollation; -import org.apache.calcite.rel.RelFieldCollation.Direction; -import org.apache.calcite.rel.RelFieldCollation.NullDirection; -import org.apache.calcite.rex.RexNode; -import org.apache.commons.lang3.tuple.Pair; -import org.jetbrains.annotations.NotNull; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.aggregations.AggregationBuilders; -import org.opensearch.search.aggregations.AggregatorFactories.Builder; -import org.opensearch.search.aggregations.BucketOrder; -import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; -import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder; -import org.opensearch.search.aggregations.bucket.missing.MissingOrder; -import org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationBuilder; -import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; -import org.opensearch.search.sort.SortOrder; -import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; -import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; -import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; -import org.opensearch.sql.opensearch.storage.OpenSearchIndex; - -@Getter -public class PushDownContext extends AbstractCollection { - private final OpenSearchIndex osIndex; - private final OpenSearchRequestBuilder requestBuilder; - private ArrayDeque operationsForRequestBuilder; - - private boolean isAggregatePushed = false; - private AggPushDownAction aggPushDownAction; - private ArrayDeque operationsForAgg; - - private boolean isLimitPushed = false; - private boolean isProjectPushed = false; - - public PushDownContext(OpenSearchIndex osIndex) { - this.osIndex = osIndex; - this.requestBuilder = osIndex.createRequestBuilder(); - } - - @Override - public PushDownContext clone() { - PushDownContext newContext = new PushDownContext(osIndex); - newContext.addAll(this); - return newContext; - } - - /** - * Create a new {@link PushDownContext} without the collation action. - * - * @return A new push-down context without the collation action. - */ - public PushDownContext cloneWithoutSort() { - PushDownContext newContext = new PushDownContext(osIndex); - for (PushDownOperation action : this) { - if (action.getType() != PushDownType.SORT) { - newContext.add(action); - } - } - return newContext; - } - - @NotNull - @Override - public Iterator iterator() { - if (operationsForRequestBuilder == null) { - return Collections.emptyIterator(); - } else if (operationsForAgg == null) { - return operationsForRequestBuilder.iterator(); - } else { - return Iterators.concat(operationsForRequestBuilder.iterator(), operationsForAgg.iterator()); - } - } - - @Override - public int size() { - return (operationsForRequestBuilder == null ? 0 : operationsForRequestBuilder.size()) - + (operationsForAgg == null ? 0 : operationsForAgg.size()); - } - - ArrayDeque getOperationsForRequestBuilder() { - if (operationsForRequestBuilder == null) { - this.operationsForRequestBuilder = new ArrayDeque<>(); - } - return operationsForRequestBuilder; - } - - ArrayDeque getOperationsForAgg() { - if (operationsForAgg == null) { - this.operationsForAgg = new ArrayDeque<>(); - } - return operationsForAgg; - } - - @Override - public boolean add(PushDownOperation operation) { - if (operation.getType() == PushDownType.AGGREGATION) { - isAggregatePushed = true; - this.aggPushDownAction = (AggPushDownAction) operation.getAction(); - } - if (operation.getType() == PushDownType.LIMIT) { - isLimitPushed = true; - } - if (operation.getType() == PushDownType.PROJECT) { - isProjectPushed = true; - } - operation.getAction().transform(this, operation); - return true; - } - - void add(PushDownType type, Object digest, AbstractAction action) { - add(new PushDownOperation(type, digest, action)); - } - - public boolean containsDigest(Object digest) { - return this.stream().anyMatch(action -> action.getDigest().equals(digest)); - } - - public OpenSearchRequestBuilder createRequestBuilder() { - OpenSearchRequestBuilder newRequestBuilder = osIndex.createRequestBuilder(); - if (operationsForRequestBuilder != null) { - operationsForRequestBuilder.forEach( - operation -> ((OSRequestBuilderAction) operation.getAction()).apply(newRequestBuilder)); - } - return newRequestBuilder; - } -} - -enum PushDownType { - FILTER, - PROJECT, - AGGREGATION, - SORT, - LIMIT, - SCRIPT, - COLLAPSE - // HIGHLIGHT, - // NESTED -} - -/** - * Represents a push down operation that can be applied to an OpenSearchRequestBuilder. - * - * type PushDownType enum - * digest the digest of the pushed down operator - * action the lambda action to apply on the OpenSearchRequestBuilder - */ -@Getter -class PushDownOperation { - private final PushDownType type; - private final Object digest; - private final AbstractAction action; - - public PushDownOperation(PushDownType type, Object digest, AbstractAction action) { - this.type = type; - this.digest = digest; - this.action = action; - } - - @Override - public String toString() { - return type + "->" + digest; - } -} - -interface AbstractAction { - void apply(T target); - - void transform(PushDownContext context, PushDownOperation operation); -} - -interface OSRequestBuilderAction extends AbstractAction { - default void transform(PushDownContext context, PushDownOperation operation) { - apply(context.getRequestBuilder()); - context.getOperationsForRequestBuilder().add(operation); - } -} - -interface AggregationBuilderAction extends AbstractAction { - default void transform(PushDownContext context, PushDownOperation operation) { - apply(context.getAggPushDownAction()); - context.getOperationsForAgg().add(operation); - } -} - -@Getter -class FilterDigest { - private final int scriptCount; - private final RexNode condition; - - public FilterDigest(int scriptCount, RexNode condition) { - this.scriptCount = scriptCount; - this.condition = condition; - } - - @Override - public String toString() { - return condition.toString(); - } -} - -@Getter -class LimitDigest { - private final int limit; - private final int offset; - - public LimitDigest(int limit, int offset) { - this.limit = limit; - this.offset = offset; - } - - @Override - public String toString() { - return offset == 0 ? String.valueOf(limit) : "[" + limit + " from " + offset + "]"; - } -} - -// TODO: shall we do deep copy for this action since it's mutable? -class AggPushDownAction implements OSRequestBuilderAction { - - private Pair, OpenSearchAggregationResponseParser> aggregationBuilder; - private final Map extendedTypeMapping; - @Getter private final long scriptCount; - // Record the output field names of all buckets as the sequence of buckets - private List bucketNames; - - public AggPushDownAction( - Pair, OpenSearchAggregationResponseParser> aggregationBuilder, - Map extendedTypeMapping, - List bucketNames) { - this.aggregationBuilder = aggregationBuilder; - this.extendedTypeMapping = extendedTypeMapping; - this.scriptCount = - aggregationBuilder.getLeft().stream().filter(this::isScriptAggBuilder).count(); - this.bucketNames = bucketNames; - } - - private boolean isScriptAggBuilder(AggregationBuilder aggBuilder) { - return aggBuilder instanceof ValuesSourceAggregationBuilder - && ((ValuesSourceAggregationBuilder)aggBuilder).script() != null; - } - - @Override - public void apply(OpenSearchRequestBuilder requestBuilder) { - requestBuilder.pushDownAggregation(aggregationBuilder); - requestBuilder.pushTypeMapping(extendedTypeMapping); - } - - public void pushDownSortIntoAggBucket( - List collations, List fieldNames) { - // aggregationBuilder.getLeft() could be empty when count agg optimization works - if (aggregationBuilder.getLeft().isEmpty()) return; - AggregationBuilder builder = aggregationBuilder.getLeft().get(0); - List selected = new ArrayList<>(collations.size()); - if (builder instanceof CompositeAggregationBuilder) { - // It will always use a single CompositeAggregationBuilder for the aggregation with GroupBy - // See {@link AggregateAnalyzer} - CompositeAggregationBuilder compositeAggBuilder = (CompositeAggregationBuilder) builder; - List> buckets = compositeAggBuilder.sources(); - List> newBuckets = new ArrayList<>(buckets.size()); - List newBucketNames = new ArrayList<>(buckets.size()); - // Have to put the collation required buckets first, then the rest of buckets. - collations.forEach( - collation -> { - /* - Must find the bucket by field name because: - 1. The sequence of buckets may have changed after sort push-down. - 2. The schema of scan operator may be inconsistent with the sequence of buckets - after project push-down. - */ - String bucketName = fieldNames.get(collation.getFieldIndex()); - CompositeValuesSourceBuilder bucket = buckets.get(bucketNames.indexOf(bucketName)); - Direction direction = collation.getDirection(); - NullDirection nullDirection = collation.nullDirection; - SortOrder order = - Direction.DESCENDING.equals(direction) ? SortOrder.DESC : SortOrder.ASC; - if (bucket.missingBucket()) { - MissingOrder missingOrder; - switch (nullDirection) { - case FIRST: - missingOrder = MissingOrder.FIRST; - break; - case LAST: - missingOrder = MissingOrder.LAST; - break; - default: - missingOrder = MissingOrder.DEFAULT; - break; - } - bucket.missingOrder(missingOrder); - } - newBuckets.add(bucket.order(order)); - newBucketNames.add(bucketName); - selected.add(bucketName); - }); - buckets.stream() - .map(CompositeValuesSourceBuilder::name) - .filter(name -> !selected.contains(name)) - .forEach( - name -> { - newBuckets.add(buckets.get(bucketNames.indexOf(name))); - newBucketNames.add(name); - }); - Builder newAggBuilder = new Builder(); - compositeAggBuilder.getSubAggregations().forEach(newAggBuilder::addAggregator); - aggregationBuilder = - Pair.of( - Collections.singletonList( - AggregationBuilders.composite("composite_buckets", newBuckets) - .subAggregations(newAggBuilder) - .size(compositeAggBuilder.size())), - aggregationBuilder.getRight()); - bucketNames = newBucketNames; - } - if (builder instanceof TermsAggregationBuilder) { - ((TermsAggregationBuilder)builder).order(BucketOrder.key(!collations.get(0).getDirection().isDescending())); - } - // TODO for MultiTermsAggregationBuilder - } - - /** - * Check if the limit can be pushed down into aggregation bucket when the limit size is less than - * bucket number. - */ - public boolean pushDownLimitIntoBucketSize(Integer size) { - // aggregationBuilder.getLeft() could be empty when count agg optimization works - if (aggregationBuilder.getLeft().isEmpty()) return false; - AggregationBuilder builder = aggregationBuilder.getLeft().get(0); - if (builder instanceof CompositeAggregationBuilder) { - CompositeAggregationBuilder compositeAggBuilder = (CompositeAggregationBuilder)builder; - if (size < compositeAggBuilder.size()) { - compositeAggBuilder.size(size); - return true; - } else { - return false; - } - } - if (builder instanceof TermsAggregationBuilder) { - TermsAggregationBuilder termsAggBuilder = (TermsAggregationBuilder) builder; - if (size < termsAggBuilder.size()) { - termsAggBuilder.size(size); - return true; - } else { - return false; - } - } - if (builder instanceof MultiTermsAggregationBuilder) { - MultiTermsAggregationBuilder multiTermsAggBuilder = (MultiTermsAggregationBuilder) builder; - if (size < multiTermsAggBuilder.size()) { - multiTermsAggBuilder.size(size); - return true; - } else { - return false; - } - } - // now we only have Composite, Terms and MultiTerms bucket aggregations, - // add code here when we could support more in the future. - if (builder instanceof ValuesSourceAggregationBuilder.LeafOnly) { - // Note: all metric aggregations will be treated as pushed since it generates only one row. - return true; - } - throw new OpenSearchRequestBuilder.PushDownUnSupportedException( - "Unknown aggregation builder " + builder.getClass().getSimpleName()); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java new file mode 100644 index 00000000000..65ef6233ffb --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AbstractAction.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +/** + * A lambda action to apply on the target T + * + * @param the target type + */ +public interface AbstractAction { + void apply(T target); + + /** + * Apply the action on the target T and add the operation to the context + * + * @param context the context to add the operation to + * @param operation the operation to add to the context + */ + void transform(PushDownContext context, PushDownOperation operation); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java new file mode 100644 index 00000000000..23206a0d3da --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggPushDownAction.java @@ -0,0 +1,339 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.search.aggregations.AbstractAggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.composite.DateHistogramValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.composite.HistogramValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; +import org.opensearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; +import org.opensearch.search.aggregations.bucket.missing.MissingOrder; +import org.opensearch.search.aggregations.bucket.terms.MultiTermsAggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.support.MultiTermsValuesSourceConfig; +import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.response.agg.BucketAggregationParser; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.MetricParserHelper; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; + +/** A lambda aggregation pushdown action to apply on the {@link OpenSearchRequestBuilder} */ +@Getter +@EqualsAndHashCode +public class AggPushDownAction implements OSRequestBuilderAction { + + private Pair, OpenSearchAggregationResponseParser> aggregationBuilder; + private final Map extendedTypeMapping; + private final long scriptCount; + // Record the output field names of all buckets as the sequence of buckets + private List bucketNames; + + public AggPushDownAction( + Pair, OpenSearchAggregationResponseParser> aggregationBuilder, + Map extendedTypeMapping, + List bucketNames) { + this.aggregationBuilder = aggregationBuilder; + this.extendedTypeMapping = extendedTypeMapping; + this.scriptCount = + aggregationBuilder.getLeft().stream().filter(this::isScriptAggBuilder).count(); + this.bucketNames = bucketNames; + } + + private boolean isScriptAggBuilder(AggregationBuilder aggBuilder) { + return aggBuilder instanceof ValuesSourceAggregationBuilder + && ((ValuesSourceAggregationBuilder) aggBuilder).script() != null; + } + + @Override + public void apply(OpenSearchRequestBuilder requestBuilder) { + requestBuilder.pushDownAggregation(aggregationBuilder); + requestBuilder.pushTypeMapping(extendedTypeMapping); + } + + private BucketAggregationParser convertTo(OpenSearchAggregationResponseParser parser) { + if (parser instanceof BucketAggregationParser) { + return (BucketAggregationParser) parser; + } else if (parser instanceof CompositeAggregationParser) { + MetricParserHelper helper = ((CompositeAggregationParser) parser).getMetricsParser(); + return new BucketAggregationParser( + helper.getMetricParserMap().values().stream().collect(Collectors.toList()), helper.getCountAggNameList()); + } else { + throw new IllegalStateException("Unexpected parser type: " + parser.getClass()); + } + } + + public void pushDownSortAggMetrics(List collations, List fieldNames) { + if (aggregationBuilder.getLeft().isEmpty()) return; + AggregationBuilder builder = aggregationBuilder.getLeft().get(0); + if (builder instanceof CompositeAggregationBuilder) { + CompositeAggregationBuilder composite = (CompositeAggregationBuilder) builder; + String path = getAggregationPath(collations, fieldNames, composite); + BucketOrder bucketOrder = + collations.get(0).getDirection() == RelFieldCollation.Direction.ASCENDING + ? BucketOrder.aggregation(path, true) + : BucketOrder.aggregation(path, false); + + if (composite.sources().size() == 1) { + if (composite.sources().get(0) instanceof TermsValuesSourceBuilder) { + TermsValuesSourceBuilder terms = (TermsValuesSourceBuilder) composite.sources().get(0); + if (!terms.missingBucket()) { + TermsAggregationBuilder termsBuilder = new TermsAggregationBuilder(terms.name()); + termsBuilder.size(composite.size()); + termsBuilder.field(terms.field()); + if (terms.userValuetypeHint() != null) { + termsBuilder.userValueTypeHint(terms.userValuetypeHint()); + } + termsBuilder.order(bucketOrder); + attachSubAggregations(composite.getSubAggregations(), path, termsBuilder); + aggregationBuilder = + Pair.of( + Collections.singletonList(termsBuilder), + convertTo(aggregationBuilder.getRight())); + return; + } + } else if (composite.sources().get(0) instanceof DateHistogramValuesSourceBuilder) { + DateHistogramValuesSourceBuilder dateHisto = (DateHistogramValuesSourceBuilder) composite.sources().get(0); + DateHistogramAggregationBuilder dateHistoBuilder = + new DateHistogramAggregationBuilder(dateHisto.name()); + dateHistoBuilder.field(dateHisto.field()); + try { + dateHistoBuilder.fixedInterval(dateHisto.getIntervalAsFixed()); + } catch (IllegalArgumentException e) { + dateHistoBuilder.calendarInterval(dateHisto.getIntervalAsCalendar()); + } + if (dateHisto.userValuetypeHint() != null) { + dateHistoBuilder.userValueTypeHint(dateHisto.userValuetypeHint()); + } + dateHistoBuilder.order(bucketOrder); + attachSubAggregations(composite.getSubAggregations(), path, dateHistoBuilder); + aggregationBuilder = + Pair.of( + Collections.singletonList(dateHistoBuilder), + convertTo(aggregationBuilder.getRight())); + return; + } else if (composite.sources().get(0) instanceof HistogramValuesSourceBuilder) { + HistogramValuesSourceBuilder histo = (HistogramValuesSourceBuilder) composite.sources().get(0); + if (!histo.missingBucket()) { + HistogramAggregationBuilder histoBuilder = new HistogramAggregationBuilder(histo.name()); + histoBuilder.field(histo.field()); + histoBuilder.interval(histo.interval()); + if (histo.userValuetypeHint() != null) { + histoBuilder.userValueTypeHint(histo.userValuetypeHint()); + } + histoBuilder.order(bucketOrder); + attachSubAggregations(composite.getSubAggregations(), path, histoBuilder); + aggregationBuilder = + Pair.of( + Collections.singletonList(histoBuilder), + convertTo(aggregationBuilder.getRight())); + return; + } + } + } else { + if (composite.sources().stream() + .allMatch( + src -> src instanceof TermsValuesSourceBuilder && !((TermsValuesSourceBuilder) src).missingBucket())) { + // multi-term agg + MultiTermsAggregationBuilder multiTermsBuilder = + new MultiTermsAggregationBuilder("multi_terms_buckets"); + multiTermsBuilder.size(composite.size()); + multiTermsBuilder.terms( + composite.sources().stream() + .map(TermsValuesSourceBuilder.class::cast) + .map( + termValue -> { + MultiTermsValuesSourceConfig.Builder config = + new MultiTermsValuesSourceConfig.Builder(); + config.setFieldName(termValue.field()); + config.setUserValueTypeHint(termValue.userValuetypeHint()); + return config.build(); + }) + .collect(Collectors.toList())); + attachSubAggregations(composite.getSubAggregations(), path, multiTermsBuilder); + aggregationBuilder = + Pair.of( + Collections.singletonList(multiTermsBuilder), + convertTo(aggregationBuilder.getRight())); + return; + } + } + throw new OpenSearchRequestBuilder.PushDownUnSupportedException( + "Cannot pushdown sort aggregate metrics"); + } + } + + private String getAggregationPath( + List collations, + List fieldNames, + CompositeAggregationBuilder composite) { + String path; + AggregationBuilder metric = composite.getSubAggregations().stream().findFirst().orElse(null); + if (metric == null) { + // count agg optimized, get the path name from field names + path = fieldNames.get(collations.get(0).getFieldIndex()); + } else if (metric instanceof ValuesSourceAggregationBuilder.LeafOnly) { + path = metric.getName(); + } else { + // we do not support pushdown sort aggregate metrics for nested aggregation + throw new OpenSearchRequestBuilder.PushDownUnSupportedException( + "Cannot pushdown sort aggregate metrics, composite.getSubAggregations() is not a" + + " LeafOnly"); + } + return path; + } + + private > T attachSubAggregations( + Collection subAggregations, String path, T aggregationBuilder) { + AggregatorFactories.Builder metricBuilder = new AggregatorFactories.Builder(); + if (subAggregations.isEmpty()) { + metricBuilder.addAggregator(AggregationBuilders.count(path).field("_index")); + } else { + metricBuilder.addAggregator(subAggregations.stream().collect(Collectors.toList()).get(0)); + } + aggregationBuilder.subAggregations(metricBuilder); + return aggregationBuilder; + } + + public void pushDownSortIntoAggBucket( + List collations, List fieldNames) { + // aggregationBuilder.getLeft() could be empty when count agg optimization works + if (aggregationBuilder.getLeft().isEmpty()) return; + AggregationBuilder builder = aggregationBuilder.getLeft().get(0); + List selected = new ArrayList<>(collations.size()); + if (builder instanceof CompositeAggregationBuilder) { + CompositeAggregationBuilder compositeAggBuilder = (CompositeAggregationBuilder) builder; + // It will always use a single CompositeAggregationBuilder for the aggregation with GroupBy + // See {@link AggregateAnalyzer} + List> buckets = compositeAggBuilder.sources(); + List> newBuckets = new ArrayList<>(buckets.size()); + List newBucketNames = new ArrayList<>(buckets.size()); + // Have to put the collation required buckets first, then the rest of buckets. + collations.forEach( + collation -> { + /* + Must find the bucket by field name because: + 1. The sequence of buckets may have changed after sort push-down. + 2. The schema of scan operator may be inconsistent with the sequence of buckets + after project push-down. + */ + String bucketName = fieldNames.get(collation.getFieldIndex()); + CompositeValuesSourceBuilder bucket = buckets.get(bucketNames.indexOf(bucketName)); + RelFieldCollation.Direction direction = collation.getDirection(); + RelFieldCollation.NullDirection nullDirection = collation.nullDirection; + SortOrder order = + RelFieldCollation.Direction.DESCENDING.equals(direction) + ? SortOrder.DESC + : SortOrder.ASC; + if (bucket.missingBucket()) { + MissingOrder missingOrder; + switch (nullDirection) { + case FIRST: + missingOrder = MissingOrder.FIRST; + break; + case LAST: + missingOrder = MissingOrder.LAST; + break; + default: + missingOrder = MissingOrder.DEFAULT; + } + bucket.missingOrder(missingOrder); + } + newBuckets.add(bucket.order(order)); + newBucketNames.add(bucketName); + selected.add(bucketName); + }); + buckets.stream() + .map(CompositeValuesSourceBuilder::name) + .filter(name -> !selected.contains(name)) + .forEach( + name -> { + newBuckets.add(buckets.get(bucketNames.indexOf(name))); + newBucketNames.add(name); + }); + AggregatorFactories.Builder newAggBuilder = new AggregatorFactories.Builder(); + compositeAggBuilder.getSubAggregations().forEach(newAggBuilder::addAggregator); + aggregationBuilder = + Pair.of( + Collections.singletonList( + AggregationBuilders.composite("composite_buckets", newBuckets) + .subAggregations(newAggBuilder) + .size(compositeAggBuilder.size())), + aggregationBuilder.getRight()); + bucketNames = newBucketNames; + } + if (builder instanceof TermsAggregationBuilder) { + TermsAggregationBuilder termsAggBuilder = (TermsAggregationBuilder) builder; + termsAggBuilder.order(BucketOrder.key(!collations.get(0).getDirection().isDescending())); + } + // TODO for MultiTermsAggregationBuilder + } + + /** + * Check if the limit can be pushed down into aggregation bucket when the limit size is less than + * bucket number. + */ + public boolean pushDownLimitIntoBucketSize(Integer size) { + // aggregationBuilder.getLeft() could be empty when count agg optimization works + if (aggregationBuilder.getLeft().isEmpty()) return false; + AggregationBuilder builder = aggregationBuilder.getLeft().get(0); + if (builder instanceof CompositeAggregationBuilder) { + CompositeAggregationBuilder compositeAggBuilder = (CompositeAggregationBuilder) builder; + if (size < compositeAggBuilder.size()) { + compositeAggBuilder.size(size); + return true; + } else { + return false; + } + } + if (builder instanceof TermsAggregationBuilder) { + TermsAggregationBuilder termsAggBuilder = (TermsAggregationBuilder) builder; + if (size < termsAggBuilder.size()) { + termsAggBuilder.size(size); + return true; + } else { + return false; + } + } + if (builder instanceof MultiTermsAggregationBuilder) { + MultiTermsAggregationBuilder multiTermsAggBuilder = (MultiTermsAggregationBuilder) builder; + if (size < multiTermsAggBuilder.size()) { + multiTermsAggBuilder.size(size); + return true; + } else { + return false; + } + } + // now we only have Composite, Terms and MultiTerms bucket aggregations, + // add code here when we could support more in the future. + if (builder instanceof ValuesSourceAggregationBuilder.LeafOnly) { + // Note: all metric aggregations will be treated as pushed since it generates only one row. + return true; + } + throw new OpenSearchRequestBuilder.PushDownUnSupportedException( + "Unknown aggregation builder " + builder.getClass().getSimpleName()); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java new file mode 100644 index 00000000000..cd3e84bf7cf --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/AggregationBuilderAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +/** A lambda action to apply on the {@link AggPushDownAction} */ +public interface AggregationBuilderAction extends AbstractAction { + /** + * Apply the action on the target {@link AggPushDownAction} and add the operation to the context + * + * @param context the context to add the operation to + * @param operation the operation to add to the context + */ + default void transform(PushDownContext context, PushDownOperation operation) { + apply(context.getAggPushDownAction()); + context.getOperationsForAgg().add(operation); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/FilterDigest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/FilterDigest.java new file mode 100644 index 00000000000..49b018540bc --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/FilterDigest.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +import lombok.EqualsAndHashCode; +import lombok.ToString; +import org.apache.calcite.rex.RexNode; + + +@EqualsAndHashCode +public class FilterDigest { + private final int scriptCount; + private final RexNode condition; + + public FilterDigest(int scriptCount, RexNode condition) { + this.scriptCount = scriptCount; + this.condition = condition; + } + + public int scriptCount() { + return scriptCount; + } + + public RexNode condition() { + return condition; + } + + @Override + public String toString() { + return condition.toString(); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/LimitDigest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/LimitDigest.java new file mode 100644 index 00000000000..c36566cffd5 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/LimitDigest.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +import lombok.EqualsAndHashCode; +import lombok.ToString; + +@EqualsAndHashCode +public class LimitDigest { + private final int limit; + private final int offset; + + public LimitDigest(int limit, int offset) { + this.limit = limit; + this.offset = offset; + } + + public int limit() { + return limit; + } + + public int offset() { + return offset; + } + + @Override + public String toString() { + return offset == 0 ? String.valueOf(limit) : "[" + limit + " from " + offset + "]"; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java new file mode 100644 index 00000000000..bba33883b49 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/OSRequestBuilderAction.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; + +/** A lambda action to apply on the {@link OpenSearchRequestBuilder} */ +public interface OSRequestBuilderAction extends AbstractAction { + /** + * Apply the action on the target {@link OpenSearchRequestBuilder} and add the operation to the + * context + * + * @param context the context to add the operation to + * @param operation the operation to add to the context + */ + default void transform(PushDownContext context, PushDownOperation operation) { + apply(context.getRequestBuilder()); + context.getOperationsForRequestBuilder().add(operation); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java new file mode 100644 index 00000000000..dd36c2090b9 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownContext.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +import com.google.common.collect.Iterators; +import java.util.AbstractCollection; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Iterator; +import lombok.Getter; +import org.jetbrains.annotations.NotNull; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; + +/** Push down context is used to store all the push down operations that are applied to the query */ +@Getter +public class PushDownContext extends AbstractCollection { + private final OpenSearchIndex osIndex; + private final OpenSearchRequestBuilder requestBuilder; + private ArrayDeque operationsForRequestBuilder; + + private boolean isAggregatePushed = false; + private AggPushDownAction aggPushDownAction; + private ArrayDeque operationsForAgg; + + private boolean isLimitPushed = false; + private boolean isProjectPushed = false; + private boolean isMetricOrderPushed = false; + + public PushDownContext(OpenSearchIndex osIndex) { + this.osIndex = osIndex; + this.requestBuilder = osIndex.createRequestBuilder(); + } + + @Override + public PushDownContext clone() { + PushDownContext newContext = new PushDownContext(osIndex); + newContext.addAll(this); + return newContext; + } + + /** + * Create a new {@link PushDownContext} without the collation action. + * + * @return A new push-down context without the collation action. + */ + public PushDownContext cloneWithoutSort() { + PushDownContext newContext = new PushDownContext(osIndex); + for (PushDownOperation action : this) { + if (action.type() != PushDownType.SORT) { + newContext.add(action); + } + } + return newContext; + } + + @NotNull + @Override + public Iterator iterator() { + if (operationsForRequestBuilder == null) { + return Collections.emptyIterator(); + } else if (operationsForAgg == null) { + return operationsForRequestBuilder.iterator(); + } else { + return Iterators.concat(operationsForRequestBuilder.iterator(), operationsForAgg.iterator()); + } + } + + @Override + public int size() { + return (operationsForRequestBuilder == null ? 0 : operationsForRequestBuilder.size()) + + (operationsForAgg == null ? 0 : operationsForAgg.size()); + } + + ArrayDeque getOperationsForRequestBuilder() { + if (operationsForRequestBuilder == null) { + this.operationsForRequestBuilder = new ArrayDeque<>(); + } + return operationsForRequestBuilder; + } + + ArrayDeque getOperationsForAgg() { + if (operationsForAgg == null) { + this.operationsForAgg = new ArrayDeque<>(); + } + return operationsForAgg; + } + + @Override + public boolean add(PushDownOperation operation) { + if (operation.type() == PushDownType.AGGREGATION) { + isAggregatePushed = true; + this.aggPushDownAction = (AggPushDownAction) operation.action(); + } + if (operation.type() == PushDownType.LIMIT) { + isLimitPushed = true; + } + if (operation.type() == PushDownType.PROJECT) { + isProjectPushed = true; + } + if (operation.type() == PushDownType.SORT_AGG_METRICS) { + isMetricOrderPushed = true; + } + operation.action().transform(this, operation); + return true; + } + + public void add(PushDownType type, Object digest, AbstractAction action) { + add(new PushDownOperation(type, digest, action)); + } + + public boolean containsDigest(Object digest) { + return this.stream().anyMatch(action -> action.digest().equals(digest)); + } + + public OpenSearchRequestBuilder createRequestBuilder() { + OpenSearchRequestBuilder newRequestBuilder = osIndex.createRequestBuilder(); + if (operationsForRequestBuilder != null) { + operationsForRequestBuilder.forEach( + operation -> ((OSRequestBuilderAction) operation.action()).apply(newRequestBuilder)); + } + return newRequestBuilder; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownOperation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownOperation.java new file mode 100644 index 00000000000..a776bfcada2 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownOperation.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +import lombok.EqualsAndHashCode; +import lombok.ToString; + +/** + * Represents a push down operation that can be applied to an OpenSearchRequestBuilder. + */ +@EqualsAndHashCode +public class PushDownOperation { + private final PushDownType type; + private final Object digest; + private final AbstractAction action; + + public PushDownOperation(PushDownType type, Object digest, AbstractAction action) { + this.type = type; + this.digest = digest; + this.action = action; + } + + public PushDownType type() { + return type; + } + + public Object digest() { + return digest; + } + + public AbstractAction action() { + return action; + } + + @Override + public String toString() { + return type + "->" + digest; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java new file mode 100644 index 00000000000..2a9eccb7a0e --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/context/PushDownType.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan.context; + +/** Push down types. */ +public enum PushDownType { + FILTER, + PROJECT, + AGGREGATION, + SORT, + LIMIT, + SCRIPT, + COLLAPSE, + SORT_AGG_METRICS + // HIGHLIGHT, + // NESTED +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java index 4c78344e918..62014bf5d28 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java @@ -154,9 +154,10 @@ void analyze_aggCall_simple() throws ExpressionNotAnalyzableException { createMockAggregate( List.of(countCall, avgCall, sumCall, minCall, maxCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, outputFields, null, BUCKET_SIZE); + AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( "[{\"cnt\":{\"value_count\":{\"field\":\"_index\"}}}," + " {\"avg\":{\"avg\":{\"field\":\"a\"}}}," @@ -236,9 +237,10 @@ void analyze_aggCall_extended() throws ExpressionNotAnalyzableException { createMockAggregate( List.of(varSampCall, varPopCall, stddevSampCall, stddevPopCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, outputFields, null, BUCKET_SIZE); + AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( "[{\"var_samp\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}}," + " {\"var_pop\":{\"extended_stats\":{\"field\":\"a\",\"sigma\":2.0}}}," @@ -276,9 +278,10 @@ void analyze_groupBy() throws ExpressionNotAnalyzableException { List outputFields = List.of("a", "b", "cnt"); Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0, 1)); Project project = createMockProject(List.of(0, 1)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, outputFields, null, BUCKET_SIZE); + AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( "[{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[" @@ -316,12 +319,12 @@ void analyze_aggCall_TextWithoutKeyword() { "sum"); Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(2)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); ExpressionNotAnalyzableException exception = assertThrows( ExpressionNotAnalyzableException.class, - () -> - AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, List.of("sum"), null, BUCKET_SIZE)); + () -> AggregateAnalyzer.analyze(aggregate, project, List.of("sum"), helper)); assertEquals("[field] must not be null: [sum]", exception.getCause().getMessage()); } @@ -343,12 +346,12 @@ void analyze_groupBy_TextWithoutKeyword() { List outputFields = List.of("c", "cnt"); Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0)); Project project = createMockProject(List.of(2)); + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); ExpressionNotAnalyzableException exception = assertThrows( ExpressionNotAnalyzableException.class, - () -> - AggregateAnalyzer.analyze( - aggregate, project, rowType, fieldTypes, outputFields, null, BUCKET_SIZE)); + () -> AggregateAnalyzer.analyze(aggregate, project, outputFields, helper)); assertEquals("[field] must not be null", exception.getCause().getMessage()); } @@ -694,9 +697,11 @@ void verify() throws ExpressionNotAnalyzableException { if (agg.getInput(0) instanceof Project) { project = (Project) agg.getInput(0); } + AggregateAnalyzer.AggregateBuilderHelper helper = + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, agg.getCluster(), true, BUCKET_SIZE); Pair, OpenSearchAggregationResponseParser> result = - AggregateAnalyzer.analyze( - agg, project, rowType, fieldTypes, outputFields, agg.getCluster(), BUCKET_SIZE); + AggregateAnalyzer.analyze(agg, project, outputFields, helper); if (expectedDsl != null) { assertEquals(expectedDsl, result.getLeft().toString()); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScanCostTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScanCostTest.java index f02b40b9eae..c67d7cfaa3e 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScanCostTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/CalciteIndexScanCostTest.java @@ -49,6 +49,13 @@ import org.opensearch.sql.common.setting.Settings.Key; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.scan.context.AggPushDownAction; +import org.opensearch.sql.opensearch.storage.scan.context.AggregationBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.FilterDigest; +import org.opensearch.sql.opensearch.storage.scan.context.LimitDigest; +import org.opensearch.sql.opensearch.storage.scan.context.OSRequestBuilderAction; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownOperation; +import org.opensearch.sql.opensearch.storage.scan.context.PushDownType; @ExtendWith(MockitoExtension.class) public class CalciteIndexScanCostTest { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java index 0e0068b5169..1446c7b0470 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java @@ -938,4 +938,55 @@ public void testMinOnTimeField() { String expectedSparkSql = "SELECT MIN(`HIREDATE`) `min_hire_date`\nFROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } + + @Test + public void testSortAggregationMetrics1() { + String ppl = "source=EMP | stats bucket_nullable=false avg(SAL) as avg by DEPTNO | sort - avg"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalSort(sort0=[$0], dir0=[DESC-nulls-last])\n" + + " LogicalProject(avg=[$1], DEPTNO=[$0])\n" + + " LogicalAggregate(group=[{0}], avg=[AVG($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalFilter(condition=[IS NOT NULL($7)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "avg=2916.666666; DEPTNO=10\navg=2175.; DEPTNO=20\navg=1566.666666; DEPTNO=30\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT AVG(`SAL`) `avg`, `DEPTNO`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` IS NOT NULL\n" + + "GROUP BY `DEPTNO`\n" + + "ORDER BY 1 DESC"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testSortAggregationMetrics2() { + String ppl = + "source=EMP | stats avg(SAL) as avg by span(HIREDATE, 1year) as hiredate_span | sort" + + " avg"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalSort(sort0=[$0], dir0=[ASC-nulls-first])\n" + + " LogicalProject(avg=[$1], hiredate_span=[$0])\n" + + " LogicalAggregate(group=[{1}], avg=[AVG($0)])\n" + + " LogicalProject(SAL=[$5], hiredate_span=[SPAN($4, 1, 'y')])\n" + + " LogicalFilter(condition=[IS NOT NULL($4)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT AVG(`SAL`) `avg`, `SPAN`(`HIREDATE`, 1, 'y') `hiredate_span`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `HIREDATE` IS NOT NULL\n" + + "GROUP BY `SPAN`(`HIREDATE`, 1, 'y')\n" + + "ORDER BY 1"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } }