Skip to content

Commit 46bd9cc

Browse files
wenghallisonwang-db
authored andcommitted
[SPARK-51575][PYTHON] Combine Python Data Source pushdown & plan read workers
Follow up of #49961 ### What changes were proposed in this pull request? As pointed out by #49961 (comment), at the time of filter pushdown we already have enough information to also plan read partitions. So this PR changes the filter pushdown worker to also get partitions, reducing the number of exchanges between Python and Scala. Changes: - Extract part of `plan_data_source_read.py` that is responsible for sending the partitions and the read function to JVM. - Use the extracted logic to also send the partitions and read function when doing filter pushdown in `data_source_pushdown_filters.py`. - Update the Scala code accordingly. ### Why are the changes needed? To improve Python Data Source performance when filter pushdown configuration is enabled. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests in `test_python_datasource.py` ### Was this patch authored or co-authored using generative AI tooling? No Closes #50340 from wengh/pyds-combine-pushdown-plan. Authored-by: Haoyu Weng <[email protected]> Signed-off-by: Allison Wang <[email protected]>
1 parent b829aea commit 46bd9cc

File tree

7 files changed

+204
-156
lines changed

7 files changed

+204
-156
lines changed

python/pyspark/sql/worker/data_source_pushdown_filters.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
StringStartsWith,
4848
)
4949
from pyspark.sql.types import StructType, VariantVal, _parse_datatype_json_string
50+
from pyspark.sql.worker.plan_data_source_read import write_read_func_and_partitions
5051
from pyspark.util import handle_worker_exception, local_connect_and_auth
5152
from pyspark.worker_util import (
5253
check_python_version,
@@ -131,11 +132,12 @@ def main(infile: IO, outfile: IO) -> None:
131132
- a `DataSource` instance representing the data source
132133
- a `StructType` instance representing the output schema of the data source
133134
- a list of filters to be pushed down
135+
- configuration values
134136
135137
This process then creates a `DataSourceReader` instance by calling the `reader` method
136138
on the `DataSource` instance. It applies the filters by calling the `pushFilters` method
137-
on the reader and determines which filters are supported. The data source with updated reader
138-
is then sent back to the JVM along with the indices of the supported filters.
139+
on the reader and determines which filters are supported. The indices of the supported
140+
filters are sent back to the JVM, along with the list of partitions and the read function.
139141
"""
140142
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
141143
try:
@@ -220,10 +222,22 @@ def main(infile: IO, outfile: IO) -> None:
220222
},
221223
)
222224

223-
# Monkey patch the data source instance
224-
# to return the existing reader with the pushed down filters.
225-
data_source.reader = lambda schema: reader # type: ignore[method-assign]
226-
pickleSer._write_with_length(data_source, outfile)
225+
# Receive the max arrow batch size.
226+
max_arrow_batch_size = read_int(infile)
227+
assert max_arrow_batch_size > 0, (
228+
"The maximum arrow batch size should be greater than 0, but got "
229+
f"'{max_arrow_batch_size}'"
230+
)
231+
232+
# Return the read function and partitions. Doing this in the same worker as filter pushdown
233+
# helps reduce the number of Python worker calls.
234+
write_read_func_and_partitions(
235+
outfile,
236+
reader=reader,
237+
data_source=data_source,
238+
schema=schema,
239+
max_arrow_batch_size=max_arrow_batch_size,
240+
)
227241

228242
# Return the supported filter indices.
229243
write_int(len(supported_filter_indices), outfile)

python/pyspark/sql/worker/plan_data_source_read.py

Lines changed: 103 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,101 @@ def batched(iterator: Iterator, n: int) -> Iterator:
168168
yield batch
169169

170170

171+
def write_read_func_and_partitions(
172+
outfile: IO,
173+
*,
174+
reader: Union[DataSourceReader, DataSourceStreamReader],
175+
data_source: DataSource,
176+
schema: StructType,
177+
max_arrow_batch_size: int,
178+
) -> None:
179+
is_streaming = isinstance(reader, DataSourceStreamReader)
180+
181+
# Create input converter.
182+
converter = ArrowTableToRowsConversion._create_converter(BinaryType())
183+
184+
# Create output converter.
185+
return_type = schema
186+
187+
def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
188+
partition_bytes = None
189+
190+
# Get the partition value from the input iterator.
191+
for batch in iterator:
192+
# There should be only one row/column in the batch.
193+
assert batch.num_columns == 1 and batch.num_rows == 1, (
194+
"Expected each batch to have exactly 1 column and 1 row, "
195+
f"but found {batch.num_columns} columns and {batch.num_rows} rows."
196+
)
197+
columns = [column.to_pylist() for column in batch.columns]
198+
partition_bytes = converter(columns[0][0])
199+
200+
assert (
201+
partition_bytes is not None
202+
), "The input iterator for Python data source read function is empty."
203+
204+
# Deserialize the partition value.
205+
partition = pickleSer.loads(partition_bytes)
206+
207+
assert partition is None or isinstance(partition, InputPartition), (
208+
"Expected the partition value to be of type 'InputPartition', "
209+
f"but found '{type(partition).__name__}'."
210+
)
211+
212+
output_iter = reader.read(partition) # type: ignore[arg-type]
213+
214+
# Validate the output iterator.
215+
if not isinstance(output_iter, Iterator):
216+
raise PySparkRuntimeError(
217+
errorClass="DATA_SOURCE_INVALID_RETURN_TYPE",
218+
messageParameters={
219+
"type": type(output_iter).__name__,
220+
"name": data_source.name(),
221+
"supported_types": "iterator",
222+
},
223+
)
224+
225+
return records_to_arrow_batches(output_iter, max_arrow_batch_size, return_type, data_source)
226+
227+
command = (data_source_read_func, return_type)
228+
pickleSer._write_with_length(command, outfile)
229+
230+
if not is_streaming:
231+
# The partitioning of python batch source read is determined before query execution.
232+
try:
233+
partitions = reader.partitions() # type: ignore[call-arg]
234+
if not isinstance(partitions, list):
235+
raise PySparkRuntimeError(
236+
errorClass="DATA_SOURCE_TYPE_MISMATCH",
237+
messageParameters={
238+
"expected": "'partitions' to return a list",
239+
"actual": f"'{type(partitions).__name__}'",
240+
},
241+
)
242+
if not all(isinstance(p, InputPartition) for p in partitions):
243+
partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions])
244+
raise PySparkRuntimeError(
245+
errorClass="DATA_SOURCE_TYPE_MISMATCH",
246+
messageParameters={
247+
"expected": "elements in 'partitions' to be of type 'InputPartition'",
248+
"actual": partition_types,
249+
},
250+
)
251+
if len(partitions) == 0:
252+
partitions = [None] # type: ignore[list-item]
253+
except NotImplementedError:
254+
partitions = [None] # type: ignore[list-item]
255+
256+
# Return the serialized partition values.
257+
write_int(len(partitions), outfile)
258+
for partition in partitions:
259+
pickleSer._write_with_length(partition, outfile)
260+
else:
261+
# Send an empty list of partition for stream reader because partitions are planned
262+
# in each microbatch during query execution.
263+
write_int(0, outfile)
264+
265+
171266
def main(infile: IO, outfile: IO) -> None:
172267
"""
173268
Main method for planning a data source read.
@@ -284,91 +379,14 @@ def main(infile: IO, outfile: IO) -> None:
284379
},
285380
)
286381

