Skip to content

Commit ecce1ba

Browse files
authored
Cleanly shut down the serial port on disconnect (#633)
* Cleanly handle connection loss * Guard disconnect * Clean up exception handling and reduce unnecessary resets * Rename `application` to `api` in EZSP UART * Ensure `enter_failed_state` passes through an exception object * Fix unit tests * Bump minimum zigpy version * Fix CLI * Drop accidental import * 100% coverage
1 parent 7e1008e commit ecce1ba

16 files changed

+295
-371
lines changed

bellows/ash.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ def __init__(self, code: t.NcpResetCode) -> None:
130130
def __repr__(self) -> str:
131131
return f"<{self.__class__.__name__}(code={self.code})>"
132132

133+
def __eq__(self, other: object) -> bool | NotImplemented:
134+
if not isinstance(other, NcpFailure):
135+
return NotImplemented
136+
137+
return self.code == other.code
138+
133139

134140
class AshFrame(abc.ABC, BaseDataclassMixin):
135141
MASK: t.uint8_t
@@ -368,7 +374,7 @@ def connection_made(self, transport):
368374
self._transport = transport
369375
self._ezsp_protocol.connection_made(self)
370376

371-
def connection_lost(self, exc):
377+
def connection_lost(self, exc: Exception | None) -> None:
372378
self._transport = None
373379
self._cancel_pending_data_frames()
374380
self._ezsp_protocol.connection_lost(exc)

bellows/cli/dump.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def dump(ctx, channel, outfile):
3737
finally:
3838
if "ezsp" in ctx.obj:
3939
loop.run_until_complete(ctx.obj["ezsp"].mfglibEnd())
40-
ctx.obj["ezsp"].close()
40+
loop.run_until_complete(ctx.obj["ezsp"].disconnect())
4141

4242

4343
def ieee_15_4_fcs(data: bytes) -> bytes:

bellows/cli/ncp.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def config(ctx, config, all_):
3030
if v[0] == t.EzspStatus.ERROR_INVALID_ID:
3131
continue
3232
click.echo(f"{config.name}={v[1]}")
33-
s.close()
33+
await s.disconnect()
3434
return
3535

3636
if "=" in config:
@@ -54,7 +54,7 @@ async def config(ctx, config, all_):
5454

5555
v = await s.setConfigurationValue(config, value)
5656
click.echo(v)
57-
s.close()
57+
await s.disconnect()
5858
return
5959

6060
v = await s.getConfigurationValue(config)
@@ -86,7 +86,7 @@ async def info(ctx):
8686
click.echo(f"Board name: {brd_name}")
8787
click.echo(f"EmberZNet version: {version}")
8888

89-
s.close()
89+
await s.disconnect()
9090

9191

9292
@main.command()
@@ -105,7 +105,7 @@ async def bootloader(ctx):
105105
version, plat, micro, phy = await ezsp.getStandaloneBootloaderVersionPlatMicroPhy()
106106
if version == 0xFFFF:
107107
click.echo("No boot loader installed")
108-
ezsp.close()
108+
await ezsp.disconnect()
109109
return
110110

111111
click.echo(
@@ -118,4 +118,4 @@ async def bootloader(ctx):
118118
click.echo(f"Couldn't launch bootloader: {res[0]}")
119119
else:
120120
click.echo("bootloader launched successfully")
121-
ezsp.close()
121+
await ezsp.disconnect()

bellows/cli/network.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def cb(fut, frame_name, response):
106106

107107
s.remove_callback(cbid)
108108

109-
s.close()
109+
await s.disconnect()
110110

111111

112112
@main.command()
@@ -126,7 +126,7 @@ async def leave(ctx):
126126
expected=t.EmberStatus.NETWORK_DOWN,
127127
)
128128

129-
s.close()
129+
await s.disconnect()
130130

131131

132132
@main.command()
@@ -157,4 +157,4 @@ async def scan(ctx, channels, duration_ms, energy_scan):
157157
for network in v:
158158
click.echo(network)
159159

160-
s.close()
160+
await s.disconnect()

bellows/cli/stream.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def stream(ctx, channel, power):
3535
s = ctx.obj["ezsp"]
3636
loop.run_until_complete(s.mfglibStopStream())
3737
loop.run_until_complete(s.mfglibEnd())
38-
s.close()
38+
loop.run_until_complete(s.disconnect())
3939

4040

4141
async def _stream(ctx, channel, power):

bellows/cli/tone.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def tone(ctx, channel, power):
3535
s = ctx.obj["ezsp"]
3636
loop.run_until_complete(s.mfglibStopTone())
3737
loop.run_until_complete(s.mfglibEnd())
38-
s.close()
38+
loop.run_until_complete(s.disconnect())
3939

4040

4141
async def _tone(ctx, channel, power):

bellows/cli/util.py

+7-18
Original file line numberDiff line numberDiff line change
@@ -59,28 +59,17 @@ async def async_inner(ctx, *args, **kwargs):
5959
if extra_config:
6060
app_config.update(extra_config)
6161
application = await setup_application(app_config, startup=app_startup)
62-
ctx.obj["app"] = application
63-
await f(ctx, *args, **kwargs)
64-
await asyncio.sleep(0.5)
65-
await application.shutdown()
66-
67-
def shutdown():
68-
with contextlib.suppress(Exception):
69-
application._ezsp.close()
62+
try:
63+
ctx.obj["app"] = application
64+
await f(ctx, *args, **kwargs)
65+
finally:
66+
with contextlib.suppress(Exception):
67+
await application.shutdown()
7068

7169
@functools.wraps(f)
7270
def inner(*args, **kwargs):
7371
loop = asyncio.get_event_loop()
74-
try:
75-
loop.run_until_complete(async_inner(*args, **kwargs))
76-
except: # noqa: E722
77-
# It seems that often errors like a message send will try to send
78-
# two messages, and not reading all of them will leave the NCP in
79-
# a bad state. This seems to mitigate this somewhat. Better way?
80-
loop.run_until_complete(asyncio.sleep(0.5))
81-
raise
82-
finally:
83-
shutdown()
72+
loop.run_until_complete(async_inner(*args, **kwargs))
8473

8574
return inner
8675

bellows/ezsp/__init__.py

+18-34
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from typing import Any, Callable, Generator
1313
import urllib.parse
1414

15+
from bellows.ash import NcpFailure
16+
1517
if sys.version_info[:2] < (3, 11):
1618
from async_timeout import timeout as asyncio_timeout # pragma: no cover
1719
else:
@@ -55,13 +57,14 @@ class EZSP:
5557
v14.EZSPv14.VERSION: v14.EZSPv14,
5658
}
5759

