Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions memgraph-toolbox/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ classifiers = [

dependencies = [
"deepeval>=3.5.2",
"fastmcp>=2.12.4",
"httpx>=0.27.0",
"litellm>=1.77.1",
"mcp>=1.14.1",
"neo4j>=5.28.1",
"numpy>=2.2.6",
"orjson>=3.11.3",
"pydantic>=2.11.10",
"rich>=13.0.0",
"sentence-transformers>=5.1.0",
"torch>=2.8.0",
"typer>=0.9.0",
]

[project.optional-dependencies]
Expand All @@ -36,6 +42,9 @@ test = [
"pytest-asyncio>=0.26.0",
]

[project.scripts]
mcp-chat = "memgraph_toolbox.client.mcp_full:app"

[project.urls]
"Homepage" = "https://github.com/memgraph/ai-toolkit"
"Source" = "https://github.com/memgraph/ai-toolkit"
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import asyncio
from fastmcp.client import Client
from fastmcp.client.elicitation import ElicitResult


async def handle_elicitation(message, response_type, *_):
print(f"Server asks: {message}")
comment = input("INPUT:")
return ElicitResult(action="accept", content=response_type(**{"comment": comment}))


async def main():
async with Client(
"http://localhost:8000/mcp", elicitation_handler=handle_elicitation
) as c:
result = await c.call_tool("ask_feedback", {})
print("Tool result:", result.data)


if __name__ == "__main__":
asyncio.run(main())
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from fastmcp import FastMCP, Context
from pydantic import BaseModel

mcp = FastMCP("demo-server")


class Feedback(BaseModel):
comment: str


@mcp.tool
async def ask_feedback(ctx: Context) -> str:
result = await ctx.elicit("Please write a comment:", response_type=Feedback)
print(f"Result: {result}")
if result.action == "accept":
return f"You said: {result.data.comment}"
else:
return "You declined to provide feedback."


if __name__ == "__main__":
mcp.run(transport="streamable-http")
186 changes: 186 additions & 0 deletions memgraph-toolbox/src/memgraph_toolbox/client/mcp_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import asyncio
from fastmcp import Client
from fastmcp.exceptions import ToolError
from mcp.types import ElicitResult
from rich.console import Console

# https://github.com/prompt-toolkit/python-prompt-toolkit is also very interesting.
import typer
from litellm import acompletion
from tools import atool_calls

LLM_MODEL = "openai/gpt-4o"
app = typer.Typer(help="MCP Terminal Client (Streaming HTTP)")
console = Console()


async def describe_mcp_server(client: Client):
tools = await client.list_tools()
prompts = await client.list_prompts()
tool_names = [tool.name for tool in tools]
console.print(f"[bold green]Available Tools:[/bold green] {tool_names}")
prompt_names = [prompt.name for prompt in prompts]
console.print(f"[bold green]Available Prompts:[/bold green] {prompt_names}")
console.print(f"Getting all the prompts...")
for prompt_name in prompt_names:
result = await client.get_prompt(prompt_name)
console.print(f" Prompt: [yellow]{prompt_name}[/yellow]")
for message in result.messages:
console.print(f" Role: [blue]{message.role}[/blue]")
console.print(f" Content Text:")
console.print(f" [blue]{message.content.text}[/blue]")
return tools, prompts


async def my_elicitation_handler(message: str, response_type: type, params, context):
"""
User adds additonal feedback to the failed query.
Actions are: accept, decline, cancel.
"""
console.print(f"[bold yellow]Server asks:[/bold yellow] {message}")
console.print(f"[dim]Response type: {response_type}[/dim]")
# Use an editor to get user input
content = typer.edit(message)
# Decline
if not content or content.strip() == "":
console.print("[red]Declining elicitation request[/red]")
return ElicitResult(action="decline")
# Accept
console.print(f"[green]Sending response:[/green] {content}")
# NOTE: None of the below pass the pydantic validation for some reason.
# return response_type(**{"data": content})
# return ElicitResult(action="accept", content=response_type(**{"data": content}))
# return ElicitResult(action="accept", content={"data": content})
return {"data": content}


async def my_sampling_handler(message, params, context) -> str:
"""
Iterate in a loop and try to generate valid responses.
"""
return "Generated sample response"


def convert_mcp_tools_to_litellm(mcp_tools):
"""Convert MCP tools to litellm format"""
litellm_tools = []
for tool in mcp_tools:
litellm_tool = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or f"Tool: {tool.name}",
"parameters": tool.inputSchema or {},
},
}
litellm_tools.append(litellm_tool)
return litellm_tools


@app.command()
def chat(
base_url: str = typer.Option(
"http://localhost:8000/mcp", help="MCP server base URL"
),
interactivity_type: str = typer.Option(
"llm-use", # TODO: There should be one-shot, iterate, etc...
help="Interactivity type (llm-use, no-llm-use)",
case_sensitive=False,
metavar="INTERACTIVITY_TYPE",
),
):
config = {
"mcpServers": {
"memgraph": {"url": base_url},
}
}
client = Client(
config,
elicitation_handler=my_elicitation_handler,
sampling_handler=my_sampling_handler,
)
print(f"Interactivity type: {interactivity_type}")