287-
# Create input converter.
288-
converter = ArrowTableToRowsConversion._create_converter(BinaryType())
289-
290-
# Create output converter.
291-
return_type = schema
292-
293-
def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
294-
partition_bytes = None
295-
296-
# Get the partition value from the input iterator.
297-
for batch in iterator:
298-
# There should be only one row/column in the batch.
299-
assert batch.num_columns == 1 and batch.num_rows == 1, (
300-
"Expected each batch to have exactly 1 column and 1 row, "
301-
f"but found {batch.num_columns} columns and {batch.num_rows} rows."
302-
)
303-
columns = [column.to_pylist() for column in batch.columns]
304-
partition_bytes = converter(columns[0][0])
305-
306-
assert (
307-
partition_bytes is not None
308-
), "The input iterator for Python data source read function is empty."
309-
310-
# Deserialize the partition value.
311-
partition = pickleSer.loads(partition_bytes)
312-
313-
assert partition is None or isinstance(partition, InputPartition), (
314-
"Expected the partition value to be of type 'InputPartition', "
315-
f"but found '{type(partition).__name__}'."
316-
)
317-
318-
output_iter = reader.read(partition) # type: ignore[arg-type]
319-
320-
# Validate the output iterator.
321-
if not isinstance(output_iter, Iterator):
322-
raise PySparkRuntimeError(
323-
errorClass="DATA_SOURCE_INVALID_RETURN_TYPE",
324-
messageParameters={
325-
"type": type(output_iter).__name__,
326-
"name": data_source.name(),
327-
"supported_types": "iterator",
328-
},
329-
)
330-
331-
return records_to_arrow_batches(
332-
output_iter, max_arrow_batch_size, return_type, data_source
333-
)
334-
335-
command = (data_source_read_func, return_type)
336-
pickleSer._write_with_length(command, outfile)
337-
338-
if not is_streaming:
339-
# The partitioning of python batch source read is determined before query execution.
340-
try:
341-
partitions = reader.partitions() # type: ignore[call-arg]
342-
if not isinstance(partitions, list):
343-
raise PySparkRuntimeError(
344-
errorClass="DATA_SOURCE_TYPE_MISMATCH",
345-
messageParameters={
346-
"expected": "'partitions' to return a list",
347-
"actual": f"'{type(partitions).__name__}'",
348-
},
349-
)
350-
if not all(isinstance(p, InputPartition) for p in partitions):
351-
partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions])
352-
raise PySparkRuntimeError(
353-
errorClass="DATA_SOURCE_TYPE_MISMATCH",
354-
messageParameters={
355-
"expected": "elements in 'partitions' to be of type 'InputPartition'",
356-
"actual": partition_types,
357-
},
358-
)
359-
if len(partitions) == 0:
360-
partitions = [None] # type: ignore[list-item]
361-
except NotImplementedError:
362-
partitions = [None] # type: ignore[list-item]
363-
364-
# Return the serialized partition values.
365-
write_int(len(partitions), outfile)
366-
for partition in partitions:
367-
pickleSer._write_with_length(partition, outfile)
368-
else:
369-
# Send an empty list of partition for stream reader because partitions are planned
370-
# in each microbatch during query execution.
371-
write_int(0, outfile)
382+
# Send the read function and partitions to the JVM.
383+
write_read_func_and_partitions(
384+
outfile,
385+
reader=reader,
386+
data_source=data_source,
387+
schema=schema,
388+
max_arrow_batch_size=max_arrow_batch_size,
389+
)
372390
except BaseException as e:
373391
handle_worker_exception(e, outfile)
374392
sys.exit(-1)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,23 @@ class PythonDataSourceV2 extends TableProvider {
5252
dataSourceInPython
5353
}
5454

