55import time
66from collections import defaultdict
77from ssl import SSLContext , PROTOCOL_TLS
8- from typing import Optional , Union , Any
8+ from typing import Optional , Union , Any , List
99from concurrent .futures import CancelledError
1010from .buildin import Buildin
1111from .protocol import Proto , Protocol , ProtocolWS
@@ -135,7 +135,7 @@ def connection_info(self) -> str:
135135 """
136136 if not self .is_connected ():
137137 return 'disconnected'
138- socket = self ._protocol .info ()
138+ socket = self ._protocol .info () # type: ignore
139139 if socket is None :
140140 return 'unknown_addr'
141141 addr , port = socket .getpeername ()[:2 ]
@@ -145,7 +145,7 @@ def connect_pool(
145145 self ,
146146 pool : list ,
147147 * auth : Union [str , tuple ]
148- ) -> asyncio .Future :
148+ ) -> asyncio .Future [ None ] :
149149 """Connect using a connection pool.
150150
151151 When using a connection pool, the client will randomly choose a node
@@ -183,21 +183,23 @@ def connect_pool(
183183 assert self ._reconnecting is False
184184 assert len (pool ), 'pool must contain at least one node'
185185 if len (auth ) == 1 :
186- auth = auth [0 ]
186+ auth = auth [0 ] # type: ignore
187187
188188 self ._pool = tuple ((
189189 (address , 9200 ) if isinstance (address , str ) else address
190190 for address in pool ))
191191 self ._auth = self ._auth_check (auth )
192192 self ._pool_idx = random .randint (0 , len (pool ) - 1 )
193- return self .reconnect ()
193+ fut = self .reconnect ()
194+ if fut is None :
195+ raise ConnectionError ('client already connecting' )
196+ return fut
194197
195- def connect (
198+ async def connect (
196199 self ,
197200 host : str ,
198201 port : int = 9200 ,
199- timeout : Optional [int ] = 5
200- ) -> asyncio .Future :
202+ timeout : Optional [int ] = 5 ):
201203 """Connect to ThingsDB.
202204
203205 This method will *only* create a connection, so the connection is not
@@ -231,9 +233,9 @@ def connect(
231233 assert self .is_connected () is False
232234 self ._pool = ((host , port ),)
233235 self ._pool_idx = 0
234- return self ._connect (timeout = timeout )
236+ await self ._connect (timeout = timeout )
235237
236- def reconnect (self ) -> Optional [asyncio .Future ]:
238+ def reconnect (self ) -> Optional [asyncio .Future [ Any ] ]:
237239 """Re-connect to ThingsDB.
238240
239241 This method can be used, even when a connection still exists. In case
@@ -286,7 +288,7 @@ async def authenticate(
286288 wait forever on a response. Defaults to 5.
287289 """
288290 if len (auth ) == 1 :
289- auth = auth [0 ]
291+ auth = auth [0 ] # type: ignore
290292 self ._auth = self ._auth_check (auth )
291293 await self ._authenticate (timeout )
292294
@@ -297,7 +299,7 @@ def query(
297299 timeout : Optional [int ] = None ,
298300 skip_strip_code : bool = False ,
299301 ** kwargs : Any
300- ) -> asyncio .Future :
302+ ) -> asyncio .Future [ Any ] :
301303 """Query ThingsDB.
302304
303305 Use this method to run `code` in a scope.
@@ -348,17 +350,18 @@ def query(
348350
349351 data = [scope , code ]
350352 if kwargs :
351- data .append (kwargs )
353+ data .append (kwargs ) # type: ignore
352354
353- return self ._write_pkg (Proto .REQ_QUERY , data , timeout = timeout )
355+ return self ._write_pkg (
356+ Proto .REQ_QUERY , data , timeout = timeout ) # type: ignore
354357
355358 async def _ensure_write (
356359 self ,
357360 tp : Proto ,
358361 data : Any = None ,
359362 is_bin : bool = False ,
360363 timeout : Optional [int ] = None
361- ) -> asyncio .Future :
364+ ) -> asyncio .Future [ Any ] :
362365 if not self ._pool :
363366 raise ConnectionError ('no connection' )
364367
@@ -372,6 +375,7 @@ async def _ensure_write(
372375 continue
373376
374377 try :
378+ assert self ._protocol # we're connected
375379 res = await self ._protocol .write (tp , data , is_bin , timeout )
376380 except (asyncio .exceptions .CancelledError ,
377381 CancelledError , NodeError , AuthError ) as e :
@@ -394,9 +398,10 @@ async def _write(
394398 data : Any = None ,
395399 is_bin : bool = False ,
396400 timeout : Optional [int ] = None
397- ) -> asyncio .Future :
401+ ) -> asyncio .Future [ Any ] :
398402 if not self .is_connected ():
399403 raise ConnectionError ('no connection' )
404+ assert self ._protocol # we are connected
400405 return await self ._protocol .write (tp , data , is_bin , timeout )
401406
402407 def run (
@@ -406,7 +411,7 @@ def run(
406411 scope : Optional [str ] = None ,
407412 timeout : Optional [int ] = None ,
408413 ** kwargs : Any ,
409- ) -> asyncio .Future :
414+ ) -> asyncio .Future [ Any ] :
410415 """Run a procedure.
411416
412417 Use this method to run a stored procedure in a scope.
@@ -449,23 +454,23 @@ def run(
449454 data = [scope , procedure ]
450455
451456 if args :
452- data .append (args )
457+ data .append (args ) # type: ignore
453458 if kwargs :
454459 raise ValueError (
455460 'it is not possible to use both keyword arguments '
456461 'and positional arguments at the same time' )
457462 elif kwargs :
458- data .append (kwargs )
463+ data .append (kwargs ) # type: ignore
459464
460- return self ._write_pkg (Proto .REQ_RUN , data , timeout = timeout )
465+ return self ._write_pkg (
466+ Proto .REQ_RUN , data , timeout = timeout ) # type: ignore
461467
462- def _emit (
468+ async def _emit (
463469 self ,
464470 room_id : Union [int , str ],
465471 event : str ,
466472 * args : Optional [Any ],
467- scope : Optional [str ] = None ,
468- ) -> asyncio .Future :
473+ scope : Optional [str ] = None ):
469474 """Emit an event.
470475
471476 Use Room(room_id, scope=scope).emit(..) instead of this function to
@@ -492,10 +497,11 @@ def _emit(
492497 """
493498 if scope is None :
494499 scope = self ._scope
495- return self ._write_pkg (Proto .REQ_EMIT , [scope , room_id , event , * args ])
500+ await self ._write_pkg (Proto .REQ_EMIT , [scope , room_id , event , * args ])
496501
497502 def _join (self , * ids : Union [int , str ],
498- scope : Optional [str ] = None ) -> asyncio .Future :
503+ scope : Optional [str ] = None
504+ ) -> asyncio .Future [List [Optional [int ]]]:
499505 """Join one or more rooms.
500506
501507 Args:
@@ -521,10 +527,11 @@ def _join(self, *ids: Union[int, str],
521527 if scope is None :
522528 scope = self ._scope
523529
524- return self ._write_pkg (Proto .REQ_JOIN , [scope , * ids ])
530+ return self ._write_pkg (Proto .REQ_JOIN , [scope , * ids ]) # type: ignore
525531
526532 def _leave (self , * ids : Union [int , str ],
527- scope : Optional [str ] = None ) -> asyncio .Future :
533+ scope : Optional [str ] = None
534+ ) -> asyncio .Future [List [Optional [int ]]]:
528535 """Leave one or more rooms.
529536
530537 Stop receiving events for the rooms given by one or more ids. It is
@@ -553,7 +560,7 @@ def _leave(self, *ids: Union[int, str],
553560 if scope is None :
554561 scope = self ._scope
555562
556- return self ._write_pkg (Proto .REQ_LEAVE , [scope , * ids ])
563+ return self ._write_pkg (Proto .REQ_LEAVE , [scope , * ids ]) # type: ignore
557564
558565 @staticmethod
559566 def _auth_check (auth ):
@@ -574,7 +581,9 @@ def _auth_check(auth):
574581 def _is_websocket_host (host ):
575582 return host .startswith ('ws://' ) or host .startswith ('wss://' )
576583
577- async def _connect (self , timeout = 5 ):
584+ async def _connect (self , timeout : Optional [int ] = 5 ):
585+ if not self ._pool :
586+ return
578587 host , port = self ._pool [self ._pool_idx ]
579588 try :
580589 if self ._is_websocket_host (host ):
@@ -646,6 +655,7 @@ def _on_connection_lost(self, protocol, exc):
646655 self .reconnect ()
647656
648657 async def _reconnect_loop (self ):
658+ assert self ._pool # only when we have a pool
649659 try :
650660 wait_time = 1
651661 timeout = 2
0 commit comments