Skip to content

Commit b08cb22

Browse files
wenghallisonwang-db
authored andcommitted
[SPARK-51271][PYTHON] Add filter pushdown API to Python Data Sources
### What changes were proposed in this pull request? This PR adds support for filter pushdown to Python Data Source batch read, with a API similar to `SupportsPushDownFilters` interface. The user can implement `DataSourceReader.pushFilters` to receive filters that may be pushed down, decide which filters to push down, remember them, and return the remaining filters to be applied by Spark. Note that filter pushdown is only supported for batch read, not for streaming read. This is also the case for the Scala API. Therefore the new API is added to `DataSourceReader` and not to `DataSource` or `DataSourceStreamReader`. To keep the Python API simple, we will only support V1 filters that have a column, a boolean operator, and a literal value. The filter serialization is a placeholder and will be implemented in a future PR. ```py class DataSourceReader(ABC): ... def pushFilters(self, filters: List["Filter"]) -> Iterable["Filter"]: """ Called with the list of filters that can be pushed down to the data source. The list of filters should be interpreted as the AND of the elements. Filter pushdown allows the data source to handle a subset of filters. This can improve performance by reducing the amount of data that needs to be processed by Spark. This method is called once during query planning. By default, it returns all filters, indicating that no filters can be pushed down. Subclasses can override this method to implement filter pushdown. It's recommended to implement this method only for data sources that natively support filtering, such as databases and GraphQL APIs. .. versionadded: 4.1.0 Parameters ---------- filters : list of :class:`Filter`\\s Returns ------- iterable of :class:`Filter`\\s Filters that still need to be evaluated by Spark post the data source scan. This includes unsupported filters and partially pushed filters. Every returned filter must be one of the input filters by reference. Side effects ------------ This method is allowed to modify `self`. The object must remain picklable. Modifications to `self` are visible to the `partitions()` and `read()` methods. Examples -------- Example filters and the resulting arguments passed to pushFilters: +-------------------------------+---------------------------------------------+ | Filters | Pushdown Arguments | +-------------------------------+---------------------------------------------+ | `a = 1 and b = 2` | `[EqualTo(("a",), 1), EqualTo(("b",), 2)]` | | `a = 1 or b = 2` | `[]` | | `a = 1 or (b = 2 and c = 3)` | `[]` | | `a = 1 and (b = 2 or c = 3)` | `[EqualTo(("a",), 1)]` | +-------------------------------+---------------------------------------------+ Implement pushFilters to support EqualTo filters only: >>> def pushFilters(self, filters): ... for filter in filters: ... if isinstance(filter, EqualTo): ... # Save supported filter for handling in partitions() and read() ... self.filters.append(filter) ... else: ... # Unsupported filter ... yield filter """ return filters ``` #### Roadmap - (this PR) Add filter pushdown API - Implement filter serialization and more filter types - Add column pruning API #### Suggested reivew order (from high level to details) 1. `datasource.py`: add filter pushdown to Python Data Source API 2. `test_python_datasource.py`: tests for filter pushdown 3. `PythonScanBuilder.scala`: implement filter pushdown API in Scala 4. `UserDefinedPythonDataSource.scala`, `data_source_pushdown_filters.py`: communication between Python and Scala and filter pushdown logic - Note that the current filter serialization is a placeholder. An upcoming PR will implement the actual serialization. #### Changes to interactions between Python and Scala Original sequence: ```mermaid sequenceDiagram # autonumber participant S as Data Sources API participant D as PythonDataSourceV2 participant P as Python Worker participant U as User Implementation S ->> D: PythonDataSourceV2.inferSchema(options) D ->+ P: create_data_source.py D -->> P: pickled DS class, name,<br/>schema, options P ->+ U: unpickle DS class P ->> U: DataSource(options) U -->> P: DS instance U ->- P: pickle DS instance P -->>- D: pickled DS, schema D -->> S: schema S ->> D: PythonDataSourceV2.getTable(...) D -->> S: PythonTable S ->> D: PythonTable.newScanBuilder(options) D -->> S: PythonScanBuilder S ->> D: PythonScanBuilder.build() D -->> S: PythonScan S ->> D: PythonScan.toBatch() D -->> S: PythonBatch S ->> D: PythonBatch.planInputPartitions() D ->+ P: plan_data_source_read.py D -->> P: pickled DS, schema, ... P ->+ U: unpickle DS P ->> U: DS.reader(schema) U -->> P: reader P ->> U: reader.partitions() U -->> P: partitions U ->- P: pickle reader P -->>- D: pickled read,<br/>pickled partitions D -->> S: partitions ``` Updated sequence (new interactions are highlighted in yellow): ```mermaid sequenceDiagram # autonumber participant S as Data Sources API participant D as PythonDataSourceV2 participant P as Python Worker participant U as User Implementation S ->> D: PythonDataSourceV2.inferSchema(options) D ->+ P: create_data_source.py D -->> P: pickled DS class, name,<br/>schema, options P ->+ U: unpickle DS class P ->> U: DataSource(options) U -->> P: DS instance U ->- P: pickle DS instance P -->>- D: pickled DS, schema D -->> S: schema S ->> D: PythonDataSourceV2.getTable(...) D -->> S: PythonTable S ->> D: PythonTable.newScanBuilder(options) D -->> S: PythonScanBuilder rect rgb(255, 252, 238) S ->> D: PythonScanBuilder.pushFilters(filters) note right of D: Pushdown filters note right of D: Only simple filters are serialized<br/> and passed to Python note right of D: Other more complex filters are<br/> directly marked as unsupported D ->+ P: data_source_pushdown_filters.py D -->> P: pickled DS, filters, schema P ->+ U: unpickle DS P ->> U: DS.reader(schema) U -->> P: DataSourceReader P ->> U: reader.push_filters(filters) U -->> P: unsupported filters U ->- P: pickle DS with <br/>monkey patched reader() P -->>- D: pickled DS, supported filters D -->> S: unsupported filters, supported filters S ->> D: PythonScanBuilder.pushedFilters() D -->> S: supported filters end S ->> D: PythonScanBuilder.build() D -->> S: PythonScan S ->> D: PythonScan.toBatch() D -->> S: PythonBatch S ->> D: PythonBatch.planInputPartitions() D ->+ P: plan_data_source_read.py D -->> P: pickled DS, schema, ... P ->+ U: unpickle DS P ->> U: DS.reader(schema) U -->> P: DataSourceReader P ->> U: reader.partitions() U -->> P: partitions U ->- P: pickle read function P -->>- D: pickled read,<br/>pickled partitions D -->> S: partitions ``` ### Why are the changes needed? Filter pushdown allows reducing the amount of data produced by the reader, by filtering rows directly in the data source scan. The reduction in the amount of data can improve query performance. This PR implements filter pushdown for Python Data Sources API using the existing Scala DS filter pushdown API. An upcoming PR will implement the actual filter types and the serialization of filters. ### Does this PR introduce _any_ user-facing change? Yes. New API are added. See `datasource.py` for details. The new API is optional to implement. If not implemented, the reader will behave as before. The feature is also controlled by the new `spark.sql.python.filterPushdown.enabled` configuration which is disabled by default. If the conf is enabled, the new code path for filter pushdown is used. Otherwise, the code path is skipped and we throw an exception if the user implements `DataSourceReader.pushFilters()` so that it's not ignored silently. ### How was this patch tested? Tests added to `test_python_datasource.py` to check that: - pushed filters are not reapplied by Spark - unsupported filters are applied by Spark - pushdown happens before partitions - reader state is preserved after pushdown - pushdown is not called if no filters are present - conf is respected - ... ### Was this patch authored or co-authored using generative AI tooling? No Closes #49961 from wengh/pyds-filter-pushdown. Authored-by: Haoyu Weng <[email protected]> Signed-off-by: Allison Wang <[email protected]>
1 parent 9d90507 commit b08cb22

