Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,12 @@ def on_connect_check_health(self, check_health: bool = True):
):
raise ConnectionError("Invalid RESP version")

# Send maintenance notifications handshake if RESP3 is active and maintenance notifications are enabled
# Send maintenance notifications handshake if RESP3 is active
# and maintenance notifications are enabled
# and we have a host to determine the endpoint type from
# When the maint_notifications_config enabled mode is "auto",
# we just log a warning if the handshake fails
# When the mode is enabled=True, we raise an exception in case of failure
if (
self.protocol not in [2, "2"]
and self.maint_notifications_config
Expand All @@ -711,15 +715,21 @@ def on_connect_check_health(self, check_health: bool = True):
)
response = self.read_response()
if str_if_bytes(response) != "OK":
raise ConnectionError(
raise ResponseError(
"The server doesn't support maintenance notifications"
)
except Exception as e:
# Log warning but don't fail the connection
import logging
if (
isinstance(e, ResponseError)
and self.maint_notifications_config.enabled == "auto"
):
# Log warning but don't fail the connection
import logging

logger = logging.getLogger(__name__)
logger.warning(f"Failed to enable maintenance notifications: {e}")
logger = logging.getLogger(__name__)
logger.warning(f"Failed to enable maintenance notifications: {e}")
else:
raise

# if a client_name is given, set it
if self.client_name:
Expand Down
13 changes: 9 additions & 4 deletions redis/maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Union

from redis.typing import Number

Expand Down Expand Up @@ -447,7 +447,7 @@ class MaintNotificationsConfig:

def __init__(
self,
enabled: bool = True,
enabled: Union[bool, Literal["auto"]] = "auto",
proactive_reconnect: bool = True,
relaxed_timeout: Optional[Number] = 10,
endpoint_type: Optional[EndpointType] = None,
Expand All @@ -456,8 +456,13 @@ def __init__(
Initialize a new MaintNotificationsConfig.
Args:
enabled (bool): Whether to enable maintenance notifications handling.
Defaults to False.
enabled (bool | "auto"): Controls maintenance notifications handling behavior.
- True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup,
otherwise a ResponseError is raised.
- "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are
gracefully handled - a warning is logged and normal operation continues.
- False: Maintenance notifications are completely disabled.
Defaults to "auto".
proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
Defaults to True.
relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class TestMaintNotificationsConfig:
def test_init_defaults(self):
"""Test MaintNotificationsConfig initialization with defaults."""
config = MaintNotificationsConfig()
assert config.enabled is True
assert config.enabled == "auto"
assert config.proactive_reconnect is True
assert config.relaxed_timeout == 10

Expand Down
117 changes: 79 additions & 38 deletions tests/test_maint_notifications_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
BlockingConnectionPool,
MaintenanceState,
)
from redis.exceptions import ResponseError
from redis.maint_notifications import (
EndpointType,
MaintNotificationsConfig,
NodeMigratingNotification,
NodeMigratedNotification,
Expand Down Expand Up @@ -201,6 +203,10 @@ def send(self, data):
if b"HELLO" in data:
response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n"
self.pending_responses.append(response)
elif b"MAINT_NOTIFICATIONS" in data and b"internal-ip" in data:
# Simulate error response - activate it only for internal-ip tests
response = b"+ERROR\r\n"
self.pending_responses.append(response)
elif b"SET" in data:
response = b"+OK\r\n"

Expand Down Expand Up @@ -337,8 +343,8 @@ def shutdown(self, how):
pass


class TestMaintenanceNotificationsHandlingSingleProxy:
"""Integration tests for maintenance notifications handling with real connection pool."""
class TestMaintenanceNotificationsBase:
"""Base class for maintenance notifications handling tests."""

def setup_method(self):
"""Set up test fixtures with mocked sockets."""
Expand Down Expand Up @@ -393,7 +399,7 @@ def _get_client(
pool_class: The connection pool class (ConnectionPool or BlockingConnectionPool)
max_connections: Maximum number of connections in the pool (default: 10)
maint_notifications_config: Optional MaintNotificationsConfig to use. If not provided,
uses self.config from setup_method (default: None)
uses self.config from setup_method (default: None)
setup_pool_handler: Whether to set up pool handler for moving notifications (default: False)

Returns:
Expand Down Expand Up @@ -425,6 +431,71 @@ def _get_client(

return test_redis_client


class TestMaintenanceNotificationsHandshake(TestMaintenanceNotificationsBase):
"""Integration tests for maintenance notifications handling with real connection pool."""

def test_handshake_success_when_enabled(self):
"""Test that handshake is performed correctly."""
maint_notifications_config = MaintNotificationsConfig(
enabled=True, endpoint_type=EndpointType.EXTERNAL_IP
)
test_redis_client = self._get_client(
ConnectionPool, maint_notifications_config=maint_notifications_config
)

try:
# Perform Redis operations that should work with our improved mock responses
result_set = test_redis_client.set("hello", "world")
result_get = test_redis_client.get("hello")

# Verify operations completed successfully
assert result_set is True
assert result_get == b"world"

finally:
test_redis_client.close()

def test_handshake_success_when_auto_and_command_not_supported(self):
"""Test that when maintenance notifications are set to 'auto', the client gracefully handles unsupported MAINT_NOTIFICATIONS commands and normal Redis operations succeed."""
maint_notifications_config = MaintNotificationsConfig(
enabled="auto", endpoint_type=EndpointType.INTERNAL_IP
)
test_redis_client = self._get_client(
ConnectionPool, maint_notifications_config=maint_notifications_config
)

try:
# Perform Redis operations that should work with our improved mock responses
result_set = test_redis_client.set("hello", "world")
result_get = test_redis_client.get("hello")

# Verify operations completed successfully
assert result_set is True
assert result_get == b"world"

finally:
test_redis_client.close()

def test_handshake_failure_when_enabled(self):
"""Test that handshake is performed correctly."""
maint_notifications_config = MaintNotificationsConfig(
enabled=True, endpoint_type=EndpointType.INTERNAL_IP
)
test_redis_client = self._get_client(
ConnectionPool, maint_notifications_config=maint_notifications_config
)
try:
with pytest.raises(ResponseError):
test_redis_client.set("hello", "world")

finally:
test_redis_client.close()


class TestMaintenanceNotificationsHandlingSingleProxy(TestMaintenanceNotificationsBase):
"""Integration tests for maintenance notifications handling with real connection pool."""

def _validate_connection_handlers(self, conn, pool_handler, config):
"""Helper method to validate connection handlers are properly set."""
# Test that the node moving handler function is correctly set
Expand Down Expand Up @@ -1891,40 +1962,16 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class):
pool.disconnect()


class TestMaintenanceNotificationsHandlingMultipleProxies:
class TestMaintenanceNotificationsHandlingMultipleProxies(
TestMaintenanceNotificationsBase
):
"""Integration tests for maintenance notifications handling with real connection pool."""

def setup_method(self):
"""Set up test fixtures with mocked sockets."""
self.mock_sockets = []
self.original_socket = socket.socket
super().setup_method()
self.orig_host = "test.address.com"

# Mock socket creation to return our mock sockets
def mock_socket_factory(*args, **kwargs):
mock_sock = MockSocket()
self.mock_sockets.append(mock_sock)
return mock_sock

self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory)
self.socket_patcher.start()

# Mock select.select to simulate data availability for reading
def mock_select(rlist, wlist, xlist, timeout=0):
# Check if any of the sockets in rlist have data available
ready_sockets = []
for sock in rlist:
if hasattr(sock, "connected") and sock.connected and not sock.closed:
# Only return socket as ready if it actually has data to read
if hasattr(sock, "pending_responses") and sock.pending_responses:
ready_sockets.append(sock)
# Don't return socket as ready just because it received commands
# Only when there are actual responses available
return (ready_sockets, [], [])

self.select_patcher = patch("select.select", side_effect=mock_select)
self.select_patcher.start()

ips = ["1.2.3.4", "5.6.7.8", "9.10.11.12"]
ips = ips * 3

Expand Down Expand Up @@ -1952,15 +1999,9 @@ def mock_socket_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
)
self.getaddrinfo_patcher.start()

# Create maintenance notifications config
self.config = MaintNotificationsConfig(
enabled=True, proactive_reconnect=True, relaxed_timeout=30
)

def teardown_method(self):
"""Clean up test fixtures."""
self.socket_patcher.stop()
self.select_patcher.stop()
super().teardown_method()
self.getaddrinfo_patcher.stop()

@pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool])
Expand Down