diff --git a/CHANGELOG.md b/CHANGELOG.md index f14f845f..47a5d33c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ - Fixed an edge case where the LLM can output a property with type 'map', which was causing errors during import as it is not a valid property type in Neo4j. +### Added + +- Added `schema_visualization` function to visualize a graph schema using neo4j-viz. + ## 1.9.1 ### Fixed diff --git a/docs/source/api.rst b/docs/source/api.rst index d8280b1c..4066348b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -83,6 +83,12 @@ SchemaFromTextExtractor .. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaFromTextExtractor :members: run +schema_visualization +-------------------- + +.. autofunction:: neo4j_graphrag.experimental.utils.schema.schema_visualization + + EntityRelationExtractor ======================= diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index f00f8189..b574171b 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -852,6 +852,28 @@ You can also save and reload the extracted schema: restored_schema = GraphSchema.from_file("my_schema.json") # or my_schema.yaml +Schema Visualization +-------------------- + +It is possible to visualize a validated schema or a schema dict using the `schema_visualization` function. This function +returns a VisualizationGraph object (from the neo4j-viz package) that can visualized like this: + +.. code:: python + + from neo4j_graphrag.experimental.utils.schema import schema_visualization + + VG = schema_visualization(schema) + html = VG.render() + + # in Jupyter: + display(html) + + # to save the generated HTML + with open("my_schema.html", "w") as f: + f.write(html.data) + + + Entity and Relation Extractor ============================= diff --git a/poetry.lock b/poetry.lock index 2c3023b1..e48e7784 100644 --- a/poetry.lock +++ b/poetry.lock @@ -276,9 +276,10 @@ idna = ">=2.5" name = "asttokens" version = "3.0.0" description = "Annotate AST trees with source code positions" -optional = false +optional = true python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2"}, {file = "asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7"}, @@ -797,7 +798,7 @@ files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {main = "(platform_system == \"Windows\" or sys_platform == \"win32\" or extra == \"experimental\") and (extra == \"pinecone\" or extra == \"cohere\" or extra == \"sentence-transformers\" or extra == \"experimental\" or extra == \"openai\" or extra == \"nlp\" or sys_platform == \"win32\")", dev = "sys_platform == \"win32\" or platform_system == \"Windows\""} +markers = {main = "(sys_platform == \"win32\" or platform_system == \"Windows\" or extra == \"experimental\") and (sys_platform == \"win32\" or extra == \"pinecone\" or extra == \"cohere\" or extra == \"sentence-transformers\" or extra == \"experimental\" or extra == \"openai\" or extra == \"nlp\") and (platform_system == \"Windows\" or extra == \"kg-creation-tools\" or extra == \"experimental\" or extra == \"nlp\") and (extra == \"kg-creation-tools\" or extra == \"experimental\" or extra == \"nlp\" or extra == \"pinecone\" or extra == \"cohere\" or extra == \"sentence-transformers\" or extra == \"openai\")", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "confection" @@ -1034,9 +1035,10 @@ typing-inspect = ">=0.4.0,<1" name = "decorator" version = "5.2.1" description = "Decorators for Humans" -optional = false +optional = true python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a"}, {file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"}, @@ -1200,6 +1202,7 @@ files = [ {file = "enum_tools-0.12.0-py3-none-any.whl", hash = "sha256:d69b019f193c7b850b17d9ce18440db7ed62381571409af80ccc08c5218b340a"}, {file = "enum_tools-0.12.0.tar.gz", hash = "sha256:13ceb9376a4c5f574a1e7c5f9c8eb7f3d3fbfbb361cc18a738df1a58dfefd460"}, ] +markers = {main = "extra == \"kg-creation-tools\" or extra == \"experimental\""} [package.dependencies] pygments = ">=2.6.1" @@ -1235,11 +1238,11 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["main", "dev"] -markers = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"}, {file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"}, ] +markers = {main = "(extra == \"experimental\" or extra == \"weaviate\" or extra == \"google\" or extra == \"cohere\" or extra == \"mistralai\" or extra == \"qdrant\" or extra == \"openai\" or extra == \"anthropic\" or extra == \"ollama\" or extra == \"kg-creation-tools\") and python_version < \"3.11\"", dev = "python_version < \"3.11\""} [package.dependencies] typing-extensions = {version = ">=4.6.0", markers = "python_version < \"3.13\""} @@ -1251,9 +1254,10 @@ test = ["pytest (>=6)"] name = "executing" version = "2.2.0" description = "Get the currently executing AST node of a frame, and other information" -optional = false +optional = true python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa"}, {file = "executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755"}, @@ -2395,9 +2399,10 @@ files = [ name = "ipython" version = "8.18.1" description = "IPython: Productive Interactive Computing" -optional = false +optional = true python-versions = ">=3.9" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "ipython-8.18.1-py3-none-any.whl", hash = "sha256:e8267419d72d81955ec1177f8a29aaa90ac80ad647499201119e2f05e99aa397"}, {file = "ipython-8.18.1.tar.gz", hash = "sha256:ca6f079bb33457c66e233e4580ebfc4128855b4cf6370dddd73842a9563e8a27"}, @@ -2433,9 +2438,10 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pa name = "jedi" version = "0.19.2" description = "An autocompletion tool for Python that can be used for text editors." -optional = false +optional = true python-versions = ">=3.6" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9"}, {file = "jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0"}, @@ -3237,9 +3243,10 @@ tests = ["pytest", "simplejson"] name = "matplotlib-inline" version = "0.1.7" description = "Inline Matplotlib backend for Jupyter" -optional = false +optional = true python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca"}, {file = "matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90"}, @@ -3662,14 +3669,15 @@ pyarrow = ["pyarrow (>=1.0.0)"] [[package]] name = "neo4j-viz" -version = "0.2.6" +version = "0.4.2" description = "A simple graph visualization tool" -optional = false +optional = true python-versions = ">=3.9" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ - {file = "neo4j_viz-0.2.6-py3-none-any.whl", hash = "sha256:4adc32254d611b53cac29017e32e273320b1f9b7194d98e20db1beb80ac8a871"}, - {file = "neo4j_viz-0.2.6.tar.gz", hash = "sha256:38ab21b09b9eff8e955bb38bd8253418da35253223c47765dd69e959be2578b2"}, + {file = "neo4j_viz-0.4.2-py3-none-any.whl", hash = "sha256:b1185cfdca62359315bc9cee08d412972e9eef95b7b72a0dd2cd62c184533d6e"}, + {file = "neo4j_viz-0.4.2.tar.gz", hash = "sha256:d4bfcb10c9e01d8a3bf55178942a219cee6d1ad4490f7b7fc967b153fc329221"}, ] [package.dependencies] @@ -3679,11 +3687,11 @@ pydantic = ">=2,<3" pydantic-extra-types = ">=2,<3" [package.extras] -dev = ["ipykernel (==6.29.5)", "mypy (==1.15.0)", "nbconvert (==7.16.6)", "palettable (==3.3.3)", "pytest (==8.3.4)", "pytest-mock (==3.14.0)", "ruff (==0.9.7)", "selenium (==4.28.1)", "streamlit (==1.42.0)"] +dev = ["ipykernel (==6.29.5)", "matplotlib (>=3.9.4)", "mypy (==1.15.0)", "nbconvert (==7.16.6)", "palettable (==3.3.3)", "pytest (==8.3.4)", "pytest-mock (==3.14.0)", "ruff (==0.11.8)", "selenium (==4.32.0)", "streamlit (==1.45.0)"] docs = ["enum-tools[sphinx]", "nbsphinx (==0.9.6)", "nbsphinx-link (==1.3.1)", "sphinx (==8.1.3)"] gds = ["graphdatascience (>=1,<2)"] neo4j = ["neo4j"] -notebook = ["ipykernel (==6.29.5)", "ipywidgets (>=8.0.0)", "matplotlib (==3.10.0)", "neo4j (>=5.26.0)", "palettable (==3.3.3)", "pykernel (==0.1.6)", "snowflake-snowpark-python (==1.26.0)"] +notebook = ["ipykernel (>=6.29.5)", "ipywidgets (>=8.0.0)", "matplotlib (>=3.9.4)", "neo4j (>=5.26.0)", "palettable (>=3.3.3)", "pykernel (>=0.1.6)", "snowflake-snowpark-python (==1.26.0)"] pandas = ["pandas (>=2,<3)", "pandas-stubs (>=2,<3)"] [[package]] @@ -4332,9 +4340,10 @@ xml = ["lxml (>=4.9.2)"] name = "parso" version = "0.8.4" description = "A Python Parser" -optional = false +optional = true python-versions = ">=3.6" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18"}, {file = "parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d"}, @@ -4360,10 +4369,10 @@ files = [ name = "pexpect" version = "4.9.0" description = "Pexpect allows easy control of interactive console applications." -optional = false +optional = true python-versions = "*" -groups = ["main", "dev"] -markers = "sys_platform != \"win32\"" +groups = ["main"] +markers = "sys_platform != \"win32\" and (extra == \"kg-creation-tools\" or extra == \"experimental\")" files = [ {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"}, @@ -4666,9 +4675,10 @@ murmurhash = ">=0.28.0,<1.1.0" name = "prompt-toolkit" version = "3.0.51" description = "Library for building powerful interactive command lines in Python" -optional = false +optional = true python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "prompt_toolkit-3.0.51-py3-none-any.whl", hash = "sha256:52742911fde84e2d423e2f9a4cf1de7d7ac4e51958f648d9540e0fb8db077b07"}, {file = "prompt_toolkit-3.0.51.tar.gz", hash = "sha256:931a162e3b27fc90c86f1b48bb1fb2c528c2761475e57c9c06de13311c7b54ed"}, @@ -4829,10 +4839,10 @@ files = [ name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" -optional = false +optional = true python-versions = "*" -groups = ["main", "dev"] -markers = "sys_platform != \"win32\"" +groups = ["main"] +markers = "sys_platform != \"win32\" and (extra == \"kg-creation-tools\" or extra == \"experimental\")" files = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, @@ -4842,9 +4852,10 @@ files = [ name = "pure-eval" version = "0.2.3" description = "Safely evaluate AST nodes without side effects" -optional = false +optional = true python-versions = "*" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0"}, {file = "pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42"}, @@ -5033,9 +5044,10 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" name = "pydantic-extra-types" version = "2.10.5" description = "Extra Pydantic types." -optional = false +optional = true python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "pydantic_extra_types-2.10.5-py3-none-any.whl", hash = "sha256:b60c4e23d573a69a4f1a16dd92888ecc0ef34fb0e655b4f305530377fa70e7a8"}, {file = "pydantic_extra_types-2.10.5.tar.gz", hash = "sha256:1dcfa2c0cf741a422f088e0dbb4690e7bfadaaf050da3d6f80d6c3cf58a2bad8"}, @@ -5064,6 +5076,7 @@ files = [ {file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"}, {file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"}, ] +markers = {main = "extra == \"kg-creation-tools\" or extra == \"experimental\" or extra == \"nlp\""} [package.extras] windows-terminal = ["colorama (>=0.4.6)"] @@ -6611,9 +6624,10 @@ catalogue = ">=2.0.3,<2.1.0" name = "stack-data" version = "0.6.3" description = "Extract data from python stack frames and tracebacks for informative displays" -optional = false +optional = true python-versions = "*" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, @@ -7019,9 +7033,10 @@ telegram = ["requests"] name = "traitlets" version = "5.14.3" description = "Traitlets Python configuration system" -optional = false +optional = true python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, @@ -7321,9 +7336,10 @@ colorama = {version = ">=0.4.6", markers = "sys_platform == \"win32\" and python name = "wcwidth" version = "0.2.13" description = "Measures the displayed width of unicode strings in a terminal" -optional = false +optional = true python-versions = "*" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"kg-creation-tools\" or extra == \"experimental\"" files = [ {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, @@ -7836,4 +7852,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.1" python-versions = ">=3.9.0,<3.14" -content-hash = "f697f8880f1f3ac8fb150dd920596939ee2e895da81c42fc5f83b6123bb20e03" +content-hash = "a49fdff6b12387ac342d57f0f569309587ceecccd659847134621351af288d2e" diff --git a/pyproject.toml b/pyproject.toml index 9e44bd13..01637a97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ pyyaml = "^6.0.2" types-pyyaml = "^6.0.12.20240917" # optional deps langchain-text-splitters = {version = "^0.3.0", optional = true } -neo4j-viz = {version = "^0.2.2", optional = true } +neo4j-viz = {version = "^0.4.2", optional = true } weaviate-client = {version = "^4.6.1", optional = true } pinecone-client = {version = "^4.1.0", optional = true } google-cloud-aiplatform = {version = "^1.66.0", optional = true } @@ -74,7 +74,6 @@ sphinx = { version = "^7.2.6", python = "^3.9" } langchain-openai = {version = "^0.2.2", optional = true } langchain-huggingface = {version = "^0.1.0", optional = true } enum-tools = {extras = ["sphinx"], version = "^0.12.0"} -neo4j-viz = "^0.2.2" [tool.poetry.extras] weaviate = ["weaviate-client"] diff --git a/src/neo4j_graphrag/experimental/utils/__init__.py b/src/neo4j_graphrag/experimental/utils/__init__.py new file mode 100644 index 00000000..c0199c14 --- /dev/null +++ b/src/neo4j_graphrag/experimental/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/neo4j_graphrag/experimental/utils/schema.py b/src/neo4j_graphrag/experimental/utils/schema.py new file mode 100644 index 00000000..6e38e1f2 --- /dev/null +++ b/src/neo4j_graphrag/experimental/utils/schema.py @@ -0,0 +1,115 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Union + +try: + from neo4j_viz import VisualizationGraph, Node, Relationship +except ImportError: + VisualizationGraph = Node = Relationship = None # type: ignore + +from neo4j_graphrag.experimental.components.schema import ( + GraphSchema, + NodeType, + PropertyType, +) + + +def schema_visualization( + schema: Union[dict[str, Any], GraphSchema], +) -> VisualizationGraph: + """Helper function to visualize a GraphSchema using the neo4j-viz library. + + Usage: + + .. code:: python + + VG = schema_visualization(schema) + html = VG.render() + + # in Jupyter: + display(html) + + # to save the generated HTML + with open("my_schema.html", "w") as f: + f.write(html.data) + """ + if VisualizationGraph is None: + raise ImportError( + "Please install neo4j-viz to use the graph schema visualization feature: pip install neo4j-viz" + ) + + schema_object = GraphSchema.model_validate(schema) + + def _format_property_name(p: PropertyType) -> str: + """ + + Args: + p (PropertyType): the property to be formatted + + Returns: + str: the property name, suffixed with '*' if the property is required + + """ + return p.name + ("*" if p.required else "") + + def _relationship_properties(rel_type: str) -> dict[str, str]: + """Returns a dict {prop_name: prop_type} for all relationship properties. + + Args: + rel_type (str): the relationship type + + Returns: + dict[str, str]: the relationship properties {name: type} mapping for display + """ + for relationship_type in schema_object.relationship_types: + if relationship_type.label != rel_type: + continue + return { + _format_property_name(p): p.type for p in relationship_type.properties + } + return {} + + def _node_properties(node_type: NodeType) -> dict[str, str]: + """Returns a dict {prop_name: prop_type} for all node properties. + + Args: + node_type (NodeType): the node type object + + Returns: + dict[str, str]: the node properties {name: type} mapping for display + """ + return {_format_property_name(p): p.type for p in node_type.properties} + + nodes = [ + Node( # type: ignore + id=node_type.label, + caption=node_type.label, + properties=_node_properties(node_type), + ) + for node_type in schema_object.node_types + ] + relationships = [ + Relationship( # type: ignore + source=pattern[0], + target=pattern[2], + caption=pattern[1], + properties=_relationship_properties(pattern[1]), + ) + for pattern in schema_object.patterns + ] + + VG = VisualizationGraph(nodes=nodes, relationships=relationships) + VG.color_nodes(field="caption") + return VG diff --git a/tests/unit/experimental/utils/__init__.py b/tests/unit/experimental/utils/__init__.py new file mode 100644 index 00000000..c0199c14 --- /dev/null +++ b/tests/unit/experimental/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/experimental/utils/test_schema.py b/tests/unit/experimental/utils/test_schema.py new file mode 100644 index 00000000..09055da0 --- /dev/null +++ b/tests/unit/experimental/utils/test_schema.py @@ -0,0 +1,107 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any +from unittest.mock import patch + +import pytest +from pydantic import ValidationError + +from neo4j_viz import VisualizationGraph +from neo4j_graphrag.experimental.components.schema import GraphSchema +from neo4j_graphrag.experimental.utils.schema import schema_visualization + + +@pytest.fixture(scope="module") +def valid_schema_dict() -> dict[str, Any]: + return { + "node_types": [ + "Location", + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": True}, + {"name": "birthYear", "type": "INTEGER"}, + ], + }, + ], + "relationship_types": [ + "BORN_IN", + { + "label": "KNOWS", + "properties": [ + {"name": "since", "type": "LOCAL_DATETIME"}, + ], + }, + ], + "patterns": [ + ("Person", "BORN_IN", "Location"), + ("Person", "KNOWS", "Person"), + ], + } + + +@pytest.fixture(scope="module") +def invalid_schema_dict() -> dict[str, Any]: + return { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING", "required": True}, + {"name": "birthYear", "type": "INTEGER"}, + ], + }, + ], + "relationship_types": [ + "BORN_IN", + ], + "patterns": [ + ( + "Person", + "BORN_IN", + "Location", + ), # invalid pattern, "Location" node type not defined + ], + } + + +@patch("neo4j_graphrag.experimental.utils.schema.VisualizationGraph", None) +def test_schema_visualization_import_error() -> None: + with pytest.raises(ImportError): + schema_visualization({}) + + +def test_schema_visualization_invalid_schema_dict( + invalid_schema_dict: dict[str, Any], +) -> None: + with pytest.raises(ValidationError): + schema_visualization(invalid_schema_dict) + + +def test_schema_visualization_valid_schema_dict( + valid_schema_dict: dict[str, Any], +) -> None: + g = schema_visualization(valid_schema_dict) + assert isinstance(g, VisualizationGraph) + assert len(g.nodes) == 2 + assert len(g.relationships) == 2 + + +def test_schema_visualization_schema_object(valid_schema_dict: dict[str, Any]) -> None: + schema = GraphSchema.model_validate(valid_schema_dict) + g = schema_visualization(schema) + assert isinstance(g, VisualizationGraph) + assert len(g.nodes) == 2 + assert len(g.relationships) == 2