Skip to content

Commit 1a0b8a7

Browse files
authored
Move EZSP send lock from EZSP to individual protocol handlers (#649)
* Log frames that are not ACKed * Move command locking and prioritization into the protocol handler * Rename `delay_time` to `send_time` * Cancel all pending futures when the connection is lost * Increase ACK_TIMEOUTS from 4 to 5 * Increase the EZSP command timeout to 10s * Do not count the ASH send time in the EZSP command timeout * Set the NCP state to `FAILED` when we soft fail * Always handle ACK information, even if the frame is invalid * Remove stale constants from `Gateway` * Guard to make sure we can't send data while the transport is closing * Fix unit tests * Send a NAK frame on any parsing error * Reset the random seed every ASH test invocation * Remove unnecessary `asyncio.get_running_loop()` * Add a few more unit tests for coverage * Null out the transport when we are done with it * Fix typo when setting ncp_state * Fix typo with buffer truncation * Fix unit test to account for retries after NCP failure
1 parent e160be2 commit 1a0b8a7

File tree

7 files changed

+152
-81
lines changed

7 files changed

+152
-81
lines changed

bellows/ash.py

+45-23
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import binascii
66
from collections.abc import Coroutine
7+
import contextlib
78
import dataclasses
89
import enum
910
import logging
@@ -62,7 +63,7 @@ class Reserved(enum.IntEnum):
6263
# Maximum number of consecutive timeouts allowed while waiting to receive an ACK before
6364
# going to the FAILED state. The value 0 prevents the NCP from entering the error state
6465
# due to timeouts.
65-
ACK_TIMEOUTS = 4
66+
ACK_TIMEOUTS = 5
6667

6768

6869
def generate_random_sequence(length: int) -> bytes:
@@ -368,14 +369,26 @@ def connection_made(self, transport):
368369
self._ezsp_protocol.connection_made(self)
369370

370371
def connection_lost(self, exc):
372+
self._transport = None
373+
self._cancel_pending_data_frames()
371374
self._ezsp_protocol.connection_lost(exc)
372375

373376
def eof_received(self):
374377
self._ezsp_protocol.eof_received()
375378

379+
def _cancel_pending_data_frames(
380+
self, exc: BaseException = RuntimeError("Connection has been closed")
381+
):
382+
for fut in self._pending_data_frames.values():
383+
if not fut.done():
384+
fut.set_exception(exc)
385+
376386
def close(self):
387+
self._cancel_pending_data_frames()
388+
377389
if self._transport is not None:
378390
self._transport.close()
391+
self._transport = None
379392

380393
@staticmethod
381394
def _stuff_bytes(data: bytes) -> bytes:
@@ -399,7 +412,9 @@ def _unstuff_bytes(data: bytes) -> bytes:
399412
for c in data:
400413
if escaped:
401414
byte = c ^ 0b00100000
402-
assert byte in RESERVED_BYTES
415+
if byte not in RESERVED_BYTES:
416+
raise ParsingError(f"Invalid escaped byte: 0x{byte:02X}")
417+
403418
out.append(byte)
404419
escaped = False
405420
elif c == Reserved.ESCAPE:
@@ -417,7 +432,7 @@ def data_received(self, data: bytes) -> None:
417432
_LOGGER.debug(
418433
"Truncating buffer to %s bytes, it is growing too fast", MAX_BUFFER_SIZE
419434
)
420-
self._buffer = self._buffer[:MAX_BUFFER_SIZE]
435+
self._buffer = self._buffer[-MAX_BUFFER_SIZE:]
421436

422437
while self._buffer:
423438
if self._discarding_until_next_flag:
@@ -447,14 +462,19 @@ def data_received(self, data: bytes) -> None:
447462
if not frame_bytes:
448463
continue
449464

450-
data = self._unstuff_bytes(frame_bytes)
451-
452465
try:
466+
data = self._unstuff_bytes(frame_bytes)
453467
frame = parse_frame(data)
454468
except Exception:
455469
_LOGGER.debug(
456470
"Failed to parse frame %r", frame_bytes, exc_info=True
457471
)
472+
473+
with contextlib.suppress(NcpFailure):
474+
self._write_frame(
475+
NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq),
476+
prefix=(Reserved.CANCEL,),
477+
)
458478
else:
459479
self.frame_received(frame)
460480
elif reserved_byte == Reserved.CANCEL:
@@ -479,7 +499,7 @@ def data_received(self, data: bytes) -> None:
479499
f"Unexpected reserved byte found: 0x{reserved_byte:02X}"
480500
) # pragma: no cover
481501

