Skip to content

[SPARK-51271][PYTHON] Add filter pushdown API to Python Data Sources #49961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,21 @@
"Remote client cannot create a SparkContext. Create SparkSession instead."
]
},
"DATA_SOURCE_EXTRANEOUS_FILTERS": {
"message": [
"<type>.pushFilters() returned filters that are not part of the input. Make sure that each returned filter is one of the input filters by reference."
]
},
"DATA_SOURCE_INVALID_RETURN_TYPE": {
"message": [
"Unsupported return type ('<type>') from Python data source '<name>'. Expected types: <supported_types>."
]
},
"DATA_SOURCE_PUSHDOWN_DISABLED": {
"message": [
"<type> implements pushFilters() but filter pushdown is disabled because configuration '<conf>' is false. Set it to true to enable filter pushdown."
]
},
"DATA_SOURCE_RETURN_SCHEMA_MISMATCH": {
"message": [
"Return schema mismatch in the result from 'read' method. Expected: <expected> columns, Found: <actual> columns. Make sure the returned values match the required output schema."
Expand All @@ -204,6 +214,11 @@
"Expected <expected>, but got <actual>."
]
},
"DATA_SOURCE_UNSUPPORTED_FILTER": {
"message": [
"Unexpected filter <name>."
]
},
"DIFFERENT_PANDAS_DATAFRAME": {
"message": [
"DataFrames are not almost equal:",
Expand Down
141 changes: 140 additions & 1 deletion python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,20 @@
#
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union, TYPE_CHECKING
from dataclasses import dataclass
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
TYPE_CHECKING,
)

from pyspark.sql import Row
from pyspark.sql.types import StructType
Expand All @@ -38,6 +51,8 @@
"InputPartition",
"SimpleDataSourceStreamReader",
"WriterCommitMessage",
"Filter",
"EqualTo",
]


Expand Down Expand Up @@ -234,6 +249,69 @@ def streamReader(self, schema: StructType) -> "DataSourceStreamReader":
)


ColumnPath = Tuple[str, ...]
"""
A tuple of strings representing a column reference.

For example, `("a", "b", "c")` represents the column `a.b.c`.

.. versionadded: 4.1.0
"""


@dataclass(frozen=True)
class Filter(ABC):
"""
The base class for filters used for filter pushdown.

.. versionadded: 4.1.0

Notes
-----
Column references are represented as a tuple of strings. For example:

+----------------+----------------------+
| Column | Representation |
+----------------+----------------------+
| `col1` | `("col1",)` |
| `a.b.c` | `("a", "b", "c")` |
+----------------+----------------------+

Literal values are represented as Python objects of types such as
`int`, `float`, `str`, `bool`, `datetime`, etc.
See `Data Types <https://spark.apache.org/docs/latest/sql-ref-datatypes.html>`_
for more information about how values are represented in Python.

Currently only the equality of attribute and literal value is supported for
filter pushdown. Other types of filters cannot be pushed down.

Examples
--------
Supported filters

+---------------------+--------------------------------------------+
| SQL filter | Representation |
+---------------------+--------------------------------------------+
| `a.b.c = 1` | `EqualTo(("a", "b", "c"), 1)` |
| `a = 1` | `EqualTo(("a", "b", "c"), 1)` |
| `a = 'hi'` | `EqualTo(("a",), "hi")` |
| `a = array(1, 2)` | `EqualTo(("a",), [1, 2])` |
+---------------------+--------------------------------------------+

Unsupported filters
- `a = b`
- `f(a, b) = 1`
- `a % 2 = 1`
- `a[0] = 1`
"""


@dataclass(frozen=True)
class EqualTo(Filter):
attribute: ColumnPath
value: Any


class InputPartition:
"""
A base class representing an input partition returned by the `partitions()`
Expand Down Expand Up @@ -280,6 +358,67 @@ class DataSourceReader(ABC):
.. versionadded: 4.0.0
"""

def pushFilters(self, filters: List["Filter"]) -> Iterable["Filter"]:
"""
Called with the list of filters that can be pushed down to the data source.

The list of filters should be interpreted as the AND of the elements.

Filter pushdown allows the data source to handle a subset of filters. This
can improve performance by reducing the amount of data that needs to be
processed by Spark.

This method is called once during query planning. By default, it returns
all filters, indicating that no filters can be pushed down. Subclasses can
override this method to implement filter pushdown.

It's recommended to implement this method only for data sources that natively
support filtering, such as databases and GraphQL APIs.

.. versionadded: 4.1.0

Parameters
----------
filters : list of :class:`Filter`\\s

Returns
-------
iterable of :class:`Filter`\\s
Filters that still need to be evaluated by Spark post the data source
scan. This includes unsupported filters and partially pushed filters.
Every returned filter must be one of the input filters by reference.

Side effects
------------
This method is allowed to modify `self`. The object must remain picklable.
Modifications to `self` are visible to the `partitions()` and `read()` methods.

Examples
--------
Example filters and the resulting arguments passed to pushFilters:

+-------------------------------+---------------------------------------------+
| Filters | Pushdown Arguments |
+-------------------------------+---------------------------------------------+
| `a = 1 and b = 2` | `[EqualTo(("a",), 1), EqualTo(("b",), 2)]` |
| `a = 1 or b = 2` | `[]` |
| `a = 1 or (b = 2 and c = 3)` | `[]` |
| `a = 1 and (b = 2 or c = 3)` | `[EqualTo(("a",), 1)]` |
+-------------------------------+---------------------------------------------+

Implement pushFilters to support EqualTo filters only:

>>> def pushFilters(self, filters):
... for filter in filters:
... if isinstance(filter, EqualTo):
... # Save supported filter for handling in partitions() and read()
... self.filters.append(filter)
... else:
... # Unsupported filter
... yield filter
"""
return filters

