|
6 | 6 | import httpx |
7 | 7 | import urllib.parse |
8 | 8 | from overrides import override |
| 9 | +from tenacity import ( |
| 10 | + RetryError, |
| 11 | + Retrying, |
| 12 | + before_sleep_log, |
| 13 | + retry_if_exception, |
| 14 | + stop_after_attempt, |
| 15 | + wait_exponential, |
| 16 | + wait_random_exponential, |
| 17 | +) |
9 | 18 |
|
10 | 19 | from chromadb.api.collection_configuration import ( |
11 | 20 | CreateCollectionConfiguration, |
|
57 | 66 | logger = logging.getLogger(__name__) |
58 | 67 |
|
59 | 68 |
|
| 69 | +def is_retryable_exception(exception: BaseException) -> bool: |
| 70 | + if isinstance( |
| 71 | + exception, |
| 72 | + ( |
| 73 | + httpx.ConnectError, |
| 74 | + httpx.ConnectTimeout, |
| 75 | + httpx.ReadTimeout, |
| 76 | + httpx.WriteTimeout, |
| 77 | + httpx.PoolTimeout, |
| 78 | + httpx.NetworkError, |
| 79 | + httpx.RemoteProtocolError, |
| 80 | + ), |
| 81 | + ): |
| 82 | + return True |
| 83 | + |
| 84 | + if isinstance(exception, httpx.HTTPStatusError): |
| 85 | + # Retry on server errors that might be temporary |
| 86 | + return exception.response.status_code in [502, 503, 504] |
| 87 | + |
| 88 | + return False |
| 89 | + |
| 90 | + |
60 | 91 | class FastAPI(BaseHTTPClient, ServerAPI): |
61 | 92 | def __init__(self, system: System): |
62 | 93 | super().__init__(system) |
@@ -97,20 +128,62 @@ def __init__(self, system: System): |
97 | 128 | self._session.headers[header] = value.get_secret_value() |
98 | 129 |
|
99 | 130 | def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any: |
100 | | - # If the request has json in kwargs, use orjson to serialize it, |
101 | | - # remove it from kwargs, and add it to the content parameter |
102 | | - # This is because httpx uses a slower json serializer |
103 | | - if "json" in kwargs: |
104 | | - data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY) |
105 | | - kwargs["content"] = data |
106 | | - |
107 | | - # Unlike requests, httpx does not automatically escape the path |
108 | | - escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) |
109 | | - url = self._api_url + escaped_path |
110 | | - |
111 | | - response = self._session.request(method, url, **cast(Any, kwargs)) |
112 | | - BaseHTTPClient._raise_chroma_error(response) |
113 | | - return orjson.loads(response.text) |
| 131 | + def _send_request() -> Any: |
| 132 | + # If the request has json in kwargs, use orjson to serialize it, |
| 133 | + # remove it from kwargs, and add it to the content parameter |
| 134 | + # This is because httpx uses a slower json serializer |
| 135 | + if "json" in kwargs: |
| 136 | + data = orjson.dumps( |
| 137 | + kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY |
| 138 | + ) |
| 139 | + kwargs["content"] = data |
| 140 | + |
| 141 | + # Unlike requests, httpx does not automatically escape the path |
| 142 | + escaped_path = urllib.parse.quote( |
| 143 | + path, safe="/", encoding=None, errors=None |
| 144 | + ) |
| 145 | + url = self._api_url + escaped_path |
| 146 | + |
| 147 | + response = self._session.request(method, url, **cast(Any, kwargs)) |
| 148 | + BaseHTTPClient._raise_chroma_error(response) |
| 149 | + return orjson.loads(response.text) |
| 150 | + |
| 151 | + retry_config = self._settings.retry_config |
| 152 | + |
| 153 | + if retry_config is None: |
| 154 | + return _send_request() |
| 155 | + |
| 156 | + min_delay = max(float(retry_config.min_delay), 0.0) |
| 157 | + max_delay = max(float(retry_config.max_delay), min_delay) |
| 158 | + multiplier = max(min_delay, 1e-3) |
| 159 | + exp_base = retry_config.factor if retry_config.factor > 0 else 2.0 |
| 160 | + |
| 161 | + wait_args = { |
| 162 | + "multiplier": multiplier, |
| 163 | + "min": min_delay, |
| 164 | + "max": max_delay, |
| 165 | + "exp_base": exp_base, |
| 166 | + } |
| 167 | + |
| 168 | + wait_strategy = ( |
| 169 | + wait_random_exponential(**wait_args) |
| 170 | + if retry_config.jitter |
| 171 | + else wait_exponential(**wait_args) |
| 172 | + ) |
| 173 | + |
| 174 | + retrying = Retrying( |
| 175 | + stop=stop_after_attempt(retry_config.max_attempts), |
| 176 | + wait=wait_strategy, |
| 177 | + retry=retry_if_exception(is_retryable_exception), |
| 178 | + before_sleep=before_sleep_log(logger, logging.INFO), |
| 179 | + reraise=True, |
| 180 | + ) |
| 181 | + |
| 182 | + try: |
| 183 | + return retrying(_send_request) |
| 184 | + except RetryError as e: |
| 185 | + # Re-raise the last exception that caused the retry to fail |
| 186 | + raise e.last_attempt.exception() from None |
114 | 187 |
|
115 | 188 | @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) |
116 | 189 | @override |
|
0 commit comments