Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterator as IteratorABC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The collection.abc.Iterator is same with the typing.Iterator, just reuse it.

from functools import partial
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -537,7 +538,10 @@ def std(


def _apply_udf_to_groups(
udf: Callable[[DataBatch, ...], DataBatch],
udf: Union[
Callable[[DataBatch, ...], DataBatch],
Callable[[DataBatch, ...], Iterator[DataBatch]],
],
block: Block,
keys: List[str],
batch_format: Optional[str],
Expand All @@ -547,7 +551,8 @@ def _apply_udf_to_groups(
"""Apply UDF to groups of rows having the same set of values of the specified
columns (keys).

NOTE: This function is defined at module level to avoid capturing closures and make it serializable."""
NOTE: This function is defined at module level to avoid capturing closures and make it serializable.
"""
block_accessor = BlockAccessor.for_block(block)

boundaries = block_accessor._get_group_boundaries_sorted(keys)
Expand All @@ -559,7 +564,17 @@ def _apply_udf_to_groups(
# Convert corresponding block of each group to batch format here,
# because the block format here can be different from batch format
# (e.g. block is Arrow format, and batch is NumPy format).
yield udf(group_block_accessor.to_batch_format(batch_format), *args, **kwargs)
result = udf(
group_block_accessor.to_batch_format(batch_format), *args, **kwargs
)

# Check if the UDF returned an iterator/generator.
if isinstance(result, IteratorABC):
# If so, yield each item from the iterator.
yield from result
else:
# Otherwise, yield the single result.
yield result


# Backwards compatibility alias.
Expand Down
36 changes: 35 additions & 1 deletion python/ray/data/tests/test_groupby_e2e.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import random
import time
from typing import Optional
from typing import Iterator, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1142,6 +1142,40 @@ def func(x, y):
assert "MapBatches(func)" in ds.__repr__()


def test_map_groups_generator_udf(ray_start_regular_shared_2_cpus):
"""
Tests that map_groups supports UDFs that return generators (iterators).
"""
ds = ray.data.from_items(
[
{"group": 1, "data": 10},
{"group": 1, "data": 20},
{"group": 2, "data": 30},
]
)

def generator_udf(df: pd.DataFrame) -> Iterator[pd.DataFrame]:
# For each group, yield two DataFrames.
# 1. A DataFrame where 'data' is multiplied by 2.
yield df.assign(data=df["data"] * 2)
# 2. A DataFrame where 'data' is multiplied by 3.
yield df.assign(data=df["data"] * 3)

# Apply the generator UDF to the grouped data.
result_ds = ds.groupby("group").map_groups(generator_udf)

# The final dataset should contain all results from all yields.
# Group 1 -> data: [20, 40] and [30, 60]
# Group 2 -> data: [60] and [90]
expected_data = sorted([20, 40, 30, 60, 60, 90])

# Collect and sort the actual data to ensure correctness regardless of order.
actual_data = sorted([row["data"] for row in result_ds.take_all()])

assert actual_data == expected_data
assert result_ds.count() == 6


if __name__ == "__main__":
import sys

Expand Down