Skip to content

Commit

Permalink
filter pushdown with placeholder serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
wengh committed Feb 15, 2025
1 parent d6ad779 commit 95b2256
Show file tree
Hide file tree
Showing 14 changed files with 738 additions and 129 deletions.
31 changes: 30 additions & 1 deletion python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,20 @@
#
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union, TYPE_CHECKING
from dataclasses import dataclass
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
TYPE_CHECKING,
)

from pyspark.sql import Row
from pyspark.sql.types import StructType
Expand All @@ -38,6 +51,8 @@
"InputPartition",
"SimpleDataSourceStreamReader",
"WriterCommitMessage",
"Filter",
"EqualTo",
]


Expand Down Expand Up @@ -234,6 +249,17 @@ def streamReader(self, schema: StructType) -> "DataSourceStreamReader":
)


@dataclass(frozen=True)
class Filter(ABC):
pass


@dataclass(frozen=True)
class EqualTo(Filter):
columnPath: Tuple[str]
value: Any


class InputPartition:
"""
A base class representing an input partition returned by the `partitions()`
Expand Down Expand Up @@ -280,6 +306,9 @@ class DataSourceReader(ABC):
.. versionadded: 4.0.0
"""

def pushdownFilters(self, filters: List["Filter"]) -> Iterable["Filter"]:
return filters