482-
def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
502+
def _handle_ack(self, frame: DataFrame | AckFrame | NakFrame) -> None:
483503
# Note that ackNum is the number of the next frame the receiver expects and it
484504
# is one greater than the last frame received.
485505
for ack_num_offset in range(-TX_K, 0):
@@ -494,14 +514,19 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
494514
def frame_received(self, frame: AshFrame) -> None:
495515
_LOGGER.debug("Received frame %r", frame)
496516

517+
# If a frame has ACK information (DATA, ACK, or NAK), it should be used even if
518+
# the frame is out of sequence or invalid
497519
if isinstance(frame, DataFrame):
520+
self._handle_ack(frame)
498521
self.data_frame_received(frame)
499-
elif isinstance(frame, RStackFrame):
500-
self.rstack_frame_received(frame)
501522
elif isinstance(frame, AckFrame):
523+
self._handle_ack(frame)
502524
self.ack_frame_received(frame)
503525
elif isinstance(frame, NakFrame):
526+
self._handle_ack(frame)
504527
self.nak_frame_received(frame)
528+
elif isinstance(frame, RStackFrame):
529+
self.rstack_frame_received(frame)
505530
elif isinstance(frame, RstFrame):
506531
self.rst_frame_received(frame)
507532
elif isinstance(frame, ErrorFrame):
@@ -513,7 +538,6 @@ def data_frame_received(self, frame: DataFrame) -> None:
513538
# The Host may not piggyback acknowledgments and should promptly send an ACK
514539
# frame when it receives a DATA frame.
515540
if frame.frm_num == self._rx_seq:
516-
self._handle_ack(frame)
517541
self._rx_seq = (frame.frm_num + 1) % 8
518542
self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq))
519543

@@ -536,14 +560,10 @@ def rstack_frame_received(self, frame: RStackFrame) -> None:
536560
self._ezsp_protocol.reset_received(frame.reset_code)
537561

538562
def ack_frame_received(self, frame: AckFrame) -> None:
539-
self._handle_ack(frame)
563+
pass
540564

541565
def nak_frame_received(self, frame: NakFrame) -> None:
542-
err = NotAcked(frame=frame)
543-
544-
for fut in self._pending_data_frames.values():
545-
if not fut.done():
546-
fut.set_exception(err)
566+
self._cancel_pending_data_frames(NotAcked(frame=frame))
547567

548568
def rst_frame_received(self, frame: RstFrame) -> None:
549569
self._ncp_reset_code = None
@@ -558,12 +578,8 @@ def error_frame_received(self, frame: ErrorFrame) -> None:
558578
self._enter_failed_state(self._ncp_reset_code)
559579

560580
def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None:
561-
exc = NcpFailure(code=reset_code)
562-
563-
for fut in self._pending_data_frames.values():
564-
if not fut.done():
565-
fut.set_exception(exc)
566-
581+
self._ncp_state = NcpState.FAILED
582+
self._cancel_pending_data_frames(NcpFailure(code=reset_code))
567583
self._ezsp_protocol.reset_received(reset_code)
568584

