44from dataclasses import dataclass
55from datetime import datetime , timedelta
66from operator import attrgetter
7- from typing import TYPE_CHECKING , Any , Callable , List , Optional , Tuple , Union , cast
7+ from typing import TYPE_CHECKING , Any , Callable , List , Optional , Tuple , Type , TypeVar , Union , cast
88from urllib .parse import parse_qs , urlparse
99from uuid import uuid4
1010
@@ -217,6 +217,9 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]
217217 return await asyncio .gather (* [self ._get_job_def (job_id , int (score )) for job_id , score in jobs ])
218218
219219
220+ TArqRedis = TypeVar ('TArqRedis' , bound = ArqRedis )
221+
222+
220223async def create_pool (
221224 settings_ : Optional [RedisSettings ] = None ,
222225 * ,
@@ -225,7 +228,8 @@ async def create_pool(
225228 job_deserializer : Optional [Deserializer ] = None ,
226229 default_queue_name : str = default_queue_name ,
227230 expires_extra_ms : int = expires_extra_ms ,
228- ) -> ArqRedis :
231+ arq_redis_cls : Type [TArqRedis ] = ArqRedis , # type: ignore[assignment]
232+ ) -> TArqRedis :
229233 """
230234 Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
231235
@@ -238,19 +242,19 @@ async def create_pool(
238242
239243 if settings .sentinel :
240244
241- def pool_factory (* args : Any , ** kwargs : Any ) -> ArqRedis :
245+ def pool_factory (* args : Any , ** kwargs : Any ) -> TArqRedis :
242246 client = Sentinel ( # type: ignore[misc]
243247 * args ,
244248 sentinels = settings .host ,
245249 ssl = settings .ssl ,
246250 ** kwargs ,
247251 )
248- redis = client .master_for (settings .sentinel_master , redis_class = ArqRedis )
249- return cast (ArqRedis , redis )
252+ redis = client .master_for (settings .sentinel_master , redis_class = arq_redis_cls )
253+ return cast (TArqRedis , redis )
250254
251255 else :
252256 pool_factory = functools .partial (
253- ArqRedis ,
257+ arq_redis_cls ,
254258 host = settings .host ,
255259 port = settings .port ,
256260 unix_socket_path = settings .unix_socket_path ,
@@ -312,8 +316,5 @@ async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any])
312316 clients_connected = info_clients .get ('connected_clients' , '?' )
313317
314318 log_func (
315- f'redis_version={ redis_version } '
316- f'mem_usage={ mem_usage } '
317- f'clients_connected={ clients_connected } '
318- f'db_keys={ key_count } '
319+ f'redis_version={ redis_version } mem_usage={ mem_usage } clients_connected={ clients_connected } db_keys={ key_count } '
319320 )
0 commit comments