Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@

myst_heading_anchors = 3

# Make broken internal references into build time errors.
# See https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-nitpicky
# for more information. :py:class: references are ignored due to false positives
# arising from type annotations. See https://github.com/ray-project/ray/pull/46103
# for additional context.
nitpicky = True
nitpick_ignore_regex = [
("py:class", ".*"),
# Workaround for https://github.com/sphinx-doc/sphinx/issues/10974
("py:obj", "ray\.data\.datasource\.datasink\.WriteReturnType"),
]

# Cache notebook outputs in _build/.jupyter_cache
# To prevent notebook execution, set this to "off". To force re-execution, set this to
# "force". To cache previous runs, set this to "cache".
Expand Down
49 changes: 49 additions & 0 deletions doc/source/data/api/input_output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,53 @@ Databricks

read_databricks_tables

Delta Sharing
-------------

.. autosummary::
:nosignatures:
:toctree: doc/

read_delta_sharing_tables

Hudi
----

.. autosummary::
:nosignatures:
:toctree: doc/

read_hudi

Iceberg
-------

.. autosummary::
:nosignatures:
:toctree: doc/

read_iceberg
Dataset.write_iceberg

Lance
-----

.. autosummary::
:nosignatures:
:toctree: doc/

read_lance
Dataset.write_lance

ClickHouse
----------

.. autosummary::
:nosignatures:
:toctree: doc/

read_clickhouse

Dask
----

Expand Down Expand Up @@ -270,6 +317,8 @@ Datasink API
datasource.RowBasedFileDatasink
datasource.BlockBasedFileDatasink
datasource.FileBasedDatasource
datasource.WriteResult
datasource.WriteReturnType

Partitioning API
----------------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional

from ray.data._internal.progress_bar import ProgressBar
Expand Down Expand Up @@ -39,3 +39,6 @@ class TaskContext:

# The target maximum number of bytes to include in the task's output block.
target_max_block_size: Optional[int] = None

# Additional keyword arguments passed to the task.
kwargs: Dict[str, Any] = field(default_factory=dict)
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,12 @@ def _dispatch_tasks(self):
num_returns="streaming",
name=self.name,
**self._ray_actor_task_remote_args,
).remote(DataContext.get_current(), ctx, *input_blocks)
).remote(
DataContext.get_current(),
ctx,
*input_blocks,
**self.get_map_task_kwargs(),
)

def _task_done_callback(actor_to_return):
# Return the actor that was running the task to the pool.
Expand Down Expand Up @@ -401,12 +406,14 @@ def submit(
data_context: DataContext,
ctx: TaskContext,
*blocks: Block,
**kwargs: Dict[str, Any],
) -> Iterator[Union[Block, List[BlockMetadata]]]:
yield from _map_task(
self._map_transformer,
data_context,
ctx,
*blocks,
**kwargs,
)

def __repr__(self):
Expand Down
21 changes: 21 additions & 0 deletions python/ray/data/_internal/execution/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@ def __init__(
# too-large blocks, which may reduce parallelism for
# the subsequent operator.
self._additional_split_factor = None
# Callback functions that generate additional task kwargs
# for the map task.
self._map_task_kwargs_fns: List[Callable[[], Dict[str, Any]]] = []

def add_map_task_kwargs_fn(self, map_task_kwargs_fn: Callable[[], Dict[str, Any]]):
"""Add a callback function that generates additional kwargs for the map tasks.
In the map tasks, the kwargs can be accessible via `TaskContext.kwargs`.
"""
self._map_task_kwargs_fns.append(map_task_kwargs_fn)

def get_map_task_kwargs(self) -> Dict[str, Any]:
"""Get the kwargs for the map task.
Subclasses should pass the returned kwargs to the map tasks.
In the map tasks, the kwargs can be accessible via `TaskContext.kwargs`.
"""
kwargs = {}
for fn in self._map_task_kwargs_fns:
kwargs.update(fn())
return kwargs

def get_additional_split_factor(self) -> int:
if self._additional_split_factor is None:
Expand Down Expand Up @@ -402,6 +421,7 @@ def _map_task(
data_context: DataContext,
ctx: TaskContext,
*blocks: Block,
**kwargs: Dict[str, Any],
) -> Iterator[Union[Block, List[BlockMetadata]]]:
"""Remote function for a single operator task.

Expand All @@ -415,6 +435,7 @@ def _map_task(
as the last generator return.
"""
DataContext._set_current(data_context)
ctx.kwargs.update(kwargs)
stats = BlockExecStats.builder()
map_transformer.set_target_max_block_size(ctx.target_max_block_size)
for b_out in map_transformer.apply_transform(iter(blocks), ctx):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def _add_bundled_input(self, bundle: RefBundle):
self._map_transformer_ref,
data_context,
ctx,
*input_blocks,
*bundle.block_refs,
**self.get_map_task_kwargs(),
)
self._submit_data_task(gen, bundle)

Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from typing import List, Optional, Tuple

