diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index d4650e1791..3defeceead 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -81,10 +81,11 @@ ) if TYPE_CHECKING and SSL_AVAILABLE: - from ssl import TLSVersion, VerifyMode + from ssl import TLSVersion, VerifyFlags, VerifyMode else: TLSVersion = None VerifyMode = None + VerifyFlags = None PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) @@ -238,6 +239,8 @@ def __init__( ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Union[str, VerifyMode] = "required", + ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = True, @@ -347,6 +350,8 @@ def __init__( "ssl_keyfile": ssl_keyfile, "ssl_certfile": ssl_certfile, "ssl_cert_reqs": ssl_cert_reqs, + "ssl_include_verify_flags": ssl_include_verify_flags, + "ssl_exclude_verify_flags": ssl_exclude_verify_flags, "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index dd68f388b2..4e0e06517d 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -86,10 +86,11 @@ ) if SSL_AVAILABLE: - from ssl import TLSVersion, VerifyMode + from ssl import TLSVersion, VerifyFlags, VerifyMode else: TLSVersion = None VerifyMode = None + VerifyFlags = None TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -299,6 +300,8 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_cert_reqs: Union[str, VerifyMode] = "required", + ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, + ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = True, ssl_keyfile: Optional[str] = None, @@ -358,6 +361,8 @@ def __init__( "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_cert_reqs": ssl_cert_reqs, + "ssl_include_verify_flags": ssl_include_verify_flags, + "ssl_exclude_verify_flags": ssl_exclude_verify_flags, "ssl_certfile": ssl_certfile, "ssl_check_hostname": ssl_check_hostname, "ssl_keyfile": ssl_keyfile, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4efd868f6f..e3eb3bd9f1 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -30,11 +30,12 @@ if SSL_AVAILABLE: import ssl - from ssl import SSLContext, TLSVersion + from ssl import SSLContext, TLSVersion, VerifyFlags else: ssl = None TLSVersion = None SSLContext = None + VerifyFlags = None from ..auth.token import TokenInterface from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher @@ -793,6 +794,8 @@ def __init__( ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", + ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = True, @@ -807,6 +810,8 @@ def __init__( keyfile=ssl_keyfile, certfile=ssl_certfile, cert_reqs=ssl_cert_reqs, + include_verify_flags=ssl_include_verify_flags, + exclude_verify_flags=ssl_exclude_verify_flags, ca_certs=ssl_ca_certs, ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, @@ -832,6 +837,14 @@ def certfile(self): def cert_reqs(self): return self.ssl_context.cert_reqs + @property + def include_verify_flags(self): + return self.ssl_context.include_verify_flags + + @property + def exclude_verify_flags(self): + return self.ssl_context.exclude_verify_flags + @property def ca_certs(self): return self.ssl_context.ca_certs @@ -854,6 +867,8 @@ class RedisSSLContext: "keyfile", "certfile", "cert_reqs", + "include_verify_flags", + "exclude_verify_flags", "ca_certs", "ca_data", "context", @@ -867,6 +882,8 @@ def __init__( keyfile: Optional[str] = None, certfile: Optional[str] = None, cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, + include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, ca_certs: Optional[str] = None, ca_data: Optional[str] = None, check_hostname: bool = False, @@ -892,6 +909,8 @@ def __init__( ) cert_reqs = CERT_REQS[cert_reqs] self.cert_reqs = cert_reqs + self.include_verify_flags = include_verify_flags + self.exclude_verify_flags = exclude_verify_flags self.ca_certs = ca_certs self.ca_data = ca_data self.check_hostname = ( @@ -906,6 +925,12 @@ def get(self) -> SSLContext: context = ssl.create_default_context() context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs + if self.include_verify_flags: + for flag in self.include_verify_flags: + context.verify_flags |= flag + if self.exclude_verify_flags: + for flag in self.exclude_verify_flags: + context.verify_flags &= ~flag if self.certfile and self.keyfile: context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) if self.ca_certs or self.ca_data: @@ -953,6 +978,20 @@ def to_bool(value) -> Optional[bool]: return bool(value) +def parse_ssl_verify_flags(value): + # flags are passed in as a string representation of a list, + # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN + verify_flags_str = value.replace("[", "").replace("]", "") + + verify_flags = [] + for flag in verify_flags_str.split(","): + flag = flag.strip() + if not hasattr(VerifyFlags, flag): + raise ValueError(f"Invalid ssl verify flag: {flag}") + verify_flags.append(getattr(VerifyFlags, flag)) + return verify_flags + + URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( { "db": int, @@ -963,6 +1002,8 @@ def to_bool(value) -> Optional[bool]: "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, + "ssl_include_verify_flags": parse_ssl_verify_flags, + "ssl_exclude_verify_flags": parse_ssl_verify_flags, "timeout": float, } ) @@ -1021,6 +1062,7 @@ def parse_url(url: str) -> ConnectKwargs: if parsed.scheme == "rediss": kwargs["connection_class"] = SSLConnection + else: valid_schemes = "redis://, rediss://, unix://" raise ValueError( diff --git a/redis/client.py b/redis/client.py index 163ef3fedc..cf4d77950f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -224,6 +224,8 @@ def __init__( ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required", + ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, + ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, ssl_ca_certs: Optional[str] = None, ssl_ca_path: Optional[str] = None, ssl_ca_data: Optional[str] = None, @@ -330,6 +332,8 @@ def __init__( "ssl_keyfile": ssl_keyfile, "ssl_certfile": ssl_certfile, "ssl_cert_reqs": ssl_cert_reqs, + "ssl_include_verify_flags": ssl_include_verify_flags, + "ssl_exclude_verify_flags": ssl_exclude_verify_flags, "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, diff --git a/redis/cluster.py b/redis/cluster.py index 7c645be755..839721edf1 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -184,6 +184,8 @@ def parse_cluster_myshardid(resp, **options): "ssl_ca_data", "ssl_certfile", "ssl_cert_reqs", + "ssl_include_verify_flags", + "ssl_exclude_verify_flags", "ssl_keyfile", "ssl_password", "ssl_check_hostname", diff --git a/redis/connection.py b/redis/connection.py index 7c7071f635..a09156b0f3 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -68,8 +68,10 @@ if SSL_AVAILABLE: import ssl + from ssl import VerifyFlags else: ssl = None + VerifyFlags = None if HIREDIS_AVAILABLE: import hiredis @@ -1360,6 +1362,8 @@ def __init__( ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs="required", + ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None, + ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None, ssl_ca_certs=None, ssl_ca_data=None, ssl_check_hostname=True, @@ -1378,7 +1382,10 @@ def __init__( Args: ssl_keyfile: Path to an ssl private key. Defaults to None. ssl_certfile: Path to an ssl certificate. Defaults to None. - ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required". + ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), + or an ssl.VerifyMode. Defaults to "required". + ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None. + ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None. ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. @@ -1414,6 +1421,8 @@ def __init__( ) ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] self.cert_reqs = ssl_cert_reqs + self.ssl_include_verify_flags = ssl_include_verify_flags + self.ssl_exclude_verify_flags = ssl_exclude_verify_flags self.ca_certs = ssl_ca_certs self.ca_data = ssl_ca_data self.ca_path = ssl_ca_path @@ -1453,6 +1462,12 @@ def _wrap_socket_with_ssl(self, sock): context = ssl.create_default_context() context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs + if self.ssl_include_verify_flags: + for flag in self.ssl_include_verify_flags: + context.verify_flags |= flag + if self.ssl_exclude_verify_flags: + for flag in self.ssl_exclude_verify_flags: + context.verify_flags &= ~flag if self.certfile or self.keyfile: context.load_cert_chain( certfile=self.certfile, @@ -1566,6 +1581,20 @@ def to_bool(value): return bool(value) +def parse_ssl_verify_flags(value): + # flags are passed in as a string representation of a list, + # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN + verify_flags_str = value.replace("[", "").replace("]", "") + + verify_flags = [] + for flag in verify_flags_str.split(","): + flag = flag.strip() + if not hasattr(VerifyFlags, flag): + raise ValueError(f"Invalid ssl verify flag: {flag}") + verify_flags.append(getattr(VerifyFlags, flag)) + return verify_flags + + URL_QUERY_ARGUMENT_PARSERS = { "db": int, "socket_timeout": float, @@ -1576,6 +1605,8 @@ def to_bool(value): "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, + "ssl_include_verify_flags": parse_ssl_verify_flags, + "ssl_exclude_verify_flags": parse_ssl_verify_flags, "timeout": float, } diff --git a/tests/test_asyncio/test_ssl.py b/tests/test_asyncio/test_ssl.py index 75800f22de..154d20a9ea 100644 --- a/tests/test_asyncio/test_ssl.py +++ b/tests/test_asyncio/test_ssl.py @@ -1,3 +1,5 @@ +import ssl +import unittest.mock from urllib.parse import urlparse import pytest import pytest_asyncio @@ -54,3 +56,88 @@ async def test_cert_reqs_none_with_check_hostname(self, request): assert conn.check_hostname is False finally: await r.aclose() + + async def test_ssl_flags_applied_to_context(self, request): + """ + Test that ssl_include_verify_flags and ssl_exclude_verify_flags + are properly applied to the SSL context + """ + ssl_url = request.config.option.redis_ssl_url + parsed_url = urlparse(ssl_url) + + # Test with specific SSL verify flags + ssl_include_verify_flags = [ + ssl.VerifyFlags.VERIFY_CRL_CHECK_LEAF, # Disable strict verification + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, # Enable partial chain + ] + + ssl_exclude_verify_flags = [ + ssl.VerifyFlags.VERIFY_X509_STRICT, # Disable trusted first + ] + + r = redis.Redis( + host=parsed_url.hostname, + port=parsed_url.port, + ssl=True, + ssl_cert_reqs="none", + ssl_include_verify_flags=ssl_include_verify_flags, + ssl_exclude_verify_flags=ssl_exclude_verify_flags, + ) + + try: + # Get the connection to trigger SSL context creation + conn = r.connection_pool.make_connection() + assert isinstance(conn, redis.SSLConnection) + + # Verify the flags were processed by checking they're stored in connection + assert conn.include_verify_flags is not None + assert len(conn.include_verify_flags) == 2 + + assert conn.exclude_verify_flags is not None + assert len(conn.exclude_verify_flags) == 1 + + # Check each flag individually + for flag in ssl_include_verify_flags: + assert flag in conn.include_verify_flags, ( + f"Flag {flag} not found in stored ssl_include_verify_flags" + ) + for flag in ssl_exclude_verify_flags: + assert flag in conn.exclude_verify_flags, ( + f"Flag {flag} not found in stored ssl_exclude_verify_flags" + ) + + # Test the actual SSL context created by the connection's RedisSSLContext + # We need to mock the ssl.create_default_context to capture the context + captured_context = None + original_create_default_context = ssl.create_default_context + + def capture_context_create_default(): + nonlocal captured_context + captured_context = original_create_default_context() + return captured_context + + with unittest.mock.patch( + "ssl.create_default_context", capture_context_create_default + ): + # Trigger SSL context creation by calling get() on the RedisSSLContext + ssl_context = conn.ssl_context.get() + + # Validate that we captured a context and it has the correct flags applied + assert captured_context is not None, "SSL context was not captured" + assert ssl_context is captured_context, ( + "Returned context should be the captured one" + ) + + # Verify that VERIFY_X509_STRICT was disabled (bit cleared) + assert not ( + captured_context.verify_flags & ssl.VerifyFlags.VERIFY_X509_STRICT + ), "VERIFY_X509_STRICT should be disabled but is enabled" + + # Verify that VERIFY_CRL_CHECK_CHAIN was enabled (bit set) + assert ( + captured_context.verify_flags + & ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN + ), "VERIFY_CRL_CHECK_CHAIN should be enabled but is disabled" + + finally: + await r.aclose() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 1eb68d3775..2397f15600 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -19,6 +19,9 @@ ) from .test_pubsub import wait_for_message +if SSL_AVAILABLE: + import ssl + class DummyConnection: description_format = "DummyConnection<>" @@ -511,8 +514,6 @@ class MyConnection(redis.SSLConnection): assert pool.connection_class == MyConnection def test_cert_reqs_options(self): - import ssl - class DummyConnectionPool(redis.ConnectionPool): def get_connection(self): return self.make_connection() @@ -532,6 +533,65 @@ def get_connection(self): pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") assert pool.get_connection().check_hostname is True + def test_ssl_flags_config_parsing(self): + class DummyConnectionPool(redis.ConnectionPool): + def get_connection(self): + return self.make_connection() + + pool = DummyConnectionPool.from_url( + "rediss://?ssl_include_verify_flags=VERIFY_X509_STRICT,VERIFY_CRL_CHECK_CHAIN" + ) + + assert pool.get_connection().ssl_include_verify_flags == [ + ssl.VerifyFlags.VERIFY_X509_STRICT, + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, + ] + + pool = DummyConnectionPool.from_url( + "rediss://?ssl_include_verify_flags=[VERIFY_X509_STRICT, VERIFY_CRL_CHECK_CHAIN]" + ) + + assert pool.get_connection().ssl_include_verify_flags == [ + ssl.VerifyFlags.VERIFY_X509_STRICT, + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, + ] + + pool = DummyConnectionPool.from_url( + "rediss://?ssl_exclude_verify_flags=VERIFY_X509_STRICT, VERIFY_CRL_CHECK_CHAIN" + ) + + assert pool.get_connection().ssl_exclude_verify_flags == [ + ssl.VerifyFlags.VERIFY_X509_STRICT, + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, + ] + + pool = DummyConnectionPool.from_url( + "rediss://?ssl_include_verify_flags=VERIFY_X509_STRICT, VERIFY_CRL_CHECK_CHAIN&ssl_exclude_verify_flags=VERIFY_CRL_CHECK_LEAF" + ) + + assert pool.get_connection().ssl_include_verify_flags == [ + ssl.VerifyFlags.VERIFY_X509_STRICT, + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, + ] + assert pool.get_connection().ssl_exclude_verify_flags == [ + ssl.VerifyFlags.VERIFY_CRL_CHECK_LEAF, + ] + + def test_ssl_flags_config_invalid_flag(self): + class DummyConnectionPool(redis.ConnectionPool): + def get_connection(self): + return self.make_connection() + + with pytest.raises(ValueError): + DummyConnectionPool.from_url( + "rediss://?ssl_include_verify_flags=[VERIFY_X509,VERIFY_CRL_CHECK_CHAIN]" + ) + + with pytest.raises(ValueError): + DummyConnectionPool.from_url( + "rediss://?ssl_exclude_verify_flags=[VERIFY_X509_STRICT1, VERIFY_CRL_CHECK_CHAIN]" + ) + class TestConnection: def test_on_connect_error(self): diff --git a/tests/test_ssl.py b/tests/test_ssl.py index cb3f227629..96175d681f 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -328,3 +328,100 @@ def test_cert_reqs_none_with_check_hostname(self, request): assert conn.check_hostname is False finally: r.close() + + def test_ssl_verify_flags_applied_to_context(self, request): + """ + Test that ssl_include_verify_flags and ssl_exclude_verify_flags + are properly applied to the SSL context + """ + ssl_url = request.config.option.redis_ssl_url + parsed_url = urlparse(ssl_url) + + # Test with specific SSL verify flags + ssl_include_verify_flags = [ + ssl.VerifyFlags.VERIFY_CRL_CHECK_LEAF, # Disable strict verification + ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, # Enable partial chain + ] + + ssl_exclude_verify_flags = [ + ssl.VerifyFlags.VERIFY_X509_STRICT, # Disable trusted first + ] + + r = redis.Redis( + host=parsed_url.hostname, + port=parsed_url.port, + ssl=True, + ssl_cert_reqs="none", + ssl_include_verify_flags=ssl_include_verify_flags, + ssl_exclude_verify_flags=ssl_exclude_verify_flags, + ) + + try: + # Get the connection to trigger SSL context creation + conn = r.connection_pool.get_connection() + assert isinstance(conn, redis.SSLConnection) + + # Verify the flags were processed by checking they're stored in connection + assert conn.ssl_include_verify_flags is not None + assert len(conn.ssl_include_verify_flags) == 2 + + assert conn.ssl_exclude_verify_flags is not None + assert len(conn.ssl_exclude_verify_flags) == 1 + + # Check each flag individually + for flag in ssl_include_verify_flags: + assert flag in conn.ssl_include_verify_flags, ( + f"Flag {flag} not found in stored ssl_include_verify_flags" + ) + for flag in ssl_exclude_verify_flags: + assert flag in conn.ssl_exclude_verify_flags, ( + f"Flag {flag} not found in stored ssl_exclude_verify_flags" + ) + + # Test the actual SSL context created by the connection + # We need to create a mock socket and call _wrap_socket_with_ssl to get the context + import socket + import unittest.mock + + mock_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + try: + # Mock the wrap_socket method to capture the context + captured_context = None + + def capture_context_wrap_socket(context_self, sock, **_kwargs): + nonlocal captured_context + captured_context = context_self + # Don't actually wrap the socket, just return the original socket + # to avoid connection errors + return sock + + with unittest.mock.patch.object( + ssl.SSLContext, "wrap_socket", capture_context_wrap_socket + ): + try: + conn._wrap_socket_with_ssl(mock_sock) + except Exception: + # We expect this to potentially fail since we're not actually connecting + # but we should have captured the context + pass + + # Validate that we captured a context and it has the correct flags applied + assert captured_context is not None, "SSL context was not captured" + + # Verify that VERIFY_X509_STRICT was disabled (bit cleared) + assert not ( + captured_context.verify_flags & ssl.VerifyFlags.VERIFY_X509_STRICT + ), "VERIFY_X509_STRICT should be disabled but is enabled" + + # Verify that VERIFY_CRL_CHECK_CHAIN was enabled (bit set) + assert ( + captured_context.verify_flags + & ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN + ), "VERIFY_CRL_CHECK_CHAIN should be enabled but is disabled" + + finally: + mock_sock.close() + + finally: + r.close()