@@ -123,6 +123,13 @@ def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path):
123123 get_pyarrow_version () < parse_version ("20.0.0" ),
124124 reason = "predicate expressions require PyArrow >= 20.0.0" ,
125125)
126+ @pytest .mark .parametrize (
127+ "data_source" ,
128+ [
129+ pytest .param ("from_items" , id = "arrow_blocks" ),
130+ pytest .param ("from_pandas" , id = "pandas_blocks" ),
131+ ],
132+ )
126133@pytest .mark .parametrize (
127134 "predicate_expr, test_data, expected_indices, test_description" ,
128135 [
@@ -230,6 +237,38 @@ def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path):
230237 [0 , 3 ],
231238 "string_exclusion_filter" ,
232239 ),
240+ # Additional comparison operations
241+ pytest .param (
242+ col ("age" ) > 25 ,
243+ [
244+ {"age" : 20 , "name" : "Alice" },
245+ {"age" : 25 , "name" : "Bob" },
246+ {"age" : 30 , "name" : "Charlie" },
247+ {"age" : 35 , "name" : "David" },
248+ ],
249+ [2 , 3 ],
250+ "greater_than_filter" ,
251+ ),
252+ pytest .param (
253+ col ("age" ) < 25 ,
254+ [
255+ {"age" : 20 , "name" : "Alice" },
256+ {"age" : 25 , "name" : "Bob" },
257+ {"age" : 30 , "name" : "Charlie" },
258+ ],
259+ [0 ],
260+ "less_than_filter" ,
261+ ),
262+ pytest .param (
263+ col ("age" ) <= 25 ,
264+ [
265+ {"age" : 20 , "name" : "Alice" },
266+ {"age" : 25 , "name" : "Bob" },
267+ {"age" : 30 , "name" : "Charlie" },
268+ ],
269+ [0 , 1 ],
270+ "less_than_equal_filter" ,
271+ ),
233272 # Membership operations
234273 pytest .param (
235274 col ("category" ).is_in (["A" , "B" ]),
@@ -241,7 +280,18 @@ def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path):
241280 {"category" : "A" , "value" : 5 },
242281 ],
243282 [0 , 1 , 4 ],
244- "membership_filter" ,
283+ "is_in_filter" ,
284+ ),
285+ pytest .param (
286+ col ("category" ).not_in (["A" , "B" ]),
287+ [
288+ {"category" : "A" , "value" : 1 },
289+ {"category" : "B" , "value" : 2 },
290+ {"category" : "C" , "value" : 3 },
291+ {"category" : "D" , "value" : 4 },
292+ ],
293+ [2 , 3 ], # These are indices not the actual values
294+ "not_in_filter" ,
245295 ),
246296 # Negation operations
247297 pytest .param (
@@ -271,14 +321,18 @@ def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path):
271321)
272322def test_filter_with_predicate_expressions (
273323 ray_start_regular_shared ,
324+ data_source ,
274325 predicate_expr ,
275326 test_data ,
276327 expected_indices ,
277328 test_description ,
278329):
279- """Test filter() with Ray Data predicate expressions."""
280- # Create dataset from test data
281- ds = ray .data .from_items (test_data )
330+ """Test filter() with Ray Data predicate expressions on both Arrow and pandas blocks."""
331+ # Create dataset based on data_source parameter
332+ if data_source == "from_items" :
333+ ds = ray .data .from_items (test_data )
334+ else : # from_pandas
335+ ds = ray .data .from_pandas ([pd .DataFrame (test_data )])
282336
283337 # Apply filter with predicate expression
284338 filtered_ds = ds .filter (expr = predicate_expr )
0 commit comments