From 422a0818df45df4e24df23653aeea64aa7a598f4 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 14 Aug 2025 14:28:30 -0400 Subject: [PATCH 1/9] Make sure `is_supported_in_list` can handle comparisons to self --- zha/application/platforms/button/__init__.py | 2 +- zha/application/platforms/sensor/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/zha/application/platforms/button/__init__.py b/zha/application/platforms/button/__init__.py index ed8293f21..d7ed4d565 100644 --- a/zha/application/platforms/button/__init__.py +++ b/zha/application/platforms/button/__init__.py @@ -124,7 +124,7 @@ class IdentifyButton(Button): def is_supported_in_list(self, entities: list[BaseEntity]) -> bool: """Check if this button is supported given the list of entities.""" cls = type(self) - return not any(type(entity) is cls for entity in entities) + return not any(type(entity) is cls for entity in entities if entity is not self) class WriteAttributeButton(PlatformEntity): diff --git a/zha/application/platforms/sensor/__init__.py b/zha/application/platforms/sensor/__init__.py index fcfa8558e..730314d33 100644 --- a/zha/application/platforms/sensor/__init__.py +++ b/zha/application/platforms/sensor/__init__.py @@ -1641,7 +1641,7 @@ def _is_supported(self) -> bool: def is_supported_in_list(self, entities: list[BaseEntity]) -> bool: """Check if the sensor is supported given the list of entities.""" cls = type(self) - return not any(type(entity) is cls for entity in entities) + return not any(type(entity) is cls for entity in entities if entity is not self) @property def state(self) -> dict: From 7aafd88da299bb732ba32ec12c9ee7d49bd1db34 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 14 Aug 2025 14:28:57 -0400 Subject: [PATCH 2/9] Create events for entity addition and removal --- zha/application/const.py | 2 ++ zha/zigbee/device.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/zha/application/const.py b/zha/application/const.py index 08c857afb..8ac116fc2 100644 --- a/zha/application/const.py +++ b/zha/application/const.py @@ -193,6 +193,8 @@ def pretty_name(self) -> str: ZHA_CLUSTER_HANDLER_READS_PER_REQ = 5 ZHA_EVENT = "zha_event" ZHA_DEVICE_UPDATED_EVENT = "zha_device_updated_event" +ZHA_DEVICE_ENTITY_ADDED_EVENT = "zha_device_entity_added_event" +ZHA_DEVICE_ENTITY_REMOVED_EVENT = "zha_device_entity_removed_event" ZHA_GW_MSG = "zha_gateway_message" ZHA_GW_MSG_DEVICE_FULL_INIT = "device_fully_initialized" ZHA_GW_MSG_DEVICE_INFO = "device_info" diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index af0202a2d..a3fa20d52 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -61,6 +61,8 @@ UNKNOWN_MODEL, ZHA_CLUSTER_HANDLER_CFG_DONE, ZHA_CLUSTER_HANDLER_MSG, + ZHA_DEVICE_ENTITY_ADDED_EVENT, + ZHA_DEVICE_ENTITY_REMOVED_EVENT, ZHA_DEVICE_UPDATED_EVENT, ZHA_EVENT, ) @@ -166,6 +168,27 @@ class DeviceFirmwareInfoUpdatedEvent: new_firmware_version: str | None +@dataclass(kw_only=True, frozen=True) +class DeviceEntityAddedEvent: + """Event generated when a new entity is added to a device.""" + + event_type: Final[str] = ZHA_DEVICE_ENTITY_ADDED_EVENT + event: Final[str] = ZHA_DEVICE_ENTITY_ADDED_EVENT + + # TODO: allow all entity information to be serialized and include it here + unique_id: str + + +@dataclass(kw_only=True, frozen=True) +class DeviceEntityRemovedEvent: + """Event generated when a new entity is added to a device.""" + + event_type: Final[str] = ZHA_DEVICE_ENTITY_REMOVED_EVENT + event: Final[str] = ZHA_DEVICE_ENTITY_REMOVED_EVENT + + unique_id: str + + @dataclass(kw_only=True, frozen=True) class ClusterHandlerConfigurationComplete: """Event generated when all cluster handlers are configured.""" From 02f16daa65bd4c46ea79b0c94a8b1b6c3f75bfdb Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 14 Aug 2025 14:31:43 -0400 Subject: [PATCH 3/9] Move entity adding to a new method and add another for removal --- zha/zigbee/device.py | 46 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index a3fa20d52..07fbe0ac9 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -969,6 +969,45 @@ def _discover_new_entities(self) -> None: entity.on_add() self._pending_entities.append(entity) + def _add_entity(self, entity: PlatformEntity) -> None: + """Add an entity to the device.""" + key = (entity.PLATFORM, entity.unique_id) + + if key in self._platform_entities: + raise ValueError( + f"Cannot add entity {entity!r}, unique ID already taken by {self._platform_entities[key]!r}" + ) + + _LOGGER.debug("Discovered new entity %s", entity) + self._platform_entities[key] = entity + # entity.on_add() + self.emit( + DeviceEntityAddedEvent.event_type, + DeviceEntityAddedEvent( + unique_id=entity.unique_id, + ), + ) + + async def _remove_entity( + self, entity: BaseEntity, *, emit_event: bool = True + ) -> None: + """Remove an entity from the device.""" + key = (entity.PLATFORM, entity.unique_id) + + if key not in self._platform_entities: + raise ValueError(f"Cannot remove entity {entity!r}, unique ID not found") + + await entity.on_remove() + del self._platform_entities[key] + + if emit_event: + self.emit( + DeviceEntityRemovedEvent.event_type, + DeviceEntityRemovedEvent( + unique_id=entity.unique_id, + ), + ) + async def async_initialize(self, from_cache: bool = False) -> None: """Initialize cluster handlers.""" self.debug("started initialization") @@ -1004,15 +1043,14 @@ async def async_initialize(self, from_cache: bool = False) -> None: key = (entity.PLATFORM, entity.unique_id) # Ignore entities that already exist - if key in new_entities: + if key in new_entities or key in self._platform_entities: await entity.on_remove() continue new_entities[key] = entity - if new_entities: - _LOGGER.debug("Discovered new entities %r", new_entities) - self._platform_entities.update(new_entities) + for entity in new_entities.values(): + self._add_entity(entity) # At this point we can compute a primary entity self._compute_primary_entity() From 1ec0671e2cd1d593fd1940ebf4533455a1365714 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 14 Aug 2025 14:35:33 -0400 Subject: [PATCH 4/9] Move pending entity initialization to a new function --- zha/zigbee/device.py | 45 ++++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index 07fbe0ac9..9802c7cbd 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -1008,26 +1008,8 @@ async def _remove_entity( ), ) - async def async_initialize(self, from_cache: bool = False) -> None: - """Initialize cluster handlers.""" - self.debug("started initialization") - - self._discover_new_entities() - - await self._zdo_handler.async_initialize(from_cache) - self._zdo_handler.debug("'async_initialize' stage succeeded") - - # We intentionally do not use `gather` here! This is so that if, for example, - # three `device.async_initialize()`s are spawned, only three concurrent requests - # will ever be in flight at once. Startup concurrency is managed at the device - # level. - for endpoint in self._endpoints.values(): - try: - await endpoint.async_initialize(from_cache) - except Exception: # pylint: disable=broad-exception-caught - self.debug("Failed to initialize endpoint", exc_info=True) - - # Compute the final entities + async def _add_pending_entities(self) -> None: + """Add pending entities to the device.""" new_entities: dict[tuple[Platform, str], PlatformEntity] = {} for entity in self._pending_entities: @@ -1055,6 +1037,29 @@ async def async_initialize(self, from_cache: bool = False) -> None: # At this point we can compute a primary entity self._compute_primary_entity() + async def async_initialize(self, from_cache: bool = False) -> None: + """Initialize cluster handlers.""" + self.debug("started initialization") + + # We discover prospective entities before initialization + self._discover_new_entities() + + await self._zdo_handler.async_initialize(from_cache) + self._zdo_handler.debug("'async_initialize' stage succeeded") + + # We intentionally do not use `gather` here! This is so that if, for example, + # three `device.async_initialize()`s are spawned, only three concurrent requests + # will ever be in flight at once. Startup concurrency is managed at the device + # level. + for endpoint in self._endpoints.values(): + try: + await endpoint.async_initialize(from_cache) + except Exception: # pylint: disable=broad-exception-caught + self.debug("Failed to initialize endpoint", exc_info=True) + + # And add them after + await self._add_pending_entities() + # Sync the device's firmware version with the first platform entity for (platform, _unique_id), entity in self.platform_entities.items(): if platform != Platform.UPDATE: From 2e3729d6ea852f3f9c02daba20b02640291f7f3d Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 14 Aug 2025 14:37:36 -0400 Subject: [PATCH 5/9] Account for current entities when adding pending ones --- zha/zigbee/device.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index 9802c7cbd..a49774780 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -6,7 +6,7 @@ import asyncio from collections import defaultdict -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator, Sequence import copy import dataclasses from dataclasses import dataclass @@ -1010,6 +1010,7 @@ async def _remove_entity( async def _add_pending_entities(self) -> None: """Add pending entities to the device.""" + all_entities = dict(self._platform_entities) new_entities: dict[tuple[Platform, str], PlatformEntity] = {} for entity in self._pending_entities: @@ -1017,7 +1018,7 @@ async def _add_pending_entities(self) -> None: # Ignore unsupported entities if not entity.is_supported() or not entity.is_supported_in_list( - new_entities.values() + all_entities.values() ): await entity.on_remove() continue @@ -1025,18 +1026,20 @@ async def _add_pending_entities(self) -> None: key = (entity.PLATFORM, entity.unique_id) # Ignore entities that already exist - if key in new_entities or key in self._platform_entities: + if key in all_entities: await entity.on_remove() continue + all_entities[key] = entity new_entities[key] = entity + # Compute a new primary entity + self._compute_primary_entity(all_entities.values()) + + # Finally, add the new entities for entity in new_entities.values(): self._add_entity(entity) - # At this point we can compute a primary entity - self._compute_primary_entity() - async def async_initialize(self, from_cache: bool = False) -> None: """Initialize cluster handlers.""" self.debug("started initialization") @@ -1395,13 +1398,11 @@ def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: args = (self.nwk, self.model) + args _LOGGER.log(level, msg, *args, **kwargs) - def _compute_primary_entity(self) -> None: - """Compute the primary entity for this device.""" + def _compute_primary_entity(self, entities: Sequence[PlatformEntity]) -> None: + """Compute the primary entity from a given set of entities.""" # First, check if any entity is explicitly primary - explicitly_primary = [ - entity for entity in self._platform_entities.values() if entity.primary - ] + explicitly_primary = [entity for entity in entities if entity.primary] if len(explicitly_primary) == 1: self.debug( @@ -1417,7 +1418,7 @@ def _compute_primary_entity(self) -> None: # not explicitly marked as not primary candidates = [ e - for e in self._platform_entities.values() + for e in entities if e.enabled and hasattr(e, "info_object") and e._attr_primary is not False ] candidates.sort(reverse=True, key=lambda e: e.primary_weight) From 25b3a25ad82a8537e91567eb9def910d4c31c96e Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 14 Aug 2025 14:39:10 -0400 Subject: [PATCH 6/9] Create a `recompute_entities` method --- tests/test_discover.py | 3 +++ zha/zigbee/device.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/tests/test_discover.py b/tests/test_discover.py index ea8e9fd57..1e498f2b0 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -807,6 +807,9 @@ async def test_devices_from_files( await zha_gateway.async_block_till_done(wait_background_tasks=True) assert zha_device is not None + # Ensure entity recomputation is idempotent + await zha_device.recompute_entities() + unique_id_collisions = defaultdict(list) for entity in zha_device.platform_entities.values(): unique_id_collisions[entity.unique_id].append(entity) diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index a49774780..5ae25e51e 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -1040,6 +1040,25 @@ async def _add_pending_entities(self) -> None: for entity in new_entities.values(): self._add_entity(entity) + async def recompute_entities(self) -> None: + """Recompute all entities for this device.""" + self.debug("Recomputing entities") + + entities = list(self._platform_entities.values()) + + # Remove all entities that are no longer supported + for entity in entities[:]: + entity.recompute_capabilities() + + if not entity.is_supported() or not entity.is_supported_in_list(entities): + self.debug("Removing unsupported entity %s", entity) + await self._remove_entity(entity) + entities.remove(entity) + + # Discover new entities + self._discover_new_entities() + await self._add_pending_entities() + async def async_initialize(self, from_cache: bool = False) -> None: """Initialize cluster handlers.""" self.debug("started initialization") From bdfe5596e9b354b421cda62ff19c9a1327a3d7e1 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 14 Aug 2025 14:40:19 -0400 Subject: [PATCH 7/9] Use `_remove_entity` to remove entities --- tests/test_discover.py | 4 ++-- zha/zigbee/device.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_discover.py b/tests/test_discover.py index 1e498f2b0..30a85a7d0 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -844,8 +844,6 @@ async def test_devices_from_files( unique_id_migrations[key] = entity - await zha_device.on_remove() - # XXX: We re-serialize the JSON because integer enum types are converted when # serializing but will not compare properly otherwise loaded_device_data = json.loads( @@ -874,3 +872,5 @@ async def test_devices_from_files( tsn=None, ) ] + + await zha_device.on_remove() diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index 5ae25e51e..0f5541168 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -1111,8 +1111,10 @@ async def on_remove(self) -> None: for callback in self._on_remove_callbacks: callback() - for platform_entity in self._platform_entities.values(): - await platform_entity.on_remove() + for platform_entity in list(self._platform_entities.values()): + # TODO: To avoid unnecessary traffic during shutdown, we don't need to emit + # an event for every entity, just the device + await self._remove_entity(platform_entity, emit_event=False) for entity in self._pending_entities: await entity.on_remove() From 983586a644fdf2ab914dbdb6e6998a06c7ca12ac Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 14 Aug 2025 14:58:40 -0400 Subject: [PATCH 8/9] Add a unit test --- tests/test_device.py | 61 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/test_device.py b/tests/test_device.py index 3670c1f3c..f67fe2bd5 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -18,6 +18,7 @@ from zigpy.zcl import ClusterType from zigpy.zcl.clusters import general from zigpy.zcl.clusters.general import Ota, PowerConfiguration +from zigpy.zcl.clusters.lighting import Color from zigpy.zcl.foundation import Status, WriteAttributesResponse import zigpy.zdo.types as zdo_t @@ -49,6 +50,8 @@ from zha.exceptions import ZHAException from zha.zigbee.device import ( ClusterBinding, + DeviceEntityAddedEvent, + DeviceEntityRemovedEvent, DeviceFirmwareInfoUpdatedEvent, ZHAEvent, get_device_automation_triggers, @@ -1201,3 +1204,61 @@ async def test_symfonisk_events( ) ) ] + + +async def test_entity_recomputation(zha_gateway: Gateway) -> None: + """Test entity recomputation.""" + zigpy_dev = await zigpy_device_from_json( + zha_gateway.application_controller, + "tests/data/devices/ikea-of-sweden-tradfri-bulb-gu10-ws-400lm.json", + ) + zha_device = await join_zigpy_device(zha_gateway, zigpy_dev) + + event_listener = mock.Mock() + zha_device.on_all_events(event_listener) + + entities1 = set(zha_device.platform_entities.values()) + + # We lose track of the color temperature + zha_device._zigpy_device.endpoints[1].light_color.add_unsupported_attribute( + Color.AttributeDefs.start_up_color_temperature.id + ) + await zha_device.recompute_entities() + + entities2 = set(zha_device.platform_entities.values()) + assert entities2 - entities1 == set() + assert len(entities1 - entities2) == 1 + assert ( + list(entities1 - entities2)[0].unique_id + == "68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature" + ) + assert event_listener.mock_calls == [ + call( + DeviceEntityRemovedEvent( + unique_id="68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature" + ) + ) + ] + + event_listener.reset_mock() + + # We add it back + zha_device._zigpy_device.endpoints[1].light_color.remove_unsupported_attribute( + Color.AttributeDefs.start_up_color_temperature.id + ) + await zha_device.recompute_entities() + + entities3 = set(zha_device.platform_entities.values()) + assert ( + list(entities3 - entities2)[0].unique_id + == "68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature" + ) + assert {e.unique_id for e in entities1} == {e.unique_id for e in entities3} + + assert event_listener.mock_calls == [ + call( + DeviceEntityAddedEvent( + unique_id="68:0a:e2:ff:fe:8f:fa:33-1-768-start_up_color_temperature" + ) + ) + ] From d4118a66b067225baafcc762fc06f44869da7748 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 14 Aug 2025 15:40:37 -0400 Subject: [PATCH 9/9] Ensure pending entities are cleaned up --- zha/zigbee/device.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index 0f5541168..df19e8333 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -979,8 +979,9 @@ def _add_entity(self, entity: PlatformEntity) -> None: ) _LOGGER.debug("Discovered new entity %s", entity) + + # `entity.on_add()` is assumed to have been called already self._platform_entities[key] = entity - # entity.on_add() self.emit( DeviceEntityAddedEvent.event_type, DeviceEntityAddedEvent( @@ -1033,6 +1034,8 @@ async def _add_pending_entities(self) -> None: all_entities[key] = entity new_entities[key] = entity + self._pending_entities.clear() + # Compute a new primary entity self._compute_primary_entity(all_entities.values())