File tree

11 files changed

+787
-12
lines changed

11 files changed

+787
-12
lines changed

python/pyspark/errors/error-conditions.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,21 @@
189189
"Remote client cannot create a SparkContext. Create SparkSession instead."
190190
]
191191
},
192+
"DATA_SOURCE_EXTRANEOUS_FILTERS": {
193+
"message": [
194+
"<type>.pushFilters() returned filters that are not part of the input. Make sure that each returned filter is one of the input filters by reference."
195+
]
196+
},
192197
"DATA_SOURCE_INVALID_RETURN_TYPE": {
193198
"message": [
194199
"Unsupported return type ('<type>') from Python data source '<name>'. Expected types: <supported_types>."
195200
]
196201
},
202+
"DATA_SOURCE_PUSHDOWN_DISABLED": {
203+
"message": [
204+
"<type> implements pushFilters() but filter pushdown is disabled because configuration '<conf>' is false. Set it to true to enable filter pushdown."
205+
]
206+
},
197207
"DATA_SOURCE_RETURN_SCHEMA_MISMATCH": {
198208
"message": [
199209
"Return schema mismatch in the result from 'read' method. Expected: <expected> columns, Found: <actual> columns. Make sure the returned values match the required output schema."
@@ -204,6 +214,11 @@
204214
"Expected <expected>, but got <actual>."
205215
]
206216
},
217+
"DATA_SOURCE_UNSUPPORTED_FILTER": {
218+
"message": [
219+
"Unexpected filter <name>."
220+
]
221+
},
207222
"DIFFERENT_PANDAS_DATAFRAME": {
208223
"message": [
209224
"DataFrames are not almost equal:",

python/pyspark/sql/datasource.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,20 @@
1616
#
1717
from abc import ABC, abstractmethod
1818
from collections import UserDict
19-
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union, TYPE_CHECKING
19+
from dataclasses import dataclass
20+
from typing import (
21+
Any,
22+
Dict,
23+
Iterable,
24+
Iterator,
25+
List,
26+
Optional,
27+
Sequence,
28+
Tuple,
29+
Type,
30+
Union,
31+
TYPE_CHECKING,
32+
)
2033

2134
from pyspark.sql import Row
2235
from pyspark.sql.types import StructType
@@ -38,6 +51,8 @@
3851
"InputPartition",
3952
"SimpleDataSourceStreamReader",
4053
"WriterCommitMessage",
54+
"Filter",
55+
"EqualTo",
4156
]
4257

4358

@@ -234,6 +249,69 @@ def streamReader(self, schema: StructType) -> "DataSourceStreamReader":
234249
)
235250

