diff --git a/custom_components/hilo/__init__.py b/custom_components/hilo/__init__.py index 0d4e4dc3..d5f8ebdf 100644 --- a/custom_components/hilo/__init__.py +++ b/custom_components/hilo/__init__.py @@ -6,7 +6,7 @@ from collections import OrderedDict from datetime import datetime, timedelta import traceback -from typing import TYPE_CHECKING, List, Optional +from typing import List, Optional from aiohttp import CookieJar, client_exceptions from homeassistant.components.select import ( @@ -33,21 +33,16 @@ ) from homeassistant.helpers.aiohttp_client import async_create_clientsession from homeassistant.helpers.dispatcher import async_dispatcher_send -from homeassistant.helpers.event import async_call_later from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from pyhilo import API from pyhilo.device import HiloDevice from pyhilo.devices import Devices from pyhilo.event import Event -from pyhilo.exceptions import ( - CannotConnectError, - HiloError, - InvalidCredentialsError, - WebsocketError, -) +from pyhilo.exceptions import HiloError from pyhilo.graphql import GraphQlHelper +from pyhilo.signalr import SignalREvent, signalr_event_from_payload from pyhilo.util import from_utc_timestamp, time_diff -from pyhilo.websocket import WebsocketEvent, websocket_event_from_payload +from pysignalr.exceptions import ServerError as SignalRServerError from .config_flow import STEP_OPTION_SCHEMA, HiloFlowHandler from .const import ( @@ -76,7 +71,7 @@ ) from .oauth2 import AuthCodeWithPKCEImplementation -DISPATCHER_TOPIC_WEBSOCKET_EVENT = "pyhilo_websocket_event" +DISPATCHER_TOPIC_SIGNALR_EVENT = "pyhilo_signalr_event" SIGNAL_UPDATE_ENTITY = "pyhilo_device_update_{}" COORDINATOR_AWARE_PLATFORMS = [Platform.SENSOR] PLATFORMS = COORDINATOR_AWARE_PLATFORMS + [ @@ -187,9 +182,9 @@ async def handle_debug_event(event: Event): LOG.debug("HILO_DEBUG: Event received: %s", event) log_traces = current_options.get(CONF_LOG_TRACES) LOG.debug("HILO_DEBUG: log_traces is %s", log_traces) - websocket_event = websocket_event_from_payload(event.data) - LOG.debug("HILO_DEBUG: Websocket event parsed: %s", websocket_event) - await hilo.on_websocket_event(websocket_event) + signalr_event = signalr_event_from_payload(event.data) + LOG.debug("HILO_DEBUG: SignalR event parsed: %s", signalr_event) + await hilo.on_signalr_event(signalr_event) log_traces = current_options.get(CONF_LOG_TRACES) if log_traces: @@ -221,9 +216,9 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hilo = hass.data[DOMAIN][entry.entry_id] - hilo.should_websocket_reconnect = False + hilo.should_signalr_reconnect = False - for task in list(hilo._websocket_reconnect_tasks): + for task in list(hilo._signalr_reconnect_tasks): if not task.done(): task.cancel() try: @@ -232,12 +227,13 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: pass try: - if hasattr(hilo, "_devicehub_ws") and hilo._devicehub_ws: - await hilo._devicehub_ws.async_disconnect() - if hasattr(hilo, "_challengehub_ws") and hilo._challengehub_ws: - await hilo._challengehub_ws.async_disconnect() + await hilo._api.signalr_devices.disconnect() + except Exception as err: + LOG.error("Error disconnecting device SignalR hub: %s", err) + try: + await hilo._api.signalr_challenges.disconnect() except Exception as err: - LOG.error("Error disconnecting websockets: %s", err) + LOG.error("Error disconnecting challenge SignalR hub: %s", err) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) if unload_ok: @@ -286,19 +282,10 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry, api: API) -> None: self.devices: Devices = Devices(api) self.graphql_helper: GraphQlHelper = GraphQlHelper(api, self.devices) self.challenge_id = 0 - self._should_websocket_reconnect = True - self._websocket_reconnect_tasks: list[asyncio.Task | None] = [None, None] + self._should_signalr_reconnect = True + self._signalr_reconnect_tasks: list[asyncio.Task | None] = [None, None] self._update_task: list[asyncio.Task | None] = [None, None] self.subscriptions: List[Optional[asyncio.Task]] = [None] - self.invocations = { - "device": { - 0: self.subscribe_to_location, - }, - "challenge": { - 1: self.subscribe_to_challenge, - 2: self.subscribe_to_challengelist, - }, - } self.hq_plan_name = entry.options.get(CONF_HQ_PLAN_NAME, DEFAULT_HQ_PLAN_NAME) self.appreciation = entry.options.get( CONF_APPRECIATION_PHASE, DEFAULT_APPRECIATION_PHASE @@ -322,28 +309,37 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry, api: API) -> None: self._events: dict = {} if self.track_unknown_sources: self._api._get_device_callbacks = [self._get_unknown_source_tracker] - self._websocket_listeners = [] + self._signalr_listeners = [] - def validate_heartbeat(self, event: WebsocketEvent) -> None: - """Validate heartbeat messages from the websocket.""" + async def _on_devices_connected(self) -> None: + """Trigger device subscriptions after the device hub connects.""" + await self.subscribe_to_location() + + async def _on_challenges_connected(self) -> None: + """Trigger challenge subscriptions after the challenge hub connects.""" + await self.subscribe_to_challenge() + await self.subscribe_to_challengelist() + + def validate_heartbeat(self, event: SignalREvent) -> None: + """Validate heartbeat messages from SignalR.""" heartbeat_time = from_utc_timestamp(event.arguments[0]) # type: ignore if self._api.log_traces: LOG.debug("Heartbeat: %s", time_diff(heartbeat_time, event.timestamp)) - def register_websocket_listener(self, listener): - """Register a listener for websocket events.""" - LOG.debug("Registering websocket listener: %s", listener.__class__.__name__) - self._websocket_listeners.append(listener) + def register_signalr_listener(self, listener): + """Register a listener for SignalR events.""" + LOG.debug("Registering SignalR listener: %s", listener.__class__.__name__) + self._signalr_listeners.append(listener) - async def _handle_websocket_message(self, event): - """Process websocket messages and notify listeners.""" + async def _handle_signalr_message(self, event): + """Process SignalR messages and notify listeners.""" # TODO: ic-dev21: This needs to be cleaned up and optimized - LOG.debug("Received websocket message type: %s", event) + LOG.debug("Received SignalR message type: %s", event) target = event.target - LOG.debug("handle_websocket_message_target %s", target) + LOG.debug("handle_signalr_message_target %s", target) msg_data = event - LOG.debug("handle_websocket_message_ msg_data %s", msg_data) + LOG.debug("handle_signalr_message_ msg_data %s", msg_data) if target in [ "ChallengeListInitialValuesReceived", @@ -372,13 +368,13 @@ async def _handle_websocket_message(self, event): return # ic-dev21 Notify listeners - for listener in self._websocket_listeners: + for listener in self._signalr_listeners: handler_name = f"handle_{msg_type}" if hasattr(listener, handler_name): handler = getattr(listener, handler_name) try: - # ic-dev21 Extract the arguments from the WebsocketEvent object - if isinstance(msg_data, WebsocketEvent): + # ic-dev21 Extract the arguments from the SignalREvent object + if isinstance(msg_data, SignalREvent): arguments = msg_data.arguments if arguments: # ic-dev21 check if there are arguments await handler(arguments[0]) @@ -387,16 +383,14 @@ async def _handle_websocket_message(self, event): f"SHOULD NOT HAPPEN: Received empty arguments for {msg_type}" ) else: - LOG.warning( - f"SHOULD NOT HAPPEN: Not WebsocketEvent: {msg_data}" - ) + LOG.warning(f"SHOULD NOT HAPPEN: Not SignalREvent: {msg_data}") await handler(msg_data) except Exception as e: - LOG.error("Error in websocket handler %s: %s", handler_name, e) + LOG.error("Error in SignalR handler %s: %s", handler_name, e) LOG.error(traceback.format_exc()) - async def _handle_challenge_events(self, event: WebsocketEvent) -> None: - """Handle all challenge-related websocket events.""" + async def _handle_challenge_events(self, event: SignalREvent) -> None: + """Handle all challenge-related SignalR events.""" if event.target == "ChallengeDetailsInitialValuesReceived": challenge = event.arguments[0] LOG.debug( @@ -415,7 +409,7 @@ async def _handle_challenge_events(self, event: WebsocketEvent) -> None: LOG.debug("ChallengeAdded") challenge = event.arguments[0] self.challenge_id = challenge.get("id") - await self.subscribe_to_challenge(1, self.challenge_id) + await self.subscribe_to_challenge(self.challenge_id) elif event.target == "ChallengeListInitialValuesReceived": LOG.debug("ChallengeListInitialValuesReceived") @@ -425,7 +419,7 @@ async def _handle_challenge_events(self, event: WebsocketEvent) -> None: challenge_id = challenge.get("id") self.challenge_phase = challenge.get("currentPhase") self.challenge_id = challenge.get("id") - await self.subscribe_to_challenge(1, challenge_id) + await self.subscribe_to_challenge(challenge_id) elif event.target == "EventCHDetailsUpdatedValuesReceived": LOG.debug("EventCHDetailsUpdatedValuesReceived") @@ -435,8 +429,8 @@ async def _handle_challenge_events(self, event: WebsocketEvent) -> None: event_id = data.get("id") LOG.debug("Report for event %s: %s", event_id, report) - async def _handle_device_events(self, event: WebsocketEvent) -> None: - """Handle all device-related websocket events.""" + async def _handle_device_events(self, event: SignalREvent) -> None: + """Handle all device-related SignalR events.""" if event.target == "DevicesValuesReceived": new_devices = any( self.devices.find_device(item["deviceId"]) is None @@ -447,7 +441,7 @@ async def _handle_device_events(self, event: WebsocketEvent) -> None: "Device list appears to be desynchronized, " "waiting for next DeviceListInitialValuesReceived to refresh..." ) - # Device list will refresh on next websocket reconnect/subscribe + # Device list will refresh on next SignalR reconnect/subscribe updated_devices = self.devices.parse_values_received(event.arguments[0]) # NOTE(dvd): If we don't do this, we need to wait until the coordinator @@ -492,42 +486,34 @@ async def _handle_device_events(self, event: WebsocketEvent) -> None: ) @callback - async def on_websocket_event(self, event: WebsocketEvent) -> None: - """Define a callback for receiving a websocket event.""" - async_dispatcher_send(self._hass, DISPATCHER_TOPIC_WEBSOCKET_EVENT, event) - - if event.event_type == "COMPLETE": - # Look up the callback in both device and challenge invocation groups - cb = self.invocations["device"].get(event.invocation) or self.invocations[ - "challenge" - ].get(event.invocation) - if cb: - async_call_later(self._hass, 3, cb(event.invocation)) - - elif event.target == "Heartbeat": + async def on_signalr_event(self, event: SignalREvent) -> None: + """Define a callback for receiving a SignalR event.""" + async_dispatcher_send(self._hass, DISPATCHER_TOPIC_SIGNALR_EVENT, event) + + if event.target == "Heartbeat": self.validate_heartbeat(event) elif "Challenge" in event.target or "Event" in event.target: - LOG.debug("HILO_DEBUG: Handling challenge/event websocket event: %s", event) + LOG.debug("HILO_DEBUG: Handling challenge/event SignalR event: %s", event) await self._handle_challenge_events(event) - await self._handle_websocket_message(event) + await self._handle_signalr_message(event) elif "Device" in event.target or event.target == "GatewayValuesReceived": await self._handle_device_events(event) else: - LOG.warning("Unhandled websocket event: %s", event) + LOG.warning("Unhandled SignalR event: %s", event) @callback - async def subscribe_to_location(self, inv_id: int) -> None: + async def subscribe_to_location(self) -> None: """Send the json payload to receive updates from the location.""" LOG.debug("Subscribing to location %s", self.devices.location_id) - await self._api.websocket_devices.async_invoke( - [self.devices.location_id], "SubscribeToLocation", inv_id + await self._api.signalr_devices.invoke( + "SubscribeToLocation", [self.devices.location_id] ) @callback - async def subscribe_to_challenge(self, inv_id: int, event_id: int = 0) -> None: + async def subscribe_to_challenge(self, event_id: int = 0) -> None: """Send the json payload to receive updates from the challenge.""" LOG.debug("Subscribing to challenge : %s or %s", event_id, self.challenge_id) event_id = event_id or self.challenge_id @@ -542,60 +528,52 @@ async def subscribe_to_challenge(self, inv_id: int, event_id: int = 0) -> None: "Starting legacy connection to ChallengeHub. Your tarif is %s, and will also attempt connection. This can be safely ignored. This will be deprecated", tarif_config, ) - await self._api.websocket_challenges.async_invoke( - [{"locationId": self.devices.location_id, "eventId": event_id}], + await self._api.signalr_challenges.invoke( "SubscribeToChallenge", - inv_id, + [{"locationId": self.devices.location_id, "eventId": event_id}], ) # Subscribe to the correct challenge hub if tarif_config == "rate d": - await self._api.websocket_challenges.async_invoke( - [{"locationHiloId": self._api.urn, "eventId": event_id}], + await self._api.signalr_challenges.invoke( "SubscribeToEventCH", - inv_id, + [{"locationHiloId": self._api.urn, "eventId": event_id}], ) elif tarif_config == "flex d": - await self._api.websocket_challenges.async_invoke( - [{"locationHiloId": self._api.urn, "eventId": event_id}], + await self._api.signalr_challenges.invoke( "SubscribeToEventFlex", - inv_id, + [{"locationHiloId": self._api.urn, "eventId": event_id}], ) else: LOG.warning("Unknown plan name %s, falling back to default", tarif_config) - await self._api.websocket_challenges.async_invoke( - [{"locationId": self.devices.location_id, "eventId": event_id}], + await self._api.signalr_challenges.invoke( "SubscribeToChallenge", - inv_id, + [{"locationId": self.devices.location_id, "eventId": event_id}], ) @callback - async def subscribe_to_challengelist(self, inv_id: int) -> None: + async def subscribe_to_challengelist(self) -> None: """Send the json payload to receive updates from the challenge list.""" - # TODO : Rename challegenge functions to Event, fallback on challenge for now + # TODO : Rename challenge functions to Event, fallback on challenge for now LOG.debug( "Subscribing to challenge list at location %s", self.devices.location_id ) LOG.debug("API URN is %s", self._api.urn) - await self._api.websocket_challenges.async_invoke( - [{"locationId": self.devices.location_id}], + await self._api.signalr_challenges.invoke( "SubscribeToChallengeList", - inv_id, + [{"locationId": self.devices.location_id}], ) LOG.debug("Subscribing to event list at location %s", self.devices.location_id) - await self._api.websocket_challenges.async_invoke( - [{"locationHiloId": self._api.urn}], + await self._api.signalr_challenges.invoke( "SubscribeToEventList", - inv_id, + [{"locationHiloId": self._api.urn}], ) @callback - async def request_challenge_consumption_update( - self, inv_id: int, event_id: int = 0 - ) -> None: + async def request_challenge_consumption_update(self, event_id: int = 0) -> None: """Send the json payload to receive energy consumption updates from the challenge.""" event_id = event_id or self.challenge_id @@ -605,10 +583,9 @@ async def request_challenge_consumption_update( event_id, self.devices.location_id, ) - await self._api.websocket_challenges.async_invoke( - [{"locationId": self.devices.location_id, "eventId": event_id}], + await self._api.signalr_challenges.invoke( "RequestChallengeConsumptionUpdate", - inv_id, + [{"locationId": self.devices.location_id, "eventId": event_id}], ) # Get plan name to request the correct consumption update @@ -619,20 +596,18 @@ async def request_challenge_consumption_update( "Requesting event CH consumption update at location %s", self.devices.location_id, ) - await self._api.websocket_challenges.async_invoke( - [{"locationHiloId": self._api.urn, "eventId": event_id}], + await self._api.signalr_challenges.invoke( "RequestEventCHConsumptionUpdate", - inv_id, + [{"locationHiloId": self._api.urn, "eventId": event_id}], ) elif tarif_config == "flex d": LOG.debug( "Requesting event Flex consumption update at location %s", self.devices.location_id, ) - await self._api.websocket_challenges.async_invoke( - [{"locationHiloId": self._api.urn, "eventId": event_id}], + await self._api.signalr_challenges.invoke( "RequestEventFlexConsumptionUpdate", - inv_id, + [{"locationHiloId": self._api.urn, "eventId": event_id}], ) else: LOG.debug( @@ -640,28 +615,11 @@ async def request_challenge_consumption_update( event_id, self.devices.location_id, ) - await self._api.websocket_challenges.async_invoke( - [{"locationId": self.devices.location_id, "eventId": event_id}], + await self._api.signalr_challenges.invoke( "RequestChallengeConsumptionUpdate", - inv_id, + [{"locationId": self.devices.location_id, "eventId": event_id}], ) - @callback - async def request_status_update(self) -> None: - """Request a status update from the device websocket.""" - await self._api.websocket_devices.send_status() - - for inv_id, inv_cb in self.invocations["device"].items(): - await inv_cb(inv_id) - - @callback - async def request_status_update_challenge(self) -> None: - """Request a status update from the challenge websocket.""" - await self._api.websocket_challenges.send_status() - - for inv_id, inv_cb in self.invocations["challenge"].items(): - await inv_cb(inv_id) - @callback def _get_unknown_source_tracker(self) -> HiloDevice: return { @@ -729,30 +687,22 @@ async def async_init(self, scan_interval: int) -> None: 5. Build device list (websocket devices + gateway from REST) 6. Initialize GraphQL, register custom devices, start coordinator """ - if TYPE_CHECKING: - assert self._api.refresh_token - assert self._api.websocket - # Step 1: Get location IDs (still REST) await self.devices.async_init() - # Step 2: Register websocket callbacks and start connections - # The connect callback triggers subscribe_to_location, which causes - # the server to send DeviceListInitialValuesReceived - self._api.websocket_devices.add_connect_callback(self.request_status_update) - self._api.websocket_devices.add_event_callback(self.on_websocket_event) - self._api.websocket_challenges.add_connect_callback( - self.request_status_update_challenge - ) - self._api.websocket_challenges.add_event_callback(self.on_websocket_event) - self._websocket_reconnect_tasks[0] = asyncio.create_task( - self.start_websocket_loop(self._api.websocket_devices, 0) + # Step 2: Register SignalR callbacks and start connections + self._api.signalr_devices.add_connect_callback(self._on_devices_connected) + self._api.signalr_challenges.add_connect_callback(self._on_challenges_connected) + self._api.signalr_devices.add_event_callback(self.on_signalr_event) + self._api.signalr_challenges.add_event_callback(self.on_signalr_event) + self._signalr_reconnect_tasks[0] = asyncio.create_task( + self.start_signalr_loop(self._api.signalr_devices, 0) ) - self._websocket_reconnect_tasks[1] = asyncio.create_task( - self.start_websocket_loop(self._api.websocket_challenges, 1) + self._signalr_reconnect_tasks[1] = asyncio.create_task( + self.start_signalr_loop(self._api.signalr_challenges, 1) ) - # Step 3: Wait for DeviceListInitialValuesReceived from websocket + # Step 3: Wait for DeviceListInitialValuesReceived from SignalR await self._api.wait_for_device_cache(timeout=30.0) # Step 4: Build device list (websocket devices + gateway REST + callbacks) @@ -781,17 +731,14 @@ async def async_init(self, scan_interval: int) -> None: self._hass, self.entry, self.unknown_tracker_device ) - async def websocket_disconnect_listener(_: Event) -> None: - """Define an event handler to disconnect from the websocket.""" - if TYPE_CHECKING: - assert self._api.websocket_devices - - if self._api.websocket_devices.connected: - await self._api.websocket_devices.async_disconnect() + async def signalr_disconnect_listener(_: Event) -> None: + """Define an event handler to disconnect from the SignalR hubs.""" + await self._api.signalr_devices.disconnect() + await self._api.signalr_challenges.disconnect() self.entry.async_on_unload( self._hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_STOP, websocket_disconnect_listener + EVENT_HOMEASSISTANT_STOP, signalr_disconnect_listener ) ) self.coordinator = DataUpdateCoordinator( @@ -802,47 +749,49 @@ async def websocket_disconnect_listener(_: Event) -> None: update_method=self.async_update, ) - async def start_websocket_loop(self, websocket, id) -> None: - """Start a websocket reconnection loop.""" - if TYPE_CHECKING: - assert websocket - - try: - await websocket.async_connect() - await websocket.async_listen() - except asyncio.CancelledError: - LOG.debug("Request to cancel websocket loop received") - raise - except CannotConnectError as err: - if "Session is closed" in str(err): + async def start_signalr_loop(self, hub, id) -> None: + """Start a SignalR reconnection loop that retries forever until HA stops.""" + backoff = 5 # seconds; doubles on each error, reset to 5 after a clean run + while self.should_signalr_reconnect: + try: + LOG.info("SignalRHub[%s]: connecting", id) + await hub.run() + # hub.run() returned without raising — server closed the connection. + # That's a normal disconnect; reset backoff and reconnect quickly. + LOG.warning( + "SignalRHub[%s]: connection closed by server; reconnecting in %ss", + id, + backoff, + ) + backoff = 5 + except asyncio.CancelledError: + LOG.debug("SignalRHub[%s]: loop cancelled — stopping", id) + return + except SignalRServerError as err: LOG.warning( - "Session is closed, Home Assistant is probably shutting down" + "SignalRHub[%s]: server-initiated close; reconnecting in %ss — %s", + id, + backoff, + err, ) - self.should_websocket_reconnect = False + except Exception as err: # pylint: disable=broad-except + LOG.warning( + "SignalRHub[%s]: connection error; reconnecting in %ss — %s", + id, + backoff, + err, + ) + + if not self.should_signalr_reconnect: return - except WebsocketError as err: - LOG.error("Failed to connect to websocket: %s", err, exc_info=err) - await self.cancel_websocket_loop(websocket, id) - except InvalidCredentialsError: - LOG.warning("Invalid credentials? Refreshing websocket infos") - await self.cancel_websocket_loop(websocket, id) - try: - await self._api.refresh_ws_token() - except Exception as err: - LOG.error("Exception while refreshing the token: %s", err, exc_info=err) - except Exception as err: # pylint: disable=broad-except - LOG.error( - "Unknown exception while connecting to websocket: %s", err, exc_info=err - ) - await self.cancel_websocket_loop(websocket, id) + try: + await asyncio.sleep(backoff) + except asyncio.CancelledError: + LOG.debug("SignalRHub[%s]: sleep cancelled — stopping", id) + return - if self.should_websocket_reconnect: - LOG.info("Disconnected from websocket; reconnecting in 5 seconds.") - await asyncio.sleep(5) - self._websocket_reconnect_tasks[id] = self._hass.async_create_task( - self.start_websocket_loop(websocket, id) - ) + backoff = min(backoff * 2, 300) async def cancel_task(self, task) -> None: """Cancel a task.""" @@ -856,28 +805,18 @@ async def cancel_task(self, task) -> None: task = None return task - async def cancel_websocket_loop(self, websocket, id) -> None: - """Stop any existing websocket reconnection loop.""" - self._websocket_reconnect_tasks[id] = await self.cancel_task( - self._websocket_reconnect_tasks[id] - ) - self._update_task[id] = await self.cancel_task(self._update_task[id]) - if TYPE_CHECKING: - assert websocket - await websocket.async_disconnect() - @property - def should_websocket_reconnect(self) -> bool: - """Determine if a websocket should reconnect when the connection is lost. + def should_signalr_reconnect(self) -> bool: + """Determine if a SignalR hub should reconnect when the connection is lost. - Currently only used to disable websockets in the unit tests. + Currently only used to disable SignalR in the unit tests. """ - return self._should_websocket_reconnect + return self._should_signalr_reconnect - @should_websocket_reconnect.setter - def should_websocket_reconnect(self, value: bool) -> None: - """Set if websocket should reconnect on disconnection.""" - self._should_websocket_reconnect = value + @should_signalr_reconnect.setter + def should_signalr_reconnect(self, value: bool) -> None: + """Set if SignalR hub should reconnect on disconnection.""" + self._should_signalr_reconnect = value async def async_update(self) -> None: """Update tarif periodically.""" diff --git a/custom_components/hilo/entity.py b/custom_components/hilo/entity.py index 3d374f78..86c1a2f0 100644 --- a/custom_components/hilo/entity.py +++ b/custom_components/hilo/entity.py @@ -11,7 +11,7 @@ from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.update_coordinator import CoordinatorEntity from pyhilo.device import HiloDevice -from pyhilo.websocket import WebsocketEvent +from pyhilo.signalr import SignalREvent from . import SIGNAL_UPDATE_ENTITY, Hilo from .const import DOMAIN @@ -75,8 +75,8 @@ def _handle_coordinator_update(self) -> None: self.async_write_ha_state() @callback - def async_update_from_websocket_event(self, event: WebsocketEvent) -> None: - """Update the entity when new data comes from the websocket.""" + def async_update_from_signalr_event(self, event: SignalREvent) -> None: + """Update the entity when new data comes from SignalR.""" raise NotImplementedError() async def async_added_to_hass(self): diff --git a/custom_components/hilo/sensor.py b/custom_components/hilo/sensor.py index f47c2d08..d3e75b08 100755 --- a/custom_components/hilo/sensor.py +++ b/custom_components/hilo/sensor.py @@ -660,7 +660,7 @@ def __init__(self, hilo, device, scan_interval): self._history = [] self._events_to_poll = dict() self.async_update = Throttle(self.scan_interval)(self._async_update) - hilo.register_websocket_listener(self) + hilo.register_signalr_listener(self) # When we update the list of reward history, we can end up making # hundreds of calls to _save_history in a very short amount of time. @@ -833,7 +833,7 @@ async def _async_update(self): self._history = new_history await self._save_history_debouncer.async_call() for eventId in self._events_to_poll: - await self._hilo.subscribe_to_challenge(1, eventId) + await self._hilo.subscribe_to_challenge(eventId) async def _load_history(self) -> list: history: list = [] @@ -907,7 +907,7 @@ def __init__(self, hilo, device, scan_interval): self.async_update = Throttle(timedelta(seconds=MIN_SCAN_INTERVAL))( self._async_update ) - hilo.register_websocket_listener(self) + hilo.register_signalr_listener(self) async def handle_challenge_added(self, event_data): """Handle new challenge event.""" @@ -1003,7 +1003,7 @@ async def handle_challenge_details_update(self, challenge): if baseline_points: baselinewH = baseline_points[-1]["wh"] else: - baselinewH = challenge.get("baselineWh", 0) + baselinewH = 0 allowed_kwh = baselinewH / 1000 if baselinewH > 0 else 0 used_wH = challenge.get("currentWh", 0) @@ -1098,7 +1098,7 @@ async def async_added_to_hass(self): """Handle entity about to be added to hass event.""" await super().async_added_to_hass() - await self._hilo.subscribe_to_challengelist(2) + await self._hilo.subscribe_to_challengelist() async def _async_update(self): """Update fallback, but not needed with websockets.""" @@ -1106,11 +1106,11 @@ async def _async_update(self): event = self._events.get(event_id) if event.should_check_for_allowed_wh(): LOG.debug("ASYNC UPDATE SUB: EVENT: %s", event_id) - await self._hilo.subscribe_to_challenge(1, event_id) - await self._hilo.request_challenge_consumption_update(1, event_id) + await self._hilo.subscribe_to_challenge(event_id) + await self._hilo.request_challenge_consumption_update(event_id) elif self.state == "reduction": LOG.debug("ASYNC UPDATE: EVENT: %s", event_id) - await self._hilo.request_challenge_consumption_update(1, event_id) + await self._hilo.request_challenge_consumption_update(event_id) class DeviceSensor(HiloEntity, SensorEntity): diff --git a/tests/__init__.py b/tests/__init__.py index a099e265..e9cf8960 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -21,10 +21,10 @@ async def setup_with_selected_platforms( patch("custom_components.hilo.PLATFORMS", platforms), patch("custom_components.hilo.API.async_create", return_value=mock_api), patch( - "custom_components.hilo.Hilo.should_websocket_reconnect", + "custom_components.hilo.Hilo.should_signalr_reconnect", new_callable=PropertyMock, - ) as mock_should_websocket_reconnect, + ) as mock_should_signalr_reconnect, ): - mock_should_websocket_reconnect.return_value = False + mock_should_signalr_reconnect.return_value = False assert await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() diff --git a/tests/conftest.py b/tests/conftest.py index bca66012..d85f40a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,13 @@ """Fixtures for testing.""" +import json +from collections.abc import Generator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch + import pytest from homeassistant.core import HomeAssistant -from pyhilo.websocket import WebsocketClient +from pyhilo.signalr import SignalRHub from pytest_homeassistant_custom_component.common import ( MockConfigEntry, load_fixture, @@ -61,13 +66,15 @@ def mock_onboarding() -> Generator[MagicMock]: def mock_api() -> Generator[MagicMock]: """Return a mocked Hilo API""" with patch("pyhilo.API", autospec=True) as api_mock: - # Mock websocket methods to prevent indefinite blocking - api_mock.websocket_devices = AsyncMock(spec=WebsocketClient) - api_mock.websocket_devices.async_connect = AsyncMock(return_value=None) - api_mock.websocket_devices.async_listen = AsyncMock(return_value=None) - api_mock.websocket_challenges = AsyncMock(spec=WebsocketClient) - api_mock.websocket_challenges.async_connect = AsyncMock(return_value=None) - api_mock.websocket_challenges.async_listen = AsyncMock(return_value=None) + # Mock SignalR hubs to prevent indefinite blocking + api_mock.signalr_devices = AsyncMock(spec=SignalRHub) + api_mock.signalr_devices.run = AsyncMock(return_value=None) + api_mock.signalr_devices.connected = False + api_mock.signalr_devices.disconnect = AsyncMock(return_value=None) + api_mock.signalr_challenges = AsyncMock(spec=SignalRHub) + api_mock.signalr_challenges.run = AsyncMock(return_value=None) + api_mock.signalr_challenges.connected = False + api_mock.signalr_challenges.disconnect = AsyncMock(return_value=None) api_mock.log_traces = True api_mock.get_devices.return_value = json.loads(load_fixture("all_devices.json")) @@ -87,11 +94,11 @@ async def init_integration( with ( patch("custom_components.hilo.API.async_create", return_value=mock_api), patch( - "custom_components.hilo.Hilo.should_websocket_reconnect", + "custom_components.hilo.Hilo.should_signalr_reconnect", new_callable=PropertyMock, - ) as mock_should_websocket_reconnect, + ) as mock_should_signalr_reconnect, ): - mock_should_websocket_reconnect.return_value = False + mock_should_signalr_reconnect.return_value = False await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done()