diff --git a/src/expert/agent.py b/src/expert/agent.py index b1e463a..dcc0272 100644 --- a/src/expert/agent.py +++ b/src/expert/agent.py @@ -6,7 +6,6 @@ from langchain_core.tools import StructuredTool from langchain_mcp_adapters.client import MultiServerMCPClient -from langgraph.checkpoint.memory import InMemorySaver from langgraph.prebuilt import create_react_agent from opik.integrations.langchain import OpikTracer @@ -29,7 +28,6 @@ def __init__(self, config: Config): """Initialize the agent.""" self.config = config self.agent = None - self.checkpointer = InMemorySaver() self._load_prompts() self.opik_config = None self.mcp_client = None @@ -95,11 +93,12 @@ async def setup(self): provider = llm_config.get("provider", "openai") model = llm_config.get("model", "gpt-4.1") model_name = f"{provider}:{model}" + checkpointer = self.config.checkpointer self.agent = create_react_agent( model_name, renamed_tools, - checkpointer=self.checkpointer, + checkpointer=checkpointer, prompt=self.system_prompt, version="v2", ) @@ -127,6 +126,10 @@ def _create_opik_tracer(self, user_thread_id: str) -> OpikTracer | None: return tracer async def get_response(self, user_input: str, thread_id: str) -> str: + messages = [{"role": "user", "content": user_input}] + return await self.get_response_with_context(messages, thread_id) + + async def get_response_with_context(self, messages: list[dict], thread_id: str) -> str: """Get response from the agent for a given input.""" if not self.agent: raise RuntimeError("Agent not initialized. Call setup() first.") @@ -141,9 +144,7 @@ async def get_response(self, user_input: str, thread_id: str) -> str: if tracer: thread_config["callbacks"] = [tracer] - async for chunk in self.agent.astream( - {"messages": [{"role": "user", "content": user_input}]}, thread_config - ): + async for chunk in self.agent.astream({"messages": messages}, thread_config): if "agent" in chunk: for message in chunk["agent"]["messages"]: if "tool_calls" in message.additional_kwargs: diff --git a/src/expert/config.py b/src/expert/config.py index 6c0ee4d..126bbf9 100644 --- a/src/expert/config.py +++ b/src/expert/config.py @@ -6,6 +6,8 @@ import yaml from dotenv import load_dotenv +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.checkpoint.memory import InMemorySaver class Config: @@ -98,3 +100,15 @@ def mcp_config(self) -> dict[str, Any]: def chat_config(self) -> dict[str, Any]: """Get chat configuration.""" return self.config.get("chat", {}) + + @property + def checkpointer(self) -> BaseCheckpointSaver | None: + """Get checkpointer configuration.""" + checkpointer = self.config.get("checkpointer", None) + match checkpointer: + case "memory": + return InMemorySaver() + case None: + return None + case _: + raise ValueError(f"Invalid checkpointer configuration: {checkpointer}")