Skip to content

Commit f7a669a

Browse files
committed
Test happy path
1 parent 4de00d1 commit f7a669a

File tree

2 files changed

+98
-14
lines changed

2 files changed

+98
-14
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
from neo4j_graphrag.schema import get_structured_schema
5050

5151

52+
logger = logging.getLogger(__name__)
53+
54+
5255
class PropertyType(BaseModel):
5356
"""
5457
Represents a property on a node or relationship in the graph.
@@ -622,19 +625,19 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
622625
class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
623626
"""A class to build a GraphSchema object from an existing graph.
624627
625-
Uses the get_structured_schema function to extract existing node labels,
626-
relationship types, properties and existence constraints.
628+
Uses the get_structured_schema function to extract existing node labels,
629+
relationship types, properties and existence constraints.
627630
628-
By default, the built schema does not allow any additional item (property,
629-
node label, relationship type or pattern).
631+
By default, the built schema does not allow any additional item (property,
632+
node label, relationship type or pattern).
630633
631-
Args:
632-
driver (neo4j.Driver): connection to the neo4j database.
633-
additional_properties (bool, default False): see GraphSchema
634-
additional_node_types (bool, default False): see GraphSchema
635-
additional_relationship_types (bool, default False): see GraphSchema:
636-
additional_patterns (bool, default False): see GraphSchema:
637-
neo4j_database (Optional | str): name of the neo4j database to use
634+
Args:
635+
driver (neo4j.Driver): connection to the neo4j database.
636+
additional_properties (bool, default False): see GraphSchema
637+
additional_node_types (bool, default False): see GraphSchema
638+
additional_relationship_types (bool, default False): see GraphSchema:
639+
additional_patterns (bool, default False): see GraphSchema:
640+
neo4j_database (Optional | str): name of the neo4j database to use
638641
"""
639642

640643
def __init__(
@@ -672,7 +675,7 @@ def _extract_required_properties(
672675
"""
673676
schema_metadata = structured_schema.get("metadata", {})
674677
existence_constraint = [] # list of (node label, property name)
675-
for constraint in schema_metadata.get("constraints", []):
678+
for constraint in schema_metadata.get("constraint", []):
676679
if constraint["type"] in (
677680
"NODE_PROPERTY_EXISTENCE",
678681
"NODE_KEY",
@@ -688,10 +691,11 @@ def _extract_required_properties(
688691
existence_constraint.append((lab, prop))
689692
return existence_constraint
690693

691-
async def run(self) -> GraphSchema:
694+
async def run(self, *args, **kwargs) -> GraphSchema:
692695
structured_schema = get_structured_schema(self.driver, database=self.database)
693696
existence_constraint = self._extract_required_properties(structured_schema)
694697

698+
# node label with properties
695699
node_labels = set(structured_schema["node_props"].keys())
696700
node_types = [
697701
{
@@ -708,6 +712,8 @@ async def run(self) -> GraphSchema:
708712
}
709713
for key, properties in structured_schema["node_props"].items()
710714
]
715+
716+
# relationships with properties
711717
rel_labels = set(structured_schema["rel_props"].keys())
712718
relationship_types = [
713719
{
@@ -723,27 +729,41 @@ async def run(self) -> GraphSchema:
723729
}
724730
for key, properties in structured_schema["rel_props"].items()
725731
]
732+
726733
patterns = [
727734
(s["start"], s["type"], s["end"])
728735
for s in structured_schema["relationships"]
729736
]
737+
730738
# deal with nodes and relationships without properties
731739
for source, rel, target in patterns:
732740
if source not in node_labels:
741+
if not self.additional_properties:
742+
logger.warning(
743+
f"SCHEMA: found node label {source} without property and additional_properties=False: this node label will always be pruned!"
744+
)
733745
node_labels.add(source)
734746
node_types.append(
735747
{
736748
"label": source,
737749
}
738750
)
739751
if target not in node_labels:
752+
if not self.additional_properties:
753+
logger.warning(
754+
f"SCHEMA: found node label {target} without property and additional_properties=False: this node label will always be pruned!"
755+
)
740756
node_labels.add(target)
741757
node_types.append(
742758
{
743759
"label": target,
744760
}
745761
)
746762
if rel not in rel_labels:
763+
if not self.additional_properties:
764+
logger.warning(
765+
f"SCHEMA: found relationship type {rel} without property and additional_properties=False: this relationship type will always be pruned!"
766+
)
747767
rel_labels.add(rel)
748768
relationship_types.append(
749769
{

tests/unit/experimental/components/test_schema.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import json
1818
from typing import Tuple, Any
19-
from unittest.mock import AsyncMock, patch
19+
from unittest.mock import AsyncMock, patch, Mock
2020

2121
import pytest
2222
from pydantic import ValidationError
@@ -29,6 +29,7 @@
2929
RelationshipType,
3030
SchemaFromTextExtractor,
3131
GraphSchema,
32+
SchemaFromExistingGraphExtractor,
3233
)
3334
import os
3435
import tempfile
@@ -957,3 +958,66 @@ async def test_schema_from_text_filters_relationships_without_labels(
957958
assert len(schema.patterns) == 2
958959
assert ("Person", "WORKS_FOR", "Organization") in schema.patterns
959960
assert ("Person", "MANAGES", "Organization") in schema.patterns
961+
962+
963+
@pytest.mark.asyncio
964+
@patch("neo4j_graphrag.experimental.components.schema.get_structured_schema")
965+
async def test_schema_from_existing_graph(mock_get_structured_schema: Mock):
966+
mock_get_structured_schema.return_value = {
967+
"node_props": {
968+
"Person": [
969+
{"property": "id", "type": "INTEGER"},
970+
{"property": "name", "type": "STRING"},
971+
]
972+
},
973+
"rel_props": {"KNOWS": [{"property": "fromDate", "type": "DATE"}]},
974+
"relationships": [
975+
{"start": "Person", "type": "KNOWS", "end": "Person"},
976+
{"start": "Person", "type": "LIVES_IN", "end": "City"},
977+
],
978+
"metadata": {
979+
"constraint": [
980+
{
981+
"id": 7,
982+
"name": "person_id",
983+
"type": "NODE_PROPERTY_EXISTENCE",
984+
"entityType": "NODE",
985+
"labelsOrTypes": ["Person"],
986+
"properties": ["id"],
987+
"ownedIndex": "person_id",
988+
"propertyType": None,
989+
},
990+
],
991+
"index": [
992+
{
993+
"label": "Person",
994+
"properties": ["name"],
995+
"size": 2,
996+
"type": "RANGE",
997+
"valuesSelectivity": 1.0,
998+
"distinctValues": 2.0,
999+
},
1000+
],
1001+
},
1002+
}
1003+
driver = Mock()
1004+
schema_builder = SchemaFromExistingGraphExtractor(
1005+
driver=driver,
1006+
)
1007+
schema = await schema_builder.run()
1008+
assert isinstance(schema, GraphSchema)
1009+
assert len(schema.node_types) == 2
1010+
person_node_type = schema.node_type_from_label("Person")
1011+
assert person_node_type is not None
1012+
id_person_property = [p for p in person_node_type.properties if p.name == "id"][0]
1013+
assert id_person_property.required is True
1014+
1015+
assert schema.node_type_from_label("City") is not None
1016+
assert len(schema.relationship_types) == 2
1017+
assert schema.relationship_type_from_label("KNOWS") is not None
1018+
assert schema.relationship_type_from_label("LIVES_IN") is not None
1019+
1020+
assert schema.patterns == (
1021+
("Person", "KNOWS", "Person"),
1022+
("Person", "LIVES_IN", "City"),
1023+
)

0 commit comments

Comments
 (0)