Skip to content

Commit a2bfbff

Browse files
committed
Applying review comments
1 parent 88fed8d commit a2bfbff

File tree

9 files changed

+177
-145
lines changed

9 files changed

+177
-145
lines changed

redis/asyncio/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ def __init__(
239239
ssl_keyfile: Optional[str] = None,
240240
ssl_certfile: Optional[str] = None,
241241
ssl_cert_reqs: Union[str, VerifyMode] = "required",
242-
ssl_verify_flags_config: Optional[List[Tuple[VerifyFlags, bool]]] = None,
242+
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
243+
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
243244
ssl_ca_certs: Optional[str] = None,
244245
ssl_ca_data: Optional[str] = None,
245246
ssl_check_hostname: bool = True,
@@ -349,7 +350,8 @@ def __init__(
349350
"ssl_keyfile": ssl_keyfile,
350351
"ssl_certfile": ssl_certfile,
351352
"ssl_cert_reqs": ssl_cert_reqs,
352-
"ssl_verify_flags_config": ssl_verify_flags_config,
353+
"ssl_include_verify_flags": ssl_include_verify_flags,
354+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
353355
"ssl_ca_certs": ssl_ca_certs,
354356
"ssl_ca_data": ssl_ca_data,
355357
"ssl_check_hostname": ssl_check_hostname,

redis/asyncio/cluster.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,8 @@ def __init__(
300300
ssl_ca_certs: Optional[str] = None,
301301
ssl_ca_data: Optional[str] = None,
302302
ssl_cert_reqs: Union[str, VerifyMode] = "required",
303-
ssl_verify_flags_config: Optional[List[Tuple[VerifyFlags, bool]]] = None,
303+
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
304+
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
304305
ssl_certfile: Optional[str] = None,
305306
ssl_check_hostname: bool = True,
306307
ssl_keyfile: Optional[str] = None,
@@ -360,7 +361,8 @@ def __init__(
360361
"ssl_ca_certs": ssl_ca_certs,
361362
"ssl_ca_data": ssl_ca_data,
362363
"ssl_cert_reqs": ssl_cert_reqs,
363-
"ssl_verify_flags_config": ssl_verify_flags_config,
364+
"ssl_include_verify_flags": ssl_include_verify_flags,
365+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
364366
"ssl_certfile": ssl_certfile,
365367
"ssl_check_hostname": ssl_check_hostname,
366368
"ssl_keyfile": ssl_keyfile,

redis/asyncio/connection.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import ast
21
import asyncio
32
import copy
43
import enum
54
import inspect
6-
import re
75
import socket
86
import sys
97
import warnings
@@ -796,7 +794,8 @@ def __init__(
796794
ssl_keyfile: Optional[str] = None,
797795
ssl_certfile: Optional[str] = None,
798796
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
799-
ssl_verify_flags_config: Optional[List[Tuple["ssl.VerifyFlags", bool]]] = None,
797+
ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
798+
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
800799
ssl_ca_certs: Optional[str] = None,
801800
ssl_ca_data: Optional[str] = None,
802801
ssl_check_hostname: bool = True,
@@ -811,7 +810,8 @@ def __init__(
811810
keyfile=ssl_keyfile,
812811
certfile=ssl_certfile,
813812
cert_reqs=ssl_cert_reqs,
814-
verify_flags_config=ssl_verify_flags_config,
813+
include_verify_flags=ssl_include_verify_flags,
814+
exclude_verify_flags=ssl_exclude_verify_flags,
815815
ca_certs=ssl_ca_certs,
816816
ca_data=ssl_ca_data,
817817
check_hostname=ssl_check_hostname,
@@ -838,8 +838,12 @@ def cert_reqs(self):
838838
return self.ssl_context.cert_reqs
839839

840840
@property
841-
def verify_flags_config(self):
842-
return self.ssl_context.verify_flags_config
841+
def include_verify_flags(self):
842+
return self.ssl_context.include_verify_flags
843+
844+
@property
845+
def exclude_verify_flags(self):
846+
return self.ssl_context.exclude_verify_flags
843847

844848
@property
845849
def ca_certs(self):
@@ -863,7 +867,8 @@ class RedisSSLContext:
863867
"keyfile",
864868
"certfile",
865869
"cert_reqs",
866-
"verify_flags_config",
870+
"include_verify_flags",
871+
"exclude_verify_flags",
867872
"ca_certs",
868873
"ca_data",
869874
"context",
@@ -877,7 +882,8 @@ def __init__(
877882
keyfile: Optional[str] = None,
878883
certfile: Optional[str] = None,
879884
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
880-
verify_flags_config: Optional[List[Tuple[ssl.VerifyFlags, bool]]] = None,
885+
include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
886+
exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
881887
ca_certs: Optional[str] = None,
882888
ca_data: Optional[str] = None,
883889
check_hostname: bool = False,
@@ -903,7 +909,8 @@ def __init__(
903909
)
904910
cert_reqs = CERT_REQS[cert_reqs]
905911
self.cert_reqs = cert_reqs
906-
self.verify_flags_config = verify_flags_config
912+
self.include_verify_flags = include_verify_flags
913+
self.exclude_verify_flags = exclude_verify_flags
907914
self.ca_certs = ca_certs
908915
self.ca_data = ca_data
909916
self.check_hostname = (
@@ -918,12 +925,12 @@ def get(self) -> SSLContext:
918925
context = ssl.create_default_context()
919926
context.check_hostname = self.check_hostname
920927
context.verify_mode = self.cert_reqs
921-
if self.verify_flags_config:
922-
for flag, enabled in self.verify_flags_config:
923-
if enabled:
924-
context.options |= flag
925-
else:
926-
context.options &= ~flag
928+
if self.include_verify_flags:
929+
for flag in self.include_verify_flags:
930+
context.verify_flags |= flag
931+
if self.exclude_verify_flags:
932+
for flag in self.exclude_verify_flags:
933+
context.verify_flags &= ~flag
927934
if self.certfile and self.keyfile:
928935
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
929936
if self.ca_certs or self.ca_data:
@@ -971,6 +978,20 @@ def to_bool(value) -> Optional[bool]:
971978
return bool(value)
972979

973980

981+
def parse_ssl_verify_flags(value):
982+
# flags are passed in as a string representation of a list,
983+
# e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
984+
verify_flags_str = value.replace("[", "").replace("]", "")
985+
986+
verify_flags = []
987+
for flag in verify_flags_str.split(","):
988+
flag = flag.strip()
989+
if not hasattr(VerifyFlags, flag):
990+
raise ValueError(f"Invalid ssl verify flag: {flag}")
991+
verify_flags.append(getattr(VerifyFlags, flag))
992+
return verify_flags
993+
994+
974995
URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
975996
{
976997
"db": int,
@@ -981,6 +1002,8 @@ def to_bool(value) -> Optional[bool]:
9811002
"max_connections": int,
9821003
"health_check_interval": int,
9831004
"ssl_check_hostname": to_bool,
1005+
"ssl_include_verify_flags": parse_ssl_verify_flags,
1006+
"ssl_exclude_verify_flags": parse_ssl_verify_flags,
9841007
"timeout": float,
9851008
}
9861009
)
@@ -1040,33 +1063,6 @@ def parse_url(url: str) -> ConnectKwargs:
10401063
if parsed.scheme == "rediss":
10411064
kwargs["connection_class"] = SSLConnection
10421065

1043-
if "ssl_verify_flags_config" in kwargs:
1044-
# flags are passed in as a string representation of a list,
1045-
# e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)]
1046-
# To parse it successfully, we need to transform the flags to strings with quotes.
1047-
verify_flags_config_str = kwargs.pop("ssl_verify_flags_config")
1048-
# First wrap any VERIFY_* name in quotes
1049-
verify_flags_config_str = re.sub(
1050-
r"\b(VERIFY_[A-Z0-9_]+)\b", r'"\1"', verify_flags_config_str
1051-
)
1052-
1053-
# transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag,
1054-
# and the second is a boolean that indicates if the flag should be enabled or disabled
1055-
verify_flags_config = ast.literal_eval(verify_flags_config_str)
1056-
1057-
verify_flags_config_config_parsed = []
1058-
for flag, enabled in verify_flags_config:
1059-
if not hasattr(VerifyFlags, flag):
1060-
raise ValueError(f"Invalid verify flag: {flag}")
1061-
if not isinstance(enabled, bool):
1062-
raise ValueError(
1063-
f"Invalid verify flag enabled/disabled value: {enabled}"
1064-
)
1065-
verify_flags_config_config_parsed.append(
1066-
(getattr(VerifyFlags, flag), enabled)
1067-
)
1068-
1069-
kwargs["ssl_verify_flags_config"] = verify_flags_config_config_parsed
10701066
else:
10711067
valid_schemes = "redis://, rediss://, unix://"
10721068
raise ValueError(

redis/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
Mapping,
1313
Optional,
1414
Set,
15-
Tuple,
1615
Type,
1716
Union,
1817
)
@@ -225,7 +224,8 @@ def __init__(
225224
ssl_keyfile: Optional[str] = None,
226225
ssl_certfile: Optional[str] = None,
227226
ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required",
228-
ssl_verify_flags_config: Optional[List[Tuple["ssl.VerifyFlags", bool]]] = None,
227+
ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
228+
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
229229
ssl_ca_certs: Optional[str] = None,
230230
ssl_ca_path: Optional[str] = None,
231231
ssl_ca_data: Optional[str] = None,
@@ -332,7 +332,8 @@ def __init__(
332332
"ssl_keyfile": ssl_keyfile,
333333
"ssl_certfile": ssl_certfile,
334334
"ssl_cert_reqs": ssl_cert_reqs,
335-
"ssl_verify_flags_config": ssl_verify_flags_config,
335+
"ssl_include_verify_flags": ssl_include_verify_flags,
336+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
336337
"ssl_ca_certs": ssl_ca_certs,
337338
"ssl_ca_data": ssl_ca_data,
338339
"ssl_check_hostname": ssl_check_hostname,

redis/cluster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def parse_cluster_myshardid(resp, **options):
184184
"ssl_ca_data",
185185
"ssl_certfile",
186186
"ssl_cert_reqs",
187-
"ssl_verify_flags_config",
187+
"ssl_include_verify_flags",
188+
"ssl_exclude_verify_flags",
188189
"ssl_keyfile",
189190
"ssl_password",
190191
"ssl_check_hostname",

redis/connection.py

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import ast
21
import copy
32
import os
4-
import re
53
import socket
64
import sys
75
import threading
@@ -18,7 +16,6 @@
1816
List,
1917
Literal,
2018
Optional,
21-
Tuple,
2219
Type,
2320
TypeVar,
2421
Union,
@@ -1365,7 +1362,8 @@ def __init__(
13651362
ssl_keyfile=None,
13661363
ssl_certfile=None,
13671364
ssl_cert_reqs="required",
1368-
ssl_verify_flags_config: Optional[List[Tuple["VerifyFlags", bool]]] = None,
1365+
ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1366+
ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
13691367
ssl_ca_certs=None,
13701368
ssl_ca_data=None,
13711369
ssl_check_hostname=True,
@@ -1386,17 +1384,8 @@ def __init__(
13861384
ssl_certfile: Path to an ssl certificate. Defaults to None.
13871385
ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
13881386
or an ssl.VerifyMode. Defaults to "required".
1389-
ssl_verify_flags_config: A list with flags configuration to be set on the SSLContext. Defaults to None.
1390-
Valid format is as follows:
1391-
[
1392-
(config_flag, enabled/disabled),
1393-
...
1394-
]
1395-
Example:
1396-
[
1397-
(ssl.VerifyFlags.VERIFY_X509_STRICT, False), # disable strict
1398-
(ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN, True), # ensure partial chain is enabled
1399-
]
1387+
ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1388+
ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
14001389
ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
14011390
ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
14021391
ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
@@ -1432,7 +1421,8 @@ def __init__(
14321421
)
14331422
ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
14341423
self.cert_reqs = ssl_cert_reqs
1435-
self.ssl_verify_flags_config = ssl_verify_flags_config
1424+
self.ssl_include_verify_flags = ssl_include_verify_flags
1425+
self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
14361426
self.ca_certs = ssl_ca_certs
14371427
self.ca_data = ssl_ca_data
14381428
self.ca_path = ssl_ca_path
@@ -1472,12 +1462,12 @@ def _wrap_socket_with_ssl(self, sock):
14721462
context = ssl.create_default_context()
14731463
context.check_hostname = self.check_hostname
14741464
context.verify_mode = self.cert_reqs
1475-
if self.ssl_verify_flags_config:
1476-
for flag, enabled in self.ssl_verify_flags_config:
1477-
if enabled:
1478-
context.options |= flag
1479-
else:
1480-
context.options &= ~flag
1465+
if self.ssl_include_verify_flags:
1466+
for flag in self.ssl_include_verify_flags:
1467+
context.verify_flags |= flag
1468+
if self.ssl_exclude_verify_flags:
1469+
for flag in self.ssl_exclude_verify_flags:
1470+
context.verify_flags &= ~flag
14811471
if self.certfile or self.keyfile:
14821472
context.load_cert_chain(
14831473
certfile=self.certfile,
@@ -1591,6 +1581,20 @@ def to_bool(value):
15911581
return bool(value)
15921582

15931583

1584+
def parse_ssl_verify_flags(value):
1585+
# flags are passed in as a string representation of a list,
1586+
# e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1587+
verify_flags_str = value.replace("[", "").replace("]", "")
1588+
1589+
verify_flags = []
1590+
for flag in verify_flags_str.split(","):
1591+
flag = flag.strip()
1592+
if not hasattr(VerifyFlags, flag):
1593+
raise ValueError(f"Invalid ssl verify flag: {flag}")
1594+
verify_flags.append(getattr(VerifyFlags, flag))
1595+
return verify_flags
1596+
1597+
15941598
URL_QUERY_ARGUMENT_PARSERS = {
15951599
"db": int,
15961600
"socket_timeout": float,
@@ -1601,6 +1605,8 @@ def to_bool(value):
16011605
"max_connections": int,
16021606
"health_check_interval": int,
16031607
"ssl_check_hostname": to_bool,
1608+
"ssl_include_verify_flags": parse_ssl_verify_flags,
1609+
"ssl_exclude_verify_flags": parse_ssl_verify_flags,
16041610
"timeout": float,
16051611
}
16061612

@@ -1659,34 +1665,6 @@ def parse_url(url):
16591665
if url.scheme == "rediss":
16601666
kwargs["connection_class"] = SSLConnection
16611667

1662-
if "ssl_verify_flags_config" in kwargs:
1663-
# flags are passed in as a string representation of a list,
1664-
# e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)]
1665-
# To parse it successfully, we need to transform the flags to strings with quotes.
1666-
verify_flags_config_str = kwargs.pop("ssl_verify_flags_config")
1667-
# First wrap any VERIFY_* name in quotes
1668-
verify_flags_config_str = re.sub(
1669-
r"\b(VERIFY_[A-Z0-9_]+)\b", r'"\1"', verify_flags_config_str
1670-
)
1671-
1672-
# transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag,
1673-
# and the second is a boolean that indicates if the flag should be enabled or disabled
1674-
verify_flags_config = ast.literal_eval(verify_flags_config_str)
1675-
1676-
ssl_verify_flags_config_parsed = []
1677-
for flag, enabled in verify_flags_config:
1678-
if not hasattr(VerifyFlags, flag):
1679-
raise ValueError(f"Invalid ssl verify flag: {flag}")
1680-
if not isinstance(enabled, bool):
1681-
raise ValueError(
1682-
f"Invalid ssl verify flag enabled/disabled value: {enabled}"
1683-
)
1684-
ssl_verify_flags_config_parsed.append(
1685-
(getattr(VerifyFlags, flag), enabled)
1686-
)
1687-
1688-
kwargs["ssl_verify_flags_config"] = ssl_verify_flags_config_parsed
1689-
16901668
return kwargs
16911669

16921670

0 commit comments

Comments
 (0)