@@ -168,6 +168,101 @@ def batched(iterator: Iterator, n: int) -> Iterator:
168
168
yield batch
169
169
170
170
171
+ def write_read_func_and_partitions (
172
+ outfile : IO ,
173
+ * ,
174
+ reader : Union [DataSourceReader , DataSourceStreamReader ],
175
+ data_source : DataSource ,
176
+ schema : StructType ,
177
+ max_arrow_batch_size : int ,
178
+ ) -> None :
179
+ is_streaming = isinstance (reader , DataSourceStreamReader )
180
+
181
+ # Create input converter.
182
+ converter = ArrowTableToRowsConversion ._create_converter (BinaryType ())
183
+
184
+ # Create output converter.
185
+ return_type = schema
186
+
187
+ def data_source_read_func (iterator : Iterable [pa .RecordBatch ]) -> Iterable [pa .RecordBatch ]:
188
+ partition_bytes = None
189
+
190
+ # Get the partition value from the input iterator.
191
+ for batch in iterator :
192
+ # There should be only one row/column in the batch.
193
+ assert batch .num_columns == 1 and batch .num_rows == 1 , (
194
+ "Expected each batch to have exactly 1 column and 1 row, "
195
+ f"but found { batch .num_columns } columns and { batch .num_rows } rows."
196
+ )
197
+ columns = [column .to_pylist () for column in batch .columns ]
198
+ partition_bytes = converter (columns [0 ][0 ])
199
+
200
+ assert (
201
+ partition_bytes is not None
202
+ ), "The input iterator for Python data source read function is empty."
203
+
204
+ # Deserialize the partition value.
205
+ partition = pickleSer .loads (partition_bytes )
206
+
207
+ assert partition is None or isinstance (partition , InputPartition ), (
208
+ "Expected the partition value to be of type 'InputPartition', "
209
+ f"but found '{ type (partition ).__name__ } '."
210
+ )
211
+
212
+ output_iter = reader .read (partition ) # type: ignore[arg-type]
213
+
214
+ # Validate the output iterator.
215
+ if not isinstance (output_iter , Iterator ):
216
+ raise PySparkRuntimeError (
217
+ errorClass = "DATA_SOURCE_INVALID_RETURN_TYPE" ,
218
+ messageParameters = {
219
+ "type" : type (output_iter ).__name__ ,
220
+ "name" : data_source .name (),
221
+ "supported_types" : "iterator" ,
222
+ },
223
+ )
224
+
225
+ return records_to_arrow_batches (output_iter , max_arrow_batch_size , return_type , data_source )
226
+
227
+ command = (data_source_read_func , return_type )
228
+ pickleSer ._write_with_length (command , outfile )
229
+
230
+ if not is_streaming :
231
+ # The partitioning of python batch source read is determined before query execution.
232
+ try :
233
+ partitions = reader .partitions () # type: ignore[call-arg]
234
+ if not isinstance (partitions , list ):
235
+ raise PySparkRuntimeError (
236
+ errorClass = "DATA_SOURCE_TYPE_MISMATCH" ,
237
+ messageParameters = {
238
+ "expected" : "'partitions' to return a list" ,
239
+ "actual" : f"'{ type (partitions ).__name__ } '" ,
240
+ },
241
+ )
242
+ if not all (isinstance (p , InputPartition ) for p in partitions ):
243
+ partition_types = ", " .join ([f"'{ type (p ).__name__ } '" for p in partitions ])
244
+ raise PySparkRuntimeError (
245
+ errorClass = "DATA_SOURCE_TYPE_MISMATCH" ,
246
+ messageParameters = {
247
+ "expected" : "elements in 'partitions' to be of type 'InputPartition'" ,
248
+ "actual" : partition_types ,
249
+ },
250
+ )
251
+ if len (partitions ) == 0 :
252
+ partitions = [None ] # type: ignore[list-item]
253
+ except NotImplementedError :
254
+ partitions = [None ] # type: ignore[list-item]
255
+
256
+ # Return the serialized partition values.
257
+ write_int (len (partitions ), outfile )
258
+ for partition in partitions :
259
+ pickleSer ._write_with_length (partition , outfile )
260
+ else :
261
+ # Send an empty list of partition for stream reader because partitions are planned
262
+ # in each microbatch during query execution.
263
+ write_int (0 , outfile )
264
+
265
+
171
266
def main (infile : IO , outfile : IO ) -> None :
172
267
"""
173
268
Main method for planning a data source read.
@@ -284,91 +379,14 @@ def main(infile: IO, outfile: IO) -> None:
284
379
},
285
380
)
286
381
287
- # Create input converter.
288
- converter = ArrowTableToRowsConversion ._create_converter (BinaryType ())
289
-
290
- # Create output converter.
291
- return_type = schema
292
-
293
- def data_source_read_func (iterator : Iterable [pa .RecordBatch ]) -> Iterable [pa .RecordBatch ]:
294
- partition_bytes = None
295
-
296
- # Get the partition value from the input iterator.
297
- for batch in iterator :
298
- # There should be only one row/column in the batch.
299
- assert batch .num_columns == 1 and batch .num_rows == 1 , (
300
- "Expected each batch to have exactly 1 column and 1 row, "
301
- f"but found { batch .num_columns } columns and { batch .num_rows } rows."
302
- )
303
- columns = [column .to_pylist () for column in batch .columns ]
304
- partition_bytes = converter (columns [0 ][0 ])
305
-
306
- assert (
307
- partition_bytes is not None
308
- ), "The input iterator for Python data source read function is empty."
309
-
310
- # Deserialize the partition value.
311
- partition = pickleSer .loads (partition_bytes )
312
-
313
- assert partition is None or isinstance (partition , InputPartition ), (
314
- "Expected the partition value to be of type 'InputPartition', "
315
- f"but found '{ type (partition ).__name__ } '."
316
- )
317
-
318
- output_iter = reader .read (partition ) # type: ignore[arg-type]
319
-
320
- # Validate the output iterator.
321
- if not isinstance (output_iter , Iterator ):
322
- raise PySparkRuntimeError (
323
- errorClass = "DATA_SOURCE_INVALID_RETURN_TYPE" ,
324
- messageParameters = {
325
- "type" : type (output_iter ).__name__ ,
326
- "name" : data_source .name (),
327
- "supported_types" : "iterator" ,
328
- },
329
- )
330
-
331
- return records_to_arrow_batches (
332
- output_iter , max_arrow_batch_size , return_type , data_source
333
- )
334
-
335
- command = (data_source_read_func , return_type )
336
- pickleSer ._write_with_length (command , outfile )
337
-
338
- if not is_streaming :
339
- # The partitioning of python batch source read is determined before query execution.
340
- try :
341
- partitions = reader .partitions () # type: ignore[call-arg]
342
- if not isinstance (partitions , list ):
343
- raise PySparkRuntimeError (
344
- errorClass = "DATA_SOURCE_TYPE_MISMATCH" ,
345
- messageParameters = {
346
- "expected" : "'partitions' to return a list" ,
347
- "actual" : f"'{ type (partitions ).__name__ } '" ,
348
- },
349
- )
350
- if not all (isinstance (p , InputPartition ) for p in partitions ):
351
- partition_types = ", " .join ([f"'{ type (p ).__name__ } '" for p in partitions ])
352
- raise PySparkRuntimeError (
353
- errorClass = "DATA_SOURCE_TYPE_MISMATCH" ,
354
- messageParameters = {
355
- "expected" : "elements in 'partitions' to be of type 'InputPartition'" ,
356
- "actual" : partition_types ,
357
- },
358
- )
359
- if len (partitions ) == 0 :
360
- partitions = [None ] # type: ignore[list-item]
361
- except NotImplementedError :
362
- partitions = [None ] # type: ignore[list-item]
363
-
364
- # Return the serialized partition values.
365
- write_int (len (partitions ), outfile )
366
- for partition in partitions :
367
- pickleSer ._write_with_length (partition , outfile )
368
- else :
369
- # Send an empty list of partition for stream reader because partitions are planned
370
- # in each microbatch during query execution.
371
- write_int (0 , outfile )
382
+ # Send the read function and partitions to the JVM.
383
+ write_read_func_and_partitions (
384
+ outfile ,
385
+ reader = reader ,
386
+ data_source = data_source ,
387
+ schema = schema ,
388
+ max_arrow_batch_size = max_arrow_batch_size ,
389
+ )
372
390
except BaseException as e :
373
391
handle_worker_exception (e , outfile )
374
392
sys .exit (- 1 )
0 commit comments