Skip to content

Commit

Permalink
[opt](nereids) compare literal not convert to legacy literal and fix …
Browse files Browse the repository at this point in the history
…ip literal compareTo always equals 0 (#46482)

### What problem does this PR solve?

Problem Summary:

as #45181 mention, when sorting literals, toLegacyLiteral may cost a lot
of time, so compare literal don't use toLegacyLiteral any more.

legacy literal may have an unknown behaviour comparing two values with
different data type.

for neredis literals, different data type value compare will throw an
exception, it support valid compare with data types:
1. boolean vs boolean;
2. numeric vs numeric;
3. string like vs string like;
4. date like vs date like;
5. ipv4 vs ipv4;
6. ipv6 vs ipv6;
7. array vs array;
8. above data types vs null and max;

what's more, this pr also:
1. nereids literal remove implements Comparable<Literal>;
2. add a new interface ComparableLiteral, and the above type literals
will implement it;
3. fix  ipv4 / ipv6 / map / struct compareTo always return 0;
  • Loading branch information
yujun777 authored Feb 7, 2025
1 parent c04c9c0 commit 6b88c67
Show file tree
Hide file tree
Showing 39 changed files with 900 additions and 126 deletions.
6 changes: 6 additions & 0 deletions fe/fe-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ under the License.
<artifactId>guava-testlib</artifactId>
<scope>test</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/com.googlecode.java-ipv6/java-ipv6 -->
<dependency>
<groupId>com.googlecode.java-ipv6</groupId>
<artifactId>java-ipv6</artifactId>
<version>0.17</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.fasterxml.jackson.core/jackson-core -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ public int compareLiteral(LiteralExpr expr) {
return diff < 0 ? -1 : (diff == 0 ? 0 : 1);
}
// date time will not overflow when doing addition and subtraction
return getStringValue().compareTo(expr.getStringValue());
return Integer.signum(getStringValue().compareTo(expr.getStringValue()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ public int compareLiteral(LiteralExpr expr) {
if (expr instanceof NullLiteral) {
return 1;
}
if (expr == MaxLiteral.MAX_VALUE) {
return -1;
}
if (expr instanceof DecimalLiteral) {
return this.value.compareTo(((DecimalLiteral) expr).value);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ public int compareLiteral(LiteralExpr expr) {
if (expr instanceof NullLiteral) {
return 1;
}
if (expr == MaxLiteral.MAX_VALUE) {
return -1;
}
return Double.compare(value, expr.getDoubleValue());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,13 @@ public int compareLiteral(LiteralExpr expr) {
if (expr instanceof NullLiteral) {
return 1;
}
if (expr instanceof StringLiteral) {
return ((StringLiteral) expr).compareLiteral(this);
}
if (expr == MaxLiteral.MAX_VALUE) {
return -1;
}
if (value == expr.getLongValue()) {
return 0;
} else {
return value > expr.getLongValue() ? 1 : -1;
if (expr instanceof StringLiteral) {
return - ((StringLiteral) expr).compareLiteral(this);
}
return Long.compare(value, expr.getLongValue());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;

Expand Down Expand Up @@ -93,15 +94,15 @@ private enum MatchMinMax {

private static class MinMaxValue {
// min max range, if range = null means empty
Range<Literal> range;
Range<ComparableLiteral> range;

// expression in range is discrete value
boolean isDiscrete;

// expr relative order, for keep order after add min-max to the expression
int exprOrderIndex;

public MinMaxValue(Range<Literal> range, boolean isDiscrete, int exprOrderIndex) {
public MinMaxValue(Range<ComparableLiteral> range, boolean isDiscrete, int exprOrderIndex) {
this.range = range;
this.isDiscrete = isDiscrete;
this.exprOrderIndex = exprOrderIndex;
Expand Down Expand Up @@ -171,25 +172,27 @@ private Expression addExprMinMaxValues(Expression expr, ExpressionRewriteContext
List<Expression> addExprs = Lists.newArrayListWithExpectedSize(minMaxExprs.size() * 2);
for (Map.Entry<Expression, MinMaxValue> entry : minMaxExprs) {
Expression targetExpr = entry.getKey();
Range<Literal> range = entry.getValue().range;
Range<ComparableLiteral> range = entry.getValue().range;
if (range.hasLowerBound() && range.hasUpperBound()
&& range.lowerEndpoint().equals(range.upperEndpoint())
&& range.lowerBoundType() == BoundType.CLOSED
&& range.upperBoundType() == BoundType.CLOSED) {
Expression cmp = new EqualTo(targetExpr, range.lowerEndpoint());
Expression cmp = new EqualTo(targetExpr, (Literal) range.lowerEndpoint());
addExprs.add(cmp);
continue;
}
if (range.hasLowerBound()) {
Literal literal = range.lowerEndpoint();
ComparableLiteral literal = range.lowerEndpoint();
Expression cmp = range.lowerBoundType() == BoundType.CLOSED
? new GreaterThanEqual(targetExpr, literal) : new GreaterThan(targetExpr, literal);
? new GreaterThanEqual(targetExpr, (Literal) literal)
: new GreaterThan(targetExpr, (Literal) literal);
addExprs.add(cmp);
}
if (range.hasUpperBound()) {
Literal literal = range.upperEndpoint();
ComparableLiteral literal = range.upperEndpoint();
Expression cmp = range.upperBoundType() == BoundType.CLOSED
? new LessThanEqual(targetExpr, literal) : new LessThan(targetExpr, literal);
? new LessThanEqual(targetExpr, (Literal) literal)
: new LessThan(targetExpr, (Literal) literal);
addExprs.add(cmp);
}
}
Expand Down Expand Up @@ -243,7 +246,7 @@ private MatchMinMax getExprMatchMinMax(Expression expr,
ComparisonPredicate cp = (ComparisonPredicate) expr;
Expression left = cp.left();
Expression right = cp.right();
if (!(right instanceof Literal)) {
if (!(right instanceof ComparableLiteral)) {
return MatchMinMax.MATCH_NONE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.catalog.PartitionKey;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;

import com.google.common.base.MoreObjects;
Expand All @@ -36,7 +37,13 @@ private ColumnBound(Literal value) {

@Override
public int compareTo(ColumnBound o) {
return value.toLegacyLiteral().compareTo(o.value.toLegacyLiteral());
if (!(value instanceof ComparableLiteral)) {
throw new RuntimeException("'" + value + "' (" + value.getDataType() + ") is not comparable");
}
if (!(o.value instanceof ComparableLiteral)) {
throw new RuntimeException("'" + o.value + "' (" + o.value.getDataType() + ") is not comparable");
}
return ((ComparableLiteral) value).compareTo((ComparableLiteral) o.value);
}

public static ColumnBound of(Literal expr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
Expand Down Expand Up @@ -266,7 +267,12 @@ public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context
if (checkedExpr.isPresent()) {
return checkedExpr.get();
}
return BooleanLiteral.of(((Literal) equalTo.left()).compareTo((Literal) equalTo.right()) == 0);
if (equalTo.left() instanceof ComparableLiteral && equalTo.right() instanceof ComparableLiteral) {
return BooleanLiteral.of(((ComparableLiteral) equalTo.left())
.compareTo((ComparableLiteral) equalTo.right()) == 0);
} else {
return BooleanLiteral.of(equalTo.left().equals(equalTo.right()));
}
}

@Override
Expand All @@ -276,7 +282,8 @@ public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteCon
if (checkedExpr.isPresent()) {
return checkedExpr.get();
}
return BooleanLiteral.of(((Literal) greaterThan.left()).compareTo((Literal) greaterThan.right()) > 0);
return BooleanLiteral.of(((ComparableLiteral) greaterThan.left())
.compareTo((ComparableLiteral) greaterThan.right()) > 0);
}

@Override
Expand All @@ -286,8 +293,8 @@ public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, Expre
if (checkedExpr.isPresent()) {
return checkedExpr.get();
}
return BooleanLiteral.of(((Literal) greaterThanEqual.left())
.compareTo((Literal) greaterThanEqual.right()) >= 0);
return BooleanLiteral.of(((ComparableLiteral) greaterThanEqual.left())
.compareTo((ComparableLiteral) greaterThanEqual.right()) >= 0);
}

@Override
Expand All @@ -297,7 +304,8 @@ public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext cont
if (checkedExpr.isPresent()) {
return checkedExpr.get();
}
return BooleanLiteral.of(((Literal) lessThan.left()).compareTo((Literal) lessThan.right()) < 0);
return BooleanLiteral.of(((ComparableLiteral) lessThan.left())
.compareTo((ComparableLiteral) lessThan.right()) < 0);
}

@Override
Expand All @@ -307,7 +315,8 @@ public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewr
if (checkedExpr.isPresent()) {
return checkedExpr.get();
}
return BooleanLiteral.of(((Literal) lessThanEqual.left()).compareTo((Literal) lessThanEqual.right()) <= 0);
return BooleanLiteral.of(((ComparableLiteral) lessThanEqual.left())
.compareTo((ComparableLiteral) lessThanEqual.right()) <= 0);
}

@Override
Expand All @@ -322,7 +331,13 @@ public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewr
if (l.isNullLiteral() && r.isNullLiteral()) {
return BooleanLiteral.TRUE;
} else if (!l.isNullLiteral() && !r.isNullLiteral()) {
return BooleanLiteral.of(l.compareTo(r) == 0);
if (nullSafeEqual.left() instanceof ComparableLiteral
&& nullSafeEqual.right() instanceof ComparableLiteral) {
return BooleanLiteral.of(((ComparableLiteral) nullSafeEqual.left())
.compareTo((ComparableLiteral) nullSafeEqual.right()) == 0);
} else {
return BooleanLiteral.of(l.equals(r));
}
} else {
return BooleanLiteral.FALSE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
Expand Down Expand Up @@ -76,7 +76,8 @@ private ValueDesc buildRange(ExpressionRewriteContext context, ComparisonPredica
return new UnknownValue(context, predicate);
}
// only handle `NumericType` and `DateLikeType`
if (right.isLiteral() && (right.getDataType().isNumericType() || right.getDataType().isDateLikeType())) {
if (right instanceof ComparableLiteral
&& (right.getDataType().isNumericType() || right.getDataType().isDateLikeType())) {
return ValueDesc.range(context, predicate);
}
return new UnknownValue(context, predicate);
Expand Down Expand Up @@ -111,7 +112,7 @@ public ValueDesc visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context)
public ValueDesc visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
// only handle `NumericType` and `DateLikeType`
if (inPredicate.getOptions().size() <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE
&& ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions())
&& ExpressionUtils.isAllNonNullComparableLiteral(inPredicate.getOptions())
&& (ExpressionUtils.matchNumericType(inPredicate.getOptions())
|| ExpressionUtils.matchDateLikeType(inPredicate.getOptions()))) {
return ValueDesc.discrete(context, inPredicate);
Expand Down Expand Up @@ -216,18 +217,18 @@ public static ValueDesc union(ExpressionRewriteContext context,
/** merge discrete and ranges only, no merge other value desc */
public static List<ValueDesc> unionDiscreteAndRange(ExpressionRewriteContext context,
Expression reference, List<ValueDesc> valueDescs) {
Set<Literal> discreteValues = Sets.newHashSet();
Set<ComparableLiteral> discreteValues = Sets.newHashSet();
for (ValueDesc valueDesc : valueDescs) {
if (valueDesc instanceof DiscreteValue) {
discreteValues.addAll(((DiscreteValue) valueDesc).getValues());
}
}

// for 'a > 8 or a = 8', then range (8, +00) can convert to [8, +00)
RangeSet<Literal> rangeSet = TreeRangeSet.create();
RangeSet<ComparableLiteral> rangeSet = TreeRangeSet.create();
for (ValueDesc valueDesc : valueDescs) {
if (valueDesc instanceof RangeValue) {
Range<Literal> range = ((RangeValue) valueDesc).range;
Range<ComparableLiteral> range = ((RangeValue) valueDesc).range;
rangeSet.add(range);
if (range.hasLowerBound()
&& range.lowerBoundType() == BoundType.OPEN
Expand All @@ -250,7 +251,7 @@ public static List<ValueDesc> unionDiscreteAndRange(ExpressionRewriteContext con
if (!discreteValues.isEmpty()) {
result.add(new DiscreteValue(context, reference, discreteValues));
}
for (Range<Literal> range : rangeSet.asRanges()) {
for (Range<ComparableLiteral> range : rangeSet.asRanges()) {
result.add(new RangeValue(context, reference, range));
}
for (ValueDesc valueDesc : valueDescs) {
Expand All @@ -267,7 +268,7 @@ public static List<ValueDesc> unionDiscreteAndRange(ExpressionRewriteContext con

/** intersect */
public static ValueDesc intersect(ExpressionRewriteContext context, RangeValue range, DiscreteValue discrete) {
Set<Literal> newValues = discrete.values.stream().filter(x -> range.range.contains(x))
Set<ComparableLiteral> newValues = discrete.values.stream().filter(x -> range.range.contains(x))
.collect(Collectors.toSet());
if (newValues.isEmpty()) {
return new EmptyValue(context, range.reference);
Expand All @@ -277,11 +278,11 @@ public static ValueDesc intersect(ExpressionRewriteContext context, RangeValue r
}

private static ValueDesc range(ExpressionRewriteContext context, ComparisonPredicate predicate) {
Literal value = (Literal) predicate.right();
ComparableLiteral value = (ComparableLiteral) predicate.right();
if (predicate instanceof EqualTo) {
return new DiscreteValue(context, predicate.left(), Sets.newHashSet(value));
}
Range<Literal> range = null;
Range<ComparableLiteral> range = null;
if (predicate instanceof GreaterThanEqual) {
range = Range.atLeast(value);
} else if (predicate instanceof GreaterThan) {
Expand All @@ -296,8 +297,9 @@ private static ValueDesc range(ExpressionRewriteContext context, ComparisonPredi
}

public static ValueDesc discrete(ExpressionRewriteContext context, InPredicate in) {
// Set<Literal> literals = (Set) Utils.fastToImmutableSet(in.getOptions());
Set<Literal> literals = in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet());
// Set<ComparableLiteral> literals = (Set) Utils.fastToImmutableSet(in.getOptions());
Set<ComparableLiteral> literals = in.getOptions().stream()
.map(ComparableLiteral.class::cast).collect(Collectors.toSet());
return new DiscreteValue(context, in.getCompareExpr(), literals);
}
}
Expand Down Expand Up @@ -328,14 +330,14 @@ public ValueDesc intersect(ValueDesc other) {
* a > 1 => (1...+∞)
*/
public static class RangeValue extends ValueDesc {
Range<Literal> range;
Range<ComparableLiteral> range;

public RangeValue(ExpressionRewriteContext context, Expression reference, Range<Literal> range) {
public RangeValue(ExpressionRewriteContext context, Expression reference, Range<ComparableLiteral> range) {
super(context, reference);
this.range = range;
}

public Range<Literal> getRange() {
public Range<ComparableLiteral> getRange() {
return range;
}

Expand Down Expand Up @@ -387,15 +389,15 @@ public String toString() {
* a in (1,2,3) => [1,2,3]
*/
public static class DiscreteValue extends ValueDesc {
final Set<Literal> values;
final Set<ComparableLiteral> values;

public DiscreteValue(ExpressionRewriteContext context,
Expression reference, Set<Literal> values) {
Expression reference, Set<ComparableLiteral> values) {
super(context, reference);
this.values = values;
}

public Set<Literal> getValues() {
public Set<ComparableLiteral> getValues() {
return values;
}

Expand All @@ -405,7 +407,7 @@ public ValueDesc union(ValueDesc other) {
return other.union(this);
}
if (other instanceof DiscreteValue) {
Set<Literal> newValues = Sets.newHashSet();
Set<ComparableLiteral> newValues = Sets.newHashSet();
newValues.addAll(((DiscreteValue) other).values);
newValues.addAll(this.values);
return new DiscreteValue(context, reference, newValues);
Expand All @@ -422,7 +424,7 @@ public ValueDesc intersect(ValueDesc other) {
return other.intersect(this);
}
if (other instanceof DiscreteValue) {
Set<Literal> newValues = Sets.newHashSet();
Set<ComparableLiteral> newValues = Sets.newHashSet();
newValues.addAll(((DiscreteValue) other).values);
newValues.retainAll(this.values);
if (newValues.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondsSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.WeeksAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.WeeksSub;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.TypeCoercionUtils;
Expand Down Expand Up @@ -124,7 +125,7 @@ private static List<Expression> tryRearrangeChildren(Expression left, Expression
if (!left.child(1).isConstant()) {
throw new RuntimeException(String.format("Expected literal when arranging children for Expr %s", left));
}
Literal leftLiteral = (Literal) FoldConstantRule.evaluate(left.child(1), context);
ComparableLiteral leftLiteral = (ComparableLiteral) FoldConstantRule.evaluate(left.child(1), context);
Expression leftExpr = left.child(0);

Class<? extends Expression> oppositeOperator = REARRANGEMENT_MAP.get(left.getClass());
Expand Down
Loading

0 comments on commit 6b88c67

Please sign in to comment.