diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index e8c95e3a625a2..4e2727a87585d 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -189,11 +189,21 @@ "Remote client cannot create a SparkContext. Create SparkSession instead." ] }, + "DATA_SOURCE_EXTRANEOUS_FILTERS": { + "message": [ + ".pushFilters() returned filters that are not part of the input. Make sure that each returned filter is one of the input filters by reference." + ] + }, "DATA_SOURCE_INVALID_RETURN_TYPE": { "message": [ "Unsupported return type ('') from Python data source ''. Expected types: ." ] }, + "DATA_SOURCE_PUSHDOWN_DISABLED": { + "message": [ + " implements pushFilters() but filter pushdown is disabled because configuration '' is false. Set it to true to enable filter pushdown." + ] + }, "DATA_SOURCE_RETURN_SCHEMA_MISMATCH": { "message": [ "Return schema mismatch in the result from 'read' method. Expected: columns, Found: columns. Make sure the returned values match the required output schema." @@ -204,6 +214,11 @@ "Expected , but got ." ] }, + "DATA_SOURCE_UNSUPPORTED_FILTER": { + "message": [ + "Unexpected filter ." + ] + }, "DIFFERENT_PANDAS_DATAFRAME": { "message": [ "DataFrames are not almost equal:", diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index 651e84e84390e..2c6aad48bcb3f 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -16,7 +16,20 @@ # from abc import ABC, abstractmethod from collections import UserDict -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union, TYPE_CHECKING +from dataclasses import dataclass +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + TYPE_CHECKING, +) from pyspark.sql import Row from pyspark.sql.types import StructType @@ -38,6 +51,8 @@ "InputPartition", "SimpleDataSourceStreamReader", "WriterCommitMessage", + "Filter", + "EqualTo", ] @@ -234,6 +249,69 @@ def streamReader(self, schema: StructType) -> "DataSourceStreamReader": ) +ColumnPath = Tuple[str, ...] +""" +A tuple of strings representing a column reference. + +For example, `("a", "b", "c")` represents the column `a.b.c`. + +.. versionadded: 4.1.0 +""" + + +@dataclass(frozen=True) +class Filter(ABC): + """ + The base class for filters used for filter pushdown. + + .. versionadded: 4.1.0 + + Notes + ----- + Column references are represented as a tuple of strings. For example: + + +----------------+----------------------+ + | Column | Representation | + +----------------+----------------------+ + | `col1` | `("col1",)` | + | `a.b.c` | `("a", "b", "c")` | + +----------------+----------------------+ + + Literal values are represented as Python objects of types such as + `int`, `float`, `str`, `bool`, `datetime`, etc. + See `Data Types `_ + for more information about how values are represented in Python. + + Currently only the equality of attribute and literal value is supported for + filter pushdown. Other types of filters cannot be pushed down. + + Examples + -------- + Supported filters + + +---------------------+--------------------------------------------+ + | SQL filter | Representation | + +---------------------+--------------------------------------------+ + | `a.b.c = 1` | `EqualTo(("a", "b", "c"), 1)` | + | `a = 1` | `EqualTo(("a", "b", "c"), 1)` | + | `a = 'hi'` | `EqualTo(("a",), "hi")` | + | `a = array(1, 2)` | `EqualTo(("a",), [1, 2])` | + +---------------------+--------------------------------------------+ + + Unsupported filters + - `a = b` + - `f(a, b) = 1` + - `a % 2 = 1` + - `a[0] = 1` + """ + + +@dataclass(frozen=True) +class EqualTo(Filter): + attribute: ColumnPath + value: Any + + class InputPartition: """ A base class representing an input partition returned by the `partitions()` @@ -280,6 +358,67 @@ class DataSourceReader(ABC): .. versionadded: 4.0.0 """ + def pushFilters(self, filters: List["Filter"]) -> Iterable["Filter"]: + """ + Called with the list of filters that can be pushed down to the data source. + + The list of filters should be interpreted as the AND of the elements. + + Filter pushdown allows the data source to handle a subset of filters. This + can improve performance by reducing the amount of data that needs to be + processed by Spark. + + This method is called once during query planning. By default, it returns + all filters, indicating that no filters can be pushed down. Subclasses can + override this method to implement filter pushdown. + + It's recommended to implement this method only for data sources that natively + support filtering, such as databases and GraphQL APIs. + + .. versionadded: 4.1.0 + + Parameters + ---------- + filters : list of :class:`Filter`\\s + + Returns + ------- + iterable of :class:`Filter`\\s + Filters that still need to be evaluated by Spark post the data source + scan. This includes unsupported filters and partially pushed filters. + Every returned filter must be one of the input filters by reference. + + Side effects + ------------ + This method is allowed to modify `self`. The object must remain picklable. + Modifications to `self` are visible to the `partitions()` and `read()` methods. + + Examples + -------- + Example filters and the resulting arguments passed to pushFilters: + + +-------------------------------+---------------------------------------------+ + | Filters | Pushdown Arguments | + +-------------------------------+---------------------------------------------+ + | `a = 1 and b = 2` | `[EqualTo(("a",), 1), EqualTo(("b",), 2)]` | + | `a = 1 or b = 2` | `[]` | + | `a = 1 or (b = 2 and c = 3)` | `[]` | + | `a = 1 and (b = 2 or c = 3)` | `[EqualTo(("a",), 1)]` | + +-------------------------------+---------------------------------------------+ + + Implement pushFilters to support EqualTo filters only: + + >>> def pushFilters(self, filters): + ... for filter in filters: + ... if isinstance(filter, EqualTo): + ... # Save supported filter for handling in partitions() and read() + ... self.filters.append(filter) + ... else: + ... # Unsupported filter + ... yield filter + """ + return filters + def partitions(self) -> Sequence[InputPartition]: """ Returns an iterator of partitions for this data source. diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 34299bdb7740c..c5723c5c8b846 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -18,12 +18,14 @@ import platform import tempfile import unittest -from typing import Callable, Union +from typing import Callable, Iterable, List, Union from pyspark.errors import PythonException, AnalysisException from pyspark.sql.datasource import ( DataSource, DataSourceReader, + EqualTo, + Filter, InputPartition, DataSourceWriter, DataSourceArrowWriter, @@ -31,6 +33,7 @@ CaseInsensitiveDict, ) from pyspark.sql.functions import spark_partition_id +from pyspark.sql.session import SparkSession from pyspark.sql.types import Row, StructType from pyspark.testing.sqlutils import ( have_pyarrow, @@ -42,6 +45,8 @@ @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) class BasePythonDataSourceTestsMixin: + spark: SparkSession + def test_basic_data_source_class(self): class MyDataSource(DataSource): ... @@ -246,6 +251,161 @@ def reader(self, schema) -> "DataSourceReader": assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")]) self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2) + def test_filter_pushdown(self): + class TestDataSourceReader(DataSourceReader): + def __init__(self): + self.has_filter = False + + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: + assert set(filters) == { + EqualTo(("x",), 1), + EqualTo(("y",), 2), + }, filters + self.has_filter = True + # pretend we support x = 1 filter but in fact we don't + # so we only return y = 2 filter + yield filters[filters.index(EqualTo(("y",), 2))] + + def partitions(self): + assert self.has_filter + return super().partitions() + + def read(self, partition): + assert self.has_filter + yield [1, 1] + yield [1, 2] + yield [2, 2] + + class TestDataSource(DataSource): + @classmethod + def name(cls): + return "test" + + def schema(self): + return "x int, y int" + + def reader(self, schema) -> "DataSourceReader": + return TestDataSourceReader() + + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): + self.spark.dataSource.register(TestDataSource) + df = self.spark.read.format("test").load().filter("x = 1 and y = 2") + # only the y = 2 filter is applied post scan + assertDataFrameEqual(df, [Row(x=1, y=2), Row(x=2, y=2)]) + + def test_extraneous_filter(self): + class TestDataSourceReader(DataSourceReader): + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: + yield EqualTo(("x",), 1) + + def partitions(self): + assert False + + def read(self, partition): + assert False + + class TestDataSource(DataSource): + @classmethod + def name(cls): + return "test" + + def schema(self): + return "x int" + + def reader(self, schema) -> "DataSourceReader": + return TestDataSourceReader() + + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): + self.spark.dataSource.register(TestDataSource) + with self.assertRaisesRegex(Exception, "DATA_SOURCE_EXTRANEOUS_FILTERS"): + self.spark.read.format("test").load().filter("x = 1").show() + + def test_filter_pushdown_error(self): + error_str = "dummy error" + + class TestDataSourceReader(DataSourceReader): + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: + raise Exception(error_str) + + def read(self, partition): + yield [1] + + class TestDataSource(DataSource): + def schema(self): + return "x int" + + def reader(self, schema) -> "DataSourceReader": + return TestDataSourceReader() + + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): + self.spark.dataSource.register(TestDataSource) + df = self.spark.read.format("TestDataSource").load().filter("x = 1 or x is null") + assertDataFrameEqual(df, [Row(x=1)]) # works when not pushing down filters + with self.assertRaisesRegex(Exception, error_str): + df.filter("x = 1").explain() + + def test_filter_pushdown_disabled(self): + class TestDataSourceReader(DataSourceReader): + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: + assert False + + def read(self, partition): + assert False + + class TestDataSource(DataSource): + def reader(self, schema) -> "DataSourceReader": + return TestDataSourceReader() + + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": False}): + self.spark.dataSource.register(TestDataSource) + df = self.spark.read.format("TestDataSource").schema("x int").load() + with self.assertRaisesRegex(Exception, "DATA_SOURCE_PUSHDOWN_DISABLED"): + df.show() + + def _check_filters(self, sql_type, sql_filter, python_filters): + """ + Parameters + ---------- + sql_type: str + The SQL type of the column x. + sql_filter: str + A SQL filter using the column x. + python_filters: List[Filter] + The expected python filters to be pushed down. + """ + + class TestDataSourceReader(DataSourceReader): + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: + expected = python_filters + assert filters == expected, (filters, expected) + return filters + + def read(self, partition): + yield from [] + + class TestDataSource(DataSource): + def schema(self): + return f"x {sql_type}" + + def reader(self, schema) -> "DataSourceReader": + return TestDataSourceReader() + + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): + self.spark.dataSource.register(TestDataSource) + df = self.spark.read.format("TestDataSource").load().filter(sql_filter) + df.count() + + def test_unsupported_filter(self): + self._check_filters( + "struct", "x.a = 1 and x.b = x.c", [EqualTo(("x", "a"), 1)] + ) + self._check_filters("int", "x <> 0", []) + self._check_filters("int", "x = 1 or x > 2", []) + self._check_filters("int", "(0 < x and x < 1) or x = 2", []) + self._check_filters("int", "x % 5 = 1", []) + self._check_filters("boolean", "not x", []) + self._check_filters("array", "x[0] = 1", []) + def _get_test_json_data_source(self): import json import os diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py b/python/pyspark/sql/worker/data_source_pushdown_filters.py new file mode 100644 index 0000000000000..3aceb1b1ff08e --- /dev/null +++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py @@ -0,0 +1,206 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import faulthandler +import os +import sys +from dataclasses import dataclass, field +from typing import IO, List + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.errors import PySparkAssertionError, PySparkValueError +from pyspark.serializers import SpecialLengths, UTF8Deserializer, read_int, write_int +from pyspark.sql.datasource import DataSource, DataSourceReader, EqualTo, Filter +from pyspark.sql.types import StructType, _parse_datatype_json_string +from pyspark.util import handle_worker_exception, local_connect_and_auth +from pyspark.worker_util import ( + check_python_version, + pickleSer, + read_command, + send_accumulator_updates, + setup_broadcasts, + setup_memory_limits, + setup_spark_files, +) + +utf8_deserializer = UTF8Deserializer() + + +@dataclass(frozen=True) +class FilterRef: + filter: Filter = field(compare=False) + id: int = field(init=False) # only id is used for comparison + + def __post_init__(self) -> None: + object.__setattr__(self, "id", id(self.filter)) + + +def main(infile: IO, outfile: IO) -> None: + """ + Main method for planning a data source read with filter pushdown. + + This process is invoked from the `UserDefinedPythonDataSourceReadRunner.runInPython` + method in the optimizer rule `PlanPythonDataSourceScan` in JVM. This process is responsible + for creating a `DataSourceReader` object, applying filter pushdown, and sending the + information needed back to the JVM. + + The infile and outfile are connected to the JVM via a socket. The JVM sends the following + information to this process via the socket: + - a `DataSource` instance representing the data source + - a `StructType` instance representing the output schema of the data source + - a list of filters to be pushed down + + This process then creates a `DataSourceReader` instance by calling the `reader` method + on the `DataSource` instance. It applies the filters by calling the `pushFilters` method + on the reader and determines which filters are supported. The data source with updated reader + is then sent back to the JVM along with the indices of the supported filters. + """ + faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) + try: + if faulthandler_log_path: + faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) + faulthandler_log_file = open(faulthandler_log_path, "w") + faulthandler.enable(file=faulthandler_log_file) + + check_python_version(infile) + + memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1")) + setup_memory_limits(memory_limit_mb) + + setup_spark_files(infile) + setup_broadcasts(infile) + + _accumulatorRegistry.clear() + + # ---------------------------------------------------------------------- + # Start of worker logic + # ---------------------------------------------------------------------- + + # Receive the data source instance. + data_source = read_command(pickleSer, infile) + if not isinstance(data_source, DataSource): + raise PySparkAssertionError( + errorClass="DATA_SOURCE_TYPE_MISMATCH", + messageParameters={ + "expected": "a Python data source instance of type 'DataSource'", + "actual": f"'{type(data_source).__name__}'", + }, + ) + + # Receive the data source output schema. + schema_json = utf8_deserializer.loads(infile) + schema = _parse_datatype_json_string(schema_json) + if not isinstance(schema, StructType): + raise PySparkAssertionError( + errorClass="DATA_SOURCE_TYPE_MISMATCH", + messageParameters={ + "expected": "an output schema of type 'StructType'", + "actual": f"'{type(schema).__name__}'", + }, + ) + + # Get the reader. + reader = data_source.reader(schema=schema) + # Validate the reader. + if not isinstance(reader, DataSourceReader): + raise PySparkAssertionError( + errorClass="DATA_SOURCE_TYPE_MISMATCH", + messageParameters={ + "expected": "an instance of DataSourceReader", + "actual": f"'{type(reader).__name__}'", + }, + ) + + # Receive the pushdown filters. + num_filters = read_int(infile) + filters: List[FilterRef] = [] + for _ in range(num_filters): + name = utf8_deserializer.loads(infile) + if name == "EqualTo": + num_parts = read_int(infile) + column_path = tuple(utf8_deserializer.loads(infile) for _ in range(num_parts)) + value = read_int(infile) + filters.append(FilterRef(EqualTo(column_path, value))) + else: + raise PySparkAssertionError( + errorClass="DATA_SOURCE_UNSUPPORTED_FILTER", + messageParameters={ + "name": name, + }, + ) + + # Push down the filters and get the indices of the unsupported filters. + unsupported_filters = set( + FilterRef(f) for f in reader.pushFilters([ref.filter for ref in filters]) + ) + supported_filter_indices = [] + for i, filter in enumerate(filters): + if filter in unsupported_filters: + unsupported_filters.remove(filter) + else: + supported_filter_indices.append(i) + + # If it returned any filters that are not in the original filters, raise an error. + if len(unsupported_filters) > 0: + raise PySparkValueError( + errorClass="DATA_SOURCE_EXTRANEOUS_FILTERS", + messageParameters={ + "type": type(reader).__name__, + "input": str(list(filters)), + "extraneous": str(list(unsupported_filters)), + }, + ) + + # Monkey patch the data source instance + # to return the existing reader with the pushed down filters. + data_source.reader = lambda schema: reader # type: ignore[method-assign] + pickleSer._write_with_length(data_source, outfile) + + # Return the supported filter indices. + write_int(len(supported_filter_indices), outfile) + for index in supported_filter_indices: + write_int(index, outfile) + + # ---------------------------------------------------------------------- + # End of worker logic + # ---------------------------------------------------------------------- + except BaseException as e: + handle_worker_exception(e, outfile) + sys.exit(-1) + finally: + if faulthandler_log_path: + faulthandler.disable() + faulthandler_log_file.close() + os.remove(faulthandler_log_path) + + send_accumulator_updates(outfile) + + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + sys.exit(-1) + + +if __name__ == "__main__": + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index 4c6fd4c0a77c3..8e6dd1626785a 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -250,6 +250,7 @@ def main(infile: IO, outfile: IO) -> None: "The maximum arrow batch size should be greater than 0, but got " f"'{max_arrow_batch_size}'" ) + enable_pushdown = read_bool(infile) is_streaming = read_bool(infile) @@ -269,6 +270,19 @@ def main(infile: IO, outfile: IO) -> None: "actual": f"'{type(reader).__name__}'", }, ) + is_pushdown_implemented = ( + getattr(reader.pushFilters, "__func__", None) is not DataSourceReader.pushFilters + ) + if is_pushdown_implemented and not enable_pushdown: + # Do not silently ignore pushFilters when pushdown is disabled. + # Raise an error to ask the user to enable pushdown. + raise PySparkAssertionError( + errorClass="DATA_SOURCE_PUSHDOWN_DISABLED", + messageParameters={ + "type": type(reader).__name__, + "conf": "spark.sql.python.filterPushdown.enabled", + }, + ) # Create input converter. converter = ArrowTableToRowsConversion._create_converter(BinaryType()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ca7d8ce037931..34ec24098748c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4716,6 +4716,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PYTHON_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.python.filterPushdown.enabled") + .doc("When true, enable filter pushdown to Python datasource, at the cost of running " + + "Python worker one additional time during planning.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + val CSV_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.csv.filterPushdown.enabled") .doc("When true, enable filter pushdown to CSV datasource.") .version("3.0.0") @@ -6501,6 +6508,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def useListFilesFileSystemList: String = getConf(SQLConf.USE_LISTFILES_FILESYSTEM_LIST) + def pythonFilterPushDown: Boolean = getConf(PYTHON_FILTER_PUSHDOWN_ENABLED) + def csvFilterPushDown: Boolean = getConf(CSV_FILTER_PUSHDOWN_ENABLED) def jsonFilterPushDown: Boolean = getConf(JSON_FILTER_PUSHDOWN_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala index edea702587791..f5a1d41e52e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala @@ -52,6 +52,10 @@ class PythonDataSourceV2 extends TableProvider { dataSourceInPython } + def setDataSourceInPython(dataSourceInPython: PythonDataSourceCreationResult): Unit = { + this.dataSourceInPython = dataSourceInPython + } + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { getOrCreateDataSourceInPython(shortName, options, None).schema } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala index 8ebb91c01fc5c..d53b47bf462ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala @@ -20,16 +20,18 @@ import org.apache.spark.JobArtifactSet import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.streaming.MicroBatchStream +import org.apache.spark.sql.internal.connector.SupportsMetadata +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap - class PythonScan( - ds: PythonDataSourceV2, - shortName: String, - outputSchema: StructType, - options: CaseInsensitiveStringMap) extends Scan { - + ds: PythonDataSourceV2, + shortName: String, + outputSchema: StructType, + options: CaseInsensitiveStringMap, + supportedFilters: Array[Filter] +) extends Scan with SupportsMetadata { override def toBatch: Batch = new PythonBatch(ds, shortName, outputSchema, options) override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = @@ -44,6 +46,13 @@ class PythonScan( override def columnarSupportMode(): Scan.ColumnarSupportMode = Scan.ColumnarSupportMode.UNSUPPORTED + + override def getMetaData(): Map[String, String] = { + Map( + "PushedFilters" -> supportedFilters.mkString("[", ", ", "]"), + "ReadSchema" -> outputSchema.simpleString + ) + } } class PythonBatch( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala index e30fc9f7978cb..ad9ffd25cc45b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.sql.execution.datasources.v2.python -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -25,6 +27,34 @@ class PythonScanBuilder( ds: PythonDataSourceV2, shortName: String, outputSchema: StructType, - options: CaseInsensitiveStringMap) extends ScanBuilder { - override def build(): Scan = new PythonScan(ds, shortName, outputSchema, options) + options: CaseInsensitiveStringMap) + extends ScanBuilder + with SupportsPushDownFilters { + private var supportedFilters: Array[Filter] = Array.empty + + override def build(): Scan = + new PythonScan(ds, shortName, outputSchema, options, supportedFilters) + + // Optionally called by DSv2 once to push down filters before the scan is built. + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + if (!SQLConf.get.pythonFilterPushDown) { + return filters + } + + val dataSource = ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)) + val result = ds.source.pushdownFiltersInPython(dataSource, outputSchema, filters) + + // The Data Source instance state changes after pushdown to remember the reader instance + // created and the filters pushed down. So pushdownFiltersInPython returns a new pickled + // Data Source instance. We need to use that new instance for further operations. + ds.setDataSourceInPython(dataSource.copy(dataSource = result.dataSource)) + + // Partition the filters into supported and unsupported ones. + val isPushed = result.isFilterPushed.zip(filters) + supportedFilters = isPushed.collect { case (true, filter) => filter }.toArray + val unsupported = isPushed.collect { case (false, filter) => filter }.toArray + unsupported + } + + override def pushedFilters(): Array[Filter] = supportedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala index b3fd8479bda0d..f70ffa8e95653 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInBatchEvaluatorFactory, PythonPlannerRunner, PythonSQLMetrics} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{BinaryType, DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -61,6 +62,26 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)).runInPython() } + /** + * (Driver-side) Run Python process to push down filters, get the updated + * data source instance and the filter pushdown result. + */ + def pushdownFiltersInPython( + pythonResult: PythonDataSourceCreationResult, + outputSchema: StructType, + filters: Array[Filter]): PythonFilterPushdownResult = { + val runner = new UserDefinedPythonDataSourceFilterPushdownRunner( + createPythonFunction(pythonResult.dataSource), + outputSchema, + filters + ) + if (runner.isAnyFilterSupported) { + runner.runInPython() + } else { + PythonFilterPushdownResult(pythonResult.dataSource, filters.map(_ => false)) + } + } + /** * (Driver-side) Run Python process, and get the partition read functions, and * partition information. @@ -300,6 +321,97 @@ private class UserDefinedPythonDataSourceRunner( } } +/** + * @param isFilterPushed A sequence of bools indicating whether each filter is pushed down. + */ +case class PythonFilterPushdownResult( + dataSource: Array[Byte], + isFilterPushed: collection.Seq[Boolean]) + +/** + * Push down filters to a Python data source. + * + * @param dataSource + * a Python data source instance + * @param schema + * output schema of the Python data source + * @param filters + * all filters to be pushed down + */ +private class UserDefinedPythonDataSourceFilterPushdownRunner( + dataSource: PythonFunction, + schema: StructType, + filters: collection.Seq[Filter]) + extends PythonPlannerRunner[PythonFilterPushdownResult](dataSource) { + + private case class SerializedFilter( + name: String, + columnPath: collection.Seq[String], + value: Int, + index: Int) + + private val serializedFilters = filters.zipWithIndex.flatMap { + case (filter, i) => + filter match { + case filter @ org.apache.spark.sql.sources.EqualTo(_, value: Int) => + val columnPath = filter.v2references.head + Some(SerializedFilter("EqualTo", columnPath, value, i)) + case _ => + None + } + } + + // See the logic in `pyspark.sql.worker.data_source_pushdown_filters.py`. + override val workerModule = "pyspark.sql.worker.data_source_pushdown_filters" + + def isAnyFilterSupported: Boolean = serializedFilters.nonEmpty + + override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = { + // Send Python data source + PythonWorkerUtils.writePythonFunction(dataSource, dataOut) + + // Send output schema + PythonWorkerUtils.writeUTF(schema.json, dataOut) + + // Send the filters + // For now only handle EqualTo filter on int + dataOut.writeInt(serializedFilters.length) + for (f <- serializedFilters) { + PythonWorkerUtils.writeUTF(f.name, dataOut) + dataOut.writeInt(f.columnPath.length) + for (path <- f.columnPath) { + PythonWorkerUtils.writeUTF(path, dataOut) + } + dataOut.writeInt(f.value) + } + } + + override protected def receiveFromPython(dataIn: DataInputStream): PythonFilterPushdownResult = { + // Receive the picked data source or an exception raised in Python worker. + val length = dataIn.readInt() + if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryCompilationErrors.pythonDataSourceError(action = "plan", tpe = "read", msg = msg) + } + + // Receive the pickled data source. + val pickledDataSourceInstance: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn) + + // Receive the pushed filters as a list of indices. + val numFiltersPushed = dataIn.readInt() + val isFilterPushed = ArrayBuffer.fill(filters.length)(false) + for (_ <- 0 until numFiltersPushed) { + val i = dataIn.readInt() + isFilterPushed(serializedFilters(i).index) = true + } + + PythonFilterPushdownResult( + dataSource = pickledDataSourceInstance, + isFilterPushed = isFilterPushed + ) + } +} + case class PythonDataSourceReadInfo( func: Array[Byte], partitions: Seq[Array[Byte]]) @@ -332,6 +444,7 @@ private class UserDefinedPythonDataSourceReadRunner( // Send configurations dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch) + dataOut.writeBoolean(SQLConf.get.pythonFilterPushDown) dataOut.writeBoolean(isStreaming) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 73c05ff0e0b58..d349c5bdf7005 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -22,13 +22,20 @@ import java.io.{File, FileWriter} import org.apache.spark.SparkException import org.apache.spark.api.python.PythonUtils import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} +import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.DataSourceManager import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} +import org.apache.spark.sql.execution.datasources.v2.python.PythonScan +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -abstract class PythonDataSourceSuiteBase extends QueryTest with SharedSparkSession { +abstract class PythonDataSourceSuiteBase + extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { protected val simpleDataSourceReaderScript: String = """ @@ -213,6 +220,75 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\"")) } + test("data source reader with filter pushdown") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import ( + | DataSource, + | DataSourceReader, + | EqualTo, + | InputPartition, + |) + | + |class SimpleDataSourceReader(DataSourceReader): + | def partitions(self): + | return [InputPartition(i) for i in range(2)] + | + | def pushFilters(self, filters): + | yield filters[filters.index(EqualTo(("id",), 1))] + | + | def read(self, partition): + | yield (0, partition.value) + | yield (1, partition.value) + | yield (2, partition.value) + | + |class SimpleDataSource(DataSource): + | def schema(self): + | return "id int, partition int" + | + | def reader(self, schema): + | return SimpleDataSourceReader() + |""".stripMargin + val schema = StructType.fromDDL("id INT, partition INT") + val dataSource = + createUserDefinedPythonDataSource(name = dataSourceName, pythonScript = dataSourceScript) + withSQLConf(SQLConf.PYTHON_FILTER_PUSHDOWN_ENABLED.key -> "true") { + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = + spark.read.format(dataSourceName).schema(schema).load().filter("id = 1 and partition = 0") + val plan = df.queryExecution.executedPlan + + /** + * == Physical Plan == + * *(1) Project [id#261, partition#262] + * +- *(1) Filter ((isnotnull(id#261) AND isnotnull(partition#262)) AND (id#261 = 1)) + * +- BatchScan SimpleDataSource[id#261, partition#262] (Python) + * PushedFilters: [EqualTo(partition,0)], + * ReadSchema: struct RuntimeFilters: [] + */ + val filter = collectFirst(df.queryExecution.executedPlan) { + case s: FilterExec => + val condition = s.condition.toString + assert(!condition.contains("= 0")) // pushed filter is not in FilterExec + assert(condition.contains("= 1")) // unsupported filter is in FilterExec + s + }.getOrElse( + fail(s"Filter not found in the plan. Actual plan:\n$plan") + ) + + collectFirst(filter) { + case s: BatchScanExec if s.scan.isInstanceOf[PythonScan] => + val p = s.scan.asInstanceOf[PythonScan] + assert(p.getMetaData().get("PushedFilters").contains("[EqualTo(partition,0)]")) + }.getOrElse( + fail(s"PythonScan not found in the plan. Actual plan:\n$plan") + ) + + checkAnswer(df, Seq(Row(1, 0), Row(1, 1))) + } + } + test("register data source") { assume(shouldTestPandasUDFs) val dataSourceScript =