58-
def __init__(self, device_config: dict):
60+
def __init__(self, device_config: dict, application: Any | None = None):
5961
self._config = device_config
6062
self._callbacks = {}
6163
self._ezsp_event = asyncio.Event()
6264
self._ezsp_version = v4.EZSPv4.VERSION
6365
self._gw = None
6466
self._protocol = None
67+
self._application = application
6568

6669
self._stack_status_listeners: collections.defaultdict[
6770
t.sl_Status, list[asyncio.Future]
@@ -122,25 +125,17 @@ async def startup_reset(self) -> None:
122125

123126
await self.version()
124127

125-
@classmethod
126-
async def initialize(cls, zigpy_config: dict) -> EZSP:
127-
"""Return initialized EZSP instance."""
128-
ezsp = cls(zigpy_config[conf.CONF_DEVICE])
129-
await ezsp.connect(use_thread=zigpy_config[conf.CONF_USE_THREAD])
128+
async def connect(self, *, use_thread: bool = True) -> None:
129+
assert self._gw is None
130+
self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread)
130131

131132
try:
132-
await ezsp.startup_reset()
133+
self._protocol = v4.EZSPv4(self.handle_callback, self._gw)
134+
await self.startup_reset()
133135
except Exception:
134-
ezsp.close()
136+
await self.disconnect()
135137
raise
136138

137-
return ezsp
138-
139-
async def connect(self, *, use_thread: bool = True) -> None:
140-
assert self._gw is None
141-
self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread)
142-
self._protocol = v4.EZSPv4(self.handle_callback, self._gw)
143-
144139
async def reset(self):
145140
LOGGER.debug("Resetting EZSP")
146141
self.stop_ezsp()
@@ -179,10 +174,10 @@ async def version(self):
179174
ver,
180175
)
181176

182-
def close(self):
177+
async def disconnect(self):
183178
self.stop_ezsp()
184179
if self._gw:
185-
self._gw.close()
180+
await self._gw.disconnect()
186181
self._gw = None
187182

188183
async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
@@ -264,23 +259,12 @@ async def leaveNetwork(self, timeout: float | int = NETWORK_OPS_TIMEOUT) -> None
264259

265260
def connection_lost(self, exc):
266261
"""Lost serial connection."""
267-
LOGGER.debug(
268-
"%s connection lost unexpectedly: %s",
269-
self._config[conf.CONF_DEVICE_PATH],
270-
exc,
271-
)
272-
self.enter_failed_state(f"Serial connection loss: {exc!r}")
273-
274-
def enter_failed_state(self, error):
275-
"""UART received error frame."""
276-
if len(self._callbacks) > 1:
277-
LOGGER.error("NCP entered failed state. Requesting APP controller restart")
278-
self.close()
279-
self.handle_callback("_reset_controller_application", (error,))
280-
else:
281-
LOGGER.info(
282-
"NCP entered failed state. No application handler registered, ignoring..."
283-
)
262+
if self._application is not None:
263+
self._application.connection_lost(exc)
264+
265+
def enter_failed_state(self, code: t.NcpResetCode) -> None:
266+
"""UART received reset code."""
267+
self.connection_lost(NcpFailure(code=code))
284268

