Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/expert/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from langchain_core.tools import StructuredTool
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.prebuilt import create_react_agent
from opik.integrations.langchain import OpikTracer

Expand All @@ -29,7 +28,6 @@ def __init__(self, config: Config):
"""Initialize the agent."""
self.config = config
self.agent = None
self.checkpointer = InMemorySaver()
self._load_prompts()
self.opik_config = None
self.mcp_client = None
Expand Down Expand Up @@ -95,11 +93,12 @@ async def setup(self):
provider = llm_config.get("provider", "openai")
model = llm_config.get("model", "gpt-4.1")
model_name = f"{provider}:{model}"
checkpointer = self.config.checkpointer

self.agent = create_react_agent(
model_name,
renamed_tools,
checkpointer=self.checkpointer,
checkpointer=checkpointer,
prompt=self.system_prompt,
version="v2",
)
Expand Down Expand Up @@ -127,6 +126,10 @@ def _create_opik_tracer(self, user_thread_id: str) -> OpikTracer | None:
return tracer

async def get_response(self, user_input: str, thread_id: str) -> str:
messages = [{"role": "user", "content": user_input}]
return await self.get_response_with_context(messages, thread_id)
Comment thread
pmenendz marked this conversation as resolved.

async def get_response_with_context(self, messages: list[dict], thread_id: str) -> str:
"""Get response from the agent for a given input."""
if not self.agent:
raise RuntimeError("Agent not initialized. Call setup() first.")
Expand All @@ -141,9 +144,7 @@ async def get_response(self, user_input: str, thread_id: str) -> str:
if tracer:
thread_config["callbacks"] = [tracer]

async for chunk in self.agent.astream(
{"messages": [{"role": "user", "content": user_input}]}, thread_config
):
async for chunk in self.agent.astream({"messages": messages}, thread_config):
if "agent" in chunk:
for message in chunk["agent"]["messages"]:
if "tool_calls" in message.additional_kwargs:
Expand Down
14 changes: 14 additions & 0 deletions src/expert/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import yaml
from dotenv import load_dotenv
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver


class Config:
Expand Down Expand Up @@ -98,3 +100,15 @@ def mcp_config(self) -> dict[str, Any]:
def chat_config(self) -> dict[str, Any]:
"""Get chat configuration."""
return self.config.get("chat", {})

@property
def checkpointer(self) -> BaseCheckpointSaver | None:
"""Get checkpointer configuration."""
checkpointer = self.config.get("checkpointer", None)
match checkpointer:
case "memory":
return InMemorySaver()
case None:
return None
case _:
raise ValueError(f"Invalid checkpointer configuration: {checkpointer}")
Loading