Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: HF API Embedders - use InferenceClient.feature_extraction instead of InferenceClient.post #8794

Merged
merged 9 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
import warnings
from typing import Any, Dict, List, Optional, Union

from tqdm import tqdm
Expand Down Expand Up @@ -96,8 +96,8 @@ def __init__(
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
prefix: str = "",
suffix: str = "",
truncate: bool = True,
normalize: bool = False,
truncate: Optional[bool] = True,
normalize: Optional[bool] = False,
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
Expand All @@ -124,13 +124,11 @@ def __init__(
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
if the backend uses Text Embeddings Inference.
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
It is always set to `True` and cannot be changed.
:param normalize:
Normalizes the embeddings to unit length.
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
if the backend uses Text Embeddings Inference.
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
It is always set to `False` and cannot be changed.
:param batch_size:
Number of documents to process at once.
:param progress_bar:
Expand Down Expand Up @@ -239,18 +237,36 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[
"""
Embed a list of texts in batches.
"""
truncate = self.truncate
normalize = self.normalize

if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
if truncate is not None:
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg)
truncate = None
if normalize is not None:
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg)
normalize = None

all_embeddings = []
for i in tqdm(
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
response = self._client.post(
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
task="feature-extraction",

np_embeddings = self._client.feature_extraction(
# this method does not officially support list of strings, but works as expected
text=batch, # type: ignore[arg-type]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as explained in the comment, the method does what we need but this usage is not officially supported

Copy link
Contributor

@Amnah199 Amnah199 Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding, I looked into the API. From this discussion, am I correct to deduce that both types str and List[str] are supported for text. We are unsure because the docs don't mention List[str] officially but the underlying models do expect lists and return correct results.
In that case, it would make sense to introduce this change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are basically right.
I reached out to the huggingface_hub maintainers here: huggingface/huggingface_hub#2824
You can read this message to get a better understanding.

truncate=truncate,
normalize=normalize,
)
embeddings = json.loads(response.decode())
all_embeddings.extend(embeddings)

if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch):
raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}")

all_embeddings.extend(np_embeddings.tolist())

return all_embeddings

Expand Down
35 changes: 25 additions & 10 deletions haystack/components/embedders/hugging_face_api_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
import warnings
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
Expand Down Expand Up @@ -80,8 +80,8 @@ def __init__(
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
prefix: str = "",
suffix: str = "",
truncate: bool = True,
normalize: bool = False,
truncate: Optional[bool] = True,
normalize: Optional[bool] = False,
): # pylint: disable=too-many-positional-arguments
"""
Creates a HuggingFaceAPITextEmbedder component.
Expand All @@ -104,13 +104,11 @@ def __init__(
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
if the backend uses Text Embeddings Inference.
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
It is always set to `True` and cannot be changed.
:param normalize:
Normalizes the embeddings to unit length.
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
if the backend uses Text Embeddings Inference.
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
It is always set to `False` and cannot be changed.
"""
huggingface_hub_import.check()

Expand Down Expand Up @@ -198,12 +196,29 @@ def run(self, text: str):
"In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder."
)

truncate = self.truncate
normalize = self.normalize

if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
if truncate is not None:
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg)
truncate = None
if normalize is not None:
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg)
normalize = None

text_to_embed = self.prefix + text + self.suffix

response = self._client.post(
json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
task="feature-extraction",
)
embedding = json.loads(response.decode())[0]
np_embedding = self._client.feature_extraction(text=text_to_embed, truncate=truncate, normalize=normalize)

error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
if np_embedding.ndim > 2:
raise ValueError(error_msg)
if np_embedding.ndim == 2 and np_embedding.shape[0] != 1:
raise ValueError(error_msg)

embedding = np_embedding.flatten().tolist()

return {"embedding": embedding}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ extra-dependencies = [
"numba>=0.54.0", # This pin helps uv resolve the dependency tree. See https://github.com/astral-sh/uv/issues/7881

"transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
"huggingface_hub>=0.27.0, <0.28.0", # Hugging Face API Generators and Embedders
"huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on my local tests, this solution works both with the new and the old version, so we can safely remove the pin

"sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
"openai-whisper>=20231106", # LocalWhisperTranscriber
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
In the Hugging Face API embedders, the `InferenceClient.feature_extraction` method is now used instead of
`InferenceClient.post` to compute embeddings. This ensures a more robust and future-proof implementation.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import pytest
from huggingface_hub.utils import RepositoryNotFoundError

from numpy import array

from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
from haystack.dataclasses import Document
from haystack.utils.auth import Secret
Expand All @@ -23,8 +25,8 @@ def mock_check_valid_model():
yield mock


def mock_embedding_generation(json, **kwargs):
response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode()
def mock_embedding_generation(text, **kwargs):
response = array([[random.random() for _ in range(384)] for _ in range(len(text))])
return response


Expand Down Expand Up @@ -201,10 +203,10 @@ def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model):
"my_prefix document number 4 my_suffix",
]

