|
18 | 18 | import platform
|
19 | 19 | import tempfile
|
20 | 20 | import unittest
|
21 |
| -from typing import Callable, Union |
| 21 | +from typing import Callable, Iterable, List, Union |
22 | 22 |
|
23 | 23 | from pyspark.errors import PythonException, AnalysisException
|
24 | 24 | from pyspark.sql.datasource import (
|
25 | 25 | DataSource,
|
26 | 26 | DataSourceReader,
|
| 27 | + EqualTo, |
| 28 | + Filter, |
27 | 29 | InputPartition,
|
28 | 30 | DataSourceWriter,
|
29 | 31 | DataSourceArrowWriter,
|
30 | 32 | WriterCommitMessage,
|
31 | 33 | CaseInsensitiveDict,
|
32 | 34 | )
|
33 | 35 | from pyspark.sql.functions import spark_partition_id
|
| 36 | +from pyspark.sql.session import SparkSession |
34 | 37 | from pyspark.sql.types import Row, StructType
|
35 | 38 | from pyspark.testing.sqlutils import (
|
36 | 39 | have_pyarrow,
|
|
42 | 45 |
|
43 | 46 | @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
|
44 | 47 | class BasePythonDataSourceTestsMixin:
|
| 48 | + spark: SparkSession |
| 49 | + |
45 | 50 | def test_basic_data_source_class(self):
|
46 | 51 | class MyDataSource(DataSource):
|
47 | 52 | ...
|
@@ -246,6 +251,161 @@ def reader(self, schema) -> "DataSourceReader":
|
246 | 251 | assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
|
247 | 252 | self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2)
|
248 | 253 |
|
| 254 | + def test_filter_pushdown(self): |
| 255 | + class TestDataSourceReader(DataSourceReader): |
| 256 | + def __init__(self): |
| 257 | + self.has_filter = False |
| 258 | + |
| 259 | + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: |
| 260 | + assert set(filters) == { |
| 261 | + EqualTo(("x",), 1), |
| 262 | + EqualTo(("y",), 2), |
| 263 | + }, filters |
| 264 | + self.has_filter = True |
| 265 | + # pretend we support x = 1 filter but in fact we don't |
| 266 | + # so we only return y = 2 filter |
| 267 | + yield filters[filters.index(EqualTo(("y",), 2))] |
| 268 | + |
| 269 | + def partitions(self): |
| 270 | + assert self.has_filter |
| 271 | + return super().partitions() |
| 272 | + |
| 273 | + def read(self, partition): |
| 274 | + assert self.has_filter |
| 275 | + yield [1, 1] |
| 276 | + yield [1, 2] |
| 277 | + yield [2, 2] |
| 278 | + |
| 279 | + class TestDataSource(DataSource): |
| 280 | + @classmethod |
| 281 | + def name(cls): |
| 282 | + return "test" |
| 283 | + |
| 284 | + def schema(self): |
| 285 | + return "x int, y int" |
| 286 | + |
| 287 | + def reader(self, schema) -> "DataSourceReader": |
| 288 | + return TestDataSourceReader() |
| 289 | + |
| 290 | + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): |
| 291 | + self.spark.dataSource.register(TestDataSource) |
| 292 | + df = self.spark.read.format("test").load().filter("x = 1 and y = 2") |
| 293 | + # only the y = 2 filter is applied post scan |
| 294 | + assertDataFrameEqual(df, [Row(x=1, y=2), Row(x=2, y=2)]) |
| 295 | + |
| 296 | + def test_extraneous_filter(self): |
| 297 | + class TestDataSourceReader(DataSourceReader): |
| 298 | + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: |
| 299 | + yield EqualTo(("x",), 1) |
| 300 | + |
| 301 | + def partitions(self): |
| 302 | + assert False |
| 303 | + |
| 304 | + def read(self, partition): |
| 305 | + assert False |
| 306 | + |
| 307 | + class TestDataSource(DataSource): |
| 308 | + @classmethod |
| 309 | + def name(cls): |
| 310 | + return "test" |
| 311 | + |
| 312 | + def schema(self): |
| 313 | + return "x int" |
| 314 | + |
| 315 | + def reader(self, schema) -> "DataSourceReader": |
| 316 | + return TestDataSourceReader() |
| 317 | + |
| 318 | + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): |
| 319 | + self.spark.dataSource.register(TestDataSource) |
| 320 | + with self.assertRaisesRegex(Exception, "DATA_SOURCE_EXTRANEOUS_FILTERS"): |
| 321 | + self.spark.read.format("test").load().filter("x = 1").show() |
| 322 | + |
| 323 | + def test_filter_pushdown_error(self): |
| 324 | + error_str = "dummy error" |
| 325 | + |
| 326 | + class TestDataSourceReader(DataSourceReader): |
| 327 | + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: |
| 328 | + raise Exception(error_str) |
| 329 | + |
| 330 | + def read(self, partition): |
| 331 | + yield [1] |
| 332 | + |
| 333 | + class TestDataSource(DataSource): |
| 334 | + def schema(self): |
| 335 | + return "x int" |
| 336 | + |
| 337 | + def reader(self, schema) -> "DataSourceReader": |
| 338 | + return TestDataSourceReader() |
| 339 | + |
| 340 | + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): |
| 341 | + self.spark.dataSource.register(TestDataSource) |
| 342 | + df = self.spark.read.format("TestDataSource").load().filter("x = 1 or x is null") |
| 343 | + assertDataFrameEqual(df, [Row(x=1)]) # works when not pushing down filters |
| 344 | + with self.assertRaisesRegex(Exception, error_str): |
| 345 | + df.filter("x = 1").explain() |
| 346 | + |
| 347 | + def test_filter_pushdown_disabled(self): |
| 348 | + class TestDataSourceReader(DataSourceReader): |
| 349 | + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: |
| 350 | + assert False |
| 351 | + |
| 352 | + def read(self, partition): |
| 353 | + assert False |
| 354 | + |
| 355 | + class TestDataSource(DataSource): |
| 356 | + def reader(self, schema) -> "DataSourceReader": |
| 357 | + return TestDataSourceReader() |
| 358 | + |
| 359 | + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": False}): |
| 360 | + self.spark.dataSource.register(TestDataSource) |
| 361 | + df = self.spark.read.format("TestDataSource").schema("x int").load() |
| 362 | + with self.assertRaisesRegex(Exception, "DATA_SOURCE_PUSHDOWN_DISABLED"): |
| 363 | + df.show() |
| 364 | + |
| 365 | + def _check_filters(self, sql_type, sql_filter, python_filters): |
| 366 | + """ |
| 367 | + Parameters |
| 368 | + ---------- |
| 369 | + sql_type: str |
| 370 | + The SQL type of the column x. |
| 371 | + sql_filter: str |
| 372 | + A SQL filter using the column x. |
| 373 | + python_filters: List[Filter] |
| 374 | + The expected python filters to be pushed down. |
| 375 | + """ |
| 376 | + |
| 377 | + class TestDataSourceReader(DataSourceReader): |
| 378 | + def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]: |
| 379 | + expected = python_filters |
| 380 | + assert filters == expected, (filters, expected) |
| 381 | + return filters |
| 382 | + |
| 383 | + def read(self, partition): |
| 384 | + yield from [] |
| 385 | + |
| 386 | + class TestDataSource(DataSource): |
| 387 | + def schema(self): |
| 388 | + return f"x {sql_type}" |
| 389 | + |
| 390 | + def reader(self, schema) -> "DataSourceReader": |
| 391 | + return TestDataSourceReader() |
| 392 | + |
| 393 | + with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}): |
| 394 | + self.spark.dataSource.register(TestDataSource) |
| 395 | + df = self.spark.read.format("TestDataSource").load().filter(sql_filter) |
| 396 | + df.count() |
| 397 | + |
| 398 | + def test_unsupported_filter(self): |
| 399 | + self._check_filters( |
| 400 | + "struct<a:int, b:int, c:int>", "x.a = 1 and x.b = x.c", [EqualTo(("x", "a"), 1)] |
| 401 | + ) |
| 402 | + self._check_filters("int", "x <> 0", []) |
| 403 | + self._check_filters("int", "x = 1 or x > 2", []) |
| 404 | + self._check_filters("int", "(0 < x and x < 1) or x = 2", []) |
| 405 | + self._check_filters("int", "x % 5 = 1", []) |
| 406 | + self._check_filters("boolean", "not x", []) |
| 407 | + self._check_filters("array<int>", "x[0] = 1", []) |
| 408 | + |
249 | 409 | def _get_test_json_data_source(self):
|
250 | 410 | import json
|
251 | 411 | import os
|
|
0 commit comments