diff --git a/ucapi/api.py b/ucapi/api.py index 0204db4..6cf76f0 100644 --- a/ucapi/api.py +++ b/ucapi/api.py @@ -78,6 +78,18 @@ class _VoiceSessionContext: handler_task: asyncio.Task | None = None +@dataclass(slots=True) +class _WsContext: + """Websocket context.""" + + incoming: asyncio.Queue[str | bytes | None] + outgoing: asyncio.Queue[str | None] + pending: dict[int, asyncio.Future] + consumer_task: asyncio.Task | None = None + producer_task: asyncio.Task | None = None + router_task: asyncio.Task | None = None + + # pylint: disable=too-many-public-methods, too-many-lines class IntegrationAPI: """Integration API to communicate with Remote Two/3.""" @@ -107,12 +119,18 @@ def __init__(self, loop: AbstractEventLoop | None = None): self._available_entities = Entities("available", self._loop) self._configured_entities = Entities("configured", self._loop) + self._req_id = 1 # Request ID counter for outgoing requests + self._voice_handler: VoiceStreamHandler | None = None self._voice_session_timeout: int = self.DEFAULT_VOICE_SESSION_TIMEOUT_S # Active voice sessions self._voice_sessions: dict[VoiceSessionKey, _VoiceSessionContext] = {} # Enforce: at most one active session per entity_id (across all websockets) self._voice_session_by_entity: dict[str, VoiceSessionKey] = {} + # Websocket context with incoming & outgoing queues and handlers + self._ws_contexts: dict[Any, _WsContext] = {} + # Supported entity types + self._supported_entity_types: list[str] | None = None # Setup event loop asyncio.set_event_loop(self._loop) @@ -125,9 +143,7 @@ def _resolve_config_dir() -> str: def _voice_key(websocket: Any, session_id: int) -> VoiceSessionKey: return websocket, int(session_id) - async def init( - self, driver_path: str, setup_handler: uc.SetupHandler | None = None - ): + async def init(self, driver_path: str, setup_handler: uc.SetupHandler | None = None): """ Load driver configuration and start integration-API WebSocket server. @@ -138,9 +154,7 @@ async def init( self._driver_path = driver_path self._setup_handler = setup_handler - self._configured_entities.add_listener( - uc.Events.ENTITY_ATTRIBUTES_UPDATED, self._on_entity_attributes_updated - ) + self._configured_entities.add_listener(uc.Events.ENTITY_ATTRIBUTES_UPDATED, self._on_entity_attributes_updated) # Load driver config with open(self._driver_path, "r", encoding="utf-8") as file: @@ -155,17 +169,13 @@ async def init( _adjust_driver_url(self._driver_info, port) - disable_mdns_publish = os.getenv( - "UC_DISABLE_MDNS_PUBLISH", "false" - ).lower() in ("true", "1") + disable_mdns_publish = os.getenv("UC_DISABLE_MDNS_PUBLISH", "false").lower() in ("true", "1") if disable_mdns_publish is False: # Setup zeroconf service info name = f'{self._driver_info["driver_id"]}._uc-integration._tcp.local.' hostname = local_hostname() - driver_name = _get_default_language_string( - self._driver_info["name"], "Unknown driver" - ) + driver_name = _get_default_language_string(self._driver_info["name"], "Unknown driver") _LOG.debug("Publishing driver: name=%s, host=%s:%d", name, hostname, port) @@ -185,9 +195,7 @@ async def init( await zeroconf.async_register_service(info) host = interface if interface is not None else "0.0.0.0" - self._server_task = self._loop.create_task( - self._start_web_socket_server(host, port) - ) + self._server_task = self._loop.create_task(self._start_web_socket_server(host, port)) _LOG.info( "Driver is up: %s, version: %s, api: %s, listening on: %s:%d", @@ -205,49 +213,62 @@ async def _on_entity_attributes_updated(self, entity_id, entity_type, attributes "attributes": attributes, } - await self._broadcast_ws_event( - uc.WsMsgEvents.ENTITY_CHANGE, data, uc.EventCategory.ENTITY - ) + await self._broadcast_ws_event(uc.WsMsgEvents.ENTITY_CHANGE, data, uc.EventCategory.ENTITY) async def _start_web_socket_server(self, host: str, port: int) -> None: async with serve(self._handle_ws, host, port): await asyncio.Future() async def _handle_ws(self, websocket) -> None: + # Initialize incoming and outgoing queues + incoming: asyncio.Queue[str | bytes | None] = asyncio.Queue(maxsize=100) + outgoing: asyncio.Queue[str | None] = asyncio.Queue(maxsize=100) + + ctx = _WsContext( + incoming=incoming, + outgoing=outgoing, + pending={}, + ) + + self._clients.add(websocket) + self._ws_contexts[websocket] = ctx + try: - self._clients.add(websocket) _LOG.info("WS: Client added: %s", websocket.remote_address) + ctx.consumer_task = self._loop.create_task(self._ws_consumer(websocket, ctx)) + ctx.producer_task = self._loop.create_task(self._ws_producer(websocket, ctx)) + ctx.router_task = self._loop.create_task(self._ws_router(websocket, ctx)) + # authenticate on connection await self._authenticate(websocket, True) - self._events.emit(uc.Events.CLIENT_CONNECTED, websocket=websocket) + tasks = [t for t in [ctx.consumer_task, ctx.producer_task, ctx.router_task] if t is not None] + done, pending = await asyncio.wait( + tasks, + return_when=asyncio.FIRST_COMPLETED, + ) - async for message in websocket: - # Distinguish between text (str) and binary (bytes-like) messages - if isinstance(message, str): - # JSON text message - await self._process_ws_message(websocket, message) - elif isinstance(message, (bytes, bytearray, memoryview)): - # Binary message (protobuf in future) - await self._process_ws_binary_message(websocket, bytes(message)) - else: - _LOG.warning( - "[%s] WS: Unsupported message type %s", - websocket.remote_address, - type(message).__name__, - ) + for task in pending: + task.cancel() + + results = await asyncio.gather(*done, *pending, return_exceptions=True) + for result in results: + if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError): + raise result except ConnectionClosedOK: _LOG.info("[%s] WS: Connection closed", websocket.remote_address) except websockets.exceptions.ConnectionClosedError as e: - # no idea why they made code & reason deprecated... + close = e.rcvd or e.sent + code = getattr(close, "code", None) + reason = getattr(close, "reason", None) _LOG.info( - "[%s] WS: Connection closed with error %d: %s", + "[%s] WS: Connection closed with error %s: %s", websocket.remote_address, - e.code, - e.reason, + code, + reason, ) except websockets.exceptions.WebSocketException as e: @@ -258,26 +279,87 @@ async def _handle_ws(self, websocket) -> None: ) finally: - # Cleanup any active voice sessions associated with this websocket - keys_to_cleanup = [k for k in self._voice_sessions if k[0] is websocket] - for key in keys_to_cleanup: - try: - await self._cleanup_voice_session(key, VoiceEndReason.REMOTE) - except Exception as ex: # pylint: disable=W0718 - _LOG.exception( - "[%s] WS: Error during voice session cleanup for session_id=%s: %s", + await self._cleanup_ws(websocket) + + async def _ws_consumer(self, websocket, ctx: _WsContext) -> None: + """Route incoming message (requests or events from remote or responses to driver).""" + try: + async for raw_message in websocket: + if isinstance(raw_message, str): + try: + data = json.loads(raw_message) + except json.JSONDecodeError: + _LOG.warning( + "[%s] WS: Invalid JSON message: %s", + websocket.remote_address, + raw_message, + ) + continue + + kind = data.get("kind") + + # Handle the response to a previous driver request + if kind == "resp": + self._handle_pending_response(websocket, data) + # Otherwise handle the json request + else: + await ctx.incoming.put(data) + # Handle the binary message + elif isinstance(raw_message, (bytes, bytearray, memoryview)): + await ctx.incoming.put(bytes(raw_message)) + else: + _LOG.warning( + "[%s] WS: Unsupported message type %s", websocket.remote_address, - key[1], - ex, + type(raw_message).__name__, ) + finally: + await ctx.incoming.put(None) + await ctx.outgoing.put(None) - self._clients.remove(websocket) - _LOG.info("[%s] WS: Client removed", websocket.remote_address) - self._events.emit(uc.Events.CLIENT_DISCONNECTED, websocket=websocket) + async def _ws_producer(self, websocket, ctx: _WsContext) -> None: + """Route outgoing messages.""" + try: + while True: + msg = await ctx.outgoing.get() + if msg is None: + break + await websocket.send(msg) + except (ConnectionClosedOK, websockets.exceptions.ConnectionClosedError): + pass + + async def _ws_router(self, websocket, ctx: _WsContext) -> None: + """Route incoming requests.""" + while True: + message = await ctx.incoming.get() + if message is None: + break + if isinstance(message, dict): + await self._process_ws_message(websocket, message) + elif isinstance(message, bytes): + await self._process_ws_binary_message(websocket, message) + else: + _LOG.warning( + "[%s] WS: Unsupported routed message type %s", + websocket.remote_address, + type(message).__name__, + ) - async def _send_ok_result( - self, websocket, req_id: int, msg_data: dict[str, Any] | list | None = None - ) -> None: + def _get_ws_context(self, websocket) -> _WsContext | None: + return self._ws_contexts.get(websocket) + + async def _enqueue_ws_payload(self, websocket, payload: dict[str, Any]) -> None: + ctx = self._get_ws_context(websocket) + if ctx is None or websocket not in self._clients: + _LOG.error("Error sending payload: connection no longer established") + return + + if _LOG.isEnabledFor(logging.DEBUG): + _LOG.debug("[%s] ->: %s", websocket.remote_address, filter_log_msg_data(payload)) + + await ctx.outgoing.put(json.dumps(payload)) + + async def _send_ok_result(self, websocket, req_id: int, msg_data: dict[str, Any] | list | None = None) -> None: """ Send a WebSocket success message with status code OK. @@ -288,9 +370,7 @@ async def _send_ok_result( Raises: websockets.ConnectionClosed: When the connection is closed. """ - await self._send_ws_response( - websocket, req_id, "result", msg_data, uc.StatusCodes.OK - ) + await self._send_ws_response(websocket, req_id, "result", msg_data, uc.StatusCodes.OK) async def _send_error_result( self, @@ -312,7 +392,6 @@ async def _send_error_result( """ await self._send_ws_response(websocket, req_id, "result", msg_data, status_code) - # pylint: disable=R0917 async def _send_ws_response( self, websocket, @@ -340,20 +419,9 @@ async def _send_ws_response( "msg": msg, "msg_data": msg_data if msg_data is not None else {}, } + await self._enqueue_ws_payload(websocket, data) - if websocket in self._clients: - data_dump = json.dumps(data) - if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "[%s] ->: %s", websocket.remote_address, filter_log_msg_data(data) - ) - await websocket.send(data_dump) - else: - _LOG.error("Error sending response: connection no longer established") - - async def _broadcast_ws_event( - self, msg: str, msg_data: dict[str, Any], category: uc.EventCategory - ) -> None: + async def _broadcast_ws_event(self, msg: str, msg_data: dict[str, Any], category: uc.EventCategory) -> None: """ Send the given event-message to all connected WebSocket clients. @@ -365,21 +433,13 @@ async def _broadcast_ws_event( :param category: event category """ data = {"kind": "event", "msg": msg, "msg_data": msg_data, "cat": category} - data_dump = json.dumps(data) - for websocket in self._clients.copy(): - if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "[%s] =>: %s", websocket.remote_address, filter_log_msg_data(data) - ) try: - await websocket.send(data_dump) - except websockets.exceptions.WebSocketException: - pass + await self._enqueue_ws_payload(websocket, data) + except Exception: + _LOG.exception("Failed to enqueue broadcast for %s", websocket.remote_address) - async def _send_ws_event( - self, websocket, msg: str, msg_data: dict[str, Any], category: uc.EventCategory - ) -> None: + async def _send_ws_event(self, websocket, msg: str, msg_data: dict[str, Any], category: uc.EventCategory) -> None: """ Send an event-message to the given WebSocket client. @@ -392,35 +452,117 @@ async def _send_ws_event( websockets.ConnectionClosed: When the connection is closed. """ data = {"kind": "event", "msg": msg, "msg_data": msg_data, "cat": category} - data_dump = json.dumps(data) + await self._enqueue_ws_payload(websocket, data) - if websocket in self._clients: - if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "[%s] ->: %s", websocket.remote_address, filter_log_msg_data(data) - ) - await websocket.send(data_dump) - else: - _LOG.error("Error sending event: connection no longer established") - - async def _process_ws_message(self, websocket, message) -> None: - _LOG.debug("[%s] <-: %s", websocket.remote_address, message) + async def _process_ws_message(self, websocket, data: dict[str, Any]) -> None: + _LOG.debug("[%s] <-: %s", websocket.remote_address, data) - data = json.loads(message) kind = data["kind"] - req_id = data["id"] if "id" in data else None + req_id = data.get("id") msg = data["msg"] - msg_data = data["msg_data"] if "msg_data" in data else None + msg_data = data.get("msg_data") if kind == "req": if req_id is None: _LOG.warning( - "Ignoring request message with missing 'req_id': %s", message + "Ignoring request message with missing 'id': %s", + data, ) - else: - await self._handle_ws_request_msg(websocket, msg, req_id, msg_data) + return + await self._handle_ws_request_msg(websocket, msg, req_id, msg_data) elif kind == "event": await self._handle_ws_event_msg(websocket, msg, msg_data) + else: + _LOG.warning( + "[%s] WS: Unsupported routed message kind %s", + websocket.remote_address, + kind, + ) + + def _handle_pending_response(self, websocket, data: dict[str, Any]) -> None: + """Resolve the response message that corresponds to a pending request from the driver.""" + + resp_id = data.get("req_id", data.get("id")) + if resp_id is None: + _LOG.warning( + "[%s] WS: Received resp without req_id/id: %s", + websocket.remote_address, + data, + ) + return + + ctx = self._get_ws_context(websocket) + if ctx is None: + _LOG.debug("[%s] WS: No context for resp", websocket.remote_address) + return + + fut = ctx.pending.get(int(resp_id)) + if fut is None: + _LOG.debug( + "[%s] WS: Unmatched resp_id=%s (not pending). msg=%s", + websocket.remote_address, + resp_id, + data.get("msg"), + ) + return + + if not fut.done(): + fut.set_result(data) + + async def _ws_request( + self, + websocket, + msg: str, + msg_data: dict[str, Any] | None = None, + *, + timeout: float = 10.0, + ) -> dict[str, Any]: + """ + Send a request over websocket and await the matching response. + + - Uses a Future stored in self._ws_pending[websocket][req_id] + - Reader task (_handle_ws -> _process_ws_message) completes the future on 'resp' + - Raises TimeoutError on timeout + :param websocket: client connection + :param msg: event message name + :param msg_data: message data payload + :param timeout: timeout for message + """ + # Ensure per-socket structures exist (in case you call before _handle_ws init) + ctx = self._get_ws_context(websocket) + if ctx is None: + raise ConnectionError("WebSocket context not found") + + # Allocate req_id safely + req_id = self._req_id + self._req_id += 1 + + fut = self._loop.create_future() + ctx.pending[req_id] = fut + + try: + payload: dict[str, Any] = {"kind": "req", "id": req_id, "msg": msg} + if msg_data is not None: + payload["msg_data"] = msg_data + + await self._enqueue_ws_payload(websocket, payload) + + # Await response from client until given timeout + resp = await asyncio.wait_for(fut, timeout=timeout) + return resp + + except asyncio.TimeoutError as ex: + _LOG.error( + "[%s] Timeout waiting for response to %s (req_id=%s) %s", + websocket.remote_address, + msg, + req_id, + ex, + ) + raise ex + finally: + # Cleanup pending future entry + ctx.pending.pop(req_id, None) async def _process_ws_binary_message(self, websocket, data: bytes) -> None: """Process a binary WebSocket message using protobuf IntegrationMessage. @@ -430,9 +572,7 @@ async def _process_ws_binary_message(self, websocket, data: bytes) -> None: - Logs errors on deserialization failures and unknown message kinds. """ if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "[%s] <-: ", websocket.remote_address, len(data) - ) + _LOG.debug("[%s] <-: ", websocket.remote_address, len(data)) # Parse IntegrationMessage from bytes try: @@ -462,6 +602,30 @@ async def _process_ws_binary_message(self, websocket, data: bytes) -> None: kind, ) + async def _cleanup_ws(self, websocket) -> None: + ctx = self._ws_contexts.pop(websocket, None) + + keys_to_cleanup = [k for k in self._voice_sessions if k[0] is websocket] + for key in keys_to_cleanup: + try: + await self._cleanup_voice_session(key, VoiceEndReason.REMOTE) + except Exception as ex: + _LOG.exception( + "[%s] WS: Error during voice session cleanup for session_id=%s: %s", + websocket.remote_address, + key[1], + ex, + ) + + if ctx is not None: + for fut in ctx.pending.values(): + if not fut.done(): + fut.set_exception(ConnectionError("WebSocket disconnected")) + + self._clients.discard(websocket) + _LOG.info("[%s] WS: Client removed", websocket.remote_address) + self._events.emit(uc.Events.CLIENT_DISCONNECTED, websocket=websocket) + async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None: """Handle a RemoteVoiceBegin protobuf message. @@ -670,9 +834,7 @@ async def _voice_session_timeout_task(self, key: VoiceSessionKey) -> None: # If handler not started yet, start it now (best effort) if ctx.handler_task is None and self._voice_handler is not None: try: - ctx.handler_task = self._loop.create_task( - self._run_voice_handler(ctx.session) - ) + ctx.handler_task = self._loop.create_task(self._run_voice_handler(ctx.session)) except Exception: # pylint: disable=W0718 _LOG.exception( "Failed to start voice handler on timeout for session %s", @@ -684,9 +846,7 @@ async def _voice_session_timeout_task(self, key: VoiceSessionKey) -> None: await self._cleanup_voice_session(key) # pylint: disable=R0912 - async def _handle_ws_request_msg( - self, websocket, msg: str, req_id: int, msg_data: dict[str, Any] | None - ) -> None: + async def _handle_ws_request_msg(self, websocket, msg: str, req_id: int, msg_data: dict[str, Any] | None) -> None: if msg == uc.WsMessages.GET_DRIVER_VERSION: await self._send_ws_response( websocket, @@ -702,13 +862,7 @@ async def _handle_ws_request_msg( {"state": self.device_state}, ) elif msg == uc.WsMessages.GET_AVAILABLE_ENTITIES: - available_entities = self._available_entities.get_all() - await self._send_ws_response( - websocket, - req_id, - uc.WsMsgEvents.AVAILABLE_ENTITIES, - {"available_entities": available_entities}, - ) + await self._get_available_entities(websocket, req_id) elif msg == uc.WsMessages.GET_ENTITY_STATES: entity_states = await self._configured_entities.get_states() await self._send_ws_response( @@ -730,9 +884,7 @@ async def _handle_ws_request_msg( await self._unsubscribe_events(websocket, msg_data) await self._send_ok_result(websocket, req_id) elif msg == uc.WsMessages.GET_DRIVER_METADATA: - await self._send_ws_response( - websocket, req_id, uc.WsMsgEvents.DRIVER_METADATA, self._driver_info - ) + await self._send_ws_response(websocket, req_id, uc.WsMsgEvents.DRIVER_METADATA, self._driver_info) elif msg == uc.WsMessages.SETUP_DRIVER: if not await self._setup_driver(websocket, req_id, msg_data): # sleep for web-configurator quirks... @@ -743,9 +895,7 @@ async def _handle_ws_request_msg( await asyncio.sleep(0.5) await self.driver_setup_error(websocket) - async def _handle_ws_event_msg( - self, websocket: Any, msg: str, msg_data: dict[str, Any] | None - ) -> None: + async def _handle_ws_event_msg(self, websocket: Any, msg: str, msg_data: dict[str, Any] | None) -> None: if msg == uc.WsMsgEvents.CONNECT: self._events.emit(uc.Events.CONNECT, websocket=websocket) elif msg == uc.WsMsgEvents.DISCONNECT: @@ -756,9 +906,7 @@ async def _handle_ws_event_msg( self._events.emit(uc.Events.EXIT_STANDBY, websocket=websocket) elif msg == uc.WsMsgEvents.ABORT_DRIVER_SETUP: if not self._setup_handler: - _LOG.warning( - "Received abort_driver_setup event, but no setup handler provided by the driver!" - ) # noqa + _LOG.warning("Received abort_driver_setup event, but no setup handler provided by the driver!") # noqa return if "error" in msg_data: @@ -768,9 +916,7 @@ async def _handle_ws_event_msg( error = uc.IntegrationSetupError.OTHER await self._setup_handler(uc.AbortDriverSetup(error)) else: - _LOG.warning( - "Unsupported abort_driver_setup payload received: %s", msg_data - ) + _LOG.warning("Unsupported abort_driver_setup payload received: %s", msg_data) async def _authenticate(self, websocket, success: bool) -> None: await self._send_ws_response( @@ -806,9 +952,7 @@ async def set_device_state(self, state: uc.DeviceStates) -> None: uc.EventCategory.DEVICE, ) - async def _subscribe_events( - self, websocket: Any, msg_data: dict[str, Any] | None - ) -> None: + async def _subscribe_events(self, websocket: Any, msg_data: dict[str, Any] | None) -> None: if msg_data is None: _LOG.warning("Ignoring _subscribe_events: called with empty msg_data") return @@ -828,9 +972,7 @@ async def _subscribe_events( websocket=websocket, ) - async def _unsubscribe_events( - self, websocket: Any, msg_data: dict[str, Any] | None - ) -> bool: + async def _unsubscribe_events(self, websocket: Any, msg_data: dict[str, Any] | None) -> bool: if msg_data is None: _LOG.warning("Ignoring _unsubscribe_events: called with empty msg_data") return False @@ -849,23 +991,17 @@ async def _unsubscribe_events( return res - async def _entity_command( - self, websocket, req_id: int, msg_data: dict[str, Any] | None - ) -> None: + async def _entity_command(self, websocket, req_id: int, msg_data: dict[str, Any] | None) -> None: if not msg_data: _LOG.warning("Ignoring entity command: called with empty msg_data") - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.BAD_REQUEST - ) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.BAD_REQUEST) return entity_id = msg_data["entity_id"] if "entity_id" in msg_data else None cmd_id = msg_data["cmd_id"] if "cmd_id" in msg_data else None if entity_id is None or cmd_id is None: _LOG.warning("Ignoring command: missing entity_id or cmd_id") - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.BAD_REQUEST - ) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.BAD_REQUEST) return entity = self.configured_entities.get(entity_id) @@ -928,28 +1064,20 @@ async def _entity_command( "Old Entity.command signature detected for %s, trying old signature. Please update the command signature.", entity.id, ) - result = await entity.command( - cmd_id, msg_data["params"] if "params" in msg_data else None - ) + result = await entity.command(cmd_id, msg_data["params"] if "params" in msg_data else None) await self.acknowledge_command(websocket, req_id, result) - async def _browse_media( - self, websocket, req_id: int, msg_data: dict[str, Any] | None - ) -> None: + async def _browse_media(self, websocket, req_id: int, msg_data: dict[str, Any] | None) -> None: if not msg_data: _LOG.warning("Ignoring browse_media command: called with empty msg_data") - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.BAD_REQUEST - ) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.BAD_REQUEST) return entity_id = msg_data["entity_id"] if "entity_id" in msg_data else None if entity_id is None: _LOG.warning("Ignoring browse_media command: missing entity_id") - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.BAD_REQUEST - ) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.BAD_REQUEST) return entity = self.configured_entities.get(entity_id) @@ -965,12 +1093,8 @@ async def _browse_media( try: data = BrowseMediaMsgData(**msg_data) except (TypeError, ValueError): - _LOG.error( - "Cannot browse media for '%s': wrong format %s", entity_id, msg_data - ) - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.BAD_REQUEST - ) + _LOG.error("Cannot browse media for '%s': wrong format %s", entity_id, msg_data) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.BAD_REQUEST) return # call integration driver to handle browse request @@ -978,9 +1102,7 @@ async def _browse_media( result = await entity.browse(data) except Exception: # pylint: disable=W0718 _LOG.exception("Failed to call MediaPlayer.browse for '%s'", entity_id) - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.SERVER_ERROR - ) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.SERVER_ERROR) return if isinstance(result, BrowseResults): @@ -994,22 +1116,16 @@ async def _browse_media( else: await self.acknowledge_command(websocket, req_id, result) - async def _search_media( - self, websocket, req_id: int, msg_data: dict[str, Any] | None - ) -> None: + async def _search_media(self, websocket, req_id: int, msg_data: dict[str, Any] | None) -> None: if not msg_data: _LOG.warning("Ignoring search_media command: called with empty msg_data") - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.BAD_REQUEST - ) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.BAD_REQUEST) return entity_id = msg_data["entity_id"] if "entity_id" in msg_data else None if entity_id is None: _LOG.warning("Ignoring search_media command: missing entity_id") - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.BAD_REQUEST - ) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.BAD_REQUEST) return entity = self.configured_entities.get(entity_id) @@ -1024,21 +1140,15 @@ async def _search_media( try: data = SearchMediaMsgData(**msg_data) except (TypeError, ValueError): - _LOG.error( - "Cannot search media for '%s': wrong format %s", entity_id, msg_data - ) - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.BAD_REQUEST - ) + _LOG.error("Cannot search media for '%s': wrong format %s", entity_id, msg_data) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.BAD_REQUEST) return try: result = await entity.search(data) except Exception: # pylint: disable=W0718 _LOG.exception("Failed to call MediaPlayer.search for '%s'", entity_id) - await self.acknowledge_command( - websocket, req_id, uc.StatusCodes.SERVER_ERROR - ) + await self.acknowledge_command(websocket, req_id, uc.StatusCodes.SERVER_ERROR) return if isinstance(result, SearchResults): @@ -1052,9 +1162,7 @@ async def _search_media( else: await self.acknowledge_command(websocket, req_id, result) - async def _setup_driver( - self, websocket, req_id: int, msg_data: dict[str, Any] | None - ) -> bool: + async def _setup_driver(self, websocket, req_id: int, msg_data: dict[str, Any] | None) -> bool: await self.acknowledge_command(websocket, req_id) if msg_data is None or "setup_data" not in msg_data: @@ -1064,24 +1172,18 @@ async def _setup_driver( # make sure integration driver installed a setup handler if not self._setup_handler: - _LOG.error( - "Received setup_driver request, but no setup handler provided by the driver!" - ) # noqa + _LOG.error("Received setup_driver request, but no setup handler provided by the driver!") # noqa return False result = False try: action = await self._setup_handler( - uc.DriverSetupRequest( - msg_data.get("reconfigure") or False, msg_data["setup_data"] - ) + uc.DriverSetupRequest(msg_data.get("reconfigure") or False, msg_data["setup_data"]) ) if isinstance(action, uc.RequestUserInput): await self.driver_setup_progress(websocket) - await self.request_driver_setup_user_input( - websocket, action.title, action.settings - ) + await self.request_driver_setup_user_input(websocket, action.title, action.settings) result = True elif isinstance(action, uc.RequestUserConfirmation): await self.driver_setup_progress(websocket) @@ -1102,15 +1204,11 @@ async def _setup_driver( return result - async def _set_driver_user_data( - self, websocket, req_id: int, msg_data: dict[str, Any] | None - ) -> bool: + async def _set_driver_user_data(self, websocket, req_id: int, msg_data: dict[str, Any] | None) -> bool: await self.acknowledge_command(websocket, req_id) if not self._setup_handler: - _LOG.error( - "Received set_driver_user_data request, but no setup handler provided by the driver!" - ) # noqa + _LOG.error("Received set_driver_user_data request, but no setup handler provided by the driver!") # noqa return False if "input_values" in msg_data or "confirm" in msg_data: @@ -1118,27 +1216,19 @@ async def _set_driver_user_data( await asyncio.sleep(0.5) await self.driver_setup_progress(websocket) else: - _LOG.warning( - "Unsupported set_driver_user_data payload received: %s", msg_data - ) + _LOG.warning("Unsupported set_driver_user_data payload received: %s", msg_data) return False result = False try: action = uc.SetupError() if "input_values" in msg_data: - action = await self._setup_handler( - uc.UserDataResponse(msg_data["input_values"]) - ) + action = await self._setup_handler(uc.UserDataResponse(msg_data["input_values"])) elif "confirm" in msg_data: - action = await self._setup_handler( - uc.UserConfirmationResponse(msg_data["confirm"]) - ) + action = await self._setup_handler(uc.UserConfirmationResponse(msg_data["confirm"])) if isinstance(action, uc.RequestUserInput): - await self.request_driver_setup_user_input( - websocket, action.title, action.settings - ) + await self.request_driver_setup_user_input(websocket, action.title, action.settings) result = True elif isinstance(action, uc.RequestUserConfirmation): await self.request_driver_setup_user_confirmation( @@ -1184,9 +1274,7 @@ async def driver_setup_progress(self, websocket) -> None: """ data = {"event_type": "SETUP", "state": "SETUP"} - await self._send_ws_event( - websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE - ) + await self._send_ws_event(websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE) # pylint: disable=R0917 async def request_driver_setup_user_confirmation( @@ -1222,9 +1310,7 @@ async def request_driver_setup_user_confirmation( }, } - await self._send_ws_event( - websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE - ) + await self._send_ws_event(websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE) async def request_driver_setup_user_input( self, websocket, title: str | dict[str, str], settings: dict[str, Any] | list @@ -1233,30 +1319,22 @@ async def request_driver_setup_user_input( data = { "event_type": "SETUP", "state": "WAIT_USER_ACTION", - "require_user_action": { - "input": {"title": _to_language_object(title), "settings": settings} - }, + "require_user_action": {"input": {"title": _to_language_object(title), "settings": settings}}, } - await self._send_ws_event( - websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE - ) + await self._send_ws_event(websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE) async def driver_setup_complete(self, websocket) -> None: """Send a driver setup complete event to Remote Two/3.""" data = {"event_type": "STOP", "state": "OK"} - await self._send_ws_event( - websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE - ) + await self._send_ws_event(websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE) async def driver_setup_error(self, websocket, error="OTHER") -> None: """Send a driver setup error event to Remote Two/3.""" data = {"event_type": "STOP", "state": "ERROR", "error": error} - await self._send_ws_event( - websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE - ) + await self._send_ws_event(websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE) @staticmethod def _wrap_event_listener(listener: Callable) -> Callable: @@ -1277,9 +1355,7 @@ def _wrap_event_listener(listener: Callable) -> Callable: params = list(sig.parameters.values()) - accepts_varargs = any( - p.kind == inspect.Parameter.VAR_POSITIONAL for p in params - ) + accepts_varargs = any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in params) accepts_varkw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params) # How many positional args can the listener accept (excluding *args/**kwargs)? @@ -1293,18 +1369,13 @@ def _wrap_event_listener(listener: Callable) -> Callable: accepted_kw = { p.name for p in params - if p.kind - in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + if p.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) } @wraps(listener) def wrapper(*args: Any, **kwargs: Any): call_args = args if accepts_varargs else args[:max_positional] - call_kwargs = ( - kwargs - if accepts_varkw - else {k: v for k, v in kwargs.items() if k in accepted_kw} - ) + call_kwargs = kwargs if accepts_varkw else {k: v for k, v in kwargs.items() if k in accepted_kw} return listener(*call_args, **call_kwargs) return wrapper @@ -1351,10 +1422,96 @@ def remove_all_listeners(self, event: uc.Events | None) -> None: """ self._events.remove_all_listeners(event) + async def get_supported_entity_types(self, websocket, *, timeout: float = 5.0) -> list[str]: + """Request supported entity types from client and return msg_data.""" + resp = await self._ws_request( + websocket, + "get_supported_entity_types", + timeout=timeout, + ) + if resp.get("msg") != "supported_entity_types": + _LOG.debug( + "[%s] Unexpected resp msg for get_supported_entity_types: %s", + websocket.remote_address, + resp.get("msg"), + ) + return resp.get("msg_data", []) + + async def get_version(self, websocket, *, timeout: float = 5.0) -> dict[str, Any] | None: + """Request client version and return msg_data.""" + resp = await self._ws_request( + websocket, + "get_version", + timeout=timeout, + ) + if resp.get("msg") != "version": + _LOG.debug( + "[%s] Unexpected resp msg for get_version: %s", + websocket.remote_address, + resp.get("msg"), + ) + + return resp.get("msg_data") + + async def get_localization_cfg(self, websocket, *, timeout: float = 5.0) -> dict[str, Any] | None: + """Request localization config and return msg_data.""" + resp = await self._ws_request( + websocket, + "get_localization_cfg", + timeout=timeout, + ) + + if resp.get("msg") != "localization_cfg": + _LOG.debug( + "[%s] Unexpected resp msg for get_localization_cfg: %s", + websocket.remote_address, + resp.get("msg"), + ) + + return resp.get("msg_data") + + async def _update_supported_entity_types(self, websocket, *, timeout: float = 5.0) -> None: + """Update supported entity types by remote.""" + await asyncio.sleep(0) + try: + self._supported_entity_types = await self.get_supported_entity_types(websocket, timeout=timeout) + _LOG.debug( + "[%s] Supported entity types %s", + websocket.remote_address, + self._supported_entity_types, + ) + except Exception as ex: # pylint: disable=W0718 + _LOG.error( + "[%s] Unable to retrieve entity types %s", + websocket.remote_address, + ex, + ) + + async def _get_available_entities(self, websocket, req_id) -> None: + if self._supported_entity_types is None: + # Request supported entity types from remote + await self._update_supported_entity_types(websocket) + available_entities = self._available_entities.get_all() + if self._supported_entity_types: + available_entities = [ + entity for entity in available_entities if entity.get("entity_type") in self._supported_entity_types + ] + await self._send_ws_response( + websocket, + req_id, + uc.WsMsgEvents.AVAILABLE_ENTITIES, + {"available_entities": available_entities}, + ) + ############## # Properties # ############## + @property + def clients(self) -> set: + """Return all clients.""" + return self._clients.copy() + @property def client_count(self) -> int: """Return number of WebSocket clients.""" @@ -1394,9 +1551,7 @@ def _to_language_object(text: str | dict[str, str] | None) -> dict[str, str] | N return text -def _get_default_language_string( - text: str | dict[str, str] | None, default_text="Undefined" -) -> str: +def _get_default_language_string(text: str | dict[str, str] | None, default_text="Undefined") -> str: if text is None: return default_text @@ -1456,10 +1611,7 @@ def local_hostname() -> str: # local hostname keeps on changing with a increasing number suffix! # https://apple.stackexchange.com/questions/189350/my-macs-hostname-keeps-adding-a-2-to-the-end - return ( - os.getenv("UC_MDNS_LOCAL_HOSTNAME") - or f'{socket.gethostname().split(".", 1)[0]}.local.' - ) + return os.getenv("UC_MDNS_LOCAL_HOSTNAME") or f'{socket.gethostname().split(".", 1)[0]}.local.' def filter_log_msg_data(data: dict[str, Any]) -> dict[str, Any]: @@ -1484,11 +1636,7 @@ def filter_log_msg_data(data: dict[str, Any]) -> dict[str, Any]: if ( "attributes" in log_upd["msg_data"] and MediaAttr.MEDIA_IMAGE_URL in log_upd["msg_data"]["attributes"] - and ( - media_image_url := log_upd["msg_data"]["attributes"][ - MediaAttr.MEDIA_IMAGE_URL - ] - ) + and (media_image_url := log_upd["msg_data"]["attributes"][MediaAttr.MEDIA_IMAGE_URL]) and media_image_url.startswith("data:") ): log_upd["msg_data"]["attributes"][MediaAttr.MEDIA_IMAGE_URL] = "data:***" @@ -1497,9 +1645,7 @@ def filter_log_msg_data(data: dict[str, Any]) -> dict[str, Any]: if ( "attributes" in item and MediaAttr.MEDIA_IMAGE_URL in item["attributes"] - and ( - media_image_url := item["attributes"][MediaAttr.MEDIA_IMAGE_URL] - ) + and (media_image_url := item["attributes"][MediaAttr.MEDIA_IMAGE_URL]) and media_image_url.startswith("data:") ): item["attributes"][MediaAttr.MEDIA_IMAGE_URL] = "data:***"