Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import javax.annotation.Nullable;
import java.util.List;
import java.util.stream.Collectors;

@ToString
@Getter
Expand All @@ -22,7 +23,7 @@ public class Trendline extends UnresolvedPlan {
private UnresolvedPlan child;
@Nullable
private final Field sortByField;
private final List<UnresolvedExpression> computations;
private final List<Trendline.TrendlineComputation> computations;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
Expand All @@ -40,15 +41,21 @@ public <T, C> T accept(AbstractNodeVisitor<T, C> visitor, C context) {
return visitor.visitTrendline(this, context);
}

public List<Trendline.TrendlineComputation> filterComputationByType(TrendlineType type) {
return computations.stream()
.filter(computation -> computation.getComputationType().equals(type))
.collect(Collectors.toList());
}

@Getter
public static class TrendlineComputation extends UnresolvedExpression {

private final Integer numberOfDataPoints;
private final UnresolvedExpression dataField;
private final Field dataField;
private final String alias;
private final TrendlineType computationType;

public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, String computationType) {
public TrendlineComputation(Integer numberOfDataPoints, Field dataField, String alias, String computationType) {
this.numberOfDataPoints = numberOfDataPoints;
this.dataField = dataField;
this.alias = alias;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package org.opensearch.sql.ppl;

import lombok.val;
import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
Expand All @@ -20,6 +22,7 @@
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
Expand All @@ -29,18 +32,22 @@
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame;
import org.apache.spark.sql.catalyst.expressions.UnresolvedWindowExpression;
import org.apache.spark.sql.catalyst.expressions.WindowExpression;
import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition;
import org.apache.spark.sql.catalyst.expressions.WindowSpecReference;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$;
import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$;
import org.apache.spark.sql.catalyst.plans.logical.Limit;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project$;
import org.apache.spark.sql.catalyst.plans.logical.WithWindowDefinition;
import org.apache.spark.sql.execution.ExplainMode;
import org.apache.spark.sql.execution.command.DescribeTableCommand;
import org.apache.spark.sql.execution.command.ExplainCommand;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.AggregateFunction;
Expand Down Expand Up @@ -109,6 +116,7 @@
import org.opensearch.sql.ppl.utils.ParseStrategy;
import org.opensearch.sql.ppl.utils.SortUtils;
import org.opensearch.sql.ppl.utils.WindowSpecTransformer;
import scala.None$;
import scala.Option;
import scala.Tuple2;
import scala.collection.IterableLike;
Expand All @@ -119,11 +127,14 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;
import java.util.UUID;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static java.util.Collections.emptyList;
import static java.util.List.of;
import static org.apache.spark.sql.catalyst.expressions.Literal.create;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL;
import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation;
import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;
Expand Down Expand Up @@ -282,10 +293,50 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) {
// Create an UnresolvedStar for all-fields projection
context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.empty()));
}
visitExpressionList(node.getComputations(), context);
visitExpressionList(node.filterComputationByType(Trendline.TrendlineType.SMA), context);
Seq<NamedExpression> projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);

