Skip to content

Commit f154610

Browse files
committed
Add type hints for qserv kraken
1 parent 0c2d2e3 commit f154610

File tree

8 files changed

+192
-94
lines changed

8 files changed

+192
-94
lines changed

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ exclude = [
5151
"etc/*",
5252
"extern/*",
5353
"python/lsst/utils",
54-
"python/lsst/qserv/testing",
5554
"python/lsst/qserv/.*/tests",
5655
"src/*",
5756
]

python/lsst/qserv/testing/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
from .config import Config, QueryFactory
22

3-
__all__ = [Config, QueryFactory]
3+
__all__ = [
4+
"Config",
5+
"QueryFactory",
6+
]

python/lsst/qserv/testing/config.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import contextlib
66
import logging
77
import random
8+
from collections.abc import Generator
9+
from typing import Self, TextIO
810

911
import numpy as np
1012
import yaml
@@ -13,15 +15,15 @@
1315

1416

1517
@contextlib.contextmanager
16-
def _make_file(path_or_file):
18+
def _make_file(path_or_file: str | TextIO) -> Generator[TextIO, None, None]:
1719
"""Context manager that makes a file out of argument.
1820
1921
Parameters
2022
----------
2123
path_or_file : `str` or file object
2224
Path name for a file or a file object.
2325
"""
24-
if hasattr(path_or_file, "read"):
26+
if isinstance(path_or_file, TextIO):
2527
yield path_or_file
2628
else:
2729
with open(path_or_file) as file:
@@ -37,11 +39,11 @@ class _ValueRandomUniform:
3739
Range for generated numbers.
3840
"""
3941

40-
def __init__(self, min, max):
42+
def __init__(self: Self, min: float, max: float) -> None:
4143
self._min = float(min)
4244
self._max = float(max)
4345

44-
def __call__(self):
46+
def __call__(self) -> float:
4547
return random.uniform(self._min, self._max)
4648

4749

@@ -54,11 +56,11 @@ class _ValueRandomUniformInt:
5456
Range for generated numbers.
5557
"""
5658

57-
def __init__(self, min, max):
59+
def __init__(self: Self, min: int, max: int) -> None:
5860
self._min = float(min)
5961
self._max = float(max)
6062

61-
def __call__(self):
63+
def __call__(self) -> int:
6264
return int(random.uniform(self._min, self._max))
6365

6466

@@ -73,7 +75,9 @@ class _ValueIntFromFile:
7375
One of "random" or "sequential".
7476
"""
7577

76-
def __init__(self, path, mode="random"):
78+
_array: np.ndarray | list[int]
79+
80+
def __init__(self: Self, path: str, mode: str = "random") -> None:
7781
# read all numbers from file as integers
7882
if path == "/dev/null":
7983
# for testing only
@@ -84,7 +88,7 @@ def __init__(self, path, mode="random"):
8488
assert mode in ("random", "sequential")
8589
self._seq = 0
8690

87-
def __call__(self):
91+
def __call__(self) -> int:
8892
if self._mode == "random":
8993
return random.choice(self._array)
9094
else:
@@ -108,12 +112,12 @@ class QueryFactory:
108112
with a description of how to generate variable value.
109113
"""
110114

111-
def __init__(self, txt, variables=None):
115+
def __init__(self: Self, txt: str, variables: dict[str, dict] | None = None) -> None:
112116
self._txt = txt
113117
self._vars = {}
114118
if variables is not None:
115119
for var, config in variables.items():
116-
generator = None
120+
generator: _ValueRandomUniform | _ValueRandomUniformInt | _ValueIntFromFile | None = None
117121
if "distribution" in config:
118122
if config["distribution"] == "uniform":
119123
min = config.get("min", 0.0)
@@ -131,7 +135,7 @@ def __init__(self, txt, variables=None):
131135
raise ValueError(f"Cannot parse variable configuration {var} = {config}")
132136
self._vars[var] = generator
133137

134-
def query(self):
138+
def query(self: Self) -> str:
135139
"""Return next query to execute.
136140
137141
Returns
@@ -157,7 +161,10 @@ class Config:
157161
List of dictionaries, cannot be empty.
158162
"""
159163

160-
def __init__(self, configs):
164+
_config: dict
165+
_queries: dict
166+
167+
def __init__(self: Self, configs: list[dict]) -> None:
161168
if not configs:
162169
raise ValueError("empty configurations list")
163170

