diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 803841964..b0a280523 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -124,16 +124,29 @@ def get_plan(name: str) -> PlanModel: return PlanModel.from_plan(context().plans[name]) -def get_devices() -> list[DeviceModel]: - """Get all available devices in the BlueskyContext""" - return [DeviceModel.from_device(device) for device in context().devices.values()] - - def get_device(name: str) -> DeviceModel: """Retrieve device by name from the BlueskyContext""" return DeviceModel.from_device(context().devices[name]) +def get_devices(interface_name: str | None = None) -> list[DeviceModel]: + if interface_name is None: + return [ + DeviceModel.from_device(device) for device in context().devices.values() + ] + """Retrieve device by protocol from the BlueskyContext""" + interface_class = globals().get(interface_name) + if interface_class is None: + return [] + devices = context().devices + results: list[DeviceModel] = [] + for device in devices.values(): + if isinstance(device, interface_class): + results.append(DeviceModel.from_device(device)) + + return results + + def submit_task(task: Task) -> str: """Submit a task to be run on begin_task""" return worker().submit_task(task) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 3e779c518..9c8dd2efa 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -1,5 +1,6 @@ from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager +from typing import Annotated import jwt from fastapi import ( @@ -9,6 +10,7 @@ Depends, FastAPI, HTTPException, + Query, Request, Response, status, @@ -199,9 +201,17 @@ def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)): @router.get("/devices", response_model=DeviceResponse) @start_as_current_span(TRACER) -def get_devices(runner: WorkerDispatcher = Depends(_runner)): +def get_devices( + runner: WorkerDispatcher = Depends(_runner), + expected_type: Annotated[str | None, Query(alias="type")] = Query( + None, description="Filter devices by protocol name" + ), +): """Retrieve information about all available devices.""" - devices = runner.run(interface.get_devices) + if expected_type: + devices = runner.run(lambda: interface.get_devices(expected_type)) + else: + devices = runner.run(interface.get_devices) return DeviceResponse(devices=devices) diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index b4ee2d3b5..4244e743a 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -151,6 +151,35 @@ def test_get_device(context_mock: MagicMock): assert interface.get_device("non_existing_device") +@patch("blueapi.service.interface.context") +def test_get_devices_by_protocol(context_mock: MagicMock): + context = BlueskyContext() + context.register_device(SynAxis(name="my_axis")) + context_mock.return_value = context + + assert interface.get_devices("Pausable") == [ + DeviceModel( + name="my_axis", + protocols=[ + "Checkable", + "HasHints", + "HasName", + "HasParent", + "Movable", + "Pausable", + "Readable", + "Stageable", + "Stoppable", + "Subscribable", + "Configurable", + "Triggerable", + ], + ), + ] + + assert interface.get_devices("non_existing_interface") == [] + + @patch("blueapi.service.interface.context") def test_submit_task(context_mock: MagicMock): context = BlueskyContext() diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 0ccf160e3..ac5900bae 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -7,6 +7,7 @@ import pytest from fastapi import status from fastapi.testclient import TestClient +from ophyd.sim import SynAxis from pydantic import BaseModel, ValidationError from pydantic_core import InitErrorDetails from super_state_machine.errors import TransitionError @@ -160,6 +161,39 @@ class MyDevice: } +def test_get_devices_by_protocol(mock_runner: Mock, client: TestClient) -> None: + sya = SynAxis(name="my_axis") + mock_runner.run.return_value = DeviceModel.from_device(sya) + response = client.get("/devices?protocol_name=Pausable") + + mock_runner.run.assert_called_once_with(test_get_devices_by_protocol, "Pausable") + assert response.status_code == status.HTTP_200_OK + assert response.json() == { + "name": "my_axis", + "protocols": [ + "Checkable", + "HasHints", + "HasName", + "HasParent", + "Movable", + "Pausable", + "Readable", + "Stageable", + "Stoppable", + "Subscribable", + "Configurable", + "Triggerable", + ], + } + + response = client.get("/devices?protocol_name=Non_Existing_Protocol") + mock_runner.run.assert_called_once_with( + test_get_devices_by_protocol, "Non_Existing_Protocol" + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() == {} + + def test_get_non_existent_device_by_name(mock_runner: Mock, client: TestClient) -> None: mock_runner.run.side_effect = KeyError("my-device") response = client.get("/devices/my-device")