285269
def __getattr__(self, name: str) -> Callable:
286270
if name not in self._protocol.COMMANDS:

bellows/uart.py

+19-37
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,29 @@
1818
RESET_TIMEOUT = 5
1919

2020

21-
class Gateway(asyncio.Protocol):
22-
def __init__(self, application, connected_future=None, connection_done_future=None):
23-
self._application = application
21+
class Gateway(zigpy.serial.SerialProtocol):
22+
def __init__(self, api, connection_done_future=None):
23+
super().__init__()
24+
self._api = api
2425

2526
self._reset_future = None
2627
self._startup_reset_future = None
27-
self._connected_future = connected_future
2828
self._connection_done_future = connection_done_future
2929

30-
self._transport = None
31-
32-
def close(self):
33-
self._transport.close()
34-
35-
def connection_made(self, transport):
36-
"""Callback when the uart is connected"""
37-
self._transport = transport
38-
if self._connected_future is not None:
39-
self._connected_future.set_result(True)
40-
4130
async def send_data(self, data: bytes) -> None:
4231
await self._transport.send_data(data)
4332

4433
def data_received(self, data):
4534
"""Callback when there is data received from the uart"""
46-
self._application.frame_received(data)
35+
36+
# We intentionally do not call `SerialProtocol.data_received`
37+
self._api.frame_received(data)
4738

4839
def reset_received(self, code: t.NcpResetCode) -> None:
4940
"""Reset acknowledgement frame receive handler"""
50-
# not a reset we've requested. Signal application reset
41+
# not a reset we've requested. Signal api reset
5142
if code is not t.NcpResetCode.RESET_SOFTWARE:
52-
self._application.enter_failed_state(code)
43+
self._api.enter_failed_state(code)
5344
return
5445

5546
if self._reset_future and not self._reset_future.done():
@@ -61,7 +52,7 @@ def reset_received(self, code: t.NcpResetCode) -> None:
6152

6253
def error_received(self, code: t.NcpResetCode) -> None:
6354
"""Error frame receive handler."""
64-
self._application.enter_failed_state(code)
55+
self._api.enter_failed_state(code)
6556

6657
async def wait_for_startup_reset(self) -> None:
6758
"""Wait for the first reset frame on startup."""
@@ -77,12 +68,9 @@ def _reset_cleanup(self, future):
7768
"""Delete reset future."""
7869
self._reset_future = None
7970

80-
def eof_received(self):
81-
"""Server gracefully closed its side of the connection."""
82-
self.connection_lost(ConnectionResetError("Remote server closed connection"))
83-
8471
def connection_lost(self, exc):
8572
"""Port was closed unexpectedly."""
73+
super().connection_lost(exc)
8674

8775
LOGGER.debug("Connection lost: %r", exc)
8876
reason = exc or ConnectionResetError("Remote server closed connection")
@@ -102,12 +90,7 @@ def connection_lost(self, exc):
10290
self._reset_future.set_exception(reason)
10391
self._reset_future = None
10492

105-
if exc is None:
106-
LOGGER.debug("Closed serial connection")
107-
return
108-
109-
LOGGER.error("Lost serial connection: %r", exc)
110-
self._application.connection_lost(exc)
93+
self._api.connection_lost(exc)
11194

11295
async def reset(self):
11396
"""Send a reset frame and init internal state."""
@@ -126,13 +109,12 @@ async def reset(self):
126109
return await self._reset_future
127110

128111

129-
async def _connect(config, application):
112+
async def _connect(config, api):
130113
loop = asyncio.get_event_loop()
131114

132-
connection_future = loop.create_future()
133115
connection_done_future = loop.create_future()
134116

135-
gateway = Gateway(application, connection_future, connection_done_future)
117+
gateway = Gateway(api, connection_done_future)
136118
protocol = AshProtocol(gateway)
137119

138120
if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None:
@@ -149,25 +131,25 @@ async def _connect(config, application):
149131
rtscts=rtscts,
150132
)
151133

152-
await connection_future
134+
await gateway.wait_until_connected()
153135

154136
thread_safe_protocol = ThreadsafeProxy(gateway, loop)
155137
return thread_safe_protocol, connection_done_future
156138

157139

158-
async def connect(config, application, use_thread=True):
140+
async def connect(config, api, use_thread=True):
159141
if use_thread:
160-
application = ThreadsafeProxy(application, asyncio.get_event_loop())
142+
api = ThreadsafeProxy(api, asyncio.get_event_loop())
161143
thread = EventLoopThread()
162144
await thread.start()
163145
try:
164146
protocol, connection_done = await thread.run_coroutine_threadsafe(
165-
_connect(config, application)
147+
_connect(config, api)
166148
)
167149
except Exception:
168150
thread.force_stop()
169151
raise
170152
connection_done.add_done_callback(lambda _: thread.force_stop())
171153
else:
172-
protocol, _ = await _connect(config, application)
154+
protocol, _ = await _connect(config, api)
173155
return protocol

0 commit comments

Comments
 (0)