Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,22 @@
import com.google.common.collect.ImmutableList;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
Expand All @@ -38,6 +43,7 @@
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
Expand Down Expand Up @@ -501,13 +507,51 @@ public Void visitInputRef(RexInputRef inputRef) {
return selectedColumns;
}

// `RelDecorrelator` may generate a Project with duplicated fields, e.g. Project($0,$0).
// There will be problem if pushing down the pattern like `Aggregate(AGG($0),{1})-Project($0,$0)`,
// as it will lead to field-name conflict.
// We should wait and rely on `AggregateProjectMergeRule` to mitigate it by having this constraint
// Nevertheless, that rule cannot handle all cases if there is RexCall in the Project,
// e.g. Project($0, $0, +($0,1)). We cannot push down the Aggregate for this corner case.
// TODO: Simplify the Project where there is RexCall by adding a new rule.
static boolean distinctProjectList(LogicalProject project) {
// Change to Set<Pair<RexNode, String>> to resolve
// https://github.com/opensearch-project/sql/issues/4347
Set<Pair<RexNode, String>> rexSet = new HashSet<>();
return project.getNamedProjects().stream().allMatch(rexSet::add);
}

static boolean containsRexOver(LogicalProject project) {
return project.getProjects().stream().anyMatch(RexOver::containsOver);
}

/**
* The LogicalSort is a LIMIT that should be pushed down when its fetch field is not null and its
* collation is empty. For example: <code>sort name | head 5</code> should not be pushed down
* because it has a field collation.
*
* @param sort The LogicalSort to check.
* @return True if the LogicalSort is a LIMIT, false otherwise.
*/
static boolean isLogicalSortLimit(LogicalSort sort) {
return sort.fetch != null;
}

static boolean projectContainsExpr(Project project) {
return project.getProjects().stream().anyMatch(p -> p instanceof RexCall);
}

static boolean sortByFieldsOnly(Sort sort) {
return !sort.getCollation().getFieldCollations().isEmpty() && sort.fetch == null;
}

/**
* Get a string representation of the argument types expressed in ExprType for error messages.
*
* @param argTypes the list of argument types as {@link RelDataType}
* @return a string in the format [type1,type2,...] representing the argument types
*/
public static String getActualSignature(List<RelDataType> argTypes) {
static String getActualSignature(List<RelDataType> argTypes) {
return "["
+ argTypes.stream()
.map(OpenSearchTypeFactory::convertRelDataTypeToExprType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1018,24 +1018,118 @@ public void testExplainCountsByAgg() throws IOException {
}

@Test
public void testExplainSortOnMetricsNoBucketNullable() throws IOException {
// TODO enhancement later: https://github.com/opensearch-project/sql/issues/4282
public void testExplainSortOnMetrics() throws IOException {
enabledOnlyWhenPushdownIsEnabled();
String expected = loadExpectedPlan("explain_agg_sort_on_metrics1.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
"source=opensearch-sql_test_index_account | stats bucket_nullable=false count() by"
+ " state | sort `count()`"));

expected = loadExpectedPlan("explain_agg_sort_on_metrics2.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
"source=opensearch-sql_test_index_account | stats bucket_nullable=false sum(balance)"
+ " as sum by state | sort - sum"));
// TODO limit should pushdown to non-composite agg
expected = loadExpectedPlan("explain_agg_sort_on_metrics3.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
String.format(
"source=%s | stats count() as cnt by span(birthdate, 1d) | sort - cnt",
TEST_INDEX_BANK)));
expected = loadExpectedPlan("explain_agg_sort_on_metrics4.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
String.format(
"source=%s | stats bucket_nullable=false sum(balance) by span(age, 5) | sort -"
+ " `sum(balance)`",
TEST_INDEX_BANK)));
}

@Test
public void testExplainSortOnMetricsMultiTerms() throws IOException {
enabledOnlyWhenPushdownIsEnabled();
String expected = loadExpectedPlan("explain_agg_sort_on_metrics_multi_terms.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
"source=opensearch-sql_test_index_account | stats bucket_nullable=false count() by"
+ " gender, state | sort `count()`"));
}

@Test
public void testExplainCompositeMultiBucketsAutoDateThenSortOnMetricsNotPushdown()
throws IOException {
enabledOnlyWhenPushdownIsEnabled();
assertYamlEqualsIgnoreId(
loadExpectedPlan("agg_composite_multi_terms_autodate_sort_agg_metric_not_push.yaml"),
explainQueryYaml(
String.format(
"source=%s | bin timestamp bins=3 | stats bucket_nullable=false avg(value), count()"
+ " as cnt by category, value, timestamp | sort cnt",
TEST_INDEX_TIME_DATA)));
}

@Test
public void testExplainCompositeRangeThenSortOnMetricsNotPushdown() throws IOException {
enabledOnlyWhenPushdownIsEnabled();
assertYamlEqualsIgnoreId(
loadExpectedPlan("agg_composite_range_sort_agg_metric_not_push.yaml"),
explainQueryYaml(
String.format(
"source=%s | eval value_range = case(value < 7000, 'small'"
+ " else 'great') | stats bucket_nullable=false avg(value), count() as cnt by"
+ " value_range, category | sort cnt",
TEST_INDEX_TIME_DATA)));
}

@Test
public void testExplainCompositeAutoDateThenSortOnMetricsNotPushdown() throws IOException {
enabledOnlyWhenPushdownIsEnabled();
assertYamlEqualsIgnoreId(
loadExpectedPlan("agg_composite_autodate_sort_agg_metric_not_push.yaml"),
explainQueryYaml(
String.format(
"source=%s | bin timestamp bins=3 | stats bucket_nullable=false avg(value), count()"
+ " as cnt by timestamp, category | sort cnt",
TEST_INDEX_TIME_DATA)));
}

@Test
public void testExplainCompositeRangeAutoDateThenSortOnMetricsNotPushdown() throws IOException {
enabledOnlyWhenPushdownIsEnabled();
assertYamlEqualsIgnoreId(
loadExpectedPlan("agg_composite_autodate_range_metric_sort_agg_metric_not_push.yaml"),
explainQueryYaml(
String.format(
"source=%s | bin timestamp bins=3 | eval value_range = case(value < 7000, 'small'"
+ " else 'great') | stats bucket_nullable=false avg(value), count() as cnt by"
+ " timestamp, value_range, category | sort cnt",
TEST_INDEX_TIME_DATA)));
}

@Test
public void testExplainMultipleAggregatorsWithSortOnOneMetricNotPushDown() throws IOException {
enabledOnlyWhenPushdownIsEnabled();
String expected =
loadExpectedPlan("explain_multiple_agg_with_sort_on_one_metric_not_push1.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
"source=opensearch-sql_test_index_account | stats bucket_nullable=false count() as c,"
+ " sum(balance) as s by state | sort c"));
expected = loadExpectedPlan("explain_multiple_agg_with_sort_on_one_metric_not_push2.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
"source=opensearch-sql_test_index_account | stats bucket_nullable=false count() as c,"
+ " sum(balance) as s by state | sort c, s"));
}

@Test
public void testExplainEvalMax() throws IOException {
String expected = loadExpectedPlan("explain_eval_max.json");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
calcite:
logical: |
LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalSort(sort0=[$1], dir0=[ASC-nulls-first])
LogicalProject(avg(value)=[$3], cnt=[$4], timestamp=[$0], value_range=[$1], category=[$2])
LogicalAggregate(group=[{0, 1, 2}], avg(value)=[AVG($3)], cnt=[COUNT()])
LogicalProject(timestamp=[$9], value_range=[$10], category=[$1], value=[$2])
LogicalFilter(condition=[AND(IS NOT NULL($9), IS NOT NULL($1))])
LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())], value_range=[CASE(<($2, 7000), 'small':VARCHAR, 'great':VARCHAR)])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2, 3},avg(value)=AVG($1),cnt=COUNT()), PROJECT->[avg(value), cnt, timestamp, value_range, category]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null},"aggregations":{"value_range":{"range":{"field":"value","ranges":[{"key":"small","to":7000.0},{"key":"great","from":7000.0}],"keyed":true},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
calcite:
logical: |
LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalSort(sort0=[$1], dir0=[ASC-nulls-first])
LogicalProject(avg(value)=[$2], cnt=[$3], timestamp=[$0], category=[$1])
LogicalAggregate(group=[{0, 1}], avg(value)=[AVG($2)], cnt=[COUNT()])
LogicalProject(timestamp=[$9], category=[$1], value=[$2])
LogicalFilter(condition=[AND(IS NOT NULL($9), IS NOT NULL($1))])
LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2},avg(value)=AVG($1),cnt=COUNT()), PROJECT->[avg(value), cnt, timestamp, category]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
calcite:
logical: |
LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalSort(sort0=[$1], dir0=[ASC-nulls-first])
LogicalProject(avg(value)=[$3], cnt=[$4], category=[$0], value=[$1], timestamp=[$2])
LogicalAggregate(group=[{0, 1, 2}], avg(value)=[AVG($1)], cnt=[COUNT()])
LogicalProject(category=[$1], value=[$2], timestamp=[$9])
LogicalFilter(condition=[AND(IS NOT NULL($1), IS NOT NULL($2), IS NOT NULL($9))])
LogicalProject(@timestamp=[$0], category=[$1], value=[$2], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], timestamp=[WIDTH_BUCKET($3, 3, -(MAX($3) OVER (), MIN($3) OVER ()), MAX($3) OVER ())])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first])
EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t1):DOUBLE], avg(value)=[$t4], cnt=[$t3], category=[$t0], value=[$t1], timestamp=[$t2])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1, 2},cnt=COUNT())], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}},{"value":{"terms":{"field":"value","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"timestamp":{"auto_date_histogram":{"field":"timestamp","buckets":3,"minimum_interval":null}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
calcite:
logical: |
LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalSort(sort0=[$1], dir0=[ASC-nulls-first])
LogicalProject(avg(value)=[$2], cnt=[$3], value_range=[$0], category=[$1])
LogicalAggregate(group=[{0, 1}], avg(value)=[AVG($2)], cnt=[COUNT()])
LogicalProject(value_range=[$10], category=[$1], value=[$2])
LogicalFilter(condition=[IS NOT NULL($1)])
LogicalProject(@timestamp=[$0], category=[$1], value=[$2], timestamp=[$3], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], value_range=[CASE(<($2, 7000), 'small':VARCHAR, 'great':VARCHAR)])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2},avg(value)=AVG($1),cnt=COUNT()), PROJECT->[avg(value), cnt, value_range, category]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"category":{"terms":{"field":"category","missing_bucket":false,"order":"asc"}}}]},"aggregations":{"value_range":{"range":{"field":"value","ranges":[{"key":"small","to":7000.0},{"key":"great","from":7000.0}],"keyed":true},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,4 @@ calcite:
LogicalFilter(condition=[IS NOT NULL($7)])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), PROJECT->[count(), state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), SORT_AGG_METRICS->[1 ASC FIRST], PROJECT->[count(), state], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"state":{"terms":{"field":"state.keyword","size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"count()":"asc"},{"_key":"asc"}]},"aggregations":{"count()":{"value_count":{"field":"_index"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
calcite:
logical: |
LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalSort(sort0=[$0], dir0=[ASC-nulls-first])
LogicalProject(count()=[$2], gender=[$0], state=[$1])
LogicalAggregate(group=[{0, 1}], count()=[COUNT()])
LogicalProject(gender=[$4], state=[$7])
LogicalFilter(condition=[AND(IS NOT NULL($4), IS NOT NULL($7))])
LogicalSystemLimit(sort0=[$0], dir0=[DESC-nulls-last], fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalSort(sort0=[$0], dir0=[DESC-nulls-last])
LogicalProject(sum=[$1], state=[$0])
LogicalAggregate(group=[{0}], sum=[SUM($1)])
LogicalProject(state=[$7], balance=[$3])
LogicalFilter(condition=[IS NOT NULL($7)])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},count()=COUNT()), PROJECT->[count(), gender, state]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"gender":{"terms":{"field":"gender.keyword","missing_bucket":false,"order":"asc"}}},{"state":{"terms":{"field":"state.keyword","missing_bucket":false,"order":"asc"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={1},sum=SUM($0)), SORT_AGG_METRICS->[1 DESC LAST], PROJECT->[sum, state], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"state":{"terms":{"field":"state.keyword","size":1000,"min_doc_count":1,"shard_min_doc_count":0,"show_term_doc_count_error":false,"order":[{"sum":"desc"},{"_key":"asc"}]},"aggregations":{"sum":{"sum":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Loading
Loading