return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
LogicalPlan logicalPlan = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p));
List<Trendline.TrendlineComputation> weightedTrendlineComputations = node.filterComputationByType(Trendline.TrendlineType.WMA);
if(!weightedTrendlineComputations.isEmpty()) {
for(Trendline.TrendlineComputation computation : weightedTrendlineComputations) {
String uniquePrefix = UUID.randomUUID().toString();
SpecifiedWindowFrame frameSpecification = new SpecifiedWindowFrame(RowFrame$.MODULE$, create(- computation.getNumberOfDataPoints() + 1, IntegerType$.MODULE$), CurrentRow$.MODULE$);
Expression sortFieldExpression = visitExpression(node.getSortByField(), context);
Seq<SortOrder> sortOrder = context
.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(sortFieldExpression, SortUtils.isSortedAscending(node.getSortByField())));
WindowSpecDefinition windowSpecDefinition = new WindowSpecDefinition(seq(), sortOrder, frameSpecification);
List<NamedExpression> aliases = new ArrayList<>();
aliases.add(UnresolvedStar$.MODULE$.apply((Option) None$.MODULE$));

String windowName = "wma_window-" + uniquePrefix;
Seq s = seq(new Tuple2(windowName, windowSpecDefinition));
scala.collection.immutable.Map wmaWindow = scala.collection.immutable.Map$.MODULE$.apply(s);
List<String> tempFieldNames = new ArrayList<>();
for(int i = 1; i <= computation.getNumberOfDataPoints(); i++) {
String fieldName = computation.getDataField().getField().toString();
String tempFieldName = fieldName + "_wma_" + i + "_" + uniquePrefix;
tempFieldNames.add(tempFieldName);
UnresolvedWindowExpression windowExpression = new UnresolvedWindowExpression(new UnresolvedFunction(seq("nth_value"), seq(new UnresolvedAttribute(seq(fieldName)), new org.apache.spark.sql.catalyst.expressions.Literal(i, IntegerType$.MODULE$)), false, (Option)None$.MODULE$, false), new WindowSpecReference(windowName));
val alias = new org.apache.spark.sql.catalyst.expressions.Alias(windowExpression, tempFieldName, NamedExpression.newExprId(), seq(), (Option) None$.MODULE$, seq());
aliases.add(alias);
}
context.apply(p -> Project$.MODULE$.apply(seq(aliases), p));
context.apply(p -> new WithWindowDefinition(wmaWindow, p));
Optional<UnresolvedFunction> sumFunction = IntStream.range(0, tempFieldNames.size())
.mapToObj(i -> new UnresolvedFunction(seq("*"), seq(new org.apache.spark.sql.catalyst.expressions.Literal(i + 1, IntegerType$.MODULE$), new UnresolvedAttribute(seq(tempFieldNames.get(i)))), false, (Option) None$.MODULE$, false))
.reduce((left, right) -> new UnresolvedFunction(seq("+"), seq(left, right), false, (Option) None$.MODULE$, false));
Integer divideBy = IntStream.range(1, computation.getNumberOfDataPoints() + 1).sum(); // todo correct!
UnresolvedFunction wmaResult = new UnresolvedFunction(seq("/"), seq(sumFunction.get(), new org.apache.spark.sql.catalyst.expressions.Literal(divideBy, IntegerType$.MODULE$)), false, (Option) None$.MODULE$, false);
val alias = new org.apache.spark.sql.catalyst.expressions.Alias(wmaResult, computation.getAlias(), NamedExpression.newExprId(), seq(), (Option) None$.MODULE$, seq());
context.apply(p -> Project$.MODULE$.apply(seq(UnresolvedStar$.MODULE$.apply((Option) None$.MODULE$), alias), p));
visitFieldList(tempFieldNames.stream().map(name -> new Field(new QualifiedName(name))).collect(Collectors.toList()), context);
Seq<Expression> toDrop = context.retainAllNamedParseExpressions(p -> p);
logicalPlan = context.apply(p -> DataFrameDropColumns$.MODULE$.apply(toDrop, p));
}
}

return logicalPlan;
}

@Override
Expand Down Expand Up @@ -494,7 +545,7 @@ private void visitFieldList(List<Field> fieldList, CatalystPlanContext context)
fieldList.forEach(field -> visitExpression(field, context));
}

private List<Expression> visitExpressionList(List<UnresolvedExpression> expressionList, CatalystPlanContext context) {
private List<Expression> visitExpressionList(List<? extends UnresolvedExpression> expressionList, CatalystPlanContext context) {
return expressionList.isEmpty()
? emptyList()
: expressionList.stream().map(field -> visitExpression(field, context))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,10 @@ private java.util.Map<Alias, Field> buildLookupPair(List<OpenSearchPPLParser.Loo

@Override
public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) {
List<UnresolvedExpression> trendlineComputations = ctx.trendlineClause()
List<Trendline.TrendlineComputation> trendlineComputations = ctx.trendlineClause()
.stream()
.map(expressionBuilder::visit)
.map(Trendline.TrendlineComputation.class::cast)
.collect(Collectors.toList());
return Optional.ofNullable(ctx.sortField())
.map(this::internalVisitExpression)
Expand Down