Skip to content

Commit

Permalink
Enhance embed_documents to make use of Cohere's ability to request …
Browse files Browse the repository at this point in the history
…multiple embeddings at once (#350)

The [Cohere embedding
provider](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html)
supports asking for multiple embeddings at the same time.

The current implementation of `embed_documents` always iterates across
the submitted texts and attempts to embed them one-by-one.

This PR updates this `embed_documents` logic to check whether we are
using the cohere provider, and if so request all of the embeddings at
once.
  • Loading branch information
jimfingal authored Feb 11, 2025
1 parent 6355b0f commit 925cd56
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
61 changes: 48 additions & 13 deletions libs/aws/langchain_aws/embeddings/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class BedrockEmbeddings(BaseModel, Embeddings):
protected_namespaces=(),
)

@property
def provider(self) -> str:
"""Provider of the model."""
return self.model_id.split(".")[0]

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""
Expand Down Expand Up @@ -121,20 +126,38 @@ def validate_environment(self) -> Self:
return self

def _embedding_func(self, text: str) -> List[float]:
"""Call out to Bedrock embedding endpoint."""
"""Call out to Bedrock embedding endpoint with a single text."""
# replace newlines, which can negatively affect performance.
text = text.replace(os.linesep, " ")

# format input body for provider
provider = self.model_id.split(".")[0]
input_body: Dict[str, Any] = {}
if provider == "cohere":
input_body["input_type"] = "search_document"
input_body["texts"] = [text]
if self.provider == "cohere":
response_body = self._invoke_model(
input_body={
"input_type": "search_document",
"texts": [text],
}
)
return response_body.get("embeddings")[0]
else:
# includes common provider == "amazon"
input_body["inputText"] = text
response_body = self._invoke_model(
input_body={"inputText": text},
)
return response_body.get("embedding")

def _cohere_multi_embedding(self, texts: List[str]) -> List[float]:
"""Call out to Cohere Bedrock embedding endpoint with multiple inputs."""
# replace newlines, which can negatively affect performance.
texts = [text.replace(os.linesep, " ") for text in texts]

return self._invoke_model(
input_body={
"input_type": "search_document",
"texts": texts,
}
).get("embeddings")

def _invoke_model(self, input_body: Dict[str, Any] = {}) -> Dict[str, Any]:
if self.model_kwargs:
input_body = {**input_body, **self.model_kwargs}

Expand All @@ -149,11 +172,7 @@ def _embedding_func(self, text: str) -> List[float]:
)

response_body = json.loads(response.get("body").read())
if provider == "cohere":
return response_body.get("embeddings")[0]
else:
return response_body.get("embedding")

return response_body
except Exception as e:
logging.error(f"Error raised by inference endpoint: {e}")
raise e
Expand All @@ -173,6 +192,22 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
Returns:
List of embeddings, one for each text.
"""

# If we are able to make use of Cohere's multiple embeddings, use that
if self.provider == "cohere":
return self._embed_cohere_documents(texts)
else:
return self._iteratively_embed_documents(texts)

def _embed_cohere_documents(self, texts: List[str]) -> List[List[float]]:
response = self._cohere_multi_embedding(texts)

if self.normalize:
response = [self._normalize_vector(embedding) for embedding in response]

return response

def _iteratively_embed_documents(self, texts: List[str]) -> List[List[float]]:
results = []
for text in texts:
response = self._embedding_func(text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ def bedrock_embeddings_v2() -> BedrockEmbeddings:
)


@pytest.fixture
def cohere_embeddings_v3() -> BedrockEmbeddings:
return BedrockEmbeddings(
model_id="cohere.embed-english-v3",
)


@pytest.mark.scheduled
def test_bedrock_embedding_documents(bedrock_embeddings) -> None:
documents = ["foo bar"]
Expand Down Expand Up @@ -101,3 +108,21 @@ def test_embed_query_with_size(bedrock_embeddings_v2) -> None:
output = bedrock_embeddings_v2.embed_query(prompt_data)
assert len(response[0]) == 256
assert len(output) == 256


@pytest.mark.scheduled
def test_bedrock_cohere_embedding_documents(cohere_embeddings_v3) -> None:
documents = ["foo bar"]
output = cohere_embeddings_v3.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 1024


@pytest.mark.scheduled
def test_bedrock_cohere_embedding_documents_multiple(cohere_embeddings_v3) -> None:
documents = ["foo bar", "bar foo", "foo"]
output = cohere_embeddings_v3.embed_documents(documents)
assert len(output) == 3
assert len(output[0]) == 1024
assert len(output[1]) == 1024
assert len(output[2]) == 1024

0 comments on commit 925cd56

Please sign in to comment.