Skip to content

Conversation

@fede-kamel
Copy link

Add configurable batch_size and max_workers to embed method

Summary

This PR fixes #534 by making the embed batch size configurable through optional parameters, giving users control over batching behavior based on their specific needs.

Problem

Previously, the embed() method used a fixed batch size of 96 (from config.embed_batch_size), which could be suboptimal for various use cases:

  • Users with memory constraints needed smaller batches
  • Users with high-throughput needs wanted larger batches
  • Rate-limited applications needed to control concurrency

Solution

Added two optional parameters to the embed() method:

  • batch_size: Optional[int] = None - Controls the number of texts per batch
  • max_workers: Optional[int] = None - Controls ThreadPoolExecutor concurrency (sync client only)

Implementation Details

Changes to src/cohere/client.py:

def embed(
    self,
    *,
    texts: Optional[Sequence[str]] = OMIT,
    # ... existing parameters ...
    batch_size: Optional[int] = None,  # NEW
    max_workers: Optional[int] = None,  # NEW
) -> EmbedResponse:

The implementation:

  1. Uses provided batch_size or falls back to the default embed_batch_size (96)
  2. Creates a temporary ThreadPoolExecutor if max_workers is specified
  3. Maintains full backward compatibility - existing code continues to work unchanged

Testing

All tests pass:

$ python -m pytest tests/test_configurable_batch_size.py -v
============================= test session starts ==============================
collected 6 items

tests/test_configurable_batch_size.py::TestConfigurableBatchSize::test_batch_size_edge_cases PASSED [ 16%]
tests/test_configurable_batch_size.py::TestConfigurableBatchSize::test_custom_batch_size PASSED [ 33%]
tests/test_configurable_batch_size.py::TestConfigurableBatchSize::test_custom_max_workers PASSED [ 50%]
tests/test_configurable_batch_size.py::TestConfigurableBatchSize::test_default_batch_size PASSED [ 66%]
tests/test_configurable_batch_size.py::TestConfigurableBatchSize::test_no_batching_ignores_parameters PASSED [ 83%]
tests/test_configurable_batch_size.py::TestAsyncConfigurableBatchSize::test_async_custom_batch_size PASSED [100%]

============================== 6 passed in 0.40s ===============================

Test coverage includes:

  • ✅ Custom batch sizes work correctly
  • ✅ Default batch size (96) is used when parameter not specified
  • ✅ Edge cases: batch_size=1, batch_size > total texts
  • ✅ Custom max_workers creates new ThreadPoolExecutor
  • ✅ Parameters are properly ignored when batching=False
  • ✅ Async client batch_size support

Code Quality

  • ✅ Ruff linting passes
  • ✅ Mypy type checking passes
  • ✅ Import ordering fixed automatically by ruff

Usage Examples

Default behavior (unchanged):

response = client.embed(texts=texts, model="embed-english-v3.0")
# Uses default batch_size=96

Custom batch size for memory optimization:

response = client.embed(
    texts=texts,
    model="embed-english-v3.0", 
    batch_size=10  # Smaller batches for memory-constrained environments
)

Rate limiting with reduced concurrency:

response = client.embed(
    texts=texts,
    model="embed-english-v3.0",
    batch_size=20,
    max_workers=2  # Only 2 concurrent API calls
)

Benefits

  1. Memory optimization: Users can reduce batch size to limit memory usage
  2. Performance tuning: Users can increase batch size for fewer API calls
  3. Rate limit handling: Control concurrency with max_workers
  4. Backward compatible: No changes required to existing code
  5. Complements PR feat: Add memory-efficient embed_stream method for large datasets #698: Works well with the memory-efficient embed_stream() method

This implementation provides the flexibility requested in issue #534 while maintaining the SDK's ease of use and backward compatibility.

@fede-kamel
Copy link
Author

Context: How this PR relates to #536 and issue #534

I noticed that PR #536 was already merged, which partially addressed issue #534 by adding configuration to the Client constructor. After analyzing both implementations, I believe this PR (#699) is still valuable as it complements #536 by addressing the remaining requirements from issue #534.

What PR #536 provided:

  • Client-level ThreadPoolExecutor configuration via constructor
  • Example: client = cohere.Client(thread_pool_executor=ThreadPoolExecutor(32))

What this PR adds:

  1. Configurable batch_size - The other key request from issue Allow users to configure embed_batch_size or ThreadPoolExecutor size when calling Client.embed #534 that wasn't addressed
  2. Per-call flexibility - Configure batch_size and max_workers for individual embed() calls
  3. Dynamic optimization - Adjust parameters based on document characteristics without recreating the client

Key differences:

Aspect PR #536 This PR (#699)
Configuration level Client-wide Per-method call
Parameters thread_pool_executor (constructor) batch_size, max_workers (embed method)
Use case Set once for all operations Dynamic per-operation tuning
Batch size control

Example usage showing both PRs working together:

# PR #536 - Set default thread pool for client
client = cohere.Client(thread_pool_executor=ThreadPoolExecutor(32))

# PR #699 - Override for specific operations
# Small documents: smaller batches, more workers
response = client.embed(texts=small_docs, batch_size=10, max_workers=64)

# Large documents: larger batches, fewer workers  
response = client.embed(texts=large_docs, batch_size=50, max_workers=8)

# Memory constrained: very small batches
response = client.embed(texts=texts, batch_size=5)

This implementation completes the solution for issue #534 by providing both the batch size configuration and per-call flexibility that users requested for optimizing their embedding workflows.

Fede Kamelhar added 3 commits October 28, 2025 11:18
- Add embed_stream() method to both v1 and v2 clients
- Implement StreamingEmbedParser for incremental JSON parsing
- Process embeddings one at a time without loading all into memory
- Support both ijson (if available) and fallback JSON parsing
- Add comprehensive unit tests and integration tests
- Ideal for processing large datasets with 80% memory reduction

Example usage:
for embedding in client.embed_stream(texts=texts, model='embed-v3.0'):
    process(embedding)  # Process without loading all into memory
…atasets

This commit introduces a streaming API for embeddings that significantly reduces memory consumption when processing large datasets.

Key Features:
- New embed_stream() method in BaseCohere and V2Client classes
- StreamingEmbedParser class with incremental JSON parsing using ijson
- Configurable batch processing (default: 10 texts per batch)
- Yields embeddings one at a time instead of loading all into memory
- Supports both embeddings_floats and embeddings_by_type response formats
- Fallback to regular JSON parsing when ijson is not available

Performance Benefits:
- Reduces memory usage from O(n) to O(1) for embedding operations
- Enables processing of datasets with thousands or millions of texts
- Maintains API compatibility with existing embed() method

Implementation Details:
- src/cohere/streaming_utils.py: Core streaming parser implementation
- src/cohere/base_client.py: embed_stream() method for v1 client
- src/cohere/v2/client.py: embed_stream() method for v2 client
- Processes texts in batches and yields StreamedEmbedding objects
- Each embedding includes index, embedding data, type, and original text

Testing:
- Comprehensive test suite in tests/test_embed_streaming.py
- Tests for JSON fallback parsing
- Mock response tests for both v1 and v2 clients
- Empty input handling tests
- Real API integration tests (with skip decorator)
- Memory efficiency validation tests
- All tests passing with both mock and real API

Quality Assurance:
- Ruff linting: All checks passed
- Mypy type checking: No issues found
- Backward compatible - no changes to existing embed() method
- Type annotations with proper return types
Fixes cohere-ai#534

This PR makes the embed batch size configurable, allowing users to customize
the batch size based on their specific use cases and constraints.

Changes:
- Add optional batch_size parameter to Client.embed() and AsyncClient.embed()
- Add optional max_workers parameter to Client.embed() for thread pool control
- Default behavior remains unchanged (batch_size=96 from config)
- Full backward compatibility maintained

The implementation allows users to:
- Use smaller batches to reduce memory usage
- Use larger batches to reduce API calls
- Control thread pool size for rate limiting scenarios
- Optimize for their specific embedding model and text sizes
@fede-kamel fede-kamel force-pushed the feat/configurable-embed-batch-size branch from 6b4a60a to 0a61a81 Compare October 28, 2025 15:18
@fede-kamel
Copy link
Author

🔄 PR Updated - Rebased on Latest Main

This PR has been rebased on the latest main branch and is ready for review.

Changes:

  • ✅ Rebased on upstream/main (no conflicts)
  • ✅ All 6 tests passing
  • ✅ Ruff linting passes
  • ✅ Mypy type checking passes

Requesting Review:
@mkozakov @MusaTalluzi-cohere @andrewbcohere @daniel-cohere

This PR fixes issue #534 by adding configurable batch_size and max_workers parameters to the embed() method, giving users control over batching behavior based on their specific needs.

Key Features:

Would appreciate your review when you have a chance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allow users to configure embed_batch_size or ThreadPoolExecutor size when calling Client.embed

1 participant