Skip to content

Commit 11f8e9c

Browse files
committed
Extract required properties from existing constraints
1 parent 9de9b4c commit 11f8e9c

File tree

1 file changed

+75
-4
lines changed
  • src/neo4j_graphrag/experimental/components

1 file changed

+75
-4
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -620,13 +620,78 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
620620

621621

622622
class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
623-
"""A class to build a GraphSchema object from an existing graph."""
623+
"""A class to build a GraphSchema object from an existing graph.
624624
625-
def __init__(self, driver: neo4j.Driver) -> None:
625+
Uses the get_structured_schema function to extract existing node labels,
626+
relationship types, properties and existence constraints.
627+
628+
By default, the built schema does not allow any additional item (property,
629+
node label, relationship type or pattern).
630+
631+
Args:
632+
driver (neo4j.Driver): connection to the neo4j database.
633+
additional_properties (bool, default False): see GraphSchema
634+
additional_node_types (bool, default False): see GraphSchema
635+
additional_relationship_types (bool, default False): see GraphSchema:
636+
additional_patterns (bool, default False): see GraphSchema:
637+
neo4j_database (Optional | str): name of the neo4j database to use
638+
"""
639+
640+
def __init__(
641+
self,
642+
driver: neo4j.Driver,
643+
additional_properties: bool = False,
644+
additional_node_types: bool = False,
645+
additional_relationship_types: bool = False,
646+
additional_patterns: bool = False,
647+
neo4j_database: Optional[str] = None,
648+
) -> None:
626649
self.driver = driver
650+
self.database = neo4j_database
651+
652+
self.additional_properties = additional_properties
653+
self.additional_node_types = additional_node_types
654+
self.additional_relationship_types = additional_relationship_types
655+
self.additional_patterns = additional_patterns
656+
657+
@staticmethod
658+
def _extract_required_properties(
659+
structured_schema: dict[str, Any],
660+
) -> list[tuple[str, str]]:
661+
"""Extract a list of (node label (or rel type), property name) for which
662+
an "EXISTENCE" or "KEY" constraint is defined in the DB.
663+
664+
Args:
665+
666+
structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.
667+
668+
Returns:
669+
670+
list of tuples of (node label (or rel type), property name)
671+
672+
"""
673+
schema_metadata = structured_schema.get("metadata", {})
674+
existence_constraint = [] # list of (node label, property name)
675+
for constraint in schema_metadata.get("constraints", []):
676+
if constraint["type"] in (
677+
"NODE_PROPERTY_EXISTENCE",
678+
"NODE_KEY",
679+
"RELATIONSHIP_PROPERTY_EXISTENCE",
680+
"RELATIONSHIP_KEY",
681+
):
682+
properties = constraint["properties"]
683+
labels = constraint["labelsOrTypes"]
684+
# note: existence constraint only apply to a single property
685+
# and a single label
686+
prop = properties[0]
687+
lab = labels[0]
688+
existence_constraint.append((lab, prop))
689+
return existence_constraint
690+
691+
async def run(self) -> GraphSchema:
692+
structured_schema = get_structured_schema(self.driver, database=self.database)
693+
existence_constraint = self._extract_required_properties(structured_schema)
627694

628-
async def run(self, **kwargs: Any) -> GraphSchema:
629-
structured_schema = get_structured_schema(self.driver)
630695
node_labels = set(structured_schema["node_props"].keys())
631696
node_types = [
632697
{
@@ -635,9 +700,11 @@ async def run(self, **kwargs: Any) -> GraphSchema:
635700
{
636701
"name": p["property"],
637702
"type": p["type"],
703+
"required": (key, p["property"]) in existence_constraint,
638704
}
639705
for p in properties
640706
],
707+
"additional_properties": self.additional_properties,
641708
}
642709
for key, properties in structured_schema["node_props"].items()
643710
]
@@ -649,6 +716,7 @@ async def run(self, **kwargs: Any) -> GraphSchema:
649716
{
650717
"name": p["property"],
651718
"type": p["type"],
719+
"required": (key, p["property"]) in existence_constraint,
652720
}
653721
for p in properties
654722
],
@@ -687,5 +755,8 @@ async def run(self, **kwargs: Any) -> GraphSchema:
687755
"node_types": node_types,
688756
"relationship_types": relationship_types,
689757
"patterns": patterns,
758+
"additional_node_types": self.additional_node_types,
759+
"additional_relationship_types": self.additional_relationship_types,
760+
"additional_patterns": self.additional_patterns,
690761
}
691762
)

0 commit comments

Comments
 (0)