11import asyncio
2- import collections
32import contextlib
43import logging
54import ssl
@@ -18,9 +17,8 @@ def __init__(self, config):
1817 self ._config = config
1918 self ._mdns_resolver = enapter .mdns .Resolver ()
2019 self ._tls_context = self ._new_tls_context (config )
21- self ._client = None
22- self ._client_ready = asyncio .Event ()
23- self ._subscribers = collections .defaultdict (int )
20+ self ._publisher = None
21+ self ._publisher_connected = asyncio .Event ()
2422
2523 @staticmethod
2624 def _new_logger (config ):
@@ -31,89 +29,58 @@ def config(self):
3129 return self ._config
3230
3331 async def publish (self , * args , ** kwargs ):
34- client = await self ._wait_client ()
35- await client .publish (* args , ** kwargs )
32+ await self ._publisher_connected . wait ()
33+ await self . _publisher .publish (* args , ** kwargs )
3634
3735 @enapter .async_ .generator
38- async def subscribe (self , topic ):
36+ async def subscribe (self , * topics ):
3937 while True :
40- client = await self ._wait_client ()
41-
4238 try :
43- async with client . messages () as messages :
44- async with self . _subscribe ( client , topic ) :
45- async for msg in messages :
46- if msg . topic . matches ( topic ):
47- yield msg
48-
39+ async with self . _connect () as subscriber :
40+ for topic in topics :
41+ await subscriber . subscribe ( topic )
42+ self . _logger . info ( "subscriber [%s] connected" , "," . join ( topics ))
43+ async for msg in subscriber . messages :
44+ yield msg
4945 except aiomqtt .MqttError as e :
5046 self ._logger .error (e )
5147 retry_interval = 5
5248 await asyncio .sleep (retry_interval )
53-
54- @contextlib .asynccontextmanager
55- async def _subscribe (self , client , topic ):
56- first_subscriber = not self ._subscribers [topic ]
57- self ._subscribers [topic ] += 1
58- try :
59- if first_subscriber :
60- await client .subscribe (topic )
61- yield
62- finally :
63- self ._subscribers [topic ] -= 1
64- assert not self ._subscribers [topic ] < 0
65- last_unsubscriber = not self ._subscribers [topic ]
66- if last_unsubscriber :
67- del self ._subscribers [topic ]
68- await client .unsubscribe (topic )
69-
70- async def _wait_client (self ):
71- await self ._client_ready .wait ()
72- assert self ._client_ready .is_set ()
73- return self ._client
49+ finally :
50+ self ._logger .info ("subscriber disconnected" )
7451
7552 async def _run (self ):
7653 self ._logger .info ("starting" )
77-
7854 self ._started .set ()
79-
8055 while True :
8156 try :
82- async with self ._connect () as client :
83- self ._client = client
84- self ._client_ready .set ()
85- self ._logger .info ("client ready" )
86-
87- # tracking disconnect
88- async with client .messages () as messages :
89- async for msg in messages :
90- pass
57+ async with self ._connect () as publisher :
58+ self ._logger .info ("publisher connected" )
59+ self ._publisher = publisher
60+ self ._publisher_connected .set ()
61+ async for msg in publisher .messages :
62+ pass
9163 except aiomqtt .MqttError as e :
9264 self ._logger .error (e )
9365 retry_interval = 5
9466 await asyncio .sleep (retry_interval )
9567 finally :
96- self ._client_ready .clear ()
97- self ._client = None
98- self ._logger .info ("client not ready " )
68+ self ._publisher_connected .clear ()
69+ self ._publisher = None
70+ self ._logger .info ("publisher disconnected " )
9971
10072 @contextlib .asynccontextmanager
10173 async def _connect (self ):
10274 host = await self ._maybe_resolve_mdns (self ._config .host )
103-
104- try :
105- async with aiomqtt .Client (
106- hostname = host ,
107- port = self ._config .port ,
108- username = self ._config .user ,
109- password = self ._config .password ,
110- logger = self ._logger ,
111- tls_context = self ._tls_context ,
112- ) as client :
113- yield client
114- except asyncio .CancelledError :
115- # FIXME: A cancelled `aiomqtt.Client.connect` leaks resources.
116- raise
75+ async with aiomqtt .Client (
76+ hostname = host ,
77+ port = self ._config .port ,
78+ username = self ._config .user ,
79+ password = self ._config .password ,
80+ logger = self ._logger ,
81+ tls_context = self ._tls_context ,
82+ ) as client :
83+ yield client
11784
11885 @staticmethod
11986 def _new_tls_context (config ):
0 commit comments