diff --git a/redis/client.py b/redis/client.py index adb57d404e..e22ca3d73d 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1217,7 +1217,8 @@ def run_in_thread( sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, - pubsub = None + pubsub = None, + sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: @@ -1233,7 +1234,7 @@ def run_in_thread( pubsub = self if pubsub is None else pubsub thread = PubSubWorkerThread( - pubsub, sleep_time, daemon=daemon, exception_handler=exception_handler + pubsub, sleep_time, daemon=daemon, exception_handler=exception_handler, sharded_pubsub=sharded_pubsub ) thread.start() return thread @@ -1248,12 +1249,14 @@ def __init__( exception_handler: Union[ Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None ] = None, + sharded_pubsub: bool = False, ): super().__init__() self.daemon = daemon self.pubsub = pubsub self.sleep_time = sleep_time self.exception_handler = exception_handler + self.sharded_pubsub = sharded_pubsub self._running = threading.Event() def run(self) -> None: @@ -1264,7 +1267,10 @@ def run(self) -> None: sleep_time = self.sleep_time while self._running.is_set(): try: - pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) + if not self.sharded_pubsub: + pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) + else: + pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=sleep_time) except BaseException as e: if self.exception_handler is None: raise diff --git a/redis/cluster.py b/redis/cluster.py index dbcf5cc2b7..2fd4625e6b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -3154,7 +3154,8 @@ def _reinitialize_on_error(self, error): self._nodes_manager.initialize() self.reinitialize_counter = 0 else: - self._nodes_manager.update_moved_exception(error) + if type(error) == MovedError: + self._nodes_manager.update_moved_exception(error) self._executing = False diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 8183d11293..2f87024f20 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -9,7 +9,7 @@ from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.multidb.database import State as DBState, Database, AbstractDatabase, Databases +from redis.multidb.database import Database, AbstractDatabase, Databases from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -78,13 +78,8 @@ def raise_exception_on_failed_hc(error): # Set states according to a weights and circuit state if database.circuit.state == CBState.CLOSED and not is_active_db_found: - database.state = DBState.ACTIVE self.command_executor.active_database = database is_active_db_found = True - elif database.circuit.state == CBState.CLOSED and is_active_db_found: - database.state = DBState.PASSIVE - else: - database.state = DBState.DISCONNECTED if not is_active_db_found: raise NoValidDatabaseException('Initial connection failed - no active database found') @@ -115,8 +110,6 @@ def set_active_database(self, database: AbstractDatabase) -> None: if database.circuit.state == CBState.CLOSED: highest_weighted_db, _ = self._databases.get_top_n(1)[0] - highest_weighted_db.state = DBState.PASSIVE - database.state = DBState.ACTIVE self.command_executor.active_database = database return @@ -138,9 +131,7 @@ def add_database(self, database: AbstractDatabase): def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: - new_database.state = DBState.ACTIVE self.command_executor.active_database = new_database - highest_weight_database.state = DBState.PASSIVE def remove_database(self, database: Database): """ @@ -150,7 +141,6 @@ def remove_database(self, database: Database): highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: - highest_weighted_db.state = DBState.ACTIVE self.command_executor.active_database = highest_weighted_db def update_database_weight(self, database: AbstractDatabase, weight: float): @@ -240,7 +230,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep database.circuit.state = CBState.OPEN elif is_healthy and database.circuit.state != CBState.CLOSED: database.circuit.state = CBState.CLOSED - except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError) as e: + except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError, ValueError) as e: if database.circuit.state != CBState.OPEN: database.circuit.state = CBState.OPEN is_healthy = False @@ -334,7 +324,7 @@ def execute(self) -> List[Any]: class PubSub: """ - PubSub object for multi-database client. + PubSub object for multi database client. """ def __init__(self, client: MultiDBClient, **kwargs): """Initialize the PubSub object for a multi-database client. @@ -438,18 +428,33 @@ def get_message( ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout ) - get_sharded_message = get_message + def get_sharded_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available in a sharded channel, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number, or None, to wait indefinitely. + """ + return self._client.command_executor.execute_pubsub_method( + 'get_sharded_message', + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) def run_in_thread( self, sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, + sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": return self._client.command_executor.execute_pubsub_run_in_thread( sleep_time=sleep_time, daemon=daemon, exception_handler=exception_handler, - pubsub=self + pubsub=self, + sharded_pubsub=sharded_pubsub, ) diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 40370c2e18..094230a31d 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -235,10 +235,15 @@ def execute_pubsub_run_in_thread( sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, + sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": def callback(): return self._active_pubsub.run_in_thread( - sleep_time, daemon=daemon, exception_handler=exception_handler, pubsub=pubsub + sleep_time, + daemon=daemon, + exception_handler=exception_handler, + pubsub=pubsub, + sharded_pubsub=sharded_pubsub ) return self._execute_with_failure_detection(callback) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 204b7c91f3..3253ffa093 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -8,12 +8,6 @@ from redis.multidb.circuit import CircuitBreaker from redis.typing import Number - -class State(Enum): - ACTIVE = 0 - PASSIVE = 1 - DISCONNECTED = 2 - class AbstractDatabase(ABC): @property @abstractmethod @@ -39,18 +33,6 @@ def weight(self, weight: float): """Set the weight of this database in compare to others.""" pass - @property - @abstractmethod - def state(self) -> State: - """The state of the current database.""" - pass - - @state.setter - @abstractmethod - def state(self, state: State): - """Set the state of the current database.""" - pass - @property @abstractmethod def circuit(self) -> CircuitBreaker: @@ -70,8 +52,7 @@ def __init__( self, client: Union[redis.Redis, RedisCluster], circuit: CircuitBreaker, - weight: float, - state: State = State.DISCONNECTED, + weight: float ): """ Initialize a new Database instance. @@ -86,7 +67,6 @@ def __init__( self._cb = circuit self._cb.database = self self._weight = weight - self._state = state @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -104,14 +84,6 @@ def weight(self) -> float: def weight(self, weight: float): self._weight = weight - @property - def state(self) -> State: - return self._state - - @state.setter - def state(self, state: State): - self._state = state - @property def circuit(self) -> CircuitBreaker: return self._cb diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 7b16d4ba88..2598bc4d06 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -39,7 +39,7 @@ def kwargs(self): class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): """ - Re-subscribe the currently active pub / sub to a new active database. + Re-subscribe currently active pub/sub to a new active database. """ def listen(self, event: ActiveDatabaseChanged): old_pubsub = event.command_executor.active_pubsub diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 1396a1e997..cca220dc3f 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,4 +1,7 @@ from abc import abstractmethod, ABC + +import redis +from redis import Redis from redis.retry import Retry @@ -51,8 +54,20 @@ def check_health(self, database) -> bool: def _returns_echoed_message(self, database) -> bool: expected_message = ["healthcheck", b"healthcheck"] - actual_message = database.client.execute_command('ECHO', "healthcheck") - return actual_message in expected_message + + if isinstance(database.client, Redis): + actual_message = database.client.execute_command("ECHO" ,"healthcheck") + return actual_message in expected_message + else: + # For a cluster checks if all nodes are healthy. + all_nodes = database.client.get_nodes() + for node in all_nodes: + actual_message = node.redis_connection.execute_command("ECHO" ,"healthcheck") + + if actual_message not in expected_message: + return False + + return True def _dummy_fail(self): pass \ No newline at end of file diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index f85e0a6fd7..a34ef01476 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -7,7 +7,7 @@ from redis.multidb.circuit import CircuitBreaker, State as CBState from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL -from redis.multidb.database import Database, State, Databases +from redis.multidb.database import Database, Databases from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -38,7 +38,6 @@ def mock_hc() -> HealthCheck: def mock_db(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) - db.state = request.param.get("state", State.ACTIVE) db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) @@ -53,7 +52,6 @@ def mock_db(request) -> Database: def mock_db1(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) - db.state = request.param.get("state", State.ACTIVE) db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) @@ -68,7 +66,6 @@ def mock_db1(request) -> Database: def mock_db2(request) -> Database: db = Mock(spec=Database) db.weight = request.param.get("weight", 1.0) - db.state = request.param.get("state", State.ACTIVE) db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index c14f605c2a..37ee9b3fd3 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -8,7 +8,7 @@ from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF -from redis.multidb.database import State as DBState, AbstractDatabase +from redis.multidb.database import AbstractDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy @@ -166,26 +166,14 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - sleep(0.15) assert client.set('key', 'value') == 'OK2' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - sleep(0.22) assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -215,10 +203,6 @@ def test_execute_command_throws_exception_on_failed_initialization( assert mock_hc.check_health.call_count == 3 - assert mock_db.state == DBState.DISCONNECTED - assert mock_db1.state == DBState.DISCONNECTED - assert mock_db2.state == DBState.DISCONNECTED - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -277,18 +261,11 @@ def test_add_database_makes_new_database_active( assert client.set('key', 'value') == 'OK2' assert mock_hc.check_health.call_count == 2 - assert mock_db.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE - client.add_database(mock_db1) assert mock_hc.check_health.call_count == 3 assert client.set('key', 'value') == 'OK1' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -319,17 +296,10 @@ def test_remove_highest_weighted_database( assert client.set('key', 'value') == 'OK1' assert mock_hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - client.remove_database(mock_db1) assert client.set('key', 'value') == 'OK2' - assert mock_db.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -360,19 +330,11 @@ def test_update_database_weight_to_be_highest( assert client.set('key', 'value') == 'OK1' assert mock_hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - client.update_database_weight(mock_db2, 0.8) assert mock_db2.weight == 0.8 assert client.set('key', 'value') == 'OK2' - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.PASSIVE - assert mock_db2.state == DBState.ACTIVE - @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', [ @@ -491,17 +453,9 @@ def test_set_active_database( assert client.set('key', 'value') == 'OK1' assert mock_hc.check_health.call_count == 3 - assert mock_db.state == DBState.PASSIVE - assert mock_db1.state == DBState.ACTIVE - assert mock_db2.state == DBState.PASSIVE - client.set_active_database(mock_db) assert client.set('key', 'value') == 'OK' - assert mock_db.state == DBState.ACTIVE - assert mock_db1.state == DBState.PASSIVE - assert mock_db2.state == DBState.PASSIVE - with pytest.raises(ValueError, match='Given database is not a member of database list'): client.set_active_database(Mock(spec=AbstractDatabase)) diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 9601638913..08bd8ab0c4 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -1,5 +1,5 @@ from redis.backoff import ExponentialBackoff -from redis.multidb.database import Database, State +from redis.multidb.database import Database from redis.multidb.healthcheck import EchoHealthCheck from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -14,7 +14,7 @@ def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True assert mock_client.execute_command.call_count == 3 @@ -26,7 +26,7 @@ def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, moc """ mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'wrong'] hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == False assert mock_client.execute_command.call_count == 3 @@ -35,7 +35,7 @@ def test_database_close_circuit_on_successful_healthcheck(self, mock_client, moc mock_client.execute_command.side_effect = [ConnectionError, ConnectionError, 'healthcheck'] mock_cb.state = CBState.HALF_OPEN hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) - db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) + db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True assert mock_client.execute_command.call_count == 3 \ No newline at end of file diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index b347fe50ba..4182962fb1 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -3,6 +3,7 @@ import pytest +from redis import Redis from redis.backoff import NoBackoff, ExponentialBackoff from redis.event import EventDispatcher, EventListenerInterface from redis.multidb.client import MultiDBClient @@ -42,12 +43,18 @@ def fault_injector_client(): return FaultInjectorClient(url) @pytest.fixture() -def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener]: - endpoint_config = get_endpoint_config('re-active-active') +def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: + client_class = request.param.get('client_class', Redis) + + if client_class == Redis: + endpoint_config = get_endpoint_config('re-active-active') + else: + endpoint_config = get_endpoint_config('re-active-active-oss-cluster') + username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) - command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.5, base=0.05), retries=3)) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=2, base=0.05), retries=10)) # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. @@ -82,13 +89,14 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen db_configs.append(db_config1) config = MultiDbConfig( + client_class=client_class, databases_config=db_configs, command_retry=command_retry, failure_threshold=failure_threshold, health_check_interval=health_check_interval, - health_check_backoff=ExponentialBackoff(cap=0.5, base=0.05), - health_check_retries=3, event_dispatcher=event_dispatcher, + health_check_backoff=ExponentialBackoff(cap=5, base=0.5), + health_check_retries=3, ) - return MultiDBClient(config), listener \ No newline at end of file + return MultiDBClient(config), listener, endpoint_config \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 071babb6c0..967fa43cdb 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -5,17 +5,16 @@ import pytest +from redis import Redis, RedisCluster from redis.client import Pipeline -from tests.test_scenario.conftest import get_endpoint_config from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) -def trigger_network_failure_action(fault_injector_client, event: threading.Event = None): - endpoint_config = get_endpoint_config('re-active-active') +def trigger_network_failure_action(fault_injector_client, config, event: threading.Event = None): action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 2, "cluster_index": 0} + parameters={"bdb_id": config['bdb_id'], "delay": 2, "cluster_index": 0} ) result = fault_injector_client.trigger_action(action_request) @@ -31,29 +30,32 @@ def trigger_network_failure_action(fault_injector_client, event: threading.Event logger.info(f"Action completed. Status: {status_result['status']}") -class TestActiveActiveStandalone: +class TestActiveActive: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(3) + sleep(5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - r_multi_db, listener = r_multi_db - # Client initialized on the first command. r_multi_db.set('key', 'value') thread.start() @@ -61,32 +63,33 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector # Execute commands before network failure while not event.is_set(): assert r_multi_db.get('key') == 'value' - sleep(0.1) + sleep(0.5) - # Execute commands after network failure - for _ in range(3): + # Execute commands until database failover + while not listener.is_changed_flag: assert r_multi_db.get('key') == 'value' - sleep(0.1) - - assert listener.is_changed_flag == True + sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - r_multi_db, listener = r_multi_db - # Client initialized on first pipe execution. with r_multi_db.pipeline() as pipe: pipe.set('{hash}key1', 'value1') @@ -109,10 +112,10 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute pipeline until database failover + for _ in range(5): with r_multi_db.pipeline() as pipe: pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -121,27 +124,28 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) - - assert listener.is_changed_flag == True + sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - r_multi_db, listener = r_multi_db - # Client initialized on first pipe execution. pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') @@ -156,6 +160,7 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject # Execute pipeline before network failure while not event.is_set(): + pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') pipe.set('{hash}key3', 'value3') @@ -163,10 +168,11 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute pipeline until database failover + for _ in range(5): + pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') pipe.set('{hash}key3', 'value3') @@ -174,27 +180,28 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) - - assert listener.is_changed_flag == True + sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - r_multi_db, listener = r_multi_db - def callback(pipe: Pipeline): pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -207,34 +214,35 @@ def callback(pipe: Pipeline): r_multi_db.transaction(callback) thread.start() - # Execute pipeline before network failure + # Execute transaction before network failure while not event.is_set(): r_multi_db.transaction(callback) - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute transaction until database failover + while not listener.is_changed_flag: r_multi_db.transaction(callback) - sleep(0.1) - - assert listener.is_changed_flag == True + sleep(0.5) @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - - r_multi_db, listener = r_multi_db data = json.dumps({'message': 'test'}) messages_count = 0 @@ -249,37 +257,38 @@ def handler(message): pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) thread.start() - # Execute pipeline before network failure + # Execute publish before network failure while not event.is_set(): r_multi_db.publish('test-channel', data) - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute publish until database failover + while not listener.is_changed_flag: r_multi_db.publish('test-channel', data) - sleep(0.1) + sleep(0.5) pubsub_thread.stop() - - assert listener.is_changed_flag == True assert messages_count > 5 @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_sharded_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - - r_multi_db, listener = r_multi_db data = json.dumps({'message': 'test'}) messages_count = 0 @@ -291,20 +300,22 @@ def handler(message): # Assign a handler and run in a separate thread. pubsub.ssubscribe(**{'test-channel': handler}) - pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) + pubsub_thread = pubsub.run_in_thread( + sleep_time=0.1, + daemon=True, + sharded_pubsub=True + ) thread.start() - # Execute pipeline before network failure + # Execute publish before network failure while not event.is_set(): r_multi_db.spublish('test-channel', data) - sleep(0.1) + sleep(0.5) - # Execute pipeline after network failure - for _ in range(3): + # Execute publish until database failover + while not listener.is_changed_flag: r_multi_db.spublish('test-channel', data) - sleep(0.1) + sleep(0.5) pubsub_thread.stop() - - assert listener.is_changed_flag == True assert messages_count > 5 \ No newline at end of file