55-
def setDataSourceInPython(dataSourceInPython: PythonDataSourceCreationResult): Unit = {
56-
this.dataSourceInPython = dataSourceInPython
55+
private var readInfo: PythonDataSourceReadInfo = _
56+
57+
def getOrCreateReadInfo(
58+
shortName: String,
59+
options: CaseInsensitiveStringMap,
60+
outputSchema: StructType,
61+
isStreaming: Boolean
62+
): PythonDataSourceReadInfo = {
63+
if (readInfo == null) {
64+
val creationResult = getOrCreateDataSourceInPython(shortName, options, Some(outputSchema))
65+
readInfo = source.createReadInfoInPython(creationResult, outputSchema, isStreaming)
66+
}
67+
readInfo
68+
}
69+
70+
def setReadInfo(readInfo: PythonDataSourceReadInfo): Unit = {
71+
this.readInfo = readInfo
5772
}
5873

5974
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,7 @@ class PythonMicroBatchStream(
9090
}
9191

9292
private lazy val readInfo: PythonDataSourceReadInfo = {
93-
ds.source.createReadInfoInPython(
94-
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
95-
outputSchema,
96-
isStreaming = true)
93+
ds.getOrCreateReadInfo(shortName, options, outputSchema, isStreaming = true)
9794
}
9895

9996
override def createReaderFactory(): PartitionReaderFactory = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,7 @@ class PythonBatch(
6363
private val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
6464

6565
private lazy val infoInPython: PythonDataSourceReadInfo = {
66-
ds.source.createReadInfoInPython(
67-
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
68-
outputSchema,
69-
isStreaming = false)
66+
ds.getOrCreateReadInfo(shortName, options, outputSchema, isStreaming = false)
7067
}
7168

7269
override def planInputPartitions(): Array[InputPartition] =

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,19 @@ class PythonScanBuilder(
4242
}
4343

4444
val dataSource = ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema))
45-
val result = ds.source.pushdownFiltersInPython(dataSource, outputSchema, filters)
46-
47-
// The Data Source instance state changes after pushdown to remember the reader instance
48-
// created and the filters pushed down. So pushdownFiltersInPython returns a new pickled
49-
// Data Source instance. We need to use that new instance for further operations.
50-
ds.setDataSourceInPython(dataSource.copy(dataSource = result.dataSource))
51-
52-
// Partition the filters into supported and unsupported ones.
53-
val isPushed = result.isFilterPushed.zip(filters)
54-
supportedFilters = isPushed.collect { case (true, filter) => filter }.toArray
55-
val unsupported = isPushed.collect { case (false, filter) => filter }.toArray
56-
unsupported
45+
ds.source.pushdownFiltersInPython(dataSource, outputSchema, filters) match {
46+
case None => filters // No filters are supported.
47+
case Some(result) =>
48+
// Filter pushdown also returns partitions and the read function.
49+
// This helps reduce the number of Python worker calls.
50+
ds.setReadInfo(result.readInfo)
51+
52+
// Partition the filters into supported and unsupported ones.
53+
val isPushed = result.isFilterPushed.zip(filters)
54+
supportedFilters = isPushed.collect { case (true, filter) => filter }.toArray
55+
val unsupported = isPushed.collect { case (false, filter) => filter }.toArray
56+
unsupported
57+
}
5758
}
5859

5960
override def pushedFilters(): Array[Filter] = supportedFilters

0 commit comments

Comments
 (0)