async def _run():
async with client:
tools, prompts = await describe_mcp_server(client)
litellm_tools = convert_mcp_tools_to_litellm(tools)
# TODO(gitbuda): Fill this in from the templates.
chat_messages = [
{
"role": "user",
"content": """
You are an expert in Cypher and Memgraph graph databases.
You have access to the following tools:
- `run_query`: Run a Cypher query on the Memgraph database.
Input is a Cypher query string. Output is a list of records or an error message.
- `get_schema`: Get the schema information of the Memgraph database. Input is empty.
You should first call `get_schema` to understand the database structure,
then use `run_query` to execute Cypher queries.
IMPORTANT: Do not call run_query if you don't have the schema information.
""",
}
]

while True:
user_input = input("\n> ").strip()
if user_input.lower() in ("exit", "quit"):
break
# TODO(gitbuda): Expose prompts user can pick from.

if interactivity_type == "llm-use":
console.print(f"\n[bold yellow]Prompt:[/bold yellow] {user_input}")
# Convert MCP tools to litellm format
chat_messages.append({"role": "user", "content": user_input})
pick_tool_resp = await acompletion(
LLM_MODEL, chat_messages, tools=litellm_tools
)
# print(pick_tool_resp)
# TODO(gitbuda): Here should be a prompt to allow tool call.
tool_call_msg = await atool_calls(
pick_tool_resp, chat_messages, client
)
print(tool_call_msg)
chat_messages.append(
{
"role": "assistant",
"content": """Based on the given data, generate the run_query call.""",
}
)
the_query_resp = await acompletion(
LLM_MODEL, tool_call_msg, tools=litellm_tools
)
print(the_query_resp)
tool_call_msg = await atool_calls(
the_query_resp, chat_messages, client
)
mg_data = tool_call_msg[-1]["content"]
console.print(
f"\n[bold yellow]Query Result:[/bold yellow] {mg_data}"
)

if interactivity_type == "no-llm-use":
query_result = await client.call_tool(
"run_query", {"query": user_input}
)
console.print(
f"\n[bold yellow]Query Result:[/bold yellow] {query_result.data}"
)
# Print any additional content if available
for content in query_result.content:
if hasattr(content, "text"):
console.print(
f"[bold blue]Content:[/bold blue] {content.text}"
)
elif hasattr(content, "data"):
console.print(
f"[bold blue]Data:[/bold blue] {content.data}"
)

# TODO(gitbuda): If something fails or empty, don't take user input -> rerun the whole thing again (make sure the train history is there).

asyncio.run(_run())


if __name__ == "__main__":
app()
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ async def prompt_with_tools(
messages=messages,
tools=tools,
)
# TODO(gitbuda): Replace the code below with tools.atool_calls
msg = resp["choices"][0]["message"]
# Tool calls by the MCP server.
tool_calls = msg.get("tool_calls", [])
Expand Down
80 changes: 80 additions & 0 deletions memgraph-toolbox/src/memgraph_toolbox/client/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import json, logging

logger = logging.getLogger(__name__)


async def atool_calls(resp, messages, session):
"""
Execute tool calls by the MCP server.
"""
msg = resp["choices"][0]["message"]
# Tool calls by the MCP server.
tool_calls = msg.get("tool_calls", [])
if not tool_calls:
logger.info("LLM Response: %s", msg.get("content", ""))
else:
logger.info("LLM wants to use %d tools:", len(tool_calls))
for tc in tool_calls:
logger.info(
"- %s: %s",
tc["function"]["name"],
tc["function"].get("description", "No description"),
)
# Ensure all tool call IDs are within OpenAI's 40-character limit
for tc in tool_calls:
if len(tc["id"]) > 40:
logger.warning(
"Tool call ID is too long, truncating to 40 characters: %s",
tc["id"],
)
tc["id"] = tc["id"][:40]
messages.append(
{
"role": "assistant",
"content": msg.get("content", ""),
"tool_calls": tool_calls,
}
)

for tc in tool_calls:
try:
arguments = tc["function"].get("arguments", {})
logger.info("Arguments: %s", arguments)
if isinstance(arguments, str):
arguments = json.loads(arguments)
result = await session.call_tool(
name=tc["function"]["name"], arguments=arguments
)
if hasattr(result, "content") and result.content:
if isinstance(result.content, list):
content_text = []
for content_item in result.content:
if hasattr(content_item, "text"):
content_text.append(content_item.text)
elif isinstance(content_item, str):
content_text.append(content_item)
else:
content_text.append(str(content_item))
content_str = "\n".join(content_text)
elif hasattr(result.content, "text"):
content_str = result.content.text
else:
content_str = str(result.content)
else:
content_str = "No content returned"
logger.debug(
"Tool %s result: %s",
tc["function"]["name"],
content_str,
)
messages.append(
{
"role": "tool",
"tool_call_id": tc["id"],
"name": tc["function"]["name"],
"content": content_str,
}
)
except Exception as e:
logger.error(str(e))
return messages
Loading