@@ -190,7 +197,7 @@ def __init__(self, configs):
190197
raise ValueError(f"Unexpected query configuration: {qkey}: {qcfg}")
191198

192199
@classmethod
193-
def from_yaml(cls, config_files):
200+
def from_yaml(cls: type[Self], config_files: list[str | TextIO]) -> Self:
194201
"""Make configuration from bunch of YAML files
195202
196203
Parameters
@@ -210,7 +217,7 @@ def from_yaml(cls, config_files):
210217
configs.append(yaml.load(file, Loader=yaml.SafeLoader))
211218
return cls(configs)
212219

213-
def to_yaml(self):
220+
def to_yaml(self: Self) -> str:
214221
"""Convert current config to YAML string.
215222
216223
Returns
@@ -220,7 +227,7 @@ def to_yaml(self):
220227
"""
221228
return yaml.dump(self._config)
222229

223-
def classes(self):
230+
def classes(self: Self) -> set[str]:
224231
"""Return set of classes defined in configuration.
225232
226233
Returns
@@ -230,7 +237,7 @@ def classes(self):
230237
"""
231238
return self._classes
232239

233-
def queries(self, q_class):
240+
def queries(self: Self, q_class: str) -> dict[str, QueryFactory]:
234241
"""Return queries for given class.
235242
236243
Parameters
@@ -246,7 +253,7 @@ def queries(self, q_class):
246253
"""
247254
return self._queries[q_class]
248255

249-
def concurrent_queries(self, q_class):
256+
def concurrent_queries(self: Self, q_class: str) -> int:
250257
"""Return number of concurrent queries for given class.
251258
252259
Parameters
@@ -261,7 +268,7 @@ def concurrent_queries(self, q_class):
261268
"""
262269
return self._config["queryClasses"][q_class]["concurrentQueries"]
263270

264-
def max_rate(self, q_class):
271+
def max_rate(self: Self, q_class: str) -> float:
265272
"""Return maximum rate for given class.
266273
267274
Parameters
@@ -276,7 +283,7 @@ def max_rate(self, q_class):
276283
"""
277284
return self._config["queryClasses"][q_class].get("maxRate")
278285

279-
def arraysize(self, q_class):
286+
def arraysize(self: Self, q_class: str) -> int:
280287
"""Return array size for fetchmany().
281288
282289
Parameters
@@ -291,7 +298,7 @@ def arraysize(self, q_class):
291298
"""
292299
return self._config["queryClasses"][q_class].get("arraysize")
293300

294-
def split(self, n_workers, i_worker):
301+
def split(self: Self, n_workers: int, i_worker: int) -> Self:
295302
"""Divide configuration (or its workload) between number of workers.
296303
297304
If we want to run test with multiple workers we need to divide work
@@ -328,7 +335,7 @@ def split(self, n_workers, i_worker):
328335
return self.__class__([self._config, dict(queryClasses=overrides)])
329336

330337
@staticmethod
331-
def _merge(config1, config2):
338+
def _merge(config1: dict, config2: dict) -> dict:
332339
"""Merge two config objects, return result.
333340
334341
If configuration present in both then second one overrides first.

python/lsst/qserv/testing/main.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@
88
import functools
99
import logging
1010
import os
11+
from collections.abc import Callable
1112

12-
import MySQLdb
13+
from lsst.qserv.testing import mock_db
14+
15+
import mysql.connector
16+
from mysql.connector.abstracts import MySQLConnectionAbstract
17+
from mysql.connector.pooling import PooledMySQLConnection
1318

14-
from . import mock_db
1519
from .config import Config
1620
from .monitor import InfluxDBFileMonitor, LogMonitor
1721
from .runner_mgr import RunnerManager
1822

1923
_LOG = logging.getLogger(__name__)
2024

2125

22-
def _log_config(level, slot):
26+
def _log_config(level: int, slot: int | None) -> None:
2327
levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
2428
if slot is None:
2529
simple_format = "%(asctime)s %(levelname)7s %(name)s -- %(message)s"
@@ -29,7 +33,7 @@ def _log_config(level, slot):
2933
logging.basicConfig(format=simple_format, level=level)
3034

3135

32-
def main():
36+
def main() -> None:
3337
parser = argparse.ArgumentParser(description="Test harness to generate load for QServ")
3438

3539
parser.add_argument(
@@ -141,9 +145,9 @@ def main():
141145
if num_slots is None or slot is None:
142146
parser.error("cannot determine slurm configuration, envvar is not set")
143147
return 2
144-
num_slots = int(num_slots)
145-
if num_slots > 1:
146-
args.num_slots = num_slots
148+
num_slots_int = int(num_slots)
149+
if num_slots_int > 1:
150+
args.num_slots = num_slots_int
147151
args.slot = int(slot)
148152

149153
_log_config(args.verbose, args.slot)
@@ -157,16 +161,22 @@ def main():
157161
print(cfg.to_yaml())
158162

159163
# connection factory
164+
conn_factory: Callable[..., PooledMySQLConnection | MySQLConnectionAbstract | mock_db.MockConnection]
160165
if args.dummy_db:
161166
conn_factory = mock_db.connect
162167
else:
163168
conn_factory = functools.partial(
164-
MySQLdb.connect, host=args.host, port=args.port, user=args.user, passwd=args.password, db=args.db
169+
mysql.connector.connect,
170+
host=args.host,
171+
port=args.port,
172+
user=args.user,
173+
passwd=args.password,
174+
db=args.db,
165175
)
166176

167177
# monitor
168178
tags = None if args.slot is None else {"slot": args.slot}
169-
monitor = None
179+
monitor: LogMonitor | InfluxDBFileMonitor | None = None
170180
if args.monitor == "log":
171181
monitor = LogMonitor(logging.getLogger("metrics"), tags=tags)
172182
elif args.monitor == "influxdb-file":

python/lsst/qserv/testing/mock_db.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import re
55
import time
66
from collections import namedtuple
7+
from typing import Any, Self
78

89
_LOG = logging.getLogger(__name__)
910

1011
_num_re = re.compile(r"\d+")
1112

1213

13-
ColDesriptor = namedtuple("ColDesriptor", "name type_code display_size internal_size precision scale null_ok")
14+
ColDescriptor = namedtuple(
15+
"ColDescriptor", "name type_code display_size internal_size precision scale null_ok"
16+
)
1417

1518

1619
class MockCursor:
@@ -20,13 +23,16 @@ class MockCursor:
2023
"""
2124

2225
arraysize = 1
26+
_query: str | None
27+
n_rows: int
28+
rows: list[tuple[int, str]]
2329

24-
def __init__(self):
30+
def __init__(self: Self) -> None:
2531
self._query = None
2632
self.n_rows = 0
2733
self.rows = []
2834

29-
def execute(self, query):
35+
def execute(self: Self, query: str) -> None:
3036
_LOG.debug("executing query: %s", query)
3137
self._query = query
3238
self.n_rows = 2
@@ -40,30 +46,30 @@ def execute(self, query):
4046
# spend at least few milliseconds in query
4147
time.sleep(0.01)
4248

43-
def fetchall(self):
49+
def fetchall(self: Self) -> list[tuple[int, str]]:
4450
rows = self.rows
4551
self.rows = []
4652
return rows
4753

48-
def fetchmany(self, arraysize=None):
54+
def fetchmany(self: Self, arraysize: int | None = None) -> list[tuple[int, str]]:
4955
if arraysize is None:
5056
arraysize = self.arraysize
5157
rows = self.rows[:arraysize]
5258
self.rows = self.rows[arraysize:]
5359
return rows
5460

5561
@property
56-
def rowcount(self):
62+
def rowcount(self: Self) -> int:
5763
return self.n_rows
5864

5965
@property
60-
def description(self):
66+
def description(self: Self) -> list[ColDescriptor]:
6167
# some randome codes
6268
return [
63-
ColDesriptor(
69+
ColDescriptor(
6470
name="ID", type_code=1, display_size=10, internal_size=4, precision=0, scale=1, null_ok=False
6571
),
66-
ColDesriptor(
72+
ColDescriptor(
6773
name="name",
6874
type_code=15,
6975
display_size=32,
@@ -76,11 +82,11 @@ def description(self):
7682

7783

7884
class MockConnection:
79-
def cursor(self):
85+
def cursor(self: Self) -> MockCursor:
8086
return MockCursor()
8187

8288

83-
def connect(*args, **kwargs):
89+
def connect(*args: Any, **kwargs: Any) -> MockConnection:
8490
"""Can take any parameters so it can be used as replacement for
8591
MySQLdb.connect() method.
8692
"""

0 commit comments

Comments
 (0)