Skip to content

Commit

Permalink
Cherry-pick docstores deserialization refactoring to 2.4.x (#8227)
Browse files Browse the repository at this point in the history
* fix: deserialize Document Stores using specific `from_dict` class methods (#8207)

* use from_dict

* unused import

* improve logic

* improve reno

* refactor: utility function for docstore deserialization (#8226)

* refactor docstore deserialization

* more tests

* reno; headers

* expose key
  • Loading branch information
anakin87 authored Aug 14, 2024
1 parent 40cb53f commit 34e2412
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 61 deletions.
18 changes: 4 additions & 14 deletions haystack/components/caching/cache_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from typing import Any, Dict, List

from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging
from haystack.core.serialization import import_class_by_name
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_parameters

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,18 +71,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "CacheChecker":
:returns:
Deserialized component.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")

doc_store_data = data["init_parameters"]["document_store"]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
# deserialize the document store
data = deserialize_document_store_in_init_parameters(data)

return default_from_dict(cls, data)

Expand Down
18 changes: 4 additions & 14 deletions haystack/components/retrievers/filter_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from typing import Any, Dict, List, Optional

from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging
from haystack.core.serialization import import_class_by_name
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_parameters

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,18 +77,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "FilterRetriever":
:returns:
The deserialized component.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")

doc_store_data = data["init_parameters"]["document_store"]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
# deserialize the document store
data = deserialize_document_store_in_init_parameters(data)

return default_from_dict(cls, data)

Expand Down
19 changes: 3 additions & 16 deletions haystack/components/retrievers/sentence_window_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from typing import Any, Dict, List

from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.core.serialization import import_class_by_name
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import DocumentStore
from haystack.utils import deserialize_document_store_in_init_parameters


@component
Expand Down Expand Up @@ -117,21 +117,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceWindowRetriever":
:returns:
Deserialized component.
"""
init_params = data.get("init_parameters", {})

if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")

# deserialize the document store
doc_store_data = data["init_parameters"]["document_store"]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e

data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
data = deserialize_document_store_in_init_parameters(data)

# deserialize the component
return default_from_dict(cls, data)
Expand Down
19 changes: 5 additions & 14 deletions haystack/components/writers/document_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from typing import Any, Dict, List, Optional

from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict, logging
from haystack.core.serialization import import_class_by_name
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.document_stores.types import DocumentStore, DuplicatePolicy
from haystack.utils import deserialize_document_store_in_init_parameters

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,18 +73,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "DocumentWriter":
:raises DeserializationError:
If the document store is not properly specified in the serialization data or its type cannot be imported.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")

doc_store_data = data["init_parameters"]["document_store"]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
data["init_parameters"]["document_store"] = default_from_dict(doc_store_class, doc_store_data)
# deserialize the document store
data = deserialize_document_store_in_init_parameters(data)

data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]]

return default_from_dict(cls, data)
Expand Down
2 changes: 2 additions & 0 deletions haystack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .auth import Secret, deserialize_secrets_inplace
from .callable_serialization import deserialize_callable, serialize_callable
from .device import ComponentDevice, Device, DeviceMap, DeviceType
from .docstore_deserialization import deserialize_document_store_in_init_parameters
from .expit import expit
from .filters import document_matches_filter
from .jupyter import is_in_jupyter
Expand All @@ -26,4 +27,5 @@
"deserialize_callable",
"serialize_type",
"deserialize_type",
"deserialize_document_store_in_init_parameters",
]
41 changes: 41 additions & 0 deletions haystack/utils/docstore_deserialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict

from haystack import DeserializationError
from haystack.core.serialization import default_from_dict, import_class_by_name


def deserialize_document_store_in_init_parameters(data: Dict[str, Any], key: str = "document_store") -> Dict[str, Any]:
"""
Deserializes a generic document store from the init_parameters of a serialized component.
:param data:
The dictionary to deserialize from.
:param key:
The key in the `data["init_parameters"]` dictionary where the document store is specified.
:returns:
The dictionary, with the document store deserialized.
:raises DeserializationError:
If the document store is not properly specified in the serialization data or its type cannot be imported.
"""
init_params = data.get("init_parameters", {})
if key not in init_params:
raise DeserializationError(f"Missing '{key}' in serialization data")
if "type" not in init_params[key]:
raise DeserializationError(f"Missing 'type' in {key} serialization data")

doc_store_data = data["init_parameters"][key]
try:
doc_store_class = import_class_by_name(doc_store_data["type"])
except ImportError as e:
raise DeserializationError(f"Class '{doc_store_data['type']}' not correctly imported") from e
if hasattr(doc_store_class, "from_dict"):
data["init_parameters"][key] = doc_store_class.from_dict(doc_store_data)
else:
data["init_parameters"][key] = default_from_dict(doc_store_class, doc_store_data)

return data
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Introduce an utility function to deserialize a generic Document Store
from the init_parameters of a serialized component.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
For components that support multiple Document Stores, prioritize using the specific `from_dict` class method
for deserialization when available. Otherwise, fall back to the generic `default_from_dict` method.
This impacts the following generic components: `CacheChecker`, `DocumentWriter`, `FilterRetriever`, and
`SentenceWindowRetriever`.
2 changes: 1 addition & 1 deletion test/components/caching/test_url_cache_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_from_dict_without_docstore_type(self):
"type": "haystack.components.caching.cache_checker.UrlCacheChecker",
"init_parameters": {"document_store": {"init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
with pytest.raises(DeserializationError):
CacheChecker.from_dict(data)

def test_from_dict_nonexisting_docstore(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_from_dict_without_docstore(self):

def test_from_dict_without_docstore_type(self):
data = {"type": "SentenceWindowRetriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
with pytest.raises(DeserializationError):
SentenceWindowRetriever.from_dict(data)

def test_from_dict_non_existing_docstore(self):
Expand Down
2 changes: 1 addition & 1 deletion test/components/writers/test_document_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_from_dict_without_docstore(self):

def test_from_dict_without_docstore_type(self):
data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
with pytest.raises(DeserializationError):
DocumentWriter.from_dict(data)

def test_from_dict_nonexisting_docstore(self):
Expand Down
89 changes: 89 additions & 0 deletions test/utils/test_docstore_deserialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import patch
import pytest

from haystack.document_stores.in_memory.document_store import InMemoryDocumentStore
from haystack.utils.docstore_deserialization import deserialize_document_store_in_init_parameters
from haystack.core.errors import DeserializationError


class FakeDocumentStore:
pass


def test_deserialize_document_store_in_init_parameters():
data = {
"type": "haystack.components.writers.document_writer.DocumentWriter",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
}
},
}

result = deserialize_document_store_in_init_parameters(data)
assert isinstance(result["init_parameters"]["document_store"], InMemoryDocumentStore)


def test_from_dict_is_called():
"""If the document store provides a from_dict method, it should be called."""
data = {
"type": "haystack.components.writers.document_writer.DocumentWriter",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
}
},
}

with patch.object(InMemoryDocumentStore, "from_dict") as mock_from_dict:
deserialize_document_store_in_init_parameters(data)

mock_from_dict.assert_called_once_with(
{"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", "init_parameters": {}}
)


def test_default_from_dict_is_called():
"""If the document store does not provide a from_dict method, default_from_dict should be called."""
data = {
"type": "haystack.components.writers.document_writer.DocumentWriter",
"init_parameters": {
"document_store": {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}}
},
}

with patch("haystack.utils.docstore_deserialization.default_from_dict") as mock_default_from_dict:
deserialize_document_store_in_init_parameters(data)

mock_default_from_dict.assert_called_once_with(
FakeDocumentStore, {"type": "test_docstore_deserialization.FakeDocumentStore", "init_parameters": {}}
)


def test_missing_document_store_key():
data = {"init_parameters": {"policy": "SKIP"}}
with pytest.raises(DeserializationError):
deserialize_document_store_in_init_parameters(data)


def test_missing_type_key_in_document_store():
data = {"init_parameters": {"document_store": {"init_parameters": {}}, "policy": "SKIP"}}
with pytest.raises(DeserializationError):
deserialize_document_store_in_init_parameters(data)


def test_invalid_class_import():
data = {
"init_parameters": {
"document_store": {"type": "invalid.module.InvalidClass", "init_parameters": {}},
"policy": "SKIP",
}
}
with pytest.raises(DeserializationError):
deserialize_document_store_in_init_parameters(data)

0 comments on commit 34e2412

Please sign in to comment.