def partitions(self) -> Sequence[InputPartition]:
"""
Returns an iterator of partitions for this data source.
Expand Down
41 changes: 28 additions & 13 deletions python/pyspark/sql/streaming/python_streaming_source_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_parse_datatype_json_string,
StructType,
)
from pyspark.sql.worker.internal.data_source_reader_info import DataSourceReaderInfo
from pyspark.sql.worker.plan_data_source_read import records_to_arrow_batches
from pyspark.util import handle_worker_exception, local_connect_and_auth
from pyspark.worker_util import (
Expand Down Expand Up @@ -69,7 +70,7 @@ def latest_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None:

def partitions_func(
reader: DataSourceStreamReader,
data_source: DataSource,
data_source_name: str,
schema: StructType,
max_arrow_batch_size: int,
infile: IO,
Expand All @@ -87,7 +88,7 @@ def partitions_func(
if it is None:
write_int(PREFETCHED_RECORDS_NOT_FOUND, outfile)
else:
send_batch_func(it, outfile, schema, max_arrow_batch_size, data_source)
send_batch_func(it, outfile, schema, max_arrow_batch_size, data_source_name)
else:
write_int(PREFETCHED_RECORDS_NOT_FOUND, outfile)

Expand All @@ -103,9 +104,9 @@ def send_batch_func(
outfile: IO,
schema: StructType,
max_arrow_batch_size: int,
data_source: DataSource,
data_source_name: str,
) -> None:
batches = list(records_to_arrow_batches(rows, max_arrow_batch_size, schema, data_source))
batches = list(records_to_arrow_batches(rows, max_arrow_batch_size, schema, data_source_name))
if len(batches) != 0:
write_int(NON_EMPTY_PYARROW_RECORD_BATCHES, outfile)
write_int(SpecialLengths.START_ARROW_STREAM, outfile)
Expand All @@ -125,15 +126,26 @@ def main(infile: IO, outfile: IO) -> None:

_accumulatorRegistry.clear()

# Receive the data source instance.
data_source = read_command(pickleSer, infile)
# Receive the data source reader.
reader_info = read_command(pickleSer, infile)
if not isinstance(reader_info, DataSourceReaderInfo):
raise PySparkAssertionError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "a Python data source reader info of type 'DataSourceReaderInfo'",
"actual": f"'{type(reader_info).__name__}'",
},
)

reader = reader_info.reader
data_source_name = reader_info.data_source_name

if not isinstance(data_source, DataSource):
if not isinstance(reader, DataSourceStreamReader):
raise PySparkAssertionError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "a Python data source instance of type 'DataSource'",
"actual": f"'{type(data_source).__name__}'",
"expected": "a Python data source reader of type 'DataSourceStreamReader'",
"actual": f"'{type(reader).__name__}'",
},
)

Expand All @@ -155,9 +167,7 @@ def main(infile: IO, outfile: IO) -> None:
f"'{max_arrow_batch_size}'"
)

# Instantiate data source reader.
try:
reader = _streamReader(data_source, schema)
# Initialization succeed.
write_int(0, outfile)
outfile.flush()
Expand All @@ -171,7 +181,12 @@ def main(infile: IO, outfile: IO) -> None:
latest_offset_func(reader, outfile)
elif func_id == PARTITIONS_FUNC_ID:
partitions_func(
reader, data_source, schema, max_arrow_batch_size, infile, outfile
reader,
data_source_name,
schema,
max_arrow_batch_size,
infile,
outfile,
)
elif func_id == COMMIT_FUNC_ID:
commit_func(reader, infile, outfile)
Expand All @@ -184,7 +199,7 @@ def main(infile: IO, outfile: IO) -> None:
)
outfile.flush()
except Exception as e:
error_msg = "data source {} throw exception: {}".format(data_source.name, e)
error_msg = "data source {} throw exception: {}".format(data_source_name, e)
raise PySparkRuntimeError(
errorClass="PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
messageParameters={"msg": error_msg},
Expand Down
43 changes: 42 additions & 1 deletion python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import platform
import tempfile
import unittest
from typing import Callable, Union
from typing import Callable, Iterable, List, Union

from pyspark.errors import PythonException, AnalysisException
from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
EqualTo,
Filter,
InputPartition,
DataSourceWriter,
DataSourceArrowWriter,
Expand Down Expand Up @@ -246,6 +248,45 @@ def reader(self, schema) -> "DataSourceReader":
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2)

def test_filter_pushdown(self):
class TestDataSourceReader(DataSourceReader):
def __init__(self):
self.has_filter = False

def pushdownFilters(self, filters: List[Filter]) -> Iterable[Filter]:
assert len(filters) == 2
assert set(filters) == {EqualTo(("x",), 1), EqualTo(("y",), 2)}
self.has_filter = True
# pretend we support x = 1 filter but in fact we don't
# so we only return y = 2 filter
yield EqualTo(("y",), 2)

def partitions(self):
assert self.has_filter
return super().partitions()

def read(self, partition):
assert self.has_filter
yield [1, 1]
yield [1, 2]
yield [2, 2]

class TestDataSource(DataSource):
@classmethod
def name(cls):
return "test"

def schema(self):
return "x int, y int"

def reader(self, schema) -> "DataSourceReader":
return TestDataSourceReader()

self.spark.dataSource.register(TestDataSource)
df = self.spark.read.format("test").load().filter("x = 1 and y = 2")
# only the y = 2 filter is applied post scan
assertDataFrameEqual(df, [Row(x=1, y=2), Row(x=2, y=2)])

def _get_test_json_data_source(self):
import json
import os
Expand Down
87 changes: 87 additions & 0 deletions python/pyspark/sql/worker/data_source_get_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import IO, Union

from pyspark.errors import PySparkAssertionError
from pyspark.serializers import (
read_bool,
)
from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
DataSourceStreamReader,
)
from pyspark.sql.datasource_internal import _streamReader
from pyspark.sql.types import (
_parse_datatype_json_string,
BinaryType,
StructType,
)
from pyspark.sql.worker.internal.data_source_reader_info import DataSourceReaderInfo
from pyspark.sql.worker.internal.data_source_worker import worker_main
from pyspark.worker_util import (
read_command,
pickleSer,
utf8_deserializer,
)


@worker_main
def main(infile: IO, outfile: IO) -> None:
# Receive the data source instance.
data_source = read_command(pickleSer, infile)
if not isinstance(data_source, DataSource):
raise PySparkAssertionError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "a Python data source instance of type 'DataSource'",
"actual": f"'{type(data_source).__name__}'",
},
)

# Receive the data source output schema.
schema_json = utf8_deserializer.loads(infile)
schema = _parse_datatype_json_string(schema_json)
if not isinstance(schema, StructType):
raise PySparkAssertionError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "an output schema of type 'StructType'",
"actual": f"'{type(schema).__name__}'",
},
)

is_streaming = read_bool(infile)

# Instantiate data source reader.
if is_streaming:
reader: Union[DataSourceReader, DataSourceStreamReader] = _streamReader(data_source, schema)
else:
reader = data_source.reader(schema=schema)
# Validate the reader.
if not isinstance(reader, DataSourceReader):
raise PySparkAssertionError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "an instance of DataSourceReader",
"actual": f"'{type(reader).__name__}'",
},
)

reader_info = DataSourceReaderInfo(reader=reader, data_source_name=data_source.name)
pickleSer._write_with_length(reader_info, outfile)
Loading

0 comments on commit 95b2256

Please sign in to comment.