diff --git a/cookbook/graphstores/neo4j_basic_examples.ipynb b/cookbook/graphstores/neo4j_basic_examples.ipynb new file mode 100644 index 0000000..ac4bcaf --- /dev/null +++ b/cookbook/graphstores/neo4j_basic_examples.ipynb @@ -0,0 +1,436 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GraphStore: Neo4j Database Basic Examples\n", + "\n", + "This notebook demonstrates how to use the `Neo4jGraphStore` in `Floki` for basic graph-based tasks. We will explore:\n", + "\n", + "* Initializing the `Neo4jGraphStore` class.\n", + "* Adding sample nodes.\n", + "* Adding one sample relationship.\n", + "* Querying graph database.\n", + "* Resseting database." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install Required Libraries\n", + "\n", + "Ensure floki and neo4j are installed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install floki-ai neo4j" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import Environment Variables\n", + "\n", + "Load your API keys or other configuration values using `dotenv`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dotenv import load_dotenv\n", + "load_dotenv() # Load environment variables from a `.env` file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Enable Logging" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.basicConfig(level=logging.INFO)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deploy Neo4j Graph Database as Docker Container" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#docker run \\\n", + "#--restart always \\\n", + "#--publish=7474:7474 --publish=7687:7687 \\\n", + "#--env NEO4J_AUTH=neo4j/neo4j \\\n", + "#--volume=neo4j-data \\\n", + "#--name neo4j-apoc \\\n", + "#--env NEO4J_apoc_export_file_enabled=true \\\n", + "#--env NEO4J_apoc_import_file_enabled=true \\\n", + "#--env NEO4J_apoc_import_file_use__neo4j__config=true \\\n", + "#--env NEO4J_PLUGINS=\\[\\\"apoc\\\"\\] \\\n", + "#neo4j:latest" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize Neo4jGraphStore\n", + "\n", + "Set the `NEO4J_URI`, `NEO4J_USERNAME` and `NEO4J_PASSWORD` variables in a `.env` file. The URI can be set to `bolt://localhost:7687`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:floki.storage.graphstores.neo4j.client:Successfully created the driver for URI: bolt://localhost:7687\n", + "INFO:floki.storage.graphstores.neo4j.base:Neo4jGraphStore initialized with database neo4j\n" + ] + } + ], + "source": [ + "from floki.storage.graphstores.neo4j import Neo4jGraphStore\n", + "import os\n", + "\n", + "# Initialize Neo4jGraphStore\n", + "graph_store = Neo4jGraphStore(\n", + " uri=os.getenv(\"NEO4J_URI\"),\n", + " user=os.getenv(\"NEO4J_USERNAME\"),\n", + " password=os.getenv(\"NEO4J_PASSWORD\"),\n", + " database=\"neo4j\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:floki.storage.graphstores.neo4j.client:Connected to Neo4j Kernel version 5.15.0 (community edition)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neo4j connection successful\n" + ] + } + ], + "source": [ + "# Test the connection\n", + "assert graph_store.client.test_connection(), \"Connection to Neo4j failed\"\n", + "print(\"Neo4j connection successful\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Add Sample Nodes\n", + "Create and add nodes to the graph store:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from floki.types import Node\n", + "\n", + "# Sample nodes\n", + "nodes = [\n", + " Node(\n", + " id=\"1\",\n", + " label=\"Person\",\n", + " properties={\"name\": \"Alice\", \"age\": 30},\n", + " additional_labels=[\"Employee\"]\n", + " ),\n", + " Node(\n", + " id=\"2\",\n", + " label=\"Person\",\n", + " properties={\"name\": \"Bob\", \"age\": 25},\n", + " additional_labels=[\"Contractor\"]\n", + " )\n", + "]\n", + "\n", + "# Add nodes\n", + "graph_store.add_nodes(nodes)\n", + "print(\"Nodes added successfully\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Add Sample Relationship\n", + "Create and add a relationship to the graph store:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from floki.types import Relationship\n", + "\n", + "# Sample relationships\n", + "relationships = [\n", + " Relationship(\n", + " source_node_id=\"1\",\n", + " target_node_id=\"2\",\n", + " type=\"KNOWS\",\n", + " properties={\"since\": \"2023\"}\n", + " )\n", + "]\n", + "\n", + "# Add relationships\n", + "graph_store.add_relationships(relationships)\n", + "print(\"Relationships added successfully\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Query Graph" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:floki.storage.graphstores.neo4j.base:Query executed successfully: MATCH (n) RETURN n | Time: 0.01 seconds | Results: 2\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nodes in the database:\n", + "{'n': {'createdAt': '2025-01-13T06:56:37.251710Z', 'name': 'Alice', 'id': '1', 'age': 30, 'updatedAt': '2025-01-13T06:56:37.251710Z'}}\n", + "{'n': {'createdAt': '2025-01-13T06:56:37.251710Z', 'name': 'Bob', 'id': '2', 'age': 25, 'updatedAt': '2025-01-13T06:56:37.251710Z'}}\n" + ] + } + ], + "source": [ + "query = \"MATCH (n) RETURN n\"\n", + "results = graph_store.query(query)\n", + "print(\"Nodes in the database:\")\n", + "for record in results:\n", + " print(record)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:floki.storage.graphstores.neo4j.base:Query executed successfully: \n", + "MATCH (a)-[r]->(b)\n", + "RETURN a.id AS source, b.id AS target, type(r) AS type, properties(r) AS properties\n", + " | Time: 0.01 seconds | Results: 1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Relationships in the database:\n", + "{'source': '1', 'target': '2', 'type': 'KNOWS', 'properties': {'updatedAt': '2025-01-13T07:01:05.566187Z', 'createdAt': '2025-01-13T07:01:05.566187Z', 'since': '2023'}}\n" + ] + } + ], + "source": [ + "query = \"\"\"\n", + "MATCH (a)-[r]->(b)\n", + "RETURN a.id AS source, b.id AS target, type(r) AS type, properties(r) AS properties\n", + "\"\"\"\n", + "results = graph_store.query(query)\n", + "print(\"Relationships in the database:\")\n", + "for record in results:\n", + " print(record)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:floki.storage.graphstores.neo4j.base:Query executed successfully: \n", + "MATCH (n)-[r]->(m)\n", + "RETURN n, r, m\n", + " | Time: 0.09 seconds | Results: 1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nodes and relationships in the database:\n", + "{'n': {'createdAt': '2025-01-13T06:56:37.251710Z', 'name': 'Alice', 'id': '1', 'age': 30, 'updatedAt': '2025-01-13T06:56:37.251710Z'}, 'r': ({'createdAt': '2025-01-13T06:56:37.251710Z', 'name': 'Alice', 'id': '1', 'age': 30, 'updatedAt': '2025-01-13T06:56:37.251710Z'}, 'KNOWS', {'createdAt': '2025-01-13T06:56:37.251710Z', 'name': 'Bob', 'id': '2', 'age': 25, 'updatedAt': '2025-01-13T06:56:37.251710Z'}), 'm': {'createdAt': '2025-01-13T06:56:37.251710Z', 'name': 'Bob', 'id': '2', 'age': 25, 'updatedAt': '2025-01-13T06:56:37.251710Z'}}\n" + ] + } + ], + "source": [ + "query = \"\"\"\n", + "MATCH (n)-[r]->(m)\n", + "RETURN n, r, m\n", + "\"\"\"\n", + "results = graph_store.query(query)\n", + "print(\"Nodes and relationships in the database:\")\n", + "for record in results:\n", + " print(record)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reset Graph" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:floki.storage.graphstores.neo4j.base:Database reset successfully\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Graph database has been reset.\n" + ] + } + ], + "source": [ + "graph_store.reset()\n", + "print(\"Graph database has been reset.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:floki.storage.graphstores.neo4j.base:Query executed successfully: MATCH (n) RETURN n | Time: 0.00 seconds | Results: 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nodes in the database:\n" + ] + } + ], + "source": [ + "query = \"MATCH (n) RETURN n\"\n", + "results = graph_store.query(query)\n", + "print(\"Nodes in the database:\")\n", + "for record in results:\n", + " print(record)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/poetry.lock b/poetry.lock index 7307bdf..8ff0a06 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1481,25 +1481,6 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} -[[package]] -name = "neo4j" -version = "5.26.0" -description = "Neo4j Bolt driver for Python" -optional = false -python-versions = ">=3.7" -files = [ - {file = "neo4j-5.26.0-py3-none-any.whl", hash = "sha256:511a6a9468ca89b521bf686f885a2070acc462b1d09821d43710bd477acdf11e"}, - {file = "neo4j-5.26.0.tar.gz", hash = "sha256:51b25ba127b7b9fdae1ddf48ae697ddfab331e60f4b6d8488d1fc1f74ec60dcc"}, -] - -[package.dependencies] -pytz = "*" - -[package.extras] -numpy = ["numpy (>=1.7.0,<2.0.0)"] -pandas = ["numpy (>=1.7.0,<2.0.0)", "pandas (>=1.1.0,<3.0.0)"] -pyarrow = ["pyarrow (>=1.0.0)"] - [[package]] name = "nest-asyncio" version = "1.6.0" @@ -1996,17 +1977,6 @@ files = [ [package.dependencies] six = ">=1.5" -[[package]] -name = "pytz" -version = "2024.2" -description = "World timezone definitions, modern and historical" -optional = false -python-versions = "*" -files = [ - {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"}, - {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, -] - [[package]] name = "pywin32" version = "308" @@ -2569,4 +2539,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "6ad2484e9eb6919d437ccdcf6aef9b2ea49815d244ae1e1b5740ae556168a67f" +content-hash = "a383e7b76c4007de6bcdd187712a611ba1ee9de0fe5f3d938496b47e404216d9" diff --git a/pyproject.toml b/pyproject.toml index 17dc695..e21c49f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "floki-ai" -version = "0.9.1" +version = "0.10.1" description = "Agentic Workflows Made Simple" readme = "README.md" authors = [{ name = "Roberto Rodriguez" }] @@ -31,7 +31,6 @@ dependencies = [ "openai==1.59.6", "openapi-pydantic==0.5.1", "regex>=2023.12.25", - "neo4j==5.26.0", "Jinja2==3.1.5", "azure-identity==1.19.0", "dapr==1.14.0", @@ -46,7 +45,7 @@ homepage = "https://github.com/Cyb3rWard0g/floki" [tool.poetry] name = "floki" -version = "0.9.1" +version = "0.10.1" description = "Agentic Workflows Made Simple" authors = ["Roberto Rodriguez"] license = "MIT" @@ -69,7 +68,6 @@ pydantic = "2.10.5" openai = "1.59.6" openapi-pydantic = "0.5.1" regex = "^2023.12.25" -neo4j = "5.26.0" Jinja2 = "3.1.5" azure-identity = "1.19.0" tornado = "^6.4.2" diff --git a/requirements.txt b/requirements.txt index ee1b478..06631cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ openai==1.59.6 openapi-pydantic==0.5.1 openapi-schema-pydantic==1.2.4 regex>=2023.12.25 -neo4j==5.26.0 Jinja2==3.1.5 azure-identity==1.19.0 dapr==1.14.0 diff --git a/setup.py b/setup.py index 8ca6eee..7fa51ae 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="floki-ai", - version="0.9.1", + version="0.10.1", author="Roberto Rodriguez", description="Agentic Workflows Made Simple", long_description=long_description, @@ -25,7 +25,6 @@ "openapi-pydantic==0.5.1", "openapi-schema-pydantic==1.2.4", "regex>=2023.12.25", - "neo4j==5.26.0", "Jinja2==3.1.5", "azure-identity==1.19.0", "dapr==1.14.0", diff --git a/src/floki/storage/graphstores/neo4j/base.py b/src/floki/storage/graphstores/neo4j/base.py index 9563f87..92d3551 100644 --- a/src/floki/storage/graphstores/neo4j/base.py +++ b/src/floki/storage/graphstores/neo4j/base.py @@ -1,12 +1,11 @@ -import logging from floki.storage.graphstores import GraphStoreBase from floki.storage.graphstores.neo4j.client import Neo4jClient from floki.storage.graphstores.neo4j.utils import value_sanitize, get_current_time from floki.types import Node, Relationship -from neo4j import Query -from neo4j.exceptions import Neo4jError, CypherSyntaxError -from pydantic import Field -from typing import Any, Dict, Optional, List +from pydantic import BaseModel, ValidationError, Field +from typing import Any, Dict, Optional, List, Literal +from collections import defaultdict +import logging logger = logging.getLogger(__name__) @@ -30,15 +29,41 @@ def model_post_init(self, __context: Any) -> None: """ Post-initialization to set up the Neo4j client after model instantiation. """ - self.client = Neo4jClient(self.uri, self.user, self.password, self.database) + self.client = Neo4jClient(uri=self.uri, user=self.user, password=self.password, database=self.database) logger.info(f"Neo4jGraphStore initialized with database {self.database}") # Complete post-initialization super().model_post_init(__context) + def batch_execute(self, query: str, data: List[Dict[str, Any]], batch_size: int = 1000) -> None: + """ + Execute a Cypher query in batches. + + Args: + query (str): The Cypher query to execute. + data (List[Dict[str, Any]]): The data to pass to the query. + batch_size (int): The size of each batch. Defaults to 1000. + + Raises: + ValueError: If there is an issue with the query execution. + """ + from neo4j.exceptions import Neo4jError + + total_batches = (len(data) + batch_size - 1) // batch_size + for i in range(0, len(data), batch_size): + batch = data[i:i + batch_size] + try: + with self.client.driver.session(database=self.client.database) as session: + # Pass the correct parameter name + session.run(query, {"data": batch}) + logger.info("Processed batch %d/%d", i // batch_size + 1, total_batches) + except Neo4jError as e: + logger.error("Batch execution failed: %s", str(e)) + raise ValueError(f"Batch execution failed: {str(e)}") + def add_node(self, node: Node) -> None: """ - Add a node to the Neo4j database with specified label and properties. + Add a single node to the Neo4j database. Args: node (Node): The node to add. @@ -46,75 +71,67 @@ def add_node(self, node: Node) -> None: Raises: ValueError: If there is an issue with the query execution. """ - query = f"""MERGE (n: {node.label} {{id: {node.id}}}) - ON CREATE SET n.createdAt = $current_time - SET n.updatedAt = $current_time, n += apoc.map.clean($props, [], []) - WITH n - CALL apoc.create.addLabels(n, $additional_labels) - YIELD node - """ - if node.embedding is not None: - query += """ - WITH node, $embedding AS embedding - CALL db.create.setNodeVectorProperty(node, 'embedding', embedding) - YIELD node - """ - query += "RETURN node" - - params = { - 'props': node.properties, - 'additional_labels': node.additional_labels, - 'embedding': node.embedding, - 'current_time': get_current_time() - } - try: - with self.client.driver.session(database=self.client.database) as session: - session.run(query, params) - logger.info("Node with label %s and properties %s added or updated successfully", node.label, node.properties) - except Neo4jError as e: - logger.error("Failed to add or update node: %s", str(e)) - raise ValueError(f"Failed to add or update node: {str(e)}") + # Encapsulate single node in a list and call `add_nodes` + self.add_nodes([node]) - def add_nodes(self, nodes: List[Node]) -> None: + def add_nodes(self, nodes: List[Node], batch_size: int = 1000) -> None: """ - Add multiple nodes to the Neo4j database with specified label and properties. + Add multiple nodes to the Neo4j database in batches, supporting different labels. + Handles cases where vector support is not available. Args: nodes (List[Node]): A list of nodes to add. + batch_size (int): The size of each batch. Defaults to 1000. Raises: ValueError: If there is an issue with the query execution. """ - query = """ - UNWIND $nodes AS node - MERGE (n:{node.label} {id: node.id}) - ON CREATE SET n.createdAt = $current_time - SET n.updatedAt = $current_time, n += apoc.map.clean(node.properties, [], []) - WITH n, node.additional_labels AS additional_labels, node.embedding AS embedding - CALL apoc.create.addLabels(n, additional_labels) - YIELD node - WITH node, embedding - WHERE embedding IS NOT NULL - CALL db.create.setNodeVectorProperty(node, 'embedding', embedding) - YIELD node - RETURN node - """ - - params = { - 'nodes': [n.model_dump() for n in nodes], - 'current_time': get_current_time() - } - try: - with self.client.driver.session(database=self.client.database) as session: - session.run(query, params) - logger.info("Nodes added or updated successfully") - except Neo4jError as e: - logger.error("Failed to add or update nodes: %s", str(e)) - raise ValueError(f"Failed to add or update nodes: {str(e)}") + + # Group nodes by their labels + nodes_by_label = defaultdict(list) + for node in nodes: + nodes_by_label[node.label].append(node) + + for label, grouped_nodes in nodes_by_label.items(): + query = f""" + UNWIND $data AS node + MERGE (n:`{label}` {{id: node.id}}) + ON CREATE SET n.createdAt = node.current_time + SET n.updatedAt = node.current_time, n += apoc.map.clean(node.properties, [], []) + WITH n, node.additional_labels AS additional_labels, node.embedding AS embedding + CALL apoc.create.addLabels(n, additional_labels) + YIELD node AS labeled_node + WITH labeled_node AS n, embedding + CALL apoc.do.when( + embedding IS NOT NULL, + 'CALL db.create.setNodeVectorProperty(n, "embedding", $embedding) YIELD node RETURN node', + 'RETURN n', + {{n: n, embedding: embedding}} + ) YIELD value AS final_node + RETURN final_node + """ + + # Prepare data for batch processing + current_time = get_current_time() + data = [ + { + **n.model_dump(), + "current_time": current_time + } + for n in grouped_nodes + ] + + # Execute in batches for the current label + try: + self.batch_execute(query, data, batch_size) + logger.info(f"Nodes with label `{label}` added successfully.") + except ValueError as e: + logger.error(f"Failed to add nodes with label `{label}`: {str(e)}") + raise def add_relationship(self, relationship: Relationship) -> None: """ - Create a relationship between two nodes in the Neo4j database. + Create a single relationship between two nodes in the Neo4j database. Args: relationship (Relationship): The relationship to create. @@ -122,65 +139,70 @@ def add_relationship(self, relationship: Relationship) -> None: Raises: ValueError: If there is an issue with the query execution. """ - query = f""" - MATCH (a {{id: $source_node_id}}), (b {{id: $target_node_id}}) - MERGE (a)-[r:{relationship.type}]->(b) - ON CREATE SET r.createdAt = $current_time - SET r.updatedAt = $current_time, r += $properties - RETURN r - """ - params = { - 'source_node_id': relationship.source_node_id, - 'target_node_id': relationship.target_node_id, - 'properties': relationship.properties or {}, - 'current_time': get_current_time() - } - try: - with self.client.driver.session(database=self.client.database) as session: - session.run(query, params) - logger.info("Relationship with label %s between %s and %s created or updated successfully", relationship.type, relationship.source_node_id, relationship.target_node_id) - except Neo4jError as e: - logger.error("Failed to create or update relationship: %s", str(e)) - raise ValueError(f"Failed to create or update relationship: {str(e)}") - - def add_relationships(self, relationships: List[Relationship]) -> None: + # Encapsulate the single relationship in a list and delegate + self.add_relationships([relationship]) + + + def add_relationships(self, relationships: List[Relationship], batch_size: int = 1000) -> None: """ - Create multiple relationships between nodes in the Neo4j database. + Create multiple relationships between nodes in the Neo4j database in batches. Args: relationships (List[Relationship]): A list of relationships to create. + batch_size (int): The size of each batch. Defaults to 1000. Raises: ValueError: If there is an issue with the query execution. """ - query = """ - UNWIND $relationships AS rel - MATCH (a {id: rel.source_node_id}), (b {id: rel.target_node_id}) - MERGE (a)-[r:{rel.label}]->(b) - ON CREATE SET r.createdAt = $current_time - SET r.updatedAt = $current_time, r += rel.properties - RETURN r - """ - params = { - 'relationships': [r.model_dump() for r in relationships], - 'current_time': get_current_time() - } - try: - with self.client.driver.session(database=self.client.database) as session: - session.run(query, params) - logger.info("Relationships created or updated successfully") - except Neo4jError as e: - logger.error("Failed to create or update relationships: %s", str(e)) - raise ValueError(f"Failed to create or update relationships: {str(e)}") - - def query(self, query: str, params: Dict[str, Any] = None, sanitize: bool = None) -> List[Dict[str, Any]]: + # Group relationships by their types + relationships_by_type = defaultdict(list) + for relationship in relationships: + relationships_by_type[relationship.type].append(relationship) + + # Process each relationship type separately + for rel_type, rel_group in relationships_by_type.items(): + query = f""" + UNWIND $data AS rel + MATCH (a {{id: rel.source_node_id}}), (b {{id: rel.target_node_id}}) + MERGE (a)-[r:`{rel_type}`]->(b) + ON CREATE SET r.createdAt = rel.current_time + SET r.updatedAt = rel.current_time, r += rel.properties + RETURN r + """ + + # Prepare data for batch processing + current_time = get_current_time() + data = [ + { + **rel.model_dump(), + "current_time": current_time + } + for rel in rel_group + ] + + # Execute in batches for the current relationship type + try: + self.batch_execute(query, data, batch_size) + logger.info(f"Relationships of type `{rel_type}` added successfully.") + except ValueError as e: + logger.error(f"Failed to add relationships of type `{rel_type}`: {str(e)}") + raise + + def query( + self, + query: str, + params: Optional[Dict[str, Any]] = None, + sanitize: Optional[bool] = None, + pagination_limit: Optional[int] = None, + ) -> List[Dict[str, Any]]: """ - Execute a Cypher query against a Neo4j database and optionally sanitize the results. + Execute a Cypher query against the Neo4j database and optionally sanitize or paginate the results. Args: query (str): The Cypher query to execute. - params (Dict[str, Any]): Optional dictionary of parameters for the query. - sanitize (bool): Whether to sanitize the results. + params (Dict[str, Any], optional): Parameters for the query. Defaults to None. + sanitize (bool, optional): Whether to sanitize the results. Defaults to class-level setting. + pagination_limit (int, optional): Limit the number of results for pagination. Defaults to None. Returns: List[Dict[str, Any]]: A list of dictionaries representing the query results. @@ -188,26 +210,43 @@ def query(self, query: str, params: Dict[str, Any] = None, sanitize: bool = None Raises: ValueError: If there is a syntax error in the Cypher query. Neo4jError: If any other Neo4j-related error occurs. - - References: - `https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/graphs/neo4j_graph.py`_ """ + from neo4j import Query + from neo4j.exceptions import Neo4jError, CypherSyntaxError + import time + params = params or {} sanitize = sanitize if sanitize is not None else self.sanitize + start_time = time.time() + try: with self.client.driver.session(database=self.client.database) as session: + # Add pagination support if a limit is provided + if pagination_limit: + query = f"{query} LIMIT {pagination_limit}" + result = session.run(Query(text=query, timeout=self.timeout), parameters=params) json_data = [record.data() for record in result] + + # Optional sanitization of results if sanitize: json_data = [value_sanitize(el) for el in json_data] - logger.info("Query executed successfully: %s", query) + + execution_time = time.time() - start_time + logger.info( + "Query executed successfully: %s | Time: %.2f seconds | Results: %d", + query, + execution_time, + len(json_data), + ) return json_data + except CypherSyntaxError as e: - logger.error("Syntax error in the Cypher query: %s", str(e)) - raise ValueError(f"Syntax error in the Cypher query: {str(e)}") + logger.error("Syntax error in Cypher query: %s | Query: %s", str(e), query) + raise ValueError(f"Syntax error in Cypher query: {str(e)}") except Neo4jError as e: - logger.error("An error occurred during the query execution: %s", str(e)) - raise ValueError(f"An error occurred during the query execution: {str(e)}") + logger.error("Neo4j error: %s | Query: %s", str(e), query) + raise ValueError(f"Neo4j error: {str(e)}") def reset(self): """ @@ -216,6 +255,8 @@ def reset(self): Raises: ValueError: If there is an issue with the query execution. """ + from neo4j.exceptions import Neo4jError + try: with self.client.driver.session() as session: session.run("CALL apoc.schema.assert({}, {})") @@ -232,51 +273,66 @@ def refresh_schema(self) -> None: Raises: ValueError: If there is an issue with the query execution. """ + from neo4j.exceptions import Neo4jError + try: - # Refresh node properties - node_properties = self.query( - """ - CALL apoc.meta.data() - YIELD label, property, type - WHERE type <> 'RELATIONSHIP' - RETURN label, collect({property: property, type: type}) AS properties - """ - ) - - # Refresh relationship properties - relationship_properties = self.query( - """ - CALL apoc.meta.data() - YIELD label, property, type - WHERE type = 'RELATIONSHIP' - RETURN label, collect({property: property, type: type}) AS properties - """ - ) - - # Refresh constraints + # Define queries as constants for reusability + NODE_PROPERTIES_QUERY = """ + CALL apoc.meta.data() + YIELD label, property, type + WHERE type <> 'RELATIONSHIP' + RETURN label, collect({property: property, type: type}) AS properties + """ + RELATIONSHIP_PROPERTIES_QUERY = """ + CALL apoc.meta.data() + YIELD label, property, type + WHERE type = 'RELATIONSHIP' + RETURN label, collect({property: property, type: type}) AS properties + """ + INDEXES_QUERY = """ + CALL apoc.schema.nodes() + YIELD label, properties, type, size, valuesSelectivity + WHERE type = 'RANGE' + RETURN *, size * valuesSelectivity AS distinctValues + """ + + # Execute queries + logger.debug("Refreshing node properties...") + node_properties = self.query(NODE_PROPERTIES_QUERY) + + logger.debug("Refreshing relationship properties...") + relationship_properties = self.query(RELATIONSHIP_PROPERTIES_QUERY) + + logger.debug("Refreshing constraints...") constraints = self.query("SHOW CONSTRAINTS") - # Refresh indexes - indexes = self.query( - """ - CALL apoc.schema.nodes() - YIELD label, properties, type, size, valuesSelectivity - WHERE type = 'RANGE' - RETURN *, size * valuesSelectivity as distinctValues - """ - ) + logger.debug("Refreshing indexes...") + indexes = self.query(INDEXES_QUERY) + # Transform query results into schema dictionary self.graph_schema = { - "node_props": {record['label']: record['properties'] for record in node_properties}, - "rel_props": {record['label']: record['properties'] for record in relationship_properties}, - "constraints": constraints, - "indexes": indexes, + "node_props": { + record.get("label"): record.get("properties", []) + for record in (node_properties or []) + }, + "rel_props": { + record.get("label"): record.get("properties", []) + for record in (relationship_properties or []) + }, + "constraints": constraints or [], + "indexes": indexes or [], } + logger.info("Schema refreshed successfully") + except Neo4jError as e: logger.error("Failed to refresh schema: %s", str(e)) raise ValueError(f"Failed to refresh schema: {str(e)}") + except Exception as e: + logger.error("Unexpected error while refreshing schema: %s", str(e)) + raise ValueError(f"Unexpected error while refreshing schema: {str(e)}") + def get_schema(self, refresh: bool = False) -> Dict[str, Any]: """ Get the schema of the Neo4jGraph store. @@ -290,22 +346,65 @@ def get_schema(self, refresh: bool = False) -> Dict[str, Any]: if not self.graph_schema or refresh: self.refresh_schema() return self.graph_schema - - def create_vector_index(self, label: str, property: str, dimensions: int, similarity_function: str = 'cosine'): + + def validate_schema(self, expected_schema: BaseModel) -> bool: + """ + Validate the current graph schema against an expected Pydantic schema model. + + Args: + expected_schema (Type[BaseModel]): The Pydantic schema model to validate against. + + Returns: + bool: True if schema matches, False otherwise. + """ + # Retrieve the current schema + current_schema = self.get_schema() + + try: + # Attempt to initialize the expected schema with the current schema + validated_schema = expected_schema(**current_schema) + logger.info("Schema validation passed: %s", validated_schema) + return True + except ValidationError as e: + # Handle and log validation errors + logger.error("Schema validation failed due to validation errors:") + for error in e.errors(): + logger.error(f"Field: {error['loc']}, Error: {error['msg']}") + return False + + def create_vector_index( + self, + label: str, + property: str, + dimensions: int, + similarity_function: Literal['cosine', 'dot', 'euclidean'] = 'cosine' + ) -> None: """ Creates a vector index for a specified label and property in the Neo4j database. Args: - label (str): The label of the nodes to index. - property (str): The property of the nodes to index. + label (str): The label of the nodes to index (non-empty). + property (str): The property of the nodes to index (non-empty). dimensions (int): The number of dimensions of the vector. - similarity_function (str): The similarity function to use (default is 'cosine'). + similarity_function (Literal): The similarity function to use ('cosine', 'dot', 'euclidean'). Raises: - ValueError: If there is an issue with the query execution. + ValueError: If there is an issue with the query execution or invalid arguments. """ + from neo4j.exceptions import Neo4jError + + # Ensure label and property are non-empty strings + if not all([label, property]): + raise ValueError("Both `label` and `property` must be non-empty strings.") + + # Ensure dimensions is valid + if not isinstance(dimensions, int) or dimensions <= 0: + raise ValueError("`dimensions` must be a positive integer.") + + # Construct index name and query + index_name = f"{label.lower()}_{property}_vector_index" query = f""" - CREATE VECTOR INDEX {label.lower()}_embedding_index IF NOT EXISTS + CREATE VECTOR INDEX {index_name} IF NOT EXISTS FOR (n:{label}) ON (n.{property}) OPTIONS {{ @@ -315,10 +414,27 @@ def create_vector_index(self, label: str, property: str, dimensions: int, simila }} }} """ + try: - with self.client.driver.session() as session: + with self.client.driver.session(database=self.database) as session: session.run(query) - logger.info("Vector index for %s on property %s created successfully", label, property) + logger.info( + "Vector index `%s` for label `%s` on property `%s` created successfully.", + index_name, label, property + ) + + # Optionally update graph schema + if "indexes" in self.graph_schema: + self.graph_schema["indexes"].append({ + "label": label, + "property": property, + "dimensions": dimensions, + "similarity_function": similarity_function + }) + except Neo4jError as e: logger.error("Failed to create vector index: %s", str(e)) - raise ValueError(f"Failed to create vector index: {str(e)}") \ No newline at end of file + raise ValueError(f"Failed to create vector index: {str(e)}") + except Exception as e: + logger.error("Unexpected error during vector index creation: %s", str(e)) + raise ValueError(f"Unexpected error: {str(e)}") \ No newline at end of file diff --git a/src/floki/storage/graphstores/neo4j/client.py b/src/floki/storage/graphstores/neo4j/client.py index b584c3d..6d5c621 100644 --- a/src/floki/storage/graphstores/neo4j/client.py +++ b/src/floki/storage/graphstores/neo4j/client.py @@ -1,46 +1,67 @@ -from neo4j import GraphDatabase -from typing import Optional +from pydantic import BaseModel, Field +from typing import Optional, Any import os import logging logger = logging.getLogger(__name__) -class Neo4jClient: - def __init__(self, uri: Optional[str] = None, user: Optional[str] = None, password: Optional[str] = None, database: Optional[str] = 'neo4j'): - """ - Initializes the Neo4j client with the given connection parameters. +class Neo4jClient(BaseModel): + """ + Client for interacting with a Neo4j database. + Handles connection initialization, closing, and basic testing of connectivity. + """ + + uri: str = Field(default=None, description="The URI of the Neo4j database. Defaults to the 'NEO4J_URI' environment variable.") + user: str = Field(default=None, description="The username for Neo4j authentication. Defaults to the 'NEO4J_USERNAME' environment variable.") + password: str = Field(default=None, description="The password for Neo4j authentication. Defaults to the 'NEO4J_PASSWORD' environment variable.") + database: str = Field(default="neo4j", description="The default database to use. Defaults to 'neo4j'.") + driver: Optional[Any] = Field(default=None, init=False, description="The Neo4j driver instance for database operations. Initialized in 'model_post_init'.") - Args: - uri (Optional[str]): The URI of the Neo4j database. - user (Optional[str]): The username for authentication. - password (Optional[str]): The password for authentication. - database (Optional[str]): The database to use. Defaults to 'neo4j'. + def model_post_init(self, __context: Any) -> None: """ - self.uri = os.getenv('NEO4J_URI', uri) - self.user = os.getenv('NEO4J_USERNAME', user) - self.password = os.getenv('NEO4J_PASSWORD', password) - self.database = database - self.driver = None + Post-initialization logic to handle dynamic imports and environment variable defaults. + """ + try: + from neo4j import GraphDatabase + except ImportError as e: + raise ImportError( + "The 'neo4j' package is required but not installed. Install it with 'pip install neo4j'." + ) from e + + # Handle environment variable defaults + self.uri = self.uri or os.getenv("NEO4J_URI") + self.user = self.user or os.getenv("NEO4J_USERNAME") + self.password = self.password or os.getenv("NEO4J_PASSWORD") + + if not all([self.uri, self.user, self.password]): + raise ValueError("Missing required connection parameters (uri, user, password). Set them as environment variables or pass explicitly.") + # Initialize the Neo4j driver try: self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) logger.info("Successfully created the driver for URI: %s", self.uri) except Exception as e: logger.error("Failed to create the driver: %s", str(e)) - exit(1) - - def close(self): - """Closes the Neo4j driver connection.""" + raise ValueError(f"Failed to initialize the Neo4j driver: {str(e)}") + + # Complete post-initialization + super().model_post_init(__context) + + def close(self) -> None: + """ + Closes the Neo4j driver connection. + """ if self.driver is not None: self.driver.close() logger.info("Neo4j driver connection closed") - - def test_connection(self): - """Tests the connection to the Neo4j database. + + def test_connection(self) -> bool: + """ + Tests the connection to the Neo4j database. Returns: bool: True if the connection is successful, False otherwise. - + Raises: ValueError: If there is an error testing the connection. """ @@ -49,7 +70,10 @@ def test_connection(self): result = session.run("CALL dbms.components() YIELD name, versions, edition") record = result.single() if record: - logger.info("Connected to %s version %s (%s edition)", record['name'], record['versions'][0], record['edition']) + logger.info( + "Connected to %s version %s (%s edition)", + record["name"], record["versions"][0], record["edition"] + ) return True else: logger.warning("No record found during the connection test") diff --git a/src/floki/storage/graphstores/neo4j/utils.py b/src/floki/storage/graphstores/neo4j/utils.py index d0496cd..27486e1 100644 --- a/src/floki/storage/graphstores/neo4j/utils.py +++ b/src/floki/storage/graphstores/neo4j/utils.py @@ -1,29 +1,66 @@ from typing import Any import datetime +import logging LIST_LIMIT = 100 # Maximum number of elements in a list to be processed +logger = logging.getLogger(__name__) + def value_sanitize(data: Any) -> Any: """ - Sanitizes the input data (dictionary or list) for use in a language model context. - This function filters out large lists and simplifies nested structures to improve the - efficiency of language model processing. + Sanitizes the input data (dictionary or list) for use in a language model or database context. + This function filters out large lists, simplifies nested structures, and ensures Neo4j-specific + data types are handled efficiently. Args: data (Any): The data to sanitize, which can be a dictionary, list, or other types. Returns: - Any: The sanitized data, or None if a list exceeds the predefined size limit. + Any: The sanitized data. Returns `None` for lists exceeding the size limit or unsupported types. """ if isinstance(data, dict): - return {key: sanitized for key, value in data.items() if (sanitized := value_sanitize(value)) is not None} + # Sanitize each key-value pair in the dictionary. + sanitized_dict = {} + for key, value in data.items(): + # Preserve essential metadata keys starting with "_" (e.g., Neo4j system keys). + if key.startswith("_"): + sanitized_dict[key] = value + continue + + # Recursively sanitize the value. + sanitized_value = value_sanitize(value) + if sanitized_value is not None: + sanitized_dict[key] = sanitized_value + + return sanitized_dict + elif isinstance(data, list): + # Truncate or sample large lists to avoid exceeding size limits. if len(data) > LIST_LIMIT: - return None - return [sanitized for item in data if (sanitized := value_sanitize(item)) is not None] - else: + return data[:LIST_LIMIT] # Return the first `LIST_LIMIT` elements instead of discarding the list. + + # Recursively sanitize each element in the list. + sanitized_list = [ + sanitized_item for item in data if (sanitized_item := value_sanitize(item)) is not None + ] + return sanitized_list + + elif isinstance(data, tuple): + # Sanitize tuples (e.g., Neo4j relationships) + return tuple(value_sanitize(item) for item in data) + + elif isinstance(data, datetime.datetime): + # Convert datetime objects to ISO 8601 string for consistency. + return data.isoformat() + + elif isinstance(data, (int, float, bool, str)): + # Primitive types are returned as-is. return data + else: + logger.warning(f"Unsupported data type encountered: {type(data)}. Value: {repr(data)}") + return None # Exclude the data entirely. + def get_current_time(): """Get current time in UTC for creation and modification of nodes and relationships""" return datetime.datetime.now(datetime.timezone.utc).isoformat().replace('+00:00', 'Z') \ No newline at end of file