def partitions(self) -> Sequence[InputPartition]:
"""
Returns an iterator of partitions for this data source.
Expand Down
162 changes: 161 additions & 1 deletion python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,22 @@
import platform
import tempfile
import unittest
from typing import Callable, Union
from typing import Callable, Iterable, List, Union

from pyspark.errors import PythonException, AnalysisException
from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
EqualTo,
Filter,
InputPartition,
DataSourceWriter,
DataSourceArrowWriter,
WriterCommitMessage,
CaseInsensitiveDict,
)
from pyspark.sql.functions import spark_partition_id
from pyspark.sql.session import SparkSession
from pyspark.sql.types import Row, StructType
from pyspark.testing.sqlutils import (
have_pyarrow,
Expand All @@ -42,6 +45,8 @@

@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class BasePythonDataSourceTestsMixin:
spark: SparkSession

def test_basic_data_source_class(self):
class MyDataSource(DataSource):
...
Expand Down Expand Up @@ -246,6 +251,161 @@ def reader(self, schema) -> "DataSourceReader":
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2)

def test_filter_pushdown(self):
class TestDataSourceReader(DataSourceReader):
def __init__(self):
self.has_filter = False

def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
assert set(filters) == {
EqualTo(("x",), 1),
EqualTo(("y",), 2),
}, filters
self.has_filter = True
# pretend we support x = 1 filter but in fact we don't
# so we only return y = 2 filter
yield filters[filters.index(EqualTo(("y",), 2))]

def partitions(self):
assert self.has_filter
return super().partitions()

def read(self, partition):
assert self.has_filter
yield [1, 1]
yield [1, 2]
yield [2, 2]

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return "x int, y int"

def reader(self, schema) -> "DataSourceReader":
return TestDataSourceReader()

with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
self.spark.dataSource.register(TestDataSource)
df = self.spark.read.format("test").load().filter("x = 1 and y = 2")
# only the y = 2 filter is applied post scan
assertDataFrameEqual(df, [Row(x=1, y=2), Row(x=2, y=2)])

def test_extraneous_filter(self):
class TestDataSourceReader(DataSourceReader):
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
yield EqualTo(("x",), 1)

def partitions(self):
assert False

def read(self, partition):
assert False

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return "x int"

def reader(self, schema) -> "DataSourceReader":
return TestDataSourceReader()

with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
self.spark.dataSource.register(TestDataSource)
with self.assertRaisesRegex(Exception, "DATA_SOURCE_EXTRANEOUS_FILTERS"):
self.spark.read.format("test").load().filter("x = 1").show()

def test_filter_pushdown_error(self):
error_str = "dummy error"

class TestDataSourceReader(DataSourceReader):
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
raise Exception(error_str)

def read(self, partition):
yield [1]

class TestDataSource(DataSource):
def schema(self):
return "x int"

def reader(self, schema) -> "DataSourceReader":
return TestDataSourceReader()

with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
self.spark.dataSource.register(TestDataSource)
df = self.spark.read.format("TestDataSource").load().filter("x = 1 or x is null")
assertDataFrameEqual(df, [Row(x=1)]) # works when not pushing down filters
with self.assertRaisesRegex(Exception, error_str):
df.filter("x = 1").explain()

def test_filter_pushdown_disabled(self):
class TestDataSourceReader(DataSourceReader):
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
assert False

def read(self, partition):
assert False

class TestDataSource(DataSource):
def reader(self, schema) -> "DataSourceReader":
return TestDataSourceReader()

with self.sql_conf({"spark.sql.python.filterPushdown.enabled": False}):
self.spark.dataSource.register(TestDataSource)
df = self.spark.read.format("TestDataSource").schema("x int").load()
with self.assertRaisesRegex(Exception, "DATA_SOURCE_PUSHDOWN_DISABLED"):
df.show()

def _check_filters(self, sql_type, sql_filter, python_filters):
"""
Parameters
----------
sql_type: str
The SQL type of the column x.
sql_filter: str
A SQL filter using the column x.
python_filters: List[Filter]
The expected python filters to be pushed down.
"""

class TestDataSourceReader(DataSourceReader):
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
expected = python_filters
assert filters == expected, (filters, expected)
return filters

def read(self, partition):
yield from []

class TestDataSource(DataSource):
def schema(self):
return f"x {sql_type}"

def reader(self, schema) -> "DataSourceReader":
return TestDataSourceReader()

with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
self.spark.dataSource.register(TestDataSource)
df = self.spark.read.format("TestDataSource").load().filter(sql_filter)
df.count()

def test_unsupported_filter(self):
self._check_filters(
"struct<a:int, b:int, c:int>", "x.a = 1 and x.b = x.c", [EqualTo(("x", "a"), 1)]
)
self._check_filters("int", "x <> 0", [])
self._check_filters("int", "x = 1 or x > 2", [])
self._check_filters("int", "(0 < x and x < 1) or x = 2", [])
self._check_filters("int", "x % 5 = 1", [])
self._check_filters("boolean", "not x", [])
self._check_filters("array<int>", "x[0] = 1", [])

def _get_test_json_data_source(self):
import json
import os
Expand Down
Loading