diff --git a/CHANGELOG.md b/CHANGELOG.md index 5edbdf22..7175ec6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ ### Added - Added an optional `node_label_neo4j` parameter in the external retrievers to speed up the search query in Neo4j. - +- Added an optional `id_property_getter` callable parameter in the Qdrant retriever to allow for custom ID retrieval. ## 1.10.1 diff --git a/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py b/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py index eeccead1..14ee2c55 100644 --- a/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py +++ b/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py @@ -221,7 +221,7 @@ def get_search_results( ) result_tuples = [ - [f"{o[self.id_property_external]}", o["score"] or 0.0] + (f"{o[self.id_property_external]}", o["score"] or 0.0) for o in response["matches"] ] diff --git a/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py index 90ec5109..b49f8c7b 100644 --- a/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py +++ b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py @@ -20,6 +20,7 @@ import neo4j from pydantic import ValidationError from qdrant_client import QdrantClient +from qdrant_client.conversions.common_types import ScoredPoint from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import ( @@ -80,6 +81,7 @@ class QdrantNeo4jRetriever(ExternalRetriever): result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation `_). node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve. This label must be properly escaped if needed, eg "`Label with spaces`". + id_property_getter (Optional[Callable[[ScoredPoint], str]]): Function to get the id property from a ScoredPoint. Defaults to point.payload.get(id_property_external, point.id). Raises: RetrieverInitializationError: If validation of the input arguments fail. @@ -101,6 +103,7 @@ def __init__( ] = None, neo4j_database: Optional[str] = None, node_label_neo4j: Optional[str] = None, + id_property_getter: Optional[Callable[[ScoredPoint], Any]] = None, ): try: driver_model = Neo4jDriverModel(driver=driver) @@ -142,6 +145,14 @@ def __init__( self.return_properties = validated_data.return_properties self.retrieval_query = validated_data.retrieval_query self.result_formatter = validated_data.result_formatter + self.id_property_getter = id_property_getter + + def get_match_id_from_point(self, point: ScoredPoint) -> Any: + if self.id_property_getter: + return self.id_property_getter(point) + if point.payload is None: + raise ValueError(f"Payload is None for point {point}") + return point.payload.get(self.id_property_external, point.id) def get_search_results( self, @@ -220,10 +231,7 @@ def get_search_results( result_tuples = [] for point in points: - assert point.payload is not None - result_tuples.append( - [point.payload.get(self.id_property_external, point.id), point.score] - ) + result_tuples.append((self.get_match_id_from_point(point), point.score)) search_query = get_match_query( return_properties=self.return_properties, diff --git a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py index 99fc5171..29df9837 100644 --- a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py +++ b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py @@ -235,7 +235,7 @@ def get_search_results( logger.debug("Response: %s", response) result_tuples = [ - [f"{o.properties[self.id_property_external]}", o.metadata.certainty or 0.0] + (f"{o.properties[self.id_property_external]}", o.metadata.certainty or 0.0) for o in response.objects ] diff --git a/tests/unit/retrievers/external/test_pinecone.py b/tests/unit/retrievers/external/test_pinecone.py index 9c09b0fb..6e643e9a 100644 --- a/tests/unit/retrievers/external/test_pinecone.py +++ b/tests/unit/retrievers/external/test_pinecone.py @@ -95,7 +95,7 @@ def test_pinecone_retriever_search_happy_path( driver.execute_query.assert_called_once_with( search_query, { - "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], "id_property": "sync_id", }, database_=None, @@ -168,7 +168,7 @@ def test_pinecone_retriever_search_return_properties( driver.execute_query.assert_called_once_with( search_query, { - "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], "id_property": "sync_id", }, database_=None, @@ -228,7 +228,7 @@ def test_pinecone_retriever_search_retrieval_query( driver.execute_query.assert_called_once_with( search_query, { - "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], "id_property": "sync_id", }, database_=None, diff --git a/tests/unit/retrievers/external/test_qdrant.py b/tests/unit/retrievers/external/test_qdrant.py index 6b010038..4847d117 100644 --- a/tests/unit/retrievers/external/test_qdrant.py +++ b/tests/unit/retrievers/external/test_qdrant.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from unittest import mock from unittest.mock import MagicMock @@ -70,7 +71,7 @@ def test_qdrant_retriever_search_happy_path( driver.execute_query.assert_called_once_with( search_query, { - "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], "id_property": "sync_id", }, database_=None, @@ -149,7 +150,7 @@ def test_qdrant_retriever_search_return_properties( driver.execute_query.assert_called_once_with( search_query, { - "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], "id_property": "sync_id", }, database_=None, @@ -215,7 +216,7 @@ def test_qdrant_retriever_search_retrieval_query( driver.execute_query.assert_called_once_with( search_query, { - "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], "id_property": "sync_id", }, database_=None, @@ -267,3 +268,70 @@ def test_qdrant_retriever_invalid_retrieval_query( assert "retrieval_query" in str(exc_info.value) assert "Input should be a valid string" in str(exc_info.value) + + +def test_qdrant_retriever_search_custom_match_id_getter( + driver: MagicMock, client: MagicMock +) -> None: + def my_id_getter(point: ScoredPoint) -> Any: + if point.payload is None: + raise Exception("Payload is None") + return point.payload["data"]["id"] + + retriever = QdrantNeo4jRetriever( + driver=driver, + client=client, + collection_name="dummy-text", + id_property_neo4j="sync_id", + id_property_getter=my_id_getter, + ) + with mock.patch.object(retriever, "client") as mock_client: + top_k = 5 + mock_client.query_points.return_value = QueryResponse( + points=[ + ScoredPoint( + id=i, + version=0, + score=i / top_k, + payload={ + "data": {"id": f"node_{i}"}, + }, + ) + for i in range(top_k) + ] + ) + driver.execute_query.return_value = ( + [ + neo4j.Record({"node": {"sync_id": f"node_{i}"}, "score": i / top_k}) + for i in range(top_k) + ], + None, + None, + ) + query_vector = [1.0 for _ in range(1536)] + search_query = get_match_query() + records = retriever.search(query_vector=query_vector) + + driver.execute_query.assert_called_once_with( + search_query, + { + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], + "id_property": "sync_id", + }, + database_=None, + routing_=neo4j.RoutingControl.READ, + ) + + assert records == RetrieverResult( + items=[ + RetrieverResultItem( + content="", + metadata=None, + ) + for i in range(top_k) + ], + metadata={"__retriever": "QdrantNeo4jRetriever"}, + ) diff --git a/tests/unit/retrievers/external/test_weaviate.py b/tests/unit/retrievers/external/test_weaviate.py index d6091827..784b0edb 100644 --- a/tests/unit/retrievers/external/test_weaviate.py +++ b/tests/unit/retrievers/external/test_weaviate.py @@ -75,7 +75,7 @@ def test_text_search_remote_vector_store_happy_path(driver: MagicMock) -> None: search_query, { "match_params": [ - [node_id_value, node_match_score], + (node_id_value, node_match_score), ], "id_property": "sync_id", }, @@ -142,7 +142,7 @@ def test_text_search_remote_vector_store_return_properties(driver: MagicMock) -> search_query, { "match_params": [ - [node_id_value, node_match_score], + (node_id_value, node_match_score), ], "id_property": "sync_id", }, @@ -190,7 +190,7 @@ def test_text_search_remote_vector_store_retrieval_query(driver: MagicMock) -> N search_query, { "match_params": [ - [node_id_value, node_match_score], + (node_id_value, node_match_score), ], "id_property": "sync_id", },