5
5
import contextlib
6
6
import logging
7
7
import random
8
+ from collections .abc import Generator
9
+ from typing import Self , TextIO
8
10
9
11
import numpy as np
10
12
import yaml
13
15
14
16
15
17
@contextlib .contextmanager
16
- def _make_file (path_or_file ) :
18
+ def _make_file (path_or_file : str | TextIO ) -> Generator [ TextIO , None , None ] :
17
19
"""Context manager that makes a file out of argument.
18
20
19
21
Parameters
20
22
----------
21
23
path_or_file : `str` or file object
22
24
Path name for a file or a file object.
23
25
"""
24
- if hasattr (path_or_file , "read" ):
26
+ if isinstance (path_or_file , TextIO ):
25
27
yield path_or_file
26
28
else :
27
29
with open (path_or_file ) as file :
@@ -37,11 +39,11 @@ class _ValueRandomUniform:
37
39
Range for generated numbers.
38
40
"""
39
41
40
- def __init__ (self , min , max ) :
42
+ def __init__ (self : Self , min : float , max : float ) -> None :
41
43
self ._min = float (min )
42
44
self ._max = float (max )
43
45
44
- def __call__ (self ):
46
+ def __call__ (self ) -> float :
45
47
return random .uniform (self ._min , self ._max )
46
48
47
49
@@ -54,11 +56,11 @@ class _ValueRandomUniformInt:
54
56
Range for generated numbers.
55
57
"""
56
58
57
- def __init__ (self , min , max ) :
59
+ def __init__ (self : Self , min : int , max : int ) -> None :
58
60
self ._min = float (min )
59
61
self ._max = float (max )
60
62
61
- def __call__ (self ):
63
+ def __call__ (self ) -> int :
62
64
return int (random .uniform (self ._min , self ._max ))
63
65
64
66
@@ -73,7 +75,9 @@ class _ValueIntFromFile:
73
75
One of "random" or "sequential".
74
76
"""
75
77
76
- def __init__ (self , path , mode = "random" ):
78
+ _array : np .ndarray | list [int ]
79
+
80
+ def __init__ (self : Self , path : str , mode : str = "random" ) -> None :
77
81
# read all numbers from file as integers
78
82
if path == "/dev/null" :
79
83
# for testing only
@@ -84,7 +88,7 @@ def __init__(self, path, mode="random"):
84
88
assert mode in ("random" , "sequential" )
85
89
self ._seq = 0
86
90
87
- def __call__ (self ):
91
+ def __call__ (self ) -> int :
88
92
if self ._mode == "random" :
89
93
return random .choice (self ._array )
90
94
else :
@@ -108,12 +112,12 @@ class QueryFactory:
108
112
with a description of how to generate variable value.
109
113
"""
110
114
111
- def __init__ (self , txt , variables = None ):
115
+ def __init__ (self : Self , txt : str , variables : dict [ str , dict ] | None = None ) -> None :
112
116
self ._txt = txt
113
117
self ._vars = {}
114
118
if variables is not None :
115
119
for var , config in variables .items ():
116
- generator = None
120
+ generator : _ValueRandomUniform | _ValueRandomUniformInt | _ValueIntFromFile | None = None
117
121
if "distribution" in config :
118
122
if config ["distribution" ] == "uniform" :
119
123
min = config .get ("min" , 0.0 )
@@ -131,7 +135,7 @@ def __init__(self, txt, variables=None):
131
135
raise ValueError (f"Cannot parse variable configuration { var } = { config } " )
132
136
self ._vars [var ] = generator
133
137
134
- def query (self ) :
138
+ def query (self : Self ) -> str :
135
139
"""Return next query to execute.
136
140
137
141
Returns
@@ -157,7 +161,10 @@ class Config:
157
161
List of dictionaries, cannot be empty.
158
162
"""
159
163
160
- def __init__ (self , configs ):
164
+ _config : dict
165
+ _queries : dict
166
+
167
+ def __init__ (self : Self , configs : list [dict ]) -> None :
161
168
if not configs :
162
169
raise ValueError ("empty configurations list" )
163
170
@@ -190,7 +197,7 @@ def __init__(self, configs):
190
197
raise ValueError (f"Unexpected query configuration: { qkey } : { qcfg } " )
191
198
192
199
@classmethod
193
- def from_yaml (cls , config_files ) :
200
+ def from_yaml (cls : type [ Self ] , config_files : list [ str | TextIO ]) -> Self :
194
201
"""Make configuration from bunch of YAML files
195
202
196
203
Parameters
@@ -210,7 +217,7 @@ def from_yaml(cls, config_files):
210
217
configs .append (yaml .load (file , Loader = yaml .SafeLoader ))
211
218
return cls (configs )
212
219
213
- def to_yaml (self ) :
220
+ def to_yaml (self : Self ) -> str :
214
221
"""Convert current config to YAML string.
215
222
216
223
Returns
@@ -220,7 +227,7 @@ def to_yaml(self):
220
227
"""
221
228
return yaml .dump (self ._config )
222
229
223
- def classes (self ) :
230
+ def classes (self : Self ) -> set [ str ] :
224
231
"""Return set of classes defined in configuration.
225
232
226
233
Returns
@@ -230,7 +237,7 @@ def classes(self):
230
237
"""
231
238
return self ._classes
232
239
233
- def queries (self , q_class ) :
240
+ def queries (self : Self , q_class : str ) -> dict [ str , QueryFactory ] :
234
241
"""Return queries for given class.
235
242
236
243
Parameters
@@ -246,7 +253,7 @@ def queries(self, q_class):
246
253
"""
247
254
return self ._queries [q_class ]
248
255
249
- def concurrent_queries (self , q_class ) :
256
+ def concurrent_queries (self : Self , q_class : str ) -> int :
250
257
"""Return number of concurrent queries for given class.
251
258
252
259
Parameters
@@ -261,7 +268,7 @@ def concurrent_queries(self, q_class):
261
268
"""
262
269
return self ._config ["queryClasses" ][q_class ]["concurrentQueries" ]
263
270
264
- def max_rate (self , q_class ) :
271
+ def max_rate (self : Self , q_class : str ) -> float :
265
272
"""Return maximum rate for given class.
266
273
267
274
Parameters
@@ -276,7 +283,7 @@ def max_rate(self, q_class):
276
283
"""
277
284
return self ._config ["queryClasses" ][q_class ].get ("maxRate" )
278
285
279
- def arraysize (self , q_class ) :
286
+ def arraysize (self : Self , q_class : str ) -> int :
280
287
"""Return array size for fetchmany().
281
288
282
289
Parameters
@@ -291,7 +298,7 @@ def arraysize(self, q_class):
291
298
"""
292
299
return self ._config ["queryClasses" ][q_class ].get ("arraysize" )
293
300
294
- def split (self , n_workers , i_worker ) :
301
+ def split (self : Self , n_workers : int , i_worker : int ) -> Self :
295
302
"""Divide configuration (or its workload) between number of workers.
296
303
297
304
If we want to run test with multiple workers we need to divide work
@@ -328,7 +335,7 @@ def split(self, n_workers, i_worker):
328
335
return self .__class__ ([self ._config , dict (queryClasses = overrides )])
329
336
330
337
@staticmethod
331
- def _merge (config1 , config2 ) :
338
+ def _merge (config1 : dict , config2 : dict ) -> dict :
332
339
"""Merge two config objects, return result.
333
340
334
341
If configuration present in both then second one overrides first.
0 commit comments