569585
def _write_frame(
@@ -573,6 +589,9 @@ def _write_frame(
573589
prefix: tuple[Reserved] = (),
574590
suffix: tuple[Reserved] = (Reserved.FLAG,),
575591
) -> None:
592+
if self._transport is None or self._transport.is_closing():
593+
raise NcpFailure("Transport is closed, cannot send frame")
594+
576595
if _LOGGER.isEnabledFor(logging.DEBUG):
577596
prefix_str = "".join([f"{r.name} + " for r in prefix])
578597
suffix_str = "".join([f" + {r.name}" for r in suffix])
@@ -631,7 +650,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
631650
await ack_future
632651
except NotAcked:
633652
_LOGGER.debug(
634-
"NCP responded with NAK. Retrying (attempt %d)", attempt + 1
653+
"NCP responded with NAK to %r. Retrying (attempt %d)",
654+
frame,
655+
attempt + 1,
635656
)
636657

637658
# For timing purposes, NAK can be treated as an ACK
@@ -650,9 +671,10 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
650671
raise
651672
except asyncio.TimeoutError:
652673
_LOGGER.debug(
653-
"No ACK received in %0.2fs (attempt %d)",
674+
"No ACK received in %0.2fs (attempt %d) for %r",
654675
self._t_rx_ack,
655676
attempt + 1,
677+
frame,
656678
)
657679
# If a DATA frame acknowledgement is not received within the
658680
# current timeout value, then t_rx_ack is doubled.

bellows/ezsp/__init__.py

+1-22
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from typing import Any, Callable, Generator
1313
import urllib.parse
1414

15-
from zigpy.datastructures import PriorityDynamicBoundedSemaphore
16-
1715
if sys.version_info[:2] < (3, 11):
1816
from async_timeout import timeout as asyncio_timeout # pragma: no cover
1917
else:
@@ -41,8 +39,6 @@
4139
NETWORK_OPS_TIMEOUT = 10
4240
NETWORK_COORDINATOR_STARTUP_RESET_WAIT = 1
4341

44-
MAX_COMMAND_CONCURRENCY = 1
45-
4642

4743
class EZSP:
4844
_BY_VERSION = {
@@ -66,7 +62,6 @@ def __init__(self, device_config: dict):
6662
self._ezsp_version = v4.EZSPv4.VERSION
6763
self._gw = None
6864
self._protocol = None
69-
self._send_sem = PriorityDynamicBoundedSemaphore(value=MAX_COMMAND_CONCURRENCY)
7065

7166
self._stack_status_listeners: collections.defaultdict[
7267
t.sl_Status, list[asyncio.Future]
@@ -190,21 +185,6 @@ def close(self):
190185
self._gw.close()
191186
self._gw = None
192187

193-
def _get_command_priority(self, name: str) -> int:
194-
return {
195-
# Deprioritize any commands that send packets
196-
"set_source_route": -1,
197-
"setExtendedTimeout": -1,
198-
"send_unicast": -1,
199-
"send_multicast": -1,
200-
"send_broadcast": -1,
201-
# Prioritize watchdog commands
202-
"nop": 999,
203-
"readCounters": 999,
204-
"readAndClearCounters": 999,
205-
"getValue": 999,
206-
}.get(name, 0)
207-
208188
async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
209189
command = getattr(self._protocol, name)
210190

@@ -217,8 +197,7 @@ async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
217197
)
218198
raise EzspError("EZSP is not running")
219199

220-
async with self._send_sem(priority=self._get_command_priority(name)):
221-
return await command(*args, **kwargs)
200+
return await command(*args, **kwargs)
222201

223202
async def _list_command(
224203
self, name, item_frames, completion_frame, spos, *args, **kwargs

bellows/ezsp/protocol.py

+60-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import functools
77
import logging
88
import sys
9+
import time
910
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Iterable
1011

1112
import zigpy.state
@@ -15,6 +16,8 @@
1516
else:
1617
from asyncio import timeout as asyncio_timeout # pragma: no cover
1718

19+
from zigpy.datastructures import PriorityDynamicBoundedSemaphore
20+
1821
from bellows.config import CONF_EZSP_POLICIES
1922
from bellows.exception import InvalidCommandError
2023
import bellows.types as t
@@ -23,7 +26,9 @@
2326
from bellows.uart import Gateway
2427

2528
LOGGER = logging.getLogger(__name__)
26-
EZSP_CMD_TIMEOUT = 6 # Sum of all ASH retry timeouts: 0.4 + 0.8 + 1.6 + 3.2
29+
30+
EZSP_CMD_TIMEOUT = 10
31+
MAX_COMMAND_CONCURRENCY = 1
2732

2833

2934
class ProtocolHandler(abc.ABC):
@@ -42,6 +47,9 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None:
4247
for name, (cmd_id, tx_schema, rx_schema) in self.COMMANDS.items()
4348
}
4449
self.tc_policy = 0
50+
self._send_semaphore = PriorityDynamicBoundedSemaphore(
51+
value=MAX_COMMAND_CONCURRENCY
52+
)
4553

4654
# Cached by `set_extended_timeout` so subsequent calls are a little faster
4755
self._address_table_size: int | None = None
@@ -65,18 +73,60 @@ def _ezsp_frame_rx(self, data: bytes) -> tuple[int, int, bytes]:
6573
def _ezsp_frame_tx(self, name: str) -> bytes:
6674
"""Serialize the named frame."""
6775

76+
def _get_command_priority(self, name: str) -> int:
77+
return {
78+
# Deprioritize any commands that send packets
79+
"setSourceRoute": -1,
80+
"setExtendedTimeout": -1,
81+
"sendUnicast": -1,
82+
"sendMulticast": -1,
83+
"sendBroadcast": -1,
84+
# Prioritize watchdog commands
85+
"nop": 999,
86+
"readCounters": 999,
87+
"readAndClearCounters": 999,
88+
"getValue": 999,
89+
}.get(name, 0)
90+
6891
async def command(self, name, *args, **kwargs) -> Any:
6992
"""Serialize command and send it."""
70-
LOGGER.debug("Sending command %s: %s %s", name, args, kwargs)
71-
data = self._ezsp_frame(name, *args, **kwargs)
72-
cmd_id, _, rx_schema = self.COMMANDS[name]
73-
future = asyncio.get_running_loop().create_future()
74-
self._awaiting[self._seq] = (cmd_id, rx_schema, future)
75-
self._seq = (self._seq + 1) % 256
76-
77-
async with asyncio_timeout(EZSP_CMD_TIMEOUT):
93+
delayed = False
94+
send_time = None
95+
96+
if self._send_semaphore.locked():
97+
delayed = True
98+
send_time = time.monotonic()
99+
100+
LOGGER.debug(
101+
"Send semaphore is locked, delaying before sending %s(%r, %r)",
102+
name,
103+
args,
104+
kwargs,
105+
)
106+
107+
async with self._send_semaphore(priority=self._get_command_priority(name)):
108+
if delayed:
109+
LOGGER.debug(
110+
"Sending command %s: %s %s after %0.2fs delay",
111+
name,
112+
args,
113+
kwargs,
114+
time.monotonic() - send_time,
115+
)
116+
else:
117+
LOGGER.debug("Sending command %s: %s %s", name, args, kwargs)
118+
119+
data = self._ezsp_frame(name, *args, **kwargs)
120+
cmd_id, _, rx_schema = self.COMMANDS[name]
121+
122+
future = asyncio.get_running_loop().create_future()
123+
self._awaiting[self._seq] = (cmd_id, rx_schema, future)
124+
self._seq = (self._seq + 1) % 256
125+
78126
await self._gw.send_data(data)
79-
return await future
127+
128+
async with asyncio_timeout(EZSP_CMD_TIMEOUT):
129+
return await future
80130

81131
async def update_policies(self, policy_config: dict) -> None:
82132
"""Set up the policies for what the NCP should do."""

bellows/uart.py

-15
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,6 @@
1919

2020

2121
class Gateway(asyncio.Protocol):
22-
FLAG = b"\x7E" # Marks end of frame
23-
ESCAPE = b"\x7D"
24-
XON = b"\x11" # Resume transmission
25-
XOFF = b"\x13" # Stop transmission
26-
SUBSTITUTE = b"\x18"
27-
CANCEL = b"\x1A" # Terminates a frame in progress
28-
STUFF = 0x20
29-
RANDOMIZE_START = 0x42
30-
RANDOMIZE_SEQ = 0xB8
31-
32-
RESERVED = FLAG + ESCAPE + XON + XOFF + SUBSTITUTE + CANCEL
33-
34-
class Terminator:
35-
pass
36-
3722
def __init__(self, application, connected_future=None, connection_done_future=None):
3823
self._application = application
3924

0 commit comments

Comments
 (0)