From 35668549f647a37c42545453bf9431e84f5cdb69 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 21 Aug 2025 16:18:34 +0300 Subject: [PATCH 1/4] Adding ssl_verify_flags_config argument for ssl connection configuration --- redis/asyncio/client.py | 5 +- redis/asyncio/cluster.py | 5 +- redis/asyncio/connection.py | 48 ++++++++++++++++++- redis/client.py | 3 ++ redis/cluster.py | 1 + redis/connection.py | 55 +++++++++++++++++++++- tests/test_asyncio/test_ssl.py | 78 +++++++++++++++++++++++++++++++ tests/test_connection_pool.py | 34 +++++++++++++- tests/test_ssl.py | 85 ++++++++++++++++++++++++++++++++++ 9 files changed, 308 insertions(+), 6 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index d4650e1791..ddaead5f8e 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -81,10 +81,11 @@ ) if TYPE_CHECKING and SSL_AVAILABLE: - from ssl import TLSVersion, VerifyMode + from ssl import TLSVersion, VerifyFlags, VerifyMode else: TLSVersion = None VerifyMode = None + VerifyFlags = None PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) @@ -238,6 +239,7 @@ def __init__( ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Union[str, VerifyMode] = "required", + ssl_verify_flags_config: Optional[List[Tuple[VerifyFlags, bool]]] = None, ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = True, @@ -347,6 +349,7 @@ def __init__( "ssl_keyfile": ssl_keyfile, "ssl_certfile": ssl_certfile, "ssl_cert_reqs": ssl_cert_reqs, + "ssl_verify_flags_config": ssl_verify_flags_config, "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index dd68f388b2..40a77dd858 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -86,10 +86,11 @@ ) if SSL_AVAILABLE: - from ssl import TLSVersion, VerifyMode + from ssl import TLSVersion, VerifyFlags, VerifyMode else: TLSVersion = None VerifyMode = None + VerifyFlags = None TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -299,6 +300,7 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_cert_reqs: Union[str, VerifyMode] = "required", + ssl_verify_flags_config: Optional[List[Tuple[VerifyFlags, bool]]] = None, ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = True, ssl_keyfile: Optional[str] = None, @@ -358,6 +360,7 @@ def __init__( "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_cert_reqs": ssl_cert_reqs, + "ssl_verify_flags_config": ssl_verify_flags_config, "ssl_certfile": ssl_certfile, "ssl_check_hostname": ssl_check_hostname, "ssl_keyfile": ssl_keyfile, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4efd868f6f..ddbc5939d2 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1,7 +1,9 @@ +import ast import asyncio import copy import enum import inspect +import re import socket import sys import warnings @@ -30,11 +32,12 @@ if SSL_AVAILABLE: import ssl - from ssl import SSLContext, TLSVersion + from ssl import SSLContext, TLSVersion, VerifyFlags else: ssl = None TLSVersion = None SSLContext = None + VerifyFlags = None from ..auth.token import TokenInterface from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher @@ -793,6 +796,7 @@ def __init__( ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", + ssl_verify_flags_config: Optional[List[Tuple["ssl.VerifyFlags", bool]]] = None, ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = True, @@ -807,6 +811,7 @@ def __init__( keyfile=ssl_keyfile, certfile=ssl_certfile, cert_reqs=ssl_cert_reqs, + verify_flags_config=ssl_verify_flags_config, ca_certs=ssl_ca_certs, ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, @@ -832,6 +837,10 @@ def certfile(self): def cert_reqs(self): return self.ssl_context.cert_reqs + @property + def verify_flags_config(self): + return self.ssl_context.verify_flags_config + @property def ca_certs(self): return self.ssl_context.ca_certs @@ -854,6 +863,7 @@ class RedisSSLContext: "keyfile", "certfile", "cert_reqs", + "verify_flags_config", "ca_certs", "ca_data", "context", @@ -867,6 +877,7 @@ def __init__( keyfile: Optional[str] = None, certfile: Optional[str] = None, cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, + verify_flags_config: Optional[List[Tuple[ssl.VerifyFlags, bool]]] = None, ca_certs: Optional[str] = None, ca_data: Optional[str] = None, check_hostname: bool = False, @@ -892,6 +903,7 @@ def __init__( ) cert_reqs = CERT_REQS[cert_reqs] self.cert_reqs = cert_reqs + self.verify_flags_config = verify_flags_config self.ca_certs = ca_certs self.ca_data = ca_data self.check_hostname = ( @@ -906,6 +918,12 @@ def get(self) -> SSLContext: context = ssl.create_default_context() context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs + if self.verify_flags_config: + for flag, enabled in self.verify_flags_config: + if enabled: + context.options |= flag + else: + context.options &= ~flag if self.certfile and self.keyfile: context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) if self.ca_certs or self.ca_data: @@ -1021,6 +1039,34 @@ def parse_url(url: str) -> ConnectKwargs: if parsed.scheme == "rediss": kwargs["connection_class"] = SSLConnection + + if "ssl_verify_flags_config" in kwargs: + # flags are passed in as a string representation of a list, + # e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)] + # To parse it sucessfully, we need transform the flags to strings with quotes. + verify_flags_config_str = kwargs.pop("ssl_verify_flags_config") + # First wrap any VERIFY_* name in quotes + verify_flags_config_str = re.sub( + r"\b(VERIFY_[A-Z0-9_]+)\b", r'"\1"', verify_flags_config_str + ) + + # transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag, + # and the second is a boolean that indicates if the flad should be enabled or disabled + verify_flags_config = ast.literal_eval(verify_flags_config_str) + + verify_flags_config_config_parsed = [] + for flag, enabled in verify_flags_config: + if not hasattr(VerifyFlags, flag): + raise ValueError(f"Invalid verify flag: {flag}") + if not isinstance(enabled, bool): + raise ValueError( + f"Invalid verify flag enabled/disabled value: {enabled}" + ) + verify_flags_config_config_parsed.append( + (getattr(VerifyFlags, flag), enabled) + ) + + kwargs["ssl_verify_flags_config"] = verify_flags_config_config_parsed else: valid_schemes = "redis://, rediss://, unix://" raise ValueError( diff --git a/redis/client.py b/redis/client.py index 163ef3fedc..e0173edd2f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -12,6 +12,7 @@ Mapping, Optional, Set, + Tuple, Type, Union, ) @@ -224,6 +225,7 @@ def __init__( ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required", + ssl_verify_flags_config: Optional[List[Tuple["ssl.VerifyFlags", bool]]] = None, ssl_ca_certs: Optional[str] = None, ssl_ca_path: Optional[str] = None, ssl_ca_data: Optional[str] = None, @@ -330,6 +332,7 @@ def __init__( "ssl_keyfile": ssl_keyfile, "ssl_certfile": ssl_certfile, "ssl_cert_reqs": ssl_cert_reqs, + "ssl_verify_flags_config": ssl_verify_flags_config, "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, diff --git a/redis/cluster.py b/redis/cluster.py index 7c645be755..b0f6f87e07 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -184,6 +184,7 @@ def parse_cluster_myshardid(resp, **options): "ssl_ca_data", "ssl_certfile", "ssl_cert_reqs", + "ssl_verify_flags_config", "ssl_keyfile", "ssl_password", "ssl_check_hostname", diff --git a/redis/connection.py b/redis/connection.py index 7c7071f635..2de5fa2e92 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,5 +1,7 @@ +import ast import copy import os +import re import socket import sys import threading @@ -16,6 +18,7 @@ List, Literal, Optional, + Tuple, Type, TypeVar, Union, @@ -68,8 +71,10 @@ if SSL_AVAILABLE: import ssl + from ssl import VerifyFlags else: ssl = None + VerifyFlags = None if HIREDIS_AVAILABLE: import hiredis @@ -1360,6 +1365,7 @@ def __init__( ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs="required", + ssl_verify_flags_config: Optional[List[Tuple["VerifyFlags", bool]]] = None, ssl_ca_certs=None, ssl_ca_data=None, ssl_check_hostname=True, @@ -1378,7 +1384,19 @@ def __init__( Args: ssl_keyfile: Path to an ssl private key. Defaults to None. ssl_certfile: Path to an ssl certificate. Defaults to None. - ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required". + ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), + or an ssl.VerifyMode. Defaults to "required". + ssl_verify_flags_config: A list with flags configuration to be set on the SSLContext. Defaults to None. + Valid format is as follows: + [ + (config_flag, enabled/disabled), + ... + ] + Example: + [ + (ssl.VERIFY_X509_STRICT, False), # disable strict + (ssl.VERIFY_X509_PARTIAL_CHAIN, True), # ensure partial chain is enabled + ] ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. @@ -1414,6 +1432,7 @@ def __init__( ) ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] self.cert_reqs = ssl_cert_reqs + self.ssl_verify_flags_config = ssl_verify_flags_config self.ca_certs = ssl_ca_certs self.ca_data = ssl_ca_data self.ca_path = ssl_ca_path @@ -1453,6 +1472,12 @@ def _wrap_socket_with_ssl(self, sock): context = ssl.create_default_context() context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs + if self.ssl_verify_flags_config: + for flag, enabled in self.ssl_verify_flags_config: + if enabled: + context.options |= flag + else: + context.options &= ~flag if self.certfile or self.keyfile: context.load_cert_chain( certfile=self.certfile, @@ -1634,6 +1659,34 @@ def parse_url(url): if url.scheme == "rediss": kwargs["connection_class"] = SSLConnection + if "ssl_verify_flags_config" in kwargs: + # flags are passed in as a string representation of a list, + # e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)] + # To parse it sucessfully, we need transform the flags to strings with quotes. + verify_flags_config_str = kwargs.pop("ssl_verify_flags_config") + # First wrap any VERIFY_* name in quotes + verify_flags_config_str = re.sub( + r"\b(VERIFY_[A-Z0-9_]+)\b", r'"\1"', verify_flags_config_str + ) + + # transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag, + # and the second is a boolean that indicates if the flad should be enabled or disabled + verify_flags_config = ast.literal_eval(verify_flags_config_str) + + ssl_verify_flags_config_parsed = [] + for flag, enabled in verify_flags_config: + if not hasattr(VerifyFlags, flag): + raise ValueError(f"Invalid ssl verify flag: {flag}") + if not isinstance(enabled, bool): + raise ValueError( + f"Invalid ssl verify flag enabled/disabled value: {enabled}" + ) + ssl_verify_flags_config_parsed.append( + (getattr(VerifyFlags, flag), enabled) + ) + + kwargs["ssl_verify_flags_config"] = ssl_verify_flags_config_parsed + return kwargs diff --git a/tests/test_asyncio/test_ssl.py b/tests/test_asyncio/test_ssl.py index 75800f22de..14f774b76f 100644 --- a/tests/test_asyncio/test_ssl.py +++ b/tests/test_asyncio/test_ssl.py @@ -1,3 +1,5 @@ +import ssl +import unittest.mock from urllib.parse import urlparse import pytest import pytest_asyncio @@ -54,3 +56,79 @@ async def test_cert_reqs_none_with_check_hostname(self, request): assert conn.check_hostname is False finally: await r.aclose() + + async def test_ssl_flags_config_applied_to_context(self, request): + """Test that ssl_flags_config is properly applied to the SSL context in async connections""" + ssl_url = request.config.option.redis_ssl_url + parsed_url = urlparse(ssl_url) + + # Test with specific SSL verify flags + ssl_verify_flags_config = [ + (ssl.VerifyFlags.VERIFY_X509_STRICT, False), # Disable strict verification + (ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN, True), # Enable partial chain + ] + + r = redis.Redis( + host=parsed_url.hostname, + port=parsed_url.port, + ssl=True, + ssl_cert_reqs="none", + ssl_verify_flags_config=ssl_verify_flags_config, + ) + + try: + # Get the connection to trigger SSL context creation + conn = r.connection_pool.make_connection() + assert isinstance(conn, redis.SSLConnection) + + # Verify that ssl_verify_flags was stored correctly in the RedisSSLContext + assert conn.ssl_context.verify_flags_config == ssl_verify_flags_config + + # Verify the flags were processed by checking they're stored in connection + assert conn.verify_flags_config is not None + assert len(conn.verify_flags_config) == 2 + + # Check each flag individually + for flag, expected_enabled in ssl_verify_flags_config: + found = False + for stored_flag, stored_enabled in conn.verify_flags_config: + if stored_flag == flag: + assert stored_enabled == expected_enabled + found = True + break + assert found, f"Flag {flag} not found in stored ssl_verify_flags" + + # Test the actual SSL context created by the connection's RedisSSLContext + # We need to mock the ssl.create_default_context to capture the context + captured_context = None + original_create_default_context = ssl.create_default_context + + def capture_context_create_default(): + nonlocal captured_context + captured_context = original_create_default_context() + return captured_context + + with unittest.mock.patch( + "ssl.create_default_context", capture_context_create_default + ): + # Trigger SSL context creation by calling get() on the RedisSSLContext + ssl_context = conn.ssl_context.get() + + # Validate that we captured a context and it has the correct flags applied + assert captured_context is not None, "SSL context was not captured" + assert ssl_context is captured_context, ( + "Returned context should be the captured one" + ) + + # Verify that VERIFY_X509_STRICT was disabled (bit cleared) + assert not ( + captured_context.options & ssl.VerifyFlags.VERIFY_X509_STRICT + ), "VERIFY_X509_STRICT should be disabled but is enabled" + + # Verify that VERIFY_X509_PARTIAL_CHAIN was enabled (bit set) + assert ( + captured_context.options & ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN + ), "VERIFY_X509_PARTIAL_CHAIN should be enabled but is disabled" + + finally: + await r.aclose() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 1eb68d3775..6322f4aa4b 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -19,6 +19,9 @@ ) from .test_pubsub import wait_for_message +if SSL_AVAILABLE: + import ssl + class DummyConnection: description_format = "DummyConnection<>" @@ -511,8 +514,6 @@ class MyConnection(redis.SSLConnection): assert pool.connection_class == MyConnection def test_cert_reqs_options(self): - import ssl - class DummyConnectionPool(redis.ConnectionPool): def get_connection(self): return self.make_connection() @@ -532,6 +533,35 @@ def get_connection(self): pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") assert pool.get_connection().check_hostname is True + def test_ssl_flags_config_parsing(self): + class DummyConnectionPool(redis.ConnectionPool): + def get_connection(self): + return self.make_connection() + + pool = DummyConnectionPool.from_url( + "rediss://?ssl_verify_flags_config=[(VERIFY_X509_STRICT,False), (VERIFY_X509_PARTIAL_CHAIN,True)]" + ) + + assert pool.get_connection().ssl_verify_flags_config == [ + (ssl.VerifyFlags.VERIFY_X509_STRICT, False), + (ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN, True), + ] + + def test_ssl_flags_config_invalid_flag(self): + class DummyConnectionPool(redis.ConnectionPool): + def get_connection(self): + return self.make_connection() + + with pytest.raises(ValueError): + DummyConnectionPool.from_url( + "rediss://?ssl_verify_flags_config=[(VERIFY_X509,False), (VERIFY_X509_PARTIAL_CHAIN,True)]" + ) + + with pytest.raises(ValueError): + DummyConnectionPool.from_url( + "rediss://?ssl_verify_flags_config=[(VERIFY_X509_STRICT,Ok), (VERIFY_X509_PARTIAL_CHAIN,True)]" + ) + class TestConnection: def test_on_connect_error(self): diff --git a/tests/test_ssl.py b/tests/test_ssl.py index cb3f227629..78a6449db4 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -328,3 +328,88 @@ def test_cert_reqs_none_with_check_hostname(self, request): assert conn.check_hostname is False finally: r.close() + + def test_ssl_verify_flags_config_applied_to_context(self, request): + """Test that ssl_verify_flags_config is properly applied to the SSL context""" + ssl_url = request.config.option.redis_ssl_url + parsed_url = urlparse(ssl_url) + + # Test with specific SSL verify flags + ssl_verify_flags_config = [ + (ssl.VerifyFlags.VERIFY_X509_STRICT, False), # Disable strict verification + (ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN, True), # Enable partial chain + ] + + r = redis.Redis( + host=parsed_url.hostname, + port=parsed_url.port, + ssl=True, + ssl_cert_reqs="none", + ssl_verify_flags_config=ssl_verify_flags_config, + ) + + try: + # Get the connection to trigger SSL context creation + conn = r.connection_pool.get_connection() + assert isinstance(conn, redis.SSLConnection) + + # Verify the flags were processed by checking they're stored in connection + assert conn.ssl_verify_flags_config is not None + assert len(conn.ssl_verify_flags_config) == 2 + + # Check each flag individually + for flag, expected_enabled in ssl_verify_flags_config: + found = False + for stored_flag, stored_enabled in conn.ssl_verify_flags_config: + if stored_flag == flag: + assert stored_enabled == expected_enabled + found = True + break + assert found, f"Flag {flag} not found in stored ssl_verify_flags_config" + + # Test the actual SSL context created by the connection + # We need to create a mock socket and call _wrap_socket_with_ssl to get the context + import socket + import unittest.mock + + mock_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + try: + # Mock the wrap_socket method to capture the context + captured_context = None + + def capture_context_wrap_socket(context_self, sock, **_kwargs): + nonlocal captured_context + captured_context = context_self + # Don't actually wrap the socket, just return the original socket + # to avoid connection errors + return sock + + with unittest.mock.patch.object( + ssl.SSLContext, "wrap_socket", capture_context_wrap_socket + ): + try: + conn._wrap_socket_with_ssl(mock_sock) + except Exception: + # We expect this to potentially fail since we're not actually connecting + # but we should have captured the context + pass + + # Validate that we captured a context and it has the correct flags applied + assert captured_context is not None, "SSL context was not captured" + + # Verify that VERIFY_X509_STRICT was disabled (bit cleared) + assert not ( + captured_context.options & ssl.VerifyFlags.VERIFY_X509_STRICT + ), "VERIFY_X509_STRICT should be disabled but is enabled" + + # Verify that VERIFY_X509_PARTIAL_CHAIN was enabled (bit set) + assert ( + captured_context.options & ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN + ), "VERIFY_X509_PARTIAL_CHAIN should be enabled but is disabled" + + finally: + mock_sock.close() + + finally: + r.close() From cb8767d4b7f6ba0691527df16d9a2c8c399fe978 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 18 Sep 2025 14:00:36 +0300 Subject: [PATCH 2/4] Changing ssl verify flag used for testing - the current one was not available in python 3.9 --- tests/test_asyncio/test_ssl.py | 8 ++++---- tests/test_connection_pool.py | 8 ++++---- tests/test_ssl.py | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/test_asyncio/test_ssl.py b/tests/test_asyncio/test_ssl.py index 14f774b76f..a5c8ac4b56 100644 --- a/tests/test_asyncio/test_ssl.py +++ b/tests/test_asyncio/test_ssl.py @@ -65,7 +65,7 @@ async def test_ssl_flags_config_applied_to_context(self, request): # Test with specific SSL verify flags ssl_verify_flags_config = [ (ssl.VerifyFlags.VERIFY_X509_STRICT, False), # Disable strict verification - (ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN, True), # Enable partial chain + (ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, True), # Enable partial chain ] r = redis.Redis( @@ -125,10 +125,10 @@ def capture_context_create_default(): captured_context.options & ssl.VerifyFlags.VERIFY_X509_STRICT ), "VERIFY_X509_STRICT should be disabled but is enabled" - # Verify that VERIFY_X509_PARTIAL_CHAIN was enabled (bit set) + # Verify that VERIFY_CRL_CHECK_CHAIN was enabled (bit set) assert ( - captured_context.options & ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN - ), "VERIFY_X509_PARTIAL_CHAIN should be enabled but is disabled" + captured_context.options & ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN + ), "VERIFY_CRL_CHECK_CHAIN should be enabled but is disabled" finally: await r.aclose() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 6322f4aa4b..2af4b4bdc8 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -539,12 +539,12 @@ def get_connection(self): return self.make_connection() pool = DummyConnectionPool.from_url( - "rediss://?ssl_verify_flags_config=[(VERIFY_X509_STRICT,False), (VERIFY_X509_PARTIAL_CHAIN,True)]" + "rediss://?ssl_verify_flags_config=[(VERIFY_X509_STRICT,False), (VERIFY_CRL_CHECK_CHAIN,True)]" ) assert pool.get_connection().ssl_verify_flags_config == [ (ssl.VerifyFlags.VERIFY_X509_STRICT, False), - (ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN, True), + (ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, True), ] def test_ssl_flags_config_invalid_flag(self): @@ -554,12 +554,12 @@ def get_connection(self): with pytest.raises(ValueError): DummyConnectionPool.from_url( - "rediss://?ssl_verify_flags_config=[(VERIFY_X509,False), (VERIFY_X509_PARTIAL_CHAIN,True)]" + "rediss://?ssl_verify_flags_config=[(VERIFY_X509,False), (VERIFY_CRL_CHECK_CHAIN,True)]" ) with pytest.raises(ValueError): DummyConnectionPool.from_url( - "rediss://?ssl_verify_flags_config=[(VERIFY_X509_STRICT,Ok), (VERIFY_X509_PARTIAL_CHAIN,True)]" + "rediss://?ssl_verify_flags_config=[(VERIFY_X509_STRICT,Ok), (VERIFY_CRL_CHECK_CHAIN,True)]" ) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 78a6449db4..666ffeac23 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -337,7 +337,7 @@ def test_ssl_verify_flags_config_applied_to_context(self, request): # Test with specific SSL verify flags ssl_verify_flags_config = [ (ssl.VerifyFlags.VERIFY_X509_STRICT, False), # Disable strict verification - (ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN, True), # Enable partial chain + (ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, True), # Enable partial chain ] r = redis.Redis( @@ -403,10 +403,10 @@ def capture_context_wrap_socket(context_self, sock, **_kwargs): captured_context.options & ssl.VerifyFlags.VERIFY_X509_STRICT ), "VERIFY_X509_STRICT should be disabled but is enabled" - # Verify that VERIFY_X509_PARTIAL_CHAIN was enabled (bit set) + # Verify that VERIFY_CRL_CHECK_CHAIN was enabled (bit set) assert ( - captured_context.options & ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN - ), "VERIFY_X509_PARTIAL_CHAIN should be enabled but is disabled" + captured_context.options & ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN + ), "VERIFY_CRL_CHECK_CHAIN should be enabled but is disabled" finally: mock_sock.close() From 88fed8d1fd35ed4ad84da9c8af8d1e19ff0ce155 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 18 Sep 2025 14:07:14 +0300 Subject: [PATCH 3/4] Applying Copilot's review comments --- redis/asyncio/connection.py | 4 ++-- redis/connection.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index ddbc5939d2..70cb5a6c09 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1043,7 +1043,7 @@ def parse_url(url: str) -> ConnectKwargs: if "ssl_verify_flags_config" in kwargs: # flags are passed in as a string representation of a list, # e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)] - # To parse it sucessfully, we need transform the flags to strings with quotes. + # To parse it successfully, we need to transform the flags to strings with quotes. verify_flags_config_str = kwargs.pop("ssl_verify_flags_config") # First wrap any VERIFY_* name in quotes verify_flags_config_str = re.sub( @@ -1051,7 +1051,7 @@ def parse_url(url: str) -> ConnectKwargs: ) # transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag, - # and the second is a boolean that indicates if the flad should be enabled or disabled + # and the second is a boolean that indicates if the flag should be enabled or disabled verify_flags_config = ast.literal_eval(verify_flags_config_str) verify_flags_config_config_parsed = [] diff --git a/redis/connection.py b/redis/connection.py index 2de5fa2e92..7a1a9107ef 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1394,8 +1394,8 @@ def __init__( ] Example: [ - (ssl.VERIFY_X509_STRICT, False), # disable strict - (ssl.VERIFY_X509_PARTIAL_CHAIN, True), # ensure partial chain is enabled + (ssl.VerifyFlags.VERIFY_X509_STRICT, False), # disable strict + (ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN, True), # ensure partial chain is enabled ] ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. @@ -1662,7 +1662,7 @@ def parse_url(url): if "ssl_verify_flags_config" in kwargs: # flags are passed in as a string representation of a list, # e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)] - # To parse it sucessfully, we need transform the flags to strings with quotes. + # To parse it successfully, we need to transform the flags to strings with quotes. verify_flags_config_str = kwargs.pop("ssl_verify_flags_config") # First wrap any VERIFY_* name in quotes verify_flags_config_str = re.sub( @@ -1670,7 +1670,7 @@ def parse_url(url): ) # transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag, - # and the second is a boolean that indicates if the flad should be enabled or disabled + # and the second is a boolean that indicates if the flag should be enabled or disabled verify_flags_config = ast.literal_eval(verify_flags_config_str) ssl_verify_flags_config_parsed = [] From 91bca0816366a737a0d1e8bffd84109cac008d87 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 25 Sep 2025 13:38:59 +0300 Subject: [PATCH 4/4] Applying review comments --- redis/asyncio/client.py | 6 ++- redis/asyncio/cluster.py | 6 ++- redis/asyncio/connection.py | 80 ++++++++++++++++------------------ redis/client.py | 7 +-- redis/cluster.py | 3 +- redis/connection.py | 78 ++++++++++++--------------------- tests/test_asyncio/test_ssl.py | 51 +++++++++++++--------- tests/test_connection_pool.py | 42 +++++++++++++++--- tests/test_ssl.py | 48 ++++++++++++-------- 9 files changed, 176 insertions(+), 145 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index ddaead5f8e..3defeceead 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -239,7 +239,8 @@ def __init__( ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Union[str, VerifyMode] = "required", - ssl_verify_flags_config: Optional[List[Tuple[VerifyFlags, bool]]] = None, + ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = True, @@ -349,7 +350,8 @@ def __init__( "ssl_keyfile": ssl_keyfile, "ssl_certfile": ssl_certfile, "ssl_cert_reqs": ssl_cert_reqs, - "ssl_verify_flags_config": ssl_verify_flags_config, + "ssl_include_verify_flags": ssl_include_verify_flags, + "ssl_exclude_verify_flags": ssl_exclude_verify_flags, "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 40a77dd858..4e0e06517d 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -300,7 +300,8 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_cert_reqs: Union[str, VerifyMode] = "required", - ssl_verify_flags_config: Optional[List[Tuple[VerifyFlags, bool]]] = None, + ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = True, ssl_keyfile: Optional[str] = None, @@ -360,7 +361,8 @@ def __init__( "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_cert_reqs": ssl_cert_reqs, - "ssl_verify_flags_config": ssl_verify_flags_config, + "ssl_include_verify_flags": ssl_include_verify_flags, + "ssl_exclude_verify_flags": ssl_exclude_verify_flags, "ssl_certfile": ssl_certfile, "ssl_check_hostname": ssl_check_hostname, "ssl_keyfile": ssl_keyfile, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 70cb5a6c09..e3eb3bd9f1 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1,9 +1,7 @@ -import ast import asyncio import copy import enum import inspect -import re import socket import sys import warnings @@ -796,7 +794,8 @@ def __init__( ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", - ssl_verify_flags_config: Optional[List[Tuple["ssl.VerifyFlags", bool]]] = None, + ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = True, @@ -811,7 +810,8 @@ def __init__( keyfile=ssl_keyfile, certfile=ssl_certfile, cert_reqs=ssl_cert_reqs, - verify_flags_config=ssl_verify_flags_config, + include_verify_flags=ssl_include_verify_flags, + exclude_verify_flags=ssl_exclude_verify_flags, ca_certs=ssl_ca_certs, ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, @@ -838,8 +838,12 @@ def cert_reqs(self): return self.ssl_context.cert_reqs @property - def verify_flags_config(self): - return self.ssl_context.verify_flags_config + def include_verify_flags(self): + return self.ssl_context.include_verify_flags + + @property + def exclude_verify_flags(self): + return self.ssl_context.exclude_verify_flags @property def ca_certs(self): @@ -863,7 +867,8 @@ class RedisSSLContext: "keyfile", "certfile", "cert_reqs", - "verify_flags_config", + "include_verify_flags", + "exclude_verify_flags", "ca_certs", "ca_data", "context", @@ -877,7 +882,8 @@ def __init__( keyfile: Optional[str] = None, certfile: Optional[str] = None, cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, - verify_flags_config: Optional[List[Tuple[ssl.VerifyFlags, bool]]] = None, + include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, ca_certs: Optional[str] = None, ca_data: Optional[str] = None, check_hostname: bool = False, @@ -903,7 +909,8 @@ def __init__( ) cert_reqs = CERT_REQS[cert_reqs] self.cert_reqs = cert_reqs - self.verify_flags_config = verify_flags_config + self.include_verify_flags = include_verify_flags + self.exclude_verify_flags = exclude_verify_flags self.ca_certs = ca_certs self.ca_data = ca_data self.check_hostname = ( @@ -918,12 +925,12 @@ def get(self) -> SSLContext: context = ssl.create_default_context() context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs - if self.verify_flags_config: - for flag, enabled in self.verify_flags_config: - if enabled: - context.options |= flag - else: - context.options &= ~flag + if self.include_verify_flags: + for flag in self.include_verify_flags: + context.verify_flags |= flag + if self.exclude_verify_flags: + for flag in self.exclude_verify_flags: + context.verify_flags &= ~flag if self.certfile and self.keyfile: context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) if self.ca_certs or self.ca_data: @@ -971,6 +978,20 @@ def to_bool(value) -> Optional[bool]: return bool(value) +def parse_ssl_verify_flags(value): + # flags are passed in as a string representation of a list, + # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN + verify_flags_str = value.replace("[", "").replace("]", "") + + verify_flags = [] + for flag in verify_flags_str.split(","): + flag = flag.strip() + if not hasattr(VerifyFlags, flag): + raise ValueError(f"Invalid ssl verify flag: {flag}") + verify_flags.append(getattr(VerifyFlags, flag)) + return verify_flags + + URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( { "db": int, @@ -981,6 +1002,8 @@ def to_bool(value) -> Optional[bool]: "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, + "ssl_include_verify_flags": parse_ssl_verify_flags, + "ssl_exclude_verify_flags": parse_ssl_verify_flags, "timeout": float, } ) @@ -1040,33 +1063,6 @@ def parse_url(url: str) -> ConnectKwargs: if parsed.scheme == "rediss": kwargs["connection_class"] = SSLConnection - if "ssl_verify_flags_config" in kwargs: - # flags are passed in as a string representation of a list, - # e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)] - # To parse it successfully, we need to transform the flags to strings with quotes. - verify_flags_config_str = kwargs.pop("ssl_verify_flags_config") - # First wrap any VERIFY_* name in quotes - verify_flags_config_str = re.sub( - r"\b(VERIFY_[A-Z0-9_]+)\b", r'"\1"', verify_flags_config_str - ) - - # transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag, - # and the second is a boolean that indicates if the flag should be enabled or disabled - verify_flags_config = ast.literal_eval(verify_flags_config_str) - - verify_flags_config_config_parsed = [] - for flag, enabled in verify_flags_config: - if not hasattr(VerifyFlags, flag): - raise ValueError(f"Invalid verify flag: {flag}") - if not isinstance(enabled, bool): - raise ValueError( - f"Invalid verify flag enabled/disabled value: {enabled}" - ) - verify_flags_config_config_parsed.append( - (getattr(VerifyFlags, flag), enabled) - ) - - kwargs["ssl_verify_flags_config"] = verify_flags_config_config_parsed else: valid_schemes = "redis://, rediss://, unix://" raise ValueError( diff --git a/redis/client.py b/redis/client.py index e0173edd2f..cf4d77950f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -12,7 +12,6 @@ Mapping, Optional, Set, - Tuple, Type, Union, ) @@ -225,7 +224,8 @@ def __init__( ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required", - ssl_verify_flags_config: Optional[List[Tuple["ssl.VerifyFlags", bool]]] = None, + ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, ssl_ca_certs: Optional[str] = None, ssl_ca_path: Optional[str] = None, ssl_ca_data: Optional[str] = None, @@ -332,7 +332,8 @@ def __init__( "ssl_keyfile": ssl_keyfile, "ssl_certfile": ssl_certfile, "ssl_cert_reqs": ssl_cert_reqs, - "ssl_verify_flags_config": ssl_verify_flags_config, + "ssl_include_verify_flags": ssl_include_verify_flags, + "ssl_exclude_verify_flags": ssl_exclude_verify_flags, "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, diff --git a/redis/cluster.py b/redis/cluster.py index b0f6f87e07..839721edf1 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -184,7 +184,8 @@ def parse_cluster_myshardid(resp, **options): "ssl_ca_data", "ssl_certfile", "ssl_cert_reqs", - "ssl_verify_flags_config", + "ssl_include_verify_flags", + "ssl_exclude_verify_flags", "ssl_keyfile", "ssl_password", "ssl_check_hostname", diff --git a/redis/connection.py b/redis/connection.py index 7a1a9107ef..a09156b0f3 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,7 +1,5 @@ -import ast import copy import os -import re import socket import sys import threading @@ -18,7 +16,6 @@ List, Literal, Optional, - Tuple, Type, TypeVar, Union, @@ -1365,7 +1362,8 @@ def __init__( ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs="required", - ssl_verify_flags_config: Optional[List[Tuple["VerifyFlags", bool]]] = None, + ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None, + ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None, ssl_ca_certs=None, ssl_ca_data=None, ssl_check_hostname=True, @@ -1386,17 +1384,8 @@ def __init__( ssl_certfile: Path to an ssl certificate. Defaults to None. ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required". - ssl_verify_flags_config: A list with flags configuration to be set on the SSLContext. Defaults to None. - Valid format is as follows: - [ - (config_flag, enabled/disabled), - ... - ] - Example: - [ - (ssl.VerifyFlags.VERIFY_X509_STRICT, False), # disable strict - (ssl.VerifyFlags.VERIFY_X509_PARTIAL_CHAIN, True), # ensure partial chain is enabled - ] + ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None. + ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None. ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. @@ -1432,7 +1421,8 @@ def __init__( ) ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] self.cert_reqs = ssl_cert_reqs - self.ssl_verify_flags_config = ssl_verify_flags_config + self.ssl_include_verify_flags = ssl_include_verify_flags + self.ssl_exclude_verify_flags = ssl_exclude_verify_flags self.ca_certs = ssl_ca_certs self.ca_data = ssl_ca_data self.ca_path = ssl_ca_path @@ -1472,12 +1462,12 @@ def _wrap_socket_with_ssl(self, sock): context = ssl.create_default_context() context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs - if self.ssl_verify_flags_config: - for flag, enabled in self.ssl_verify_flags_config: - if enabled: - context.options |= flag - else: - context.options &= ~flag + if self.ssl_include_verify_flags: + for flag in self.ssl_include_verify_flags: + context.verify_flags |= flag + if self.ssl_exclude_verify_flags: + for flag in self.ssl_exclude_verify_flags: + context.verify_flags &= ~flag if self.certfile or self.keyfile: context.load_cert_chain( certfile=self.certfile, @@ -1591,6 +1581,20 @@ def to_bool(value): return bool(value) +def parse_ssl_verify_flags(value): + # flags are passed in as a string representation of a list, + # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN + verify_flags_str = value.replace("[", "").replace("]", "") + + verify_flags = [] + for flag in verify_flags_str.split(","): + flag = flag.strip() + if not hasattr(VerifyFlags, flag): + raise ValueError(f"Invalid ssl verify flag: {flag}") + verify_flags.append(getattr(VerifyFlags, flag)) + return verify_flags + + URL_QUERY_ARGUMENT_PARSERS = { "db": int, "socket_timeout": float, @@ -1601,6 +1605,8 @@ def to_bool(value): "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, + "ssl_include_verify_flags": parse_ssl_verify_flags, + "ssl_exclude_verify_flags": parse_ssl_verify_flags, "timeout": float, } @@ -1659,34 +1665,6 @@ def parse_url(url): if url.scheme == "rediss": kwargs["connection_class"] = SSLConnection - if "ssl_verify_flags_config" in kwargs: - # flags are passed in as a string representation of a list, - # e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)] - # To parse it successfully, we need to transform the flags to strings with quotes. - verify_flags_config_str = kwargs.pop("ssl_verify_flags_config") - # First wrap any VERIFY_* name in quotes - verify_flags_config_str = re.sub( - r"\b(VERIFY_[A-Z0-9_]+)\b", r'"\1"', verify_flags_config_str - ) - - # transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag, - # and the second is a boolean that indicates if the flag should be enabled or disabled - verify_flags_config = ast.literal_eval(verify_flags_config_str) - - ssl_verify_flags_config_parsed = [] - for flag, enabled in verify_flags_config: - if not hasattr(VerifyFlags, flag): - raise ValueError(f"Invalid ssl verify flag: {flag}") - if not isinstance(enabled, bool): - raise ValueError( - f"Invalid ssl verify flag enabled/disabled value: {enabled}" - ) - ssl_verify_flags_config_parsed.append( - (getattr(VerifyFlags, flag), enabled) - ) - - kwargs["ssl_verify_flags_config"] = ssl_verify_flags_config_parsed - return kwargs diff --git a/tests/test_asyncio/test_ssl.py b/tests/test_asyncio/test_ssl.py index a5c8ac4b56..154d20a9ea 100644 --- a/tests/test_asyncio/test_ssl.py +++ b/tests/test_asyncio/test_ssl.py @@ -57,15 +57,22 @@ async def test_cert_reqs_none_with_check_hostname(self, request): finally: await r.aclose() - async def test_ssl_flags_config_applied_to_context(self, request): - """Test that ssl_flags_config is properly applied to the SSL context in async connections""" + async def test_ssl_flags_applied_to_context(self, request): + """ + Test that ssl_include_verify_flags and ssl_exclude_verify_flags + are properly applied to the SSL context + """ ssl_url = request.config.option.redis_ssl_url parsed_url = urlparse(ssl_url) # Test with specific SSL verify flags - ssl_verify_flags_config = [ - (ssl.VerifyFlags.VERIFY_X509_STRICT, False), # Disable strict verification - (ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, True), # Enable partial chain + ssl_include_verify_flags = [ + ssl.VerifyFlags.VERIFY_CRL_CHECK_LEAF, # Disable strict verification + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, # Enable partial chain + ] + + ssl_exclude_verify_flags = [ + ssl.VerifyFlags.VERIFY_X509_STRICT, # Disable trusted first ] r = redis.Redis( @@ -73,7 +80,8 @@ async def test_ssl_flags_config_applied_to_context(self, request): port=parsed_url.port, ssl=True, ssl_cert_reqs="none", - ssl_verify_flags_config=ssl_verify_flags_config, + ssl_include_verify_flags=ssl_include_verify_flags, + ssl_exclude_verify_flags=ssl_exclude_verify_flags, ) try: @@ -81,22 +89,22 @@ async def test_ssl_flags_config_applied_to_context(self, request): conn = r.connection_pool.make_connection() assert isinstance(conn, redis.SSLConnection) - # Verify that ssl_verify_flags was stored correctly in the RedisSSLContext - assert conn.ssl_context.verify_flags_config == ssl_verify_flags_config - # Verify the flags were processed by checking they're stored in connection - assert conn.verify_flags_config is not None - assert len(conn.verify_flags_config) == 2 + assert conn.include_verify_flags is not None + assert len(conn.include_verify_flags) == 2 + + assert conn.exclude_verify_flags is not None + assert len(conn.exclude_verify_flags) == 1 # Check each flag individually - for flag, expected_enabled in ssl_verify_flags_config: - found = False - for stored_flag, stored_enabled in conn.verify_flags_config: - if stored_flag == flag: - assert stored_enabled == expected_enabled - found = True - break - assert found, f"Flag {flag} not found in stored ssl_verify_flags" + for flag in ssl_include_verify_flags: + assert flag in conn.include_verify_flags, ( + f"Flag {flag} not found in stored ssl_include_verify_flags" + ) + for flag in ssl_exclude_verify_flags: + assert flag in conn.exclude_verify_flags, ( + f"Flag {flag} not found in stored ssl_exclude_verify_flags" + ) # Test the actual SSL context created by the connection's RedisSSLContext # We need to mock the ssl.create_default_context to capture the context @@ -122,12 +130,13 @@ def capture_context_create_default(): # Verify that VERIFY_X509_STRICT was disabled (bit cleared) assert not ( - captured_context.options & ssl.VerifyFlags.VERIFY_X509_STRICT + captured_context.verify_flags & ssl.VerifyFlags.VERIFY_X509_STRICT ), "VERIFY_X509_STRICT should be disabled but is enabled" # Verify that VERIFY_CRL_CHECK_CHAIN was enabled (bit set) assert ( - captured_context.options & ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN + captured_context.verify_flags + & ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN ), "VERIFY_CRL_CHECK_CHAIN should be enabled but is disabled" finally: diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 2af4b4bdc8..2397f15600 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -539,12 +539,42 @@ def get_connection(self): return self.make_connection() pool = DummyConnectionPool.from_url( - "rediss://?ssl_verify_flags_config=[(VERIFY_X509_STRICT,False), (VERIFY_CRL_CHECK_CHAIN,True)]" + "rediss://?ssl_include_verify_flags=VERIFY_X509_STRICT,VERIFY_CRL_CHECK_CHAIN" ) - assert pool.get_connection().ssl_verify_flags_config == [ - (ssl.VerifyFlags.VERIFY_X509_STRICT, False), - (ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, True), + assert pool.get_connection().ssl_include_verify_flags == [ + ssl.VerifyFlags.VERIFY_X509_STRICT, + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, + ] + + pool = DummyConnectionPool.from_url( + "rediss://?ssl_include_verify_flags=[VERIFY_X509_STRICT, VERIFY_CRL_CHECK_CHAIN]" + ) + + assert pool.get_connection().ssl_include_verify_flags == [ + ssl.VerifyFlags.VERIFY_X509_STRICT, + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, + ] + + pool = DummyConnectionPool.from_url( + "rediss://?ssl_exclude_verify_flags=VERIFY_X509_STRICT, VERIFY_CRL_CHECK_CHAIN" + ) + + assert pool.get_connection().ssl_exclude_verify_flags == [ + ssl.VerifyFlags.VERIFY_X509_STRICT, + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, + ] + + pool = DummyConnectionPool.from_url( + "rediss://?ssl_include_verify_flags=VERIFY_X509_STRICT, VERIFY_CRL_CHECK_CHAIN&ssl_exclude_verify_flags=VERIFY_CRL_CHECK_LEAF" + ) + + assert pool.get_connection().ssl_include_verify_flags == [ + ssl.VerifyFlags.VERIFY_X509_STRICT, + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, + ] + assert pool.get_connection().ssl_exclude_verify_flags == [ + ssl.VerifyFlags.VERIFY_CRL_CHECK_LEAF, ] def test_ssl_flags_config_invalid_flag(self): @@ -554,12 +584,12 @@ def get_connection(self): with pytest.raises(ValueError): DummyConnectionPool.from_url( - "rediss://?ssl_verify_flags_config=[(VERIFY_X509,False), (VERIFY_CRL_CHECK_CHAIN,True)]" + "rediss://?ssl_include_verify_flags=[VERIFY_X509,VERIFY_CRL_CHECK_CHAIN]" ) with pytest.raises(ValueError): DummyConnectionPool.from_url( - "rediss://?ssl_verify_flags_config=[(VERIFY_X509_STRICT,Ok), (VERIFY_CRL_CHECK_CHAIN,True)]" + "rediss://?ssl_exclude_verify_flags=[VERIFY_X509_STRICT1, VERIFY_CRL_CHECK_CHAIN]" ) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 666ffeac23..96175d681f 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -329,15 +329,22 @@ def test_cert_reqs_none_with_check_hostname(self, request): finally: r.close() - def test_ssl_verify_flags_config_applied_to_context(self, request): - """Test that ssl_verify_flags_config is properly applied to the SSL context""" + def test_ssl_verify_flags_applied_to_context(self, request): + """ + Test that ssl_include_verify_flags and ssl_exclude_verify_flags + are properly applied to the SSL context + """ ssl_url = request.config.option.redis_ssl_url parsed_url = urlparse(ssl_url) # Test with specific SSL verify flags - ssl_verify_flags_config = [ - (ssl.VerifyFlags.VERIFY_X509_STRICT, False), # Disable strict verification - (ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, True), # Enable partial chain + ssl_include_verify_flags = [ + ssl.VerifyFlags.VERIFY_CRL_CHECK_LEAF, # Disable strict verification + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, # Enable partial chain + ] + + ssl_exclude_verify_flags = [ + ssl.VerifyFlags.VERIFY_X509_STRICT, # Disable trusted first ] r = redis.Redis( @@ -345,7 +352,8 @@ def test_ssl_verify_flags_config_applied_to_context(self, request): port=parsed_url.port, ssl=True, ssl_cert_reqs="none", - ssl_verify_flags_config=ssl_verify_flags_config, + ssl_include_verify_flags=ssl_include_verify_flags, + ssl_exclude_verify_flags=ssl_exclude_verify_flags, ) try: @@ -354,18 +362,21 @@ def test_ssl_verify_flags_config_applied_to_context(self, request): assert isinstance(conn, redis.SSLConnection) # Verify the flags were processed by checking they're stored in connection - assert conn.ssl_verify_flags_config is not None - assert len(conn.ssl_verify_flags_config) == 2 + assert conn.ssl_include_verify_flags is not None + assert len(conn.ssl_include_verify_flags) == 2 + + assert conn.ssl_exclude_verify_flags is not None + assert len(conn.ssl_exclude_verify_flags) == 1 # Check each flag individually - for flag, expected_enabled in ssl_verify_flags_config: - found = False - for stored_flag, stored_enabled in conn.ssl_verify_flags_config: - if stored_flag == flag: - assert stored_enabled == expected_enabled - found = True - break - assert found, f"Flag {flag} not found in stored ssl_verify_flags_config" + for flag in ssl_include_verify_flags: + assert flag in conn.ssl_include_verify_flags, ( + f"Flag {flag} not found in stored ssl_include_verify_flags" + ) + for flag in ssl_exclude_verify_flags: + assert flag in conn.ssl_exclude_verify_flags, ( + f"Flag {flag} not found in stored ssl_exclude_verify_flags" + ) # Test the actual SSL context created by the connection # We need to create a mock socket and call _wrap_socket_with_ssl to get the context @@ -400,12 +411,13 @@ def capture_context_wrap_socket(context_self, sock, **_kwargs): # Verify that VERIFY_X509_STRICT was disabled (bit cleared) assert not ( - captured_context.options & ssl.VerifyFlags.VERIFY_X509_STRICT + captured_context.verify_flags & ssl.VerifyFlags.VERIFY_X509_STRICT ), "VERIFY_X509_STRICT should be disabled but is enabled" # Verify that VERIFY_CRL_CHECK_CHAIN was enabled (bit set) assert ( - captured_context.options & ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN + captured_context.verify_flags + & ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN ), "VERIFY_CRL_CHECK_CHAIN should be enabled but is disabled" finally: