Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/powermem/prompts/graph/graph_tools_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@
RELATIONS_STRUCT_TOOL = {
"type": "function",
"function": {
"name": "establish_relations",
"name": "establish_relationships",
"description": "Establish relationships among the entities based on the provided text.",
"strict": True,
"parameters": {
Expand Down Expand Up @@ -466,4 +466,4 @@ def get_relations_tool(self, structured: bool = False) -> Dict[str, Any]:

def get_extract_entities_tool(self, structured: bool = False) -> Dict[str, Any]:
"""Get extract entities tool."""
return EXTRACT_ENTITIES_STRUCT_TOOL if structured else EXTRACT_ENTITIES_TOOL
return EXTRACT_ENTITIES_STRUCT_TOOL if structured else EXTRACT_ENTITIES_TOOL
41 changes: 33 additions & 8 deletions src/powermem/storage/oceanbase/oceanbase_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,10 @@ def search(self, query: str, filters: Dict[str, Any], limit: int = 100) -> List[
if idx < len(search_output):
item = search_output[idx]
search_results.append({
"source": item["source"],
"relationship": item["relationship"],
"destination": item["destination"]
"source": item["source"],
"relationship": item["relationship"],
"destination": item["destination"],
"score": float(scores[idx]),
})

logger.info("Returned %d search results (from %d candidates)", len(search_results), len(search_output))
Expand Down Expand Up @@ -652,9 +653,16 @@ def _retrieve_nodes_from_data(self, data: str, filters: Dict[str, Any]) -> Dict[
Returns:
Dictionary mapping entity names to entity types.
"""
_tools = [self.graph_tools_prompts.get_extract_entities_tool()]
if constants.is_structured_llm_provider(self.llm_provider):
_tools = [self.graph_tools_prompts.get_extract_entities_tool(structured=True)]
_tools = [
self.graph_tools_prompts.get_extract_entities_tool(structured=True),
self.graph_tools_prompts.get_noop_tool(structured=True),
]
else:
_tools = [
self.graph_tools_prompts.get_extract_entities_tool(),
self.graph_tools_prompts.get_noop_tool(),
]

search_results = self.llm.generate_response(
messages=[
Expand Down Expand Up @@ -736,9 +744,16 @@ def _establish_nodes_relations_from_data(
},
]

_tools = [self.graph_tools_prompts.get_relations_tool()]
if constants.is_structured_llm_provider(self.llm_provider):
_tools = [self.graph_tools_prompts.get_relations_tool(structured=True)]
_tools = [
self.graph_tools_prompts.get_relations_tool(structured=True),
self.graph_tools_prompts.get_noop_tool(structured=True),
]
else:
_tools = [
self.graph_tools_prompts.get_relations_tool(),
self.graph_tools_prompts.get_noop_tool(),
]

extracted_entities = self.llm.generate_response(
messages=messages,
Expand Down Expand Up @@ -781,6 +796,7 @@ def _search_graph_db(
List of dictionaries containing source, relationship, destination and their IDs.
"""
result_relations = []
seen_relations = set()

for node in node_list:
n_embedding = self.embedding_model.embed(node)
Expand All @@ -798,7 +814,16 @@ def _search_graph_db(

# Use multi-hop search with early stopping
multi_hop_results = self._multi_hop_search(entity_ids, filters, limit)
result_relations.extend(multi_hop_results)
for relation in multi_hop_results:
relation_key = (
relation.get("source"),
relation.get("relationship"),
relation.get("destination"),
)
if relation_key in seen_relations:
continue
seen_relations.add(relation_key)
result_relations.append(relation)

return result_relations

Expand Down
75 changes: 75 additions & 0 deletions tests/unit/test_oceanbase_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import MagicMock, patch, Mock
import pytest
import uuid
from powermem.prompts.graph.graph_tools_prompts import GraphToolsPrompts
from powermem.storage.oceanbase.oceanbase_graph import MemoryGraph
from powermem.storage.oceanbase import constants

Expand Down Expand Up @@ -186,6 +187,16 @@ def test_coerce_tool_response_to_dict(self):
result = self.memory_graph._coerce_tool_response_to_dict(response_invalid)
self.assertEqual(result, {})

def test_structured_relations_tool_name_matches_regular_tool(self):
"""Test structured relation extraction uses the expected tool name."""
prompts = GraphToolsPrompts()

regular_name = prompts.get_relations_tool()["function"]["name"]
structured_name = prompts.get_relations_tool(structured=True)["function"]["name"]

self.assertEqual(regular_name, "establish_relationships")
self.assertEqual(structured_name, regular_name)

def test_add_method(self):
"""Test the add method with mocked components."""
# Mock the necessary methods that add() calls
Expand Down Expand Up @@ -248,6 +259,7 @@ def test_search_method(self):
self.assertEqual(result[0]["source"], "alice")
self.assertEqual(result[0]["relationship"], "knows")
self.assertEqual(result[0]["destination"], "bob")
self.assertEqual(result[0]["score"], 0.8)

def test_search_method_empty_results(self):
"""Test the search method with empty results."""
Expand All @@ -261,6 +273,69 @@ def test_search_method_empty_results(self):
# Check the result
self.assertEqual(result, [])

def test_search_graph_db_deduplicates_relations_across_seed_nodes(self):
"""Test graph search removes duplicate triples from overlapping seeds."""
self.mock_embedding_model.embed.side_effect = [[0.1], [0.2]]
self.memory_graph._search_node = MagicMock(side_effect=[
[{"id": "entity1"}],
[{"id": "entity2"}],
])
duplicate_relation = {
"source": "alice",
"relationship": "knows",
"destination": "bob",
}
self.memory_graph._multi_hop_search = MagicMock(side_effect=[
[duplicate_relation],
[dict(duplicate_relation)],
])

result = self.memory_graph._search_graph_db(
node_list=["alice", "bob"],
filters=self.test_filters,
)

self.assertEqual(result, [duplicate_relation])

def test_retrieve_nodes_includes_noop_tool(self):
"""Test entity extraction provides noop as an opt-out tool."""
extract_tool = {"function": {"name": "extract_entities"}}
noop_tool = {"function": {"name": "noop"}}
self.mock_graph_tools_prompts.get_extract_entities_tool.return_value = extract_tool
self.mock_graph_tools_prompts.get_noop_tool.return_value = noop_tool
self.mock_llm.generate_response.return_value = {"tool_calls": []}

self.memory_graph._retrieve_nodes_from_data("No memory here", self.test_filters)

self.mock_graph_tools_prompts.get_extract_entities_tool.assert_called_once_with(structured=True)
self.mock_graph_tools_prompts.get_noop_tool.assert_called_once_with(structured=True)
self.assertEqual(
self.mock_llm.generate_response.call_args.kwargs["tools"],
[extract_tool, noop_tool],
)

def test_establish_relations_includes_noop_tool(self):
"""Test relation extraction provides noop as an opt-out tool."""
relations_tool = {"function": {"name": "establish_relationships"}}
noop_tool = {"function": {"name": "noop"}}
self.mock_graph_tools_prompts.get_relations_tool.return_value = relations_tool
self.mock_graph_tools_prompts.get_noop_tool.return_value = noop_tool
self.mock_llm.generate_response.return_value = {"tool_calls": []}

result = self.memory_graph._establish_nodes_relations_from_data(
"Alice knows Bob",
self.test_filters,
{"alice": "person", "bob": "person"},
)

self.assertEqual(result, [])
self.mock_graph_tools_prompts.get_relations_tool.assert_called_once_with(structured=True)
self.mock_graph_tools_prompts.get_noop_tool.assert_called_once_with(structured=True)
self.assertEqual(
self.mock_llm.generate_response.call_args.kwargs["tools"],
[relations_tool, noop_tool],
)

def test_get_all_method(self):
"""Test the get_all method."""
# Mock relationships results
Expand Down
Loading