diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index 459e386976..d3b92fb74c 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -import json +import warnings from typing import Any, Dict, List, Optional, Union from tqdm import tqdm @@ -96,8 +96,8 @@ def __init__( token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), prefix: str = "", suffix: str = "", - truncate: bool = True, - normalize: bool = False, + truncate: Optional[bool] = True, + normalize: Optional[bool] = False, batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, @@ -124,13 +124,11 @@ def __init__( Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `True` and cannot be changed. :param normalize: Normalizes the embeddings to unit length. Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `False` and cannot be changed. :param batch_size: Number of documents to process at once. :param progress_bar: @@ -239,18 +237,36 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[ """ Embed a list of texts in batches. """ + truncate = self.truncate + normalize = self.normalize + + if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: + if truncate is not None: + msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + truncate = None + if normalize is not None: + msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + normalize = None all_embeddings = [] for i in tqdm( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i : i + batch_size] - response = self._client.post( - json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize}, - task="feature-extraction", + + np_embeddings = self._client.feature_extraction( + # this method does not officially support list of strings, but works as expected + text=batch, # type: ignore[arg-type] + truncate=truncate, + normalize=normalize, ) - embeddings = json.loads(response.decode()) - all_embeddings.extend(embeddings) + + if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): + raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}") + + all_embeddings.extend(np_embeddings.tolist()) return all_embeddings diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index 2cd68d34da..535d3a9430 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -import json +import warnings from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging @@ -80,8 +80,8 @@ def __init__( token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), prefix: str = "", suffix: str = "", - truncate: bool = True, - normalize: bool = False, + truncate: Optional[bool] = True, + normalize: Optional[bool] = False, ): # pylint: disable=too-many-positional-arguments """ Creates a HuggingFaceAPITextEmbedder component. @@ -104,13 +104,11 @@ def __init__( Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `True` and cannot be changed. :param normalize: Normalizes the embeddings to unit length. Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `False` and cannot be changed. """ huggingface_hub_import.check() @@ -198,12 +196,29 @@ def run(self, text: str): "In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder." ) + truncate = self.truncate + normalize = self.normalize + + if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: + if truncate is not None: + msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + truncate = None + if normalize is not None: + msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + normalize = None + text_to_embed = self.prefix + text + self.suffix - response = self._client.post( - json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize}, - task="feature-extraction", - ) - embedding = json.loads(response.decode())[0] + np_embedding = self._client.feature_extraction(text=text_to_embed, truncate=truncate, normalize=normalize) + + error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}" + if np_embedding.ndim > 2: + raise ValueError(error_msg) + if np_embedding.ndim == 2 and np_embedding.shape[0] != 1: + raise ValueError(error_msg) + + embedding = np_embedding.flatten().tolist() return {"embedding": embedding} diff --git a/pyproject.toml b/pyproject.toml index 9a5f15070b..eda943c19a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ extra-dependencies = [ "numba>=0.54.0", # This pin helps uv resolve the dependency tree. See https://github.com/astral-sh/uv/issues/7881 "transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... - "huggingface_hub>=0.27.0, <0.28.0", # Hugging Face API Generators and Embedders + "huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders "sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder "langdetect", # TextLanguageRouter and DocumentLanguageClassifier "openai-whisper>=20231106", # LocalWhisperTranscriber diff --git a/releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml b/releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml new file mode 100644 index 0000000000..baf9a890aa --- /dev/null +++ b/releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + In the Hugging Face API embedders, the `InferenceClient.feature_extraction` method is now used instead of + `InferenceClient.post` to compute embeddings. This ensures a more robust and future-proof implementation. diff --git a/test/components/embedders/test_hugging_face_api_document_embedder.py b/test/components/embedders/test_hugging_face_api_document_embedder.py index b9332d5363..9d452b02ca 100644 --- a/test/components/embedders/test_hugging_face_api_document_embedder.py +++ b/test/components/embedders/test_hugging_face_api_document_embedder.py @@ -8,6 +8,8 @@ import pytest from huggingface_hub.utils import RepositoryNotFoundError +from numpy import array + from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder from haystack.dataclasses import Document from haystack.utils.auth import Secret @@ -23,8 +25,8 @@ def mock_check_valid_model(): yield mock -def mock_embedding_generation(json, **kwargs): - response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode() +def mock_embedding_generation(text, **kwargs): + response = array([[random.random() for _ in range(384)] for _ in range(len(text))]) return response @@ -201,10 +203,10 @@ def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): "my_prefix document number 4 my_suffix", ] - def test_embed_batch(self, mock_check_valid_model): + def test_embed_batch(self, mock_check_valid_model, recwarn): texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceAPIDocumentEmbedder( @@ -223,6 +225,40 @@ def test_embed_batch(self, mock_check_valid_model): assert len(embedding) == 384 assert all(isinstance(x, float) for x in embedding) + # Check that warnings about ignoring truncate and normalize are raised + assert len(recwarn) == 2 + assert "truncate" in str(recwarn[0].message) + assert "normalize" in str(recwarn[1].message) + + def test_embed_batch_wrong_embedding_shape(self, mock_check_valid_model): + texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] + + # embedding ndim != 2 + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([0.1, 0.2, 0.3]) + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + + with pytest.raises(ValueError): + embedder._embed_batch(texts_to_embed=texts, batch_size=2) + + # embedding ndim == 2 but shape[0] != len(batch) + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + + with pytest.raises(ValueError): + embedder._embed_batch(texts_to_embed=texts, batch_size=2) + def test_run_wrong_input_format(self, mock_check_valid_model): embedder = HuggingFaceAPIDocumentEmbedder( api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} @@ -252,7 +288,7 @@ def test_run(self, mock_check_valid_model): Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceAPIDocumentEmbedder( @@ -268,16 +304,14 @@ def test_run(self, mock_check_valid_model): result = embedder.run(documents=docs) mock_embedding_patch.assert_called_once_with( - json={ - "inputs": [ - "prefix Cuisine | I love cheese suffix", - "prefix ML | A transformer is a deep learning architecture suffix", - ], - "truncate": True, - "normalize": False, - }, - task="feature-extraction", + text=[ + "prefix Cuisine | I love cheese suffix", + "prefix ML | A transformer is a deep learning architecture suffix", + ], + truncate=None, + normalize=None, ) + documents_with_embeddings = result["documents"] assert isinstance(documents_with_embeddings, list) @@ -294,7 +328,7 @@ def test_run_custom_batch_size(self, mock_check_valid_model): Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceAPIDocumentEmbedder( diff --git a/test/components/embedders/test_hugging_face_api_text_embedder.py b/test/components/embedders/test_hugging_face_api_text_embedder.py index 6e699fca25..84b2d6e83c 100644 --- a/test/components/embedders/test_hugging_face_api_text_embedder.py +++ b/test/components/embedders/test_hugging_face_api_text_embedder.py @@ -7,7 +7,7 @@ import random import pytest from huggingface_hub.utils import RepositoryNotFoundError - +from numpy import array from haystack.components.embedders import HuggingFaceAPITextEmbedder from haystack.utils.auth import Secret from haystack.utils.hf import HFEmbeddingAPIType @@ -21,11 +21,6 @@ def mock_check_valid_model(): yield mock -def mock_embedding_generation(json, **kwargs): - response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode() - return response - - class TestHuggingFaceAPITextEmbedder: def test_init_invalid_api_type(self): with pytest.raises(ValueError): @@ -141,9 +136,9 @@ def test_run_wrong_input_format(self, mock_check_valid_model): with pytest.raises(TypeError): embedder.run(text=list_integers_input) - def test_run(self, mock_check_valid_model): - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: - mock_embedding_patch.side_effect = mock_embedding_generation + def test_run(self, mock_check_valid_model, recwarn): + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]]) embedder = HuggingFaceAPITextEmbedder( api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, @@ -156,13 +151,40 @@ def test_run(self, mock_check_valid_model): result = embedder.run(text="The food was delicious") mock_embedding_patch.assert_called_once_with( - json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False}, - task="feature-extraction", + text="prefix The food was delicious suffix", truncate=None, normalize=None ) assert len(result["embedding"]) == 384 assert all(isinstance(x, float) for x in result["embedding"]) + # Check that warnings about ignoring truncate and normalize are raised + assert len(recwarn) == 2 + assert "truncate" in str(recwarn[0].message) + assert "normalize" in str(recwarn[1].message) + + def test_run_wrong_embedding_shape(self, mock_check_valid_model): + # embedding ndim > 2 + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]) + + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} + ) + + with pytest.raises(ValueError): + embedder.run(text="The food was delicious") + + # embedding ndim == 2 but shape[0] != 1 + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} + ) + + with pytest.raises(ValueError): + embedder.run(text="The food was delicious") + @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif(