Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32444][SQL] Infer filters from DPP #29243

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@ object SubqueryExpression {
}

/**
* Returns true when an expression contains a subquery that has outer reference(s). The outer
* Returns true when an expression contains a subquery that has outer reference(s) except
* the [[org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery]]. The outer
* reference attributes are kept as children of subquery expression by
* [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]]
*/
def hasCorrelatedSubquery(e: Expression): Boolean = {
e.find {
case _: DynamicPruningSubquery => false
case s: SubqueryExpression => s.children.nonEmpty
case _ => false
}.isDefined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,16 @@ trait ConstraintHelper {
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*/
def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = {
def inferAdditionalConstraints(
constraints: ExpressionSet,
isInferDynamicPruning: Boolean = false): ExpressionSet = {
var inferredConstraints = ExpressionSet()
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])
val predicates = if (isInferDynamicPruning) {
constraints.filterNot(_.isInstanceOf[IsNotNull])
} else {
constraints.filterNot(e => e.isInstanceOf[IsNotNull] || e.isInstanceOf[DynamicPruning])
}
predicates.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = predicates - eq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.SchemaPruning
import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, InferDynamicPruningFilters, PartitionPruning}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs}

class SparkOptimizer(
Expand All @@ -44,7 +44,11 @@ class SparkOptimizer(
Batch("PartitionPruning", Once,
PartitionPruning,
OptimizeSubqueries) :+
Batch("Pushdown Filters from PartitionPruning", fixedPoint,
Batch("Pushdown Filters from PartitionPruning before Inferring Filters", fixedPoint,
PushDownPredicates) :+
Batch("Infer Filters from PartitionPruning", Once,
InferDynamicPruningFilters) :+
Batch("Pushdown Filters from PartitionPruning after Inferring Filters", fixedPoint,
PushDownPredicates) :+
Batch("Cleanup filters that cannot be pushed down", Once,
CleanupDynamicPruningFilters,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.dynamicpruning

import org.apache.spark.sql.catalyst.expressions.{And, DynamicPruningSubquery, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.{InnerLike, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{ConstraintHelper, Filter, Join, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.dynamicpruning.PartitionPruning._
import org.apache.spark.sql.internal.SQLConf

/**
* Similar to InferFiltersFromConstraints, this one only infer DynamicPruning filters.
*/
object InferDynamicPruningFilters extends Rule[LogicalPlan]
with PredicateHelper with ConstraintHelper {

def apply(plan: LogicalPlan): LogicalPlan = {
if (SQLConf.get.constraintPropagationEnabled) {
inferFilters(plan)
} else {
plan
}
}

private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform {
case join @ Join(left, right, joinType, _, _) =>
joinType match {
// For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
// inner join, it just drops the right side in the final output.
case _: InnerLike | LeftSemi =>
val allConstraints = inferDynamicPrunings(join)
val newLeft = inferNewFilter(left, allConstraints)
val newRight = inferNewFilter(right, allConstraints)
join.copy(left = newLeft, right = newRight)

// For right outer join, we can only infer additional filters for left side.
case RightOuter =>
val allConstraints = inferDynamicPrunings(join)
val newLeft = inferNewFilter(left, allConstraints)
join.copy(left = newLeft)

// For left join, we can only infer additional filters for right side.
case LeftOuter | LeftAnti =>
val allConstraints = inferDynamicPrunings(join)
val newRight = inferNewFilter(right, allConstraints)
join.copy(right = newRight)

case _ => join
}
}

def inferDynamicPrunings(join: Join): ExpressionSet = {
val baseConstraints = join.left.constraints.union(join.right.constraints)
.union(ExpressionSet(join.condition.map(splitConjunctivePredicates).getOrElse(Nil)))
inferAdditionalConstraints(baseConstraints, true).filter {
case DynamicPruningSubquery(
pruningKey, buildQuery, buildKeys, broadcastKeyIndex, _, _) =>
getPartitionTableScan(pruningKey, join) match {
case Some(partScan) =>
pruningHasBenefit(pruningKey, partScan, buildKeys(broadcastKeyIndex), buildQuery)
case _ =>
false
}
case _ => false
}
}

private def inferNewFilter(plan: LogicalPlan, dynamicPrunings: ExpressionSet): LogicalPlan = {
val newPredicates = dynamicPrunings
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
} -- plan.constraints
if (newPredicates.isEmpty) {
plan
} else {
Filter(newPredicates.reduce(And), plan)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper {
* using column statistics if they are available, otherwise we use the config value of
* `spark.sql.optimizer.joinFilterRatio`.
*/
private def pruningHasBenefit(
private[sql] def pruningHasBenefit(
partExpr: Expression,
partPlan: LogicalPlan,
otherExpr: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,111 @@ abstract class DynamicPartitionPruningSuiteBase
checkAnswer(df, Nil)
}
}

test("Infer filters from DPP") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
withTable("df1", "df2", "df3", "df4") {
spark.range(1000)
.select(col("id"), col("id").as("k"))
.write
.partitionBy("k")
.format(tableFormat)
.mode("overwrite")
.saveAsTable("df1")

spark.range(1000)
.select(col("id"), col("id").as("k"))
.write
.partitionBy("k")
.format(tableFormat)
.mode("overwrite")
.saveAsTable("df2")

spark.range(5)
.select(col("id"), col("id").as("k"))
.write
.partitionBy("k")
.format(tableFormat)
.mode("overwrite")
.saveAsTable("df3")

spark.range(100)
.select(col("id"), col("id").as("k"))
.write
.format(tableFormat)
.mode("overwrite")
.saveAsTable("df4")

spark.range(1000)
.select(col("id"), col("id").as("k"))
.write
.format(tableFormat)
.mode("overwrite")
.saveAsTable("df5")

Given("Inferred DPP on partition column")
Seq(true, false).foreach { infer =>
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> s"$infer") {
val df = sql(
"""
|SELECT t1.id,
| df4.k
|FROM (SELECT df2.id,
| df1.k
| FROM df1
| JOIN df2
| ON df1.k = df2.k) t1
| JOIN df4
| ON t1.k = df4.k AND df4.id < 2
|""".stripMargin)
if (infer) {
assert(collectDynamicPruningExpressions(df.queryExecution.executedPlan).size === 2)
} else {
assert(collectDynamicPruningExpressions(df.queryExecution.executedPlan).size === 1)
}
checkAnswer(df, Row(0, 0) :: Row(1, 1) :: Nil)
}
}

Given("Remove no benefit inferred DPP on partition column")
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") {
val df = sql(
"""
|SELECT t1.id,
| df4.k
|FROM (SELECT df3.id,
| df1.k
| FROM df1
| JOIN df3
| ON df1.k = df3.k) t1
| JOIN df4
| ON t1.k = df4.k AND df4.id < 2
|""".stripMargin)
assert(collectDynamicPruningExpressions(df.queryExecution.executedPlan).size === 1)
checkAnswer(df, Row(0, 0) :: Row(1, 1) :: Nil)
}

Given("Remove inferred DPP on non-partition column")
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") {
val df = sql(
"""
|SELECT t1.id,
| df4.k
|FROM (SELECT df5.id,
| df1.k
| FROM df1
| JOIN df5
| ON df1.k = df5.k) t1
| JOIN df4
| ON t1.k = df4.k AND df4.id < 2
|""".stripMargin)

assert(collectDynamicPruningExpressions(df.queryExecution.executedPlan).size === 1)
checkAnswer(df, Row(0, 0) :: Row(1, 1) :: Nil)
}
}
}
}
}

class DynamicPartitionPruningSuiteAEOff extends DynamicPartitionPruningSuiteBase {
Expand Down