From 789505662c295fdd3cf8304ddd88b91b34a533c1 Mon Sep 17 00:00:00 2001 From: Evangelos Pappas Date: Mon, 18 May 2026 14:43:59 +0200 Subject: [PATCH] feat(sdk): collapse bench API to bool opt-in + lazy training.bench (refs basilica-backend#661) Part of the SDK API simplification plan (docs/plans/SDK-API-SIMPLIFICATION-PLAN.md on basilica-backend main). The bench API surface was over-exposed for an opt-in diagnostic helper: - bench: str modes ("on-start" / "off") -- two string tokens for a binary - training.wait_until_bench_complete(timeout=...) raises or returns - BenchStatus with four terminal phases + four is_* properties - Two access paths to the result (bench_status.result vs training.bench) Target after S2 (this change): - bench: bool -- True opts in, False opts out (default) - training.bench: BenchResult | None (unchanged; lazy) - training.bench_diagnostics: Optional[Dict[str, Any]] (new) -- small debug dict with phase / message / timings, for the rare case where the user wants to know WHY a probe didn't measure - bench: str ("on-start" / "off") still accepted with DeprecationWarning - wait_until_bench_complete[_async] and bench_status emit DeprecationWarning pointing at the lazy training.bench accessor - Removed in next major Changes: - python/basilica/__init__.py: _normalize_bench_param helper; bench param type Union[bool, str] with default False on deploy_distributed, deploy_distributed_async, deploy_distributed_managed, deploy_distributed_managed_async; deprecation warning emitted by helper - python/basilica/decorators.py: @distributed bench param type Union[bool, str] with default False (forwarded verbatim, normalized downstream) - python/basilica/distributed.py: new training.bench_diagnostics lazy property; bench_status, wait_until_bench_complete[_async] emit DeprecationWarning; internal _bench_status_no_warn reads the BenchStatus without warning - tests/test_bench_bool_simplification.py: 22 tests pinning the new surface (bool acceptance, str deprecation, diagnostics dict shape, lazy bench BenchResult|None semantics, wait_until_bench_complete + bench_status deprecation warnings) - pyproject.toml + Cargo.toml + Cargo.lock: bump 0.29.4 -> 0.29.5 All 179 existing SDK tests pass; new tests bring total to 201. Wire contract is unchanged: distributed.bench.mode is still "on-start" / "off" on the operator-facing JSON. Only the user-facing SDK parameter type narrows. Operator + CRD schema untouched. --- Cargo.lock | 2 +- crates/basilica-sdk-python/Cargo.toml | 2 +- crates/basilica-sdk-python/pyproject.toml | 2 +- .../python/basilica/__init__.py | 60 ++- .../python/basilica/decorators.py | 9 +- .../python/basilica/distributed.py | 167 ++++-- .../tests/test_bench_bool_simplification.py | 500 ++++++++++++++++++ 7 files changed, 679 insertions(+), 63 deletions(-) create mode 100644 crates/basilica-sdk-python/tests/test_bench_bool_simplification.py diff --git a/Cargo.lock b/Cargo.lock index 4284f0e0..9f534710 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1520,7 +1520,7 @@ dependencies = [ [[package]] name = "basilica-sdk-python" -version = "0.29.4" +version = "0.29.5" dependencies = [ "basilica-sdk", "basilica-validator", diff --git a/crates/basilica-sdk-python/Cargo.toml b/crates/basilica-sdk-python/Cargo.toml index 73eab16d..670647c3 100644 --- a/crates/basilica-sdk-python/Cargo.toml +++ b/crates/basilica-sdk-python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "basilica-sdk-python" -version = "0.29.4" +version = "0.29.5" edition = "2021" authors = ["Basilica Team"] description = "Python bindings for the Basilica SDK" diff --git a/crates/basilica-sdk-python/pyproject.toml b/crates/basilica-sdk-python/pyproject.toml index 6457ad55..afc9148a 100644 --- a/crates/basilica-sdk-python/pyproject.toml +++ b/crates/basilica-sdk-python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "basilica-sdk" -version = "0.29.4" +version = "0.29.5" description = "Python SDK for deploying containerized applications on the Basilica GPU cloud" readme = "README.md" license = { text = "MIT OR Apache-2.0" } diff --git a/crates/basilica-sdk-python/python/basilica/__init__.py b/crates/basilica-sdk-python/python/basilica/__init__.py index 6f7cd69d..a4e0e88d 100644 --- a/crates/basilica-sdk-python/python/basilica/__init__.py +++ b/crates/basilica-sdk-python/python/basilica/__init__.py @@ -212,6 +212,39 @@ def __init__( _SHELL_SCRIPT_LAUNCHERS = frozenset({"bash", "sh", "/bin/bash", "/bin/sh"}) +def _normalize_bench_param(bench: Any) -> str: + """ + basilica-backend#661 / SDK-S2: collapse the bench API surface to a + bool opt-in. + + ``bench=True`` -> ``"on-start"`` wire token (probe scheduled). + ``bench=False`` -> ``"off"`` wire token (no probe; default). + ``bench="on-start" | "off"`` -> back-compat, emits DeprecationWarning + pointing at the bool form. Any other input raises ValidationError + via the downstream wire-shape check (preserves the existing + "invalid bench mode" error path). + + The wire shape stays the operator's canonical ``"on-start" | "off"`` + string. Only the user-facing parameter type narrows. + """ + if isinstance(bench, bool): + return "on-start" if bench else "off" + if isinstance(bench, str): + if bench in ("on-start", "off"): + warnings.warn( + f"bench={bench!r} (str) is deprecated and will be removed " + f"in the next major. Use bench=" + f"{'True' if bench == 'on-start' else 'False'} instead. " + f"See basilica-backend#661 / SDK-S2.", + DeprecationWarning, + stacklevel=3, + ) + return bench + # Pass through to the downstream validator so the existing error + # path stays consistent (raises ValidationError with field="bench"). + return bench # type: ignore[return-value] + + def _shell_join_preserving_vars(command: List[str]) -> str: """ Join an argv list into a single shell command string for the operator's @@ -1115,7 +1148,7 @@ def deploy_distributed( provider_filter: Optional[ProviderFilter] = None, topology_spread: str = "provider-aware", nccl_env: Optional[Dict[str, str]] = None, - bench: str = "off", + bench: Union[bool, str] = False, bench_placement: str = "preferred", rendezvous_backend: str = "etcd-v2", command: Optional[List[str]] = None, @@ -1156,10 +1189,14 @@ def deploy_distributed( | none`. Default `provider-aware`. SDK arch § 4. nccl_env: NCCL env vars merged on top of operator defaults. User values win on collision. - bench: `"on-start"` to schedule a 2-rank NCCL bench probe in the - user's namespace alongside workers (counts against the - namespace rank budget; result lands on `training.bench`). - `"off"` (default) skips the probe. SDK arch § 7. + bench: ``True`` to opt in to the per-UD NCCL bench probe; + ``False`` (default) skips the probe. Reads back as + ``training.bench`` (``BenchResult | None``) after the UD + reaches a terminal state. Replaces the ``"on-start"`` / + ``"off"`` string modes (still accepted with + ``DeprecationWarning``; removed in the next major). See + ``basilica-backend#661`` / SDK-S2 for the rationale. + SDK arch § 7. bench_placement: Placement policy for the bench Pod pair on multi-tenant clusters. `"preferred"` (default) lets the bench fall back off the worker pair when those nodes have @@ -1225,6 +1262,11 @@ def deploy_distributed( field="wait_for_bench", value=wait_for_bench, ) + # basilica-backend#661 / SDK-S2: collapse bench API to bool. + # ``True`` -> "on-start"; ``False`` -> "off"; legacy str modes + # remain accepted with DeprecationWarning, normalized to the + # wire token here. + bench = _normalize_bench_param(bench) request_dict = self._build_distributed_request( name=name, @@ -1295,7 +1337,7 @@ def deploy_distributed_managed( provider_filter: Optional[ProviderFilter] = None, topology_spread: str = "provider-aware", nccl_env: Optional[Dict[str, str]] = None, - bench: str = "off", + bench: Union[bool, str] = False, bench_placement: str = "preferred", rendezvous_backend: str = "etcd-v2", command: Optional[List[str]] = None, @@ -2321,7 +2363,7 @@ async def deploy_distributed_async( provider_filter: Optional[ProviderFilter] = None, topology_spread: str = "provider-aware", nccl_env: Optional[Dict[str, str]] = None, - bench: str = "off", + bench: Union[bool, str] = False, bench_placement: str = "preferred", rendezvous_backend: str = "etcd-v2", command: Optional[List[str]] = None, @@ -2350,6 +2392,8 @@ async def deploy_distributed_async( field="wait_for_bench", value=wait_for_bench, ) + # basilica-backend#661 / SDK-S2: collapse bench API to bool. + bench = _normalize_bench_param(bench) request_dict = self._build_distributed_request( name=name, @@ -2414,7 +2458,7 @@ def deploy_distributed_managed_async( provider_filter: Optional[ProviderFilter] = None, topology_spread: str = "provider-aware", nccl_env: Optional[Dict[str, str]] = None, - bench: str = "off", + bench: Union[bool, str] = False, bench_placement: str = "preferred", rendezvous_backend: str = "etcd-v2", command: Optional[List[str]] = None, diff --git a/crates/basilica-sdk-python/python/basilica/decorators.py b/crates/basilica-sdk-python/python/basilica/decorators.py index 7371f707..0c5a5cf5 100644 --- a/crates/basilica-sdk-python/python/basilica/decorators.py +++ b/crates/basilica-sdk-python/python/basilica/decorators.py @@ -483,7 +483,7 @@ def distributed( provider_filter: Optional[Union[ProviderFilter, Dict[str, List[str]]]] = None, topology_spread: str = "provider-aware", nccl_env: Optional[Dict[str, str]] = None, - bench: str = "off", + bench: Union[bool, str] = False, rendezvous_backend: str = "etcd-v2", env: Optional[Dict[str, str]] = None, pip_packages: Optional[List[str]] = None, @@ -512,7 +512,12 @@ def distributed( provider_filter: ProviderFilter or `{"include": [...], "exclude": [...]}` dict. topology_spread: One of `pack | provider-aware | region-aware | none`. nccl_env: NCCL env vars merged on top of operator defaults. - bench: `on-start` to schedule a 2-rank NCCL bench probe; `off` (default). + bench: ``True`` to opt in to the per-UD NCCL bench probe; ``False`` + (default) skips it. Reads back as ``training.bench`` + (``BenchResult | None``) post-terminal. The legacy str modes + ``"on-start"`` / ``"off"`` remain accepted with + ``DeprecationWarning``; removed in the next major. See + ``basilica-backend#661`` / SDK-S2. rendezvous_backend: `etcd-v2` (default) | `c10d` | `static`. env: Environment variables passed to the worker pods. pip_packages: Additional pip packages to install. diff --git a/crates/basilica-sdk-python/python/basilica/distributed.py b/crates/basilica-sdk-python/python/basilica/distributed.py index 1291d83a..8e45dac5 100644 --- a/crates/basilica-sdk-python/python/basilica/distributed.py +++ b/crates/basilica-sdk-python/python/basilica/distributed.py @@ -14,6 +14,7 @@ import asyncio import time +import warnings from dataclasses import dataclass, field from datetime import datetime from types import TracebackType @@ -510,19 +511,38 @@ def bench(self) -> Optional[BenchResult]: @property def bench_status(self) -> Optional[BenchStatus]: """ - Issue B/N (refs #506): full bench probe state, including - explicit lifecycle [`phase`], `started_at` / `completed_at` - timing, and a human-readable `message`. + DEPRECATED (basilica-backend#661, SDK-S2): full bench probe state. - Returns `None` when bench is off (`mode != on-start`) or the - operator has not yet observed the Job at all. Otherwise the - returned [`BenchStatus`] reflects the operator's - `status.distributed.bench` exactly. + Use ``training.bench`` (``BenchResult | None``) for the result + payload, and ``training.bench_diagnostics`` (``dict | None``) + for the rarely-needed debug detail (phase / message / timings). + This four-property enum surface (``is_terminal`` / ``is_successful`` + / ``is_failed`` / ``is_skipped``) is being collapsed because the + diagnostic is OPT-IN: most callers want a yes/no answer, not a + four-phase ceremony. - DECOUPLING CONTRACT: this surface is INDEPENDENT of - `WorldStatus.ready`. `wait_until_min_world` polls workers only; - bench has its own opt-in waiter, `wait_until_bench_complete`. + Returns ``None`` when bench is off (``mode != on-start``) or the + operator has not yet observed the Job at all. Otherwise the + returned :class:`BenchStatus` reflects the operator's + ``status.distributed.bench`` exactly. Removed in the next major. """ + warnings.warn( + "DistributedTraining.bench_status is deprecated and will be " + "removed in the next major. Use training.bench (BenchResult | " + "None) for the result and training.bench_diagnostics (dict | " + "None) for the rare debug detail. See " + "basilica-backend#661 / SDK-S2.", + DeprecationWarning, + stacklevel=2, + ) + return self._bench_status_no_warn + + @property + def _bench_status_no_warn(self) -> Optional[BenchStatus]: + """Internal: read the full BenchStatus without emitting the + deprecation warning. Used by ``bench_diagnostics`` and the + legacy ``wait_until_bench_complete`` waiter so they remain + callable without double-warning the user.""" if self._cached_status is None: self.refresh() bench_raw = (self._cached_status or {}).get("distributed", {}).get("bench") @@ -530,6 +550,51 @@ def bench_status(self) -> Optional[BenchStatus]: return None return BenchStatus.from_status_dict(bench_raw) + @property + def bench_diagnostics(self) -> Optional[Dict[str, Any]]: + """ + basilica-backend#661 / SDK-S2: simplified bench debug surface. + + Returns ``None`` when bench was not requested (``mode != on-start``) + OR when the operator has not yet published a bench block. + Otherwise returns a ``dict`` with the small set of fields a + researcher might inspect when ``training.bench`` is ``None`` and + they want to know WHY the probe didn't measure: + + - ``mode``: ``"on-start"`` (the only non-off value the user can + set) or ``"off"``. + - ``phase``: operator's bench-Job lifecycle phase + (``"Pending"`` / ``"Running"`` / ``"Succeeded"`` / + ``"Failed"`` / ``"TimedOut"`` / ``"Skipped"``). + - ``message``: human-readable reason from the operator. + - ``started_at`` / ``completed_at``: timing. + - ``last_attempt_at`` / ``last_attempt_outcome``: most recent + attempt outcome (e.g. ``"skipped"`` when workers exited + before the bench-controller observed them). + + Most users only ever read ``training.bench``; this attribute is + the escape hatch for the rare case where they need to debug a + ``None`` result. Replaces the multi-property ``BenchStatus`` + enum exposed via ``bench_status``. + """ + bs = self._bench_status_no_warn + if bs is None: + return None + # Mode=off means "user didn't ask for a probe". Surfacing + # diagnostics in that case would be confusing -- there's nothing + # to debug. Collapse to None. + if bs.mode == "off": + return None + return { + "mode": bs.mode, + "phase": bs.phase, + "message": bs.message, + "started_at": bs.started_at, + "completed_at": bs.completed_at, + "last_attempt_at": bs.last_attempt_at, + "last_attempt_outcome": bs.last_attempt_outcome, + } + def metrics(self) -> DistributedMetrics: """Platform-side metric snapshot for this UD (SDK arch § 6).""" if self._cached_status is None: @@ -709,44 +774,36 @@ async def wait_until_min_world_async(self, timeout: int = 300) -> None: def wait_until_bench_complete(self, timeout: int = 1500) -> Optional[BenchStatus]: """ - Issue B/N (refs #506): OPT-IN waiter for the bench probe. - - Most callers do NOT need this; bench is best-effort per - `basilica_bench.py:75-78` and the world is considered ready - when workers reach Ready (see :py:meth:`wait_until_min_world`). - This method exists for the small set of callers that DO want - the bench measurement before continuing (e.g. a research run - that records busbw metadata into checkpoint headers). - - Returns the terminal :py:class:`BenchStatus` once `phase` - enters a terminal state. Closes #480: the four terminal phases - are: - - - `Succeeded` -- bench probe produced a measurement; the result - payload is on `BenchStatus.result` (mirrored on - `DistributedTraining.bench` for backward compatibility). - - `Failed` -- bench probe ran but errored. See `message`. - - `TimedOut` -- bench probe's own deadline elapsed before - completion. See `message`. - - `Skipped` -- the operator decided not to run the bench probe - (e.g. workers exited before the bench-controller observed - them). See `message` for the reason. NOT a failure; the - workload itself may have completed cleanly. - - If `bench.mode != on-start` returns `None` immediately (nothing - to wait on). Polls every 5s. - - Default timeout matches the operator's `BENCH_ACTIVE_DEADLINE_SECONDS` - (1500s = 25 min). Raises `TimeoutError` past the deadline if the - bench has not reached a terminal phase. Callers handling a - terminal `Skipped` should branch on `bs.is_skipped` (or - `bs.phase != "Succeeded"`); see `examples/20_distributed_diloco.py` - and `examples/22_distributed_with_bench.py`. + DEPRECATED (basilica-backend#661, SDK-S2): OPT-IN waiter for the bench probe. + + Use ``training.bench`` (returns ``BenchResult | None``) after the + UD reaches a terminal state. The four-phase explicit-wait + ceremony is being collapsed: most callers want to know whether + the probe measured (``bench is not None``), not which of the + four terminal phases it landed in. For the rare debug path use + ``training.bench_diagnostics`` (dict with phase / message / + timings). + + Remains functional for two minor versions. Returns the terminal + :class:`BenchStatus` once ``phase`` enters a terminal state + (``Succeeded`` / ``Failed`` / ``TimedOut`` / ``Skipped``). + ``mode != on-start`` returns ``None`` immediately. Raises + ``TimeoutError`` if the bench has not reached a terminal phase + within ``timeout`` seconds. Polls every 5s. """ + warnings.warn( + "DistributedTraining.wait_until_bench_complete is deprecated " + "and will be removed in the next major. Use training.bench " + "(BenchResult | None) after the UD reaches a terminal state, " + "and training.bench_diagnostics for debug detail. See " + "basilica-backend#661 / SDK-S2.", + DeprecationWarning, + stacklevel=2, + ) deadline = time.monotonic() + max(timeout, 0) while time.monotonic() < deadline: self.refresh() - bs = self.bench_status + bs = self._bench_status_no_warn if bs is None: # bench is Off — nothing to wait for. return None @@ -755,7 +812,7 @@ def wait_until_bench_complete(self, timeout: int = 1500) -> Optional[BenchStatus time.sleep(min(5, max(timeout // 10, 1))) # One last refresh. self.refresh() - bs = self.bench_status + bs = self._bench_status_no_warn if bs is not None and bs.is_terminal: return bs raise TimeoutError( @@ -767,22 +824,32 @@ async def wait_until_bench_complete_async( self, timeout: int = 1500 ) -> Optional[BenchStatus]: """ - Async variant of :py:meth:`wait_until_bench_complete`. + DEPRECATED (basilica-backend#661, SDK-S2): async variant of + :py:meth:`wait_until_bench_complete`. - Same terminal-phase set: `Succeeded` | `Failed` | `TimedOut` | - `Skipped` (closes #480). Same `None`-on-bench-off semantics. + Use ``training.bench`` post-terminal instead. Same + ``None``-on-bench-off semantics; removed in the next major. """ + warnings.warn( + "DistributedTraining.wait_until_bench_complete_async is " + "deprecated and will be removed in the next major. Use " + "training.bench (BenchResult | None) after the UD reaches " + "a terminal state, and training.bench_diagnostics for debug " + "detail. See basilica-backend#661 / SDK-S2.", + DeprecationWarning, + stacklevel=2, + ) deadline = asyncio.get_event_loop().time() + max(timeout, 0) while asyncio.get_event_loop().time() < deadline: await self.refresh_async() - bs = self.bench_status + bs = self._bench_status_no_warn if bs is None: return None if bs.is_terminal: return bs await asyncio.sleep(min(5, max(timeout // 10, 1))) await self.refresh_async() - bs = self.bench_status + bs = self._bench_status_no_warn if bs is not None and bs.is_terminal: return bs raise TimeoutError( diff --git a/crates/basilica-sdk-python/tests/test_bench_bool_simplification.py b/crates/basilica-sdk-python/tests/test_bench_bool_simplification.py new file mode 100644 index 00000000..8be2cf74 --- /dev/null +++ b/crates/basilica-sdk-python/tests/test_bench_bool_simplification.py @@ -0,0 +1,500 @@ +""" +Unit tests pinning the simplified bench API surface +(basilica-backend issue 661 / SDK-S2). + +WHY this file exists (read the issue body for the full plan): + +Today the SDK exposes too much state for an opt-in measurement helper: +- ``bench: str`` modes (``"on-start"`` / ``"off"``) -- two string tokens + to memorize for a binary opt-in. +- ``training.wait_until_bench_complete(timeout=...)`` -- raises + ``TimeoutError`` or returns a four-phase ``BenchStatus``. +- ``BenchStatus`` with four terminal phases (``Succeeded`` / ``Failed`` / + ``TimedOut`` / ``Skipped``) and ``is_terminal`` / ``is_successful`` / + ``is_failed`` / ``is_skipped`` properties. +- Two access paths to the result (``bench_status.result`` vs + ``training.bench``). + +Target after S2 (per +``docs/plans/SDK-API-SIMPLIFICATION-PLAN.md`` on basilica-backend +``main``): + +- ``bench: bool`` -- ``True`` opts in, ``False`` opts out. The string + values ``"on-start"`` / ``"off"`` remain accepted for backward-compat + with a ``DeprecationWarning`` pointing at the bool form. +- ``training.bench`` returns ``BenchResult | None`` (unchanged; lazy). +- ``training.bench_diagnostics`` returns ``Optional[Dict[str, Any]]`` -- + a small dict with ``phase``, ``message``, ``mode``, + ``started_at`` / ``completed_at`` / ``last_attempt_at`` / + ``last_attempt_outcome`` for the rare caller who needs to know WHY + the probe did not measure. Most users only read ``training.bench``. +- ``wait_until_bench_complete[_async]`` and direct ``BenchStatus`` use + remain functional but emit ``DeprecationWarning`` pointing at + ``training.bench`` / ``training.bench_diagnostics``. + +These tests: +1. PRE-FIX: fail (today's SDK rejects ``bench=True`` as an invalid + string, has no ``bench_diagnostics`` attribute, and emits no + deprecation warning on the string-mode form). +2. POST-FIX: pass. + +Stubbing pattern mirrors ``test_deploy_distributed_managed.py``: bypass +``BasilicaClient.__init__`` and stub the PyO3 binding so no auth / +network calls fire. +""" + +import warnings +from typing import Any, Dict +from unittest.mock import MagicMock + +import pytest + +from basilica import ( + BasilicaClient, + BenchResult, + DistributedTraining, + WorldSize, +) + + +# ============================================================================= +# Helpers. +# ============================================================================= + + +def _make_client_with_stub( + name: str = "dlc-bench-bool-test", + namespace: str = "u-test", +) -> BasilicaClient: + """Build a BasilicaClient whose PyO3 binding is fully stubbed. + + Bypasses ``BasilicaClient.__init__`` to avoid the auth bootstrap. + """ + client = BasilicaClient.__new__(BasilicaClient) + inner = MagicMock() + + create_response = MagicMock() + create_response.instance_name = name + inner.create_distributed_deployment = MagicMock(return_value=create_response) + + get_response = MagicMock() + get_response.namespace = namespace + get_response.instance_name = name + get_response.image = "ghcr.io/example/trainer:latest" + get_response.phase = "ready" + get_response.message = None + get_response.share_token = None + get_response.share_url = None + get_response.public_metadata = None + # Workers Ready immediately so wait_until_min_world returns + # without blocking the unit test. + get_response.distributed = { + "worldSize": { + "ready": 2, + "target": 2, + "min": 2, + "max": 4, + "belowMinimum": False, + }, + } + inner.get_deployment = MagicMock(return_value=get_response) + inner.delete_deployment = MagicMock(return_value=None) + inner.scale_distributed_deployment = MagicMock(return_value=None) + + client._client = inner + return client + + +def _deploy_kwargs() -> Dict[str, Any]: + """Minimum kwargs that exercise the ``bench=`` parameter.""" + return { + "name": "dlc-bench-bool-test", + "image": "ghcr.io/example/trainer:latest", + "world_size": WorldSize(min=2, target=2, max=4), + "command": ["python3", "/workspace/noop.py"], + # `timeout=0` is fine: the stubbed status reports min ranks ready, + # so wait_until_min_world returns immediately. + "timeout": 0, + } + + +def _stub_bench_block( + phase: str, + *, + mode: str = "on-start", + message: str | None = None, + with_result: bool = False, +) -> Dict[str, Any]: + """Build a ``status.distributed.bench`` block mirroring the operator's + wire shape. Used by tests that exercise the lazy ``training.bench`` and + ``training.bench_diagnostics`` surfaces.""" + bench: Dict[str, Any] = {"mode": mode, "phase": phase} + if phase != "Pending": + bench["startedAt"] = "2026-05-18T00:46:25Z" + if phase in {"Succeeded", "Failed", "TimedOut"}: + bench["completedAt"] = "2026-05-18T01:01:31Z" + if phase == "Skipped": + bench["lastAttemptAt"] = "2026-05-18T01:15:50Z" + bench["lastAttemptOutcome"] = "skipped" + if message is not None: + bench["message"] = message + if with_result: + bench["result"] = { + "measuredAt": "2026-05-18T01:00:00Z", + "busbwGbpsP50": 12.345, + "busbwGbpsP10": 10.0, + "busbwGbpsP90": 15.0, + "algbwGbpsP50": 10.0, + "latencyUsAt1mib": 50.0, + "sizeBytesSwept": [1048576, 16777216], + "probeNodeA": "node-a", + "probeNodeB": "node-b", + } + return bench + + +def _stub_response_with_bench( + name: str, + namespace: str, + bench_block: Dict[str, Any] | None, +) -> Any: + """PyO3-shape fake DeploymentResponse with optional bench block.""" + + class FakeDeployment: + instance_name = name + user_id = namespace + image = "ghcr.io/example/trainer:latest" + state = "running" + url = "https://x" + created_at = "2026-05-18T00:46:25Z" + updated_at = "2026-05-18T01:15:50Z" + phase = "succeeded" + message = None + share_token = None + share_url = None + public_metadata = False + distributed: Dict[str, Any] = { + "worldSize": { + "ready": 2, + "target": 2, + "min": 2, + "max": 2, + "belowMinimum": False, + }, + "ranks": [], + } + + fake = FakeDeployment() + fake.namespace = namespace + if bench_block is not None: + fake.distributed = {**fake.distributed, "bench": bench_block} + return fake + + +# ============================================================================= +# A. ``bench=bool`` acceptance. +# ============================================================================= + + +class TestBenchBoolAcceptance: + """``bench=True`` and ``bench=False`` must be accepted without warnings or errors.""" + + def test_bench_true_accepted_without_warning(self) -> None: + client = _make_client_with_stub() + with warnings.catch_warnings(record=True) as recorded: + warnings.simplefilter("always") + training = client.deploy_distributed(bench=True, **_deploy_kwargs()) + assert isinstance(training, DistributedTraining) + # No DeprecationWarning expected for the canonical bool form. + deprecations = [ + w for w in recorded if issubclass(w.category, DeprecationWarning) + ] + assert deprecations == [], ( + f"bench=True must NOT raise DeprecationWarning, got " + f"{[str(w.message) for w in deprecations]!r}" + ) + + def test_bench_false_accepted_without_warning(self) -> None: + client = _make_client_with_stub() + with warnings.catch_warnings(record=True) as recorded: + warnings.simplefilter("always") + training = client.deploy_distributed(bench=False, **_deploy_kwargs()) + assert isinstance(training, DistributedTraining) + deprecations = [ + w for w in recorded if issubclass(w.category, DeprecationWarning) + ] + assert deprecations == [], ( + f"bench=False must NOT raise DeprecationWarning, got " + f"{[str(w.message) for w in deprecations]!r}" + ) + + def test_bench_true_emits_on_start_on_the_wire(self) -> None: + """``bench=True`` -> request body has ``distributed.bench.mode='on-start'``.""" + client = _make_client_with_stub() + client.deploy_distributed(bench=True, **_deploy_kwargs()) + sent_payload = client._client.create_distributed_deployment.call_args.args[0] + assert sent_payload["distributed"]["bench"]["mode"] == "on-start" + + def test_bench_false_emits_off_on_the_wire(self) -> None: + """``bench=False`` -> request body has ``distributed.bench.mode='off'``.""" + client = _make_client_with_stub() + client.deploy_distributed(bench=False, **_deploy_kwargs()) + sent_payload = client._client.create_distributed_deployment.call_args.args[0] + assert sent_payload["distributed"]["bench"]["mode"] == "off" + + def test_bench_default_is_off(self) -> None: + """Omitting ``bench`` (default) emits ``mode=off`` -- no probe scheduled.""" + client = _make_client_with_stub() + client.deploy_distributed(**_deploy_kwargs()) + sent_payload = client._client.create_distributed_deployment.call_args.args[0] + assert sent_payload["distributed"]["bench"]["mode"] == "off" + + +# ============================================================================= +# B. ``@basilica.distributed`` decorator accepts bench=bool. +# ============================================================================= + + +class TestDecoratorAcceptsBenchBool: + def test_decorator_factory_accepts_bench_true(self) -> None: + """``@basilica.distributed(bench=True)`` is constructible without raising.""" + from basilica import distributed + decorator = distributed( + name="dlc-decorator-bench-test", + world_size=WorldSize(min=2, target=2, max=2), + bench=True, + ) + # The decorator returns a callable that wraps a function; we + # do not need to deploy here -- the kwargs were captured. + assert callable(decorator) + + def test_decorator_factory_accepts_bench_false(self) -> None: + from basilica import distributed + decorator = distributed( + name="dlc-decorator-bench-test", + world_size=WorldSize(min=2, target=2, max=2), + bench=False, + ) + assert callable(decorator) + + +# ============================================================================= +# C. ``bench=str`` form is deprecated (still accepted). +# ============================================================================= + + +class TestBenchStrDeprecation: + def test_bench_on_start_str_emits_deprecation_warning(self) -> None: + client = _make_client_with_stub() + with pytest.warns(DeprecationWarning, match=r"bench=True"): + client.deploy_distributed(bench="on-start", **_deploy_kwargs()) + + def test_bench_off_str_emits_deprecation_warning(self) -> None: + client = _make_client_with_stub() + with pytest.warns(DeprecationWarning, match=r"bench=False"): + client.deploy_distributed(bench="off", **_deploy_kwargs()) + + def test_bench_str_still_passed_through_to_wire(self) -> None: + """``bench='on-start'`` still works (back-compat): wire token is preserved.""" + client = _make_client_with_stub() + with pytest.warns(DeprecationWarning): + client.deploy_distributed(bench="on-start", **_deploy_kwargs()) + sent_payload = client._client.create_distributed_deployment.call_args.args[0] + assert sent_payload["distributed"]["bench"]["mode"] == "on-start" + + +# ============================================================================= +# D. ``training.bench_diagnostics`` (new simplified debug surface). +# ============================================================================= + + +class TestBenchDiagnostics: + """``training.bench_diagnostics`` is the rarely-needed debug accessor. + + Returns ``None`` when bench wasn't requested (mode=off) OR no + operator status block; otherwise a dict with ``phase`` / ``message`` / + timestamp keys. Replaces the four-property ``BenchStatus`` enum + ceremony for the common case. + """ + + def test_diagnostics_attribute_exists(self) -> None: + """``DistributedTraining.bench_diagnostics`` is a public attribute.""" + assert hasattr(DistributedTraining, "bench_diagnostics"), ( + "training.bench_diagnostics is the simplified debug surface " + "for SDK-S2; it must exist on the class even before instance " + "creation. See basilica-backend#661." + ) + + def test_diagnostics_returns_none_when_bench_off(self) -> None: + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-off", "u-test", bench_block={"mode": "off", "phase": "Skipped"} + ) + training = DistributedTraining(client, "ud-bench-off") + # With mode=off, the diagnostics surface should report None to + # the user (the operator publishes a Skipped block for bookkeeping + # but the user didn't ask for the probe). + diag = training.bench_diagnostics + assert diag is None, ( + f"bench_diagnostics must be None when mode=off, got {diag!r}" + ) + + def test_diagnostics_returns_none_when_no_bench_block(self) -> None: + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-no-bench", "u-test", bench_block=None + ) + training = DistributedTraining(client, "ud-no-bench") + assert training.bench_diagnostics is None + + def test_diagnostics_returns_dict_when_bench_on_start_skipped(self) -> None: + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-skipped", + "u-test", + bench_block=_stub_bench_block( + "Skipped", + message="workers exited before bench-controller observed them", + ), + ) + training = DistributedTraining(client, "ud-bench-skipped") + diag = training.bench_diagnostics + assert diag is not None, "bench=on-start + Skipped -> diagnostics must be non-None" + assert isinstance(diag, dict) + assert diag["phase"] == "Skipped" + assert diag["mode"] == "on-start" + assert "workers exited" in diag["message"] + assert "last_attempt_at" in diag + assert "last_attempt_outcome" in diag + assert diag["last_attempt_outcome"] == "skipped" + + def test_diagnostics_returns_dict_when_bench_succeeded(self) -> None: + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-succ", + "u-test", + bench_block=_stub_bench_block("Succeeded", with_result=True), + ) + training = DistributedTraining(client, "ud-bench-succ") + diag = training.bench_diagnostics + assert diag is not None + assert diag["phase"] == "Succeeded" + assert diag["mode"] == "on-start" + assert "started_at" in diag + assert "completed_at" in diag + + +# ============================================================================= +# E. ``training.bench`` (unchanged) — lazy ``BenchResult | None``. +# ============================================================================= + + +class TestTrainingBenchLazyResult: + """``training.bench`` collapses all four non-Succeeded terminal phases + to ``None``. The user reads "did we measure?" with a single + ``if training.bench is not None`` check; the four-phase ceremony + moves to ``bench_diagnostics`` for debugging.""" + + def test_bench_returns_result_on_succeeded(self) -> None: + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-succ", + "u-test", + bench_block=_stub_bench_block("Succeeded", with_result=True), + ) + training = DistributedTraining(client, "ud-bench-succ") + assert training.bench is not None + assert isinstance(training.bench, BenchResult) + assert training.bench.busbw_gbps_p50 == 12.345 + + def test_bench_returns_none_on_skipped(self) -> None: + """Skipped means "no measurement" -- the user reads ``None``.""" + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-skipped", + "u-test", + bench_block=_stub_bench_block("Skipped"), + ) + training = DistributedTraining(client, "ud-bench-skipped") + assert training.bench is None + + def test_bench_returns_none_on_failed(self) -> None: + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-failed", + "u-test", + bench_block=_stub_bench_block("Failed", message="probe crashed"), + ) + training = DistributedTraining(client, "ud-bench-failed") + assert training.bench is None + + def test_bench_returns_none_on_timed_out(self) -> None: + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-timeout", + "u-test", + bench_block=_stub_bench_block("TimedOut", message="deadline elapsed"), + ) + training = DistributedTraining(client, "ud-bench-timeout") + assert training.bench is None + + +# ============================================================================= +# F. ``wait_until_bench_complete`` is deprecated. +# ============================================================================= + + +class TestWaitUntilBenchCompleteDeprecated: + """Per SDK-S2: ``wait_until_bench_complete`` remains functional for + two minor versions but emits ``DeprecationWarning`` pointing at the + lazy ``training.bench`` accessor.""" + + def test_wait_until_bench_complete_emits_deprecation_warning(self) -> None: + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-deprecated", + "u-test", + bench_block=_stub_bench_block("Succeeded", with_result=True), + ) + training = DistributedTraining(client, "ud-bench-deprecated") + with pytest.warns(DeprecationWarning, match=r"training\.bench"): + bs = training.wait_until_bench_complete(timeout=5) + # Still functional: it returns the BenchStatus. + assert bs is not None + assert bs.phase == "Succeeded" + + def test_wait_until_bench_complete_async_emits_deprecation_warning(self) -> None: + import asyncio + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-deprecated-async", + "u-test", + bench_block=_stub_bench_block("Succeeded", with_result=True), + ) + training = DistributedTraining(client, "ud-bench-deprecated-async") + with pytest.warns(DeprecationWarning, match=r"training\.bench"): + asyncio.run(training.wait_until_bench_complete_async(timeout=5)) + + +# ============================================================================= +# G. ``bench_status`` accessor remains BUT emits DeprecationWarning. +# ============================================================================= + + +class TestBenchStatusAccessorDeprecated: + """``DistributedTraining.bench_status`` keeps working for back-compat + but emits DeprecationWarning pointing at ``training.bench`` / + ``training.bench_diagnostics``.""" + + def test_bench_status_emits_deprecation_warning(self) -> None: + client = MagicMock() + client.get.return_value = _stub_response_with_bench( + "ud-bench-status", + "u-test", + bench_block=_stub_bench_block("Succeeded", with_result=True), + ) + training = DistributedTraining(client, "ud-bench-status") + with pytest.warns(DeprecationWarning, match=r"bench_diagnostics"): + bs = training.bench_status + assert bs is not None + assert bs.phase == "Succeeded"