236251

252+
ColumnPath = Tuple[str, ...]
253+
"""
254+
A tuple of strings representing a column reference.
255+
256+
For example, `("a", "b", "c")` represents the column `a.b.c`.
257+
258+
.. versionadded: 4.1.0
259+
"""
260+
261+
262+
@dataclass(frozen=True)
263+
class Filter(ABC):
264+
"""
265+
The base class for filters used for filter pushdown.
266+
267+
.. versionadded: 4.1.0
268+
269+
Notes
270+
-----
271+
Column references are represented as a tuple of strings. For example:
272+
273+
+----------------+----------------------+
274+
| Column | Representation |
275+
+----------------+----------------------+
276+
| `col1` | `("col1",)` |
277+
| `a.b.c` | `("a", "b", "c")` |
278+
+----------------+----------------------+
279+
280+
Literal values are represented as Python objects of types such as
281+
`int`, `float`, `str`, `bool`, `datetime`, etc.
282+
See `Data Types <https://spark.apache.org/docs/latest/sql-ref-datatypes.html>`_
283+
for more information about how values are represented in Python.
284+
285+
Currently only the equality of attribute and literal value is supported for
286+
filter pushdown. Other types of filters cannot be pushed down.
287+
288+
Examples
289+
--------
290+
Supported filters
291+
292+
+---------------------+--------------------------------------------+
293+
| SQL filter | Representation |
294+
+---------------------+--------------------------------------------+
295+
| `a.b.c = 1` | `EqualTo(("a", "b", "c"), 1)` |
296+
| `a = 1` | `EqualTo(("a", "b", "c"), 1)` |
297+
| `a = 'hi'` | `EqualTo(("a",), "hi")` |
298+
| `a = array(1, 2)` | `EqualTo(("a",), [1, 2])` |
299+
+---------------------+--------------------------------------------+
300+
301+
Unsupported filters
302+
- `a = b`
303+
- `f(a, b) = 1`
304+
- `a % 2 = 1`
305+
- `a[0] = 1`
306+
"""
307+
308+
309+
@dataclass(frozen=True)
310+
class EqualTo(Filter):
311+
attribute: ColumnPath
312+
value: Any
313+
314+
237315
class InputPartition:
238316
"""
239317
A base class representing an input partition returned by the `partitions()`
@@ -280,6 +358,67 @@ class DataSourceReader(ABC):
280358
.. versionadded: 4.0.0
281359
"""
282360

