9
9
from redis .multidb .command_executor import DefaultCommandExecutor
10
10
from redis .multidb .config import MultiDbConfig , DEFAULT_GRACE_PERIOD
11
11
from redis .multidb .circuit import State as CBState , CircuitBreaker
12
- from redis .multidb .database import State as DBState , Database , AbstractDatabase , Databases
12
+ from redis .multidb .database import Database , AbstractDatabase , Databases
13
13
from redis .multidb .exception import NoValidDatabaseException
14
14
from redis .multidb .failure_detector import FailureDetector
15
15
from redis .multidb .healthcheck import HealthCheck
@@ -78,13 +78,8 @@ def raise_exception_on_failed_hc(error):
78
78
79
79
# Set states according to a weights and circuit state
80
80
if database .circuit .state == CBState .CLOSED and not is_active_db_found :
81
- database .state = DBState .ACTIVE
82
81
self .command_executor .active_database = database
83
82
is_active_db_found = True
84
- elif database .circuit .state == CBState .CLOSED and is_active_db_found :
85
- database .state = DBState .PASSIVE
86
- else :
87
- database .state = DBState .DISCONNECTED
88
83
89
84
if not is_active_db_found :
90
85
raise NoValidDatabaseException ('Initial connection failed - no active database found' )
@@ -115,8 +110,6 @@ def set_active_database(self, database: AbstractDatabase) -> None:
115
110
116
111
if database .circuit .state == CBState .CLOSED :
117
112
highest_weighted_db , _ = self ._databases .get_top_n (1 )[0 ]
118
- highest_weighted_db .state = DBState .PASSIVE
119
- database .state = DBState .ACTIVE
120
113
self .command_executor .active_database = database
121
114
return
122
115
@@ -138,9 +131,7 @@ def add_database(self, database: AbstractDatabase):
138
131
139
132
def _change_active_database (self , new_database : AbstractDatabase , highest_weight_database : AbstractDatabase ):
140
133
if new_database .weight > highest_weight_database .weight and new_database .circuit .state == CBState .CLOSED :
141
- new_database .state = DBState .ACTIVE
142
134
self .command_executor .active_database = new_database
143
- highest_weight_database .state = DBState .PASSIVE
144
135
145
136
def remove_database (self , database : Database ):
146
137
"""
@@ -150,7 +141,6 @@ def remove_database(self, database: Database):
150
141
highest_weighted_db , highest_weight = self ._databases .get_top_n (1 )[0 ]
151
142
152
143
if highest_weight <= weight and highest_weighted_db .circuit .state == CBState .CLOSED :
153
- highest_weighted_db .state = DBState .ACTIVE
154
144
self .command_executor .active_database = highest_weighted_db
155
145
156
146
def update_database_weight (self , database : AbstractDatabase , weight : float ):
@@ -240,7 +230,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep
240
230
database .circuit .state = CBState .OPEN
241
231
elif is_healthy and database .circuit .state != CBState .CLOSED :
242
232
database .circuit .state = CBState .CLOSED
243
- except (ConnectionError , TimeoutError , socket .timeout , ConnectionRefusedError ) as e :
233
+ except (ConnectionError , TimeoutError , socket .timeout , ConnectionRefusedError , ValueError ) as e :
244
234
if database .circuit .state != CBState .OPEN :
245
235
database .circuit .state = CBState .OPEN
246
236
is_healthy = False
@@ -334,7 +324,7 @@ def execute(self) -> List[Any]:
334
324
335
325
class PubSub :
336
326
"""
337
- PubSub object for multi- database client.
327
+ PubSub object for multi database client.
338
328
"""
339
329
def __init__ (self , client : MultiDBClient , ** kwargs ):
340
330
"""Initialize the PubSub object for a multi-database client.
@@ -438,18 +428,33 @@ def get_message(
438
428
ignore_subscribe_messages = ignore_subscribe_messages , timeout = timeout
439
429
)
440
430
441
- get_sharded_message = get_message
431
+ def get_sharded_message (
432
+ self , ignore_subscribe_messages : bool = False , timeout : float = 0.0
433
+ ):
434
+ """
435
+ Get the next message if one is available in a sharded channel, otherwise None.
436
+
437
+ If timeout is specified, the system will wait for `timeout` seconds
438
+ before returning. Timeout should be specified as a floating point
439
+ number, or None, to wait indefinitely.
440
+ """
441
+ return self ._client .command_executor .execute_pubsub_method (
442
+ 'get_sharded_message' ,
443
+ ignore_subscribe_messages = ignore_subscribe_messages , timeout = timeout
444
+ )
442
445
443
446
def run_in_thread (
444
447
self ,
445
448
sleep_time : float = 0.0 ,
446
449
daemon : bool = False ,
447
450
exception_handler : Optional [Callable ] = None ,
451
+ sharded_pubsub : bool = False ,
448
452
) -> "PubSubWorkerThread" :
449
453
return self ._client .command_executor .execute_pubsub_run_in_thread (
450
454
sleep_time = sleep_time ,
451
455
daemon = daemon ,
452
456
exception_handler = exception_handler ,
453
- pubsub = self
457
+ pubsub = self ,
458
+ sharded_pubsub = sharded_pubsub ,
454
459
)
455
460
0 commit comments