From 07e3056b203cc9a6a32553c7e0d03de377b2f437 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 10 Jul 2024 15:13:46 +0200 Subject: [PATCH] feat: add sentence window retrieval (#7997) * initial import * adding tests * adding license and release notes * adding missing release notes * working with any type of doc store * nit * adding get_class_object to serialization package * nit * refactoring get_class_object() * refactoring get_class_object() * chaning type and var names * more refactoring * Update haystack/core/serialization.py Co-authored-by: Vladimir Blagojevic * Update haystack/core/serialization.py Co-authored-by: Vladimir Blagojevic * Update test/core/test_serialization.py Co-authored-by: Vladimir Blagojevic * more refactoring * more refactoring * Pydoc syntax --------- Co-authored-by: Vladimir Blagojevic --- haystack/components/retrievers/__init__.py | 3 +- .../retrievers/sentence_window_retrieval.py | 139 +++++++++++++++++ haystack/core/serialization.py | 25 ++- ...nce-window-retrieval-5de4b0d6b2e8b0d6.yaml | 7 + .../test_sentence_window_retriever.py | 143 ++++++++++++++++++ test/core/test_serialization.py | 22 ++- 6 files changed, 336 insertions(+), 3 deletions(-) create mode 100644 haystack/components/retrievers/sentence_window_retrieval.py create mode 100644 releasenotes/notes/add-sentence-window-retrieval-5de4b0d6b2e8b0d6.yaml create mode 100644 test/components/retrievers/test_sentence_window_retriever.py diff --git a/haystack/components/retrievers/__init__.py b/haystack/components/retrievers/__init__.py index 92c9da0c77..e86e40fbca 100644 --- a/haystack/components/retrievers/__init__.py +++ b/haystack/components/retrievers/__init__.py @@ -5,5 +5,6 @@ from haystack.components.retrievers.filter_retriever import FilterRetriever from haystack.components.retrievers.in_memory.bm25_retriever import InMemoryBM25Retriever from haystack.components.retrievers.in_memory.embedding_retriever import InMemoryEmbeddingRetriever +from haystack.components.retrievers.sentence_window_retrieval import SentenceWindowRetrieval -__all__ = ["FilterRetriever", "InMemoryEmbeddingRetriever", "InMemoryBM25Retriever"] +__all__ = ["FilterRetriever", "InMemoryEmbeddingRetriever", "InMemoryBM25Retriever", "SentenceWindowRetrieval"] diff --git a/haystack/components/retrievers/sentence_window_retrieval.py b/haystack/components/retrievers/sentence_window_retrieval.py new file mode 100644 index 0000000000..a4e0ee23ef --- /dev/null +++ b/haystack/components/retrievers/sentence_window_retrieval.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List + +from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict +from haystack.core.serialization import import_class_by_name +from haystack.document_stores.types import DocumentStore + + +@component +class SentenceWindowRetrieval: + """ + A component that retrieves surrounding documents of a given document from the document store. + + This component is designed to work together with one of the existing retrievers, e.g. BM25Retriever, + EmbeddingRetriever. One of these retrievers can be used to retrieve documents based on a query and then use this + component to get the surrounding documents of the retrieved documents. + """ + + def __init__(self, document_store: DocumentStore, window_size: int = 3): + """ + Creates a new SentenceWindowRetrieval component. + + :param document_store: The document store to use for retrieving the surrounding documents. + :param window_size: The number of surrounding documents to retrieve. + """ + if window_size < 1: + raise ValueError("The window_size parameter must be greater than 0.") + + self.window_size = window_size + self.document_store = document_store + + @staticmethod + def merge_documents_text(documents: List[Document]) -> str: + """ + Merge a list of document text into a single string. + + This functions concatenates the textual content of a list of documents into a single string, eliminating any + overlapping content. + + :param documents: List of Documents to merge. + """ + sorted_docs = sorted(documents, key=lambda doc: doc.meta["split_idx_start"]) + merged_text = "" + last_idx_end = 0 + for doc in sorted_docs: + start = doc.meta["split_idx_start"] # start of the current content + + # if the start of the current content is before the end of the last appended content, adjust it + start = max(start, last_idx_end) + + # append the non-overlapping part to the merged text + merged_text = merged_text.strip() + merged_text += doc.content[start - doc.meta["split_idx_start"] :] # type: ignore + + # update the last end index + last_idx_end = doc.meta["split_idx_start"] + len(doc.content) # type: ignore + + return merged_text + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + docstore = self.document_store.to_dict() + return default_to_dict(self, document_store=docstore, window_size=self.window_size) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SentenceWindowRetrieval": + """ + Deserializes the component from a dictionary. + + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + + if "document_store" not in init_params: + raise DeserializationError("Missing 'document_store' in serialization data") + if "type" not in init_params["document_store"]: + raise DeserializationError("Missing 'type' in document store's serialization data") + + # deserialize the document store + doc_store_data = data["init_parameters"]["document_store"] + try: + doc_store_class = import_class_by_name(doc_store_data["type"]) + except ImportError as e: + raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e + + data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data) + + # deserialize the component + return default_from_dict(cls, data) + + @component.output_types(context_windows=List[str]) + def run(self, retrieved_documents: List[Document]): + """ + Based on the `source_id` and on the `doc.meta['split_id']` get surrounding documents from the document store. + + Implements the logic behind the sentence-window technique, retrieving the surrounding documents of a given + document from the document store. + + :param retrieved_documents: List of retrieved documents from the previous retriever. + :type retrieved_documents: List[Document] + :returns: + A dictionary with the following keys: + - `context_windows`: List of strings representing the context windows of the retrieved documents. + """ + + if not all("split_id" in doc.meta for doc in retrieved_documents): + raise ValueError("The retrieved documents must have 'split_id' in the metadata.") + + if not all("source_id" in doc.meta for doc in retrieved_documents): + raise ValueError("The retrieved documents must have 'source_id' in the metadata.") + + context_windows = [] + for doc in retrieved_documents: + source_id = doc.meta["source_id"] + split_id = doc.meta["split_id"] + min_before = min(list(range(split_id - 1, split_id - self.window_size - 1, -1))) + max_after = max(list(range(split_id + 1, split_id + self.window_size + 1, 1))) + context_docs = self.document_store.filter_documents( + { + "operator": "AND", + "conditions": [ + {"field": "source_id", "operator": "==", "value": source_id}, + {"field": "split_id", "operator": ">=", "value": min_before}, + {"field": "split_id", "operator": "<=", "value": max_after}, + ], + } + ) + context_windows.append(self.merge_documents_text(context_docs)) + + return {"context_windows": context_windows} diff --git a/haystack/core/serialization.py b/haystack/core/serialization.py index 8426e4b4cf..15cff1e2c9 100644 --- a/haystack/core/serialization.py +++ b/haystack/core/serialization.py @@ -5,9 +5,10 @@ import inspect from collections.abc import Callable from dataclasses import dataclass +from importlib import import_module from typing import Any, Dict, Optional, Type -from haystack.core.component.component import _hook_component_init +from haystack.core.component.component import _hook_component_init, logger from haystack.core.errors import DeserializationError, SerializationError @@ -189,3 +190,25 @@ def default_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any: if data["type"] != generate_qualified_class_name(cls): raise DeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'") return cls(**init_params) + + +def import_class_by_name(fully_qualified_name: str) -> Type[object]: + """ + Utility function to import (load) a class object based on its fully qualified class name. + + This function dynamically imports a class based on its string name. + It splits the name into module path and class name, imports the module, + and returns the class object. + + :param fully_qualified_name: the fully qualified class name as a string + :returns: the class object. + :raises ImportError: If the class cannot be imported or found. + """ + try: + module_path, class_name = fully_qualified_name.rsplit(".", 1) + logger.debug(f"Attempting to import class '{class_name}' from module '{module_path}'") + module = import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as error: + logger.error(f"Failed to import class '{fully_qualified_name}'") + raise ImportError(f"Could not import class '{fully_qualified_name}'") from error diff --git a/releasenotes/notes/add-sentence-window-retrieval-5de4b0d6b2e8b0d6.yaml b/releasenotes/notes/add-sentence-window-retrieval-5de4b0d6b2e8b0d6.yaml new file mode 100644 index 0000000000..6679d84757 --- /dev/null +++ b/releasenotes/notes/add-sentence-window-retrieval-5de4b0d6b2e8b0d6.yaml @@ -0,0 +1,7 @@ +--- + +features: + - | + Adding a new component allowing to perform sentence-window retrieval, i.e. retrieves surrounding documents of a + given document from the document store. This is useful when a document is split into multiple chunks and you want to + retrieve the surrounding context of a given chunk. diff --git a/test/components/retrievers/test_sentence_window_retriever.py b/test/components/retrievers/test_sentence_window_retriever.py new file mode 100644 index 0000000000..0d057fea05 --- /dev/null +++ b/test/components/retrievers/test_sentence_window_retriever.py @@ -0,0 +1,143 @@ +import pytest + +from haystack import Document, DeserializationError +from haystack.components.retrievers.sentence_window_retrieval import SentenceWindowRetrieval +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.components.preprocessors import DocumentSplitter + + +class TestSentenceWindowRetrieval: + def test_init_default(self): + retrieval = SentenceWindowRetrieval(InMemoryDocumentStore()) + assert retrieval.window_size == 3 + + def test_init_with_parameters(self): + retrieval = SentenceWindowRetrieval(InMemoryDocumentStore(), window_size=5) + assert retrieval.window_size == 5 + + def test_init_with_invalid_window_size_parameter(self): + with pytest.raises(ValueError): + SentenceWindowRetrieval(InMemoryDocumentStore(), window_size=-2) + + def test_merge_documents(self): + docs = [ + { + "id": "doc_0", + "content": "This is a text with some words. There is a ", + "source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99", + "page_number": 1, + "split_id": 0, + "split_idx_start": 0, + "_split_overlap": [{"doc_id": "doc_1", "range": (0, 22)}], + }, + { + "id": "doc_1", + "content": "some words. There is a second sentence. And there is ", + "source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99", + "page_number": 1, + "split_id": 1, + "split_idx_start": 21, + "_split_overlap": [{"doc_id": "doc_0", "range": (20, 42)}, {"doc_id": "doc_2", "range": (0, 29)}], + }, + { + "id": "doc_2", + "content": "second sentence. And there is also a third sentence", + "source_id": "c5d7c632affc486d0cfe7b3c0f4dc1d3896ea720da2b538d6d10b104a3df5f99", + "page_number": 1, + "split_id": 2, + "split_idx_start": 45, + "_split_overlap": [{"doc_id": "doc_1", "range": (23, 52)}], + }, + ] + merged_text = SentenceWindowRetrieval.merge_documents_text([Document.from_dict(doc) for doc in docs]) + expected = "This is a text with some words. There is a second sentence. And there is also a third sentence" + assert merged_text == expected + + def test_to_dict(self): + window_retrieval = SentenceWindowRetrieval(InMemoryDocumentStore()) + data = window_retrieval.to_dict() + + assert data["type"] == "haystack.components.retrievers.sentence_window_retrieval.SentenceWindowRetrieval" + assert data["init_parameters"]["window_size"] == 3 + assert ( + data["init_parameters"]["document_store"]["type"] + == "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore" + ) + + def test_from_dict(self): + data = { + "type": "haystack.components.retrievers.sentence_window_retrieval.SentenceWindowRetrieval", + "init_parameters": { + "document_store": { + "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", + "init_parameters": {}, + }, + "window_size": 5, + }, + } + component = SentenceWindowRetrieval.from_dict(data) + assert isinstance(component.document_store, InMemoryDocumentStore) + assert component.window_size == 5 + + def test_from_dict_without_docstore(self): + data = {"type": "SentenceWindowRetrieval", "init_parameters": {}} + with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): + SentenceWindowRetrieval.from_dict(data) + + def test_from_dict_without_docstore_type(self): + data = {"type": "SentenceWindowRetrieval", "init_parameters": {"document_store": {"init_parameters": {}}}} + with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): + SentenceWindowRetrieval.from_dict(data) + + def test_from_dict_non_existing_docstore(self): + data = { + "type": "SentenceWindowRetrieval", + "init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}}, + } + with pytest.raises(DeserializationError): + SentenceWindowRetrieval.from_dict(data) + + def test_document_without_split_id(self): + docs = [ + Document(content="This is a text with some words. There is a ", meta={"id": "doc_0"}), + Document(content="some words. There is a second sentence. And there is ", meta={"id": "doc_1"}), + ] + with pytest.raises(ValueError): + retriever = SentenceWindowRetrieval(document_store=InMemoryDocumentStore(), window_size=3) + retriever.run(retrieved_documents=docs) + + def test_document_without_source_id(self): + docs = [ + Document(content="This is a text with some words. There is a ", meta={"id": "doc_0", "split_id": 0}), + Document( + content="some words. There is a second sentence. And there is ", meta={"id": "doc_1", "split_id": 1} + ), + ] + with pytest.raises(ValueError): + retriever = SentenceWindowRetrieval(document_store=InMemoryDocumentStore(), window_size=3) + retriever.run(retrieved_documents=docs) + + @pytest.mark.integration + def test_run_with_pipeline(self): + splitter = DocumentSplitter(split_length=10, split_overlap=5, split_by="word") + text = ( + "This is a text with some words. There is a second sentence. And there is also a third sentence. " + "It also contains a fourth sentence. And a fifth sentence. And a sixth sentence. And a seventh sentence" + ) + + doc = Document(content=text) + + docs = splitter.run([doc]) + ds = InMemoryDocumentStore() + ds.write_documents(docs["documents"]) + + retriever = SentenceWindowRetrieval(document_store=ds, window_size=3) + result = retriever.run(retrieved_documents=[list(ds.storage.values())[3]]) + expected = { + "context_windows": [ + "This is a text with some words. There is a second sentence. And there is also a third sentence. It " + "also contains a fourth sentence. And a fifth sentence. And a sixth sentence. And a seventh sentence" + ] + } + + assert result == expected diff --git a/test/core/test_serialization.py b/test/core/test_serialization.py index 9907c9781a..755271bdea 100644 --- a/test/core/test_serialization.py +++ b/test/core/test_serialization.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import datetime + import sys from unittest.mock import Mock @@ -10,7 +12,12 @@ from haystack.core.component import component from haystack.core.errors import DeserializationError from haystack.testing import factory -from haystack.core.serialization import default_to_dict, default_from_dict, generate_qualified_class_name +from haystack.core.serialization import ( + default_to_dict, + default_from_dict, + generate_qualified_class_name, + import_class_by_name, +) def test_default_component_to_dict(): @@ -87,3 +94,16 @@ def test_get_qualified_class_name(): comp = MyComponent() res = generate_qualified_class_name(type(comp)) assert res == "haystack.testing.factory.MyComponent" + + +def test_import_class_by_name(): + data = "haystack.core.pipeline.Pipeline" + class_object = import_class_by_name(data) + class_instance = class_object() + assert isinstance(class_instance, Pipeline) + + +def test_import_class_by_name_no_valid_class(): + data = "some.invalid.class" + with pytest.raises(ImportError): + import_class_by_name(data)