Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions graphiti_core/driver/neptune_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import asyncio
import datetime
import logging
import os
from collections.abc import Coroutine
from typing import Any

Expand All @@ -26,6 +27,8 @@

from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider

from pydantic import SecretStr

logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10

Expand Down Expand Up @@ -109,7 +112,7 @@
class NeptuneDriver(GraphDriver):
provider: GraphProvider = GraphProvider.NEPTUNE

def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443, use_https=True):
"""This initializes a NeptuneDriver for use with Neptune as a backend

Args:
Expand All @@ -124,7 +127,15 @@ def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int =
if host.startswith('neptune-db://'):
# This is a Neptune Database Cluster
endpoint = host.replace('neptune-db://', '')
self.client = NeptuneGraph(endpoint, port)
aws_access_key_id = os.getenv('NEPTUNE_AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.getenv('NEPTUNE_AWS_SECRET_ACCESS_KEY')
aws_region_name = os.getenv('NEPTUNE_AWS_REGION')
if aws_access_key_id and aws_secret_access_key and aws_region_name:
self.client = NeptuneGraph(endpoint, port, aws_access_key_id=SecretStr(aws_access_key_id),
aws_secret_access_key=SecretStr(aws_secret_access_key),
region_name=aws_region_name, use_https=use_https)
else:
self.client = NeptuneGraph(endpoint, port, use_https=use_https)
logger.debug('Creating Neptune Database session for %s', host)
elif host.startswith('neptune-graph://'):
# This is a Neptune Analytics Graph
Expand All @@ -143,9 +154,9 @@ def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int =
self.aoss_client = OpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth(
session.get_credentials(), session.region_name, 'aoss'
session.get_credentials(), aws_region_name or 'us-east-2', 'es'
),
use_ssl=True,
use_ssl=use_https,
verify_certs=True,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
Expand Down Expand Up @@ -191,7 +202,7 @@ def _sanitize_parameters(self, query, params: dict):
async def execute_query(
self, cypher_query_, **kwargs: Any
) -> tuple[dict[str, Any], None, None]:
params = dict(kwargs)
params = kwargs.get('params', kwargs)
if isinstance(cypher_query_, list):
for q in cypher_query_:
result, _, _ = self._run_query(q[0], q[1])
Expand Down
64 changes: 52 additions & 12 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,46 @@ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> f
return dot_product / (norm_vector1 * norm_vector2)


def _lucene_sanitize(query: str) -> str:
# Escape special characters from a query before passing into Lucene
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
escape_map = str.maketrans(
{
'+': r'\+',
'-': r'\-',
'&': r'\&',
'|': r'\|',
'!': r'\!',
'(': r'\(',
')': r'\)',
'{': r'\{',
'}': r'\}',
'[': r'\[',
']': r'\]',
'^': r'\^',
'"': r'\"',
"'": r"\'",
'~': r'\~',
'*': r'\*',
'?': r'\?',
':': r'\:',
'\\': r'\\',
'/': r'\/',
'@': r'\@',
'%': r'\%',
'O': r'\O',
'R': r'\R',
'N': r'\N',
'T': r'\T',
'A': r'\A',
'D': r'\D',
}
)

sanitized = query.translate(escape_map)
return sanitized


def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
if driver.provider == GraphProvider.KUZU:
# Kuzu only supports simple queries.
Expand All @@ -95,7 +135,7 @@ def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver)

group_ids_filter += ' AND ' if group_ids_filter else ''

lucene_query = lucene_sanitize(query)
lucene_query = _lucene_sanitize(query)
# If the lucene query is too long return no query
if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
return ''
Expand Down Expand Up @@ -338,7 +378,7 @@ async def edge_similarity_search(
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in resp:
if r['embedding']:
if 'embedding' in r and r['embedding']:
score = calculate_cosine_similarity(
search_vector, list(map(float, r['embedding'].split(',')))
)
Expand Down Expand Up @@ -668,7 +708,7 @@ async def node_similarity_search(
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in resp:
if r['embedding']:
if 'embedding' in r:
score = calculate_cosine_similarity(
search_vector, list(map(float, r['embedding'].split(',')))
)
Expand Down Expand Up @@ -1023,7 +1063,7 @@ async def community_similarity_search(
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in resp:
if r['embedding']:
if 'embedding' in r and r['embedding']:
score = calculate_cosine_similarity(
search_vector, list(map(float, r['embedding'].split(',')))
)
Expand Down Expand Up @@ -1367,10 +1407,10 @@ async def get_relevant_edges(
input_ids = []
for r in resp:
score = calculate_cosine_similarity(
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
list(map(float, r.get('source_embedding', '0').split(','))), r.get('target_embedding', 0)
)
if score > min_score:
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
input_ids.append({'id': r.get('id'), 'score': score, 'uuid': r.get('search_edge_uuid')})

# Match the edge ides and return the values
query = """
Expand Down Expand Up @@ -1497,8 +1537,8 @@ async def get_relevant_edges(
)

relevant_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
result.get('search_edge_uuid', result.get('uuid', '')): [
get_entity_edge_from_record(record, driver.provider) for record in result.get('matches', [])
]
for result in results
}
Expand Down Expand Up @@ -1554,10 +1594,10 @@ async def get_edge_invalidation_candidates(
input_ids = []
for r in resp:
score = calculate_cosine_similarity(
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
list(map(float, r.get('source_embedding', '0').split(','))), r.get('target_embedding', 0)
)
if score > min_score:
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
input_ids.append({'id': r.get('id'), 'score': score, 'uuid': r.get('search_edge_uuid')})

# Match the edge ides and return the values
query = """
Expand Down Expand Up @@ -1684,8 +1724,8 @@ async def get_edge_invalidation_candidates(
**filter_params,
)
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
result.get('search_edge_uuid', result.get('uuid', '')): [
get_entity_edge_from_record(record, driver.provider) for record in result.get('matches', [])
]
for result in results
}
Expand Down
5 changes: 4 additions & 1 deletion graphiti_core/utils/maintenance/node_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,10 @@ async def resolve_extracted_nodes(
resolution_id: int = resolution.id
duplicate_idx: int = resolution.duplicate_idx

extracted_node = extracted_nodes[resolution_id]
try:
extracted_node = extracted_nodes[resolution_id]
except IndexError:
continue

resolved_node = (
existing_nodes[duplicate_idx]
Expand Down