Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
get_caller_location,
mask_api_key,
merge_dicts,
parse_env_var_float,
response_raise_for_status,
)

Expand Down Expand Up @@ -564,10 +565,6 @@ def user_info(self) -> Mapping[str, Any]:
self._user_info = self.api_conn().get_json("ping")
return self._user_info

def set_user_info_if_null(self, info: Mapping[str, Any]):
if not self._user_info:
self._user_info = info

def global_bg_logger(self) -> "_BackgroundLogger":
return getattr(self._override_bg_logger, "logger", None) or self._global_bg_logger.get()

Expand Down Expand Up @@ -629,14 +626,28 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
base_num_retries: Maximum number of retries before giving up and re-raising the exception.
backoff_factor: A multiplier used to determine the time to wait between retries.
The actual wait time is calculated as: backoff_factor * (2 ** retry_count).
default_timeout_secs: Default timeout in seconds for requests that don't specify one.
Prevents indefinite hangs on stale connections.
"""

def __init__(self, *args: Any, base_num_retries: int = 0, backoff_factor: float = 0.5, **kwargs: Any):
def __init__(
self,
*args: Any,
base_num_retries: int = 0,
backoff_factor: float = 0.5,
default_timeout_secs: float = 60,
**kwargs: Any,
):
self.base_num_retries = base_num_retries
self.backoff_factor = backoff_factor
self.default_timeout_secs = default_timeout_secs
super().__init__(*args, **kwargs)

def send(self, *args, **kwargs):
# Apply default timeout if none provided to prevent indefinite hangs
if kwargs.get("timeout") is None:
kwargs["timeout"] = self.default_timeout_secs

num_prev_retries = 0
while True:
try:
Expand All @@ -648,6 +659,14 @@ def send(self, *args, **kwargs):
return response
except (urllib3.exceptions.HTTPError, requests.exceptions.RequestException) as e:
if num_prev_retries < self.base_num_retries:
if isinstance(e, requests.exceptions.ReadTimeout):
# Clear all connection pools to discard stale connections. This
# fixes hangs caused by NAT gateways silently dropping idle TCP
# connections (e.g., Azure's ~4 min timeout). close() calls
# PoolManager.clear() which is thread-safe: in-flight requests
# keep their checked-out connections, and new requests create
# fresh pools on demand.
self.close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that closing the HTTPAdapter will work for subsequent requests? It
seems to be calling methods like
this,
and it's unclear to me whether that's okay? I wonder if another weird thing to
try is to recreate the adapter like roughly self = RetryRequestExceptionsAdapter(...) (even though reassigning self won't strictly
work, but we could reassign a member variable).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i tested with and without concurrency and it works fine.

# Emulates the sleeping logic in the backoff_factor of urllib3 Retry
sleep_s = self.backoff_factor * (2**num_prev_retries)
print("Retrying request after error:", e, file=sys.stderr)
Expand All @@ -669,14 +688,16 @@ def __init__(self, base_url: str, adapter: HTTPAdapter | None = None):
def ping(self) -> bool:
try:
resp = self.get("ping")
_state.set_user_info_if_null(resp.json())
return resp.ok
except requests.exceptions.ConnectionError:
return False

def make_long_lived(self) -> None:
if not self.adapter:
self.adapter = RetryRequestExceptionsAdapter(base_num_retries=10, backoff_factor=0.5)
timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0)
self.adapter = RetryRequestExceptionsAdapter(
base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs
)
self._reset()

@staticmethod
Expand Down Expand Up @@ -721,6 +742,8 @@ def delete(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
return self.session.delete(_urljoin(self.base_url, path), *args, **kwargs)

def get_json(self, object_type: str, args: Mapping[str, Any] | None = None, retries: int = 0) -> Mapping[str, Any]:
# FIXME[matt]: the retry logic seems to be unused and could be n*2 because of the the retry logic
# in the RetryRequestExceptionsAdapter. We should probably remove this.
tries = retries + 1
for i in range(tries):
resp = self.get(f"/{object_type}", params=args)
Expand Down
Loading
Loading