Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch state updates for reporting power plugs #389

Draft
wants to merge 6 commits into
base: dev
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 90 additions & 2 deletions zha/application/platforms/sensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from asyncio import Task
import asyncio
from dataclasses import dataclass
from datetime import UTC, date, datetime
import enum
Expand Down Expand Up @@ -329,7 +329,7 @@ def __init__(
) -> None:
"""Init this sensor."""
super().__init__(unique_id, cluster_handlers, endpoint, device, **kwargs)
self._polling_task: Task | None = None
self._polling_task: asyncio.Task | None = None

def on_add(self) -> None:
"""Run when entity is added."""
Expand Down Expand Up @@ -604,6 +604,10 @@ class ElectricalMeasurement(PollableSensor):
_multiplier_attribute_name: str | None = "ac_power_multiplier"
_attr_max_attribute_name: str = None

# The final state is computed from up to three attributes, wait for them all to come
# in before emitting a change
_aggregate_attribute_reports_timeout: float = 2.0

def __init__(
self,
unique_id: str,
Expand All @@ -619,6 +623,90 @@ def __init__(
self._max_attribute_name,
}

self._pending_state_update_attributes: set[str] = set()
self._pending_state_update_timer: asyncio.TimerHandle | None = None

@property
def _all_state_update_attributes(self) -> set[str]:
"""Return a set of attributes that are required to compute state."""
return {
attr_name
for attr_name in (
(
self._attribute_name,
self._divisor_attribute_name,
self._multiplier_attribute_name,
)
+ tuple(self._attr_extra_state_attribute_names)
)
if (
attr_name is not None
and attr_name
not in self._cluster_handler.cluster.unsupported_attributes
)
} - {"measurement_type"}

async def on_remove(self) -> None:
"""Run when entity is removed."""
if self._pending_state_update_timer is not None:
self._pending_state_update_timer.cancel()
self._pending_state_update_timer = None

await super().on_remove()

def handle_cluster_handler_attribute_updated(
self,
event: ClusterAttributeUpdatedEvent,
) -> None:
"""Handle attribute updates from the cluster handler."""
state_update_attrs = self._all_state_update_attributes

if len(state_update_attrs) == 1 or not (
event.attribute_name == self._attribute_name
or event.attribute_name in self._attr_extra_state_attribute_names
):
super().handle_cluster_handler_attribute_updated(event)
return

# We need to wait for all of the relevant attributes to be received before we
# can emit a state change event
if not self._pending_state_update_attributes:
self._pending_state_update_attributes = state_update_attrs

loop = asyncio.get_running_loop()
self._pending_state_update_timer = loop.call_later(
self._aggregate_attribute_reports_timeout,
self._emit_state_change_after_attributes_received,
)

# If we have no attributes to wait for *or* we receive a new attribute report
# for an existing attribute during a timeout window, we need to emit immediately
if (
not self._pending_state_update_attributes
or event.attribute_name not in self._pending_state_update_attributes
):
self._emit_state_change_after_attributes_received()
else:
self._pending_state_update_attributes.discard(event.attribute_name)
_LOGGER.debug(
"Waiting for attributes to be reported before changing state: %s",
self._pending_state_update_attributes,
)

def _emit_state_change_after_attributes_received(self) -> None:
"""Emit a state change after all attributes have been received."""
self._pending_state_update_attributes.clear()

if self._pending_state_update_timer is not None:
self._pending_state_update_timer.cancel()
self._pending_state_update_timer = None

_LOGGER.debug(
"Emitting state changed event, pending attributes: %s",
self._pending_state_update_attributes,
)
self.maybe_emit_state_changed_event()

@property
def _max_attribute_name(self) -> str:
"""Return the max attribute name."""
Expand Down
Loading