diff --git a/redis/connection.py b/redis/connection.py index a09156b0f3..837fccd40e 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -688,8 +688,12 @@ def on_connect_check_health(self, check_health: bool = True): ): raise ConnectionError("Invalid RESP version") - # Send maintenance notifications handshake if RESP3 is active and maintenance notifications are enabled + # Send maintenance notifications handshake if RESP3 is active + # and maintenance notifications are enabled # and we have a host to determine the endpoint type from + # When the maint_notifications_config enabled mode is "auto", + # we just log a warning if the handshake fails + # When the mode is enabled=True, we raise an exception in case of failure if ( self.protocol not in [2, "2"] and self.maint_notifications_config @@ -711,15 +715,21 @@ def on_connect_check_health(self, check_health: bool = True): ) response = self.read_response() if str_if_bytes(response) != "OK": - raise ConnectionError( + raise ResponseError( "The server doesn't support maintenance notifications" ) except Exception as e: - # Log warning but don't fail the connection - import logging + if ( + isinstance(e, ResponseError) + and self.maint_notifications_config.enabled == "auto" + ): + # Log warning but don't fail the connection + import logging - logger = logging.getLogger(__name__) - logger.warning(f"Failed to enable maintenance notifications: {e}") + logger = logging.getLogger(__name__) + logger.warning(f"Failed to enable maintenance notifications: {e}") + else: + raise # if a client_name is given, set it if self.client_name: diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index eeb77299ef..37e4f93a3f 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -5,7 +5,7 @@ import threading import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union from redis.typing import Number @@ -447,7 +447,7 @@ class MaintNotificationsConfig: def __init__( self, - enabled: bool = True, + enabled: Union[bool, Literal["auto"]] = "auto", proactive_reconnect: bool = True, relaxed_timeout: Optional[Number] = 10, endpoint_type: Optional[EndpointType] = None, @@ -456,8 +456,13 @@ def __init__( Initialize a new MaintNotificationsConfig. Args: - enabled (bool): Whether to enable maintenance notifications handling. - Defaults to False. + enabled (bool | "auto"): Controls maintenance notifications handling behavior. + - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup, + otherwise a ResponseError is raised. + - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are + gracefully handled - a warning is logged and normal operation continues. + - False: Maintenance notifications are completely disabled. + Defaults to "auto". proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. Defaults to True. relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance. diff --git a/tests/test_maint_notifications.py b/tests/test_maint_notifications.py index b365189b7b..08ac15368f 100644 --- a/tests/test_maint_notifications.py +++ b/tests/test_maint_notifications.py @@ -387,7 +387,7 @@ class TestMaintNotificationsConfig: def test_init_defaults(self): """Test MaintNotificationsConfig initialization with defaults.""" config = MaintNotificationsConfig() - assert config.enabled is True + assert config.enabled == "auto" assert config.proactive_reconnect is True assert config.relaxed_timeout == 10 diff --git a/tests/test_maint_notifications_handling.py b/tests/test_maint_notifications_handling.py index 38f614bb90..baa7d601fa 100644 --- a/tests/test_maint_notifications_handling.py +++ b/tests/test_maint_notifications_handling.py @@ -13,7 +13,9 @@ BlockingConnectionPool, MaintenanceState, ) +from redis.exceptions import ResponseError from redis.maint_notifications import ( + EndpointType, MaintNotificationsConfig, NodeMigratingNotification, NodeMigratedNotification, @@ -201,6 +203,10 @@ def send(self, data): if b"HELLO" in data: response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" self.pending_responses.append(response) + elif b"MAINT_NOTIFICATIONS" in data and b"internal-ip" in data: + # Simulate error response - activate it only for internal-ip tests + response = b"+ERROR\r\n" + self.pending_responses.append(response) elif b"SET" in data: response = b"+OK\r\n" @@ -337,8 +343,8 @@ def shutdown(self, how): pass -class TestMaintenanceNotificationsHandlingSingleProxy: - """Integration tests for maintenance notifications handling with real connection pool.""" +class TestMaintenanceNotificationsBase: + """Base class for maintenance notifications handling tests.""" def setup_method(self): """Set up test fixtures with mocked sockets.""" @@ -393,7 +399,7 @@ def _get_client( pool_class: The connection pool class (ConnectionPool or BlockingConnectionPool) max_connections: Maximum number of connections in the pool (default: 10) maint_notifications_config: Optional MaintNotificationsConfig to use. If not provided, - uses self.config from setup_method (default: None) + uses self.config from setup_method (default: None) setup_pool_handler: Whether to set up pool handler for moving notifications (default: False) Returns: @@ -425,6 +431,71 @@ def _get_client( return test_redis_client + +class TestMaintenanceNotificationsHandshake(TestMaintenanceNotificationsBase): + """Integration tests for maintenance notifications handling with real connection pool.""" + + def test_handshake_success_when_enabled(self): + """Test that handshake is performed correctly.""" + maint_notifications_config = MaintNotificationsConfig( + enabled=True, endpoint_type=EndpointType.EXTERNAL_IP + ) + test_redis_client = self._get_client( + ConnectionPool, maint_notifications_config=maint_notifications_config + ) + + try: + # Perform Redis operations that should work with our improved mock responses + result_set = test_redis_client.set("hello", "world") + result_get = test_redis_client.get("hello") + + # Verify operations completed successfully + assert result_set is True + assert result_get == b"world" + + finally: + test_redis_client.close() + + def test_handshake_success_when_auto_and_command_not_supported(self): + """Test that when maintenance notifications are set to 'auto', the client gracefully handles unsupported MAINT_NOTIFICATIONS commands and normal Redis operations succeed.""" + maint_notifications_config = MaintNotificationsConfig( + enabled="auto", endpoint_type=EndpointType.INTERNAL_IP + ) + test_redis_client = self._get_client( + ConnectionPool, maint_notifications_config=maint_notifications_config + ) + + try: + # Perform Redis operations that should work with our improved mock responses + result_set = test_redis_client.set("hello", "world") + result_get = test_redis_client.get("hello") + + # Verify operations completed successfully + assert result_set is True + assert result_get == b"world" + + finally: + test_redis_client.close() + + def test_handshake_failure_when_enabled(self): + """Test that handshake is performed correctly.""" + maint_notifications_config = MaintNotificationsConfig( + enabled=True, endpoint_type=EndpointType.INTERNAL_IP + ) + test_redis_client = self._get_client( + ConnectionPool, maint_notifications_config=maint_notifications_config + ) + try: + with pytest.raises(ResponseError): + test_redis_client.set("hello", "world") + + finally: + test_redis_client.close() + + +class TestMaintenanceNotificationsHandlingSingleProxy(TestMaintenanceNotificationsBase): + """Integration tests for maintenance notifications handling with real connection pool.""" + def _validate_connection_handlers(self, conn, pool_handler, config): """Helper method to validate connection handlers are properly set.""" # Test that the node moving handler function is correctly set @@ -1891,40 +1962,16 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): pool.disconnect() -class TestMaintenanceNotificationsHandlingMultipleProxies: +class TestMaintenanceNotificationsHandlingMultipleProxies( + TestMaintenanceNotificationsBase +): """Integration tests for maintenance notifications handling with real connection pool.""" def setup_method(self): """Set up test fixtures with mocked sockets.""" - self.mock_sockets = [] - self.original_socket = socket.socket + super().setup_method() self.orig_host = "test.address.com" - # Mock socket creation to return our mock sockets - def mock_socket_factory(*args, **kwargs): - mock_sock = MockSocket() - self.mock_sockets.append(mock_sock) - return mock_sock - - self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) - self.socket_patcher.start() - - # Mock select.select to simulate data availability for reading - def mock_select(rlist, wlist, xlist, timeout=0): - # Check if any of the sockets in rlist have data available - ready_sockets = [] - for sock in rlist: - if hasattr(sock, "connected") and sock.connected and not sock.closed: - # Only return socket as ready if it actually has data to read - if hasattr(sock, "pending_responses") and sock.pending_responses: - ready_sockets.append(sock) - # Don't return socket as ready just because it received commands - # Only when there are actual responses available - return (ready_sockets, [], []) - - self.select_patcher = patch("select.select", side_effect=mock_select) - self.select_patcher.start() - ips = ["1.2.3.4", "5.6.7.8", "9.10.11.12"] ips = ips * 3 @@ -1952,15 +1999,9 @@ def mock_socket_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): ) self.getaddrinfo_patcher.start() - # Create maintenance notifications config - self.config = MaintNotificationsConfig( - enabled=True, proactive_reconnect=True, relaxed_timeout=30 - ) - def teardown_method(self): """Clean up test fixtures.""" - self.socket_patcher.stop() - self.select_patcher.stop() + super().teardown_method() self.getaddrinfo_patcher.stop() @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool])