forked from apache/doris
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature](nereids) extend infer predicates (apache#40878)
This pr refactors the PredicatePropagation module and adds support for predicate deduction, including: 1. Support for predicate deduction of like, not in, !=; 2. Support for predicate deduction of abs(b)=1 for a=b and abs(a)=1; 3. Support for transitive deduction of non-equivalent relations, for example, a>b b>1 leads to a>1. 4. Deleted useless predicates. But still has something to do in predicate inference: 1. support expr in infer predicate, e.g. abs(t1.c1)>abs(t2.c2) and abs(t1.c1)<1 2. need to add expr qualifier info, to determine whether abs(t1.c1) and abs(t2.c2) is from same table.
- Loading branch information
1 parent
4583a02
commit a205182
Showing
40 changed files
with
3,501 additions
and
406 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -561,4 +561,6 @@ default boolean isPartitionedTable() { | |
} | ||
|
||
boolean autoAnalyzeEnabled(); | ||
|
||
TableIndexes getTableIndexes(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
266 changes: 266 additions & 0 deletions
266
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
package org.apache.doris.nereids.rules.rewrite; | ||
|
||
import org.apache.doris.nereids.CascadesContext; | ||
import org.apache.doris.nereids.analyzer.Scope; | ||
import org.apache.doris.nereids.exceptions.AnalysisException; | ||
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer; | ||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; | ||
import org.apache.doris.nereids.trees.expressions.Cast; | ||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; | ||
import org.apache.doris.nereids.trees.expressions.EqualTo; | ||
import org.apache.doris.nereids.trees.expressions.Expression; | ||
import org.apache.doris.nereids.trees.expressions.InPredicate; | ||
import org.apache.doris.nereids.trees.expressions.Like; | ||
import org.apache.doris.nereids.trees.expressions.Not; | ||
import org.apache.doris.nereids.trees.expressions.Or; | ||
import org.apache.doris.nereids.trees.expressions.Slot; | ||
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; | ||
import org.apache.doris.nereids.trees.expressions.literal.Literal; | ||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; | ||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; | ||
import org.apache.doris.nereids.trees.plans.Plan; | ||
import org.apache.doris.nereids.types.DecimalV2Type; | ||
import org.apache.doris.nereids.types.DecimalV3Type; | ||
import org.apache.doris.nereids.util.ExpressionUtils; | ||
import org.apache.doris.nereids.util.ImmutableEqualSet; | ||
import org.apache.doris.nereids.util.PredicateInferUtils; | ||
|
||
import com.google.common.collect.ImmutableList; | ||
import org.jetbrains.annotations.Nullable; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.LinkedHashSet; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Set; | ||
|
||
/**ReplacePredicate*/ | ||
public class InferPredicateByReplace { | ||
private static List<Expression> getAllSubExpressions(Expression expr) { | ||
List<Expression> subExpressions = new ArrayList<>(); | ||
getAllSubExpressions(expr, subExpressions); | ||
return subExpressions; | ||
} | ||
|
||
private static void getAllSubExpressions(Expression expr, List<Expression> res) { | ||
res.add(expr); | ||
if (expr.children().size() != 1) { | ||
Set<Slot> slots = expr.getInputSlots(); | ||
if (slots.size() == 1) { | ||
res.add(slots.iterator().next()); | ||
} | ||
return; | ||
} | ||
getAllSubExpressions(expr.child(0), res); | ||
} | ||
|
||
/** fill map exprPredicates : expression and all its corresponding predicates */ | ||
private static class PredicatesCollector extends ExpressionVisitor<Void, Map<Expression, Set<Expression>>> { | ||
public static PredicatesCollector INSTANCE = new PredicatesCollector(); | ||
|
||
@Override | ||
public Void visit(Expression expr, Map<Expression, Set<Expression>> context) { | ||
return null; | ||
} | ||
|
||
@Override | ||
public Void visitOr(Or expr, Map<Expression, Set<Expression>> context) { | ||
return null; | ||
} | ||
|
||
@Override | ||
public Void visitInPredicate(InPredicate inPredicate, Map<Expression, Set<Expression>> context) { | ||
if (!validInPredicate(inPredicate)) { | ||
return null; | ||
} | ||
for (Expression expr : getAllSubExpressions(inPredicate.getCompareExpr())) { | ||
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(inPredicate); | ||
} | ||
return null; | ||
} | ||
|
||
@Override | ||
public Void visitComparisonPredicate(ComparisonPredicate comparisonPredicate, | ||
Map<Expression, Set<Expression>> context) { | ||
if (!validComparisonPredicate(comparisonPredicate)) { | ||
return null; | ||
} | ||
// It is believed that 1<a has been rewritten as a>1 | ||
for (Expression expr : getAllSubExpressions(comparisonPredicate.child(0))) { | ||
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(comparisonPredicate); | ||
} | ||
return null; | ||
} | ||
|
||
@Override | ||
public Void visitNot(Not not, Map<Expression, Set<Expression>> context) { | ||
if (not.child(0) instanceof InPredicate && validInPredicate((InPredicate) not.child(0)) | ||
|| not.child(0) instanceof ComparisonPredicate | ||
&& validComparisonPredicate((ComparisonPredicate) not.child(0))) { | ||
for (Expression expr : getAllSubExpressions(not.child(0).child(0))) { | ||
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(not); | ||
} | ||
} | ||
return null; | ||
} | ||
|
||
@Override | ||
public Void visitLike(Like like, Map<Expression, Set<Expression>> context) { | ||
if (!(like.child(1) instanceof Literal)) { | ||
return null; | ||
} | ||
for (Expression expr : getAllSubExpressions(like.child(0))) { | ||
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(like); | ||
} | ||
return null; | ||
} | ||
|
||
private boolean validComparisonPredicate(ComparisonPredicate comparisonPredicate) { | ||
return comparisonPredicate.right() instanceof Literal; | ||
} | ||
|
||
private boolean validInPredicate(InPredicate inPredicate) { | ||
return inPredicate.isLiteralChildren(); | ||
} | ||
} | ||
|
||
/* replaceToThis: find all predicates that replaceToThis can deduce (e.g. replaceToThis = b) | ||
equalSet: the equivalent set of replaceToThis (e.g. equalSet: a=b) | ||
exprPredicates: expression and all its corresponding predicates (e.g. such as {a: [a<10, a>1], b: [b in (1, 2)]}) | ||
return: all predicates that replaceToThis can deduce (return b<10, b>1) */ | ||
private static <T extends Expression> Set<Expression> getEqualSetAndDoReplace(T replaceToThis, Set<T> equalSet, | ||
Map<? extends Expression, Set<Expression>> exprPredicates) { | ||
ExpressionAnalyzer analyzer = new ReplaceAnalyzer(null, new Scope(ImmutableList.of()), null, false, false); | ||
Set<Expression> res = new LinkedHashSet<>(); | ||
for (T equals : equalSet) { | ||
Map<Expression, Expression> replaceMap = new HashMap<>(); | ||
replaceMap.put(equals, replaceToThis); | ||
if (!exprPredicates.containsKey(equals)) { | ||
continue; | ||
} | ||
for (Expression predicate : exprPredicates.get(equals)) { | ||
Expression newPredicates = ExpressionUtils.replace(predicate, replaceMap); | ||
try { | ||
Expression analyzed = analyzer.analyze(newPredicates); | ||
res.add(analyzed.withInferred(true)); | ||
} catch (Exception e) { | ||
// has cast error, just not infer and do nothing | ||
} | ||
} | ||
} | ||
return res; | ||
} | ||
|
||
/* Extract the equivalence relationship a=b, and when case (d_tinyint as int)=d_int is encountered, | ||
remove the cast and extract d_tinyint=d_int | ||
EqualPairs is the output parameter and the equivalent pair of predicate derivation input, | ||
which is used to ensure that the derivation | ||
does not generate repeated equivalent conditions, such as a=b and b=a */ | ||
private static ImmutableEqualSet<Expression> findEqual(Set<Expression> inputs) { | ||
ImmutableEqualSet.Builder<Expression> fromCastEqualSetBuilder = new ImmutableEqualSet.Builder<>(); | ||
for (Expression input : inputs) { | ||
if (!(input instanceof EqualTo)) { | ||
continue; | ||
} | ||
EqualTo equalTo = (EqualTo) input; | ||
Set<Slot> leftInputSlots = equalTo.left().getInputSlots(); | ||
Set<Slot> rightInputSlots = equalTo.right().getInputSlots(); | ||
if (leftInputSlots.isEmpty() && rightInputSlots.isEmpty()) { | ||
continue; | ||
} | ||
PredicateInferUtils.getPairFromCast((ComparisonPredicate) input) | ||
.filter(pair -> PredicateInferUtils.isSlotOrLiteral(pair.first) | ||
&& PredicateInferUtils.isSlotOrLiteral(pair.second)) | ||
.filter(pair -> !(pair.first instanceof NullLiteral) && !(pair.second instanceof NullLiteral)) | ||
.ifPresent(pair -> { | ||
Expression left = pair.first; | ||
Expression right = pair.second; | ||
fromCastEqualSetBuilder.addEqualPair(left, right); | ||
}); | ||
} | ||
return fromCastEqualSetBuilder.build(); | ||
} | ||
|
||
/** This is the exposed interface. Inputs are the input predicates for derivation. | ||
* The return value is the derived predicates*/ | ||
public static Set<Expression> infer(Set<Expression> inputs) { | ||
ImmutableEqualSet<Expression> hasCastEqualSet = findEqual(inputs); | ||
Set<Expression> targetExprs = hasCastEqualSet.getAllItemSet(); | ||
if (targetExprs.isEmpty()) { | ||
return new LinkedHashSet<>(inputs); | ||
} | ||
Map<Expression, Set<Expression>> exprPredicates = new HashMap<>(); | ||
for (Expression input : inputs) { | ||
if (input.anyMatch(expr -> !((ExpressionTrait) expr).isDeterministic()) | ||
|| input.getInputSlots().size() != 1) { | ||
continue; | ||
} | ||
input.accept(PredicatesCollector.INSTANCE, exprPredicates); | ||
} | ||
Set<Expression> inferPredicates = new LinkedHashSet<>(inputs); | ||
if (!exprPredicates.isEmpty()) { | ||
for (Expression expr : targetExprs) { | ||
if (expr instanceof Literal) { | ||
continue; | ||
} | ||
inferPredicates.addAll(getEqualSetAndDoReplace(expr, hasCastEqualSet.calEqualSet(expr), | ||
exprPredicates)); | ||
} | ||
} | ||
return inferPredicates; | ||
} | ||
|
||
/** ReplaceAnalyzer is to perform type conversion on the expression after replacement | ||
* and perform type check on the expression. | ||
* If there is a cast that will cause an error during execution, an exception should be thrown. */ | ||
private static class ReplaceAnalyzer extends ExpressionAnalyzer { | ||
private ReplaceAnalyzer(Plan currentPlan, Scope scope, | ||
@Nullable CascadesContext cascadesContext, | ||
boolean enableExactMatch, boolean bindSlotInOuterScope) { | ||
super(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope); | ||
} | ||
|
||
@Override | ||
public Expression visitCast(Cast cast, ExpressionRewriteContext context) { | ||
cast = (Cast) super.visitCast(cast, context); | ||
if (cast.getDataType().isDecimalV3Type()) { | ||
DecimalV3Type targetType = (DecimalV3Type) cast.getDataType(); | ||
DecimalV3Type childType = DecimalV3Type.forType(cast.child().getDataType()); | ||
if ((childType.getPrecision() - childType.getScale()) | ||
> (targetType.getPrecision() - targetType.getScale()) | ||
|| childType.getScale() > targetType.getScale()) { | ||
throw new AnalysisException("can not cast from origin type " + cast.child().getDataType() | ||
+ " to target type=" + targetType); | ||
} | ||
} else if (cast.getDataType().isDecimalV2Type()) { | ||
DecimalV2Type targetType = (DecimalV2Type) cast.getDataType(); | ||
DecimalV2Type childType = DecimalV2Type.forType(cast.child().getDataType()); | ||
if ((childType.getPrecision() - childType.getScale()) | ||
> (targetType.getPrecision() - targetType.getScale()) | ||
|| childType.getScale() > targetType.getScale()) { | ||
throw new AnalysisException("can not cast from origin type " + cast.child().getDataType() | ||
+ " to target type=" + targetType); | ||
} | ||
} | ||
return cast; | ||
} | ||
} | ||
} |
Oops, something went wrong.