# TODO(Clark): Remove compute dependency once we delete the legacy compute.
Expand Down Expand Up @@ -311,6 +312,11 @@ def _get_fused_map_operator(
min_rows_per_bundle=min_rows_per_bundled_input,
ray_remote_args=ray_remote_args,
)
op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators)
for map_task_kwargs_fn in itertools.chain(
up_op._map_task_kwargs_fns, down_op._map_task_kwargs_fns
):
op.add_map_task_kwargs_fn(map_task_kwargs_fn)

# Build a map logical operator to be used as a reference for further fusion.
# TODO(Scott): This is hacky, remove this once we push fusion to be purely based
Expand Down
72 changes: 58 additions & 14 deletions python/ray/data/_internal/planner/plan_write_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Callable, Iterator, Union
import itertools
from typing import Callable, Iterator, List, Union

from pandas import DataFrame

from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.execution.interfaces import PhysicalOperator
Expand All @@ -9,41 +12,82 @@
MapTransformer,
)
from ray.data._internal.logical.operators.write_operator import Write
from ray.data.block import Block
from ray.data.datasource.datasink import Datasink
from ray.data.block import Block, BlockAccessor
from ray.data.datasource.datasink import Datasink, WriteResult
from ray.data.datasource.datasource import Datasource


def gen_datasink_write_result(
write_result_blocks: List[Block],
) -> WriteResult:
assert all(
isinstance(block, DataFrame) and len(block) == 1
for block in write_result_blocks
)
total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks)
total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks)

write_returns = [result["write_return"][0] for result in write_result_blocks]
return WriteResult(total_num_rows, total_size_bytes, write_returns)


def generate_write_fn(
datasink_or_legacy_datasource: Union[Datasink, Datasource], **write_args
) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]:
# If the write op succeeds, the resulting Dataset is a list of
# arbitrary objects (one object per write task). Otherwise, an error will
# be raised. The Datasource can handle execution outcomes with the
# on_write_complete() and on_write_failed().
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
"""Writes the blocks to the given datasink or legacy datasource.

Outputs the original blocks to be written."""
# Create a copy of the iterator, so we can return the original blocks.
it1, it2 = itertools.tee(blocks, 2)
if isinstance(datasink_or_legacy_datasource, Datasink):
write_result = datasink_or_legacy_datasource.write(blocks, ctx)
else:
write_result = datasink_or_legacy_datasource.write(
blocks, ctx, **write_args
ctx.kwargs["_datasink_write_return"] = datasink_or_legacy_datasource.write(
it1, ctx
)
else:
datasink_or_legacy_datasource.write(it1, ctx, **write_args)

return it2

return fn


def generate_collect_write_stats_fn() -> Callable[
[Iterator[Block], TaskContext], Iterator[Block]
]:
# If the write op succeeds, the resulting Dataset is a list of
# one Block which contain stats/metrics about the write.
# Otherwise, an error will be raised. The Datasource can handle
# execution outcomes with `on_write_complete()`` and `on_write_failed()``.
def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
"""Handles stats collection for block writes."""
block_accessors = [BlockAccessor.for_block(block) for block in blocks]
total_num_rows = sum(ba.num_rows() for ba in block_accessors)
total_size_bytes = sum(ba.size_bytes() for ba in block_accessors)

# NOTE: Write tasks can return anything, so we need to wrap it in a valid block
# type.
import pandas as pd

block = pd.DataFrame({"write_result": [write_result]})
return [block]
block = pd.DataFrame(
{
"num_rows": [total_num_rows],
"size_bytes": [total_size_bytes],
"write_return": [ctx.kwargs.get("_datasink_write_return", None)],
}
)
return iter([block])

return fn


def plan_write_op(op: Write, input_physical_dag: PhysicalOperator) -> PhysicalOperator:
write_fn = generate_write_fn(op._datasink_or_legacy_datasource, **op._write_args)
collect_stats_fn = generate_collect_write_stats_fn()
# Create a MapTransformer for a write operator
transform_fns = [
BlockMapTransformFn(write_fn),
BlockMapTransformFn(collect_stats_fn),
]
map_transformer = MapTransformer(transform_fns)
return MapOperator.create(
Expand Down
Loading
Loading