def test_embed_batch(self, mock_check_valid_model):
def test_embed_batch(self, mock_check_valid_model, recwarn):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]

with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation

embedder = HuggingFaceAPIDocumentEmbedder(
Expand All @@ -223,6 +225,40 @@ def test_embed_batch(self, mock_check_valid_model):
assert len(embedding) == 384
assert all(isinstance(x, float) for x in embedding)

# Check that warnings about ignoring truncate and normalize are raised
assert len(recwarn) == 2
assert "truncate" in str(recwarn[0].message)
assert "normalize" in str(recwarn[1].message)

def test_embed_batch_wrong_embedding_shape(self, mock_check_valid_model):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]

# embedding ndim != 2
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([0.1, 0.2, 0.3])

embedder = HuggingFaceAPIDocumentEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "BAAI/bge-small-en-v1.5"},
token=Secret.from_token("fake-api-token"),
)

with pytest.raises(ValueError):
embedder._embed_batch(texts_to_embed=texts, batch_size=2)

# embedding ndim == 2 but shape[0] != len(batch)
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])

embedder = HuggingFaceAPIDocumentEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "BAAI/bge-small-en-v1.5"},
token=Secret.from_token("fake-api-token"),
)

with pytest.raises(ValueError):
embedder._embed_batch(texts_to_embed=texts, batch_size=2)

def test_run_wrong_input_format(self, mock_check_valid_model):
embedder = HuggingFaceAPIDocumentEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
Expand Down Expand Up @@ -252,7 +288,7 @@ def test_run(self, mock_check_valid_model):
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]

with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation

embedder = HuggingFaceAPIDocumentEmbedder(
Expand All @@ -268,16 +304,14 @@ def test_run(self, mock_check_valid_model):
result = embedder.run(documents=docs)

mock_embedding_patch.assert_called_once_with(
json={
"inputs": [
"prefix Cuisine | I love cheese suffix",
"prefix ML | A transformer is a deep learning architecture suffix",
],
"truncate": True,
"normalize": False,
},
task="feature-extraction",
text=[
"prefix Cuisine | I love cheese suffix",
"prefix ML | A transformer is a deep learning architecture suffix",
],
truncate=None,
normalize=None,
)

documents_with_embeddings = result["documents"]

assert isinstance(documents_with_embeddings, list)
Expand All @@ -294,7 +328,7 @@ def test_run_custom_batch_size(self, mock_check_valid_model):
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]

with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation

embedder = HuggingFaceAPIDocumentEmbedder(
Expand Down
44 changes: 33 additions & 11 deletions test/components/embedders/test_hugging_face_api_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import random
import pytest
from huggingface_hub.utils import RepositoryNotFoundError

from numpy import array
from haystack.components.embedders import HuggingFaceAPITextEmbedder
from haystack.utils.auth import Secret
from haystack.utils.hf import HFEmbeddingAPIType
Expand All @@ -21,11 +21,6 @@ def mock_check_valid_model():
yield mock


def mock_embedding_generation(json, **kwargs):
response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode()
return response


class TestHuggingFaceAPITextEmbedder:
def test_init_invalid_api_type(self):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -141,9 +136,9 @@ def test_run_wrong_input_format(self, mock_check_valid_model):
with pytest.raises(TypeError):
embedder.run(text=list_integers_input)

def test_run(self, mock_check_valid_model):
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation
def test_run(self, mock_check_valid_model, recwarn):
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]])

embedder = HuggingFaceAPITextEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
Expand All @@ -156,13 +151,40 @@ def test_run(self, mock_check_valid_model):
result = embedder.run(text="The food was delicious")

mock_embedding_patch.assert_called_once_with(
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
task="feature-extraction",
text="prefix The food was delicious suffix", truncate=None, normalize=None
)

assert len(result["embedding"]) == 384
assert all(isinstance(x, float) for x in result["embedding"])

# Check that warnings about ignoring truncate and normalize are raised
assert len(recwarn) == 2
assert "truncate" in str(recwarn[0].message)
assert "normalize" in str(recwarn[1].message)

def test_run_wrong_embedding_shape(self, mock_check_valid_model):
# embedding ndim > 2
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]])

embedder = HuggingFaceAPITextEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
)

with pytest.raises(ValueError):
embedder.run(text="The food was delicious")

# embedding ndim == 2 but shape[0] != 1
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])

embedder = HuggingFaceAPITextEmbedder(
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
)

with pytest.raises(ValueError):
embedder.run(text="The food was delicious")

@pytest.mark.flaky(reruns=5, reruns_delay=5)
@pytest.mark.integration
@pytest.mark.skipif(
Expand Down