361+
def pushFilters(self, filters: List["Filter"]) -> Iterable["Filter"]:
362+
"""
363+
Called with the list of filters that can be pushed down to the data source.
364+
365+
The list of filters should be interpreted as the AND of the elements.
366+
367+
Filter pushdown allows the data source to handle a subset of filters. This
368+
can improve performance by reducing the amount of data that needs to be
369+
processed by Spark.
370+
371+
This method is called once during query planning. By default, it returns
372+
all filters, indicating that no filters can be pushed down. Subclasses can
373+
override this method to implement filter pushdown.
374+
375+
It's recommended to implement this method only for data sources that natively
376+
support filtering, such as databases and GraphQL APIs.
377+
378+
.. versionadded: 4.1.0
379+
380+
Parameters
381+
----------
382+
filters : list of :class:`Filter`\\s
383+
384+
Returns
385+
-------
386+
iterable of :class:`Filter`\\s
387+
Filters that still need to be evaluated by Spark post the data source
388+
scan. This includes unsupported filters and partially pushed filters.
389+
Every returned filter must be one of the input filters by reference.
390+
391+
Side effects
392+
------------
393+
This method is allowed to modify `self`. The object must remain picklable.
394+
Modifications to `self` are visible to the `partitions()` and `read()` methods.
395+
396+
Examples
397+
--------
398+
Example filters and the resulting arguments passed to pushFilters:
399+
400+
+-------------------------------+---------------------------------------------+
401+
| Filters | Pushdown Arguments |
402+
+-------------------------------+---------------------------------------------+
403+
| `a = 1 and b = 2` | `[EqualTo(("a",), 1), EqualTo(("b",), 2)]` |
404+
| `a = 1 or b = 2` | `[]` |
405+
| `a = 1 or (b = 2 and c = 3)` | `[]` |
406+
| `a = 1 and (b = 2 or c = 3)` | `[EqualTo(("a",), 1)]` |
407+
+-------------------------------+---------------------------------------------+
408+
409+
Implement pushFilters to support EqualTo filters only:
410+
411+
>>> def pushFilters(self, filters):
412+
... for filter in filters:
413+
... if isinstance(filter, EqualTo):
414+
... # Save supported filter for handling in partitions() and read()
415+
... self.filters.append(filter)
416+
... else:
417+
... # Unsupported filter
418+
... yield filter
419+
"""
420+
return filters
421+
283422
def partitions(self) -> Sequence[InputPartition]:
284423
"""
285424
Returns an iterator of partitions for this data source.

python/pyspark/sql/tests/test_python_datasource.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,22 @@
1818
import platform
1919
import tempfile
2020
import unittest
21-
from typing import Callable, Union
21+
from typing import Callable, Iterable, List, Union
2222

2323
from pyspark.errors import PythonException, AnalysisException
2424
from pyspark.sql.datasource import (
2525
DataSource,
2626
DataSourceReader,
27+
EqualTo,
28+
Filter,
2729
InputPartition,
2830
DataSourceWriter,
2931
DataSourceArrowWriter,
3032
WriterCommitMessage,
3133
CaseInsensitiveDict,
3234
)
3335
from pyspark.sql.functions import spark_partition_id
36+
from pyspark.sql.session import SparkSession
3437
from pyspark.sql.types import Row, StructType
3538
from pyspark.testing.sqlutils import (
3639
have_pyarrow,
@@ -42,6 +45,8 @@
4245

4346
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
4447
class BasePythonDataSourceTestsMixin:
48+
spark: SparkSession
49+
4550
def test_basic_data_source_class(self):
4651
class MyDataSource(DataSource):
4752
...
@@ -246,6 +251,161 @@ def reader(self, schema) -> "DataSourceReader":
246251
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
247252
self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2)
248253

254+
def test_filter_pushdown(self):
255+
class TestDataSourceReader(DataSourceReader):
256+
def __init__(self):
257+
self.has_filter = False
258+
259+
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
260+
assert set(filters) == {
261+
EqualTo(("x",), 1),
262+
EqualTo(("y",), 2),
263+
}, filters
264+
self.has_filter = True
265+
# pretend we support x = 1 filter but in fact we don't
266+
# so we only return y = 2 filter
267+
yield filters[filters.index(EqualTo(("y",), 2))]
268+
269+
def partitions(self):
270+
assert self.has_filter
271+
return super().partitions()
272+
273+
def read(self, partition):
274+
assert self.has_filter
275+
yield [1, 1]
276+
yield [1, 2]
277+
yield [2, 2]
278+
279+
class TestDataSource(DataSource):
280+
@classmethod
281+
def name(cls):
282+
return "test"
283+
284+
def schema(self):
285+
return "x int, y int"
286+
287+
def reader(self, schema) -> "DataSourceReader":
288+
return TestDataSourceReader()
289+
290+
with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
291+
self.spark.dataSource.register(TestDataSource)
292+
df = self.spark.read.format("test").load().filter("x = 1 and y = 2")
293+
# only the y = 2 filter is applied post scan
294+
assertDataFrameEqual(df, [Row(x=1, y=2), Row(x=2, y=2)])
295+
296+
def test_extraneous_filter(self):
297+
class TestDataSourceReader(DataSourceReader):
298+
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
299+
yield EqualTo(("x",), 1)
300+
301+
def partitions(self):
302+
assert False
303+
304+
def read(self, partition):
305+
assert False
306+
307+
class TestDataSource(DataSource):
308+
@classmethod
309+
def name(cls):
310+
return "test"
311+
312+
def schema(self):
313+
return "x int"
314+
315+
def reader(self, schema) -> "DataSourceReader":
316+
return TestDataSourceReader()
317+
318+
with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
319+
self.spark.dataSource.register(TestDataSource)
320+
with self.assertRaisesRegex(Exception, "DATA_SOURCE_EXTRANEOUS_FILTERS"):
321+
self.spark.read.format("test").load().filter("x = 1").show()
322+
323+
def test_filter_pushdown_error(self):
324+
error_str = "dummy error"
325+
326+
class TestDataSourceReader(DataSourceReader):
327+
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
328+
raise Exception(error_str)
329+
330+
def read(self, partition):
331+
yield [1]
332+
333+
class TestDataSource(DataSource):
334+
def schema(self):
335+
return "x int"
336+
337+
def reader(self, schema) -> "DataSourceReader":
338+
return TestDataSourceReader()
339+
340+
with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
341+
self.spark.dataSource.register(TestDataSource)
342+
df = self.spark.read.format("TestDataSource").load().filter("x = 1 or x is null")
343+
assertDataFrameEqual(df, [Row(x=1)]) # works when not pushing down filters
344+
with self.assertRaisesRegex(Exception, error_str):
345+
df.filter("x = 1").explain()
346+
347+
def test_filter_pushdown_disabled(self):
348+
class TestDataSourceReader(DataSourceReader):
349+
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
350+
assert False
351+
352+
def read(self, partition):
353+
assert False
354+
355+
class TestDataSource(DataSource):
356+
def reader(self, schema) -> "DataSourceReader":
357+
return TestDataSourceReader()
358+
359+
with self.sql_conf({"spark.sql.python.filterPushdown.enabled": False}):
360+
self.spark.dataSource.register(TestDataSource)
361+
df = self.spark.read.format("TestDataSource").schema("x int").load()
362+
with self.assertRaisesRegex(Exception, "DATA_SOURCE_PUSHDOWN_DISABLED"):
363+
df.show()
364+
365+
def _check_filters(self, sql_type, sql_filter, python_filters):
366+
"""
367+
Parameters
368+
----------
369+
sql_type: str
370+
The SQL type of the column x.
371+
sql_filter: str
372+
A SQL filter using the column x.
373+
python_filters: List[Filter]
374+
The expected python filters to be pushed down.
375+
"""
376+
377+
class TestDataSourceReader(DataSourceReader):
378+
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
379+
expected = python_filters
380+
assert filters == expected, (filters, expected)
381+
return filters
382+
383+
def read(self, partition):
384+
yield from []
385+
386+
class TestDataSource(DataSource):
387+
def schema(self):
388+
return f"x {sql_type}"
389+
390+
def reader(self, schema) -> "DataSourceReader":
391+
return TestDataSourceReader()
392+
393+
with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
394+
self.spark.dataSource.register(TestDataSource)
395+
df = self.spark.read.format("TestDataSource").load().filter(sql_filter)
396+
df.count()
397+
398+
def test_unsupported_filter(self):
399+
self._check_filters(
400+
"struct<a:int, b:int, c:int>", "x.a = 1 and x.b = x.c", [EqualTo(("x", "a"), 1)]
401+
)
402+
self._check_filters("int", "x <> 0", [])
403+
self._check_filters("int", "x = 1 or x > 2", [])
404+
self._check_filters("int", "(0 < x and x < 1) or x = 2", [])
405+
self._check_filters("int", "x % 5 = 1", [])
406+
self._check_filters("boolean", "not x", [])
407+
self._check_filters("array<int>", "x[0] = 1", [])
408+
249409
def _get_test_json_data_source(self):
250410
import json
251411
import os

0 commit comments

Comments
 (0)