Skip to content

Commit 215eb12

Browse files
authored
Merge branch 'master' into ps_extract_hitless_related_connection_management_logic_in_mixin_classes
2 parents e82d5de + 768ad9c commit 215eb12

File tree

4 files changed

+48
-50
lines changed

4 files changed

+48
-50
lines changed

redis/commands/search/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import redis
1+
from redis.client import Pipeline as RedisPipeline
22

33
from ...asyncio.client import Pipeline as AsyncioPipeline
44
from .commands import (
@@ -181,7 +181,7 @@ def pipeline(self, transaction=True, shard_hint=None):
181181
return p
182182

183183

184-
class Pipeline(SearchCommands, redis.client.Pipeline):
184+
class Pipeline(SearchCommands, RedisPipeline):
185185
"""Pipeline for the module."""
186186

187187

redis/commands/search/aggregation.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Union
1+
from typing import List, Optional, Tuple, Union
22

33
from redis.commands.search.dialect import DEFAULT_DIALECT
44

@@ -27,9 +27,9 @@ class Reducer:
2727
NAME = None
2828

2929
def __init__(self, *args: str) -> None:
30-
self._args = args
31-
self._field = None
32-
self._alias = None
30+
self._args: Tuple[str, ...] = args
31+
self._field: Optional[str] = None
32+
self._alias: Optional[str] = None
3333

3434
def alias(self, alias: str) -> "Reducer":
3535
"""
@@ -49,13 +49,14 @@ def alias(self, alias: str) -> "Reducer":
4949
if alias is FIELDNAME:
5050
if not self._field:
5151
raise ValueError("Cannot use FIELDNAME alias with no field")
52-
# Chop off initial '@'
53-
alias = self._field[1:]
52+
else:
53+
# Chop off initial '@'
54+
alias = self._field[1:]
5455
self._alias = alias
5556
return self
5657

5758
@property
58-
def args(self) -> List[str]:
59+
def args(self) -> Tuple[str, ...]:
5960
return self._args
6061

6162

@@ -64,7 +65,7 @@ class SortDirection:
6465
This special class is used to indicate sort direction.
6566
"""
6667

67-
DIRSTRING = None
68+
DIRSTRING: Optional[str] = None
6869

6970
def __init__(self, field: str) -> None:
7071
self.field = field
@@ -104,17 +105,17 @@ def __init__(self, query: str = "*") -> None:
104105
All member methods (except `build_args()`)
105106
return the object itself, making them useful for chaining.
106107
"""
107-
self._query = query
108-
self._aggregateplan = []
109-
self._loadfields = []
110-
self._loadall = False
111-
self._max = 0
112-
self._with_schema = False
113-
self._verbatim = False
114-
self._cursor = []
115-
self._dialect = DEFAULT_DIALECT
116-
self._add_scores = False
117-
self._scorer = "TFIDF"
108+
self._query: str = query
109+
self._aggregateplan: List[str] = []
110+
self._loadfields: List[str] = []
111+
self._loadall: bool = False
112+
self._max: int = 0
113+
self._with_schema: bool = False
114+
self._verbatim: bool = False
115+
self._cursor: List[str] = []
116+
self._dialect: int = DEFAULT_DIALECT
117+
self._add_scores: bool = False
118+
self._scorer: str = "TFIDF"
118119

119120
def load(self, *fields: str) -> "AggregateRequest":
120121
"""
@@ -133,7 +134,7 @@ def load(self, *fields: str) -> "AggregateRequest":
133134
return self
134135

135136
def group_by(
136-
self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
137+
self, fields: Union[str, List[str]], *reducers: Reducer
137138
) -> "AggregateRequest":
138139
"""
139140
Specify by which fields to group the aggregation.
@@ -147,7 +148,6 @@ def group_by(
147148
`aggregation` module.
148149
"""
149150
fields = [fields] if isinstance(fields, str) else fields
150-
reducers = [reducers] if isinstance(reducers, Reducer) else reducers
151151

152152
ret = ["GROUPBY", str(len(fields)), *fields]
153153
for reducer in reducers:
@@ -251,12 +251,10 @@ def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
251251
.sort_by(Desc("@paid"), max=10)
252252
```
253253
"""
254-
if isinstance(fields, (str, SortDirection)):
255-
fields = [fields]
256254

257255
fields_args = []
258256
for f in fields:
259-
if isinstance(f, SortDirection):
257+
if isinstance(f, (Asc, Desc)):
260258
fields_args += [f.field, f.DIRSTRING]
261259
else:
262260
fields_args += [f]
@@ -356,7 +354,7 @@ def build_args(self) -> List[str]:
356354
ret.extend(self._loadfields)
357355

358356
if self._dialect:
359-
ret.extend(["DIALECT", self._dialect])
357+
ret.extend(["DIALECT", str(self._dialect)])
360358

361359
ret.extend(self._aggregateplan)
362360

@@ -393,7 +391,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None:
393391
self.cursor = cursor
394392
self.schema = schema
395393

396-
def __repr__(self) -> (str, str):
394+
def __repr__(self) -> str:
397395
cid = self.cursor.cid if self.cursor else -1
398396
return (
399397
f"<{self.__class__.__name__} at 0x{id(self):x} "

redis/commands/search/commands.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def create_index(
221221

222222
return self.execute_command(*args)
223223

224-
def alter_schema_add(self, fields: List[str]):
224+
def alter_schema_add(self, fields: Union[Field, List[Field]]):
225225
"""
226226
Alter the existing search index by adding new fields. The index
227227
must already exist.
@@ -336,11 +336,11 @@ def add_document(
336336
doc_id: str,
337337
nosave: bool = False,
338338
score: float = 1.0,
339-
payload: bool = None,
339+
payload: Optional[bool] = None,
340340
replace: bool = False,
341341
partial: bool = False,
342342
language: Optional[str] = None,
343-
no_create: str = False,
343+
no_create: bool = False,
344344
**fields: List[str],
345345
):
346346
"""
@@ -464,7 +464,7 @@ def info(self):
464464
return self._parse_results(INFO_CMD, res)
465465

466466
def get_params_args(
467-
self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
467+
self, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
468468
):
469469
if query_params is None:
470470
return []
@@ -478,7 +478,7 @@ def get_params_args(
478478
return args
479479

480480
def _mk_query_args(
481-
self, query, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
481+
self, query, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
482482
):
483483
args = [self.index_name]
484484

@@ -528,7 +528,7 @@ def search(
528528
def explain(
529529
self,
530530
query: Union[str, Query],
531-
query_params: Dict[str, Union[str, int, float]] = None,
531+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
532532
):
533533
"""Returns the execution plan for a complex query.
534534
@@ -543,7 +543,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa
543543
def aggregate(
544544
self,
545545
query: Union[AggregateRequest, Cursor],
546-
query_params: Dict[str, Union[str, int, float]] = None,
546+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
547547
):
548548
"""
549549
Issue an aggregation query.
@@ -598,7 +598,7 @@ def profile(
598598
self,
599599
query: Union[Query, AggregateRequest],
600600
limited: bool = False,
601-
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
601+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
602602
):
603603
"""
604604
Performs a search or aggregate command and collects performance
@@ -936,7 +936,7 @@ async def info(self):
936936
async def search(
937937
self,
938938
query: Union[str, Query],
939-
query_params: Dict[str, Union[str, int, float]] = None,
939+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
940940
):
941941
"""
942942
Search the index for a given query, and return a result of documents
@@ -968,7 +968,7 @@ async def search(
968968
async def aggregate(
969969
self,
970970
query: Union[AggregateResult, Cursor],
971-
query_params: Dict[str, Union[str, int, float]] = None,
971+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
972972
):
973973
"""
974974
Issue an aggregation query.

redis/commands/search/query.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Union
1+
from typing import List, Optional, Tuple, Union
22

33
from redis.commands.search.dialect import DEFAULT_DIALECT
44

@@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None:
3131
self._with_scores: bool = False
3232
self._scorer: Optional[str] = None
3333
self._filters: List = list()
34-
self._ids: Optional[List[str]] = None
34+
self._ids: Optional[Tuple[str, ...]] = None
3535
self._slop: int = -1
3636
self._timeout: Optional[float] = None
3737
self._in_order: bool = False
@@ -81,7 +81,7 @@ def return_field(
8181
self._return_fields += ("AS", as_field)
8282
return self
8383

84-
def _mk_field_list(self, fields: List[str]) -> List:
84+
def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List:
8585
if not fields:
8686
return []
8787
return [fields] if isinstance(fields, str) else list(fields)
@@ -126,7 +126,7 @@ def summarize(
126126

127127
def highlight(
128128
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
129-
) -> None:
129+
) -> "Query":
130130
"""
131131
Apply specified markup to matched term(s) within the returned field(s).
132132
@@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query":
187187
self._scorer = scorer
188188
return self
189189

190-
def get_args(self) -> List[str]:
190+
def get_args(self) -> List[Union[str, int, float]]:
191191
"""Format the redis arguments for this query and return them."""
192-
args = [self._query_string]
192+
args: List[Union[str, int, float]] = [self._query_string]
193193
args += self._get_args_tags()
194194
args += self._summarize_fields + self._highlight_fields
195195
args += ["LIMIT", self._offset, self._num]
196196
return args
197197

198-
def _get_args_tags(self) -> List[str]:
199-
args = []
198+
def _get_args_tags(self) -> List[Union[str, int, float]]:
199+
args: List[Union[str, int, float]] = []
200200
if self._no_content:
201201
args.append("NOCONTENT")
202202
if self._fields:
@@ -288,14 +288,14 @@ def with_scores(self) -> "Query":
288288
self._with_scores = True
289289
return self
290290

291-
def limit_fields(self, *fields: List[str]) -> "Query":
291+
def limit_fields(self, *fields: str) -> "Query":
292292
"""
293293
Limit the search to specific TEXT fields only.
294294
295-
- **fields**: A list of strings; case-sensitive field names
295+
- **fields**: Each element should be a string, case sensitive field name
296296
from the defined schema.
297297
"""
298-
self._fields = fields
298+
self._fields = list(fields)
299299
return self
300300

301301
def add_filter(self, flt: "Filter") -> "Query":
@@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query":
340340

341341

342342
class Filter:
343-
def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
343+
def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None:
344344
self.args = [keyword, field] + list(args)
345345

346346

0 commit comments

Comments
 (0)