Skip to content

Commit 1a7f76c

Browse files
authored
typing and fixes (#45)
* typing and fixes
1 parent 89c68b3 commit 1a7f76c

File tree

9 files changed

+140
-86
lines changed

9 files changed

+140
-86
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ jobs:
2424
- name: Install dependencies
2525
run: |
2626
python -m pip install --upgrade pip
27-
pip install pytest pycodestyle
27+
pip install pytest pycodestyle pyright websockets
2828
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
2929
- name: Run tests with pytest
3030
run: |
3131
pytest
3232
- name: Lint with PyCodeStyle
3333
run: |
34-
find . -name \*.py -exec pycodestyle {} +
34+
find . -name \*.py -exec pycodestyle {} +
35+
- name: Code validation using Pyright
36+
run: |
37+
pyright

examples/bookstore/webserver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import asyncio
2929
from sys import argv
3030
from functools import partial
31-
from aiohttp import web
31+
from aiohttp import web # type: ignore
3232
from thingsdb.client import Client
3333
from thingsdb.room import Room, event
3434

@@ -62,12 +62,14 @@ def on_cleanup():
6262
async def add_book(request):
6363
book = await request.json()
6464
# Use the procedure to add the book
65+
assert bookstore
6566
await bookstore.add_book(book)
6667
return web.HTTPNoContent()
6768

6869

6970
# We have the books in memory, no need for a query
7071
async def get_books(request):
72+
assert bookstore
7173
return web.json_response({
7274
"book_titles": [book['title'] for book in bookstore.books]
7375
})

thingsdb/client/buildin.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import asyncio
12
import datetime
3+
from abc import ABC, abstractmethod
24
from typing import Union as U
35
from typing import Optional
46
from typing import Any
@@ -9,6 +11,15 @@ class Buildin:
911
#
1012
# Build-in functions from the @thingsdb scope
1113
#
14+
@abstractmethod
15+
def query(
16+
self,
17+
code: str,
18+
scope: Optional[str] = None,
19+
timeout: Optional[int] = None,
20+
skip_strip_code: bool = False,
21+
**kwargs: Any) -> asyncio.Future[Any]:
22+
...
1223

1324
async def collection_info(self, collection: U[int, str]) -> dict:
1425
"""Returns information about a specific collection.
@@ -245,16 +256,16 @@ async def new_token(
245256
expiration_time: Optional[datetime.datetime] = None,
246257
description: str = ''):
247258

248-
if expiration_time is not None:
249-
expiration_time = int(datetime.datetime.timestamp(expiration_time))
259+
ts = None if expiration_time is None \
260+
else int(expiration_time.timestamp())
250261

251262
return await self.query(
252263
"""//ti
253264
et = is_nil(expiration_time) ? nil : datetime(expiration_time);
254265
new_token(user, et, description);
255266
""",
256267
user=user,
257-
expiration_time=expiration_time,
268+
expiration_time=ts,
258269
description=description,
259270
scope='@t')
260271

@@ -334,7 +345,8 @@ async def set_module_scope(
334345
module_scope=scope,
335346
scope='@t')
336347

337-
async def set_password(self, user: str, new_password: str = None) -> None:
348+
async def set_password(self, user: str,
349+
new_password: Optional[str] = None) -> None:
338350
return await self.query(
339351
'set_password(user, new_password)',
340352
user=user,
@@ -412,16 +424,15 @@ async def new_backup(
412424
max_files: Optional[int] = 7,
413425
scope='@n'):
414426

415-
if start_ts is not None:
416-
start_ts = int(datetime.datetime.timestamp(start_ts))
427+
ts = None if start_ts is None else int(start_ts.timestamp())
417428

418429
return await self.query(
419430
"""//ti
420431
start_ts = is_nil(start_ts) ? nil : datetime(start_ts);
421432
new_backup(file_template, start_ts, repeat, max_files);
422433
""",
423434
file_template=file_template,
424-
start_ts=start_ts,
435+
start_ts=ts,
425436
repeat=repeat,
426437
max_files=max_files,
427438
scope=scope)
@@ -454,14 +465,14 @@ async def restart_module(self, name: str) -> None:
454465
return await self.query('restart_module(name)', name=name, scope='@t')
455466

456467
async def set_log_level(self, log_level: str, scope='@n') -> None:
457-
log_level = (
468+
level = (
458469
'DEBUG',
459470
'INFO',
460471
'WARNING',
461472
'ERROR',
462473
'CRITICAL').index(log_level)
463474
return await self.query(
464-
'set_log_level(log_level)', log_level=log_level, scope=scope)
475+
'set_log_level(log_level)', log_level=level, scope=scope)
465476

466477
async def shutdown(self, scope='@n') -> None:
467478
"""Shutdown the node in the selected scope.

thingsdb/client/client.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from collections import defaultdict
77
from ssl import SSLContext, PROTOCOL_TLS
8-
from typing import Optional, Union, Any
8+
from typing import Optional, Union, Any, List
99
from concurrent.futures import CancelledError
1010
from .buildin import Buildin
1111
from .protocol import Proto, Protocol, ProtocolWS
@@ -135,7 +135,7 @@ def connection_info(self) -> str:
135135
"""
136136
if not self.is_connected():
137137
return 'disconnected'
138-
socket = self._protocol.info()
138+
socket = self._protocol.info() # type: ignore
139139
if socket is None:
140140
return 'unknown_addr'
141141
addr, port = socket.getpeername()[:2]
@@ -145,7 +145,7 @@ def connect_pool(
145145
self,
146146
pool: list,
147147
*auth: Union[str, tuple]
148-
) -> asyncio.Future:
148+
) -> asyncio.Future[None]:
149149
"""Connect using a connection pool.
150150
151151
When using a connection pool, the client will randomly choose a node
@@ -183,21 +183,23 @@ def connect_pool(
183183
assert self._reconnecting is False
184184
assert len(pool), 'pool must contain at least one node'
185185
if len(auth) == 1:
186-
auth = auth[0]
186+
auth = auth[0] # type: ignore
187187

188188
self._pool = tuple((
189189
(address, 9200) if isinstance(address, str) else address
190190
for address in pool))
191191
self._auth = self._auth_check(auth)
192192
self._pool_idx = random.randint(0, len(pool) - 1)
193-
return self.reconnect()
193+
fut = self.reconnect()
194+
if fut is None:
195+
raise ConnectionError('client already connecting')
196+
return fut
194197

195-
def connect(
198+
async def connect(
196199
self,
197200
host: str,
198201
port: int = 9200,
199-
timeout: Optional[int] = 5
200-
) -> asyncio.Future:
202+
timeout: Optional[int] = 5):
201203
"""Connect to ThingsDB.
202204
203205
This method will *only* create a connection, so the connection is not
@@ -231,9 +233,9 @@ def connect(
231233
assert self.is_connected() is False
232234
self._pool = ((host, port),)
233235
self._pool_idx = 0
234-
return self._connect(timeout=timeout)
236+
await self._connect(timeout=timeout)
235237

236-
def reconnect(self) -> Optional[asyncio.Future]:
238+
def reconnect(self) -> Optional[asyncio.Future[Any]]:
237239
"""Re-connect to ThingsDB.
238240
239241
This method can be used, even when a connection still exists. In case
@@ -286,7 +288,7 @@ async def authenticate(
286288
wait forever on a response. Defaults to 5.
287289
"""
288290
if len(auth) == 1:
289-
auth = auth[0]
291+
auth = auth[0] # type: ignore
290292
self._auth = self._auth_check(auth)
291293
await self._authenticate(timeout)
292294

@@ -297,7 +299,7 @@ def query(
297299
timeout: Optional[int] = None,
298300
skip_strip_code: bool = False,
299301
**kwargs: Any
300-
) -> asyncio.Future:
302+
) -> asyncio.Future[Any]:
301303
"""Query ThingsDB.
302304
303305
Use this method to run `code` in a scope.
@@ -348,17 +350,18 @@ def query(
348350

349351
data = [scope, code]
350352
if kwargs:
351-
data.append(kwargs)
353+
data.append(kwargs) # type: ignore
352354

353-
return self._write_pkg(Proto.REQ_QUERY, data, timeout=timeout)
355+
return self._write_pkg(
356+
Proto.REQ_QUERY, data, timeout=timeout) # type: ignore
354357

355358
async def _ensure_write(
356359
self,
357360
tp: Proto,
358361
data: Any = None,
359362
is_bin: bool = False,
360363
timeout: Optional[int] = None
361-
) -> asyncio.Future:
364+
) -> asyncio.Future[Any]:
362365
if not self._pool:
363366
raise ConnectionError('no connection')
364367

@@ -372,6 +375,7 @@ async def _ensure_write(
372375
continue
373376

374377
try:
378+
assert self._protocol # we're connected
375379
res = await self._protocol.write(tp, data, is_bin, timeout)
376380
except (asyncio.exceptions.CancelledError,
377381
CancelledError, NodeError, AuthError) as e:
@@ -394,9 +398,10 @@ async def _write(
394398
data: Any = None,
395399
is_bin: bool = False,
396400
timeout: Optional[int] = None
397-
) -> asyncio.Future:
401+
) -> asyncio.Future[Any]:
398402
if not self.is_connected():
399403
raise ConnectionError('no connection')
404+
assert self._protocol # we are connected
400405
return await self._protocol.write(tp, data, is_bin, timeout)
401406

402407
def run(
@@ -406,7 +411,7 @@ def run(
406411
scope: Optional[str] = None,
407412
timeout: Optional[int] = None,
408413
**kwargs: Any,
409-
) -> asyncio.Future:
414+
) -> asyncio.Future[Any]:
410415
"""Run a procedure.
411416
412417
Use this method to run a stored procedure in a scope.
@@ -449,23 +454,23 @@ def run(
449454
data = [scope, procedure]
450455

451456
if args:
452-
data.append(args)
457+
data.append(args) # type: ignore
453458
if kwargs:
454459
raise ValueError(
455460
'it is not possible to use both keyword arguments '
456461
'and positional arguments at the same time')
457462
elif kwargs:
458-
data.append(kwargs)
463+
data.append(kwargs) # type: ignore
459464

460-
return self._write_pkg(Proto.REQ_RUN, data, timeout=timeout)
465+
return self._write_pkg(
466+
Proto.REQ_RUN, data, timeout=timeout) # type: ignore
461467

462-
def _emit(
468+
async def _emit(
463469
self,
464470
room_id: Union[int, str],
465471
event: str,
466472
*args: Optional[Any],
467-
scope: Optional[str] = None,
468-
) -> asyncio.Future:
473+
scope: Optional[str] = None):
469474
"""Emit an event.
470475
471476
Use Room(room_id, scope=scope).emit(..) instead of this function to
@@ -492,10 +497,11 @@ def _emit(
492497
"""
493498
if scope is None:
494499
scope = self._scope
495-
return self._write_pkg(Proto.REQ_EMIT, [scope, room_id, event, *args])
500+
await self._write_pkg(Proto.REQ_EMIT, [scope, room_id, event, *args])
496501

497502
def _join(self, *ids: Union[int, str],
498-
scope: Optional[str] = None) -> asyncio.Future:
503+
scope: Optional[str] = None
504+
) -> asyncio.Future[List[Optional[int]]]:
499505
"""Join one or more rooms.
500506
501507
Args:
@@ -521,10 +527,11 @@ def _join(self, *ids: Union[int, str],
521527
if scope is None:
522528
scope = self._scope
523529

524-
return self._write_pkg(Proto.REQ_JOIN, [scope, *ids])
530+
return self._write_pkg(Proto.REQ_JOIN, [scope, *ids]) # type: ignore
525531

526532
def _leave(self, *ids: Union[int, str],
527-
scope: Optional[str] = None) -> asyncio.Future:
533+
scope: Optional[str] = None
534+
) -> asyncio.Future[List[Optional[int]]]:
528535
"""Leave one or more rooms.
529536
530537
Stop receiving events for the rooms given by one or more ids. It is
@@ -553,7 +560,7 @@ def _leave(self, *ids: Union[int, str],
553560
if scope is None:
554561
scope = self._scope
555562

556-
return self._write_pkg(Proto.REQ_LEAVE, [scope, *ids])
563+
return self._write_pkg(Proto.REQ_LEAVE, [scope, *ids]) # type: ignore
557564

558565
@staticmethod
559566
def _auth_check(auth):
@@ -574,7 +581,9 @@ def _auth_check(auth):
574581
def _is_websocket_host(host):
575582
return host.startswith('ws://') or host.startswith('wss://')
576583

577-
async def _connect(self, timeout=5):
584+
async def _connect(self, timeout: Optional[int] = 5):
585+
if not self._pool:
586+
return
578587
host, port = self._pool[self._pool_idx]
579588
try:
580589
if self._is_websocket_host(host):
@@ -646,6 +655,7 @@ def _on_connection_lost(self, protocol, exc):
646655
self.reconnect()
647656

648657
async def _reconnect_loop(self):
658+
assert self._pool # only when we have a pool
649659
try:
650660
wait_time = 1
651661
timeout = 2

thingsdb/client/package.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ def _handle_fail_file(self, message: bytes):
4848
def extract_data_from(self, barray: bytearray) -> None:
4949
try:
5050
self.data = msgpack.unpackb(
51-
barray[self.__class__.st_package.size:self.total],
51+
bytes(barray[self.__class__.st_package.size:self.total]),
5252
raw=False) \
5353
if self.length else None
5454
except Exception as e:
55-
self._handle_fail_file(barray)
55+
self._handle_fail_file(bytes(barray))
5656
raise e
5757
finally:
5858
del barray[:self.total]

0 commit comments

Comments
 (0)