From e06997aed5949a4c1ce4f7cdf579ab7950c50b81 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 31 Jan 2025 17:31:44 +0100 Subject: [PATCH 1/8] HF API Embedders: refactoring --- .../hugging_face_api_document_embedder.py | 21 ++++--- .../hugging_face_api_text_embedder.py | 19 ++++-- ...test_hugging_face_api_document_embedder.py | 59 ++++++++++++++----- .../test_hugging_face_api_text_embedder.py | 37 ++++++++---- 4 files changed, 96 insertions(+), 40 deletions(-) diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index 459e386976..2da60204b3 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import json from typing import Any, Dict, List, Optional, Union from tqdm import tqdm @@ -124,13 +123,11 @@ def __init__( Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `True` and cannot be changed. :param normalize: Normalizes the embeddings to unit length. Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `False` and cannot be changed. :param batch_size: Number of documents to process at once. :param progress_bar: @@ -239,18 +236,24 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[ """ Embed a list of texts in batches. """ - all_embeddings = [] for i in tqdm( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i : i + batch_size] - response = self._client.post( - json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize}, - task="feature-extraction", + + embeddings = self._client.feature_extraction( + # this method does not officially support list of strings, but works as expected + text=batch, # type: ignore[arg-type] + # Serverless Inference API does not support truncate and normalize, so we pass None in the request + truncate=self.truncate if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, + normalize=self.normalize if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, ) - embeddings = json.loads(response.decode()) - all_embeddings.extend(embeddings) + + if embeddings.ndim != 2 or embeddings.shape[0] != len(batch): + raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {embeddings.shape}") + + all_embeddings.extend(embeddings.tolist()) return all_embeddings diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index 2cd68d34da..a614f488af 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -104,13 +104,11 @@ def __init__( Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `True` and cannot be changed. :param normalize: Normalizes the embeddings to unit length. Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` if the backend uses Text Embeddings Inference. If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. - It is always set to `False` and cannot be changed. """ huggingface_hub_import.check() @@ -200,10 +198,19 @@ def run(self, text: str): text_to_embed = self.prefix + text + self.suffix - response = self._client.post( - json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize}, - task="feature-extraction", + response = self._client.feature_extraction( + text=text_to_embed, + # Serverless Inference API does not support truncate and normalize, so we pass None in the request + truncate=self.truncate if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, + normalize=self.normalize if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, ) - embedding = json.loads(response.decode())[0] + + if response.ndim > 2: + raise ValueError(f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {response.shape}") + + if response.ndim == 2 and response.shape[0] != 1: + raise ValueError(f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {response.shape}") + + embedding = response.flatten().tolist() return {"embedding": embedding} diff --git a/test/components/embedders/test_hugging_face_api_document_embedder.py b/test/components/embedders/test_hugging_face_api_document_embedder.py index b9332d5363..2627e1cf4a 100644 --- a/test/components/embedders/test_hugging_face_api_document_embedder.py +++ b/test/components/embedders/test_hugging_face_api_document_embedder.py @@ -8,6 +8,8 @@ import pytest from huggingface_hub.utils import RepositoryNotFoundError +from numpy import array + from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder from haystack.dataclasses import Document from haystack.utils.auth import Secret @@ -23,8 +25,8 @@ def mock_check_valid_model(): yield mock -def mock_embedding_generation(json, **kwargs): - response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode() +def mock_embedding_generation(text, **kwargs): + response = array([[random.random() for _ in range(384)] for _ in range(len(text))]) return response @@ -204,7 +206,7 @@ def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): def test_embed_batch(self, mock_check_valid_model): texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceAPIDocumentEmbedder( @@ -223,6 +225,35 @@ def test_embed_batch(self, mock_check_valid_model): assert len(embedding) == 384 assert all(isinstance(x, float) for x in embedding) + def test_embed_batch_wrong_embedding_shape(self, mock_check_valid_model): + texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] + + # embedding ndim != 2 + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([0.1, 0.2, 0.3]) + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + + with pytest.raises(ValueError): + embedder._embed_batch(texts_to_embed=texts, batch_size=2) + + # embedding ndim == 2 but shape[0] != len(batch) + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + + with pytest.raises(ValueError): + embedder._embed_batch(texts_to_embed=texts, batch_size=2) + def test_run_wrong_input_format(self, mock_check_valid_model): embedder = HuggingFaceAPIDocumentEmbedder( api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} @@ -252,7 +283,7 @@ def test_run(self, mock_check_valid_model): Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceAPIDocumentEmbedder( @@ -268,16 +299,14 @@ def test_run(self, mock_check_valid_model): result = embedder.run(documents=docs) mock_embedding_patch.assert_called_once_with( - json={ - "inputs": [ - "prefix Cuisine | I love cheese suffix", - "prefix ML | A transformer is a deep learning architecture suffix", - ], - "truncate": True, - "normalize": False, - }, - task="feature-extraction", + text=[ + "prefix Cuisine | I love cheese suffix", + "prefix ML | A transformer is a deep learning architecture suffix", + ], + truncate=None, + normalize=None, ) + documents_with_embeddings = result["documents"] assert isinstance(documents_with_embeddings, list) @@ -294,7 +323,7 @@ def test_run_custom_batch_size(self, mock_check_valid_model): Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceAPIDocumentEmbedder( @@ -322,7 +351,7 @@ def test_run_custom_batch_size(self, mock_check_valid_model): assert len(doc.embedding) == 384 assert all(isinstance(x, float) for x in doc.embedding) - @pytest.mark.flaky(reruns=5, reruns_delay=5) + # @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), diff --git a/test/components/embedders/test_hugging_face_api_text_embedder.py b/test/components/embedders/test_hugging_face_api_text_embedder.py index 6e699fca25..11557088be 100644 --- a/test/components/embedders/test_hugging_face_api_text_embedder.py +++ b/test/components/embedders/test_hugging_face_api_text_embedder.py @@ -7,7 +7,7 @@ import random import pytest from huggingface_hub.utils import RepositoryNotFoundError - +from numpy import array from haystack.components.embedders import HuggingFaceAPITextEmbedder from haystack.utils.auth import Secret from haystack.utils.hf import HFEmbeddingAPIType @@ -21,11 +21,6 @@ def mock_check_valid_model(): yield mock -def mock_embedding_generation(json, **kwargs): - response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode() - return response - - class TestHuggingFaceAPITextEmbedder: def test_init_invalid_api_type(self): with pytest.raises(ValueError): @@ -142,8 +137,8 @@ def test_run_wrong_input_format(self, mock_check_valid_model): embedder.run(text=list_integers_input) def test_run(self, mock_check_valid_model): - with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: - mock_embedding_patch.side_effect = mock_embedding_generation + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]]) embedder = HuggingFaceAPITextEmbedder( api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, @@ -156,13 +151,35 @@ def test_run(self, mock_check_valid_model): result = embedder.run(text="The food was delicious") mock_embedding_patch.assert_called_once_with( - json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False}, - task="feature-extraction", + text="prefix The food was delicious suffix", truncate=None, normalize=None ) assert len(result["embedding"]) == 384 assert all(isinstance(x, float) for x in result["embedding"]) + def test_run_wrong_embedding_shape(self, mock_check_valid_model): + # embedding ndim > 2 + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]) + + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} + ) + + with pytest.raises(ValueError): + embedder.run(text="The food was delicious") + + # embedding ndim == 2 but shape[0] != 1 + with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} + ) + + with pytest.raises(ValueError): + embedder.run(text="The food was delicious") + @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif( From d7e3fb11b6faeb26420c0183ff15e41c99af656f Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 31 Jan 2025 17:35:46 +0100 Subject: [PATCH 2/8] rename variables --- .../embedders/hugging_face_api_document_embedder.py | 8 ++++---- .../embedders/hugging_face_api_text_embedder.py | 13 +++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index 2da60204b3..d6d4153a19 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -242,7 +242,7 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[ ): batch = texts_to_embed[i : i + batch_size] - embeddings = self._client.feature_extraction( + np_embeddings = self._client.feature_extraction( # this method does not officially support list of strings, but works as expected text=batch, # type: ignore[arg-type] # Serverless Inference API does not support truncate and normalize, so we pass None in the request @@ -250,10 +250,10 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[ normalize=self.normalize if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, ) - if embeddings.ndim != 2 or embeddings.shape[0] != len(batch): - raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {embeddings.shape}") + if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): + raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}") - all_embeddings.extend(embeddings.tolist()) + all_embeddings.extend(np_embeddings.tolist()) return all_embeddings diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index a614f488af..005f2c2543 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -198,19 +198,20 @@ def run(self, text: str): text_to_embed = self.prefix + text + self.suffix - response = self._client.feature_extraction( + np_embedding = self._client.feature_extraction( text=text_to_embed, # Serverless Inference API does not support truncate and normalize, so we pass None in the request truncate=self.truncate if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, normalize=self.normalize if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, ) - if response.ndim > 2: - raise ValueError(f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {response.shape}") + error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}" + if np_embedding.ndim > 2: + raise ValueError(error_msg) - if response.ndim == 2 and response.shape[0] != 1: - raise ValueError(f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {response.shape}") + if np_embedding.ndim == 2 and np_embedding.shape[0] != 1: + raise ValueError(error_msg) - embedding = response.flatten().tolist() + embedding = np_embedding.flatten().tolist() return {"embedding": embedding} From 3de0901022aaf944b3a9fad22470637f9cead66c Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 31 Jan 2025 17:39:19 +0100 Subject: [PATCH 3/8] rm leftovers --- haystack/components/embedders/hugging_face_api_text_embedder.py | 1 - .../embedders/test_hugging_face_api_document_embedder.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index 005f2c2543..f19506d97d 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -208,7 +208,6 @@ def run(self, text: str): error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}" if np_embedding.ndim > 2: raise ValueError(error_msg) - if np_embedding.ndim == 2 and np_embedding.shape[0] != 1: raise ValueError(error_msg) diff --git a/test/components/embedders/test_hugging_face_api_document_embedder.py b/test/components/embedders/test_hugging_face_api_document_embedder.py index 2627e1cf4a..dc44639131 100644 --- a/test/components/embedders/test_hugging_face_api_document_embedder.py +++ b/test/components/embedders/test_hugging_face_api_document_embedder.py @@ -351,7 +351,7 @@ def test_run_custom_batch_size(self, mock_check_valid_model): assert len(doc.embedding) == 384 assert all(isinstance(x, float) for x in doc.embedding) - # @pytest.mark.flaky(reruns=5, reruns_delay=5) + @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), From 3599e0516bbf621e3abc3ba5e68e6081d510ff20 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 31 Jan 2025 17:43:13 +0100 Subject: [PATCH 4/8] rm pin --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9a5f15070b..eda943c19a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ extra-dependencies = [ "numba>=0.54.0", # This pin helps uv resolve the dependency tree. See https://github.com/astral-sh/uv/issues/7881 "transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... - "huggingface_hub>=0.27.0, <0.28.0", # Hugging Face API Generators and Embedders + "huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders "sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder "langdetect", # TextLanguageRouter and DocumentLanguageClassifier "openai-whisper>=20231106", # LocalWhisperTranscriber From 6c32f8328c3e0cfd8e56737d1d90271c908a3408 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 31 Jan 2025 17:47:40 +0100 Subject: [PATCH 5/8] rm unused import --- haystack/components/embedders/hugging_face_api_text_embedder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index f19506d97d..c100f5c42e 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import json from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging From 56afdd025c380fed9a488c1932604ad53e8f75f2 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 31 Jan 2025 18:09:03 +0100 Subject: [PATCH 6/8] relnote --- .../hf-embedders-feature-extraction-ea0421a8f76052f0.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml diff --git a/releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml b/releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml new file mode 100644 index 0000000000..baf9a890aa --- /dev/null +++ b/releasenotes/notes/hf-embedders-feature-extraction-ea0421a8f76052f0.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + In the Hugging Face API embedders, the `InferenceClient.feature_extraction` method is now used instead of + `InferenceClient.post` to compute embeddings. This ensures a more robust and future-proof implementation. From f9479e0385ab216552911337d90e75f675d155ee Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 3 Feb 2025 15:28:06 +0100 Subject: [PATCH 7/8] warning with truncate/normalize and serverless inference API --- .../hugging_face_api_document_embedder.py | 23 +++++++++++++---- .../hugging_face_api_text_embedder.py | 25 +++++++++++++------ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index d6d4153a19..d3b92fb74c 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import warnings from typing import Any, Dict, List, Optional, Union from tqdm import tqdm @@ -95,8 +96,8 @@ def __init__( token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), prefix: str = "", suffix: str = "", - truncate: bool = True, - normalize: bool = False, + truncate: Optional[bool] = True, + normalize: Optional[bool] = False, batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, @@ -236,6 +237,19 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[ """ Embed a list of texts in batches. """ + truncate = self.truncate + normalize = self.normalize + + if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: + if truncate is not None: + msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + truncate = None + if normalize is not None: + msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + normalize = None + all_embeddings = [] for i in tqdm( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" @@ -245,9 +259,8 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[ np_embeddings = self._client.feature_extraction( # this method does not officially support list of strings, but works as expected text=batch, # type: ignore[arg-type] - # Serverless Inference API does not support truncate and normalize, so we pass None in the request - truncate=self.truncate if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, - normalize=self.normalize if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, + truncate=truncate, + normalize=normalize, ) if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index c100f5c42e..535d3a9430 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import warnings from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging @@ -79,8 +80,8 @@ def __init__( token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), prefix: str = "", suffix: str = "", - truncate: bool = True, - normalize: bool = False, + truncate: Optional[bool] = True, + normalize: Optional[bool] = False, ): # pylint: disable=too-many-positional-arguments """ Creates a HuggingFaceAPITextEmbedder component. @@ -195,14 +196,22 @@ def run(self, text: str): "In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder." ) + truncate = self.truncate + normalize = self.normalize + + if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: + if truncate is not None: + msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + truncate = None + if normalize is not None: + msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." + warnings.warn(msg) + normalize = None + text_to_embed = self.prefix + text + self.suffix - np_embedding = self._client.feature_extraction( - text=text_to_embed, - # Serverless Inference API does not support truncate and normalize, so we pass None in the request - truncate=self.truncate if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, - normalize=self.normalize if self.api_type != HFEmbeddingAPIType.SERVERLESS_INFERENCE_API else None, - ) + np_embedding = self._client.feature_extraction(text=text_to_embed, truncate=truncate, normalize=normalize) error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}" if np_embedding.ndim > 2: From a27287d423dd2e8c61b4ad8eca3e83cd7490baee Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 3 Feb 2025 15:54:16 +0100 Subject: [PATCH 8/8] test that warnings are raised --- .../embedders/test_hugging_face_api_document_embedder.py | 7 ++++++- .../embedders/test_hugging_face_api_text_embedder.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/components/embedders/test_hugging_face_api_document_embedder.py b/test/components/embedders/test_hugging_face_api_document_embedder.py index dc44639131..9d452b02ca 100644 --- a/test/components/embedders/test_hugging_face_api_document_embedder.py +++ b/test/components/embedders/test_hugging_face_api_document_embedder.py @@ -203,7 +203,7 @@ def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): "my_prefix document number 4 my_suffix", ] - def test_embed_batch(self, mock_check_valid_model): + def test_embed_batch(self, mock_check_valid_model, recwarn): texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: @@ -225,6 +225,11 @@ def test_embed_batch(self, mock_check_valid_model): assert len(embedding) == 384 assert all(isinstance(x, float) for x in embedding) + # Check that warnings about ignoring truncate and normalize are raised + assert len(recwarn) == 2 + assert "truncate" in str(recwarn[0].message) + assert "normalize" in str(recwarn[1].message) + def test_embed_batch_wrong_embedding_shape(self, mock_check_valid_model): texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] diff --git a/test/components/embedders/test_hugging_face_api_text_embedder.py b/test/components/embedders/test_hugging_face_api_text_embedder.py index 11557088be..84b2d6e83c 100644 --- a/test/components/embedders/test_hugging_face_api_text_embedder.py +++ b/test/components/embedders/test_hugging_face_api_text_embedder.py @@ -136,7 +136,7 @@ def test_run_wrong_input_format(self, mock_check_valid_model): with pytest.raises(TypeError): embedder.run(text=list_integers_input) - def test_run(self, mock_check_valid_model): + def test_run(self, mock_check_valid_model, recwarn): with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]]) @@ -157,6 +157,11 @@ def test_run(self, mock_check_valid_model): assert len(result["embedding"]) == 384 assert all(isinstance(x, float) for x in result["embedding"]) + # Check that warnings about ignoring truncate and normalize are raised + assert len(recwarn) == 2 + assert "truncate" in str(recwarn[0].message) + assert "normalize" in str(recwarn[1].message) + def test_run_wrong_embedding_shape(self, mock_check_valid_model): # embedding ndim > 2 with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: