diff --git a/docs/reference/openapi.yaml b/docs/reference/openapi.yaml index aaf125f15..ef63f097b 100644 --- a/docs/reference/openapi.yaml +++ b/docs/reference/openapi.yaml @@ -4,6 +4,9 @@ components: additionalProperties: false description: Representation of a device properties: + address: + title: Address + type: string name: description: Name of the device title: Name @@ -17,6 +20,7 @@ components: required: - name - protocols + - address title: DeviceModel type: object DeviceResponse: @@ -340,7 +344,7 @@ components: type: object info: title: BlueAPI Control - version: 0.0.10 + version: 0.0.11 openapi: 3.1.0 paths: /config/oidc: @@ -361,6 +365,16 @@ paths: get: description: Retrieve information about all available devices. operationId: get_devices_devices_get + parameters: + - description: Maximum depth of children to return, -1 for all + in: query + name: max_depth + required: false + schema: + default: 0 + minimum: 0 + title: Max Depth + type: integer responses: '200': content: @@ -368,6 +382,12 @@ paths: schema: $ref: '#/components/schemas/DeviceResponse' description: Successful Response + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + description: Validation Error summary: Get Devices /devices/{name}: get: diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index f2996af89..2a753801b 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -164,11 +164,19 @@ def get_plans(obj: dict) -> None: @controller.command(name="devices") @check_connection +@click.option( + "-d", + "--max_depth", + type=click.IntRange(-1), + required=False, + help="Maximum depth of children to return: -1 for all", + default=0, +) @click.pass_obj -def get_devices(obj: dict) -> None: +def get_devices(obj: dict, max_depth: int) -> None: """Get a list of devices available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display(client.get_devices()) + obj["fmt"].display(client.get_devices(max_depth)) @controller.command(name="listen") diff --git a/src/blueapi/cli/format.py b/src/blueapi/cli/format.py index 490e57cfa..087a4db28 100644 --- a/src/blueapi/cli/format.py +++ b/src/blueapi/cli/format.py @@ -12,6 +12,7 @@ from blueapi.core.bluesky_types import DataEvent from blueapi.service.model import ( + DeviceModel, DeviceResponse, PlanResponse, PythonEnvironmentResponse, @@ -62,7 +63,7 @@ def display_full(obj: Any, stream: Stream): print(indent(json.dumps(schema, indent=2), " ")) case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) + print(_format_name(dev)) for proto in dev.protocols: print(f" {proto}") case DataEvent(name=name, doc=doc): @@ -124,7 +125,7 @@ def display_compact(obj: Any, stream: Stream): print(f" {arg}={_describe_type(spec, req)}") case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) + print(_format_name(dev)) print( indent( textwrap.fill( @@ -164,6 +165,12 @@ def display_compact(obj: Any, stream: Stream): FALLBACK(other, stream=stream) +def _format_name(device: DeviceModel) -> str: + if not device.address or device.address == device.name: + return device.name + return f"{device.name} @ {device.address}" + + def _describe_type(spec: dict[Any, Any], required: bool = False): disp = "" match spec.get("type"): diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index bca532ed5..fac046406 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -90,8 +90,8 @@ def get_plan(self, name: str) -> PlanModel: """ return self._rest.get_plan(name) - @start_as_current_span(TRACER) - def get_devices(self) -> DeviceResponse: + @start_as_current_span(TRACER, "max_depth") + def get_devices(self, max_depth: int) -> DeviceResponse: """ List devices available @@ -99,7 +99,7 @@ def get_devices(self) -> DeviceResponse: DeviceResponse: Devices that can be used in plans """ - return self._rest.get_devices() + return self._rest.get_devices(max_depth) @start_as_current_span(TRACER, "name") def get_device(self, name: str) -> DeviceModel: diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index fcbb55073..19d605313 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -71,8 +71,10 @@ def get_plans(self) -> PlanResponse: def get_plan(self, name: str) -> PlanModel: return self._request_and_deserialize(f"/plans/{name}", PlanModel) - def get_devices(self) -> DeviceResponse: - return self._request_and_deserialize("/devices", DeviceResponse) + def get_devices(self, max_depth: int) -> DeviceResponse: + return self._request_and_deserialize( + "/devices", DeviceResponse, params={"max_depth": max_depth} + ) def get_device(self, name: str) -> DeviceModel: return self._request_and_deserialize(f"/devices/{name}", DeviceModel) diff --git a/src/blueapi/core/device_lookup.py b/src/blueapi/core/device_lookup.py index 72dff32f0..565fed60f 100644 --- a/src/blueapi/core/device_lookup.py +++ b/src/blueapi/core/device_lookup.py @@ -1,5 +1,7 @@ from typing import Any +from ophyd_async.core import DeviceVector + from .bluesky_types import Device, is_bluesky_compatible_device @@ -28,6 +30,8 @@ def find_component(obj: Any, addr: list[str]) -> Device | None: # Otherwise, we error. if isinstance(obj, dict): component = obj.get(head) + elif isinstance(obj, DeviceVector): + component = obj.get(int(head)) elif is_bluesky_compatible_device(obj): component = getattr(obj, head, None) else: diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 89d08b76b..33a056212 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -177,9 +177,14 @@ def get_plan(name: str) -> PlanModel: return PlanModel.from_plan(context().plans[name]) -def get_devices() -> list[DeviceModel]: +def get_devices(max_depth: int) -> list[DeviceModel]: """Get all available devices in the BlueskyContext""" - return [DeviceModel.from_device(device) for device in context().devices.values()] + return [ + model + for device in context().devices.values() + for model in DeviceModel.from_device_tree(device, max_depth) + if model.protocols + ] def get_device(name: str) -> DeviceModel: diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index e87f0196e..783b2cc96 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -11,6 +11,7 @@ Depends, FastAPI, HTTPException, + Query, Request, Response, status, @@ -54,7 +55,7 @@ from .runner import WorkerDispatcher #: API version to publish in OpenAPI schema -REST_API_VERSION = "0.0.10" +REST_API_VERSION = "0.0.11" RUNNER: WorkerDispatcher | None = None @@ -231,9 +232,18 @@ def get_plan_by_name( @start_as_current_span(TRACER) def get_devices( runner: Annotated[WorkerDispatcher, Depends(_runner)], + max_depth: Annotated[ + int, + Query( + description="Maximum depth of children to return, -1 for all", + ge=0, + # https://github.com/fastapi/fastapi/discussions/13473 + json_schema_extra={"description": None}, + ), + ] = 0, ) -> DeviceResponse: """Retrieve information about all available devices.""" - devices = runner.run(interface.get_devices) + devices = runner.run(interface.get_devices, max_depth) return DeviceResponse(devices=devices) diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index e49e6f4a2..41a133da7 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -4,6 +4,8 @@ from typing import Annotated, Any from bluesky.protocols import HasName +from ophyd import Device as SyncDevice +from ophyd_async.core import Device as AsyncDevice from pydantic import Field from pydantic.json_schema import SkipJsonSchema @@ -34,16 +36,69 @@ class DeviceModel(BlueapiBaseModel): protocols: list[ProtocolInfo] = Field( description="Protocols that a device conforms to, indicating its capabilities" ) + address: str @classmethod def from_device(cls, device: Device) -> "DeviceModel": name = device.name if isinstance(device, HasName) else _UNKNOWN_NAME - return cls(name=name, protocols=list(_protocol_info(device))) + return cls(name=name, protocols=list(_protocol_info(device)), address=name) + + @classmethod + def from_device_tree(cls, root: Device, max_depth: int) -> list["DeviceModel"]: + if isinstance(root, AsyncDevice): + return [ + DeviceModel( + name=device.name, + protocols=list(_protocol_info(device)), + address=address, + ) + for address, device in _from_async_device( + root, max_depth=max_depth + ).items() + ] + if isinstance(root, SyncDevice): + return [ + DeviceModel( + name=device.name, + protocols=list(_protocol_info(device)), + address=address, + ) + for address, device in _from_sync_device( + root, max_depth=max_depth + ).items() + ] + return [DeviceModel.from_device(root)] + + +def _from_async_device(root: AsyncDevice, max_depth: int) -> dict[str, AsyncDevice]: + depth = 0 + devices: dict[str, AsyncDevice] = {root.name: root} + branches: dict[str, AsyncDevice] = {root.name: root} + while branches and (max_depth == -1 or depth < max_depth): + leaves: dict[str, AsyncDevice] = {} + for addr, parent in branches.items(): + for suffix, child in parent.children(): + leaves[f"{addr}.{suffix}"] = child + devices.update(leaves) + branches = leaves + depth += 1 + return devices + + +def _from_sync_device(root: SyncDevice, max_depth: int) -> dict[str, SyncDevice]: + return { + root.name: root, + **{ + k.dotted_name: k.item + for k in root.walk_signals() + if max_depth == -1 or len(k.ancestors) <= max_depth + }, + } def _protocol_info(device: Device) -> Iterable[ProtocolInfo]: for protocol in BLUESKY_PROTOCOLS: - if isinstance(device, protocol): + if isinstance(device, protocol) and protocol is not AsyncDevice: yield ProtocolInfo( name=protocol.__name__, types=[arg.__name__ for arg in generic_bounds(device, protocol)], diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 2209ad4ff..8d330490f 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -162,7 +162,7 @@ def test_get_non_existent_plan(client: BlueapiClient): def test_get_devices(client: BlueapiClient, expected_devices: DeviceResponse): - retrieved_devices = client.get_devices() + retrieved_devices = client.get_devices(max_depth=0) retrieved_devices.devices.sort(key=lambda x: x.name) expected_devices.devices.sort(key=lambda x: x.name) diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index 6337d1148..840165fb0 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -35,11 +35,26 @@ PLAN = PlanModel(name="foo") DEVICES = DeviceResponse( devices=[ - DeviceModel(name="foo", protocols=[]), - DeviceModel(name="bar", protocols=[]), + DeviceModel(name="foo", protocols=[], address="foo"), + DeviceModel(name="bar", protocols=[], address="bar"), ] ) -DEVICE = DeviceModel(name="foo", protocols=[]) +DEVICES_AND_CHILDREN = DeviceResponse( + devices=[ + DeviceModel(name="foo", protocols=[], address="foo"), + DeviceModel(name="bar", protocols=[], address="bar"), + DeviceModel(name="foo-bar", protocols=[], address="foo.bar"), + ] +) +DEVICES_AND_ALL_DESCENDENTS = DeviceResponse( + devices=[ + DeviceModel(name="foo", protocols=[], address="foo"), + DeviceModel(name="bar", protocols=[], address="bar"), + DeviceModel(name="foo-bar", protocols=[], address="foo.bar"), + DeviceModel(name="foo-bar-baz", protocols=[], address="foo.bar.baz"), + ] +) +DEVICE = DeviceModel(name="foo", protocols=[], address="foo") TASK = TrackableTask(task_id="foo", task=Task(name="bar", params={})) TASKS = TasksListResponse(tasks=[TASK]) ACTIVE_TASK = WorkerTask(task_id="bar") @@ -67,11 +82,18 @@ @pytest.fixture def mock_rest() -> BlueapiRestClient: + def get_devices(max_depth: int) -> DeviceResponse: + if max_depth == -1 or max_depth > 1: + return DEVICES_AND_ALL_DESCENDENTS + if max_depth == 1: + return DEVICES_AND_CHILDREN + return DEVICES + mock = Mock(spec=BlueapiRestClient) mock.get_plans.return_value = PLANS mock.get_plan.return_value = PLAN - mock.get_devices.return_value = DEVICES + mock.get_devices.side_effect = get_devices mock.get_device.return_value = DEVICE mock.get_state.return_value = WorkerState.IDLE mock.get_task.return_value = TASK @@ -121,7 +143,7 @@ def test_get_nonexistant_plan( def test_get_devices(client: BlueapiClient): - assert client.get_devices() == DEVICES + assert client.get_devices(max_depth=0) == DEVICES def test_get_device(client: BlueapiClient): @@ -511,7 +533,7 @@ def test_get_plan_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClien def test_get_devices_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): with asserting_span_exporter(exporter, "get_devices"): - client.get_devices() + client.get_devices(max_depth=0) def test_get_device_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): diff --git a/tests/unit_tests/service/test_device_model.py b/tests/unit_tests/service/test_device_model.py new file mode 100644 index 000000000..b5fb95ca2 --- /dev/null +++ b/tests/unit_tests/service/test_device_model.py @@ -0,0 +1,81 @@ +import pytest +from ophyd import Device as SyncDevice +from ophyd.sim import SynAxis +from ophyd_async.core import Device as AsyncDevice + +from blueapi.service.model import DeviceModel, ProtocolInfo + + +@pytest.fixture +def sync_device() -> SyncDevice: + class NestedDevice(SyncDevice): + axis = SynAxis(name="branch") + + return NestedDevice(name="root") + + +max_depth: list[int] = [0, 1, 2, 3, -1, -2] + + +root_sync_device = DeviceModel(name="root", protocols=[], address="root") +branch_sync_device = ( + DeviceModel( + name="branch", + protocols=[ + ProtocolInfo(name="Checkable"), + ProtocolInfo(name="Movable"), + ProtocolInfo(name="Pausable"), + ProtocolInfo(name="Readable"), + ProtocolInfo(name="Stageable"), + ProtocolInfo(name="Stoppable"), + ProtocolInfo(name="Subscribable"), + ProtocolInfo(name="Configurable"), + ProtocolInfo(name="Triggerable"), + ], + address="root.axis", + ), +) +leaf_sync_devices = [] + + +@pytest.mark.parametrize("max_depth", max_depth) +def test_ophyd_depth(sync_device: SyncDevice, max_depth: int): + response = DeviceModel.from_device_tree(sync_device, max_depth) + if max_depth == 0 or max_depth < -1: + assert response == [root_sync_device] + elif max_depth == 1: + assert response == [root_sync_device, branch_sync_device] + else: + assert response == [root_sync_device, branch_sync_device, *leaf_sync_devices] + + +root_async_device = DeviceModel(name="root", protocols=[], address="root") +branch_async_device = ( + DeviceModel( + name="branch", + protocols=[ + ProtocolInfo(name="Checkable"), + ProtocolInfo(name="Movable"), + ProtocolInfo(name="Pausable"), + ProtocolInfo(name="Readable"), + ProtocolInfo(name="Stageable"), + ProtocolInfo(name="Stoppable"), + ProtocolInfo(name="Subscribable"), + ProtocolInfo(name="Configurable"), + ProtocolInfo(name="Triggerable"), + ], + address="root.axis", + ), +) +leaf_async_devices = [] + + +@pytest.mark.parametrize("max_depth", max_depth) +def test_ophyd_async_depth(async_device: AsyncDevice, max_depth: int): + response = DeviceModel.from_device_tree(async_device, max_depth) + if max_depth == 0 or max_depth < -1: + assert response == [root_async_device] + elif max_depth == 1: + assert response == [root_async_device, branch_async_device] + else: + assert response == [root_async_device, branch_async_device, *leaf_async_devices] diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index 9237aaf42..efc041765 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -1,5 +1,4 @@ import uuid -from dataclasses import dataclass from unittest.mock import ANY, MagicMock, Mock, patch import pytest @@ -9,6 +8,7 @@ from dodal.common.beamlines.beamline_utils import set_path_provider from dodal.common.visit import StartDocumentPathProvider from ophyd.sim import SynAxis +from ophyd_async.core import DeviceVector from stomp.connect import StompConnection11 as Connection from blueapi.client.numtracker import NumtrackerClient @@ -124,23 +124,28 @@ def test_get_plan(context_mock: MagicMock): interface.get_plan("non_existing_plan") -@dataclass -class MyDevice(Stoppable): - name: str - +class MyDevice(DeviceVector["MyDevice"], Stoppable): def stop(self, success: bool = True) -> None: pass +@pytest.fixture +def family_device() -> MyDevice: + child = MyDevice({1: MyDevice({})}) + return MyDevice({1: child, 2: MyDevice({})}, "root") + + @patch("blueapi.service.interface.context") -def test_get_devices(context_mock: MagicMock): +def test_get_devices(context_mock: MagicMock, family_device: MyDevice): context = BlueskyContext() - context.register_device(MyDevice(name="my_device")) + context.register_device(family_device) context.register_device(SynAxis(name="my_axis")) context_mock.return_value = context - assert interface.get_devices() == [ - DeviceModel(name="my_device", protocols=[ProtocolInfo(name="Stoppable")]), + assert interface.get_devices(max_depth=0) == [ + DeviceModel( + name="root", protocols=[ProtocolInfo(name="Stoppable")], address="root" + ), DeviceModel( name="my_axis", protocols=[ @@ -154,24 +159,64 @@ def test_get_devices(context_mock: MagicMock): ProtocolInfo(name="Configurable"), ProtocolInfo(name="Triggerable"), ], + address="my_axis", ), ] @patch("blueapi.service.interface.context") -def test_get_device(context_mock: MagicMock): +def test_get_device(context_mock: MagicMock, family_device: MyDevice): context = BlueskyContext() - context.register_device(MyDevice(name="my_device")) + context.register_device(family_device) context_mock.return_value = context - assert interface.get_device("my_device") == DeviceModel( - name="my_device", protocols=[ProtocolInfo(name="Stoppable")] + assert interface.get_device("root") == DeviceModel( + name="root", protocols=[ProtocolInfo(name="Stoppable")], address="root" ) with pytest.raises(KeyError): assert interface.get_device("non_existing_device") +@patch("blueapi.service.interface.context") +def test_get_devices_depth(context_mock: MagicMock, family_device: MyDevice): + context = BlueskyContext() + context.register_device(family_device) + context_mock.return_value = context + + root_device = [ + DeviceModel( + name="root", protocols=[ProtocolInfo(name="Stoppable")], address="root" + ) + ] + + root_and_branches = root_device + [ + DeviceModel( + name="root-1", protocols=[ProtocolInfo(name="Stoppable")], address="root.1" + ), + DeviceModel( + name="root-2", protocols=[ProtocolInfo(name="Stoppable")], address="root.2" + ), + ] + + all_devices = root_and_branches + [ + DeviceModel( + name="root-1-1", + protocols=[ProtocolInfo(name="Stoppable")], + address="root.1.1", + ) + ] + + assert interface.get_devices(max_depth=0) == root_device + assert interface.get_devices(max_depth=-2) == root_device + + assert interface.get_devices(max_depth=1) == root_and_branches + + assert interface.get_devices(max_depth=2) == all_devices + assert interface.get_devices(max_depth=3) == all_devices + assert interface.get_devices(max_depth=-1) == all_devices + + @patch("blueapi.service.interface.context") def test_submit_task(context_mock: MagicMock): context = BlueskyContext() diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 50c04036f..67004e844 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -201,6 +201,7 @@ def test_get_devices(mock_runner: Mock, client: TestClient) -> None: "devices": [ { "name": "my-device", + "address": "my-device", "protocols": [{"name": "Stoppable", "types": []}], } ] @@ -217,6 +218,7 @@ def test_get_device_by_name(mock_runner: Mock, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK assert response.json() == { "name": "my-device", + "address": "my-device", "protocols": [{"name": "Stoppable", "types": []}], } diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 00ffe09f1..2f1058c1b 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -11,6 +11,7 @@ import pytest import responses +import responses.matchers import yaml from bluesky.protocols import Movable from bluesky_stomp.messaging import StompClient @@ -164,8 +165,7 @@ def test_get_plans(runner: CliRunner): def test_get_devices(runner: CliRunner): device = MyDevice(name="my-device") - response = responses.add( - responses.GET, + response = responses.get( "http://localhost:8000/devices", json=DeviceResponse(devices=[DeviceModel.from_device(device)]).model_dump(), status=200, @@ -444,7 +444,8 @@ def test_device_output_formatting(): "ComplexType" ] } - ] + ], + "address": "my-device" } ] """)