@@ -620,13 +620,78 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
620
620
621
621
622
622
class SchemaFromExistingGraphExtractor (BaseSchemaBuilder ):
623
- """A class to build a GraphSchema object from an existing graph."""
623
+ """A class to build a GraphSchema object from an existing graph.
624
624
625
- def __init__ (self , driver : neo4j .Driver ) -> None :
625
+ Uses the get_structured_schema function to extract existing node labels,
626
+ relationship types, properties and existence constraints.
627
+
628
+ By default, the built schema does not allow any additional item (property,
629
+ node label, relationship type or pattern).
630
+
631
+ Args:
632
+ driver (neo4j.Driver): connection to the neo4j database.
633
+ additional_properties (bool, default False): see GraphSchema
634
+ additional_node_types (bool, default False): see GraphSchema
635
+ additional_relationship_types (bool, default False): see GraphSchema:
636
+ additional_patterns (bool, default False): see GraphSchema:
637
+ neo4j_database (Optional | str): name of the neo4j database to use
638
+ """
639
+
640
+ def __init__ (
641
+ self ,
642
+ driver : neo4j .Driver ,
643
+ additional_properties : bool = False ,
644
+ additional_node_types : bool = False ,
645
+ additional_relationship_types : bool = False ,
646
+ additional_patterns : bool = False ,
647
+ neo4j_database : Optional [str ] = None ,
648
+ ) -> None :
626
649
self .driver = driver
650
+ self .database = neo4j_database
651
+
652
+ self .additional_properties = additional_properties
653
+ self .additional_node_types = additional_node_types
654
+ self .additional_relationship_types = additional_relationship_types
655
+ self .additional_patterns = additional_patterns
656
+
657
+ @staticmethod
658
+ def _extract_required_properties (
659
+ structured_schema : dict [str , Any ],
660
+ ) -> list [tuple [str , str ]]:
661
+ """Extract a list of (node label (or rel type), property name) for which
662
+ an "EXISTENCE" or "KEY" constraint is defined in the DB.
663
+
664
+ Args:
665
+
666
+ structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.
667
+
668
+ Returns:
669
+
670
+ list of tuples of (node label (or rel type), property name)
671
+
672
+ """
673
+ schema_metadata = structured_schema .get ("metadata" , {})
674
+ existence_constraint = [] # list of (node label, property name)
675
+ for constraint in schema_metadata .get ("constraints" , []):
676
+ if constraint ["type" ] in (
677
+ "NODE_PROPERTY_EXISTENCE" ,
678
+ "NODE_KEY" ,
679
+ "RELATIONSHIP_PROPERTY_EXISTENCE" ,
680
+ "RELATIONSHIP_KEY" ,
681
+ ):
682
+ properties = constraint ["properties" ]
683
+ labels = constraint ["labelsOrTypes" ]
684
+ # note: existence constraint only apply to a single property
685
+ # and a single label
686
+ prop = properties [0 ]
687
+ lab = labels [0 ]
688
+ existence_constraint .append ((lab , prop ))
689
+ return existence_constraint
690
+
691
+ async def run (self ) -> GraphSchema :
692
+ structured_schema = get_structured_schema (self .driver , database = self .database )
693
+ existence_constraint = self ._extract_required_properties (structured_schema )
627
694
628
- async def run (self , ** kwargs : Any ) -> GraphSchema :
629
- structured_schema = get_structured_schema (self .driver )
630
695
node_labels = set (structured_schema ["node_props" ].keys ())
631
696
node_types = [
632
697
{
@@ -635,9 +700,11 @@ async def run(self, **kwargs: Any) -> GraphSchema:
635
700
{
636
701
"name" : p ["property" ],
637
702
"type" : p ["type" ],
703
+ "required" : (key , p ["property" ]) in existence_constraint ,
638
704
}
639
705
for p in properties
640
706
],
707
+ "additional_properties" : self .additional_properties ,
641
708
}
642
709
for key , properties in structured_schema ["node_props" ].items ()
643
710
]
@@ -649,6 +716,7 @@ async def run(self, **kwargs: Any) -> GraphSchema:
649
716
{
650
717
"name" : p ["property" ],
651
718
"type" : p ["type" ],
719
+ "required" : (key , p ["property" ]) in existence_constraint ,
652
720
}
653
721
for p in properties
654
722
],
@@ -687,5 +755,8 @@ async def run(self, **kwargs: Any) -> GraphSchema:
687
755
"node_types" : node_types ,
688
756
"relationship_types" : relationship_types ,
689
757
"patterns" : patterns ,
758
+ "additional_node_types" : self .additional_node_types ,
759
+ "additional_relationship_types" : self .additional_relationship_types ,
760
+ "additional_patterns" : self .additional_patterns ,
690
761
}
691
762
)
0 commit comments