diff --git a/python/ray/data/grouped_data.py b/python/ray/data/grouped_data.py index 42b9340d699..89398469dc0 100644 --- a/python/ray/data/grouped_data.py +++ b/python/ray/data/grouped_data.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator as IteratorABC from functools import partial from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union @@ -538,7 +539,10 @@ def std( def _apply_udf_to_groups( - udf: Callable[[DataBatch, ...], DataBatch], + udf: Union[ + Callable[[DataBatch, ...], DataBatch], + Callable[[DataBatch, ...], Iterator[DataBatch]], + ], block: Block, keys: List[str], batch_format: Optional[str], @@ -548,7 +552,8 @@ def _apply_udf_to_groups( """Apply UDF to groups of rows having the same set of values of the specified columns (keys). - NOTE: This function is defined at module level to avoid capturing closures and make it serializable.""" + NOTE: This function is defined at module level to avoid capturing closures and make it serializable. + """ block_accessor = BlockAccessor.for_block(block) boundaries = block_accessor._get_group_boundaries_sorted(keys) @@ -560,7 +565,17 @@ def _apply_udf_to_groups( # Convert corresponding block of each group to batch format here, # because the block format here can be different from batch format # (e.g. block is Arrow format, and batch is NumPy format). - yield udf(group_block_accessor.to_batch_format(batch_format), *args, **kwargs) + result = udf( + group_block_accessor.to_batch_format(batch_format), *args, **kwargs + ) + + # Check if the UDF returned an iterator/generator. + if isinstance(result, IteratorABC): + # If so, yield each item from the iterator. + yield from result + else: + # Otherwise, yield the single result. + yield result # Backwards compatibility alias. diff --git a/python/ray/data/tests/test_groupby_e2e.py b/python/ray/data/tests/test_groupby_e2e.py index bd0017125e6..fb1c1795a72 100644 --- a/python/ray/data/tests/test_groupby_e2e.py +++ b/python/ray/data/tests/test_groupby_e2e.py @@ -1,7 +1,7 @@ import itertools import random import time -from typing import Optional +from typing import Iterator, Optional import numpy as np import pandas as pd @@ -1142,6 +1142,40 @@ def func(x, y): assert "MapBatches(func)" in ds.__repr__() +def test_map_groups_generator_udf(ray_start_regular_shared_2_cpus): + """ + Tests that map_groups supports UDFs that return generators (iterators). + """ + ds = ray.data.from_items( + [ + {"group": 1, "data": 10}, + {"group": 1, "data": 20}, + {"group": 2, "data": 30}, + ] + ) + + def generator_udf(df: pd.DataFrame) -> Iterator[pd.DataFrame]: + # For each group, yield two DataFrames. + # 1. A DataFrame where 'data' is multiplied by 2. + yield df.assign(data=df["data"] * 2) + # 2. A DataFrame where 'data' is multiplied by 3. + yield df.assign(data=df["data"] * 3) + + # Apply the generator UDF to the grouped data. + result_ds = ds.groupby("group").map_groups(generator_udf) + + # The final dataset should contain all results from all yields. + # Group 1 -> data: [20, 40] and [30, 60] + # Group 2 -> data: [60] and [90] + expected_data = sorted([20, 40, 30, 60, 60, 90]) + + # Collect and sort the actual data to ensure correctness regardless of order. + actual_data = sorted([row["data"] for row in result_ds.take_all()]) + + assert actual_data == expected_data + assert result_ds.count() == 6 + + if __name__ == "__main__": import sys