From ceb7deaf1e9662b003955259160def46da91134d Mon Sep 17 00:00:00 2001 From: Menendez6 Date: Fri, 11 Jul 2025 13:35:02 +0000 Subject: [PATCH 1/2] Support for chat context --- src/expert/agent.py | 13 ++++++++----- src/expert/config.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/expert/agent.py b/src/expert/agent.py index b1e463a..4db9fa9 100644 --- a/src/expert/agent.py +++ b/src/expert/agent.py @@ -6,7 +6,6 @@ from langchain_core.tools import StructuredTool from langchain_mcp_adapters.client import MultiServerMCPClient -from langgraph.checkpoint.memory import InMemorySaver from langgraph.prebuilt import create_react_agent from opik.integrations.langchain import OpikTracer @@ -29,7 +28,6 @@ def __init__(self, config: Config): """Initialize the agent.""" self.config = config self.agent = None - self.checkpointer = InMemorySaver() self._load_prompts() self.opik_config = None self.mcp_client = None @@ -95,11 +93,12 @@ async def setup(self): provider = llm_config.get("provider", "openai") model = llm_config.get("model", "gpt-4.1") model_name = f"{provider}:{model}" + checkpointer = self.config.checkpointer self.agent = create_react_agent( model_name, renamed_tools, - checkpointer=self.checkpointer, + checkpointer= checkpointer, prompt=self.system_prompt, version="v2", ) @@ -125,8 +124,12 @@ def _create_opik_tracer(self, user_thread_id: str) -> OpikTracer | None: graph=self.agent.get_graph(xray=True), project_name=project_name, tags=[user_thread_id] ) return tracer - + async def get_response(self, user_input: str, thread_id: str) -> str: + messages = [{"role": "user", "content": user_input}] + return await self.get_response_with_context(messages, thread_id) + + async def get_response_with_context(self, messages: list[dict], thread_id: str) -> str: """Get response from the agent for a given input.""" if not self.agent: raise RuntimeError("Agent not initialized. Call setup() first.") @@ -142,7 +145,7 @@ async def get_response(self, user_input: str, thread_id: str) -> str: thread_config["callbacks"] = [tracer] async for chunk in self.agent.astream( - {"messages": [{"role": "user", "content": user_input}]}, thread_config + {"messages": messages}, thread_config ): if "agent" in chunk: for message in chunk["agent"]["messages"]: diff --git a/src/expert/config.py b/src/expert/config.py index 6c0ee4d..815a3cc 100644 --- a/src/expert/config.py +++ b/src/expert/config.py @@ -3,6 +3,8 @@ import os from pathlib import Path from typing import Any +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.checkpoint.memory import InMemorySaver import yaml from dotenv import load_dotenv @@ -98,3 +100,15 @@ def mcp_config(self) -> dict[str, Any]: def chat_config(self) -> dict[str, Any]: """Get chat configuration.""" return self.config.get("chat", {}) + + @property + def checkpointer(self) -> BaseCheckpointSaver: + """Get checkpointer configuration.""" + checkpointer = self.config.get("checkpointer", None) + match checkpointer: + case "memory": + return InMemorySaver() + case None: + return None + case _: + raise ValueError(f"Invalid checkpointer configuration: {checkpointer}") From 3700b58493135eab32b0c566184064fe822c630a Mon Sep 17 00:00:00 2001 From: Menendez6 Date: Fri, 11 Jul 2025 14:22:36 +0000 Subject: [PATCH 2/2] format code --- src/expert/agent.py | 8 +++----- src/expert/config.py | 8 ++++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/expert/agent.py b/src/expert/agent.py index 4db9fa9..dcc0272 100644 --- a/src/expert/agent.py +++ b/src/expert/agent.py @@ -98,7 +98,7 @@ async def setup(self): self.agent = create_react_agent( model_name, renamed_tools, - checkpointer= checkpointer, + checkpointer=checkpointer, prompt=self.system_prompt, version="v2", ) @@ -124,7 +124,7 @@ def _create_opik_tracer(self, user_thread_id: str) -> OpikTracer | None: graph=self.agent.get_graph(xray=True), project_name=project_name, tags=[user_thread_id] ) return tracer - + async def get_response(self, user_input: str, thread_id: str) -> str: messages = [{"role": "user", "content": user_input}] return await self.get_response_with_context(messages, thread_id) @@ -144,9 +144,7 @@ async def get_response_with_context(self, messages: list[dict], thread_id: str) if tracer: thread_config["callbacks"] = [tracer] - async for chunk in self.agent.astream( - {"messages": messages}, thread_config - ): + async for chunk in self.agent.astream({"messages": messages}, thread_config): if "agent" in chunk: for message in chunk["agent"]["messages"]: if "tool_calls" in message.additional_kwargs: diff --git a/src/expert/config.py b/src/expert/config.py index 815a3cc..126bbf9 100644 --- a/src/expert/config.py +++ b/src/expert/config.py @@ -3,11 +3,11 @@ import os from pathlib import Path from typing import Any -from langgraph.checkpoint.base import BaseCheckpointSaver -from langgraph.checkpoint.memory import InMemorySaver import yaml from dotenv import load_dotenv +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.checkpoint.memory import InMemorySaver class Config: @@ -100,9 +100,9 @@ def mcp_config(self) -> dict[str, Any]: def chat_config(self) -> dict[str, Any]: """Get chat configuration.""" return self.config.get("chat", {}) - + @property - def checkpointer(self) -> BaseCheckpointSaver: + def checkpointer(self) -> BaseCheckpointSaver | None: """Get checkpointer configuration.""" checkpointer = self.config.get("checkpointer", None) match checkpointer: