diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index fff23a4cc..9f59c85c8 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -12,6 +12,7 @@ import javax.annotation.Nullable; import java.util.List; +import java.util.stream.Collectors; @ToString @Getter @@ -22,7 +23,7 @@ public class Trendline extends UnresolvedPlan { private UnresolvedPlan child; @Nullable private final Field sortByField; - private final List computations; + private final List computations; @Override public UnresolvedPlan attach(UnresolvedPlan child) { @@ -40,15 +41,21 @@ public T accept(AbstractNodeVisitor visitor, C context) { return visitor.visitTrendline(this, context); } + public List filterComputationByType(TrendlineType type) { + return computations.stream() + .filter(computation -> computation.getComputationType().equals(type)) + .collect(Collectors.toList()); + } + @Getter public static class TrendlineComputation extends UnresolvedExpression { private final Integer numberOfDataPoints; - private final UnresolvedExpression dataField; + private final Field dataField; private final String alias; private final TrendlineType computationType; - public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, String computationType) { + public TrendlineComputation(Integer numberOfDataPoints, Field dataField, String alias, String computationType) { this.numberOfDataPoints = numberOfDataPoints; this.dataField = dataField; this.alias = alias; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 876ede1db..9ff484f18 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -5,7 +5,9 @@ package org.opensearch.sql.ppl; +import lombok.val; import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; @@ -20,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; import org.apache.spark.sql.catalyst.expressions.InSubquery$; import org.apache.spark.sql.catalyst.expressions.LessThan; +import org.apache.spark.sql.catalyst.expressions.LessThan; import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; @@ -29,18 +32,22 @@ import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; +import org.apache.spark.sql.catalyst.expressions.UnresolvedWindowExpression; import org.apache.spark.sql.catalyst.expressions.WindowExpression; import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; +import org.apache.spark.sql.catalyst.expressions.WindowSpecReference; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Project$; +import org.apache.spark.sql.catalyst.plans.logical.WithWindowDefinition; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.execution.command.ExplainCommand; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; @@ -109,6 +116,7 @@ import org.opensearch.sql.ppl.utils.ParseStrategy; import org.opensearch.sql.ppl.utils.SortUtils; import org.opensearch.sql.ppl.utils.WindowSpecTransformer; +import scala.None$; import scala.Option; import scala.Tuple2; import scala.collection.IterableLike; @@ -119,11 +127,14 @@ import java.util.Objects; import java.util.Optional; import java.util.Stack; +import java.util.UUID; import java.util.function.BiFunction; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static java.util.Collections.emptyList; import static java.util.List.of; +import static org.apache.spark.sql.catalyst.expressions.Literal.create; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; @@ -282,10 +293,50 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { // Create an UnresolvedStar for all-fields projection context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty())); } - visitExpressionList(node.getComputations(), context); + visitExpressionList(node.filterComputationByType(Trendline.TrendlineType.SMA), context); Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + LogicalPlan logicalPlan = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + List weightedTrendlineComputations = node.filterComputationByType(Trendline.TrendlineType.WMA); + if(!weightedTrendlineComputations.isEmpty()) { + for(Trendline.TrendlineComputation computation : weightedTrendlineComputations) { + String uniquePrefix = UUID.randomUUID().toString(); + SpecifiedWindowFrame frameSpecification = new SpecifiedWindowFrame(RowFrame$.MODULE$, create(- computation.getNumberOfDataPoints() + 1, IntegerType$.MODULE$), CurrentRow$.MODULE$); + Expression sortFieldExpression = visitExpression(node.getSortByField(), context); + Seq sortOrder = context + .retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(sortFieldExpression, SortUtils.isSortedAscending(node.getSortByField()))); + WindowSpecDefinition windowSpecDefinition = new WindowSpecDefinition(seq(), sortOrder, frameSpecification); + List aliases = new ArrayList<>(); + aliases.add(UnresolvedStar$.MODULE$.apply((Option) None$.MODULE$)); + + String windowName = "wma_window-" + uniquePrefix; + Seq s = seq(new Tuple2(windowName, windowSpecDefinition)); + scala.collection.immutable.Map wmaWindow = scala.collection.immutable.Map$.MODULE$.apply(s); + List tempFieldNames = new ArrayList<>(); + for(int i = 1; i <= computation.getNumberOfDataPoints(); i++) { + String fieldName = computation.getDataField().getField().toString(); + String tempFieldName = fieldName + "_wma_" + i + "_" + uniquePrefix; + tempFieldNames.add(tempFieldName); + UnresolvedWindowExpression windowExpression = new UnresolvedWindowExpression(new UnresolvedFunction(seq("nth_value"), seq(new UnresolvedAttribute(seq(fieldName)), new org.apache.spark.sql.catalyst.expressions.Literal(i, IntegerType$.MODULE$)), false, (Option)None$.MODULE$, false), new WindowSpecReference(windowName)); + val alias = new org.apache.spark.sql.catalyst.expressions.Alias(windowExpression, tempFieldName, NamedExpression.newExprId(), seq(), (Option) None$.MODULE$, seq()); + aliases.add(alias); + } + context.apply(p -> Project$.MODULE$.apply(seq(aliases), p)); + context.apply(p -> new WithWindowDefinition(wmaWindow, p)); + Optional sumFunction = IntStream.range(0, tempFieldNames.size()) + .mapToObj(i -> new UnresolvedFunction(seq("*"), seq(new org.apache.spark.sql.catalyst.expressions.Literal(i + 1, IntegerType$.MODULE$), new UnresolvedAttribute(seq(tempFieldNames.get(i)))), false, (Option) None$.MODULE$, false)) + .reduce((left, right) -> new UnresolvedFunction(seq("+"), seq(left, right), false, (Option) None$.MODULE$, false)); + Integer divideBy = IntStream.range(1, computation.getNumberOfDataPoints() + 1).sum(); // todo correct! + UnresolvedFunction wmaResult = new UnresolvedFunction(seq("/"), seq(sumFunction.get(), new org.apache.spark.sql.catalyst.expressions.Literal(divideBy, IntegerType$.MODULE$)), false, (Option) None$.MODULE$, false); + val alias = new org.apache.spark.sql.catalyst.expressions.Alias(wmaResult, computation.getAlias(), NamedExpression.newExprId(), seq(), (Option) None$.MODULE$, seq()); + context.apply(p -> Project$.MODULE$.apply(seq(UnresolvedStar$.MODULE$.apply((Option) None$.MODULE$), alias), p)); + visitFieldList(tempFieldNames.stream().map(name -> new Field(new QualifiedName(name))).collect(Collectors.toList()), context); + Seq toDrop = context.retainAllNamedParseExpressions(p -> p); + logicalPlan = context.apply(p -> DataFrameDropColumns$.MODULE$.apply(toDrop, p)); + } + } + + return logicalPlan; } @Override @@ -494,7 +545,7 @@ private void visitFieldList(List fieldList, CatalystPlanContext context) fieldList.forEach(field -> visitExpression(field, context)); } - private List visitExpressionList(List expressionList, CatalystPlanContext context) { + private List visitExpressionList(List expressionList, CatalystPlanContext context) { return expressionList.isEmpty() ? emptyList() : expressionList.stream().map(field -> visitExpression(field, context)) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 0380671c3..523d3d661 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -388,9 +388,10 @@ private java.util.Map buildLookupPair(List trendlineComputations = ctx.trendlineClause() + List trendlineComputations = ctx.trendlineClause() .stream() .map(expressionBuilder::visit) + .map(Trendline.TrendlineComputation.class::cast) .collect(Collectors.toList()); return Optional.ofNullable(ctx.sortField()) .map(this::internalVisitExpression)