@@ -71,6 +71,8 @@ def fn(*args, **kwargs):
7171
7272
7373class Client (BaseCohere , CacheMixin ):
74+ _executor : ThreadPoolExecutor
75+
7476 def __init__ (
7577 self ,
7678 api_key : typing .Optional [typing .Union [str , typing .Callable [[], str ]]] = None ,
@@ -80,10 +82,13 @@ def __init__(
8082 client_name : typing .Optional [str ] = None ,
8183 timeout : typing .Optional [float ] = None ,
8284 httpx_client : typing .Optional [httpx .Client ] = None ,
85+ thread_pool_executor : ThreadPoolExecutor = ThreadPoolExecutor (64 )
8386 ):
8487 if api_key is None :
8588 api_key = _get_api_key_from_environment ()
8689
90+ self ._executor = thread_pool_executor
91+
8792 BaseCohere .__init__ (
8893 self ,
8994 base_url = base_url ,
@@ -108,8 +113,6 @@ def __exit__(self, exc_type, exc_value, traceback):
108113
109114 wait = wait
110115
111- _executor = ThreadPoolExecutor (64 )
112-
113116 def embed (
114117 self ,
115118 * ,
@@ -250,6 +253,8 @@ def fetch_tokenizer(self, *, model: str) -> Tokenizer:
250253
251254
252255class AsyncClient (AsyncBaseCohere , CacheMixin ):
256+ _executor : ThreadPoolExecutor
257+
253258 def __init__ (
254259 self ,
255260 api_key : typing .Optional [typing .Union [str , typing .Callable [[], str ]]] = None ,
@@ -259,10 +264,13 @@ def __init__(
259264 client_name : typing .Optional [str ] = None ,
260265 timeout : typing .Optional [float ] = None ,
261266 httpx_client : typing .Optional [httpx .AsyncClient ] = None ,
267+ thread_pool_executor : ThreadPoolExecutor = ThreadPoolExecutor (64 )
262268 ):
263269 if api_key is None :
264270 api_key = _get_api_key_from_environment ()
265271
272+ self ._executor = thread_pool_executor
273+
266274 AsyncBaseCohere .__init__ (
267275 self ,
268276 base_url = base_url ,
@@ -287,8 +295,6 @@ async def __aexit__(self, exc_type, exc_value, traceback):
287295
288296 wait = async_wait
289297
290- _executor = ThreadPoolExecutor (64 )
291-
292298 async def embed (
293299 self ,
294300 * ,
0 commit comments