diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index 25aa12c3f..da88b23f5 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -17,6 +17,7 @@ import asyncio import datetime import logging +import os from collections.abc import Coroutine from typing import Any @@ -26,6 +27,8 @@ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from pydantic import SecretStr + logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 @@ -109,7 +112,7 @@ class NeptuneDriver(GraphDriver): provider: GraphProvider = GraphProvider.NEPTUNE - def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443): + def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443, use_https=True): """This initializes a NeptuneDriver for use with Neptune as a backend Args: @@ -124,7 +127,15 @@ def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = if host.startswith('neptune-db://'): # This is a Neptune Database Cluster endpoint = host.replace('neptune-db://', '') - self.client = NeptuneGraph(endpoint, port) + aws_access_key_id = os.getenv('NEPTUNE_AWS_ACCESS_KEY_ID') + aws_secret_access_key = os.getenv('NEPTUNE_AWS_SECRET_ACCESS_KEY') + aws_region_name = os.getenv('NEPTUNE_AWS_REGION') + if aws_access_key_id and aws_secret_access_key and aws_region_name: + self.client = NeptuneGraph(endpoint, port, aws_access_key_id=SecretStr(aws_access_key_id), + aws_secret_access_key=SecretStr(aws_secret_access_key), + region_name=aws_region_name, use_https=use_https) + else: + self.client = NeptuneGraph(endpoint, port, use_https=use_https) logger.debug('Creating Neptune Database session for %s', host) elif host.startswith('neptune-graph://'): # This is a Neptune Analytics Graph @@ -143,9 +154,9 @@ def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = self.aoss_client = OpenSearch( hosts=[{'host': aoss_host, 'port': aoss_port}], http_auth=Urllib3AWSV4SignerAuth( - session.get_credentials(), session.region_name, 'aoss' + session.get_credentials(), aws_region_name or 'us-east-2', 'es' ), - use_ssl=True, + use_ssl=use_https, verify_certs=True, connection_class=Urllib3HttpConnection, pool_maxsize=20, @@ -191,7 +202,7 @@ def _sanitize_parameters(self, query, params: dict): async def execute_query( self, cypher_query_, **kwargs: Any ) -> tuple[dict[str, Any], None, None]: - params = dict(kwargs) + params = kwargs.get('params', kwargs) if isinstance(cypher_query_, list): for q in cypher_query_: result, _, _ = self._run_query(q[0], q[1]) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 379662d57..8d98db550 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -78,6 +78,46 @@ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> f return dot_product / (norm_vector1 * norm_vector2) +def _lucene_sanitize(query: str) -> str: + # Escape special characters from a query before passing into Lucene + # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ / + escape_map = str.maketrans( + { + '+': r'\+', + '-': r'\-', + '&': r'\&', + '|': r'\|', + '!': r'\!', + '(': r'\(', + ')': r'\)', + '{': r'\{', + '}': r'\}', + '[': r'\[', + ']': r'\]', + '^': r'\^', + '"': r'\"', + "'": r"\'", + '~': r'\~', + '*': r'\*', + '?': r'\?', + ':': r'\:', + '\\': r'\\', + '/': r'\/', + '@': r'\@', + '%': r'\%', + 'O': r'\O', + 'R': r'\R', + 'N': r'\N', + 'T': r'\T', + 'A': r'\A', + 'D': r'\D', + } + ) + + sanitized = query.translate(escape_map) + return sanitized + + def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver): if driver.provider == GraphProvider.KUZU: # Kuzu only supports simple queries. @@ -95,7 +135,7 @@ def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver) group_ids_filter += ' AND ' if group_ids_filter else '' - lucene_query = lucene_sanitize(query) + lucene_query = _lucene_sanitize(query) # If the lucene query is too long return no query if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH: return '' @@ -338,7 +378,7 @@ async def edge_similarity_search( # Calculate Cosine similarity then return the edge ids input_ids = [] for r in resp: - if r['embedding']: + if 'embedding' in r and r['embedding']: score = calculate_cosine_similarity( search_vector, list(map(float, r['embedding'].split(','))) ) @@ -668,7 +708,7 @@ async def node_similarity_search( # Calculate Cosine similarity then return the edge ids input_ids = [] for r in resp: - if r['embedding']: + if 'embedding' in r: score = calculate_cosine_similarity( search_vector, list(map(float, r['embedding'].split(','))) ) @@ -1023,7 +1063,7 @@ async def community_similarity_search( # Calculate Cosine similarity then return the edge ids input_ids = [] for r in resp: - if r['embedding']: + if 'embedding' in r and r['embedding']: score = calculate_cosine_similarity( search_vector, list(map(float, r['embedding'].split(','))) ) @@ -1367,10 +1407,10 @@ async def get_relevant_edges( input_ids = [] for r in resp: score = calculate_cosine_similarity( - list(map(float, r['source_embedding'].split(','))), r['target_embedding'] + list(map(float, r.get('source_embedding', '0').split(','))), r.get('target_embedding', 0) ) if score > min_score: - input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']}) + input_ids.append({'id': r.get('id'), 'score': score, 'uuid': r.get('search_edge_uuid')}) # Match the edge ides and return the values query = """ @@ -1497,8 +1537,8 @@ async def get_relevant_edges( ) relevant_edges_dict: dict[str, list[EntityEdge]] = { - result['search_edge_uuid']: [ - get_entity_edge_from_record(record, driver.provider) for record in result['matches'] + result.get('search_edge_uuid', result.get('uuid', '')): [ + get_entity_edge_from_record(record, driver.provider) for record in result.get('matches', []) ] for result in results } @@ -1554,10 +1594,10 @@ async def get_edge_invalidation_candidates( input_ids = [] for r in resp: score = calculate_cosine_similarity( - list(map(float, r['source_embedding'].split(','))), r['target_embedding'] + list(map(float, r.get('source_embedding', '0').split(','))), r.get('target_embedding', 0) ) if score > min_score: - input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']}) + input_ids.append({'id': r.get('id'), 'score': score, 'uuid': r.get('search_edge_uuid')}) # Match the edge ides and return the values query = """ @@ -1684,8 +1724,8 @@ async def get_edge_invalidation_candidates( **filter_params, ) invalidation_edges_dict: dict[str, list[EntityEdge]] = { - result['search_edge_uuid']: [ - get_entity_edge_from_record(record, driver.provider) for record in result['matches'] + result.get('search_edge_uuid', result.get('uuid', '')): [ + get_entity_edge_from_record(record, driver.provider) for record in result.get('matches', []) ] for result in results } diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index f17706ab5..e448837fb 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -269,7 +269,10 @@ async def resolve_extracted_nodes( resolution_id: int = resolution.id duplicate_idx: int = resolution.duplicate_idx - extracted_node = extracted_nodes[resolution_id] + try: + extracted_node = extracted_nodes[resolution_id] + except IndexError: + continue resolved_node = ( existing_nodes[duplicate_idx]