diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index eb74a614074..5d2b85ab018 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -839,6 +839,17 @@ def _deserialize_memoryview(header, frames): return out +@dask_serialize.register(PickleBuffer) +def _serialize_picklebuffer(obj): + return _serialize_memoryview(obj.raw()) + + +@dask_deserialize.register(PickleBuffer) +def _deserialize_picklebuffer(header, frames): + out = _deserialize_memoryview(header, frames) + return PickleBuffer(out) + + ######################### # Descend into __dict__ # ######################### diff --git a/distributed/shuffle/__init__.py b/distributed/shuffle/__init__.py index df35588b464..1d3524761dc 100644 --- a/distributed/shuffle/__init__.py +++ b/distributed/shuffle/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -from distributed.shuffle._arrow import check_minimal_arrow_version from distributed.shuffle._merge import HashJoinP2PLayer, hash_join_p2p from distributed.shuffle._rechunk import rechunk_p2p from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin @@ -8,7 +7,6 @@ from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin __all__ = [ - "check_minimal_arrow_version", "hash_join_p2p", "HashJoinP2PLayer", "P2PShuffleLayer", diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py deleted file mode 100644 index 71317021acd..00000000000 --- a/distributed/shuffle/_arrow.py +++ /dev/null @@ -1,201 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable -from pathlib import Path -from typing import TYPE_CHECKING - -from packaging.version import parse - -from dask.utils import parse_bytes - -if TYPE_CHECKING: - import pandas as pd - import pyarrow as pa - - -_INPUT_PARTITION_ID_COLUMN = "__input_partition_id__" - - -def check_dtype_support(meta_input: pd.DataFrame) -> None: - import pandas as pd - - for name in meta_input: - column = meta_input[name] - # FIXME: PyArrow does not support complex numbers: https://issues.apache.org/jira/browse/ARROW-638 - if pd.api.types.is_complex_dtype(column): - raise TypeError( - f"p2p does not support data of type '{column.dtype}' found in column '{name}'." - ) - # FIXME: PyArrow does not support sparse data: https://issues.apache.org/jira/browse/ARROW-8679 - if isinstance(column.dtype, pd.SparseDtype): - raise TypeError("p2p does not support sparse data found in column '{name}'") - - -def check_minimal_arrow_version() -> None: - """Verify that the the correct version of pyarrow is installed to support - the P2P extension. - - Raises a ModuleNotFoundError if pyarrow is not installed or an - ImportError if the installed version is not recent enough. - """ - minversion = "7.0.0" - try: - import pyarrow as pa - except ModuleNotFoundError: - raise ModuleNotFoundError(f"P2P shuffling requires pyarrow>={minversion}") - if parse(pa.__version__) < parse(minversion): - raise ImportError( - f"P2P shuffling requires pyarrow>={minversion} but only found {pa.__version__}" - ) - - -def concat_tables(tables: Iterable[pa.Table]) -> pa.Table: - import pyarrow as pa - - if parse(pa.__version__) >= parse("14.0.0"): - return pa.concat_tables(tables, promote_options="permissive") - try: - return pa.concat_tables(tables, promote=True) - except pa.ArrowNotImplementedError as e: - if parse(pa.__version__) >= parse("12.0.0"): - raise e - raise - - -def convert_shards( - shards: list[pa.Table], meta: pd.DataFrame, partition_column: str, drop_column: bool -) -> pd.DataFrame: - import pandas as pd - from pandas.core.dtypes.cast import find_common_type # type: ignore[attr-defined] - - from dask.dataframe.dispatch import from_pyarrow_table_dispatch - - table = concat_tables(shards) - table = table.sort_by(_INPUT_PARTITION_ID_COLUMN) - table = table.drop([_INPUT_PARTITION_ID_COLUMN]) - - if drop_column: - meta = meta.drop(columns=partition_column) - df = from_pyarrow_table_dispatch(meta, table, self_destruct=True) - reconciled_dtypes = {} - for column, dtype in meta.dtypes.items(): - actual = df[column].dtype - if actual == dtype: - continue - # Use the specific string dtype from meta (e.g., string[pyarrow]) - if isinstance(actual, pd.StringDtype) and isinstance(dtype, pd.StringDtype): - reconciled_dtypes[column] = dtype - continue - # meta might not be aware of the actual categories so the two dtype objects are not equal - # Also, the categories_dtype does not properly roundtrip through Arrow - if isinstance(actual, pd.CategoricalDtype) and isinstance( - dtype, pd.CategoricalDtype - ): - continue - reconciled_dtypes[column] = find_common_type([actual, dtype]) - - from dask.dataframe._compat import PANDAS_GE_300 - - kwargs = {} if PANDAS_GE_300 else {"copy": False} - return df.astype(reconciled_dtypes, **kwargs) - - -def buffers_to_table(data: list[tuple[int, bytes]]) -> pa.Table: - import numpy as np - import pyarrow as pa - - """Convert a list of arrow buffers and a schema to an Arrow Table""" - - def _create_input_partition_id_array( - table: pa.Table, input_partition_id: int - ) -> pa.ChunkedArray: - arrays = ( - np.full( - (batch.num_rows,), - input_partition_id, - dtype=np.uint32(), - ) - for batch in table.to_batches() - ) - return pa.chunked_array(arrays) - - tables = ( - (input_partition_id, deserialize_table(buffer)) - for input_partition_id, buffer in data - ) - tables = ( - table.append_column( - _INPUT_PARTITION_ID_COLUMN, - _create_input_partition_id_array(table, input_partition_id), - ) - for input_partition_id, table in tables - ) - - return concat_tables(tables) - - -def serialize_table(table: pa.Table) -> bytes: - import pyarrow as pa - - stream = pa.BufferOutputStream() - with pa.ipc.new_stream(stream, table.schema) as writer: - writer.write_table(table) - return stream.getvalue().to_pybytes() - - -def deserialize_table(buffer: bytes) -> pa.Table: - import pyarrow as pa - - with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader: - return reader.read_all() - - -def read_from_disk(path: Path) -> tuple[list[pa.Table], int]: - import pyarrow as pa - - batch_size = parse_bytes("1 MiB") - batch = [] - shards = [] - - with pa.OSFile(str(path), mode="rb") as f: - size = f.seek(0, whence=2) - f.seek(0) - prev = 0 - offset = f.tell() - while offset < size: - sr = pa.RecordBatchStreamReader(f) - shard = sr.read_all() - offset = f.tell() - batch.append(shard) - - if offset - prev >= batch_size: - table = concat_tables(batch) - shards.append(_copy_table(table)) - batch = [] - prev = offset - if batch: - table = concat_tables(batch) - shards.append(_copy_table(table)) - return shards, size - - -def concat_arrays(arrays: Iterable[pa.Array]) -> pa.Array: - import pyarrow as pa - - try: - return pa.concat_arrays(arrays) - except pa.ArrowNotImplementedError as e: - if parse(pa.__version__) >= parse("12.0.0"): - raise - if e.args[0].startswith("concatenation of extension"): - raise RuntimeError( - "P2P shuffling requires pyarrow>=12.0.0 to support extension types." - ) from e - raise - - -def _copy_table(table: pa.Table) -> pa.Table: - import pyarrow as pa - - arrs = [concat_arrays(column.chunks) for column in table.columns] - return pa.table(data=arrs, schema=table.schema) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index f4f69266cc7..19524834ed2 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -19,7 +19,6 @@ from dataclasses import dataclass, field from enum import Enum from functools import partial -from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast from tornado.ioloop import IOLoop @@ -116,11 +115,10 @@ def __init__( if disk: self._disk_buffer = DiskShardsBuffer( directory=directory, - read=self.read, memory_limiter=memory_limiter_disk, ) else: - self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize) + self._disk_buffer = MemoryShardsBuffer() with self._capture_metrics("background-comms"): self._comm_buffer = CommShardsBuffer( @@ -216,7 +214,7 @@ async def send( # and unpickle it on the other side. # Performance tests informing the size threshold: # https://github.com/dask/distributed/pull/8318 - shards_or_bytes: list | bytes = pickle.dumps(shards) + shards_or_bytes: list | bytes = pickle.dumps(shards, protocol=5) else: shards_or_bytes = shards @@ -298,7 +296,7 @@ def fail(self, exception: Exception) -> None: if not self.closed: self._exception = exception - def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing + def _read_from_disk(self, id: NDIndex) -> Any: self.raise_if_closed() return self._disk_buffer.read("_".join(str(i) for i in id)) @@ -335,6 +333,7 @@ def add_partition( if self.transferred: raise RuntimeError(f"Cannot add more partitions to {self}") # Log metrics both in the "execute" and in the "p2p" contexts + context_meter.digest_metric("p2p-partitions", 1, "count") with self._capture_metrics("foreground"): with ( context_meter.meter("p2p-shard-partition-noncpu"), @@ -372,14 +371,6 @@ def _get_output_partition( ) -> _T_partition_type: """Get an output partition to the shuffle run""" - @abc.abstractmethod - def read(self, path: Path) -> tuple[Any, int]: - """Read shards from disk""" - - @abc.abstractmethod - def deserialize(self, buffer: Any) -> Any: - """Deserialize shards""" - def get_worker_plugin() -> ShuffleWorkerPlugin: from distributed import get_worker diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 327c06f3dfe..0cb9d3ff686 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -1,21 +1,21 @@ from __future__ import annotations import contextlib +import mmap import pathlib import shutil import threading -from collections.abc import Callable, Generator, Iterable -from contextlib import contextmanager +from collections.abc import Generator, Iterator +from contextlib import contextmanager, nullcontext +from pathlib import Path from typing import Any -from toolz import concat - from distributed.metrics import context_meter, thread_time from distributed.shuffle._buffer import ShardsBuffer from distributed.shuffle._exceptions import DataUnavailable from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._pickle import pickle_bytelist -from distributed.utils import Deadline, empty_context, log_errors, nbytes +from distributed.shuffle._pickle import pickle_bytelist, unpickle_bytestream +from distributed.utils import Deadline, log_errors, nbytes class ReadWriteLock: @@ -123,10 +123,14 @@ class DiskShardsBuffer(ShardsBuffer): implementation of this scheme. """ + directory: pathlib.Path + _closed: bool + _use_raw_buffers: bool | None + _directory_lock: ReadWriteLock + def __init__( self, directory: str | pathlib.Path, - read: Callable[[pathlib.Path], tuple[Any, int]], memory_limiter: ResourceLimiter, ): super().__init__( @@ -137,7 +141,7 @@ def __init__( self.directory = pathlib.Path(directory) self.directory.mkdir(exist_ok=True) self._closed = False - self._read = read + self._use_raw_buffers = None self._directory_lock = ReadWriteLock() @log_errors @@ -154,34 +158,46 @@ async def _process(self, id: str, shards: list[Any]) -> None: future then we should consider simplifying this considerably and dropping the write into communicate above. """ - frames: Iterable[bytes | bytearray | memoryview] - if isinstance(shards[0], bytes): - # Manually serialized dataframes - frames = shards - serialize_meter_ctx: Any = empty_context - else: - # Unserialized numpy arrays - # Note: no calls to pickle_bytelist will happen until we actually start - # writing to disk below. - frames = concat(pickle_bytelist(shard) for shard in shards) - serialize_meter_ctx = context_meter.meter("serialize", func=thread_time) + assert shards + if self._use_raw_buffers is None: + self._use_raw_buffers = isinstance(shards[0], list) and isinstance( + shards[0][0], (bytes, bytearray, memoryview) + ) + serialize_ctx = ( + nullcontext() + if self._use_raw_buffers + else context_meter.meter("serialize", func=thread_time) + ) + + nbytes_acc = 0 + + def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]: + nonlocal nbytes_acc + for shard in shards: + if self._use_raw_buffers: + # list[bytes | bytearray | memoryview] for dataframe shuffle + # Shard was pre-serialized before being sent over the network. + nbytes_acc += sum(map(nbytes, shard)) + yield from shard + else: + # tuple[NDIndex, ndarray] for array rechunk + frames = [s.raw() for s in pickle_bytelist(shard)] + nbytes_acc += sum(frame.nbytes for frame in frames) + yield from frames with ( self._directory_lock.read(), context_meter.meter("disk-write"), - serialize_meter_ctx, + serialize_ctx, ): - # Consider boosting total_size a bit here to account for duplication - # We only need shared (i.e., read) access to the directory to write - # to a file inside of it. if self._closed: raise RuntimeError("Already closed") with open(self.directory / str(id), mode="ab") as f: - f.writelines(frames) + f.writelines(pickle_and_tally()) context_meter.digest_metric("disk-write", 1, "count") - context_meter.digest_metric("disk-write", sum(map(nbytes, frames)), "bytes") + context_meter.digest_metric("disk-write", nbytes_acc, "bytes") def read(self, id: str) -> Any: """Read a complete file back into memory""" @@ -210,6 +226,27 @@ def read(self, id: str) -> Any: else: raise DataUnavailable(id) + def _read(self, path: Path) -> tuple[Any, int]: + """Open a memory-mapped file descriptor to disk, read all metadata, and unpickle + all arrays. This is a fast sequence of short reads interleaved with seeks. + Do not read in memory the actual data; the arrays' buffers will point to the + memory-mapped area. + + The file descriptor will be automatically closed by the kernel when all the + returned arrays are dereferenced, which will happen after the call to + concatenate3. + """ + with path.open(mode="r+b") as fh: + buffer = memoryview(mmap.mmap(fh.fileno(), 0)) + # The file descriptor has *not* been closed! + + assert self._use_raw_buffers is not None + if self._use_raw_buffers: + return buffer, buffer.nbytes + else: + shards = list(unpickle_bytestream(buffer)) + return shards, buffer.nbytes + async def close(self) -> None: await super().close() with self._directory_lock.write(): diff --git a/distributed/shuffle/_memory.py b/distributed/shuffle/_memory.py index 106052cc756..f26ef4c9029 100644 --- a/distributed/shuffle/_memory.py +++ b/distributed/shuffle/_memory.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict, deque -from typing import Any, Callable +from typing import Any from dask.sizeof import sizeof @@ -12,18 +12,14 @@ class MemoryShardsBuffer(ShardsBuffer): - _deserialize: Callable[[Any], Any] _shards: defaultdict[str, deque[Any]] - def __init__(self, deserialize: Callable[[Any], Any]) -> None: + def __init__(self) -> None: super().__init__(memory_limiter=ResourceLimiter(None)) - self._deserialize = deserialize self._shards = defaultdict(deque) @log_errors async def _process(self, id: str, shards: list[Any]) -> None: - # TODO: This can be greatly simplified, there's no need for - # background threads at all. self._shards[id].extend(shards) def read(self, id: str) -> Any: @@ -41,6 +37,7 @@ def read(self, id: str) -> Any: data = [] while shards: shard = shards.pop() - data.append(self._deserialize(shard)) + # TODO unpickle dataframes + data.append(shard) return data diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index eb248b4e229..de09f581737 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -9,7 +9,6 @@ from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer -from distributed.shuffle._arrow import check_minimal_arrow_version from distributed.shuffle._core import ShuffleId, barrier_key, get_worker_plugin from distributed.shuffle._shuffle import shuffle_barrier, shuffle_transfer @@ -238,7 +237,6 @@ def __init__( parts_out: Sequence | None = None, annotations: dict | None = None, ) -> None: - check_minimal_arrow_version() self.name = name self.name_input_left = name_input_left self.meta_input_left = meta_input_left diff --git a/distributed/shuffle/_pickle.py b/distributed/shuffle/_pickle.py index 5e0a76425e2..49250a0c0f4 100644 --- a/distributed/shuffle/_pickle.py +++ b/distributed/shuffle/_pickle.py @@ -2,12 +2,17 @@ import pickle from collections.abc import Iterator -from typing import Any +from typing import TYPE_CHECKING, Any + +from toolz import first from distributed.protocol.utils import pack_frames_prelude, unpack_frames +if TYPE_CHECKING: + import pandas as pd + -def pickle_bytelist(obj: object) -> list[bytes | memoryview]: +def pickle_bytelist(obj: object, prelude: bool = True) -> list[pickle.PickleBuffer]: """Variant of :func:`serialize_bytelist`, that doesn't support compression, locally defined classes, or any of its other fancy features but runs 10x faster for numpy arrays @@ -18,11 +23,10 @@ def pickle_bytelist(obj: object) -> list[bytes | memoryview]: unpickle_bytestream """ frames: list = [] - pik = pickle.dumps( - obj, protocol=5, buffer_callback=lambda pb: frames.append(pb.raw()) - ) - frames.insert(0, pik) - frames.insert(0, pack_frames_prelude(frames)) + pik = pickle.dumps(obj, protocol=5, buffer_callback=frames.append) + frames.insert(0, pickle.PickleBuffer(pik)) + if prelude: + frames.insert(0, pickle.PickleBuffer(pack_frames_prelude(frames))) return frames @@ -40,3 +44,68 @@ def unpickle_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]: if remainder.nbytes == 0: break b = remainder + + +def pickle_dataframe_shard( + input_part_id: int, + shard: pd.DataFrame, +) -> list[pickle.PickleBuffer]: + """Optimized pickler for pandas Dataframes. DIscard all unnecessary metadata + (like the columns header). + + Parameters: + obj: pandas + """ + return pickle_bytelist( + (input_part_id, shard.index, *shard._mgr.blocks), prelude=False + ) + + +def unpickle_and_concat_dataframe_shards( + b: bytes | bytearray | memoryview, meta: pd.DataFrame +) -> pd.DataFrame: + """Optimized unpickler for pandas Dataframes. + + Parameters + ---------- + b: + raw buffer, containing the concatenation of the outputs of + :func:`pickle_dataframe_shard`, in arbitrary order + meta: + DataFrame header + + Returns + ------- + Reconstructed output shard, sorted by input partition ID + + **Roundtrip example** + + >>> import random + >>> import pandas as pd + >>> from toolz import concat + + >>> df = pd.DataFrame(...) # Input partition + >>> meta = df.iloc[:0].copy() + >>> shards = df.iloc[0:10], df.iloc[10:20], ... + >>> frames = [pickle_dataframe_shard(i, shard) for i, shard in enumerate(shards)] + >>> random.shuffle(frames) # Simulate the frames arriving in arbitrary order + >>> blob = bytearray(b"".join(concat(frames))) # Simulate disk roundtrip + >>> df2 = unpickle_and_concat_dataframe_shards(blob, meta) + """ + import pandas as pd + from pandas.core.internals import BlockManager + + parts = list(unpickle_bytestream(b)) + # [(input_part_id, index, *blocks), ...] + parts.sort(key=first) + shards = [] + for _, idx, *blocks in parts: + axes = [meta.columns, idx] + df = pd.DataFrame._from_mgr( # type: ignore[attr-defined] + BlockManager(blocks, axes, verify_integrity=False), axes + ) + shards.append(df) + + # Actually load memory-mapped buffers into memory and close the file + # descriptors + return pd.concat(shards, copy=True) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 6bb6e3f0691..a624aa81c55 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -96,14 +96,12 @@ from __future__ import annotations -import mmap import os from collections import defaultdict from collections.abc import Callable, Generator, Hashable, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from itertools import product -from pathlib import Path from typing import TYPE_CHECKING, Any, NamedTuple import toolz @@ -126,7 +124,6 @@ handle_unpack_errors, ) from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._pickle import unpickle_bytestream from distributed.shuffle._shuffle import barrier_key, shuffle_barrier from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof @@ -705,26 +702,6 @@ def _get_output_partition( # This is where we'll spend most time. return convert_chunk(data) - def deserialize(self, buffer: Any) -> Any: - return buffer - - def read(self, path: Path) -> tuple[list[list[tuple[NDIndex, np.ndarray]]], int]: - """Open a memory-mapped file descriptor to disk, read all metadata, and unpickle - all arrays. This is a fast sequence of short reads interleaved with seeks. - Do not read in memory the actual data; the arrays' buffers will point to the - memory-mapped area. - - The file descriptor will be automatically closed by the kernel when all the - returned arrays are dereferenced, which will happen after the call to - concatenate3. - """ - with path.open(mode="r+b") as fh: - buffer = memoryview(mmap.mmap(fh.fileno(), 0)) - - # The file descriptor has *not* been closed! - shards = list(unpickle_bytestream(buffer)) - return shards, buffer.nbytes - def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 21e3e388369..c9db53dbefb 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -14,10 +14,9 @@ ) from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from pathlib import Path +from pickle import PickleBuffer from typing import TYPE_CHECKING, Any -import toolz from tornado.ioloop import IOLoop import dask @@ -29,15 +28,7 @@ from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule from distributed.metrics import context_meter -from distributed.shuffle._arrow import ( - buffers_to_table, - check_dtype_support, - check_minimal_arrow_version, - convert_shards, - deserialize_table, - read_from_disk, - serialize_table, -) +from distributed.protocol.utils import pack_frames_prelude from distributed.shuffle._core import ( NDIndex, ShuffleId, @@ -50,13 +41,16 @@ ) from distributed.shuffle._exceptions import DataUnavailable from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._pickle import ( + pickle_dataframe_shard, + unpickle_and_concat_dataframe_shards, +) from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin -from distributed.sizeof import sizeof +from distributed.utils import nbytes logger = logging.getLogger("distributed.shuffle") if TYPE_CHECKING: import pandas as pd - import pyarrow as pa # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias @@ -123,7 +117,6 @@ def rearrange_by_column_p2p( raise TypeError( f"Expected meta {column=} to be an integer column, is {meta[column].dtype}." ) - check_dtype_support(meta) npartitions = npartitions or df.npartitions token = tokenize(df, column, npartitions) @@ -180,7 +173,6 @@ def __init__( annotations: dict | None = None, drop_column: bool = False, ): - check_minimal_arrow_version() self.name = name self.column = column self.npartitions = npartitions @@ -310,81 +302,45 @@ def _construct_graph(self) -> _T_LowLevelGraph: def split_by_worker( df: pd.DataFrame, column: str, - meta: pd.DataFrame, - worker_for: pd.Series, -) -> dict[Any, pa.Table]: - """ - Split data into many arrow batches, partitioned by destination worker - """ - import numpy as np - - from dask.dataframe.dispatch import to_pyarrow_table_dispatch - - # (cudf support) Avoid pd.Series - constructor = df._constructor_sliced - assert isinstance(constructor, type) - worker_for = constructor(worker_for) - df = df.merge( - right=worker_for.cat.codes.rename("_worker"), - left_on=column, - right_index=True, - how="inner", - ) - nrows = len(df) - if not nrows: - return {} - # assert len(df) == nrows # Not true if some outputs aren't wanted - # FIXME: If we do not preserve the index something is corrupting the - # bytestream such that it cannot be deserialized anymore - t = to_pyarrow_table_dispatch(df, preserve_index=True) - t = t.sort_by("_worker") - codes = np.asarray(t["_worker"]) - t = t.drop(["_worker"]) - del df - - splits = np.where(codes[1:] != codes[:-1])[0] + 1 - splits = np.concatenate([[0], splits]) - - shards = [ - t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) - ] - shards.append(t.slice(offset=splits[-1], length=None)) - - unique_codes = codes[splits] - out = { - # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43 - worker_for.cat.categories[code]: shard - for code, shard in zip(unique_codes, shards) - } - assert sum(map(len, out.values())) == nrows - return out - - -def split_by_partition( - t: pa.Table, column: str, drop_column: bool -) -> dict[int, pa.Table]: - """ - Split data into many arrow batches, partitioned by final partition + drop_column: bool, + worker_for: dict[int, str], + input_part_id: int, +) -> dict[str, tuple[int, list[tuple[int, list[PickleBuffer]]]]]: + """Split data into many horizontal slices, partitioned by destination worker, + and serialize them once. + + Returns + ------- + {worker addr: (input_part_id, [(output_part_id, buffers), ...]), ...} + + where buffers is the serialized output (pickle bytes, buffer, buffer, ...) of + (input_part_id, index, *blocks) + + **Notes** + + - The pickle header, which is a bytes object, is wrapped in PickleBuffer so + that it's not unnecessarily deep-copied when it's deserialized by the network + stack. + - We are not delegating serialization to the network stack because (1) it's quicker + with plain pickle and (2) we want to avoid deserializing everything on receive() + only to re-serialize it again immediately afterwards when writing it to disk. + So we serialize it once now and deserialize it once after reading back from disk. + + See Also + -------- + distributed.protocol.serialize._deserialize_bytes + distributed.protocol.serialize._deserialize_picklebuffer """ - import numpy as np + out: defaultdict[str, list[tuple[int, list[PickleBuffer]]]] = defaultdict(list) - partitions = t.select([column]).to_pandas()[column].unique() - partitions.sort() - t = t.sort_by(column) + for output_part_id, part in df.groupby(column, observed=False): + assert isinstance(output_part_id, int) + if drop_column: + del part[column] + frames = pickle_dataframe_shard(input_part_id, part) + out[worker_for[output_part_id]].append((output_part_id, frames)) - partition = np.asarray(t[column]) - if drop_column: - t = t.drop([column]) - splits = np.where(partition[1:] != partition[:-1])[0] + 1 - splits = np.concatenate([[0], splits]) - - shards = [ - t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) - ] - shards.append(t.slice(offset=splits[-1], length=None)) - assert len(t) == sum(map(len, shards)) - assert len(partitions) == len(shards) - return dict(zip(partitions, shards)) + return {k: (input_part_id, v) for k, v in out.items()} class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): @@ -434,7 +390,7 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): column: str meta: pd.DataFrame partitions_of: dict[str, list[int]] - worker_for: pd.Series + worker_for: dict[int, str] drop_column: bool def __init__( @@ -457,8 +413,6 @@ def __init__( drop_column: bool, loop: IOLoop, ): - import pandas as pd - super().__init__( id=id, run_id=run_id, @@ -480,55 +434,79 @@ def __init__( for part, addr in worker_for.items(): partitions_of[addr].append(part) self.partitions_of = dict(partitions_of) - self.worker_for = pd.Series(worker_for, name="_workers").astype("category") + self.worker_for = worker_for self.drop_column = drop_column - async def _receive(self, data: list[tuple[int, bytes]]) -> None: + async def _receive( + # See split_by_worker to understand annotation of data. + # PickleBuffer objects may have been converted to bytearray by the + # pickle roundtrip that is done by _core.py when buffers are too small + self, + data: list[ + tuple[int, list[tuple[int, list[PickleBuffer | bytes | bytearray]]]] + ], + ) -> None: self.raise_if_closed() - filtered = [] - for d in data: - if d[0] not in self.received: - filtered.append(d) - self.received.add(d[0]) - self.total_recvd += sizeof(d) - del data - if not filtered: + to_write: defaultdict[ + NDIndex, list[bytes | bytearray | memoryview] + ] = defaultdict(list) + + for input_part_id, parts in data: + if input_part_id not in self.received: + self.received.add(input_part_id) + for output_part_id, frames in parts: + frames_raw = [ + frame.raw() if isinstance(frame, PickleBuffer) else frame + for frame in frames + ] + self.total_recvd += sum(map(nbytes, frames_raw)) + to_write[output_part_id,] += [ + pack_frames_prelude(frames_raw), + *frames_raw, + ] + + if not to_write: return try: - groups = await self.offload(self._repartition_buffers, filtered) - del filtered - await self._write_to_disk(groups) + await self._write_to_disk(to_write) except Exception as e: self._exception = e raise - def _repartition_buffers( - self, data: list[tuple[int, bytes]] - ) -> dict[NDIndex, bytes]: - table = buffers_to_table(data) - groups = split_by_partition(table, self.column, self.drop_column) - assert len(table) == sum(map(len, groups.values())) - del data - return {(k,): serialize_table(v) for k, v in groups.items()} - def _shard_partition( self, data: pd.DataFrame, partition_id: int, - **kwargs: Any, - ) -> dict[str, tuple[int, bytes]]: + # See split_by_worker to understand annotation + ) -> dict[str, tuple[int, list[tuple[int, list[PickleBuffer]]]]]: out = split_by_worker( - data, - self.column, - self.meta, - self.worker_for, + df=data, + column=self.column, + drop_column=self.drop_column, + worker_for=self.worker_for, + input_part_id=partition_id, ) - out = {k: (partition_id, serialize_table(t)) for k, t in out.items()} - nbytes = sum(len(b) for _, b in out.values()) - context_meter.digest_metric("p2p-shards", nbytes, "bytes") - context_meter.digest_metric("p2p-shards", len(out), "count") + # Log metrics + # Note: more metrics for this function call are logged by _core.add_partitiion() + overhead_nbytes = 0 + buffers_nbytes = 0 + shards_count = 0 + buffers_count = 0 + for _, shards in out.values(): + shards_count += len(shards) + for _, frames in shards: + # frames = [pickle bytes, buffer, buffer, ...] + buffers_count += len(frames) - 2 + overhead_nbytes += frames[0].raw().nbytes + buffers_nbytes += sum(frame.raw().nbytes for frame in frames[1:]) + + context_meter.digest_metric("p2p-shards-overhead", overhead_nbytes, "bytes") + context_meter.digest_metric("p2p-shards-buffers", buffers_nbytes, "bytes") + context_meter.digest_metric("p2p-shards-buffers", buffers_count, "count") + context_meter.digest_metric("p2p-shards", shards_count, "count") + # End log metrics return out @@ -538,24 +516,20 @@ def _get_output_partition( key: Key, **kwargs: Any, ) -> pd.DataFrame: + meta = self.meta.copy() + if self.drop_column: + meta = self.meta.drop(columns=self.column) + try: - data = self._read_from_disk((partition_id,)) - return convert_shards(data, self.meta, self.column, self.drop_column) + buffer = self._read_from_disk((partition_id,)) except DataUnavailable: - result = self.meta.copy() - if self.drop_column: - result = self.meta.drop(columns=self.column) - return result + return meta + + return unpickle_and_concat_dataframe_shards(buffer, meta) def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] - def read(self, path: Path) -> tuple[pa.Table, int]: - return read_from_disk(path) - - def deserialize(self, buffer: Any) -> Any: - return deserialize_table(buffer) - @dataclass(frozen=True) class DataFrameShuffleSpec(ShuffleSpec[int]): diff --git a/distributed/shuffle/tests/test_core.py b/distributed/shuffle/tests/test_core.py index deb9d2a0bbb..3d310d492bc 100644 --- a/distributed/shuffle/tests/test_core.py +++ b/distributed/shuffle/tests/test_core.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pickle import PickleBuffer + import pytest from distributed.shuffle._core import _mean_shard_size @@ -12,7 +14,17 @@ def test_mean_shard_size(): # Don't fully iterate over large collections assert _mean_shard_size([b"12" * n for n in range(1000)]) == 9 # Support any Buffer object - assert _mean_shard_size([b"12", bytearray(b"1234"), memoryview(b"123456")]) == 4 + assert ( + _mean_shard_size( + [ + b"12", + bytearray(b"1234"), + memoryview(b"123456"), + PickleBuffer(b"12345678"), + ] + ) + == 5 + ) # Recursion into lists or tuples; ignore int assert _mean_shard_size([(1, 2, [3, b"123456"])]) == 6 # Don't blindly call sizeof() on unexpected objects diff --git a/distributed/shuffle/tests/test_disk_buffer.py b/distributed/shuffle/tests/test_disk_buffer.py index f15ccf1ab18..fc467558736 100644 --- a/distributed/shuffle/tests/test_disk_buffer.py +++ b/distributed/shuffle/tests/test_disk_buffer.py @@ -2,7 +2,6 @@ import asyncio import os -from pathlib import Path from typing import Any import pytest @@ -13,20 +12,13 @@ from distributed.utils_test import gen_test -def read_bytes(path: Path) -> tuple[bytes, int]: - with path.open("rb") as f: - data = f.read() - size = f.tell() - return data, size - - @gen_test() async def test_basic(tmp_path): async with DiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) + directory=tmp_path, memory_limiter=ResourceLimiter(None) ) as mf: - await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) - await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) + await mf.write({"x": "foo", "y": "bar"}) + await mf.write({"x": "baz", "y": "lol"}) await mf.flush() @@ -36,17 +28,17 @@ async def test_basic(tmp_path): with pytest.raises(DataUnavailable): mf.read("z") - assert x == b"0" * 2000 - assert y == b"1" * 1000 + assert x == ["foo", "baz"] + assert y == ["bar", "lol"] assert not os.path.exists(tmp_path) @gen_test() async def test_read_before_flush(tmp_path): - payload = {"1": b"foo"} + payload = {"1": "foo"} async with DiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) + directory=tmp_path, memory_limiter=ResourceLimiter(None) ) as mf: with pytest.raises(RuntimeError): mf.read(1) @@ -57,7 +49,7 @@ async def test_read_before_flush(tmp_path): mf.read(1) await mf.flush() - assert mf.read("1") == b"foo" + assert mf.read("1") == ["foo"] with pytest.raises(DataUnavailable): mf.read(2) @@ -66,9 +58,9 @@ async def test_read_before_flush(tmp_path): @gen_test() async def test_many(tmp_path, count): async with DiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) + directory=tmp_path, memory_limiter=ResourceLimiter(None) ) as mf: - d = {i: str(i).encode() * 100 for i in range(count)} + d = {i: str(i) * 100 for i in range(count)} for _ in range(10): await mf.write(d) @@ -77,7 +69,7 @@ async def test_many(tmp_path, count): for i in d: out = mf.read(i) - assert out == str(i).encode() * 100 * 10 + assert out == [str(i) * 100] * 10 assert not os.path.exists(tmp_path) @@ -93,7 +85,7 @@ async def _process(self, *args: Any, **kwargs: Any) -> None: @gen_test() async def test_exceptions(tmp_path): async with BrokenDiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) + directory=tmp_path, memory_limiter=ResourceLimiter(None) ) as mf: await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) @@ -124,7 +116,7 @@ async def test_high_pressure_flush_with_exception(tmp_path): payload = {f"shard-{ix}": [f"shard-{ix}".encode() * 100] for ix in range(100)} async with EventuallyBrokenDiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) + directory=tmp_path, memory_limiter=ResourceLimiter(None) ) as mf: tasks = [] for _ in range(10): diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py index 7fa4d5b2ce3..88c25b430c1 100644 --- a/distributed/shuffle/tests/test_graph.py +++ b/distributed/shuffle/tests/test_graph.py @@ -6,7 +6,6 @@ pd = pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") -pytest.importorskip("pyarrow") import dask from dask.blockwise import Blockwise diff --git a/distributed/shuffle/tests/test_memory_buffer.py b/distributed/shuffle/tests/test_memory_buffer.py index 453e6bef219..4b84ec74e9d 100644 --- a/distributed/shuffle/tests/test_memory_buffer.py +++ b/distributed/shuffle/tests/test_memory_buffer.py @@ -13,7 +13,7 @@ def deserialize_bytes(buffer: bytes) -> bytes: @gen_test() async def test_basic(): - async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf: + async with MemoryShardsBuffer() as mf: await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) @@ -32,7 +32,7 @@ async def test_basic(): @gen_test() async def test_read_before_flush(): payload = {"1": b"foo"} - async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf: + async with MemoryShardsBuffer() as mf: with pytest.raises(RuntimeError): mf.read("1") @@ -50,7 +50,7 @@ async def test_read_before_flush(): @pytest.mark.parametrize("count", [2, 100, 1000]) @gen_test() async def test_many(count): - async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf: + async with MemoryShardsBuffer() as mf: d = {str(i): str(i).encode() * 100 for i in range(count)} for _ in range(10): diff --git a/distributed/shuffle/tests/test_merge.py b/distributed/shuffle/tests/test_merge.py index 27786e963c5..1dee5623aaf 100644 --- a/distributed/shuffle/tests/test_merge.py +++ b/distributed/shuffle/tests/test_merge.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib from typing import Any from unittest import mock @@ -23,11 +22,6 @@ from distributed import get_client -try: - import pyarrow as pa -except ImportError: - pa = None - pytestmark = pytest.mark.ci1 @@ -47,27 +41,6 @@ async def list_eq(a, b): dd._compat.assert_numpy_array_equal(av, bv) -@pytest.mark.skipif(dd._dask_expr_enabled(), reason="pyarrow>=7.0.0 already required") -@gen_cluster(client=True) -async def test_minimal_version(c, s, a, b): - no_pyarrow_ctx = ( - mock.patch.dict("sys.modules", {"pyarrow": None}) - if pa is not None - else contextlib.nullcontext() - ) - with no_pyarrow_ctx: - A = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6], "y": [1, 1, 2, 2, 3, 4]}) - a = dd.repartition(A, [0, 4, 5]) - - B = pd.DataFrame({"y": [1, 3, 4, 4, 5, 6], "z": [6, 5, 4, 3, 2, 1]}) - b = dd.repartition(B, [0, 2, 5]) - - with pytest.raises( - ModuleNotFoundError, match="requires pyarrow" - ), dask.config.set({"dataframe.shuffle.method": "p2p"}): - await c.compute(dd.merge(a, b, left_on="x", right_on="z")) - - @pytest.mark.parametrize("how", ["inner", "left", "right", "outer"]) @gen_cluster(client=True) async def test_basic_merge(c, s, a, b, how): diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index c4e70e63d2f..ee76843988a 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib import itertools import logging import os @@ -15,7 +14,6 @@ from unittest import mock import pytest -from packaging.version import parse as parse_version from tornado.ioloop import IOLoop from dask.utils import key_split @@ -43,18 +41,11 @@ ) from distributed.core import ConnectionPool from distributed.scheduler import TaskState as SchedulerTaskState -from distributed.shuffle._arrow import ( - buffers_to_table, - convert_shards, - read_from_disk, - serialize_table, -) from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._shuffle import ( DataFrameShuffleRun, _get_worker_for_range_sharding, - split_by_partition, split_by_worker, ) from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin, _ShuffleRunManager @@ -73,13 +64,8 @@ try: import pyarrow as pa - - PYARROW_GE_12 = parse_version(pa.__version__).release >= (12,) - PYARROW_GE_14 = parse_version(pa.__version__).release >= (14,) except ImportError: pa = None - PYARROW_GE_12 = False - PYARROW_GE_14 = False @pytest.fixture(params=[0, 0.3, 1], ids=["none", "some", "all"]) @@ -124,43 +110,12 @@ async def check_scheduler_cleanup( assert not plugin.heartbeats -@pytest.mark.skipif(dd._dask_expr_enabled(), reason="pyarrow>=7.0.0 already required") -@gen_cluster(client=True) -async def test_minimal_version(c, s, a, b): - no_pyarrow_ctx = ( - mock.patch.dict("sys.modules", {"pyarrow": None}) - if pa is not None - else contextlib.nullcontext() - ) - with no_pyarrow_ctx: - df = dask.datasets.timeseries( - start="2000-01-01", - end="2000-01-10", - dtypes={"x": float, "y": float}, - freq="10 s", - ) - with pytest.raises( - ModuleNotFoundError, match="requires pyarrow" - ), dask.config.set({"dataframe.shuffle.method": "p2p"}): - await c.compute(df.shuffle("x")) - - @pytest.mark.gpu -@pytest.mark.filterwarnings( - "ignore:Ignoring the following arguments to `from_pyarrow_table_dispatch`." -) @gen_cluster(client=True) async def test_basic_cudf_support(c, s, a, b): cudf = pytest.importorskip("cudf") pytest.importorskip("dask_cudf") - try: - from dask.dataframe.dispatch import to_pyarrow_table_dispatch - - to_pyarrow_table_dispatch(cudf.DataFrame()) - except TypeError: - pytest.skip(reason="Newer version of dask_cudf is required.") - df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1019,8 +974,11 @@ async def test_heartbeat(c, s, a, b): await check_scheduler_cleanup(s) -@pytest.mark.skipif("not pa", reason="Requires PyArrow") -@pytest.mark.filterwarnings("ignore:DatetimeTZBlock") # pandas >=2.2 vs. pyarrow <15 +class Stub: + def __init__(self, value: int) -> None: + self.value = value + + @pytest.mark.parametrize("drop_column", [True, False]) def test_processing_chain(tmp_path, drop_column): """ @@ -1030,10 +988,6 @@ def test_processing_chain(tmp_path, drop_column): Here we verify its accuracy in a single threaded situation. """ - class Stub: - def __init__(self, value: int) -> None: - self.value = value - counter = count() workers = ["a", "b", "c"] npartitions = 5 @@ -1061,11 +1015,9 @@ def __init__(self, value: int) -> None: [np.timedelta64(1, "D") + i for i in range(100)], dtype="timedelta64[ns]", ), - # FIXME PyArrow does not support complex numbers: - # https://issues.apache.org/jira/browse/ARROW-638 - # f"col{next(counter)}": pd.array(range(100), dtype="csingle"), - # f"col{next(counter)}": pd.array(range(100), dtype="cdouble"), - # f"col{next(counter)}": pd.array(range(100), dtype="clongdouble"), + f"col{next(counter)}": pd.array(range(100), dtype="csingle"), + f"col{next(counter)}": pd.array(range(100), dtype="cdouble"), + f"col{next(counter)}": pd.array(range(100), dtype="clongdouble"), # Nullable dtypes f"col{next(counter)}": pd.array([True, False] * 50, dtype="boolean"), f"col{next(counter)}": pd.array(range(100), dtype="Int8"), @@ -1083,31 +1035,18 @@ def __init__(self, value: int) -> None: ), f"col{next(counter)}": pd.array(["x", "y"] * 50, dtype="category"), f"col{next(counter)}": pd.array(["lorem ipsum"] * 100, dtype="string"), - # FIXME: PyArrow does not support sparse data: - # https://issues.apache.org/jira/browse/ARROW-8679 - # f"col{next(counter)}": pd.array( - # [np.nan, np.nan, 1.0, np.nan, np.nan] * 20, - # dtype="Sparse[float64]", - # ), - # custom objects - # FIXME: Serializing custom objects is not supported in P2P shuffling - # f"col{next(counter)}": pd.array( - # [Stub(i) for i in range(100)], dtype="object" - # ), + f"col{next(counter)}": pd.array( + [np.nan, np.nan, 1.0, np.nan, np.nan] * 20, + dtype="Sparse[float64]", + ), + # custom objects (no cloudpickle) + f"col{next(counter)}": pd.array([Stub(i) for i in range(100)], dtype="object"), + # Extension types + f"col{next(counter)}": pd.period_range("2022-01-01", periods=100, freq="D"), + f"col{next(counter)}": pd.interval_range(start=0, end=100, freq=1), } - if PYARROW_GE_12: - columns.update( - { - # Extension types - f"col{next(counter)}": pd.period_range( - "2022-01-01", periods=100, freq="D" - ), - f"col{next(counter)}": pd.interval_range(start=0, end=100, freq=1), - } - ) - - if PANDAS_GE_150: + if PANDAS_GE_150 and pa: columns.update( { # PyArrow dtypes @@ -1146,10 +1085,10 @@ def __init__(self, value: int) -> None: df = pd.DataFrame(columns) df["_partitions"] = df.col4 % npartitions worker_for = {i: random.choice(workers) for i in list(range(npartitions))} - worker_for = pd.Series(worker_for, name="_worker").astype("category") + worker_for = pd.Series(worker_for, name="_workers").astype("category") meta = df.head(0) - data = split_by_worker(df, "_partitions", worker_for=worker_for, meta=meta) + data = split_by_worker(df, "_partitions", worker_for=worker_for) assert set(data) == set(worker_for.cat.categories) assert sum(map(len, data.values())) == len(df) @@ -1657,7 +1596,6 @@ def new_shuffle( # 36 parametrizations # Runtime each ~0.1s -@pytest.mark.skipif(not pa, reason="Requires PyArrow") @pytest.mark.parametrize("n_workers", [1, 10]) @pytest.mark.parametrize("n_input_partitions", [1, 2, 10]) @pytest.mark.parametrize("npartitions", [1, 20]) @@ -1732,8 +1670,6 @@ async def test_basic_lowlevel_shuffle( total_bytes_recvd += metrics["disk"]["total"] total_bytes_recvd_shuffle += s.total_recvd - assert total_bytes_recvd_shuffle == total_bytes_sent - all_parts = [] for part, worker in worker_for_mapping.items(): s = local_shuffle_pool.shuffles[worker] @@ -1751,7 +1687,6 @@ async def test_basic_lowlevel_shuffle( assert len(df_after) == len(pd.concat(dfs)) -@pytest.mark.skipif(not pa, reason="Requires PyArrow") @gen_test() async def test_error_offload(tmp_path, loop_in_thread): dfs = [] @@ -1807,7 +1742,6 @@ async def offload(self, func, *args): await asyncio.gather(*[s.close() for s in [sA, sB]]) -@pytest.mark.skipif(not pa, reason="Requires PyArrow") @gen_test() async def test_error_send(tmp_path, loop_in_thread): dfs = [] @@ -1863,7 +1797,6 @@ async def send(self, *args: Any, **kwargs: Any) -> None: await asyncio.gather(*[s.close() for s in [sA, sB]]) -@pytest.mark.skipif(not pa, reason="Requires PyArrow") @gen_test() async def test_error_receive(tmp_path, loop_in_thread): dfs = [] @@ -1888,7 +1821,7 @@ async def test_error_receive(tmp_path, loop_in_thread): partitions_for_worker[w].append(part) class ErrorReceive(DataFrameShuffleRun): - async def _receive(self, data: list[tuple[int, bytes]]) -> None: + async def _receive(self, data: list[tuple[int, Any]]) -> None: raise RuntimeError("Error during receive") with DataFrameShuffleTestPool() as local_shuffle_pool: @@ -2365,18 +2298,11 @@ def make_partition(i): with dask.config.set({"dataframe.shuffle.method": "p2p"}): out = ddf.shuffle(on="a", ignore_index=True) - if PYARROW_GE_14: - result, expected = c.compute([ddf, out]) - result = await result - expected = await expected - dd.assert_eq(result, expected) - del result - else: - with raises_with_cause( - RuntimeError, r"shuffling \w+ failed", pa.ArrowInvalid, "incompatible types" - ): - await c.compute(out) - await c.close() + result, expected = c.compute([ddf, out]) + result = await result + expected = await expected + dd.assert_eq(result, expected) + del result del out await check_worker_cleanup(a) diff --git a/distributed/shuffle/tests/test_shuffle_plugins.py b/distributed/shuffle/tests/test_shuffle_plugins.py index 69983e9890e..9beb6063217 100644 --- a/distributed/shuffle/tests/test_shuffle_plugins.py +++ b/distributed/shuffle/tests/test_shuffle_plugins.py @@ -5,11 +5,7 @@ import pytest from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin -from distributed.shuffle._shuffle import ( - _get_worker_for_range_sharding, - split_by_partition, - split_by_worker, -) +from distributed.shuffle._shuffle import _get_worker_for_range_sharding, split_by_worker from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.utils_test import gen_cluster @@ -38,8 +34,6 @@ async def test_installation_on_scheduler(s, a): def test_split_by_worker(): - pytest.importorskip("pyarrow") - df = pd.DataFrame( { "x": [1, 2, 3, 4, 5], @@ -63,8 +57,6 @@ def test_split_by_worker(): def test_split_by_worker_empty(): - pytest.importorskip("pyarrow") - df = pd.DataFrame( { "x": [1, 2, 3, 4, 5], @@ -78,8 +70,6 @@ def test_split_by_worker_empty(): def test_split_by_worker_many_workers(): - pytest.importorskip("pyarrow") - df = pd.DataFrame( { "x": [1, 2, 3, 4, 5], @@ -102,23 +92,3 @@ def test_split_by_worker_many_workers(): assert _get_worker_for_range_sharding(npartitions, 1, workers) in out assert sum(map(len, out.values())) == len(df) - - -@pytest.mark.parametrize("drop_column", [True, False]) -def test_split_by_partition(drop_column): - pa = pytest.importorskip("pyarrow") - - df = pd.DataFrame( - { - "x": [1, 2, 3, 4, 5], - "_partition": [3, 1, 2, 3, 1], - } - ) - t = pa.Table.from_pandas(df) - - out = split_by_partition(t, "_partition", drop_column) - assert set(out) == {1, 2, 3} - if drop_column: - df = df.drop(columns="_partition") - assert out[1].column_names == list(df.columns) - assert sum(map(len, out.values())) == len(df) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 0d25f0bfe4a..0b86ecde18f 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -81,7 +81,6 @@ from distributed.diagnostics.plugin import WorkerPlugin from distributed.metrics import time from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler -from distributed.shuffle import check_minimal_arrow_version from distributed.sizeof import sizeof from distributed.utils import get_mp_context, is_valid_xml, open_port, sync, tmp_text from distributed.utils_test import ( @@ -3368,18 +3367,13 @@ async def test_cancel_clears_processing(c, s, *workers): def test_default_get(loop_in_thread): - has_pyarrow = False - try: - check_minimal_arrow_version() - has_pyarrow = True - except ImportError: - pass loop = loop_in_thread + distributed_default = "p2p" + local_default = "disk" + with cluster() as (s, [a, b]): pre_get = dask.base.get_scheduler() # These may change in the future but the selection below should not - distributed_default = "p2p" if has_pyarrow else "tasks" - local_default = "disk" assert get_default_shuffle_method() == local_default with Client(s["address"], set_as_default=True, loop=loop) as c: assert dask.base.get_scheduler() == c.get diff --git a/distributed/utils.py b/distributed/utils.py index 6fbbf0f2244..a7d5de8d038 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1199,6 +1199,8 @@ def nbytes(frame, _bytes_like=(bytes, bytearray)): """Number of bytes of a frame or memoryview""" if isinstance(frame, _bytes_like): return len(frame) + if isinstance(frame, PickleBuffer): + return frame.raw().nbytes else: try: return frame.nbytes