@@ -95,7 +95,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
95
95
pythonResult : PythonDataSourceReader ,
96
96
outputSchema : StructType ,
97
97
isStreaming : Boolean ): PythonDataSourceReadInfo = {
98
- new PartitionRunner (
98
+ new UserDefinedPythonDataSourcePartitionRunner (
99
99
createPythonFunction(pythonResult.reader),
100
100
UserDefinedPythonDataSource .readInputSchema,
101
101
outputSchema,
@@ -327,10 +327,64 @@ private class UserDefinedPythonDataSourceRunner(
327
327
}
328
328
}
329
329
330
+ case class PythonDataSourceReader (reader : Array [Byte ], isStreaming : Boolean )
331
+
332
+ /**
333
+ * Instantiate the reader of a Python data source.
334
+ *
335
+ * @param func
336
+ * a Python data source instance
337
+ * @param outputSchema
338
+ * output schema of the Python data source
339
+ * @param isStreaming
340
+ * whether it is a streaming read
341
+ */
342
+ private class UserDefinedPythonDataSourceReaderRunner (
343
+ func : PythonFunction ,
344
+ outputSchema : StructType ,
345
+ isStreaming : Boolean )
346
+ extends PythonPlannerRunner [PythonDataSourceReader ](func) {
347
+
348
+ // See the logic in `pyspark.sql.worker.data_source_get_reader.py`.
349
+ override val workerModule = " pyspark.sql.worker.data_source_get_reader"
350
+
351
+ override protected def writeToPython (dataOut : DataOutputStream , pickler : Pickler ): Unit = {
352
+ // Send Python data source
353
+ PythonWorkerUtils .writePythonFunction(func, dataOut)
354
+
355
+ // Send output schema
356
+ PythonWorkerUtils .writeUTF(outputSchema.json, dataOut)
357
+
358
+ dataOut.writeBoolean(isStreaming)
359
+ }
360
+
361
+ override protected def receiveFromPython (dataIn : DataInputStream ): PythonDataSourceReader = {
362
+ // Receive the picked reader or an exception raised in Python worker.
363
+ val length = dataIn.readInt()
364
+ if (length == SpecialLengths .PYTHON_EXCEPTION_THROWN ) {
365
+ val msg = PythonWorkerUtils .readUTF(dataIn)
366
+ throw QueryCompilationErrors .pythonDataSourceError(action = " plan" , tpe = " read" , msg = msg)
367
+ }
368
+
369
+ // Receive the pickled reader.
370
+ val pickledFunction : Array [Byte ] = PythonWorkerUtils .readBytes(length, dataIn)
371
+
372
+ PythonDataSourceReader (reader = pickledFunction, isStreaming = isStreaming)
373
+ }
374
+ }
375
+
330
376
case class PythonFilterPushdownResult (
331
377
reader : PythonDataSourceReader ,
332
378
isFilterPushed : collection.Seq [Boolean ])
333
379
380
+ /**
381
+ * Push down filters to a Python data source.
382
+ *
383
+ * @param reader
384
+ * a Python data source reader instance
385
+ * @param filters
386
+ * all filters to be pushed down
387
+ */
334
388
private class UserDefinedPythonDataSourceFilterPushdownRunner (
335
389
reader : PythonFunction ,
336
390
filters : collection.Seq [Filter ])
@@ -346,7 +400,7 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
346
400
case (filter, i) =>
347
401
filter match {
348
402
case filter @ org.apache.spark.sql.sources.EqualTo (_, value : Int ) =>
349
- val columnPath = filter.v2references.head
403
+ val columnPath = filter.v2references.head
350
404
Some (SerializedFilter (" EqualTo" , columnPath, value, i))
351
405
case _ =>
352
406
None
@@ -381,7 +435,7 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
381
435
throw QueryCompilationErrors .pythonDataSourceError(action = " plan" , tpe = " read" , msg = msg)
382
436
}
383
437
384
- // Receive the pickled ' reader' .
438
+ // Receive the pickled reader.
385
439
val pickledReader : Array [Byte ] = PythonWorkerUtils .readBytes(length, dataIn)
386
440
387
441
// Receive the pushed filters as a list of indices.
@@ -399,50 +453,6 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
399
453
}
400
454
}
401
455
402
- case class PythonDataSourceReader (reader : Array [Byte ], isStreaming : Boolean )
403
-
404
- /**
405
- * Send information to a Python process to plan a Python data source read.
406
- *
407
- * @param func
408
- * an Python data source instance
409
- * @param outputSchema
410
- * output schema of the Python data source
411
- */
412
- private class UserDefinedPythonDataSourceReaderRunner (
413
- func : PythonFunction ,
414
- outputSchema : StructType ,
415
- isStreaming : Boolean )
416
- extends PythonPlannerRunner [PythonDataSourceReader ](func) {
417
-
418
- // See the logic in `pyspark.sql.worker.data_source_get_reader.py`.
419
- override val workerModule = " pyspark.sql.worker.data_source_get_reader"
420
-
421
- override protected def writeToPython (dataOut : DataOutputStream , pickler : Pickler ): Unit = {
422
- // Send Python data source
423
- PythonWorkerUtils .writePythonFunction(func, dataOut)
424
-
425
- // Send output schema
426
- PythonWorkerUtils .writeUTF(outputSchema.json, dataOut)
427
-
428
- dataOut.writeBoolean(isStreaming)
429
- }
430
-
431
- override protected def receiveFromPython (dataIn : DataInputStream ): PythonDataSourceReader = {
432
- // Receive the picked reader or an exception raised in Python worker.
433
- val length = dataIn.readInt()
434
- if (length == SpecialLengths .PYTHON_EXCEPTION_THROWN ) {
435
- val msg = PythonWorkerUtils .readUTF(dataIn)
436
- throw QueryCompilationErrors .pythonDataSourceError(action = " plan" , tpe = " read" , msg = msg)
437
- }
438
-
439
- // Receive the pickled 'read' function.
440
- val pickledFunction : Array [Byte ] = PythonWorkerUtils .readBytes(length, dataIn)
441
-
442
- PythonDataSourceReader (reader = pickledFunction, isStreaming = isStreaming)
443
- }
444
- }
445
-
446
456
case class PythonDataSourceReadInfo (
447
457
func : Array [Byte ],
448
458
partitions : Seq [Array [Byte ]])
@@ -459,7 +469,7 @@ case class PythonDataSourceReadInfo(
459
469
* @param isStreaming
460
470
* whether it is a streaming read
461
471
*/
462
- private class PartitionRunner (
472
+ private class UserDefinedPythonDataSourcePartitionRunner (
463
473
reader : PythonFunction ,
464
474
inputSchema : StructType ,
465
475
outputSchema : StructType ,
0 commit comments