Skip to content

Commit 9d3bbc2

Browse files
committed
clean up UserDefinedPythonDataSource.scala
1 parent 5064aa3 commit 9d3bbc2

File tree

1 file changed

+58
-48
lines changed

1 file changed

+58
-48
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
9595
pythonResult: PythonDataSourceReader,
9696
outputSchema: StructType,
9797
isStreaming: Boolean): PythonDataSourceReadInfo = {
98-
new PartitionRunner(
98+
new UserDefinedPythonDataSourcePartitionRunner(
9999
createPythonFunction(pythonResult.reader),
100100
UserDefinedPythonDataSource.readInputSchema,
101101
outputSchema,
@@ -327,10 +327,64 @@ private class UserDefinedPythonDataSourceRunner(
327327
}
328328
}
329329

330+
case class PythonDataSourceReader(reader: Array[Byte], isStreaming: Boolean)
331+
332+
/**
333+
* Instantiate the reader of a Python data source.
334+
*
335+
* @param func
336+
* a Python data source instance
337+
* @param outputSchema
338+
* output schema of the Python data source
339+
* @param isStreaming
340+
* whether it is a streaming read
341+
*/
342+
private class UserDefinedPythonDataSourceReaderRunner(
343+
func: PythonFunction,
344+
outputSchema: StructType,
345+
isStreaming: Boolean)
346+
extends PythonPlannerRunner[PythonDataSourceReader](func) {
347+
348+
// See the logic in `pyspark.sql.worker.data_source_get_reader.py`.
349+
override val workerModule = "pyspark.sql.worker.data_source_get_reader"
350+
351+
override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
352+
// Send Python data source
353+
PythonWorkerUtils.writePythonFunction(func, dataOut)
354+
355+
// Send output schema
356+
PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
357+
358+
dataOut.writeBoolean(isStreaming)
359+
}
360+
361+
override protected def receiveFromPython(dataIn: DataInputStream): PythonDataSourceReader = {
362+
// Receive the picked reader or an exception raised in Python worker.
363+
val length = dataIn.readInt()
364+
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
365+
val msg = PythonWorkerUtils.readUTF(dataIn)
366+
throw QueryCompilationErrors.pythonDataSourceError(action = "plan", tpe = "read", msg = msg)
367+
}
368+
369+
// Receive the pickled reader.
370+
val pickledFunction: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn)
371+
372+
PythonDataSourceReader(reader = pickledFunction, isStreaming = isStreaming)
373+
}
374+
}
375+
330376
case class PythonFilterPushdownResult(
331377
reader: PythonDataSourceReader,
332378
isFilterPushed: collection.Seq[Boolean])
333379

380+
/**
381+
* Push down filters to a Python data source.
382+
*
383+
* @param reader
384+
* a Python data source reader instance
385+
* @param filters
386+
* all filters to be pushed down
387+
*/
334388
private class UserDefinedPythonDataSourceFilterPushdownRunner(
335389
reader: PythonFunction,
336390
filters: collection.Seq[Filter])
@@ -346,7 +400,7 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
346400
case (filter, i) =>
347401
filter match {
348402
case filter @ org.apache.spark.sql.sources.EqualTo(_, value: Int) =>
349-
val columnPath = filter.v2references.head
403+
val columnPath = filter.v2references.head
350404
Some(SerializedFilter("EqualTo", columnPath, value, i))
351405
case _ =>
352406
None
@@ -381,7 +435,7 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
381435
throw QueryCompilationErrors.pythonDataSourceError(action = "plan", tpe = "read", msg = msg)
382436
}
383437

384-
// Receive the pickled 'reader'.
438+
// Receive the pickled reader.
385439
val pickledReader: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn)
386440

387441
// Receive the pushed filters as a list of indices.
@@ -399,50 +453,6 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
399453
}
400454
}
401455

402-
case class PythonDataSourceReader(reader: Array[Byte], isStreaming: Boolean)
403-
404-
/**
405-
* Send information to a Python process to plan a Python data source read.
406-
*
407-
* @param func
408-
* an Python data source instance
409-
* @param outputSchema
410-
* output schema of the Python data source
411-
*/
412-
private class UserDefinedPythonDataSourceReaderRunner(
413-
func: PythonFunction,
414-
outputSchema: StructType,
415-
isStreaming: Boolean)
416-
extends PythonPlannerRunner[PythonDataSourceReader](func) {
417-
418-
// See the logic in `pyspark.sql.worker.data_source_get_reader.py`.
419-
override val workerModule = "pyspark.sql.worker.data_source_get_reader"
420-
421-
override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
422-
// Send Python data source
423-
PythonWorkerUtils.writePythonFunction(func, dataOut)
424-
425-
// Send output schema
426-
PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
427-
428-
dataOut.writeBoolean(isStreaming)
429-
}
430-
431-
override protected def receiveFromPython(dataIn: DataInputStream): PythonDataSourceReader = {
432-
// Receive the picked reader or an exception raised in Python worker.
433-
val length = dataIn.readInt()
434-
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
435-
val msg = PythonWorkerUtils.readUTF(dataIn)
436-
throw QueryCompilationErrors.pythonDataSourceError(action = "plan", tpe = "read", msg = msg)
437-
}
438-
439-
// Receive the pickled 'read' function.
440-
val pickledFunction: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn)
441-
442-
PythonDataSourceReader(reader = pickledFunction, isStreaming = isStreaming)
443-
}
444-
}
445-
446456
case class PythonDataSourceReadInfo(
447457
func: Array[Byte],
448458
partitions: Seq[Array[Byte]])
@@ -459,7 +469,7 @@ case class PythonDataSourceReadInfo(
459469
* @param isStreaming
460470
* whether it is a streaming read
461471
*/
462-
private class PartitionRunner(
472+
private class UserDefinedPythonDataSourcePartitionRunner(
463473
reader: PythonFunction,
464474
inputSchema: StructType,
465475
outputSchema: StructType,

0 commit comments

Comments
 (0)