Skip to content

Commit 11707a1

Browse files
Add thread_pool_executor to constructor (#536)
1 parent add37bf commit 11707a1

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/cohere/client.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def fn(*args, **kwargs):
7171

7272

7373
class Client(BaseCohere, CacheMixin):
74+
_executor: ThreadPoolExecutor
75+
7476
def __init__(
7577
self,
7678
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
@@ -80,10 +82,13 @@ def __init__(
8082
client_name: typing.Optional[str] = None,
8183
timeout: typing.Optional[float] = None,
8284
httpx_client: typing.Optional[httpx.Client] = None,
85+
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64)
8386
):
8487
if api_key is None:
8588
api_key = _get_api_key_from_environment()
8689

90+
self._executor = thread_pool_executor
91+
8792
BaseCohere.__init__(
8893
self,
8994
base_url=base_url,
@@ -108,8 +113,6 @@ def __exit__(self, exc_type, exc_value, traceback):
108113

109114
wait = wait
110115

111-
_executor = ThreadPoolExecutor(64)
112-
113116
def embed(
114117
self,
115118
*,
@@ -250,6 +253,8 @@ def fetch_tokenizer(self, *, model: str) -> Tokenizer:
250253

251254

252255
class AsyncClient(AsyncBaseCohere, CacheMixin):
256+
_executor: ThreadPoolExecutor
257+
253258
def __init__(
254259
self,
255260
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
@@ -259,10 +264,13 @@ def __init__(
259264
client_name: typing.Optional[str] = None,
260265
timeout: typing.Optional[float] = None,
261266
httpx_client: typing.Optional[httpx.AsyncClient] = None,
267+
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64)
262268
):
263269
if api_key is None:
264270
api_key = _get_api_key_from_environment()
265271

272+
self._executor = thread_pool_executor
273+
266274
AsyncBaseCohere.__init__(
267275
self,
268276
base_url=base_url,
@@ -287,8 +295,6 @@ async def __aexit__(self, exc_type, exc_value, traceback):
287295

288296
wait = async_wait
289297

290-
_executor = ThreadPoolExecutor(64)
291-
292298
async def embed(
293299
self,
294300
*,

0 commit comments

Comments
 (0)