Skip to content

Commit d05ca1c

Browse files
[Data] - Fix expression mapping for Pandas (#57868)
1 parent c60e6c1 commit d05ca1c

File tree

2 files changed

+61
-7
lines changed

2 files changed

+61
-7
lines changed

python/ray/data/_internal/planner/plan_expression/expression_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def _pa_is_in(left: Any, right: Any) -> Any:
5353
Operation.NE: operator.ne,
5454
Operation.AND: operator.and_,
5555
Operation.OR: operator.or_,
56-
Operation.NOT: operator.not_,
56+
Operation.NOT: operator.invert,
5757
Operation.IS_NULL: pd.isna,
5858
Operation.IS_NOT_NULL: pd.notna,
59-
Operation.IN: lambda left, right: left.is_in(right),
60-
Operation.NOT_IN: lambda left, right: ~left.is_in(right),
59+
Operation.IN: lambda left, right: left.isin(right),
60+
Operation.NOT_IN: lambda left, right: ~left.isin(right),
6161
}
6262

6363

python/ray/data/tests/test_filter.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path):
123123
get_pyarrow_version() < parse_version("20.0.0"),
124124
reason="predicate expressions require PyArrow >= 20.0.0",
125125
)
126+
@pytest.mark.parametrize(
127+
"data_source",
128+
[
129+
pytest.param("from_items", id="arrow_blocks"),
130+
pytest.param("from_pandas", id="pandas_blocks"),
131+
],
132+
)
126133
@pytest.mark.parametrize(
127134
"predicate_expr, test_data, expected_indices, test_description",
128135
[
@@ -230,6 +237,38 @@ def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path):
230237
[0, 3],
231238
"string_exclusion_filter",
232239
),
240+
# Additional comparison operations
241+
pytest.param(
242+
col("age") > 25,
243+
[
244+
{"age": 20, "name": "Alice"},
245+
{"age": 25, "name": "Bob"},
246+
{"age": 30, "name": "Charlie"},
247+
{"age": 35, "name": "David"},
248+
],
249+
[2, 3],
250+
"greater_than_filter",
251+
),
252+
pytest.param(
253+
col("age") < 25,
254+
[
255+
{"age": 20, "name": "Alice"},
256+
{"age": 25, "name": "Bob"},
257+
{"age": 30, "name": "Charlie"},
258+
],
259+
[0],
260+
"less_than_filter",
261+
),
262+
pytest.param(
263+
col("age") <= 25,
264+
[
265+
{"age": 20, "name": "Alice"},
266+
{"age": 25, "name": "Bob"},
267+
{"age": 30, "name": "Charlie"},
268+
],
269+
[0, 1],
270+
"less_than_equal_filter",
271+
),
233272
# Membership operations
234273
pytest.param(
235274
col("category").is_in(["A", "B"]),
@@ -241,7 +280,18 @@ def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path):
241280
{"category": "A", "value": 5},
242281
],
243282
[0, 1, 4],
244-
"membership_filter",
283+
"is_in_filter",
284+
),
285+
pytest.param(
286+
col("category").not_in(["A", "B"]),
287+
[
288+
{"category": "A", "value": 1},
289+
{"category": "B", "value": 2},
290+
{"category": "C", "value": 3},
291+
{"category": "D", "value": 4},
292+
],
293+
[2, 3], # These are indices not the actual values
294+
"not_in_filter",
245295
),
246296
# Negation operations
247297
pytest.param(
@@ -271,14 +321,18 @@ def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path):
271321
)
272322
def test_filter_with_predicate_expressions(
273323
ray_start_regular_shared,
324+
data_source,
274325
predicate_expr,
275326
test_data,
276327
expected_indices,
277328
test_description,
278329
):
279-
"""Test filter() with Ray Data predicate expressions."""
280-
# Create dataset from test data
281-
ds = ray.data.from_items(test_data)
330+
"""Test filter() with Ray Data predicate expressions on both Arrow and pandas blocks."""
331+
# Create dataset based on data_source parameter
332+
if data_source == "from_items":
333+
ds = ray.data.from_items(test_data)
334+
else: # from_pandas
335+
ds = ray.data.from_pandas([pd.DataFrame(test_data)])
282336

283337
# Apply filter with predicate expression
284338
filtered_ds = ds.filter(expr=predicate_expr)

0 commit comments

Comments
 (0)