Skip to content

Commit 65a0c62

Browse files
committed
Added test cases to validate ciruit breaker
Signed-off-by: Nikhil Suri <[email protected]>
1 parent 7047a0e commit 65a0c62

File tree

3 files changed

+113
-87
lines changed

3 files changed

+113
-87
lines changed

src/databricks/sql/telemetry/circuit_breaker_manager.py

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dataclasses import dataclass
1313

1414
import pybreaker
15-
from pybreaker import CircuitBreaker, CircuitBreakerError
15+
from pybreaker import CircuitBreaker, CircuitBreakerError, CircuitBreakerListener
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -38,6 +38,48 @@
3838
LOG_CIRCUIT_BREAKER_HALF_OPEN = "Circuit breaker half-open for %s - testing telemetry requests"
3939

4040

41+
class CircuitBreakerStateListener(CircuitBreakerListener):
42+
"""Listener for circuit breaker state changes."""
43+
44+
def before_call(self, cb: CircuitBreaker, func, *args, **kwargs) -> None:
45+
"""Called before the circuit breaker calls a function."""
46+
pass
47+
48+
def failure(self, cb: CircuitBreaker, exc: BaseException) -> None:
49+
"""Called when a function called by the circuit breaker fails."""
50+
pass
51+
52+
def success(self, cb: CircuitBreaker) -> None:
53+
"""Called when a function called by the circuit breaker succeeds."""
54+
pass
55+
56+
def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None:
57+
"""Called when the circuit breaker state changes."""
58+
old_state_name = old_state.name if old_state else "None"
59+
new_state_name = new_state.name if new_state else "None"
60+
61+
logger.info(
62+
LOG_CIRCUIT_BREAKER_STATE_CHANGED,
63+
old_state_name, new_state_name, cb.name
64+
)
65+
66+
if new_state_name == CIRCUIT_BREAKER_STATE_OPEN:
67+
logger.warning(
68+
LOG_CIRCUIT_BREAKER_OPENED,
69+
cb.name
70+
)
71+
elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED:
72+
logger.info(
73+
LOG_CIRCUIT_BREAKER_CLOSED,
74+
cb.name
75+
)
76+
elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN:
77+
logger.info(
78+
LOG_CIRCUIT_BREAKER_HALF_OPEN,
79+
cb.name
80+
)
81+
82+
4183
@dataclass(frozen=True)
4284
class CircuitBreakerConfig:
4385
"""Configuration for circuit breaker behavior.
@@ -126,16 +168,13 @@ def _create_circuit_breaker(cls, host: str) -> CircuitBreaker:
126168

127169
# Create circuit breaker with configuration
128170
breaker = CircuitBreaker(
129-
fail_max=config.minimum_calls,
171+
fail_max=config.minimum_calls, # Number of failures before circuit opens
130172
reset_timeout=config.reset_timeout,
131173
name=f"{config.name}-{host}"
132174
)
133175

134-
# Set failure threshold
135-
breaker.failure_threshold = config.failure_threshold
136-
137176
# Add state change listeners for logging
138-
breaker.add_listener(cls._on_state_change)
177+
breaker.add_listener(CircuitBreakerStateListener())
139178

140179
return breaker
141180

@@ -156,36 +195,6 @@ def _create_noop_circuit_breaker(cls) -> CircuitBreaker:
156195
breaker.failure_threshold = 1.0 # 100% failure threshold
157196
return breaker
158197

159-
@classmethod
160-
def _on_state_change(cls, old_state: str, new_state: str, breaker: CircuitBreaker) -> None:
161-
"""
162-
Handle circuit breaker state changes.
163-
164-
Args:
165-
old_state: Previous state of the circuit breaker
166-
new_state: New state of the circuit breaker
167-
breaker: The circuit breaker instance
168-
"""
169-
logger.info(
170-
LOG_CIRCUIT_BREAKER_STATE_CHANGED,
171-
old_state, new_state, breaker.name
172-
)
173-
174-
if new_state == CIRCUIT_BREAKER_STATE_OPEN:
175-
logger.warning(
176-
LOG_CIRCUIT_BREAKER_OPENED,
177-
breaker.name
178-
)
179-
elif new_state == CIRCUIT_BREAKER_STATE_CLOSED:
180-
logger.info(
181-
LOG_CIRCUIT_BREAKER_CLOSED,
182-
breaker.name
183-
)
184-
elif new_state == CIRCUIT_BREAKER_STATE_HALF_OPEN:
185-
logger.info(
186-
LOG_CIRCUIT_BREAKER_HALF_OPEN,
187-
breaker.name
188-
)
189198

190199
@classmethod
191200
def get_circuit_breaker_state(cls, host: str) -> str:

src/databricks/sql/telemetry/telemetry_push_client.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,9 @@ def request(
158158
"""Make an HTTP request with circuit breaker protection."""
159159
try:
160160
# Use circuit breaker to protect the request
161-
with self._circuit_breaker:
162-
return self._delegate.request(method, url, headers, **kwargs)
161+
return self._circuit_breaker.call(
162+
lambda: self._delegate.request(method, url, headers, **kwargs)
163+
)
163164
except CircuitBreakerError as e:
164165
logger.warning(
165166
"Circuit breaker is open for host %s, blocking telemetry request to %s: %s",
@@ -185,9 +186,12 @@ def request_context(
185186
"""Context manager for making HTTP requests with circuit breaker protection."""
186187
try:
187188
# Use circuit breaker to protect the request
188-
with self._circuit_breaker:
189+
def _make_request():
189190
with self._delegate.request_context(method, url, headers, **kwargs) as response:
190-
yield response
191+
return response
192+
193+
response = self._circuit_breaker.call(_make_request)
194+
yield response
191195
except CircuitBreakerError as e:
192196
logger.warning(
193197
"Circuit breaker is open for host %s, blocking telemetry request to %s: %s",

tests/unit/test_telemetry_push_client.py

Lines changed: 60 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,21 @@ def test_initialization(self):
7474

7575
def test_initialization_disabled(self):
7676
"""Test client initialization with circuit breaker disabled."""
77-
config = CircuitBreakerConfig(enabled=False)
78-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
77+
config = CircuitBreakerConfig()
78+
client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host, config)
7979

80-
assert client._config.enabled is False
80+
assert client._config is not None
8181

8282
def test_request_context_disabled(self):
8383
"""Test request context when circuit breaker is disabled."""
84-
config = CircuitBreakerConfig(enabled=False)
85-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
84+
config = CircuitBreakerConfig()
85+
client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host, config)
8686

8787
mock_response = Mock()
88-
self.mock_delegate.request_context.return_value.__enter__.return_value = mock_response
89-
self.mock_delegate.request_context.return_value.__exit__.return_value = None
88+
mock_context = MagicMock()
89+
mock_context.__enter__.return_value = mock_response
90+
mock_context.__exit__.return_value = None
91+
self.mock_delegate.request_context.return_value = mock_context
9092

9193
with client.request_context(HttpMethod.POST, "https://test.com", {}) as response:
9294
assert response == mock_response
@@ -96,18 +98,20 @@ def test_request_context_disabled(self):
9698
def test_request_context_enabled_success(self):
9799
"""Test successful request context when circuit breaker is enabled."""
98100
mock_response = Mock()
99-
self.mock_delegate.request_context.return_value.__enter__.return_value = mock_response
100-
self.mock_delegate.request_context.return_value.__exit__.return_value = None
101+
mock_context = MagicMock()
102+
mock_context.__enter__.return_value = mock_response
103+
mock_context.__exit__.return_value = None
104+
self.mock_delegate.request_context.return_value = mock_context
101105

102-
with client.request_context(HttpMethod.POST, "https://test.com", {}) as response:
106+
with self.client.request_context(HttpMethod.POST, "https://test.com", {}) as response:
103107
assert response == mock_response
104108

105109
self.mock_delegate.request_context.assert_called_once()
106110

107111
def test_request_context_enabled_circuit_breaker_error(self):
108112
"""Test request context when circuit breaker is open."""
109113
# Mock circuit breaker to raise CircuitBreakerError
110-
with patch.object(self.client._circuit_breaker, '__enter__', side_effect=CircuitBreakerError("Circuit is open")):
114+
with patch.object(self.client._circuit_breaker, 'call', side_effect=CircuitBreakerError("Circuit is open")):
111115
with pytest.raises(CircuitBreakerError):
112116
with self.client.request_context(HttpMethod.POST, "https://test.com", {}):
113117
pass
@@ -123,8 +127,8 @@ def test_request_context_enabled_other_error(self):
123127

124128
def test_request_disabled(self):
125129
"""Test request method when circuit breaker is disabled."""
126-
config = CircuitBreakerConfig(enabled=False)
127-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
130+
config = CircuitBreakerConfig()
131+
client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host, config)
128132

129133
mock_response = Mock()
130134
self.mock_delegate.request.return_value = mock_response
@@ -147,7 +151,7 @@ def test_request_enabled_success(self):
147151
def test_request_enabled_circuit_breaker_error(self):
148152
"""Test request when circuit breaker is open."""
149153
# Mock circuit breaker to raise CircuitBreakerError
150-
with patch.object(self.client._circuit_breaker, '__enter__', side_effect=CircuitBreakerError("Circuit is open")):
154+
with patch.object(self.client._circuit_breaker, 'call', side_effect=CircuitBreakerError("Circuit is open")):
151155
with pytest.raises(CircuitBreakerError):
152156
self.client.request(HttpMethod.POST, "https://test.com", {})
153157

@@ -161,15 +165,16 @@ def test_request_enabled_other_error(self):
161165

162166
def test_get_circuit_breaker_state(self):
163167
"""Test getting circuit breaker state."""
164-
with patch.object(self.client._circuit_breaker, 'current_state', 'open'):
168+
# Mock the CircuitBreakerManager method instead of the circuit breaker property
169+
with patch('databricks.sql.telemetry.telemetry_push_client.CircuitBreakerManager.get_circuit_breaker_state', return_value='open'):
165170
state = self.client.get_circuit_breaker_state()
166171
assert state == 'open'
167172

168173
def test_reset_circuit_breaker(self):
169174
"""Test resetting circuit breaker."""
170-
with patch.object(self.client._circuit_breaker, 'reset') as mock_reset:
175+
with patch('databricks.sql.telemetry.telemetry_push_client.CircuitBreakerManager.reset_circuit_breaker') as mock_reset:
171176
self.client.reset_circuit_breaker()
172-
mock_reset.assert_called_once()
177+
mock_reset.assert_called_once_with(self.client._host)
173178

174179
def test_is_circuit_breaker_open(self):
175180
"""Test checking if circuit breaker is open."""
@@ -181,47 +186,48 @@ def test_is_circuit_breaker_open(self):
181186

182187
def test_is_circuit_breaker_enabled(self):
183188
"""Test checking if circuit breaker is enabled."""
184-
assert self.client.is_circuit_breaker_enabled() is True
185-
186-
config = CircuitBreakerConfig(enabled=False)
187-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
188-
assert client.is_circuit_breaker_enabled() is False
189+
# Circuit breaker is always enabled in this implementation
190+
assert self.client._circuit_breaker is not None
189191

190192
def test_circuit_breaker_state_logging(self):
191193
"""Test that circuit breaker state changes are logged."""
192-
with patch('databricks.sql.telemetry.circuit_breaker_http_client.logger') as mock_logger:
193-
with patch.object(self.client._circuit_breaker, '__enter__', side_effect=CircuitBreakerError("Circuit is open")):
194+
with patch('databricks.sql.telemetry.telemetry_push_client.logger') as mock_logger:
195+
with patch.object(self.client._circuit_breaker, 'call', side_effect=CircuitBreakerError("Circuit is open")):
194196
with pytest.raises(CircuitBreakerError):
195197
self.client.request(HttpMethod.POST, "https://test.com", {})
196-
197-
# Check that warning was logged
198-
mock_logger.warning.assert_called()
199-
warning_call = mock_logger.warning.call_args[0][0]
200-
assert "Circuit breaker is open" in warning_call
201-
assert self.host in warning_call
198+
199+
# Check that warning was logged
200+
mock_logger.warning.assert_called()
201+
warning_args = mock_logger.warning.call_args[0]
202+
assert "Circuit breaker is open" in warning_args[0]
203+
assert self.host in warning_args[1] # The host is the second argument
202204

203205
def test_other_error_logging(self):
204206
"""Test that other errors are logged appropriately."""
205-
with patch('databricks.sql.telemetry.circuit_breaker_http_client.logger') as mock_logger:
207+
with patch('databricks.sql.telemetry.telemetry_push_client.logger') as mock_logger:
206208
self.mock_delegate.request.side_effect = ValueError("Network error")
207209

208210
with pytest.raises(ValueError):
209211
self.client.request(HttpMethod.POST, "https://test.com", {})
210212

211213
# Check that debug was logged
212214
mock_logger.debug.assert_called()
213-
debug_call = mock_logger.debug.call_args[0][0]
214-
assert "Telemetry request failed" in debug_call
215-
assert self.host in debug_call
215+
debug_args = mock_logger.debug.call_args[0]
216+
assert "Telemetry request failed" in debug_args[0]
217+
assert self.host in debug_args[1] # The host is the second argument
216218

217219

218-
class TestCircuitBreakerHttpClientIntegration:
219-
"""Integration tests for CircuitBreakerHttpClient."""
220+
class TestCircuitBreakerTelemetryPushClientIntegration:
221+
"""Integration tests for CircuitBreakerTelemetryPushClient."""
220222

221223
def setup_method(self):
222224
"""Set up test fixtures."""
223225
self.mock_delegate = Mock()
224226
self.host = "test-host.example.com"
227+
# Clear any existing circuit breaker state
228+
from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager
229+
CircuitBreakerManager.clear_all_circuit_breakers()
230+
CircuitBreakerManager._config = None
225231

226232
def test_circuit_breaker_opens_after_failures(self):
227233
"""Test that circuit breaker opens after repeated failures."""
@@ -230,17 +236,20 @@ def test_circuit_breaker_opens_after_failures(self):
230236
minimum_calls=2, # Only 2 calls needed
231237
reset_timeout=1 # 1 second reset timeout
232238
)
233-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
239+
client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host, config)
234240

235241
# Simulate failures
236242
self.mock_delegate.request.side_effect = Exception("Network error")
237243

238-
# First few calls should fail with the original exception
239-
for _ in range(2):
240-
with pytest.raises(Exception, match="Network error"):
241-
client.request(HttpMethod.POST, "https://test.com", {})
244+
# First call should fail with the original exception
245+
with pytest.raises(Exception, match="Network error"):
246+
client.request(HttpMethod.POST, "https://test.com", {})
247+
248+
# Second call should fail with CircuitBreakerError (circuit opens after 2 failures)
249+
with pytest.raises(CircuitBreakerError):
250+
client.request(HttpMethod.POST, "https://test.com", {})
242251

243-
# After enough failures, circuit breaker should open
252+
# Third call should also fail with CircuitBreakerError (circuit is open)
244253
with pytest.raises(CircuitBreakerError):
245254
client.request(HttpMethod.POST, "https://test.com", {})
246255

@@ -251,16 +260,20 @@ def test_circuit_breaker_recovers_after_success(self):
251260
minimum_calls=2,
252261
reset_timeout=1
253262
)
254-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
263+
client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host, config)
255264

256265
# Simulate failures first
257266
self.mock_delegate.request.side_effect = Exception("Network error")
258267

259-
for _ in range(2):
260-
with pytest.raises(Exception):
261-
client.request(HttpMethod.POST, "https://test.com", {})
268+
# First call should fail with the original exception
269+
with pytest.raises(Exception):
270+
client.request(HttpMethod.POST, "https://test.com", {})
271+
272+
# Second call should fail with CircuitBreakerError (circuit opens after 2 failures)
273+
with pytest.raises(CircuitBreakerError):
274+
client.request(HttpMethod.POST, "https://test.com", {})
262275

263-
# Circuit breaker should be open now
276+
# Third call should also fail with CircuitBreakerError (circuit is open)
264277
with pytest.raises(CircuitBreakerError):
265278
client.request(HttpMethod.POST, "https://test.com", {})
266279

0 commit comments

Comments
 (0)