Skip to content

Commit f291329

Browse files
authored
Merge branch 'master' into ps_fix_evalsha_typehints
2 parents 4d5d695 + 768ad9c commit f291329

File tree

6 files changed

+61
-65
lines changed

6 files changed

+61
-65
lines changed

redis/commands/core.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def command_list(
830830

831831
return self.execute_command("COMMAND LIST", *pieces)
832832

833-
def command_getkeysandflags(self, *args: List[str]) -> List[Union[str, List[str]]]:
833+
def command_getkeysandflags(self, *args: str) -> List[Union[str, List[str]]]:
834834
"""
835835
Returns array of keys from a full Redis command and their usage flags.
836836
@@ -848,7 +848,7 @@ def command_docs(self, *args):
848848
)
849849

850850
def config_get(
851-
self, pattern: PatternT = "*", *args: List[PatternT], **kwargs
851+
self, pattern: PatternT = "*", *args: PatternT, **kwargs
852852
) -> ResponseT:
853853
"""
854854
Return a dictionary of configuration based on the ``pattern``
@@ -861,7 +861,7 @@ def config_set(
861861
self,
862862
name: KeyT,
863863
value: EncodableT,
864-
*args: List[Union[KeyT, EncodableT]],
864+
*args: Union[KeyT, EncodableT],
865865
**kwargs,
866866
) -> ResponseT:
867867
"""Set config item ``name`` with ``value``
@@ -987,9 +987,7 @@ def select(self, index: int, **kwargs) -> ResponseT:
987987
"""
988988
return self.execute_command("SELECT", index, **kwargs)
989989

990-
def info(
991-
self, section: Optional[str] = None, *args: List[str], **kwargs
992-
) -> ResponseT:
990+
def info(self, section: Optional[str] = None, *args: str, **kwargs) -> ResponseT:
993991
"""
994992
Returns a dictionary containing information about the Redis server
995993
@@ -2606,7 +2604,7 @@ def blmpop(
26062604
self,
26072605
timeout: float,
26082606
numkeys: int,
2609-
*args: List[str],
2607+
*args: str,
26102608
direction: str,
26112609
count: Optional[int] = 1,
26122610
) -> Optional[list]:
@@ -2619,14 +2617,14 @@ def blmpop(
26192617
26202618
For more information, see https://redis.io/commands/blmpop
26212619
"""
2622-
args = [timeout, numkeys, *args, direction, "COUNT", count]
2620+
cmd_args = [timeout, numkeys, *args, direction, "COUNT", count]
26232621

2624-
return self.execute_command("BLMPOP", *args)
2622+
return self.execute_command("BLMPOP", *cmd_args)
26252623

26262624
def lmpop(
26272625
self,
26282626
num_keys: int,
2629-
*args: List[str],
2627+
*args: str,
26302628
direction: str,
26312629
count: Optional[int] = 1,
26322630
) -> Union[Awaitable[list], list]:
@@ -2636,11 +2634,11 @@ def lmpop(
26362634
26372635
For more information, see https://redis.io/commands/lmpop
26382636
"""
2639-
args = [num_keys] + list(args) + [direction]
2637+
cmd_args = [num_keys] + list(args) + [direction]
26402638
if count != 1:
2641-
args.extend(["COUNT", count])
2639+
cmd_args.extend(["COUNT", count])
26422640

2643-
return self.execute_command("LMPOP", *args)
2641+
return self.execute_command("LMPOP", *cmd_args)
26442642

26452643
def lindex(
26462644
self, name: str, index: int

redis/commands/json/commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class JSONCommands:
1414
"""json commands."""
1515

1616
def arrappend(
17-
self, name: str, path: Optional[str] = Path.root_path(), *args: List[JsonType]
17+
self, name: str, path: Optional[str] = Path.root_path(), *args: JsonType
1818
) -> List[Optional[int]]:
1919
"""Append the objects ``args`` to the array under the
2020
``path` in key ``name``.
@@ -52,7 +52,7 @@ def arrindex(
5252
return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name])
5353

5454
def arrinsert(
55-
self, name: str, path: str, index: int, *args: List[JsonType]
55+
self, name: str, path: str, index: int, *args: JsonType
5656
) -> List[Optional[int]]:
5757
"""Insert the objects ``args`` to the array at index ``index``
5858
under the ``path` in key ``name``.

redis/commands/search/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import redis
1+
from redis.client import Pipeline as RedisPipeline
22

33
from ...asyncio.client import Pipeline as AsyncioPipeline
44
from .commands import (
@@ -181,7 +181,7 @@ def pipeline(self, transaction=True, shard_hint=None):
181181
return p
182182

183183

184-
class Pipeline(SearchCommands, redis.client.Pipeline):
184+
class Pipeline(SearchCommands, RedisPipeline):
185185
"""Pipeline for the module."""
186186

187187

redis/commands/search/aggregation.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Union
1+
from typing import List, Optional, Tuple, Union
22

33
from redis.commands.search.dialect import DEFAULT_DIALECT
44

@@ -27,9 +27,9 @@ class Reducer:
2727
NAME = None
2828

2929
def __init__(self, *args: str) -> None:
30-
self._args = args
31-
self._field = None
32-
self._alias = None
30+
self._args: Tuple[str, ...] = args
31+
self._field: Optional[str] = None
32+
self._alias: Optional[str] = None
3333

3434
def alias(self, alias: str) -> "Reducer":
3535
"""
@@ -49,13 +49,14 @@ def alias(self, alias: str) -> "Reducer":
4949
if alias is FIELDNAME:
5050
if not self._field:
5151
raise ValueError("Cannot use FIELDNAME alias with no field")
52-
# Chop off initial '@'
53-
alias = self._field[1:]
52+
else:
53+
# Chop off initial '@'
54+
alias = self._field[1:]
5455
self._alias = alias
5556
return self
5657

5758
@property
58-
def args(self) -> List[str]:
59+
def args(self) -> Tuple[str, ...]:
5960
return self._args
6061

6162

@@ -64,7 +65,7 @@ class SortDirection:
6465
This special class is used to indicate sort direction.
6566
"""
6667

67-
DIRSTRING = None
68+
DIRSTRING: Optional[str] = None
6869

6970
def __init__(self, field: str) -> None:
7071
self.field = field
@@ -104,17 +105,17 @@ def __init__(self, query: str = "*") -> None:
104105
All member methods (except `build_args()`)
105106
return the object itself, making them useful for chaining.
106107
"""
107-
self._query = query
108-
self._aggregateplan = []
109-
self._loadfields = []
110-
self._loadall = False
111-
self._max = 0
112-
self._with_schema = False
113-
self._verbatim = False
114-
self._cursor = []
115-
self._dialect = DEFAULT_DIALECT
116-
self._add_scores = False
117-
self._scorer = "TFIDF"
108+
self._query: str = query
109+
self._aggregateplan: List[str] = []
110+
self._loadfields: List[str] = []
111+
self._loadall: bool = False
112+
self._max: int = 0
113+
self._with_schema: bool = False
114+
self._verbatim: bool = False
115+
self._cursor: List[str] = []
116+
self._dialect: int = DEFAULT_DIALECT
117+
self._add_scores: bool = False
118+
self._scorer: str = "TFIDF"
118119

119120
def load(self, *fields: str) -> "AggregateRequest":
120121
"""
@@ -133,7 +134,7 @@ def load(self, *fields: str) -> "AggregateRequest":
133134
return self
134135

135136
def group_by(
136-
self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
137+
self, fields: Union[str, List[str]], *reducers: Reducer
137138
) -> "AggregateRequest":
138139
"""
139140
Specify by which fields to group the aggregation.
@@ -147,7 +148,6 @@ def group_by(
147148
`aggregation` module.
148149
"""
149150
fields = [fields] if isinstance(fields, str) else fields
150-
reducers = [reducers] if isinstance(reducers, Reducer) else reducers
151151

152152
ret = ["GROUPBY", str(len(fields)), *fields]
153153
for reducer in reducers:
@@ -251,12 +251,10 @@ def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
251251
.sort_by(Desc("@paid"), max=10)
252252
```
253253
"""
254-
if isinstance(fields, (str, SortDirection)):
255-
fields = [fields]
256254

257255
fields_args = []
258256
for f in fields:
259-
if isinstance(f, SortDirection):
257+
if isinstance(f, (Asc, Desc)):
260258
fields_args += [f.field, f.DIRSTRING]
261259
else:
262260
fields_args += [f]
@@ -356,7 +354,7 @@ def build_args(self) -> List[str]:
356354
ret.extend(self._loadfields)
357355

358356
if self._dialect:
359-
ret.extend(["DIALECT", self._dialect])
357+
ret.extend(["DIALECT", str(self._dialect)])
360358

361359
ret.extend(self._aggregateplan)
362360

@@ -393,7 +391,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None:
393391
self.cursor = cursor
394392
self.schema = schema
395393

396-
def __repr__(self) -> (str, str):
394+
def __repr__(self) -> str:
397395
cid = self.cursor.cid if self.cursor else -1
398396
return (
399397
f"<{self.__class__.__name__} at 0x{id(self):x} "

redis/commands/search/commands.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def create_index(
221221

222222
return self.execute_command(*args)
223223

224-
def alter_schema_add(self, fields: List[str]):
224+
def alter_schema_add(self, fields: Union[Field, List[Field]]):
225225
"""
226226
Alter the existing search index by adding new fields. The index
227227
must already exist.
@@ -336,11 +336,11 @@ def add_document(
336336
doc_id: str,
337337
nosave: bool = False,
338338
score: float = 1.0,
339-
payload: bool = None,
339+
payload: Optional[bool] = None,
340340
replace: bool = False,
341341
partial: bool = False,
342342
language: Optional[str] = None,
343-
no_create: str = False,
343+
no_create: bool = False,
344344
**fields: List[str],
345345
):
346346
"""
@@ -464,7 +464,7 @@ def info(self):
464464
return self._parse_results(INFO_CMD, res)
465465

466466
def get_params_args(
467-
self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
467+
self, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
468468
):
469469
if query_params is None:
470470
return []
@@ -478,7 +478,7 @@ def get_params_args(
478478
return args
479479

480480
def _mk_query_args(
481-
self, query, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
481+
self, query, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
482482
):
483483
args = [self.index_name]
484484

@@ -528,7 +528,7 @@ def search(
528528
def explain(
529529
self,
530530
query: Union[str, Query],
531-
query_params: Dict[str, Union[str, int, float]] = None,
531+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
532532
):
533533
"""Returns the execution plan for a complex query.
534534
@@ -543,7 +543,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa
543543
def aggregate(
544544
self,
545545
query: Union[AggregateRequest, Cursor],
546-
query_params: Dict[str, Union[str, int, float]] = None,
546+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
547547
):
548548
"""
549549
Issue an aggregation query.
@@ -598,7 +598,7 @@ def profile(
598598
self,
599599
query: Union[Query, AggregateRequest],
600600
limited: bool = False,
601-
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
601+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
602602
):
603603
"""
604604
Performs a search or aggregate command and collects performance
@@ -936,7 +936,7 @@ async def info(self):
936936
async def search(
937937
self,
938938
query: Union[str, Query],
939-
query_params: Dict[str, Union[str, int, float]] = None,
939+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
940940
):
941941
"""
942942
Search the index for a given query, and return a result of documents
@@ -968,7 +968,7 @@ async def search(
968968
async def aggregate(
969969
self,
970970
query: Union[AggregateResult, Cursor],
971-
query_params: Dict[str, Union[str, int, float]] = None,
971+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
972972
):
973973
"""
974974
Issue an aggregation query.

0 commit comments

Comments
 (0)