Skip to content

Commit 0a61a81

Browse files
Fede Kamelharfede-kamel
authored andcommitted
feat: Add configurable batch_size and max_workers to embed method
Fixes #534 This PR makes the embed batch size configurable, allowing users to customize the batch size based on their specific use cases and constraints. Changes: - Add optional batch_size parameter to Client.embed() and AsyncClient.embed() - Add optional max_workers parameter to Client.embed() for thread pool control - Default behavior remains unchanged (batch_size=96 from config) - Full backward compatibility maintained The implementation allows users to: - Use smaller batches to reduce memory usage - Use larger batches to reduce API calls - Control thread pool size for rate limiting scenarios - Optimize for their specific embedding model and text sizes
1 parent af1aee2 commit 0a61a81

File tree

3 files changed

+386
-29
lines changed

3 files changed

+386
-29
lines changed

demo_configurable_batch_size.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Demo script for the configurable batch size feature in Cohere SDK.
4+
5+
This demonstrates how to use the new batch_size and max_workers parameters
6+
to control embedding batch processing.
7+
"""
8+
9+
import os
10+
import time
11+
import cohere
12+
13+
# Initialize client (requires CO_API_KEY environment variable)
14+
client = cohere.Client()
15+
16+
# Sample texts for embedding
17+
texts = [f"Text document number {i}" for i in range(20)]
18+
19+
print(f"Embedding {len(texts)} texts...")
20+
print()
21+
22+
# Example 1: Default behavior (batch_size=96)
23+
print("1. Default behavior (batch_size=96):")
24+
start = time.time()
25+
response = client.embed(
26+
texts=texts,
27+
model="embed-english-v3.0",
28+
input_type="search_document"
29+
)
30+
print(f" Time: {time.time() - start:.2f}s")
31+
print(f" Number of embeddings: {len(response.embeddings)}")
32+
print()
33+
34+
# Example 2: Custom small batch size
35+
print("2. Custom small batch size (batch_size=5):")
36+
start = time.time()
37+
response = client.embed(
38+
texts=texts,
39+
model="embed-english-v3.0",
40+
input_type="search_document",
41+
batch_size=5 # Will make 4 API calls for 20 texts
42+
)
43+
print(f" Time: {time.time() - start:.2f}s")
44+
print(f" Number of embeddings: {len(response.embeddings)}")
45+
print()
46+
47+
# Example 3: Custom batch size with fewer workers
48+
print("3. Custom batch size with fewer workers (batch_size=5, max_workers=2):")
49+
start = time.time()
50+
response = client.embed(
51+
texts=texts,
52+
model="embed-english-v3.0",
53+
input_type="search_document",
54+
batch_size=5,
55+
max_workers=2 # Limit concurrency to 2 threads
56+
)
57+
print(f" Time: {time.time() - start:.2f}s")
58+
print(f" Number of embeddings: {len(response.embeddings)}")
59+
print()
60+
61+
# Example 4: Large batch size (all in one API call)
62+
print("4. Large batch size (batch_size=100):")
63+
start = time.time()
64+
response = client.embed(
65+
texts=texts,
66+
model="embed-english-v3.0",
67+
input_type="search_document",
68+
batch_size=100 # All texts in a single API call
69+
)
70+
print(f" Time: {time.time() - start:.2f}s")
71+
print(f" Number of embeddings: {len(response.embeddings)}")
72+
print()
73+
74+
print("Demo completed!")
75+
print()
76+
print("Key benefits of configurable batch size:")
77+
print("- batch_size: Control memory usage and API call granularity")
78+
print("- max_workers: Control concurrency for rate limiting or resource constraints")
79+
print("- Backward compatible: Defaults to existing behavior if not specified")

src/cohere/client.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
import asyncio
2+
import logging
23
import os
34
import typing
45
from concurrent.futures import ThreadPoolExecutor
5-
from tokenizers import Tokenizer # type: ignore
6-
import logging
76

87
import httpx
9-
10-
from cohere.types.detokenize_response import DetokenizeResponse
11-
from cohere.types.tokenize_response import TokenizeResponse
12-
13-
from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate
14-
from .base_client import BaseCohere, AsyncBaseCohere, OMIT
8+
from . import EmbeddingType, EmbedInputType, EmbedRequestTruncate, EmbedResponse
9+
from .base_client import OMIT, AsyncBaseCohere, BaseCohere
1510
from .config import embed_batch_size
1611
from .core import RequestOptions
1712
from .environment import ClientEnvironment
18-
from .manually_maintained.cache import CacheMixin
1913
from .manually_maintained import tokenizers as local_tokenizers
14+
from .manually_maintained.cache import CacheMixin
2015
from .overrides import run_overrides
21-
from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils
16+
from .utils import AsyncSdkUtils, SyncSdkUtils, async_wait, merge_embed_responses, wait
17+
from tokenizers import Tokenizer # type: ignore
18+
19+
from cohere.types.detokenize_response import DetokenizeResponse
20+
from cohere.types.tokenize_response import TokenizeResponse
2221

2322
logger = logging.getLogger(__name__)
2423
run_overrides()
@@ -188,6 +187,8 @@ def embed(
188187
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
189188
request_options: typing.Optional[RequestOptions] = None,
190189
batching: typing.Optional[bool] = True,
190+
batch_size: typing.Optional[int] = None,
191+
max_workers: typing.Optional[int] = None,
191192
) -> EmbedResponse:
192193
# skip batching for images for now
193194
if batching is False or images is not OMIT:
@@ -203,23 +204,34 @@ def embed(
203204
)
204205

205206
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
206-
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]
207-
208-
responses = [
209-
response
210-
for response in self._executor.map(
211-
lambda text_batch: BaseCohere.embed(
212-
self,
213-
texts=text_batch,
214-
model=model,
215-
input_type=input_type,
216-
embedding_types=embedding_types,
217-
truncate=truncate,
218-
request_options=request_options,
219-
),
220-
texts_batches,
221-
)
222-
]
207+
effective_batch_size = batch_size if batch_size is not None else embed_batch_size
208+
texts_batches = [textsarr[i : i + effective_batch_size] for i in range(0, len(textsarr), effective_batch_size)]
209+
210+
# Use custom executor if max_workers is specified
211+
executor = self._executor
212+
if max_workers is not None:
213+
executor = ThreadPoolExecutor(max_workers=max_workers)
214+
215+
try:
216+
responses = [
217+
response
218+
for response in executor.map(
219+
lambda text_batch: BaseCohere.embed(
220+
self,
221+
texts=text_batch,
222+
model=model,
223+
input_type=input_type,
224+
embedding_types=embedding_types,
225+
truncate=truncate,
226+
request_options=request_options,
227+
),
228+
texts_batches,
229+
)
230+
]
231+
finally:
232+
# Clean up custom executor if created
233+
if max_workers is not None:
234+
executor.shutdown(wait=False)
223235

224236
return merge_embed_responses(responses)
225237

@@ -380,6 +392,8 @@ async def embed(
380392
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
381393
request_options: typing.Optional[RequestOptions] = None,
382394
batching: typing.Optional[bool] = True,
395+
batch_size: typing.Optional[int] = None,
396+
max_workers: typing.Optional[int] = None,
383397
) -> EmbedResponse:
384398
# skip batching for images for now
385399
if batching is False or images is not OMIT:
@@ -395,8 +409,15 @@ async def embed(
395409
)
396410

397411
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
398-
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]
399-
412+
effective_batch_size = batch_size if batch_size is not None else embed_batch_size
413+
texts_batches = [textsarr[i : i + effective_batch_size] for i in range(0, len(textsarr), effective_batch_size)]
414+
415+
# Note: max_workers parameter is not used in async version since asyncio.gather
416+
# handles concurrency differently than ThreadPoolExecutor
417+
if max_workers is not None:
418+
# Log a warning or silently ignore - asyncio manages its own concurrency
419+
pass
420+
400421
responses = typing.cast(
401422
typing.List[EmbedResponse],
402423
await asyncio.gather(

0 commit comments

Comments
 (0)