11import asyncio
2+ import logging
23import os
34import typing
45from concurrent .futures import ThreadPoolExecutor
5- from tokenizers import Tokenizer # type: ignore
6- import logging
76
87import httpx
9-
10- from cohere .types .detokenize_response import DetokenizeResponse
11- from cohere .types .tokenize_response import TokenizeResponse
12-
13- from . import EmbedResponse , EmbedInputType , EmbeddingType , EmbedRequestTruncate
14- from .base_client import BaseCohere , AsyncBaseCohere , OMIT
8+ from . import EmbeddingType , EmbedInputType , EmbedRequestTruncate , EmbedResponse
9+ from .base_client import OMIT , AsyncBaseCohere , BaseCohere
1510from .config import embed_batch_size
1611from .core import RequestOptions
1712from .environment import ClientEnvironment
18- from .manually_maintained .cache import CacheMixin
1913from .manually_maintained import tokenizers as local_tokenizers
14+ from .manually_maintained .cache import CacheMixin
2015from .overrides import run_overrides
21- from .utils import wait , async_wait , merge_embed_responses , SyncSdkUtils , AsyncSdkUtils
16+ from .utils import AsyncSdkUtils , SyncSdkUtils , async_wait , merge_embed_responses , wait
17+ from tokenizers import Tokenizer # type: ignore
18+
19+ from cohere .types .detokenize_response import DetokenizeResponse
20+ from cohere .types .tokenize_response import TokenizeResponse
2221
2322logger = logging .getLogger (__name__ )
2423run_overrides ()
@@ -188,6 +187,8 @@ def embed(
188187 truncate : typing .Optional [EmbedRequestTruncate ] = OMIT ,
189188 request_options : typing .Optional [RequestOptions ] = None ,
190189 batching : typing .Optional [bool ] = True ,
190+ batch_size : typing .Optional [int ] = None ,
191+ max_workers : typing .Optional [int ] = None ,
191192 ) -> EmbedResponse :
192193 # skip batching for images for now
193194 if batching is False or images is not OMIT :
@@ -203,23 +204,34 @@ def embed(
203204 )
204205
205206 textsarr : typing .Sequence [str ] = texts if texts is not OMIT and texts is not None else []
206- texts_batches = [textsarr [i : i + embed_batch_size ] for i in range (0 , len (textsarr ), embed_batch_size )]
207-
208- responses = [
209- response
210- for response in self ._executor .map (
211- lambda text_batch : BaseCohere .embed (
212- self ,
213- texts = text_batch ,
214- model = model ,
215- input_type = input_type ,
216- embedding_types = embedding_types ,
217- truncate = truncate ,
218- request_options = request_options ,
219- ),
220- texts_batches ,
221- )
222- ]
207+ effective_batch_size = batch_size if batch_size is not None else embed_batch_size
208+ texts_batches = [textsarr [i : i + effective_batch_size ] for i in range (0 , len (textsarr ), effective_batch_size )]
209+
210+ # Use custom executor if max_workers is specified
211+ executor = self ._executor
212+ if max_workers is not None :
213+ executor = ThreadPoolExecutor (max_workers = max_workers )
214+
215+ try :
216+ responses = [
217+ response
218+ for response in executor .map (
219+ lambda text_batch : BaseCohere .embed (
220+ self ,
221+ texts = text_batch ,
222+ model = model ,
223+ input_type = input_type ,
224+ embedding_types = embedding_types ,
225+ truncate = truncate ,
226+ request_options = request_options ,
227+ ),
228+ texts_batches ,
229+ )
230+ ]
231+ finally :
232+ # Clean up custom executor if created
233+ if max_workers is not None :
234+ executor .shutdown (wait = False )
223235
224236 return merge_embed_responses (responses )
225237
@@ -380,6 +392,8 @@ async def embed(
380392 truncate : typing .Optional [EmbedRequestTruncate ] = OMIT ,
381393 request_options : typing .Optional [RequestOptions ] = None ,
382394 batching : typing .Optional [bool ] = True ,
395+ batch_size : typing .Optional [int ] = None ,
396+ max_workers : typing .Optional [int ] = None ,
383397 ) -> EmbedResponse :
384398 # skip batching for images for now
385399 if batching is False or images is not OMIT :
@@ -395,8 +409,15 @@ async def embed(
395409 )
396410
397411 textsarr : typing .Sequence [str ] = texts if texts is not OMIT and texts is not None else []
398- texts_batches = [textsarr [i : i + embed_batch_size ] for i in range (0 , len (textsarr ), embed_batch_size )]
399-
412+ effective_batch_size = batch_size if batch_size is not None else embed_batch_size
413+ texts_batches = [textsarr [i : i + effective_batch_size ] for i in range (0 , len (textsarr ), effective_batch_size )]
414+
415+ # Note: max_workers parameter is not used in async version since asyncio.gather
416+ # handles concurrency differently than ThreadPoolExecutor
417+ if max_workers is not None :
418+ # Log a warning or silently ignore - asyncio manages its own concurrency
419+ pass
420+
400421 responses = typing .cast (
401422 typing .List [EmbedResponse ],
402423 await asyncio .gather (
0 commit comments