Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9208000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.3.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
search_shards_resolved_index_expressions,9207000
aggregation_window,9208000
Original file line number Diff line number Diff line change
Expand Up @@ -1080,13 +1080,19 @@ private LogicalPlan resolveFuse(Fuse fuse, List<Attribute> childrenOutput) {
Expression aggFilter = new Literal(source, true, DataType.BOOLEAN);

List<NamedExpression> aggregates = new ArrayList<>();
aggregates.add(new Alias(source, score.name(), new Sum(source, score, aggFilter, SummationMode.COMPENSATED_LITERAL)));
aggregates.add(
new Alias(
source,
score.name(),
new Sum(source, score, aggFilter, AggregateFunction.NO_WINDOW, SummationMode.COMPENSATED_LITERAL)
)
);

for (Attribute attr : childrenOutput) {
if (attr.name().equals(score.name())) {
continue;
}
var valuesAgg = new Values(source, attr, aggFilter);
var valuesAgg = new Values(source, attr, aggFilter, AggregateFunction.NO_WINDOW);
// Use VALUES only on supported fields.
// FuseScoreEval will check that the input contains only columns with supported data types
// and will fail with an appropriate error message if it doesn't.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,11 @@ private static FunctionDefinition[][] functions() {
def(Score.class, uni(Score::new), "score") },
// time-series functions
new FunctionDefinition[] {
defTS(Rate.class, Rate::new, "rate"),
defTS(Irate.class, Irate::new, "irate"),
defTS(Idelta.class, Idelta::new, "idelta"),
defTS(Delta.class, Delta::new, "delta"),
defTS(Increase.class, Increase::new, "increase"),
defTS(Rate.class, bi(Rate::new), "rate"),
defTS(Irate.class, bi(Irate::new), "irate"),
defTS(Idelta.class, bi(Idelta::new), "idelta"),
defTS(Delta.class, bi(Delta::new), "delta"),
defTS(Increase.class, bi(Increase::new), "increase"),
def(MaxOverTime.class, uni(MaxOverTime::new), "max_over_time"),
def(MinOverTime.class, uni(MinOverTime::new), "min_over_time"),
def(SumOverTime.class, uni(SumOverTime::new), "sum_over_time"),
Expand All @@ -546,8 +546,8 @@ private static FunctionDefinition[][] functions() {
def(PresentOverTime.class, uni(PresentOverTime::new), "present_over_time"),
def(AbsentOverTime.class, uni(AbsentOverTime::new), "absent_over_time"),
def(AvgOverTime.class, uni(AvgOverTime::new), "avg_over_time"),
defTS(LastOverTime.class, LastOverTime::new, "last_over_time"),
defTS(FirstOverTime.class, FirstOverTime::new, "first_over_time"),
defTS(LastOverTime.class, bi(LastOverTime::new), "last_over_time"),
defTS(FirstOverTime.class, bi(FirstOverTime::new), "first_over_time"),
def(PercentileOverTime.class, bi(PercentileOverTime::new), "percentile_over_time"),
// dense vector function
def(TextEmbedding.class, bi(TextEmbedding::new), "text_embedding") } };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ public Absent(
description = "Expression that outputs values to be checked for absence."
) Expression field
) {
this(source, field, Literal.TRUE);
this(source, field, Literal.TRUE, NO_WINDOW);
}

public Absent(Source source, Expression field, Expression filter) {
super(source, field, filter, emptyList());
public Absent(Source source, Expression field, Expression filter, Expression window) {
super(source, field, filter, window, emptyList());
}

private Absent(StreamInput in) throws IOException {
Expand All @@ -102,17 +102,17 @@ public String getWriteableName() {

@Override
protected NodeInfo<Absent> info() {
return NodeInfo.create(this, Absent::new, field(), filter());
return NodeInfo.create(this, Absent::new, field(), filter(), window());
}

@Override
public AggregateFunction withFilter(Expression filter) {
return new Absent(source(), field(), filter);
return new Absent(source(), field(), filter, window());
}

@Override
public Absent replaceChildren(List<Expression> newChildren) {
return new Absent(source(), newChildren.get(0), newChildren.get(1));
return new Absent(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
}

@Override
Expand All @@ -138,6 +138,6 @@ protected TypeResolution resolveType() {

@Override
public Expression surrogate() {
return new Not(source(), new Present(source(), field(), filter()));
return new Not(source(), new Present(source(), field(), filter(), window()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import java.io.IOException;
import java.util.List;

import static java.util.Collections.emptyList;

/**
* Similar to {@link Absent}, but it is used to check the absence of values over a time series in the given field.
*/
Expand Down Expand Up @@ -70,11 +68,11 @@ public AbsentOverTime(
"version" }
) Expression field
) {
this(source, field, Literal.TRUE);
this(source, field, Literal.TRUE, NO_WINDOW);
}

public AbsentOverTime(Source source, Expression field, Expression filter) {
super(source, field, filter, emptyList());
public AbsentOverTime(Source source, Expression field, Expression filter, Expression window) {
super(source, field, filter, window, List.of());
}

private AbsentOverTime(StreamInput in) throws IOException {
Expand All @@ -88,17 +86,17 @@ public String getWriteableName() {

@Override
public AbsentOverTime withFilter(Expression filter) {
return new AbsentOverTime(source(), field(), filter);
return new AbsentOverTime(source(), field(), filter, window());
}

@Override
protected NodeInfo<AbsentOverTime> info() {
return NodeInfo.create(this, AbsentOverTime::new, field(), filter());
return NodeInfo.create(this, AbsentOverTime::new, field(), filter(), window());
}

@Override
public AbsentOverTime replaceChildren(List<Expression> newChildren) {
return new AbsentOverTime(source(), newChildren.get(0), newChildren.get(1));
return new AbsentOverTime(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
}

@Override
Expand All @@ -113,6 +111,6 @@ public DataType dataType() {

@Override
public Absent perTimeSeriesAggregation() {
return new Absent(source(), field(), filter());
return new Absent(source(), field(), filter(), window());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
Expand All @@ -17,15 +18,14 @@
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.function.BiConsumer;
Expand All @@ -38,25 +38,38 @@

/**
* A type of {@code Function} that takes multiple values and extracts a single value out of them. For example, {@code AVG()}.
* - Aggregate functions can have an optional filter and window, which default to {@code Literal.TRUE} and {@code NO_WINDOW}.
* - The aggregation function should be composed as: source, field, filter, window, parameters.
* Extra parameters should go to the parameters after the filter and window.
*/
public abstract class AggregateFunction extends Function implements PostAnalysisPlanVerificationAware {
public static final Literal NO_WINDOW = Literal.timeDuration(Source.EMPTY, Duration.ZERO);
public static final TransportVersion WINDOW_INTERVAL = TransportVersion.fromName("aggregation_window");

private final Expression field;
private final List<? extends Expression> parameters;
private final Expression filter;
private final Expression window;

protected AggregateFunction(Source source, Expression field) {
this(source, field, Literal.TRUE, emptyList());
this(source, field, Literal.TRUE, NO_WINDOW, emptyList());
}

protected AggregateFunction(Source source, Expression field, List<? extends Expression> parameters) {
this(source, field, Literal.TRUE, parameters);
this(source, field, Literal.TRUE, NO_WINDOW, parameters);
}

protected AggregateFunction(Source source, Expression field, Expression filter, List<? extends Expression> parameters) {
super(source, CollectionUtils.combine(asList(field, filter), parameters));
protected AggregateFunction(
Source source,
Expression field,
Expression filter,
Expression window,
List<? extends Expression> parameters
) {
super(source, CollectionUtils.combine(asList(field, filter, window), parameters));
this.field = field;
this.filter = filter;
this.window = Objects.requireNonNull(window, "[window] must be specified; use NO_WINDOW instead");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we expect functions to check against negative durations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should, but our random tests may pass any expression for this parameter.

this.parameters = parameters;
}

Expand All @@ -65,48 +78,27 @@ protected AggregateFunction(StreamInput in) throws IOException {
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class),
readWindow(in),
in.readNamedWriteableCollectionAsList(Expression.class)
);
}

/**
* Read a generic AggregateFunction from the stream input. This is used for BWC when the subclass requires a generic instance;
* then convert the parameters to the specific ones.
*/
protected static AggregateFunction readGenericAggregateFunction(StreamInput in) throws IOException {
return new AggregateFunction(in) {
@Override
public AggregateFunction withFilter(Expression filter) {
throw new UnsupportedOperationException();
}

@Override
public DataType dataType() {
throw new UnsupportedOperationException();
}

@Override
public Expression replaceChildren(List<Expression> newChildren) {
throw new UnsupportedOperationException();
}

@Override
protected NodeInfo<? extends Expression> info() {
throw new UnsupportedOperationException();
}

@Override
public String getWriteableName() {
throw new UnsupportedOperationException();
}
};
protected static Expression readWindow(StreamInput in) throws IOException {
if (in.getTransportVersion().supports(WINDOW_INTERVAL)) {
return in.readNamedWriteable(Expression.class);
} else {
return NO_WINDOW;
}
}

@Override
public final void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(field);
out.writeNamedWriteable(filter);
if (out.getTransportVersion().supports(WINDOW_INTERVAL)) {
out.writeNamedWriteable(window);
}
out.writeNamedWriteableCollection(parameters);
}

Expand Down Expand Up @@ -144,6 +136,23 @@ public AggregateFunction withParameters(List<? extends Expression> parameters) {
return (AggregateFunction) replaceChildren(CollectionUtils.combine(asList(field, filter), parameters));
}

/**
* Return the window associated with the aggregate function.
*/
public Expression window() {
return window;
}

/**
* Whether the aggregate function has a window different than NO_WINDOW.
*/
public boolean hasWindow() {
if (window instanceof Literal lit && lit.value() instanceof Duration duration) {
return duration.isZero() == false;
}
return true;
}

/**
* Returns the set of input attributes required by this aggregate function, excluding those referenced by the filter.
*/
Expand All @@ -168,6 +177,7 @@ public boolean equals(Object obj) {
AggregateFunction other = (AggregateFunction) obj;
return Objects.equals(other.field(), field())
&& Objects.equals(other.filter(), filter())
&& Objects.equals(other.window(), window())
&& Objects.equals(other.parameters(), parameters());
}
return false;
Expand Down
Loading