From 146553d9a3f132387f81e457fc17e587aeec1622 Mon Sep 17 00:00:00 2001 From: jimmyxie-figma Date: Tue, 25 Feb 2025 13:44:11 -0500 Subject: [PATCH 1/8] [data] add iceberg write support through pyiceberg (#50589) Add iceberg write support through pyiceberg. Ray will distributedly write the dataset using the Datasink interface through blocks to an iceberg table leveraging pyarrow as parquet files, and commit append snapshot to the iceberg table. There are couple of limitations that's worth pointing out - Supports only appending to an existing Iceberg table. - Only supports the Parquet data file format. - Ray will validate schema compatibility between its blocks and the Iceberg table and will evolve the Iceberg table schema if needed. - A single block is bin-packed into the target output size; however, the sink does not bin-pack across different blocks. - This is a limitation of [the PyIceberg API](https://github.com/apache/iceberg-python/blob/main/pyiceberg/io/pyarrow.py#L2535C5-L2535C29), and cross-block bin packing can be addressed in the future by extending the API on the Pyiceberg side --------- Signed-off-by: Jimmy Xie --- doc/source/data/api/input_output.rst | 47 ++++ python/ray/data/dataset.py | 57 +++- .../ray/data/datasource/iceberg_datasink.py | 161 +++++++++++ python/ray/data/tests/test_iceberg.py | 262 ++++++++++++++++++ 4 files changed, 526 insertions(+), 1 deletion(-) create mode 100644 python/ray/data/datasource/iceberg_datasink.py create mode 100644 python/ray/data/tests/test_iceberg.py diff --git a/doc/source/data/api/input_output.rst b/doc/source/data/api/input_output.rst index d6b496990ef6..aafc0c91c1dc 100644 --- a/doc/source/data/api/input_output.rst +++ b/doc/source/data/api/input_output.rst @@ -168,6 +168,53 @@ Databricks read_databricks_tables +Delta Sharing +------------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_delta_sharing_tables + +Hudi +---- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_hudi + +Iceberg +------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_iceberg + Dataset.write_iceberg + +Lance +----- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_lance + Dataset.write_lance + +ClickHouse +---------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_clickhouse + Dask ---- diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 780fb4acd957..f732cbaee898 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -30,7 +30,7 @@ from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray from ray.data._internal.block_list import BlockList from ray.data._internal.compute import ComputeStrategy -from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.datasource.iceberg_datasink import IcebergDatasink from ray.data._internal.equalize import _equalize from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.execution.legacy_compat import _block_list_to_bundles @@ -2845,6 +2845,61 @@ def write_json( concurrency=concurrency, ) + @ConsumptionAPI + @PublicAPI(stability="alpha") + def write_iceberg( + self, + table_identifier: str, + catalog_kwargs: Optional[Dict[str, Any]] = None, + snapshot_properties: Optional[Dict[str, str]] = None, + ray_remote_args: Dict[str, Any] = None, + concurrency: Optional[int] = None, + ) -> None: + """Writes the :class:`~ray.data.Dataset` to an Iceberg table. + + .. tip:: + For more details on PyIceberg, see + - URI: https://py.iceberg.apache.org/ + + Examples: + .. testcode:: + :skipif: True + + import ray + import pandas as pd + docs = [{"title": "Iceberg data sink test"} for key in range(4)] + ds = ray.data.from_pandas(pd.DataFrame(docs)) + ds.write_iceberg( + table_identifier="db_name.table_name", + catalog_kwargs={"name": "default", "type": "sql"} + ) + + Args: + table_identifier: Fully qualified table identifier (``db_name.table_name``) + catalog_kwargs: Optional arguments to pass to PyIceberg's catalog.load_catalog() + function (e.g., name, type, etc.). For the function definition, see + `pyiceberg catalog + `_. + snapshot_properties: custom properties write to snapshot when committing + to an iceberg table. + ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run. By default, concurrency is dynamically + decided based on the available resources. + """ + + datasink = IcebergDatasink( + table_identifier, catalog_kwargs, snapshot_properties + ) + + self.write_datasink( + datasink, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + ) + @PublicAPI(stability="alpha") @ConsumptionAPI def write_images( diff --git a/python/ray/data/datasource/iceberg_datasink.py b/python/ray/data/datasource/iceberg_datasink.py new file mode 100644 index 000000000000..7f2ae9964f7a --- /dev/null +++ b/python/ray/data/datasource/iceberg_datasink.py @@ -0,0 +1,161 @@ +""" +Module to write a Ray Dataset into an iceberg table, by using the Ray Datasink API. +""" +import logging + +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional + +from ray.data.datasource.datasink import Datasink +from ray.util.annotations import DeveloperAPI +from ray.data.block import BlockAccessor, Block +from ray.data._internal.execution.interfaces import TaskContext +from ray.data.datasource.datasink import WriteResult +import uuid + +if TYPE_CHECKING: + from pyiceberg.catalog import Catalog + from pyiceberg.manifest import DataFile + + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class IcebergDatasink(Datasink[List["DataFile"]]): + """ + Iceberg datasink to write a Ray Dataset into an existing Iceberg table. This module + heavily uses PyIceberg to write to iceberg table. All the routines in this class override + `ray.data.Datasink`. + + """ + + def __init__( + self, + table_identifier: str, + catalog_kwargs: Optional[Dict[str, Any]] = None, + snapshot_properties: Optional[Dict[str, str]] = None, + ): + """ + Initialize the IcebergDatasink + + Args: + table_identifier: The identifier of the table to read e.g. `default.taxi_dataset` + catalog_kwargs: Optional arguments to use when setting up the Iceberg + catalog + snapshot_properties: custom properties write to snapshot when committing + to an iceberg table, e.g. {"commit_time": "2021-01-01T00:00:00Z"} + """ + + from pyiceberg.io import FileIO + from pyiceberg.table import Transaction + from pyiceberg.table.metadata import TableMetadata + + self.table_identifier = table_identifier + self._catalog_kwargs = catalog_kwargs if catalog_kwargs is not None else {} + self._snapshot_properties = ( + snapshot_properties if snapshot_properties is not None else {} + ) + + if "name" in self._catalog_kwargs: + self._catalog_name = self._catalog_kwargs.pop("name") + else: + self._catalog_name = "default" + + self._uuid: str = None + self._io: FileIO = None + self._txn: Transaction = None + self._table_metadata: TableMetadata = None + + # Since iceberg transaction is not pickle-able, because of the table and catalog properties + # we need to exclude the transaction object during serialization and deserialization during pickle + def __getstate__(self) -> dict: + """Exclude `_txn` during pickling.""" + state = self.__dict__.copy() + del state["_txn"] + return state + + def __setstate__(self, state: dict) -> None: + self.__dict__.update(state) + self._txn = None + + def _get_catalog(self) -> "Catalog": + from pyiceberg import catalog + + return catalog.load_catalog(self._catalog_name, **self._catalog_kwargs) + + def on_write_start(self) -> None: + """Prepare for the transaction""" + from pyiceberg.table import PropertyUtil, TableProperties + + catalog = self._get_catalog() + table = catalog.load_table(self.table_identifier) + self._txn = table.transaction() + self._io = self._txn._table.io + self._table_metadata = self._txn.table_metadata + self._uuid = uuid.uuid4() + + if unsupported_partitions := [ + field + for field in self._table_metadata.spec().fields + if not field.transform.supports_pyarrow_transform + ]: + raise ValueError( + f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." + ) + + self._manifest_merge_enabled = PropertyUtil.property_as_bool( + self._table_metadata.properties, + TableProperties.MANIFEST_MERGE_ENABLED, + TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT, + ) + + def write( + self, blocks: Iterable[Block], ctx: TaskContext + ) -> WriteResult[List["DataFile"]]: + from pyiceberg.io.pyarrow import ( + _check_pyarrow_schema_compatible, + _dataframe_to_data_files, + ) + from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE + from pyiceberg.utils.config import Config + + data_files_list: WriteResult[List["DataFile"]] = [] + for block in blocks: + pa_table = BlockAccessor.for_block(block).to_arrow() + + downcast_ns_timestamp_to_us = ( + Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + ) + _check_pyarrow_schema_compatible( + self._table_metadata.schema(), + provided_schema=pa_table.schema, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + ) + + if pa_table.shape[0] <= 0: + continue + + data_files = _dataframe_to_data_files( + self._table_metadata, pa_table, self._io, self._uuid + ) + data_files_list.extend(data_files) + + return data_files_list + + def on_write_complete(self, write_result: WriteResult[List["DataFile"]]): + update_snapshot = self._txn.update_snapshot( + snapshot_properties=self._snapshot_properties + ) + append_method = ( + update_snapshot.merge_append + if self._manifest_merge_enabled + else update_snapshot.fast_append + ) + + with append_method() as append_files: + append_files.commit_uuid = self._uuid + for data_files in write_result.write_returns: + for data_file in data_files: + append_files.append_data_file(data_file) + + self._txn.commit_transaction() diff --git a/python/ray/data/tests/test_iceberg.py b/python/ray/data/tests/test_iceberg.py new file mode 100644 index 000000000000..f89b50193827 --- /dev/null +++ b/python/ray/data/tests/test_iceberg.py @@ -0,0 +1,262 @@ +import os +import random + +import pyarrow as pa +import pytest +from pkg_resources import parse_version +from pyiceberg import catalog as pyi_catalog +from pyiceberg import expressions as pyi_expr +from pyiceberg import schema as pyi_schema +from pyiceberg import types as pyi_types +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.transforms import IdentityTransform + +import ray +from ray._private.utils import _get_pyarrow_version +from ray.data import read_iceberg +from ray.data._internal.datasource.iceberg_datasource import IcebergDatasource + +_CATALOG_NAME = "ray_catalog" +_DB_NAME = "ray_db" +_TABLE_NAME = "ray_test" +_WAREHOUSE_PATH = "/tmp/warehouse" + +_CATALOG_KWARGS = { + "name": _CATALOG_NAME, + "type": "sql", + "uri": f"sqlite:///{_WAREHOUSE_PATH}/ray_pyiceberg_test_catalog.db", + "warehouse": f"file://{_WAREHOUSE_PATH}", +} + +_SCHEMA = pa.schema( + [ + pa.field("col_a", pa.int32()), + pa.field("col_b", pa.string()), + pa.field("col_c", pa.int16()), + ] +) + + +def create_pa_table(): + return pa.Table.from_pydict( + mapping={ + "col_a": list(range(120)), + "col_b": random.choices(["a", "b", "c", "d"], k=120), + "col_c": random.choices(list(range(10)), k=120), + }, + schema=_SCHEMA, + ) + + +@pytest.fixture(autouse=True, scope="function") +def pyiceberg_table(): + from pyiceberg.catalog.sql import SqlCatalog + + if not os.path.exists(_WAREHOUSE_PATH): + os.makedirs(_WAREHOUSE_PATH) + dummy_catalog = SqlCatalog( + _CATALOG_NAME, + **{ + "uri": f"sqlite:///{_WAREHOUSE_PATH}/ray_pyiceberg_test_catalog.db", + "warehouse": f"file://{_WAREHOUSE_PATH}", + }, + ) + + pya_table = create_pa_table() + + if (_DB_NAME,) not in dummy_catalog.list_namespaces(): + dummy_catalog.create_namespace(_DB_NAME) + if (_DB_NAME, _TABLE_NAME) in dummy_catalog.list_tables(_DB_NAME): + dummy_catalog.drop_table(f"{_DB_NAME}.{_TABLE_NAME}") + + # Create the table, and add data to it + table = dummy_catalog.create_table( + f"{_DB_NAME}.{_TABLE_NAME}", + schema=pyi_schema.Schema( + pyi_types.NestedField( + field_id=1, + name="col_a", + field_type=pyi_types.IntegerType(), + required=False, + ), + pyi_types.NestedField( + field_id=2, + name="col_b", + field_type=pyi_types.StringType(), + required=False, + ), + pyi_types.NestedField( + field_id=3, + name="col_c", + field_type=pyi_types.IntegerType(), + required=False, + ), + ), + partition_spec=PartitionSpec( + PartitionField( + source_id=3, field_id=3, transform=IdentityTransform(), name="col_c" + ) + ), + ) + table.append(pya_table) + + # Delete some data so there are delete file(s) + table.delete(delete_filter=pyi_expr.GreaterThanOrEqual("col_a", 101)) + + +@pytest.mark.skipif( + parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), + reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", +) +def test_get_catalog(): + + iceberg_ds = IcebergDatasource( + table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", + catalog_kwargs=_CATALOG_KWARGS.copy(), + ) + catalog = iceberg_ds._get_catalog() + assert catalog.name == _CATALOG_NAME + + +@pytest.mark.skipif( + parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), + reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", +) +def test_plan_files(): + + iceberg_ds = IcebergDatasource( + table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", + catalog_kwargs=_CATALOG_KWARGS.copy(), + ) + plan_files = iceberg_ds.plan_files + assert len(plan_files) == 10 + + +@pytest.mark.skipif( + parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), + reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", +) +def test_chunk_plan_files(): + + iceberg_ds = IcebergDatasource( + table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", + catalog_kwargs=_CATALOG_KWARGS.copy(), + ) + + chunks = iceberg_ds._distribute_tasks_into_equal_chunks(iceberg_ds.plan_files, 5) + assert (len(c) == 2 for c in chunks), chunks + + chunks = iceberg_ds._distribute_tasks_into_equal_chunks(iceberg_ds.plan_files, 20) + assert ( + sum(len(c) == 1 for c in chunks) == 10 + and sum(len(c) == 0 for c in chunks) == 10 + ), chunks + + +@pytest.mark.skipif( + parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), + reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", +) +def test_get_read_tasks(): + + iceberg_ds = IcebergDatasource( + table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", + catalog_kwargs=_CATALOG_KWARGS.copy(), + ) + read_tasks = iceberg_ds.get_read_tasks(5) + assert len(read_tasks) == 5 + assert all(len(rt.metadata.input_files) == 2 for rt in read_tasks) + + +@pytest.mark.skipif( + parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), + reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", +) +def test_filtered_read(): + + from pyiceberg import expressions as pyi_expr + + iceberg_ds = IcebergDatasource( + table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", + row_filter=pyi_expr.In("col_c", {1, 2, 3, 4}), + selected_fields=("col_b",), + catalog_kwargs=_CATALOG_KWARGS.copy(), + ) + read_tasks = iceberg_ds.get_read_tasks(5) + # Should be capped to 4, as there will be only 4 files + assert len(read_tasks) == 4, read_tasks + assert all(len(rt.metadata.input_files) == 1 for rt in read_tasks) + + +@pytest.mark.skipif( + parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), + reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", +) +def test_read_basic(): + + row_filter = pyi_expr.In("col_c", {1, 2, 3, 4, 5, 6, 7, 8}) + + ray_ds = read_iceberg( + table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", + row_filter=row_filter, + selected_fields=("col_a", "col_b"), + catalog_kwargs=_CATALOG_KWARGS.copy(), + ) + table: pa.Table = pa.concat_tables((ray.get(ref) for ref in ray_ds.to_arrow_refs())) + + # string -> large_string because pyiceberg by default chooses large_string + expected_schema = pa.schema( + [pa.field("col_a", pa.int32()), pa.field("col_b", pa.large_string())] + ) + assert table.schema.equals(expected_schema) + + # Read the raw table from PyIceberg + sql_catalog = pyi_catalog.load_catalog(**_CATALOG_KWARGS) + orig_table_p = ( + sql_catalog.load_table(f"{_DB_NAME}.{_TABLE_NAME}") + .scan(row_filter=row_filter, selected_fields=("col_a", "col_b")) + .to_pandas() + .sort_values(["col_a", "col_b"]) + .reset_index(drop=True) + ) + + # Actually compare the tables now + table_p = ray_ds.to_pandas().sort_values(["col_a", "col_b"]).reset_index(drop=True) + assert orig_table_p.equals(table_p) + + +@pytest.mark.skipif( + parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), + reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", +) +def test_write_basic(): + + sql_catalog = pyi_catalog.load_catalog(**_CATALOG_KWARGS) + table = sql_catalog.load_table(f"{_DB_NAME}.{_TABLE_NAME}") + table.delete() + + ds = ray.data.from_arrow(create_pa_table()) + ds.write_iceberg( + table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", + catalog_kwargs=_CATALOG_KWARGS.copy(), + ) + + # Read the raw table from PyIceberg after writing + table = sql_catalog.load_table(f"{_DB_NAME}.{_TABLE_NAME}") + orig_table_p = ( + table.scan() + .to_pandas() + .sort_values(["col_a", "col_b", "col_c"]) + .reset_index(drop=True) + ) + + table_p = ( + ds.to_pandas().sort_values(["col_a", "col_b", "col_c"]).reset_index(drop=True) + ) + assert orig_table_p.equals(table_p) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) From 1a1cdd06d72d4211d794ccf8ae528608b1aa32a4 Mon Sep 17 00:00:00 2001 From: votrou Date: Mon, 31 Mar 2025 14:18:46 -0700 Subject: [PATCH 2/8] add back removed import --- python/ray/data/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index f732cbaee898..fa3ea9b08ece 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -30,6 +30,7 @@ from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray from ray.data._internal.block_list import BlockList from ray.data._internal.compute import ComputeStrategy +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.datasource.iceberg_datasink import IcebergDatasink from ray.data._internal.equalize import _equalize from ray.data._internal.execution.interfaces import RefBundle From 9b5a26a13475cec83823dd4cb9175477d3dbff06 Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Thu, 10 Oct 2024 13:57:33 -0700 Subject: [PATCH 3/8] cherry pick write result --- .../data/_internal/planner/plan_write_op.py | 44 ++++++++--- python/ray/data/dataset.py | 10 ++- .../ray/data/datasource/bigquery_datasink.py | 6 +- python/ray/data/datasource/datasink.py | 79 ++++++++++++++++--- python/ray/data/datasource/file_datasink.py | 20 +++-- python/ray/data/datasource/mongo_datasink.py | 6 +- .../ray/data/datasource/parquet_datasink.py | 6 +- python/ray/data/datasource/sql_datasink.py | 6 +- python/ray/data/tests/test_bigquery.py | 25 ++++-- python/ray/data/tests/test_datasink.py | 4 +- python/ray/data/tests/test_formats.py | 13 +-- 11 files changed, 154 insertions(+), 65 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_write_op.py b/python/ray/data/_internal/planner/plan_write_op.py index c33e831fde0b..39184b5d1279 100644 --- a/python/ray/data/_internal/planner/plan_write_op.py +++ b/python/ray/data/_internal/planner/plan_write_op.py @@ -1,4 +1,5 @@ -from typing import Callable, Iterator, Union +import itertools +from typing import Callable, Iterator, List, Union from ray.data._internal.compute import TaskPoolStrategy from ray.data._internal.execution.interfaces import PhysicalOperator @@ -9,41 +10,60 @@ MapTransformer, ) from ray.data._internal.logical.operators.write_operator import Write -from ray.data.block import Block -from ray.data.datasource.datasink import Datasink +from ray.data.block import Block, BlockAccessor +from ray.data.datasource.datasink import Datasink, WriteResult from ray.data.datasource.datasource import Datasource def generate_write_fn( datasink_or_legacy_datasource: Union[Datasink, Datasource], **write_args ) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]: - # If the write op succeeds, the resulting Dataset is a list of - # arbitrary objects (one object per write task). Otherwise, an error will - # be raised. The Datasource can handle execution outcomes with the - # on_write_complete() and on_write_failed(). def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: + """Writes the blocks to the given datasink or legacy datasource. + + Outputs the original blocks to be written.""" + # Create a copy of the iterator, so we can return the original blocks. + it1, it2 = itertools.tee(blocks, 2) if isinstance(datasink_or_legacy_datasource, Datasink): - write_result = datasink_or_legacy_datasource.write(blocks, ctx) + datasink_or_legacy_datasource.write(it1, ctx) else: - write_result = datasink_or_legacy_datasource.write( - blocks, ctx, **write_args - ) + datasink_or_legacy_datasource.write(it1, ctx, **write_args) + return it2 + + return fn + + +def generate_collect_write_stats_fn() -> Callable[ + [Iterator[Block], TaskContext], Iterator[Block] +]: + # If the write op succeeds, the resulting Dataset is a list of + # one Block which contain stats/metrics about the write. + # Otherwise, an error will be raised. The Datasource can handle + # execution outcomes with `on_write_complete()`` and `on_write_failed()``. + def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: + """Handles stats collection for block writes.""" + block_accessors = [BlockAccessor.for_block(block) for block in blocks] + total_num_rows = sum(ba.num_rows() for ba in block_accessors) + total_size_bytes = sum(ba.size_bytes() for ba in block_accessors) # NOTE: Write tasks can return anything, so we need to wrap it in a valid block # type. import pandas as pd + write_result = WriteResult(num_rows=total_num_rows, size_bytes=total_size_bytes) block = pd.DataFrame({"write_result": [write_result]}) - return [block] + return iter([block]) return fn def plan_write_op(op: Write, input_physical_dag: PhysicalOperator) -> PhysicalOperator: write_fn = generate_write_fn(op._datasink_or_legacy_datasource, **op._write_args) + collect_stats_fn = generate_collect_write_stats_fn() # Create a MapTransformer for a write operator transform_fns = [ BlockMapTransformFn(write_fn), + BlockMapTransformFn(collect_stats_fn), ] map_transformer = MapTransformer(transform_fns) return MapOperator.create( diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index fa3ea9b08ece..516fafc866dc 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3636,13 +3636,15 @@ def write_datasink( datasink.on_write_start() self._write_ds = Dataset(plan, logical_plan).materialize() - blocks = ray.get(self._write_ds._plan.execute().get_blocks()) + # TODO: Get and handle the blocks with an iterator instead of getting + # everything in a blocking way, so some blocks can be freed earlier. + raw_write_results = ray.get(self._write_ds._plan.execute().block_refs) assert all( - isinstance(block, pd.DataFrame) and len(block) == 1 for block in blocks + isinstance(block, pd.DataFrame) and len(block) == 1 + for block in raw_write_results ) - write_results = [block["write_result"][0] for block in blocks] + datasink.on_write_complete(raw_write_results) - datasink.on_write_complete(write_results) except Exception as e: datasink.on_write_failed(e) raise diff --git a/python/ray/data/datasource/bigquery_datasink.py b/python/ray/data/datasource/bigquery_datasink.py index 33550f0791cb..be1341a29ae5 100644 --- a/python/ray/data/datasource/bigquery_datasink.py +++ b/python/ray/data/datasource/bigquery_datasink.py @@ -3,7 +3,7 @@ import tempfile import time import uuid -from typing import Any, Iterable, Optional +from typing import Iterable, Optional import pyarrow.parquet as pq @@ -70,7 +70,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: def _write_single_block(block: Block, project_id: str, dataset: str) -> None: from google.api_core import exceptions from google.cloud import bigquery @@ -127,5 +127,3 @@ def _write_single_block(block: Block, project_id: str, dataset: str) -> None: for block in blocks ] ) - - return "ok" diff --git a/python/ray/data/datasource/datasink.py b/python/ray/data/datasource/datasink.py index f77c3b93f3d2..5354bbe41cc3 100644 --- a/python/ray/data/datasource/datasink.py +++ b/python/ray/data/datasource/datasink.py @@ -1,10 +1,51 @@ -from typing import Any, Iterable, List, Optional +import logging +from dataclasses import dataclass, fields +from typing import Iterable, List, Optional import ray from ray.data._internal.execution.interfaces import TaskContext from ray.data.block import Block, BlockAccessor from ray.util.annotations import DeveloperAPI +logger = logging.getLogger(__name__) + + +@dataclass +@DeveloperAPI +class WriteResult: + """Result of a write operation, containing stats/metrics + on the written data. + + Attributes: + total_num_rows: The total number of rows written. + total_size_bytes: The total size of the written data in bytes. + """ + + num_rows: int = 0 + size_bytes: int = 0 + + @staticmethod + def aggregate_write_results(write_results: List["WriteResult"]) -> "WriteResult": + """Aggregate a list of write results. + + Args: + write_results: A list of write results. + + Returns: + A single write result that aggregates the input results. + """ + total_num_rows = 0 + total_size_bytes = 0 + + for write_result in write_results: + total_num_rows += write_result.num_rows + total_size_bytes += write_result.size_bytes + + return WriteResult( + num_rows=total_num_rows, + size_bytes=total_size_bytes, + ) + @DeveloperAPI class Datasink: @@ -26,7 +67,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: """Write blocks. This is used by a single write task. Args: @@ -39,7 +80,7 @@ def write( """ raise NotImplementedError - def on_write_complete(self, write_results: List[Any]) -> None: + def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: """Callback for when a write job completes. This can be used to "commit" a write output. This method must @@ -50,6 +91,27 @@ def on_write_complete(self, write_results: List[Any]) -> None: write_results: The objects returned by every :meth:`~Datasink.write` task. """ pass + write_result_blocks: The blocks resulting from executing + the Write operator, containing write results and stats. + Returns: + A ``WriteResult`` object containing the aggregated stats of all + the input write results. + """ + write_results = [ + result["write_result"].iloc[0] for result in write_result_blocks + ] + aggregated_write_results = WriteResult.aggregate_write_results(write_results) + + aggregated_results_str = "" + for k in fields(aggregated_write_results.__class__): + v = getattr(aggregated_write_results, k.name) + aggregated_results_str += f"\t{k}: {v}\n" + + logger.info( + f"Write operation succeeded. Aggregated write results:\n" + f"{aggregated_results_str}" + ) + return aggregated_write_results def on_write_failed(self, error: Exception) -> None: """Callback for when a write job fails. @@ -110,10 +172,9 @@ def __init__(self): self.rows_written = 0 self.enabled = True - def write(self, block: Block) -> str: + def write(self, block: Block) -> None: block = BlockAccessor.for_block(block) self.rows_written += block.num_rows() - return "ok" def get_rows_written(self): return self.rows_written @@ -127,18 +188,18 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: tasks = [] if not self.enabled: raise ValueError("disabled") for b in blocks: tasks.append(self.data_sink.write.remote(b)) ray.get(tasks) - return "ok" - def on_write_complete(self, write_results: List[Any]) -> None: - assert all(w == "ok" for w in write_results), write_results + def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: self.num_ok += 1 + aggregated_results = super().on_write_complete(write_result_blocks) + return aggregated_results def on_write_failed(self, error: Exception) -> None: self.num_failed += 1 diff --git a/python/ray/data/datasource/file_datasink.py b/python/ray/data/datasource/file_datasink.py index c4bf3fb6c867..ac30e0500b5c 100644 --- a/python/ray/data/datasource/file_datasink.py +++ b/python/ray/data/datasource/file_datasink.py @@ -8,8 +8,8 @@ from ray.data._internal.util import _is_local_scheme, call_with_retry from ray.data.block import Block, BlockAccessor from ray.data.context import DataContext +from ray.data.datasource.datasink import Datasink, WriteResult from ray.data.datasource.block_path_provider import BlockWritePathProvider -from ray.data.datasource.datasink import Datasink from ray.data.datasource.filename_provider import ( FilenameProvider, _DefaultFilenameProvider, @@ -106,7 +106,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: builder = DelegatingBlockBuilder() for block in blocks: builder.add_block(block) @@ -114,23 +114,21 @@ def write( block_accessor = BlockAccessor.for_block(block) if block_accessor.num_rows() == 0: - logger.get_logger().warning(f"Skipped writing empty block to {self.path}") - return "skip" + logger.warning(f"Skipped writing empty block to {self.path}") + return self.write_block(block_accessor, 0, ctx) - # TODO: decide if we want to return richer object when the task - # succeeds. - return "ok" def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext): raise NotImplementedError - def on_write_complete(self, write_results: List[Any]) -> None: - if not self.has_created_dir: - return + def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: + aggregated_results = super().on_write_complete(write_result_blocks) - if all(write_results == "skip" for write_results in write_results): + # If no rows were written, we can delete the directory. + if self.has_created_dir and aggregated_results.num_rows == 0: self.filesystem.delete_dir(self.path) + return aggregated_results @property def supports_distributed_writes(self) -> bool: diff --git a/python/ray/data/datasource/mongo_datasink.py b/python/ray/data/datasource/mongo_datasink.py index f2c20355a272..11909845db0b 100644 --- a/python/ray/data/datasource/mongo_datasink.py +++ b/python/ray/data/datasource/mongo_datasink.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Iterable +from typing import Iterable from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.execution.interfaces import TaskContext @@ -24,7 +24,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: import pymongo _validate_database_collection_exist( @@ -44,5 +44,3 @@ def write_block(uri: str, database: str, collection: str, block: Block): block = builder.build() write_block(self.uri, self.database, self.collection, block) - - return "ok" diff --git a/python/ray/data/datasource/parquet_datasink.py b/python/ray/data/datasource/parquet_datasink.py index a8e085e5e0f3..834ff845be61 100644 --- a/python/ray/data/datasource/parquet_datasink.py +++ b/python/ray/data/datasource/parquet_datasink.py @@ -57,13 +57,13 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: import pyarrow.parquet as pq blocks = list(blocks) if all(BlockAccessor.for_block(block).num_rows() == 0 for block in blocks): - return "skip" + return filename = self.filename_provider.get_filename_for_block( blocks[0], ctx.task_idx, 0 @@ -90,8 +90,6 @@ def write_blocks_to_path(): max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS, ) - return "ok" - @property def num_rows_per_write(self) -> Optional[int]: return self.num_rows_per_file diff --git a/python/ray/data/datasource/sql_datasink.py b/python/ray/data/datasource/sql_datasink.py index f29480ae6b1c..8807956f4e94 100644 --- a/python/ray/data/datasource/sql_datasink.py +++ b/python/ray/data/datasource/sql_datasink.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable +from typing import Callable, Iterable from ray.data._internal.execution.interfaces import TaskContext from ray.data.block import Block, BlockAccessor @@ -18,7 +18,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: with _connect(self.connection_factory) as cursor: for block in blocks: block_accessor = BlockAccessor.for_block(block) @@ -33,5 +33,3 @@ def write( if values: cursor.executemany(self.sql, values) - - return "ok" diff --git a/python/ray/data/tests/test_bigquery.py b/python/ray/data/tests/test_bigquery.py index e26d2a383437..7e571f7ec841 100644 --- a/python/ray/data/tests/test_bigquery.py +++ b/python/ray/data/tests/test_bigquery.py @@ -1,3 +1,4 @@ +from typing import Iterator from unittest import mock import pyarrow as pa @@ -8,7 +9,11 @@ from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream import ray -from ray.data.datasource import BigQueryDatasource, _BigQueryDatasink +from ray.data._internal.datasource.bigquery_datasink import BigQueryDatasink +from ray.data._internal.datasource.bigquery_datasource import BigQueryDatasource +from ray.data._internal.planner.plan_write_op import generate_collect_write_stats_fn +from ray.data.block import Block +from ray.data.datasource.datasink import WriteResult from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa from ray.tests.conftest import * # noqa @@ -196,6 +201,9 @@ def test_create_reader_table_not_found(self): class TestWriteBigQuery: """Tests for BigQuery Write.""" + def _extract_write_result(self, stats: Iterator[Block]): + return dict(next(stats).iloc[0])["write_result"] + def test_write(self, ray_get_mock): bq_datasink = _BigQueryDatasink( project_id=_TEST_GCP_PROJECT_ID, @@ -203,11 +211,15 @@ def test_write(self, ray_get_mock): ) arr = pa.array([2, 4, 5, 100]) block = pa.Table.from_arrays([arr], names=["data"]) - status = bq_datasink.write( + bq_datasink.write( blocks=[block], ctx=None, ) - assert status == "ok" + + collect_stats_fn = generate_collect_write_stats_fn() + stats = collect_stats_fn([block], None) + write_result = self._extract_write_result(stats) + assert write_result == WriteResult(num_rows=4, size_bytes=32) def test_write_dataset_exists(self, ray_get_mock): bq_datasink = _BigQueryDatasink( @@ -216,11 +228,14 @@ def test_write_dataset_exists(self, ray_get_mock): ) arr = pa.array([2, 4, 5, 100]) block = pa.Table.from_arrays([arr], names=["data"]) - status = bq_datasink.write( + bq_datasink.write( blocks=[block], ctx=None, ) - assert status == "ok" + collect_stats_fn = generate_collect_write_stats_fn() + stats = collect_stats_fn([block], None) + write_result = self._extract_write_result(stats) + assert write_result == WriteResult(num_rows=4, size_bytes=32) if __name__ == "__main__": diff --git a/python/ray/data/tests/test_datasink.py b/python/ray/data/tests/test_datasink.py index 772078490601..714f03c6dfe3 100644 --- a/python/ray/data/tests/test_datasink.py +++ b/python/ray/data/tests/test_datasink.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable +from typing import Iterable import pytest @@ -14,7 +14,7 @@ class MockDatasink(Datasink): def __init__(self, num_rows_per_write): self._num_rows_per_write = num_rows_per_write - def write(self, blocks: Iterable[Block], ctx: TaskContext) -> Any: + def write(self, blocks: Iterable[Block], ctx: TaskContext) -> None: assert sum(len(block) for block in blocks) == self._num_rows_per_write @property diff --git a/python/ray/data/tests/test_formats.py b/python/ray/data/tests/test_formats.py index 8039c173e4b8..706b2e60bfc0 100644 --- a/python/ray/data/tests/test_formats.py +++ b/python/ray/data/tests/test_formats.py @@ -1,5 +1,6 @@ import os -from typing import Any, Iterable, List +import sys +from typing import Iterable, List import pandas as pd import pyarrow as pa @@ -14,6 +15,7 @@ from ray.data._internal.execution.interfaces import TaskContext from ray.data.block import Block, BlockAccessor from ray.data.datasource import Datasink, DummyOutputDatasink +from ray.data.datasource.datasink import WriteResult from ray.data.datasource.file_meta_provider import _handle_read_os_error from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa @@ -219,7 +221,6 @@ def write(self, node_id: str, block: Block) -> str: block = BlockAccessor.for_block(block) self.rows_written += block.num_rows() self.node_ids.add(node_id) - return "ok" def get_rows_written(self): return self.rows_written @@ -235,7 +236,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: data_sink = self.data_sink def write(b): @@ -246,11 +247,11 @@ def write(b): for b in blocks: tasks.append(write(b)) ray.get(tasks) - return "ok" - def on_write_complete(self, write_results: List[Any]) -> None: - assert all(w == "ok" for w in write_results), write_results + def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: self.num_ok += 1 + aggregated_results = super().on_write_complete(write_result_blocks) + return aggregated_results def on_write_failed(self, error: Exception) -> None: self.num_failed += 1 From 0bef7cfdbf6df18e9148c455abc1cc856859ecce Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 27 Dec 2024 16:55:17 -0800 Subject: [PATCH 4/8] backport write result --- doc/source/conf.py | 12 ++ doc/source/data/api/input_output.rst | 2 + .../data/_internal/planner/plan_write_op.py | 34 +++- python/ray/data/dataset.py | 17 +- python/ray/data/datasource/__init__.py | 9 ++ .../ray/data/datasource/bigquery_datasink.py | 2 +- python/ray/data/datasource/datasink.py | 89 +++-------- python/ray/data/datasource/file_datasink.py | 17 +- python/ray/data/datasource/mongo_datasink.py | 2 +- python/ray/data/datasource/sql_datasink.py | 2 +- python/ray/data/tests/test_bigquery.py | 40 +++-- python/ray/data/tests/test_datasink.py | 148 +++++++++++++++++- python/ray/data/tests/test_formats.py | 102 +----------- 13 files changed, 274 insertions(+), 202 deletions(-) diff --git a/doc/source/conf.py b/doc/source/conf.py index 49a062315362..ff1c99b62c49 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -92,6 +92,18 @@ myst_heading_anchors = 3 +# Make broken internal references into build time errors. +# See https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-nitpicky +# for more information. :py:class: references are ignored due to false positives +# arising from type annotations. See https://github.com/ray-project/ray/pull/46103 +# for additional context. +nitpicky = True +nitpick_ignore_regex = [ + ("py:class", ".*"), + # Workaround for https://github.com/sphinx-doc/sphinx/issues/10974 + ("py:obj", "ray\.data\.datasource\.datasink\.WriteReturnType"), +] + # Cache notebook outputs in _build/.jupyter_cache # To prevent notebook execution, set this to "off". To force re-execution, set this to # "force". To cache previous runs, set this to "cache". diff --git a/doc/source/data/api/input_output.rst b/doc/source/data/api/input_output.rst index aafc0c91c1dc..bd29acbe98f3 100644 --- a/doc/source/data/api/input_output.rst +++ b/doc/source/data/api/input_output.rst @@ -317,6 +317,8 @@ Datasink API datasource.RowBasedFileDatasink datasource.BlockBasedFileDatasink datasource.FileBasedDatasource + datasource.WriteResult + datasource.WriteReturnType Partitioning API ---------------- diff --git a/python/ray/data/_internal/planner/plan_write_op.py b/python/ray/data/_internal/planner/plan_write_op.py index 39184b5d1279..ab61ea90d7b6 100644 --- a/python/ray/data/_internal/planner/plan_write_op.py +++ b/python/ray/data/_internal/planner/plan_write_op.py @@ -1,6 +1,8 @@ import itertools from typing import Callable, Iterator, List, Union +from pandas import DataFrame + from ray.data._internal.compute import TaskPoolStrategy from ray.data._internal.execution.interfaces import PhysicalOperator from ray.data._internal.execution.interfaces.task_context import TaskContext @@ -15,19 +17,36 @@ from ray.data.datasource.datasource import Datasource +def gen_datasink_write_result( + write_result_blocks: List[Block], +) -> WriteResult: + assert all( + isinstance(block, DataFrame) and len(block) == 1 + for block in write_result_blocks + ) + total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks) + total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks) + + write_returns = [result["write_return"][0] for result in write_result_blocks] + return WriteResult(total_num_rows, total_size_bytes, write_returns) + + def generate_write_fn( datasink_or_legacy_datasource: Union[Datasink, Datasource], **write_args ) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]: - def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: + def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: """Writes the blocks to the given datasink or legacy datasource. Outputs the original blocks to be written.""" # Create a copy of the iterator, so we can return the original blocks. it1, it2 = itertools.tee(blocks, 2) if isinstance(datasink_or_legacy_datasource, Datasink): - datasink_or_legacy_datasource.write(it1, ctx) + ctx.kwargs["_datasink_write_return"] = datasink_or_legacy_datasource.write( + it1, ctx + ) else: datasink_or_legacy_datasource.write(it1, ctx, **write_args) + return it2 return fn @@ -40,7 +59,7 @@ def generate_collect_write_stats_fn() -> Callable[ # one Block which contain stats/metrics about the write. # Otherwise, an error will be raised. The Datasource can handle # execution outcomes with `on_write_complete()`` and `on_write_failed()``. - def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: + def fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: """Handles stats collection for block writes.""" block_accessors = [BlockAccessor.for_block(block) for block in blocks] total_num_rows = sum(ba.num_rows() for ba in block_accessors) @@ -50,8 +69,13 @@ def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: # type. import pandas as pd - write_result = WriteResult(num_rows=total_num_rows, size_bytes=total_size_bytes) - block = pd.DataFrame({"write_result": [write_result]}) + block = pd.DataFrame( + { + "num_rows": [total_num_rows], + "size_bytes": [total_size_bytes], + "write_return": [ctx.kwargs.get("_datasink_write_return", None)], + } + ) return iter([block]) return fn diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 516fafc866dc..6d1c3c083e92 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -35,6 +35,10 @@ from ray.data._internal.equalize import _equalize from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.execution.legacy_compat import _block_list_to_bundles +from ray.data._internal.execution.interfaces.ref_bundle import ( + _ref_bundles_iterator_to_block_refs_list, +) +from ray.data._internal.execution.util import memory_string from ray.data._internal.iterator.iterator_impl import DataIteratorImpl from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator from ray.data._internal.lazy_block_list import LazyBlockList @@ -61,6 +65,7 @@ from ray.data._internal.pandas_block import PandasBlockSchema from ray.data._internal.plan import ExecutionPlan from ray.data._internal.planner.exchange.sort_task_spec import SortKey +from ray.data._internal.planner.plan_write_op import gen_datasink_write_result from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.split import _get_num_rows, _split_at_indices from ray.data._internal.stats import DatasetStats, DatasetStatsSummary, StatsManager @@ -3631,7 +3636,6 @@ def write_datasink( logical_plan = LogicalPlan(write_op) try: - import pandas as pd datasink.on_write_start() @@ -3639,11 +3643,14 @@ def write_datasink( # TODO: Get and handle the blocks with an iterator instead of getting # everything in a blocking way, so some blocks can be freed earlier. raw_write_results = ray.get(self._write_ds._plan.execute().block_refs) - assert all( - isinstance(block, pd.DataFrame) and len(block) == 1 - for block in raw_write_results + write_result = gen_datasink_write_result(raw_write_results) + logger.info( + "Data sink %s finished. %d rows and %s data written.", + datasink.get_name(), + write_result.num_rows, + memory_string(write_result.size_bytes), ) - datasink.on_write_complete(raw_write_results) + datasink.on_write_complete(write_result) except Exception as e: datasink.on_write_failed(e) diff --git a/python/ray/data/datasource/__init__.py b/python/ray/data/datasource/__init__.py index 5f950ec99001..5b444fb8d782 100644 --- a/python/ray/data/datasource/__init__.py +++ b/python/ray/data/datasource/__init__.py @@ -8,6 +8,13 @@ from ray.data.datasource.csv_datasink import _CSVDatasink from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.datasource.datasink import Datasink, DummyOutputDatasink +from ray.data._internal.datasource.sql_datasource import Connection +from ray.data.datasource.datasink import ( + Datasink, + DummyOutputDatasink, + WriteResult, + WriteReturnType, +) from ray.data.datasource.datasource import ( Datasource, RandomIntRowDatasource, @@ -113,4 +120,6 @@ "_WebDatasetDatasink", "WebDatasetDatasource", "_S3FileSystemWrapper", + "WriteResult", + "WriteReturnType", ] diff --git a/python/ray/data/datasource/bigquery_datasink.py b/python/ray/data/datasource/bigquery_datasink.py index be1341a29ae5..196ba322b4f8 100644 --- a/python/ray/data/datasource/bigquery_datasink.py +++ b/python/ray/data/datasource/bigquery_datasink.py @@ -20,7 +20,7 @@ RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11 -class _BigQueryDatasink(Datasink): +class BigQueryDatasink(Datasink[None]): def __init__( self, project_id: str, diff --git a/python/ray/data/datasource/datasink.py b/python/ray/data/datasource/datasink.py index 5354bbe41cc3..666264c8f6e8 100644 --- a/python/ray/data/datasource/datasink.py +++ b/python/ray/data/datasource/datasink.py @@ -1,6 +1,6 @@ import logging -from dataclasses import dataclass, fields -from typing import Iterable, List, Optional +from dataclasses import dataclass +from typing import Generic, Iterable, List, Optional, TypeVar import ray from ray.data._internal.execution.interfaces import TaskContext @@ -10,45 +10,25 @@ logger = logging.getLogger(__name__) -@dataclass -@DeveloperAPI -class WriteResult: - """Result of a write operation, containing stats/metrics - on the written data. - - Attributes: - total_num_rows: The total number of rows written. - total_size_bytes: The total size of the written data in bytes. - """ - - num_rows: int = 0 - size_bytes: int = 0 +WriteReturnType = TypeVar("WriteReturnType") +"""Generic type for the return value of `Datasink.write`.""" - @staticmethod - def aggregate_write_results(write_results: List["WriteResult"]) -> "WriteResult": - """Aggregate a list of write results. - Args: - write_results: A list of write results. - - Returns: - A single write result that aggregates the input results. - """ - total_num_rows = 0 - total_size_bytes = 0 - - for write_result in write_results: - total_num_rows += write_result.num_rows - total_size_bytes += write_result.size_bytes +@dataclass +@DeveloperAPI +class WriteResult(Generic[WriteReturnType]): + """Aggregated result of the Datasink write operations.""" - return WriteResult( - num_rows=total_num_rows, - size_bytes=total_size_bytes, - ) + # Total number of written rows. + num_rows: int + # Total size in bytes of written data. + size_bytes: int + # All returned values of `Datasink.write`. + write_returns: List[WriteReturnType] @DeveloperAPI -class Datasink: +class Datasink(Generic[WriteReturnType]): """Interface for defining write-related logic. If you want to write data to something that isn't built-in, subclass this class @@ -67,7 +47,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> None: + ) -> WriteReturnType: """Write blocks. This is used by a single write task. Args: @@ -75,12 +55,13 @@ def write( ctx: ``TaskContext`` for the write task. Returns: - A user-defined output. Can be anything, and the returned value is passed to - :meth:`~Datasink.on_write_complete`. + Result of this write task. When the entire write operator finishes, + All returned values will be passed as `WriteResult.write_returns` + to `Datasink.on_write_complete`. """ raise NotImplementedError - def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: + def on_write_complete(self, write_result: WriteResult[WriteReturnType]): """Callback for when a write job completes. This can be used to "commit" a write output. This method must @@ -88,30 +69,10 @@ def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: method fails, then ``on_write_failed()`` is called. Args: - write_results: The objects returned by every :meth:`~Datasink.write` task. - """ - pass - write_result_blocks: The blocks resulting from executing + write_result: Aggregated result of the the Write operator, containing write results and stats. - Returns: - A ``WriteResult`` object containing the aggregated stats of all - the input write results. """ - write_results = [ - result["write_result"].iloc[0] for result in write_result_blocks - ] - aggregated_write_results = WriteResult.aggregate_write_results(write_results) - - aggregated_results_str = "" - for k in fields(aggregated_write_results.__class__): - v = getattr(aggregated_write_results, k.name) - aggregated_results_str += f"\t{k}: {v}\n" - - logger.info( - f"Write operation succeeded. Aggregated write results:\n" - f"{aggregated_results_str}" - ) - return aggregated_write_results + pass def on_write_failed(self, error: Exception) -> None: """Callback for when a write job fails. @@ -151,7 +112,7 @@ def num_rows_per_write(self) -> Optional[int]: @DeveloperAPI -class DummyOutputDatasink(Datasink): +class DummyOutputDatasink(Datasink[None]): """An example implementation of a writable datasource for testing. Examples: >>> import ray @@ -196,10 +157,8 @@ def write( tasks.append(self.data_sink.write.remote(b)) ray.get(tasks) - def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: + def on_write_complete(self, write_result: WriteResult[None]): self.num_ok += 1 - aggregated_results = super().on_write_complete(write_result_blocks) - return aggregated_results def on_write_failed(self, error: Exception) -> None: self.num_failed += 1 diff --git a/python/ray/data/datasource/file_datasink.py b/python/ray/data/datasource/file_datasink.py index ac30e0500b5c..6266e1fedbfa 100644 --- a/python/ray/data/datasource/file_datasink.py +++ b/python/ray/data/datasource/file_datasink.py @@ -1,5 +1,5 @@ import posixpath -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional from ray._private.utils import _add_creatable_buckets_param_if_s3_uri from ray.data._internal.dataset_logger import DatasetLogger @@ -27,7 +27,7 @@ WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32 -class _FileDatasink(Datasink): +class _FileDatasink(Datasink[None]): def __init__( self, path: str, @@ -122,13 +122,10 @@ def write( def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext): raise NotImplementedError - def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: - aggregated_results = super().on_write_complete(write_result_blocks) - + def on_write_complete(self, write_result: WriteResult[None]): # If no rows were written, we can delete the directory. - if self.has_created_dir and aggregated_results.num_rows == 0: + if self.has_created_dir and write_result.num_rows == 0: self.filesystem.delete_dir(self.path) - return aggregated_results @property def supports_distributed_writes(self) -> bool: @@ -181,13 +178,15 @@ def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext): ) write_path = posixpath.join(self.path, filename) - def write_row_to_path(): + def write_row_to_path(row, write_path): with self.open_output_stream(write_path) as file: self.write_row_to_file(row, file) logger.get_logger(log_to_stdout=False).debug(f"Writing {write_path} file.") call_with_retry( - write_row_to_path, + lambda row=row, write_path=write_path: write_row_to_path( + row, write_path + ), description=f"write '{write_path}'", match=DataContext.get_current().write_file_retry_on_errors, max_attempts=WRITE_FILE_MAX_ATTEMPTS, diff --git a/python/ray/data/datasource/mongo_datasink.py b/python/ray/data/datasource/mongo_datasink.py index 11909845db0b..c2b0cd2d3566 100644 --- a/python/ray/data/datasource/mongo_datasink.py +++ b/python/ray/data/datasource/mongo_datasink.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -class _MongoDatasink(Datasink): +class MongoDatasink(Datasink[None]): def __init__(self, uri: str, database: str, collection: str) -> None: _check_import(self, module="pymongo", package="pymongo") _check_import(self, module="pymongoarrow", package="pymongoarrow") diff --git a/python/ray/data/datasource/sql_datasink.py b/python/ray/data/datasource/sql_datasink.py index 8807956f4e94..582cbfe0101d 100644 --- a/python/ray/data/datasource/sql_datasink.py +++ b/python/ray/data/datasource/sql_datasink.py @@ -6,7 +6,7 @@ from ray.data.datasource.sql_datasource import Connection, _connect -class _SQLDatasink(Datasink): +class SQLDatasink(Datasink[None]): _MAX_ROWS_PER_WRITE = 128 diff --git a/python/ray/data/tests/test_bigquery.py b/python/ray/data/tests/test_bigquery.py index 7e571f7ec841..c56b2f823215 100644 --- a/python/ray/data/tests/test_bigquery.py +++ b/python/ray/data/tests/test_bigquery.py @@ -1,6 +1,7 @@ from typing import Iterator from unittest import mock +import pandas as pd import pyarrow as pa import pytest from google.api_core import exceptions, operation @@ -11,9 +12,9 @@ import ray from ray.data._internal.datasource.bigquery_datasink import BigQueryDatasink from ray.data._internal.datasource.bigquery_datasource import BigQueryDatasource +from ray.data._internal.execution.interfaces.task_context import TaskContext from ray.data._internal.planner.plan_write_op import generate_collect_write_stats_fn from ray.data.block import Block -from ray.data.datasource.datasink import WriteResult from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa from ray.tests.conftest import * # noqa @@ -23,7 +24,6 @@ _TEST_BQ_TABLE_ID = "mocktable" _TEST_BQ_DATASET = _TEST_BQ_DATASET_ID + "." + _TEST_BQ_TABLE_ID _TEST_BQ_TEMP_DESTINATION = _TEST_GCP_PROJECT_ID + ".tempdataset.temptable" -_TEST_DISPLAY_NAME = "display_name" @pytest.fixture(autouse=True) @@ -202,7 +202,7 @@ class TestWriteBigQuery: """Tests for BigQuery Write.""" def _extract_write_result(self, stats: Iterator[Block]): - return dict(next(stats).iloc[0])["write_result"] + return dict(next(stats).iloc[0]) def test_write(self, ray_get_mock): bq_datasink = _BigQueryDatasink( @@ -211,15 +211,24 @@ def test_write(self, ray_get_mock): ) arr = pa.array([2, 4, 5, 100]) block = pa.Table.from_arrays([arr], names=["data"]) + ctx = TaskContext(1) bq_datasink.write( blocks=[block], - ctx=None, + ctx=ctx, ) collect_stats_fn = generate_collect_write_stats_fn() - stats = collect_stats_fn([block], None) - write_result = self._extract_write_result(stats) - assert write_result == WriteResult(num_rows=4, size_bytes=32) + stats = collect_stats_fn([block], ctx) + pd.testing.assert_frame_equal( + next(stats), + pd.DataFrame( + { + "num_rows": [4], + "size_bytes": [32], + "write_return": [None], + } + ), + ) def test_write_dataset_exists(self, ray_get_mock): bq_datasink = _BigQueryDatasink( @@ -228,14 +237,23 @@ def test_write_dataset_exists(self, ray_get_mock): ) arr = pa.array([2, 4, 5, 100]) block = pa.Table.from_arrays([arr], names=["data"]) + ctx = TaskContext(1) bq_datasink.write( blocks=[block], - ctx=None, + ctx=ctx, ) collect_stats_fn = generate_collect_write_stats_fn() - stats = collect_stats_fn([block], None) - write_result = self._extract_write_result(stats) - assert write_result == WriteResult(num_rows=4, size_bytes=32) + stats = collect_stats_fn([block], ctx) + pd.testing.assert_frame_equal( + next(stats), + pd.DataFrame( + { + "num_rows": [4], + "size_bytes": [32], + "write_return": [None], + } + ), + ) if __name__ == "__main__": diff --git a/python/ray/data/tests/test_datasink.py b/python/ray/data/tests/test_datasink.py index 714f03c6dfe3..8b5eaff9f2c1 100644 --- a/python/ray/data/tests/test_datasink.py +++ b/python/ray/data/tests/test_datasink.py @@ -1,16 +1,113 @@ -from typing import Iterable +from dataclasses import dataclass +from typing import Iterable, List +import numpy import pytest import ray from ray.data._internal.execution.interfaces import TaskContext -from ray.data.block import Block +from ray.data.block import Block, BlockAccessor from ray.data.datasource import Datasink +from ray.data.datasource.datasink import DummyOutputDatasink, WriteResult + + +def test_write_datasink(ray_start_regular_shared): + output = DummyOutputDatasink() + ds = ray.data.range(10, override_num_blocks=2) + ds.write_datasink(output) + assert output.num_ok == 1 + assert output.num_failed == 0 + assert ray.get(output.data_sink.get_rows_written.remote()) == 10 + + output.enabled = False + ds = ray.data.range(10, override_num_blocks=2) + with pytest.raises(ValueError): + ds.write_datasink(output, ray_remote_args={"max_retries": 0}) + assert output.num_ok == 1 + assert output.num_failed == 1 + assert ray.get(output.data_sink.get_rows_written.remote()) == 10 + + +class NodeLoggerOutputDatasink(Datasink[None]): + """A writable datasource that logs node IDs of write tasks, for testing.""" + + def __init__(self): + @ray.remote + class DataSink: + def __init__(self): + self.rows_written = 0 + self.node_ids = set() + + def write(self, node_id: str, block: Block) -> str: + block = BlockAccessor.for_block(block) + self.rows_written += block.num_rows() + self.node_ids.add(node_id) + + def get_rows_written(self): + return self.rows_written + + def get_node_ids(self): + return self.node_ids + + self.data_sink = DataSink.remote() + self.num_ok = 0 + self.num_failed = 0 + + def write( + self, + blocks: Iterable[Block], + ctx: TaskContext, + ) -> None: + data_sink = self.data_sink + + def write(b): + node_id = ray.get_runtime_context().get_node_id() + return data_sink.write.remote(node_id, b) + + tasks = [] + for b in blocks: + tasks.append(write(b)) + ray.get(tasks) + + def on_write_complete(self, write_result: WriteResult[None]): + self.num_ok += 1 + + def on_write_failed(self, error: Exception) -> None: + self.num_failed += 1 + + +def test_write_datasink_ray_remote_args(ray_start_cluster): + ray.shutdown() + cluster = ray_start_cluster + cluster.add_node( + resources={"foo": 100}, + num_cpus=1, + ) + cluster.add_node(resources={"bar": 100}, num_cpus=1) + + ray.init(cluster.address) + + @ray.remote + def get_node_id(): + return ray.get_runtime_context().get_node_id() + + bar_node_id = ray.get(get_node_id.options(resources={"bar": 1}).remote()) + + output = NodeLoggerOutputDatasink() + ds = ray.data.range(100, override_num_blocks=10) + # Pin write tasks to node with "bar" resource. + ds.write_datasink(output, ray_remote_args={"resources": {"bar": 1}}) + assert output.num_ok == 1 + assert output.num_failed == 0 + assert ray.get(output.data_sink.get_rows_written.remote()) == 100 + + node_ids = ray.get(output.data_sink.get_node_ids.remote()) + assert node_ids == {bar_node_id} @pytest.mark.parametrize("num_rows_per_write", [5, 10, 50]) def test_num_rows_per_write(tmp_path, ray_start_regular_shared, num_rows_per_write): - class MockDatasink(Datasink): + class MockDatasink(Datasink[None]): def __init__(self, num_rows_per_write): self._num_rows_per_write = num_rows_per_write @@ -26,6 +123,51 @@ def num_rows_per_write(self): ) +def test_write_result(ray_start_regular_shared): + """Test the write_result argument in `on_write_complete`.""" + + @dataclass + class CustomWriteResult: + + ids: List[int] + + class CustomDatasink(Datasink[CustomWriteResult]): + def __init__(self) -> None: + self.ids = [] + self.num_rows = 0 + self.size_bytes = 0 + + def write(self, blocks: Iterable[Block], ctx: TaskContext): + ids = [] + for b in blocks: + ids.extend(b["id"].to_pylist()) + return CustomWriteResult(ids=ids) + + def on_write_complete(self, write_result: WriteResult[CustomWriteResult]): + ids = [] + for result in write_result.write_returns: + ids.extend(result.ids) + self.ids = sorted(ids) + self.num_rows = write_result.num_rows + self.size_bytes = write_result.size_bytes + + num_items = 100 + size_bytes_per_row = 1000 + + def map_fn(row): + row["data"] = numpy.zeros(size_bytes_per_row, dtype=numpy.int8) + return row + + ds = ray.data.range(num_items).map(map_fn) + + datasink = CustomDatasink() + ds.write_datasink(datasink) + + assert datasink.ids == list(range(num_items)) + assert datasink.num_rows == num_items + assert datasink.size_bytes == pytest.approx(num_items * size_bytes_per_row, rel=0.1) + + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_formats.py b/python/ray/data/tests/test_formats.py index 706b2e60bfc0..63ce2052eb06 100644 --- a/python/ray/data/tests/test_formats.py +++ b/python/ray/data/tests/test_formats.py @@ -1,6 +1,5 @@ import os import sys -from typing import Iterable, List import pandas as pd import pyarrow as pa @@ -12,10 +11,7 @@ import ray from ray._private.test_utils import wait_for_condition -from ray.data._internal.execution.interfaces import TaskContext -from ray.data.block import Block, BlockAccessor -from ray.data.datasource import Datasink, DummyOutputDatasink -from ray.data.datasource.datasink import WriteResult +from ray.data.block import BlockAccessor from ray.data.datasource.file_meta_provider import _handle_read_os_error from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa @@ -146,23 +142,6 @@ def test_read_example_data(ray_start_regular_shared, tmp_path): ] -def test_write_datasink(ray_start_regular_shared): - output = DummyOutputDatasink() - ds = ray.data.range(10, override_num_blocks=2) - ds.write_datasink(output) - assert output.num_ok == 1 - assert output.num_failed == 0 - assert ray.get(output.data_sink.get_rows_written.remote()) == 10 - - output.enabled = False - ds = ray.data.range(10, override_num_blocks=2) - with pytest.raises(ValueError): - ds.write_datasink(output, ray_remote_args={"max_retries": 0}) - assert output.num_ok == 1 - assert output.num_failed == 1 - assert ray.get(output.data_sink.get_rows_written.remote()) == 10 - - def test_from_tf(ray_start_regular_shared): import tensorflow as tf import tensorflow_datasets as tfds @@ -207,85 +186,6 @@ def __iter__(self): assert actual_data == expected_data -class NodeLoggerOutputDatasink(Datasink): - """A writable datasource that logs node IDs of write tasks, for testing.""" - - def __init__(self): - @ray.remote - class DataSink: - def __init__(self): - self.rows_written = 0 - self.node_ids = set() - - def write(self, node_id: str, block: Block) -> str: - block = BlockAccessor.for_block(block) - self.rows_written += block.num_rows() - self.node_ids.add(node_id) - - def get_rows_written(self): - return self.rows_written - - def get_node_ids(self): - return self.node_ids - - self.data_sink = DataSink.remote() - self.num_ok = 0 - self.num_failed = 0 - - def write( - self, - blocks: Iterable[Block], - ctx: TaskContext, - ) -> None: - data_sink = self.data_sink - - def write(b): - node_id = ray.get_runtime_context().get_node_id() - return data_sink.write.remote(node_id, b) - - tasks = [] - for b in blocks: - tasks.append(write(b)) - ray.get(tasks) - - def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: - self.num_ok += 1 - aggregated_results = super().on_write_complete(write_result_blocks) - return aggregated_results - - def on_write_failed(self, error: Exception) -> None: - self.num_failed += 1 - - -def test_write_datasink_ray_remote_args(ray_start_cluster): - ray.shutdown() - cluster = ray_start_cluster - cluster.add_node( - resources={"foo": 100}, - num_cpus=1, - ) - cluster.add_node(resources={"bar": 100}, num_cpus=1) - - ray.init(cluster.address) - - @ray.remote - def get_node_id(): - return ray.get_runtime_context().get_node_id() - - bar_node_id = ray.get(get_node_id.options(resources={"bar": 1}).remote()) - - output = NodeLoggerOutputDatasink() - ds = ray.data.range(100, override_num_blocks=10) - # Pin write tasks to node with "bar" resource. - ds.write_datasink(output, ray_remote_args={"resources": {"bar": 1}}) - assert output.num_ok == 1 - assert output.num_failed == 0 - assert ray.get(output.data_sink.get_rows_written.remote()) == 100 - - node_ids = ray.get(output.data_sink.get_node_ids.remote()) - assert node_ids == {bar_node_id} - - def test_read_s3_file_error(shutdown_only, s3_path): dummy_path = s3_path + "_dummy" error_message = "Please check that file exists and has properly configured access." From 5a039155ed92983541ae2eb277fc031ad54fb561 Mon Sep 17 00:00:00 2001 From: votrou Date: Mon, 31 Mar 2025 15:09:02 -0700 Subject: [PATCH 5/8] fix imports --- python/ray/data/dataset.py | 2 +- python/ray/data/datasource/bigquery_datasink.py | 2 +- python/ray/data/datasource/mongo_datasink.py | 2 +- python/ray/data/datasource/sql_datasink.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 6d1c3c083e92..1475c1653ce5 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -31,7 +31,7 @@ from ray.data._internal.block_list import BlockList from ray.data._internal.compute import ComputeStrategy from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder -from ray.data._internal.datasource.iceberg_datasink import IcebergDatasink +from ray.data.datasource.iceberg_datasink import IcebergDatasink from ray.data._internal.equalize import _equalize from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.execution.legacy_compat import _block_list_to_bundles diff --git a/python/ray/data/datasource/bigquery_datasink.py b/python/ray/data/datasource/bigquery_datasink.py index 196ba322b4f8..000962be29be 100644 --- a/python/ray/data/datasource/bigquery_datasink.py +++ b/python/ray/data/datasource/bigquery_datasink.py @@ -20,7 +20,7 @@ RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11 -class BigQueryDatasink(Datasink[None]): +class _BigQueryDatasink(Datasink[None]): def __init__( self, project_id: str, diff --git a/python/ray/data/datasource/mongo_datasink.py b/python/ray/data/datasource/mongo_datasink.py index c2b0cd2d3566..5dca4baf189d 100644 --- a/python/ray/data/datasource/mongo_datasink.py +++ b/python/ray/data/datasource/mongo_datasink.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -class MongoDatasink(Datasink[None]): +class _MongoDatasink(Datasink[None]): def __init__(self, uri: str, database: str, collection: str) -> None: _check_import(self, module="pymongo", package="pymongo") _check_import(self, module="pymongoarrow", package="pymongoarrow") diff --git a/python/ray/data/datasource/sql_datasink.py b/python/ray/data/datasource/sql_datasink.py index 582cbfe0101d..f80a8d8127b8 100644 --- a/python/ray/data/datasource/sql_datasink.py +++ b/python/ray/data/datasource/sql_datasink.py @@ -6,7 +6,7 @@ from ray.data.datasource.sql_datasource import Connection, _connect -class SQLDatasink(Datasink[None]): +class _SQLDatasink(Datasink[None]): _MAX_ROWS_PER_WRITE = 128 From b192cdf0d44785a08c32ee213149555edd10d301 Mon Sep 17 00:00:00 2001 From: votrou Date: Wed, 23 Apr 2025 11:21:42 -0700 Subject: [PATCH 6/8] fix bugs --- python/ray/data/datasource/__init__.py | 2 +- python/ray/data/tests/test_bigquery.py | 4 +- python/ray/data/tests/test_iceberg.py | 262 ------------------------- 3 files changed, 3 insertions(+), 265 deletions(-) delete mode 100644 python/ray/data/tests/test_iceberg.py diff --git a/python/ray/data/datasource/__init__.py b/python/ray/data/datasource/__init__.py index 5b444fb8d782..5176f0ea4c7c 100644 --- a/python/ray/data/datasource/__init__.py +++ b/python/ray/data/datasource/__init__.py @@ -8,7 +8,7 @@ from ray.data.datasource.csv_datasink import _CSVDatasink from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.datasource.datasink import Datasink, DummyOutputDatasink -from ray.data._internal.datasource.sql_datasource import Connection +from ray.data.datasource.sql_datasource import Connection from ray.data.datasource.datasink import ( Datasink, DummyOutputDatasink, diff --git a/python/ray/data/tests/test_bigquery.py b/python/ray/data/tests/test_bigquery.py index c56b2f823215..69ba4cd380da 100644 --- a/python/ray/data/tests/test_bigquery.py +++ b/python/ray/data/tests/test_bigquery.py @@ -10,8 +10,8 @@ from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream import ray -from ray.data._internal.datasource.bigquery_datasink import BigQueryDatasink -from ray.data._internal.datasource.bigquery_datasource import BigQueryDatasource +from ray.data.datasource.bigquery_datasink import BigQueryDatasink +from ray.data.datasource.bigquery_datasource import BigQueryDatasource from ray.data._internal.execution.interfaces.task_context import TaskContext from ray.data._internal.planner.plan_write_op import generate_collect_write_stats_fn from ray.data.block import Block diff --git a/python/ray/data/tests/test_iceberg.py b/python/ray/data/tests/test_iceberg.py deleted file mode 100644 index f89b50193827..000000000000 --- a/python/ray/data/tests/test_iceberg.py +++ /dev/null @@ -1,262 +0,0 @@ -import os -import random - -import pyarrow as pa -import pytest -from pkg_resources import parse_version -from pyiceberg import catalog as pyi_catalog -from pyiceberg import expressions as pyi_expr -from pyiceberg import schema as pyi_schema -from pyiceberg import types as pyi_types -from pyiceberg.partitioning import PartitionField, PartitionSpec -from pyiceberg.transforms import IdentityTransform - -import ray -from ray._private.utils import _get_pyarrow_version -from ray.data import read_iceberg -from ray.data._internal.datasource.iceberg_datasource import IcebergDatasource - -_CATALOG_NAME = "ray_catalog" -_DB_NAME = "ray_db" -_TABLE_NAME = "ray_test" -_WAREHOUSE_PATH = "/tmp/warehouse" - -_CATALOG_KWARGS = { - "name": _CATALOG_NAME, - "type": "sql", - "uri": f"sqlite:///{_WAREHOUSE_PATH}/ray_pyiceberg_test_catalog.db", - "warehouse": f"file://{_WAREHOUSE_PATH}", -} - -_SCHEMA = pa.schema( - [ - pa.field("col_a", pa.int32()), - pa.field("col_b", pa.string()), - pa.field("col_c", pa.int16()), - ] -) - - -def create_pa_table(): - return pa.Table.from_pydict( - mapping={ - "col_a": list(range(120)), - "col_b": random.choices(["a", "b", "c", "d"], k=120), - "col_c": random.choices(list(range(10)), k=120), - }, - schema=_SCHEMA, - ) - - -@pytest.fixture(autouse=True, scope="function") -def pyiceberg_table(): - from pyiceberg.catalog.sql import SqlCatalog - - if not os.path.exists(_WAREHOUSE_PATH): - os.makedirs(_WAREHOUSE_PATH) - dummy_catalog = SqlCatalog( - _CATALOG_NAME, - **{ - "uri": f"sqlite:///{_WAREHOUSE_PATH}/ray_pyiceberg_test_catalog.db", - "warehouse": f"file://{_WAREHOUSE_PATH}", - }, - ) - - pya_table = create_pa_table() - - if (_DB_NAME,) not in dummy_catalog.list_namespaces(): - dummy_catalog.create_namespace(_DB_NAME) - if (_DB_NAME, _TABLE_NAME) in dummy_catalog.list_tables(_DB_NAME): - dummy_catalog.drop_table(f"{_DB_NAME}.{_TABLE_NAME}") - - # Create the table, and add data to it - table = dummy_catalog.create_table( - f"{_DB_NAME}.{_TABLE_NAME}", - schema=pyi_schema.Schema( - pyi_types.NestedField( - field_id=1, - name="col_a", - field_type=pyi_types.IntegerType(), - required=False, - ), - pyi_types.NestedField( - field_id=2, - name="col_b", - field_type=pyi_types.StringType(), - required=False, - ), - pyi_types.NestedField( - field_id=3, - name="col_c", - field_type=pyi_types.IntegerType(), - required=False, - ), - ), - partition_spec=PartitionSpec( - PartitionField( - source_id=3, field_id=3, transform=IdentityTransform(), name="col_c" - ) - ), - ) - table.append(pya_table) - - # Delete some data so there are delete file(s) - table.delete(delete_filter=pyi_expr.GreaterThanOrEqual("col_a", 101)) - - -@pytest.mark.skipif( - parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), - reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", -) -def test_get_catalog(): - - iceberg_ds = IcebergDatasource( - table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", - catalog_kwargs=_CATALOG_KWARGS.copy(), - ) - catalog = iceberg_ds._get_catalog() - assert catalog.name == _CATALOG_NAME - - -@pytest.mark.skipif( - parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), - reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", -) -def test_plan_files(): - - iceberg_ds = IcebergDatasource( - table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", - catalog_kwargs=_CATALOG_KWARGS.copy(), - ) - plan_files = iceberg_ds.plan_files - assert len(plan_files) == 10 - - -@pytest.mark.skipif( - parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), - reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", -) -def test_chunk_plan_files(): - - iceberg_ds = IcebergDatasource( - table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", - catalog_kwargs=_CATALOG_KWARGS.copy(), - ) - - chunks = iceberg_ds._distribute_tasks_into_equal_chunks(iceberg_ds.plan_files, 5) - assert (len(c) == 2 for c in chunks), chunks - - chunks = iceberg_ds._distribute_tasks_into_equal_chunks(iceberg_ds.plan_files, 20) - assert ( - sum(len(c) == 1 for c in chunks) == 10 - and sum(len(c) == 0 for c in chunks) == 10 - ), chunks - - -@pytest.mark.skipif( - parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), - reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", -) -def test_get_read_tasks(): - - iceberg_ds = IcebergDatasource( - table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", - catalog_kwargs=_CATALOG_KWARGS.copy(), - ) - read_tasks = iceberg_ds.get_read_tasks(5) - assert len(read_tasks) == 5 - assert all(len(rt.metadata.input_files) == 2 for rt in read_tasks) - - -@pytest.mark.skipif( - parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), - reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", -) -def test_filtered_read(): - - from pyiceberg import expressions as pyi_expr - - iceberg_ds = IcebergDatasource( - table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", - row_filter=pyi_expr.In("col_c", {1, 2, 3, 4}), - selected_fields=("col_b",), - catalog_kwargs=_CATALOG_KWARGS.copy(), - ) - read_tasks = iceberg_ds.get_read_tasks(5) - # Should be capped to 4, as there will be only 4 files - assert len(read_tasks) == 4, read_tasks - assert all(len(rt.metadata.input_files) == 1 for rt in read_tasks) - - -@pytest.mark.skipif( - parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), - reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", -) -def test_read_basic(): - - row_filter = pyi_expr.In("col_c", {1, 2, 3, 4, 5, 6, 7, 8}) - - ray_ds = read_iceberg( - table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", - row_filter=row_filter, - selected_fields=("col_a", "col_b"), - catalog_kwargs=_CATALOG_KWARGS.copy(), - ) - table: pa.Table = pa.concat_tables((ray.get(ref) for ref in ray_ds.to_arrow_refs())) - - # string -> large_string because pyiceberg by default chooses large_string - expected_schema = pa.schema( - [pa.field("col_a", pa.int32()), pa.field("col_b", pa.large_string())] - ) - assert table.schema.equals(expected_schema) - - # Read the raw table from PyIceberg - sql_catalog = pyi_catalog.load_catalog(**_CATALOG_KWARGS) - orig_table_p = ( - sql_catalog.load_table(f"{_DB_NAME}.{_TABLE_NAME}") - .scan(row_filter=row_filter, selected_fields=("col_a", "col_b")) - .to_pandas() - .sort_values(["col_a", "col_b"]) - .reset_index(drop=True) - ) - - # Actually compare the tables now - table_p = ray_ds.to_pandas().sort_values(["col_a", "col_b"]).reset_index(drop=True) - assert orig_table_p.equals(table_p) - - -@pytest.mark.skipif( - parse_version(_get_pyarrow_version()) < parse_version("14.0.0"), - reason="PyIceberg 0.7.0 fails on pyarrow <= 14.0.0", -) -def test_write_basic(): - - sql_catalog = pyi_catalog.load_catalog(**_CATALOG_KWARGS) - table = sql_catalog.load_table(f"{_DB_NAME}.{_TABLE_NAME}") - table.delete() - - ds = ray.data.from_arrow(create_pa_table()) - ds.write_iceberg( - table_identifier=f"{_DB_NAME}.{_TABLE_NAME}", - catalog_kwargs=_CATALOG_KWARGS.copy(), - ) - - # Read the raw table from PyIceberg after writing - table = sql_catalog.load_table(f"{_DB_NAME}.{_TABLE_NAME}") - orig_table_p = ( - table.scan() - .to_pandas() - .sort_values(["col_a", "col_b", "col_c"]) - .reset_index(drop=True) - ) - - table_p = ( - ds.to_pandas().sort_values(["col_a", "col_b", "col_c"]).reset_index(drop=True) - ) - assert orig_table_p.equals(table_p) - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main(["-v", __file__])) From 48a8324292fdf3f6c5a22452532855156bdb0aa0 Mon Sep 17 00:00:00 2001 From: votrou Date: Wed, 23 Apr 2025 12:19:21 -0700 Subject: [PATCH 7/8] remove unused file --- python/ray/data/dataset.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 1475c1653ce5..73591229b92d 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -35,9 +35,6 @@ from ray.data._internal.equalize import _equalize from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.execution.legacy_compat import _block_list_to_bundles -from ray.data._internal.execution.interfaces.ref_bundle import ( - _ref_bundles_iterator_to_block_refs_list, -) from ray.data._internal.execution.util import memory_string from ray.data._internal.iterator.iterator_impl import DataIteratorImpl from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator From af7853e9b821d6aa05d51363b860a505795705b4 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 12 Dec 2024 07:14:09 +0800 Subject: [PATCH 8/8] [Data] support passing kwargs to map tasks. (#49208) This PR enables passing kwargs to map tasks, which will be accessible via `TaskContext.kwargs`. This is a prerequisite to fixing https://github.com/ray-project/ray/issues/49207. And optimization rules can use this API to pass additional arguments to the map tasks. --------- Signed-off-by: Hao Chen --- .../execution/interfaces/task_context.py | 5 ++++- .../operators/actor_pool_map_operator.py | 9 +++++++- .../execution/operators/map_operator.py | 21 +++++++++++++++++++ .../operators/task_pool_map_operator.py | 3 ++- .../logical/rules/operator_fusion.py | 6 ++++++ python/ray/data/tests/test_operators.py | 1 + 6 files changed, 42 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/execution/interfaces/task_context.py b/python/ray/data/_internal/execution/interfaces/task_context.py index 99431125a0ad..094faf2440e0 100644 --- a/python/ray/data/_internal/execution/interfaces/task_context.py +++ b/python/ray/data/_internal/execution/interfaces/task_context.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Optional from ray.data._internal.progress_bar import ProgressBar @@ -39,3 +39,6 @@ class TaskContext: # The target maximum number of bytes to include in the task's output block. target_max_block_size: Optional[int] = None + + # Additional keyword arguments passed to the task. + kwargs: Dict[str, Any] = field(default_factory=dict) diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index acf811f32a73..d8c88a242549 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -212,7 +212,12 @@ def _dispatch_tasks(self): num_returns="streaming", name=self.name, **self._ray_actor_task_remote_args, - ).remote(DataContext.get_current(), ctx, *input_blocks) + ).remote( + DataContext.get_current(), + ctx, + *input_blocks, + **self.get_map_task_kwargs(), + ) def _task_done_callback(actor_to_return): # Return the actor that was running the task to the pool. @@ -401,12 +406,14 @@ def submit( data_context: DataContext, ctx: TaskContext, *blocks: Block, + **kwargs: Dict[str, Any], ) -> Iterator[Union[Block, List[BlockMetadata]]]: yield from _map_task( self._map_transformer, data_context, ctx, *blocks, + **kwargs, ) def __repr__(self): diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 6f9992faf530..29079d2c9553 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -85,6 +85,25 @@ def __init__( # too-large blocks, which may reduce parallelism for # the subsequent operator. self._additional_split_factor = None + # Callback functions that generate additional task kwargs + # for the map task. + self._map_task_kwargs_fns: List[Callable[[], Dict[str, Any]]] = [] + + def add_map_task_kwargs_fn(self, map_task_kwargs_fn: Callable[[], Dict[str, Any]]): + """Add a callback function that generates additional kwargs for the map tasks. + In the map tasks, the kwargs can be accessible via `TaskContext.kwargs`. + """ + self._map_task_kwargs_fns.append(map_task_kwargs_fn) + + def get_map_task_kwargs(self) -> Dict[str, Any]: + """Get the kwargs for the map task. + Subclasses should pass the returned kwargs to the map tasks. + In the map tasks, the kwargs can be accessible via `TaskContext.kwargs`. + """ + kwargs = {} + for fn in self._map_task_kwargs_fns: + kwargs.update(fn()) + return kwargs def get_additional_split_factor(self) -> int: if self._additional_split_factor is None: @@ -402,6 +421,7 @@ def _map_task( data_context: DataContext, ctx: TaskContext, *blocks: Block, + **kwargs: Dict[str, Any], ) -> Iterator[Union[Block, List[BlockMetadata]]]: """Remote function for a single operator task. @@ -415,6 +435,7 @@ def _map_task( as the last generator return. """ DataContext._set_current(data_context) + ctx.kwargs.update(kwargs) stats = BlockExecStats.builder() map_transformer.set_target_max_block_size(ctx.target_max_block_size) for b_out in map_transformer.apply_transform(iter(blocks), ctx): diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index 2d84dd1bc111..a0c5dc3de733 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -76,7 +76,8 @@ def _add_bundled_input(self, bundle: RefBundle): self._map_transformer_ref, data_context, ctx, - *input_blocks, + *bundle.block_refs, + **self.get_map_task_kwargs(), ) self._submit_data_task(gen, bundle) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index d555db84d6c3..79d664371aae 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -1,3 +1,4 @@ +import itertools from typing import List, Optional, Tuple # TODO(Clark): Remove compute dependency once we delete the legacy compute. @@ -311,6 +312,11 @@ def _get_fused_map_operator( min_rows_per_bundle=min_rows_per_bundled_input, ray_remote_args=ray_remote_args, ) + op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators) + for map_task_kwargs_fn in itertools.chain( + up_op._map_task_kwargs_fns, down_op._map_task_kwargs_fns + ): + op.add_map_task_kwargs_fn(map_task_kwargs_fn) # Build a map logical operator to be used as a reference for further fusion. # TODO(Scott): This is hacky, remove this once we push fusion to be purely based diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py index c962a50a6b56..974d6aa1e2e5 100644 --- a/python/ray/data/tests/test_operators.py +++ b/python/ray/data/tests/test_operators.py @@ -16,6 +16,7 @@ PhysicalOperator, RefBundle, ) +from ray.data._internal.execution.interfaces.task_context import TaskContext from ray.data._internal.execution.operators.actor_pool_map_operator import ( ActorPoolMapOperator, )