diff --git a/mirix/agent/__init__.py b/mirix/agent/__init__.py
index f6112b4a..cf5f9806 100755
--- a/mirix/agent/__init__.py
+++ b/mirix/agent/__init__.py
@@ -2,23 +2,14 @@
# This module contains all agent-related functionality
from . import app_constants, app_utils
-from .agent_configs import AGENT_CONFIGS
from .agent_states import AgentStates
-from .message_queue import MessageQueue
from .meta_agent import MEMORY_AGENT_CONFIGS, MemoryAgentStates, MetaAgent
-from .temporary_message_accumulator import TemporaryMessageAccumulator
-from .upload_manager import UploadManager
__all__ = [
- "AgentWrapper",
"AgentStates",
- "AGENT_CONFIGS",
- "MessageQueue",
"MetaAgent",
"MemoryAgentStates",
"MEMORY_AGENT_CONFIGS",
- "TemporaryMessageAccumulator",
- "UploadManager",
"app_constants",
"app_utils",
"Agent",
diff --git a/mirix/agent/agent.py b/mirix/agent/agent.py
index 314fe0b5..d42af87b 100644
--- a/mirix/agent/agent.py
+++ b/mirix/agent/agent.py
@@ -82,6 +82,7 @@
from mirix.services.step_manager import StepManager
from mirix.services.tool_execution_sandbox import ToolExecutionSandbox
from mirix.settings import summarizer_settings
+from mirix.topic_extraction import extract_topics_with_ollama
from mirix.system import (
get_contine_chaining,
get_token_limit_warning,
@@ -110,6 +111,83 @@
logger = get_logger(__name__)
+def extract_text_from_messages(messages: List[Message]) -> List[dict]:
+ """
+ Extract text content from messages and use placeholders for non-text content.
+
+ Args:
+ messages: List of Message objects
+
+ Returns:
+ List of dictionaries containing role and processed content
+ """
+ processed_messages = []
+
+ for message in messages:
+ content_parts = []
+
+ if message.content:
+ for content_item in message.content:
+ if isinstance(content_item, TextContent):
+ content_parts.append(content_item.text)
+ elif isinstance(content_item, ImageContent):
+ content_parts.append(f"[IMAGE:{content_item.image_id}]")
+ elif isinstance(content_item, FileContent):
+ content_parts.append(f"[FILE:{content_item.file_id}]")
+ elif isinstance(content_item, CloudFileContent):
+ content_parts.append(f"[CLOUD_FILE:{content_item.cloud_file_uri}]")
+
+ # Combine all content parts into a single string
+ combined_content = " ".join(content_parts) if content_parts else ""
+
+ processed_messages.append({
+ "role": message.role,
+ "content": combined_content,
+ })
+
+ return processed_messages
+
+
+def extract_response_text(response: "ChatCompletionResponse") -> str:
+ """
+ Extract text content from LLM response, handling multiple choices.
+
+ Args:
+ response: ChatCompletionResponse object
+
+ Returns:
+ String with all choices joined by "\n\n"
+ """
+ choice_texts = []
+
+ for choice in response.choices:
+ parts = []
+
+ # Add regular content
+ if choice.message.content:
+ parts.append(choice.message.content)
+
+ # Add reasoning content if present
+ if choice.message.reasoning_content:
+ parts.append(f"[REASONING: {choice.message.reasoning_content}]")
+
+ # Add tool calls if present
+ if choice.message.tool_calls:
+ tool_call_strs = []
+ for tool_call in choice.message.tool_calls:
+ tool_call_strs.append(
+ f"[TOOL_CALL: {tool_call.function.name}({tool_call.function.arguments})]"
+ )
+ parts.append(" ".join(tool_call_strs))
+
+ # Combine parts for this choice
+ if parts:
+ choice_texts.append(" ".join(parts))
+
+ # Join all choices with double newline
+ return "\n\n".join(choice_texts)
+
+
class BaseAgent(ABC):
"""
Abstract class for all agents.
@@ -571,11 +649,40 @@ def _get_ai_reply(
elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1:
force_tool_call = allowed_tool_names[0]
+ from mirix.services.queue_trace_context import get_agent_trace_id
+ from mirix.services.memory_agent_tool_call_trace_manager import (
+ MemoryAgentToolCallTraceManager,
+ )
+
+ agent_trace_id = get_agent_trace_id()
+ llm_trace_manager = (
+ MemoryAgentToolCallTraceManager() if agent_trace_id else None
+ )
+
active_llm_client = llm_client or LLMClient.create(
llm_config=self.agent_state.llm_config,
)
for attempt in range(1, empty_response_retry_limit + 1):
+ response = None
+ llm_trace_id = None
+ if llm_trace_manager:
+ # Extract text from messages and use placeholders for images/files
+ processed_messages = extract_text_from_messages(message_sequence)
+
+ trace = llm_trace_manager.start_tool_call(
+ agent_trace_id,
+ function_name="llm_request",
+ function_args={
+ "step_count": step_count,
+ "attempt": attempt,
+ "force_tool_call": force_tool_call,
+ "messages": processed_messages,
+ "llm_config": self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
+ },
+ actor=self.actor,
+ )
+ llm_trace_id = trace.id
try:
log_telemetry(self.logger, "_get_ai_reply create start")
@@ -633,9 +740,59 @@ def _get_ai_reply(
raise ValueError(
f"Bad finish reason from API: {response.choices[0].finish_reason}"
)
+ if llm_trace_manager and llm_trace_id:
+ cached_tokens = response.usage.cached_tokens if response.usage else 0
+ prompt_tokens = (
+ max(response.usage.prompt_tokens - cached_tokens, 0)
+ if response.usage
+ else None
+ )
+ completion_tokens = (
+ response.usage.completion_tokens if response.usage else None
+ )
+ total_tokens = response.usage.total_tokens if response.usage else None
+ credit_cost = None
+ if response.usage:
+ try:
+ from mirix.pricing import calculate_cost
+
+ credit_cost = calculate_cost(
+ model=self.model,
+ prompt_tokens=prompt_tokens or 0,
+ completion_tokens=completion_tokens or 0,
+ cached_tokens=cached_tokens,
+ )
+ except Exception as e:
+ printv(
+ f"[Mirix.Agent.{self.agent_state.name}] WARNING: Failed to calculate LLM request credits: {e}"
+ )
+ # Extract response text from all choices
+ response_text = extract_response_text(response)
+
+ llm_trace_manager.finish_tool_call(
+ llm_trace_id,
+ success=True,
+ llm_call_id=response.id,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ cached_tokens=cached_tokens if response.usage else None,
+ total_tokens=total_tokens,
+ credit_cost=credit_cost,
+ response_text=response_text,
+ actor=self.actor,
+ )
+
log_telemetry(self.logger, "_handle_ai_response finish")
except ValueError as ve:
+ if llm_trace_manager and llm_trace_id:
+ llm_trace_manager.finish_tool_call(
+ llm_trace_id,
+ success=False,
+ error_message=str(ve),
+ llm_call_id=response.id if response else None,
+ actor=self.actor,
+ )
if attempt >= empty_response_retry_limit:
printv(
f"[Mirix.Agent.{self.agent_state.name}] ERROR: Retry limit reached. Final error: {ve}"
@@ -655,6 +812,14 @@ def _get_ai_reply(
except KeyError as ke:
# Gemini api sometimes can yield empty response
# This is a retryable error
+ if llm_trace_manager and llm_trace_id:
+ llm_trace_manager.finish_tool_call(
+ llm_trace_id,
+ success=False,
+ error_message=str(ke),
+ llm_call_id=response.id if response else None,
+ actor=self.actor,
+ )
if attempt >= empty_response_retry_limit:
printv(
f"[Mirix.Agent.{self.agent_state.name}] ERROR: Retry limit reached. Final error: {ke}"
@@ -672,6 +837,14 @@ def _get_ai_reply(
continue
except LLMError as llm_error:
+ if llm_trace_manager and llm_trace_id:
+ llm_trace_manager.finish_tool_call(
+ llm_trace_id,
+ success=False,
+ error_message=str(llm_error),
+ llm_call_id=response.id if response else None,
+ actor=self.actor,
+ )
if attempt >= empty_response_retry_limit:
printv(
f"[Mirix.Agent.{self.agent_state.name}] ERROR: Retry limit reached. Final error: {llm_error}"
@@ -708,6 +881,14 @@ def _get_ai_reply(
continue
except AssertionError as ae:
+ if llm_trace_manager and llm_trace_id:
+ llm_trace_manager.finish_tool_call(
+ llm_trace_id,
+ success=False,
+ error_message=str(ae),
+ llm_call_id=response.id if response else None,
+ actor=self.actor,
+ )
if attempt >= empty_response_retry_limit:
printv(
f"[Mirix.Agent.{self.agent_state.name}] ERROR: Retry limit reached. Final error: {ae}"
@@ -724,6 +905,14 @@ def _get_ai_reply(
continue
except requests.exceptions.HTTPError as he:
+ if llm_trace_manager and llm_trace_id:
+ llm_trace_manager.finish_tool_call(
+ llm_trace_id,
+ success=False,
+ error_message=str(he),
+ llm_call_id=response.id if response else None,
+ actor=self.actor,
+ )
if attempt >= empty_response_retry_limit:
printv(
f"[Mirix.Agent.{self.agent_state.name}] ERROR: Retry limit reached. Final error: {he}"
@@ -739,6 +928,14 @@ def _get_ai_reply(
time.sleep(delay)
except Exception as e:
+ if llm_trace_manager and llm_trace_id:
+ llm_trace_manager.finish_tool_call(
+ llm_trace_id,
+ success=False,
+ error_message=str(e),
+ llm_call_id=response.id if response else None,
+ actor=self.actor,
+ )
log_telemetry(
self.logger, "_handle_ai_response finish generic Exception"
)
@@ -777,6 +974,7 @@ def _handle_ai_response(
return_memory_types_without_update: bool = False,
message_queue: Optional[any] = None,
chaining: bool = True,
+ llm_usage: Optional[dict] = None,
) -> Tuple[List[Message], bool, bool]:
"""Handles parsing and function execution"""
@@ -1023,12 +1221,14 @@ def _record_assistant_message(
)
tool_call_trace_id = None
+ llm_call_id = llm_usage.get("llm_call_id") if llm_usage else None
if tool_call_trace_manager:
trace = tool_call_trace_manager.start_tool_call(
agent_trace_id,
function_name=function_name,
function_args=function_args_for_trace,
tool_call_id=tool_call_id,
+ llm_call_id=llm_call_id,
actor=self.actor,
)
tool_call_trace_id = trace.id
@@ -1280,11 +1480,11 @@ def _record_assistant_message(
if self.agent_state.name == "episodic_memory_agent":
memory_item = self.episodic_memory_manager.get_most_recently_updated_event(
- user=self.user,
- timezone_str=self.user.timezone,
+ actor=self.actor,
+ user_id=self.user_id,
+ timezone_str=self.user.timezone if self.user else None,
)
if memory_item:
- memory_item = memory_item[0]
memory_item_str = ""
memory_item_str += (
"[Episodic Event ID]: " + memory_item.id + "\n"
@@ -1313,11 +1513,11 @@ def _record_assistant_message(
elif self.agent_state.name == "procedural_memory_agent":
memory_item = self.procedural_memory_manager.get_most_recently_updated_item(
- actor=self.user,
- timezone_str=self.user.timezone,
+ actor=self.actor,
+ user_id=self.user_id,
+ timezone_str=self.user.timezone if self.user else None,
)
if memory_item:
- memory_item = memory_item[0]
memory_item_str = ""
memory_item_str += (
"[Procedural Memory ID]: " + memory_item.id + "\n"
@@ -1345,12 +1545,12 @@ def _record_assistant_message(
elif self.agent_state.name == "resource_memory_agent":
memory_item = (
self.resource_memory_manager.get_most_recently_updated_item(
- actor=self.user,
- timezone_str=self.user.timezone,
+ actor=self.actor,
+ user_id=self.user_id,
+ timezone_str=self.user.timezone if self.user else None,
)
)
if memory_item:
- memory_item = memory_item[0]
memory_item_str = ""
memory_item_str += (
"[Resource Memory ID]: " + memory_item.id + "\n"
@@ -1379,8 +1579,9 @@ def _record_assistant_message(
elif self.agent_state.name == "knowledge_memory_agent":
memory_item = (
self.knowledge_memory_manager.get_most_recently_updated_item(
- actor=self.user,
- timezone_str=self.user.timezone,
+ actor=self.actor,
+ user_id=self.user_id,
+ timezone_str=self.user.timezone if self.user else None,
)
)
@@ -1392,7 +1593,6 @@ def _record_assistant_message(
memory_item_str = "No new knowledge items were added."
if memory_item:
- memory_item = memory_item[0]
memory_item_str = ""
memory_item_str += (
"[Knowledge ID]: " + memory_item.id + "\n"
@@ -1424,12 +1624,12 @@ def _record_assistant_message(
elif self.agent_state.name == "semantic_memory_agent":
memory_item = (
self.semantic_memory_manager.get_most_recently_updated_item(
- actor=self.user,
- timezone_str=self.user.timezone,
+ actor=self.actor,
+ user_id=self.user_id,
+ timezone_str=self.user.timezone if self.user else None,
)
)
if memory_item:
- memory_item = memory_item[0]
memory_item_str = ""
memory_item_str += (
"[Semantic Memory ID]: " + memory_item.id + "\n"
@@ -1787,7 +1987,43 @@ def step(
llm_config=self.agent_state.llm_config,
)
+ from mirix.services.queue_trace_context import get_agent_trace_id, get_queue_trace_id
+ from mirix.services.memory_agent_trace_manager import MemoryAgentTraceManager
+ from mirix.services.memory_queue_trace_manager import MemoryQueueTraceManager
+
+ queue_trace_manager = MemoryQueueTraceManager()
+
+ def handle_interrupt_request() -> None:
+ queue_trace_id = get_queue_trace_id()
+ if not queue_trace_id:
+ return
+ if not queue_trace_manager.is_interrupt_requested(queue_trace_id):
+ return
+ interrupt_reason = (
+ queue_trace_manager.get_interrupt_reason(queue_trace_id)
+ or "Interrupted by user"
+ )
+ queue_trace_manager.mark_completed(
+ queue_trace_id,
+ success=False,
+ error_message=interrupt_reason,
+ actor=self.actor,
+ )
+ agent_trace_id = get_agent_trace_id()
+ if agent_trace_id:
+ MemoryAgentTraceManager().finish_trace(
+ agent_trace_id,
+ success=False,
+ error_message=interrupt_reason,
+ actor=self.actor,
+ )
+ printv(
+ f"[Mirix.Agent.{self.agent_state.name}] INFO: Interrupt requested. Stopping agent step."
+ )
+ raise RuntimeError(interrupt_reason)
+
while True:
+ handle_interrupt_request()
kwargs["first_message"] = False
kwargs["step_count"] = step_count
@@ -1819,7 +2055,9 @@ def step(
initial_message_count=initial_message_count,
chaining=chaining,
llm_client=llm_client,
+ **kwargs
)
+ handle_interrupt_request()
continue_chaining = step_response.continue_chaining
function_failed = step_response.function_failed
@@ -1921,290 +2159,326 @@ def build_system_prompt_with_memories(
Returns:
Tuple[str, dict]: The complete system prompt and the retrieved memories dict
"""
+ from mirix.services.queue_trace_context import get_agent_trace_id
+ from mirix.services.memory_agent_tool_call_trace_manager import (
+ MemoryAgentToolCallTraceManager,
+ )
from mirix.schemas.agent import AgentType
+ agent_trace_id = get_agent_trace_id()
+ tool_call_trace_manager = (
+ MemoryAgentToolCallTraceManager() if agent_trace_id else None
+ )
+ retrieval_trace_id = None
+ if tool_call_trace_manager:
+ trace = tool_call_trace_manager.start_tool_call(
+ agent_trace_id,
+ function_name="retrieve_memories",
+ function_args={"topics": topics},
+ actor=self.actor,
+ )
+ retrieval_trace_id = trace.id
+
timezone_str = self.user.timezone
if retrieved_memories is None:
retrieved_memories = {}
- if "key_words" in retrieved_memories:
- key_words = retrieved_memories["key_words"]
- else:
- key_words = topics if topics is not None else ""
- retrieved_memories["key_words"] = key_words
-
- search_method = "bm25"
-
- # Prepare embedding for semantic search
- if key_words != "" and search_method == "embedding":
- embedded_text = embedding_model(
- self.agent_state.embedding_config
- ).get_text_embedding(key_words)
- embedded_text = np.array(embedded_text)
- embedded_text = np.pad(
- embedded_text,
- (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]),
- mode="constant",
- ).tolist()
- else:
- embedded_text = None
+ try:
+ if "key_words" in retrieved_memories:
+ key_words = retrieved_memories["key_words"]
+ else:
+ key_words = topics if topics is not None else ""
+ retrieved_memories["key_words"] = key_words
+
+ search_method = "bm25"
+
+ # Prepare embedding for semantic search
+ if key_words != "" and search_method == "embedding":
+ embedded_text = embedding_model(
+ self.agent_state.embedding_config
+ ).get_text_embedding(key_words)
+ embedded_text = np.array(embedded_text)
+ embedded_text = np.pad(
+ embedded_text,
+ (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]),
+ mode="constant",
+ ).tolist()
+ else:
+ embedded_text = None
- # Extract fade_after_days from agent's memory_config
- fade_after_days = None
- if self.agent_state.memory_config:
- decay_config = self.agent_state.memory_config.get("decay", {})
- if decay_config:
- fade_after_days = decay_config.get("fade_after_days")
+ # Extract fade_after_days from agent's memory_config
+ fade_after_days = None
+ if self.agent_state.memory_config:
+ decay_config = self.agent_state.memory_config.get("decay", {})
+ if decay_config:
+ fade_after_days = decay_config.get("fade_after_days")
- # Retrieve core memory
- if (
- self.agent_state.agent_type == AgentType.core_memory_agent
- or "core" not in retrieved_memories
- ):
- current_persisted_memory = Memory(
- blocks=[
- b
- for block in self.block_manager.get_blocks(
+ # Retrieve core memory
+ if (
+ self.agent_state.agent_type == AgentType.core_memory_agent
+ or "core" not in retrieved_memories
+ ):
+ current_persisted_memory = Memory(
+ blocks=[
+ b
+ for block in self.block_manager.get_blocks(
+ user=self.user,
+ auto_create_from_default=False # Don't auto-create here, only in step()
+ )
+ if (
+ b := self.block_manager.get_block_by_id(
+ block.id, user=self.user
+ )
+ )
+ is not None
+ ]
+ )
+ core_memory = current_persisted_memory.compile()
+ retrieved_memories["core"] = core_memory
+
+ if (
+ self.agent_state.agent_type == AgentType.knowledge_memory_agent
+ or "knowledge" not in retrieved_memories
+ ):
+ if (
+ self.agent_state.agent_type == AgentType.knowledge_memory_agent
+ or self.agent_state.agent_type == AgentType.reflexion_agent
+ ):
+ current_knowledge = self.knowledge_memory_manager.list_knowledge(
+ agent_state=self.agent_state,
user=self.user,
- auto_create_from_default=False # Don't auto-create here, only in step()
+ embedded_text=embedded_text,
+ query=key_words,
+ search_field="caption",
+ search_method=search_method,
+ limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
+ timezone_str=timezone_str,
+ fade_after_days=fade_after_days,
)
- if (
- b := self.block_manager.get_block_by_id(
- block.id, user=self.user
- )
+ else:
+ current_knowledge = self.knowledge_memory_manager.list_knowledge(
+ agent_state=self.agent_state,
+ user=self.user,
+ embedded_text=embedded_text,
+ query=key_words,
+ search_field="caption",
+ search_method=search_method,
+ limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
+ timezone_str=timezone_str,
+ sensitivity=["low", "medium"],
+ fade_after_days=fade_after_days,
+ )
+
+ knowledge_memory = ""
+ if len(current_knowledge) > 0:
+ for idx, knowledge_item in enumerate(current_knowledge):
+ knowledge_memory += f"[{idx}] Knowledge Item ID: {knowledge_item.id}; Caption: {knowledge_item.caption}\n"
+ retrieved_memories["knowledge"] = {
+ "total_number_of_items": self.knowledge_memory_manager.get_total_number_of_items(
+ user=self.user
+ ),
+ "current_count": len(current_knowledge),
+ "text": knowledge_memory,
+ }
+
+ # Retrieve episodic memory
+ if (
+ self.agent_state.name == "episodic_memory_agent"
+ or "episodic" not in retrieved_memories
+ ):
+ current_episodic_memory = self.episodic_memory_manager.list_episodic_memory(
+ agent_state=self.agent_state,
+ user=self.user,
+ limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
+ timezone_str=timezone_str,
+ fade_after_days=fade_after_days,
+ )
+ episodic_memory = ""
+ if len(current_episodic_memory) > 0:
+ for idx, event in enumerate(current_episodic_memory):
+ # Use agent_type instead of name to handle both standalone and meta-agent child agents
+ from mirix.schemas.agent import AgentType
+
+ if (
+ self.agent_state.agent_type == AgentType.episodic_memory_agent
+ or self.agent_state.agent_type == AgentType.reflexion_agent
+ ):
+ episodic_memory += f"[Event ID: {event.id}] Timestamp: {event.occurred_at.strftime('%Y-%m-%d %H:%M:%S')} - {event.summary} (Details: {len(event.details)} Characters)\n"
+ else:
+ episodic_memory += f"[{idx}] Timestamp: {event.occurred_at.strftime('%Y-%m-%d %H:%M:%S')} - {event.summary} (Details: {len(event.details)} Characters)\n"
+
+ recent_episodic_memory = episodic_memory.strip()
+
+ most_relevant_episodic_memory = (
+ self.episodic_memory_manager.list_episodic_memory(
+ agent_state=self.agent_state,
+ user=self.user,
+ embedded_text=embedded_text,
+ query=key_words,
+ search_field="details",
+ search_method=search_method,
+ limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
+ timezone_str=timezone_str,
+ fade_after_days=fade_after_days,
)
- is not None
- ]
- )
- core_memory = current_persisted_memory.compile()
- retrieved_memories["core"] = core_memory
+ )
+ most_relevant_episodic_memory_str = ""
+ if len(most_relevant_episodic_memory) > 0:
+ for idx, event in enumerate(most_relevant_episodic_memory):
+ # Use agent_type instead of name to handle both standalone and meta-agent child agents
+ from mirix.schemas.agent import AgentType
- if (
- self.agent_state.agent_type == AgentType.knowledge_memory_agent
- or "knowledge" not in retrieved_memories
- ):
+ if (
+ self.agent_state.agent_type == AgentType.episodic_memory_agent
+ or self.agent_state.agent_type == AgentType.reflexion_agent
+ ):
+ most_relevant_episodic_memory_str += f"[Event ID: {event.id}] Timestamp: {event.occurred_at.strftime('%Y-%m-%d %H:%M:%S')} - {event.summary} (Details: {len(event.details)} Characters)\n"
+ else:
+ most_relevant_episodic_memory_str += f"[{idx}] Timestamp: {event.occurred_at.strftime('%Y-%m-%d %H:%M:%S')} - {event.summary} (Details: {len(event.details)} Characters)\n"
+ relevant_episodic_memory = most_relevant_episodic_memory_str.strip()
+ retrieved_memories["episodic"] = {
+ "total_number_of_items": self.episodic_memory_manager.get_total_number_of_items(
+ user=self.user
+ ),
+ "recent_count": len(current_episodic_memory),
+ "relevant_count": len(most_relevant_episodic_memory),
+ "recent_episodic_memory": recent_episodic_memory,
+ "relevant_episodic_memory": relevant_episodic_memory,
+ }
+
+ # Retrieve resource memory
if (
- self.agent_state.agent_type == AgentType.knowledge_memory_agent
- or self.agent_state.agent_type == AgentType.reflexion_agent
+ self.agent_state.agent_type == AgentType.resource_memory_agent
+ or "resource" not in retrieved_memories
):
- current_knowledge = self.knowledge_memory_manager.list_knowledge(
+ current_resource_memory = self.resource_memory_manager.list_resources(
agent_state=self.agent_state,
user=self.user,
- embedded_text=embedded_text,
query=key_words,
- search_field="caption",
+ embedded_text=embedded_text,
+ search_field="summary",
search_method=search_method,
limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
timezone_str=timezone_str,
fade_after_days=fade_after_days,
)
- else:
- current_knowledge = self.knowledge_memory_manager.list_knowledge(
+ resource_memory = ""
+ if len(current_resource_memory) > 0:
+ for idx, resource in enumerate(current_resource_memory):
+ if (
+ self.agent_state.agent_type == AgentType.resource_memory_agent
+ or self.agent_state.agent_type == AgentType.reflexion_agent
+ ):
+ resource_memory += f"[Resource ID: {resource.id}] Resource Title: {resource.title}; Resource Summary: {resource.summary} Resource Type: {resource.resource_type}\n"
+ else:
+ resource_memory += f"[{idx}] Resource Title: {resource.title}; Resource Summary: {resource.summary} Resource Type: {resource.resource_type}\n"
+ resource_memory = resource_memory.strip()
+ retrieved_memories["resource"] = {
+ "total_number_of_items": self.resource_memory_manager.get_total_number_of_items(
+ user=self.user
+ ),
+ "current_count": len(current_resource_memory),
+ "text": resource_memory,
+ }
+
+ # Retrieve procedural memory
+ if (
+ self.agent_state.agent_type == AgentType.procedural_memory_agent
+ or "procedural" not in retrieved_memories
+ ):
+ current_procedural_memory = self.procedural_memory_manager.list_procedures(
agent_state=self.agent_state,
user=self.user,
- embedded_text=embedded_text,
query=key_words,
- search_field="caption",
+ embedded_text=embedded_text,
+ search_field="summary",
search_method=search_method,
limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
timezone_str=timezone_str,
- sensitivity=["low", "medium"],
fade_after_days=fade_after_days,
)
+ procedural_memory = ""
+ if len(current_procedural_memory) > 0:
+ for idx, procedure in enumerate(current_procedural_memory):
+ if (
+ self.agent_state.agent_type == AgentType.procedural_memory_agent
+ or self.agent_state.agent_type == AgentType.reflexion_agent
+ ):
+ procedural_memory += f"[Procedure ID: {procedure.id}] Entry Type: {procedure.entry_type}; Summary: {procedure.summary}\n"
+ else:
+ procedural_memory += f"[{idx}] Entry Type: {procedure.entry_type}; Summary: {procedure.summary}\n"
+ procedural_memory = procedural_memory.strip()
+ retrieved_memories["procedural"] = {
+ "total_number_of_items": self.procedural_memory_manager.get_total_number_of_items(
+ user=self.user
+ ),
+ "current_count": len(current_procedural_memory),
+ "text": procedural_memory,
+ }
- knowledge_memory = ""
- if len(current_knowledge) > 0:
- for idx, knowledge_item in enumerate(current_knowledge):
- knowledge_memory += f"[{idx}] Knowledge Item ID: {knowledge_item.id}; Caption: {knowledge_item.caption}\n"
- retrieved_memories["knowledge"] = {
- "total_number_of_items": self.knowledge_memory_manager.get_total_number_of_items(
- user=self.user
- ),
- "current_count": len(current_knowledge),
- "text": knowledge_memory,
- }
-
- # Retrieve episodic memory
- if (
- self.agent_state.name == "episodic_memory_agent"
- or "episodic" not in retrieved_memories
- ):
- current_episodic_memory = self.episodic_memory_manager.list_episodic_memory(
- agent_state=self.agent_state,
- user=self.user,
- limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
- timezone_str=timezone_str,
- fade_after_days=fade_after_days,
- )
- episodic_memory = ""
- if len(current_episodic_memory) > 0:
- for idx, event in enumerate(current_episodic_memory):
- # Use agent_type instead of name to handle both standalone and meta-agent child agents
- from mirix.schemas.agent import AgentType
-
- if (
- self.agent_state.agent_type == AgentType.episodic_memory_agent
- or self.agent_state.agent_type == AgentType.reflexion_agent
- ):
- episodic_memory += f"[Event ID: {event.id}] Timestamp: {event.occurred_at.strftime('%Y-%m-%d %H:%M:%S')} - {event.summary} (Details: {len(event.details)} Characters)\n"
- else:
- episodic_memory += f"[{idx}] Timestamp: {event.occurred_at.strftime('%Y-%m-%d %H:%M:%S')} - {event.summary} (Details: {len(event.details)} Characters)\n"
-
- recent_episodic_memory = episodic_memory.strip()
-
- most_relevant_episodic_memory = (
- self.episodic_memory_manager.list_episodic_memory(
+ # Retrieve semantic memory
+ if (
+ self.agent_state.agent_type == AgentType.semantic_memory_agent
+ or "semantic" not in retrieved_memories
+ ):
+ current_semantic_memory = self.semantic_memory_manager.list_semantic_items(
agent_state=self.agent_state,
user=self.user,
- embedded_text=embedded_text,
query=key_words,
+ embedded_text=embedded_text,
search_field="details",
search_method=search_method,
limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
timezone_str=timezone_str,
fade_after_days=fade_after_days,
)
- )
- most_relevant_episodic_memory_str = ""
- if len(most_relevant_episodic_memory) > 0:
- for idx, event in enumerate(most_relevant_episodic_memory):
- # Use agent_type instead of name to handle both standalone and meta-agent child agents
- from mirix.schemas.agent import AgentType
-
- if (
- self.agent_state.agent_type == AgentType.episodic_memory_agent
- or self.agent_state.agent_type == AgentType.reflexion_agent
- ):
- most_relevant_episodic_memory_str += f"[Event ID: {event.id}] Timestamp: {event.occurred_at.strftime('%Y-%m-%d %H:%M:%S')} - {event.summary} (Details: {len(event.details)} Characters)\n"
- else:
- most_relevant_episodic_memory_str += f"[{idx}] Timestamp: {event.occurred_at.strftime('%Y-%m-%d %H:%M:%S')} - {event.summary} (Details: {len(event.details)} Characters)\n"
- relevant_episodic_memory = most_relevant_episodic_memory_str.strip()
- retrieved_memories["episodic"] = {
- "total_number_of_items": self.episodic_memory_manager.get_total_number_of_items(
- user=self.user
- ),
- "recent_count": len(current_episodic_memory),
- "relevant_count": len(most_relevant_episodic_memory),
- "recent_episodic_memory": recent_episodic_memory,
- "relevant_episodic_memory": relevant_episodic_memory,
- }
-
- # Retrieve resource memory
- if (
- self.agent_state.agent_type == AgentType.resource_memory_agent
- or "resource" not in retrieved_memories
- ):
- current_resource_memory = self.resource_memory_manager.list_resources(
- agent_state=self.agent_state,
- user=self.user,
- query=key_words,
- embedded_text=embedded_text,
- search_field="summary",
- search_method=search_method,
- limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
- timezone_str=timezone_str,
- fade_after_days=fade_after_days,
- )
- resource_memory = ""
- if len(current_resource_memory) > 0:
- for idx, resource in enumerate(current_resource_memory):
- if (
- self.agent_state.agent_type == AgentType.resource_memory_agent
- or self.agent_state.agent_type == AgentType.reflexion_agent
- ):
- resource_memory += f"[Resource ID: {resource.id}] Resource Title: {resource.title}; Resource Summary: {resource.summary} Resource Type: {resource.resource_type}\n"
- else:
- resource_memory += f"[{idx}] Resource Title: {resource.title}; Resource Summary: {resource.summary} Resource Type: {resource.resource_type}\n"
- resource_memory = resource_memory.strip()
- retrieved_memories["resource"] = {
- "total_number_of_items": self.resource_memory_manager.get_total_number_of_items(
- user=self.user
- ),
- "current_count": len(current_resource_memory),
- "text": resource_memory,
- }
-
- # Retrieve procedural memory
- if (
- self.agent_state.agent_type == AgentType.procedural_memory_agent
- or "procedural" not in retrieved_memories
- ):
- current_procedural_memory = self.procedural_memory_manager.list_procedures(
- agent_state=self.agent_state,
- user=self.user,
- query=key_words,
- embedded_text=embedded_text,
- search_field="summary",
- search_method=search_method,
- limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
- timezone_str=timezone_str,
- fade_after_days=fade_after_days,
- )
- procedural_memory = ""
- if len(current_procedural_memory) > 0:
- for idx, procedure in enumerate(current_procedural_memory):
- if (
- self.agent_state.agent_type == AgentType.procedural_memory_agent
- or self.agent_state.agent_type == AgentType.reflexion_agent
- ):
- procedural_memory += f"[Procedure ID: {procedure.id}] Entry Type: {procedure.entry_type}; Summary: {procedure.summary}\n"
- else:
- procedural_memory += f"[{idx}] Entry Type: {procedure.entry_type}; Summary: {procedure.summary}\n"
- procedural_memory = procedural_memory.strip()
- retrieved_memories["procedural"] = {
- "total_number_of_items": self.procedural_memory_manager.get_total_number_of_items(
- user=self.user
- ),
- "current_count": len(current_procedural_memory),
- "text": procedural_memory,
- }
+ semantic_memory = ""
+ if len(current_semantic_memory) > 0:
+ for idx, semantic_memory_item in enumerate(current_semantic_memory):
+ if (
+ self.agent_state.agent_type == AgentType.semantic_memory_agent
+ or self.agent_state.agent_type == AgentType.reflexion_agent
+ ):
+ semantic_memory += f"[Semantic Memory ID: {semantic_memory_item.id}] Name: {semantic_memory_item.name}; Summary: {semantic_memory_item.summary}\n"
+ else:
+ semantic_memory += f"[{idx}] Name: {semantic_memory_item.name}; Summary: {semantic_memory_item.summary}\n"
- # Retrieve semantic memory
- if (
- self.agent_state.agent_type == AgentType.semantic_memory_agent
- or "semantic" not in retrieved_memories
- ):
- current_semantic_memory = self.semantic_memory_manager.list_semantic_items(
- agent_state=self.agent_state,
- user=self.user,
- query=key_words,
- embedded_text=embedded_text,
- search_field="details",
- search_method=search_method,
- limit=MAX_RETRIEVAL_LIMIT_IN_SYSTEM,
- timezone_str=timezone_str,
- fade_after_days=fade_after_days,
- )
- semantic_memory = ""
- if len(current_semantic_memory) > 0:
- for idx, semantic_memory_item in enumerate(current_semantic_memory):
- if (
- self.agent_state.agent_type == AgentType.semantic_memory_agent
- or self.agent_state.agent_type == AgentType.reflexion_agent
- ):
- semantic_memory += f"[Semantic Memory ID: {semantic_memory_item.id}] Name: {semantic_memory_item.name}; Summary: {semantic_memory_item.summary}\n"
- else:
- semantic_memory += f"[{idx}] Name: {semantic_memory_item.name}; Summary: {semantic_memory_item.summary}\n"
+ semantic_memory = semantic_memory.strip()
+ retrieved_memories["semantic"] = {
+ "total_number_of_items": self.semantic_memory_manager.get_total_number_of_items(
+ user=self.user
+ ),
+ "current_count": len(current_semantic_memory),
+ "text": semantic_memory,
+ }
- semantic_memory = semantic_memory.strip()
- retrieved_memories["semantic"] = {
- "total_number_of_items": self.semantic_memory_manager.get_total_number_of_items(
- user=self.user
- ),
- "current_count": len(current_semantic_memory),
- "text": semantic_memory,
- }
+ # Build the complete system prompt
+ memory_system_prompt = self.build_system_prompt(retrieved_memories)
- # Build the complete system prompt
- memory_system_prompt = self.build_system_prompt(retrieved_memories)
+ complete_system_prompt = raw_system + "\n\n" + memory_system_prompt
- complete_system_prompt = raw_system + "\n\n" + memory_system_prompt
+ if key_words:
+ complete_system_prompt += "\n\nThe above memories are retrieved based on the following keywords. If some memories are empty or does not contain the content related to the keywords, it is highly likely that memory does not contain any relevant information."
- if key_words:
- complete_system_prompt += "\n\nThe above memories are retrieved based on the following keywords. If some memories are empty or does not contain the content related to the keywords, it is highly likely that memory does not contain any relevant information."
+ if tool_call_trace_manager and retrieval_trace_id:
+ tool_call_trace_manager.finish_tool_call(
+ retrieval_trace_id,
+ success=True,
+ response_text=json_dumps(retrieved_memories),
+ actor=self.actor,
+ )
- return complete_system_prompt, retrieved_memories
+ return complete_system_prompt, retrieved_memories
+ except Exception as exc:
+ if tool_call_trace_manager and retrieval_trace_id:
+ tool_call_trace_manager.finish_tool_call(
+ retrieval_trace_id,
+ success=False,
+ error_message=str(exc),
+ actor=self.actor,
+ )
+ raise
def build_system_prompt(self, retrieved_memories: dict) -> str:
"""Build the system prompt for the LLM API"""
@@ -2372,6 +2646,25 @@ def _extract_topics_from_messages(self, messages: List[Message]) -> Optional[str
Returns:
Optional[str]: Extracted topics or None if extraction fails
"""
+ from mirix.services.queue_trace_context import get_agent_trace_id
+ from mirix.services.memory_agent_tool_call_trace_manager import (
+ MemoryAgentToolCallTraceManager,
+ )
+
+ agent_trace_id = get_agent_trace_id()
+ tool_call_trace_manager = (
+ MemoryAgentToolCallTraceManager() if agent_trace_id else None
+ )
+ extraction_trace_id = None
+ if tool_call_trace_manager:
+ trace = tool_call_trace_manager.start_tool_call(
+ agent_trace_id,
+ function_name="extract_topics",
+ function_args={},
+ actor=self.actor,
+ )
+ extraction_trace_id = trace.id
+
try:
# Add instruction message for topic extraction
temporary_messages = copy.deepcopy(messages)
@@ -2417,14 +2710,43 @@ def _extract_topics_from_messages(self, messages: List[Message]) -> Optional[str
}
]
+ topic_llm_config = (
+ self.agent_state.topic_extraction_llm_config
+ if getattr(self.agent_state, "topic_extraction_llm_config", None)
+ else self.agent_state.llm_config
+ )
+
+ if topic_llm_config.model_endpoint_type == "ollama":
+ message_dicts = [
+ m.to_openai_dict() if hasattr(m, "to_openai_dict") else m
+ for m in temporary_messages
+ ]
+ topics = extract_topics_with_ollama(
+ messages=message_dicts,
+ model_name=topic_llm_config.model,
+ base_url=topic_llm_config.model_endpoint,
+ )
+ if topics:
+ printv(
+ f"[Mirix.Agent.{self.agent_state.name}] INFO: Extracted topics: {topics}"
+ )
+ if tool_call_trace_manager and extraction_trace_id:
+ tool_call_trace_manager.finish_tool_call(
+ extraction_trace_id,
+ success=True,
+ response_text=topics or "No topics extracted",
+ actor=self.actor,
+ )
+ return topics
+
# Use LLMClient to extract topics
llm_client = LLMClient.create(
- llm_config=self.agent_state.llm_config,
+ llm_config=topic_llm_config,
)
if not llm_client:
raise ValueError(
- f"No LLM client available for model endpoint type: {self.agent_state.llm_config.model_endpoint_type}"
+ f"No LLM client available for model endpoint type: {topic_llm_config.model_endpoint_type}"
)
response = llm_client.send_llm_request(
@@ -2433,6 +2755,35 @@ def _extract_topics_from_messages(self, messages: List[Message]) -> Optional[str
stream=False,
force_tool_call="update_topic",
)
+ usage_payload = None
+ if response.usage:
+ cached_tokens = response.usage.cached_tokens
+ non_cached_prompt_tokens = max(
+ response.usage.prompt_tokens - cached_tokens, 0
+ )
+ usage_payload = {
+ "prompt_tokens": non_cached_prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "cached_tokens": cached_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "credit_cost": None,
+ }
+ # try:
+ # from mirix.pricing import calculate_cost
+
+ # usage_payload["credit_cost"] = calculate_cost(
+ # model=topic_llm_config.model,
+ # prompt_tokens=non_cached_prompt_tokens,
+ # completion_tokens=response.usage.completion_tokens,
+ # cached_tokens=cached_tokens,
+ # )
+ # except Exception as e:
+ # printv(
+ # f"[Mirix.Agent.{self.agent_state.name}] WARNING: Failed to calculate topic extraction credits: {e}"
+ # )
+
+ # TODO: Extract Topics is free for now:
+ usage_payload['credit_costs'] = 0.0
# Extract topics from the response
for choice in response.choices:
@@ -2449,6 +2800,39 @@ def _extract_topics_from_messages(self, messages: List[Message]) -> Optional[str
printv(
f"[Mirix.Agent.{self.agent_state.name}] INFO: Extracted topics: {topics}"
)
+
+ if tool_call_trace_manager and extraction_trace_id:
+ tool_call_trace_manager.finish_tool_call(
+ extraction_trace_id,
+ success=True,
+ response_text=topics or "No topics extracted",
+ prompt_tokens=(
+ usage_payload.get("prompt_tokens")
+ if usage_payload
+ else None
+ ),
+ completion_tokens=(
+ usage_payload.get("completion_tokens")
+ if usage_payload
+ else None
+ ),
+ cached_tokens=(
+ usage_payload.get("cached_tokens")
+ if usage_payload
+ else None
+ ),
+ total_tokens=(
+ usage_payload.get("total_tokens")
+ if usage_payload
+ else None
+ ),
+ credit_cost=(
+ usage_payload.get("credit_cost")
+ if usage_payload
+ else None
+ ),
+ actor=self.actor,
+ )
return topics
except (json.JSONDecodeError, KeyError) as parse_error:
printv(
@@ -2460,6 +2844,37 @@ def _extract_topics_from_messages(self, messages: List[Message]) -> Optional[str
printv(
f"[Mirix.Agent.{self.agent_state.name}] INFO: Error in extracting the topic from the messages: {e}"
)
+ if tool_call_trace_manager and extraction_trace_id:
+ tool_call_trace_manager.finish_tool_call(
+ extraction_trace_id,
+ success=False,
+ error_message=str(e),
+ actor=self.actor,
+ )
+ return None
+
+ if tool_call_trace_manager and extraction_trace_id:
+ tool_call_trace_manager.finish_tool_call(
+ extraction_trace_id,
+ success=True,
+ response_text="No topics extracted",
+ prompt_tokens=(
+ usage_payload.get("prompt_tokens") if usage_payload else None
+ ),
+ completion_tokens=(
+ usage_payload.get("completion_tokens") if usage_payload else None
+ ),
+ cached_tokens=(
+ usage_payload.get("cached_tokens") if usage_payload else None
+ ),
+ total_tokens=(
+ usage_payload.get("total_tokens") if usage_payload else None
+ ),
+ credit_cost=(
+ usage_payload.get("credit_cost") if usage_payload else None
+ ),
+ actor=self.actor,
+ )
return None
@@ -2614,31 +3029,47 @@ def inner_step(
f"[Mirix.Agent.{self.agent_state.name}] INFO: AI response received - choices: {len(response.choices)}"
)
- # Deduct credits based on model-specific token pricing
- if response.usage and self.client_id:
+ llm_usage_payload = None
+ cost = None
+ if response.usage:
+ cached_tokens = response.usage.cached_tokens
+ non_cached_prompt_tokens = max(
+ response.usage.prompt_tokens - cached_tokens, 0
+ )
+ llm_usage_payload = {
+ "llm_call_id": response.id,
+ "prompt_tokens": non_cached_prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "cached_tokens": cached_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "credit_cost": None,
+ }
try:
from mirix.pricing import calculate_cost
- client_manager = ClientManager()
-
- cached_tokens = response.usage.cached_tokens
- non_cached_prompt_tokens = max(
- response.usage.prompt_tokens - cached_tokens, 0
- )
-
cost = calculate_cost(
model=self.model,
prompt_tokens=non_cached_prompt_tokens,
completion_tokens=response.usage.completion_tokens,
cached_tokens=cached_tokens,
)
+ llm_usage_payload["credit_cost"] = cost
+ except Exception as e:
+ printv(
+ f"[Mirix.Agent.{self.agent_state.name}] WARNING: Failed to calculate credits: {e}"
+ )
+
+ # Deduct credits based on model-specific token pricing
+ if response.usage and self.client_id and cost is not None:
+ try:
+ client_manager = ClientManager()
client_manager.deduct_credits(self.client_id, cost)
usage_info = (
- f"input: {non_cached_prompt_tokens}, output: {response.usage.completion_tokens}"
+ f"input: {llm_usage_payload['prompt_tokens']}, output: {llm_usage_payload['completion_tokens']}"
)
- if cached_tokens > 0:
- usage_info += f", cached: {cached_tokens}"
+ if llm_usage_payload["cached_tokens"] > 0:
+ usage_info += f", cached: {llm_usage_payload['cached_tokens']}"
printv(
f"[Mirix.Agent.{self.agent_state.name}] INFO: Deducted ${cost:.6f} from client {self.client_id} "
f"(model: {self.model}, {usage_info})"
@@ -2665,8 +3096,9 @@ def inner_step(
# (if yes) Step 4: call the function
# (if yes) Step 5: send the info on the function call and function response to LLM
all_response_messages = []
- for response_choice in response.choices:
+ for response_idx, response_choice in enumerate(response.choices):
response_message = response_choice.message
+ response_usage = llm_usage_payload if response_idx == 0 else None
tmp_response_messages, continue_chaining, function_failed = (
self._handle_ai_response(
first_input_messge, # give the last message to the function so that other agents can see this message through funciton_calls
@@ -2682,6 +3114,7 @@ def inner_step(
return_memory_types_without_update=return_memory_types_without_update,
message_queue=message_queue,
chaining=chaining,
+ llm_usage=response_usage,
)
)
all_response_messages.extend(tmp_response_messages)
diff --git a/mirix/agent/agent_configs.py b/mirix/agent/agent_configs.py
deleted file mode 100644
index ad1b4c9a..00000000
--- a/mirix/agent/agent_configs.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from mirix.schemas.agent import AgentType
-
-# Agent configuration definitions
-AGENT_CONFIGS = [
- {
- "name": "background_agent",
- "agent_type": AgentType.background_agent,
- "attr_name": "background_agent_state",
- "include_base_tools": False,
- },
- {
- "name": "reflexion_agent",
- "agent_type": AgentType.reflexion_agent,
- "attr_name": "reflexion_agent_state",
- "include_base_tools": False,
- },
- {
- "name": "episodic_memory_agent",
- "agent_type": AgentType.episodic_memory_agent,
- "attr_name": "episodic_memory_agent_state",
- "include_base_tools": False,
- },
- {
- "name": "procedural_memory_agent",
- "agent_type": AgentType.procedural_memory_agent,
- "attr_name": "procedural_memory_agent_state",
- "include_base_tools": False,
- },
- {
- "name": "knowledge_memory_agent",
- "agent_type": AgentType.knowledge_memory_agent,
- "attr_name": "knowledge_memory_agent_state",
- "include_base_tools": False,
- },
- {
- "name": "meta_memory_agent",
- "agent_type": AgentType.meta_memory_agent,
- "attr_name": "meta_memory_agent_state",
- "include_base_tools": False,
- },
- {
- "name": "semantic_memory_agent",
- "agent_type": AgentType.semantic_memory_agent,
- "attr_name": "semantic_memory_agent_state",
- "include_base_tools": False,
- },
- {
- "name": "core_memory_agent",
- "agent_type": AgentType.core_memory_agent,
- "attr_name": "core_memory_agent_state",
- "include_base_tools": False,
- },
- {
- "name": "resource_memory_agent",
- "agent_type": AgentType.resource_memory_agent,
- "attr_name": "resource_memory_agent_state",
- "include_base_tools": False,
- },
- {
- "name": "chat_agent",
- "agent_type": None, # chat_agent doesn't use a specific agent_type
- "attr_name": "agent_state",
- "include_base_tools": True,
- },
-]
diff --git a/mirix/agent/message_queue.py b/mirix/agent/message_queue.py
deleted file mode 100644
index fd85a23e..00000000
--- a/mirix/agent/message_queue.py
+++ /dev/null
@@ -1,118 +0,0 @@
-import threading
-import time
-import traceback
-import uuid
-from mirix.log import get_logger
-
-
-logger = get_logger(__name__)
-
-class MessageQueue:
- """
- Handles queueing and ordering of messages to different agent types.
- Ensures that messages of the same type are processed in order.
- """
-
- def __init__(self):
- self.message_queue = {}
- self._message_queue_lock = threading.Lock()
-
- def send_message_in_queue(self, client, agent_id, kwargs, agent_type="chat"):
- """
- Queue a message to be sent to a specific agent type.
-
- Args:
- client: The mirix client instance
- agent_id: The ID of the agent to send the message to
- kwargs: Arguments to pass to client.send_message
- agent_type: Type of agent to send message to
-
- Returns:
- Tuple of (response, agent_type)
- """
- message_uuid = uuid.uuid4()
-
- with self._message_queue_lock:
- self.message_queue[message_uuid] = {
- "kwargs": kwargs,
- "started": False,
- "finished": False,
- "type": agent_type,
- }
-
- # Wait for earlier requests of the same type to finish
- while not self._check_if_earlier_requests_are_finished(message_uuid):
- time.sleep(0.1)
-
- with self._message_queue_lock:
- self.message_queue[message_uuid]["started"] = True
-
- try:
- response = client.send_message(
- agent_id=agent_id,
- role="user",
- **self.message_queue[message_uuid]["kwargs"],
- )
- except Exception as e:
- logger.error("Error sending message: %s", e)
- logger.error(traceback.format_exc())
- logger.debug(
- "agent_type: ", agent_type, "gets error. agent_id: ", agent_id, "ERROR"
- )
- response = "ERROR"
-
- with self._message_queue_lock:
- self.message_queue[message_uuid]["finished"] = True
- del self.message_queue[message_uuid]
-
- return response, agent_type
-
- def _check_if_earlier_requests_are_finished(self, message_uuid):
- """Check if all earlier requests of the same type have finished."""
- with self._message_queue_lock:
- if message_uuid not in self.message_queue:
- raise ValueError("Message not found in the queue.")
-
- # Get current message type
- current_message_type = self.message_queue[message_uuid]["type"]
-
- # Find index of current message
- message_keys = list(self.message_queue.keys())
- idx = message_keys.index(message_uuid)
-
- # Check earlier messages of the same type
- for i in range(idx):
- earlier_message = self.message_queue[message_keys[i]]
- if earlier_message["type"] == current_message_type:
- if not earlier_message["finished"]:
- return False
-
- return True
-
- def _get_agent_id_for_type(self, agent_states, agent_type):
- """Get the agent ID for the specified agent type."""
- agent_type_to_state_mapping = {
- "chat": "agent_state",
- "episodic_memory": "episodic_memory_agent_state",
- "procedural_memory": "procedural_memory_agent_state",
- "knowledge": "knowledge_memory_agent_state",
- "meta_memory": "meta_memory_agent_state",
- "semantic_memory": "semantic_memory_agent_state",
- "core_memory": "core_memory_agent_state",
- "resource_memory": "resource_memory_agent_state",
- "meta_memory_agent": "meta_memory_agent_state", # Alias
- }
-
- state_name = agent_type_to_state_mapping.get(agent_type)
- if not state_name:
- raise ValueError(f"Unknown agent type: {agent_type}")
-
- if not hasattr(agent_states, state_name):
- raise ValueError(f"Agent state {state_name} not found")
-
- return getattr(agent_states, state_name).id
-
- def get_queue_length(self):
- """Get the current length of the message queue."""
- with self._message_queue_lock:
- return len(self.message_queue)
diff --git a/mirix/agent/meta_agent.py b/mirix/agent/meta_agent.py
index 44d6c043..ae70c4a8 100644
--- a/mirix/agent/meta_agent.py
+++ b/mirix/agent/meta_agent.py
@@ -11,7 +11,6 @@
from mirix import EmbeddingConfig, LLMConfig
from mirix.agent.agent import Agent, BaseAgent
-from mirix.agent.message_queue import MessageQueue
from mirix.interface import AgentInterface
from mirix.orm import User
from mirix.prompts import gpt_system
@@ -220,9 +219,6 @@ def __init__(
# Initialize container for memory agent states
self.memory_agent_states = MemoryAgentStates()
- # Initialize message queue for coordinating agent operations
- self.message_queue = MessageQueue()
-
# Initialize or load memory sub-agents
self._initialize_memory_agents()
@@ -424,56 +420,6 @@ def step(
else:
raise RuntimeError("No meta_memory_agent available for coordination")
- def send_message_to_agent(
- self, agent_name: str, message: Union[str, dict], **kwargs
- ) -> tuple:
- """
- Send a message to a specific memory agent through the message queue.
-
- Args:
- agent_name: Name of the agent to send message to
- message: Message content (string or dict)
- **kwargs: Additional arguments for message processing
-
- Returns:
- Tuple of (response, usage_statistics)
- """
- # Get agent state
- agent_state = self.memory_agent_states.get_agent_state(f"{agent_name}_state")
- if agent_state is None:
- raise ValueError(f"Agent state not found for: {agent_name}")
-
- # Determine agent type for message queue
- agent_type_map = {
- "episodic_memory_agent": "episodic_memory",
- "procedural_memory_agent": "procedural_memory",
- "knowledge_memory_agent": "knowledge",
- "meta_memory_agent": "meta_memory",
- "semantic_memory_agent": "semantic_memory",
- "core_memory_agent": "core_memory",
- "resource_memory_agent": "resource_memory",
- "reflexion_agent": "reflexion",
- "background_agent": "background",
- }
-
- agent_type = agent_type_map.get(agent_name, agent_name)
-
- # Format message
- if isinstance(message, str):
- message_data = {"message": message}
- else:
- message_data = message
-
- # Send through message queue
- response, usage = self.message_queue.send_message_in_queue(
- client=self.server, # Pass server directly
- agent_id=agent_state.id,
- message_data=message_data,
- agent_type=agent_type,
- **kwargs,
- )
-
- return response, usage
def update_llm_config(self, llm_config: LLMConfig):
"""
diff --git a/mirix/agent/temporary_message_accumulator.py b/mirix/agent/temporary_message_accumulator.py
deleted file mode 100644
index c09b786a..00000000
--- a/mirix/agent/temporary_message_accumulator.py
+++ /dev/null
@@ -1,915 +0,0 @@
-import copy
-import logging
-import os
-import threading
-import time
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from datetime import datetime, timedelta
-
-from tqdm import tqdm
-
-from mirix.agent.app_constants import (
- GEMINI_MODELS,
- SKIP_META_MEMORY_MANAGER,
- TEMPORARY_MESSAGE_LIMIT,
-)
-from mirix.agent.app_utils import encode_image
-from mirix.constants import CHAINING_FOR_MEMORY_UPDATE
-from mirix.voice_utils import convert_base64_to_audio_segment, process_voice_files
-
-
-def get_image_mime_type(image_path):
- """Get MIME type for image files."""
- if image_path.lower().endswith((".png", ".PNG")):
- return "image/png"
- elif image_path.lower().endswith((".jpg", ".jpeg", ".JPG", ".JPEG")):
- return "image/jpeg"
- elif image_path.lower().endswith((".gif", ".GIF")):
- return "image/gif"
- elif image_path.lower().endswith((".webp", ".WEBP")):
- return "image/webp"
- else:
- return "image/png" # Default fallback
-
-
-class TemporaryMessageAccumulator:
- """
- Handles accumulation and processing of temporary messages (screenshots, voice, text)
- for memory absorption into different agent types.
- """
-
- def __init__(
- self,
- client,
- google_client,
- timezone,
- upload_manager,
- message_queue,
- model_name,
- temporary_message_limit=TEMPORARY_MESSAGE_LIMIT,
- ):
- self.client = client
- self.google_client = google_client
- self.timezone = timezone
- self.upload_manager = upload_manager
- self.message_queue = message_queue
- self.model_name = model_name
- self.temporary_message_limit = temporary_message_limit
-
- # Initialize logger
- self.logger = logging.getLogger(
- f"Mirix.TemporaryMessageAccumulator.{model_name}"
- )
- self.logger.setLevel(logging.INFO)
-
- # Determine if this model needs file uploads
- self.needs_upload = model_name in GEMINI_MODELS
-
- # Initialize locks for thread safety
- self._temporary_messages_lock = threading.Lock()
-
- # Initialize temporary message storage
- self.temporary_messages = [] # Flat list of (timestamp, item) tuples
- self.temporary_user_messages = [[]] # List of batches
-
- # URI tracking for cloud files
- self.uri_to_create_time = {}
-
- # Upload tracking for cleanup
- self.upload_start_times = {} # Track when uploads started for cleanup purposes
-
- def add_message(
- self, full_message, timestamp, delete_after_upload=True, async_upload=True
- ):
- """Add a message to temporary storage."""
- if self.needs_upload and self.upload_manager is not None:
- if "image_uris" in full_message and full_message["image_uris"]:
- # Handle image uploads with optional sources information
- if async_upload:
- image_file_ref_placeholders = [
- self.upload_manager.upload_file_async(image_uri, timestamp)
- for image_uri in full_message["image_uris"]
- ]
- else:
- image_file_ref_placeholders = [
- self.upload_manager.upload_file(image_uri, timestamp)
- for image_uri in full_message["image_uris"]
- ]
- # Track upload start times for timeout detection
- current_time = time.time()
- for placeholder in image_file_ref_placeholders:
- if isinstance(placeholder, dict) and placeholder.get("pending"):
- placeholder_id = id(
- placeholder
- ) # Use object ID as unique identifier
- self.upload_start_times[placeholder_id] = current_time
- else:
- image_file_ref_placeholders = None
-
- if "voice_files" in full_message and full_message["voice_files"]:
- audio_segment = []
- for i, voice_file in enumerate(full_message["voice_files"]):
- converted_segment = convert_base64_to_audio_segment(voice_file)
- if converted_segment is not None:
- audio_segment.append(converted_segment)
- else:
- self.logger.error(
- f"❌ Error converting voice chunk {i + 1}/{len(full_message['voice_files'])} to AudioSegment"
- )
- continue
- audio_segment = None if len(audio_segment) == 0 else audio_segment
- if audio_segment:
- self.logger.info(
- f"✅ Successfully processed {len(audio_segment)} voice segments"
- )
- else:
- self.logger.info("❌ No voice segments were successfully processed")
- else:
- audio_segment = None
-
- with self._temporary_messages_lock:
- sources = full_message.get("sources")
- self.temporary_messages.append(
- (
- timestamp,
- {
- "image_uris": image_file_ref_placeholders,
- "sources": sources,
- "audio_segments": audio_segment,
- "message": full_message["message"],
- },
- )
- )
-
- if delete_after_upload and full_message["image_uris"]:
- threading.Thread(
- target=self._cleanup_file_after_upload,
- args=(full_message["image_uris"], image_file_ref_placeholders),
- daemon=True,
- ).start()
-
- else:
- with self._temporary_messages_lock:
- sources = full_message.get("sources")
- image_uris = full_message.get("image_uris", [])
- self.temporary_messages.append(
- (
- timestamp,
- {
- "image_uris": image_uris,
- "sources": sources,
- "audio_segments": full_message.get("voice_files", []),
- "message": full_message["message"],
- "delete_after_upload": delete_after_upload, # Store delete flag for OpenAI models
- },
- )
- )
-
- # # Print accumulation statistics
- # total_messages = len(self.temporary_messages)
- # total_images = sum(len(item.get('image_uris', []) or []) for _, item in self.temporary_messages)
- # total_voice_files = sum(len(item.get('audio_segments', []) or []) for _, item in self.temporary_messages)
-
- def add_user_conversation(self, user_message, assistant_response):
- """Add user conversation to temporary storage."""
- self.temporary_user_messages[-1].extend(
- [
- {"role": "user", "content": user_message},
- {"role": "assistant", "content": assistant_response},
- ]
- )
-
- def should_absorb_content(self):
- """Check if content should be absorbed into memory and return ready messages."""
-
- if self.needs_upload:
- with self._temporary_messages_lock:
- ready_messages = []
-
- # Process messages in temporal order
- for i, (timestamp, item) in enumerate(self.temporary_messages):
- item_copy = copy.deepcopy(item)
- has_pending_uploads = False
-
- # Check if this message has any pending uploads
- if "image_uris" in item and item["image_uris"]:
- processed_image_uris = []
- pending_count = 0
- completed_count = 0
-
- for j, file_ref in enumerate(item["image_uris"]):
- if isinstance(file_ref, dict) and file_ref.get("pending"):
- # Get upload status
- upload_status = self.upload_manager.get_upload_status(
- file_ref
- )
-
- if upload_status["status"] == "completed":
- # Upload completed, use the resolved reference
- processed_image_uris.append(upload_status["result"])
- completed_count += 1
- # Note: Don't clean up here, this is just a check
- elif upload_status["status"] == "failed":
- # Note: Don't clean up here, this is just a check
- continue
- elif upload_status["status"] == "unknown":
- # Upload was cleaned up, treat as failed
- continue
- else:
- # Still pending
- has_pending_uploads = True
- pending_count += 1
- break
- else:
- # Already uploaded file reference
- processed_image_uris.append(file_ref)
- completed_count += 1
-
- if has_pending_uploads:
- # Found a pending message - we must stop here to maintain temporal order
- # We cannot process any messages beyond this point
- break
- else:
- # Update the copy with resolved image URIs
- item_copy["image_uris"] = processed_image_uris
- ready_messages.append((timestamp, item_copy))
- else:
- # No images or already processed, add to ready list
- ready_messages.append((timestamp, item_copy))
-
- # Check if we have enough ready messages to process
- if len(ready_messages) >= self.temporary_message_limit:
- return ready_messages
- else:
- return []
- else:
- # For non-GEMINI models: no uploads needed, just check message count
- with self._temporary_messages_lock:
- # Since there are no pending uploads to wait for, all messages are ready
- if len(self.temporary_messages) >= self.temporary_message_limit:
- # Return all messages as ready for processing
- ready_messages = []
- for timestamp, item in self.temporary_messages:
- item_copy = copy.deepcopy(item)
- ready_messages.append((timestamp, item_copy))
- return ready_messages
- else:
- return []
-
- def get_recent_images_for_chat(self, current_timestamp):
- """Get the most recent images for chat context (non-blocking).
-
- Returns:
- List of tuples: (timestamp, file_ref, sources) where sources may be None
- """
- with self._temporary_messages_lock:
- # Get the most recent content
- recent_limit = min(
- self.temporary_message_limit, len(self.temporary_messages)
- )
- most_recent_content = (
- self.temporary_messages[-recent_limit:] if recent_limit > 0 else []
- )
-
- # Calculate timestamp cutoff (1 minute ago)
- cutoff_time = current_timestamp - timedelta(minutes=1)
-
- # Extract only images for the current message context
- most_recent_images = []
- for timestamp, item in most_recent_content:
- # Handle different timestamp formats that might be used
- if isinstance(timestamp, str):
- # Try to parse timestamp string and make it timezone-aware
- timestamp_dt = datetime.fromisoformat(
- timestamp.replace("Z", "+00:00")
- )
- # If timezone-naive, localize it to match the cutoff_time timezone awareness
- if timestamp_dt.tzinfo is None:
- timestamp_dt = self.timezone.localize(timestamp_dt)
- elif isinstance(timestamp, datetime):
- timestamp_dt = timestamp
- # If timezone-naive, localize it to match the cutoff_time timezone awareness
- if timestamp_dt.tzinfo is None:
- timestamp_dt = self.timezone.localize(timestamp_dt)
- elif isinstance(timestamp, (int, float)):
- # Unix timestamp - make it timezone-aware
- timestamp_dt = datetime.fromtimestamp(timestamp, tz=self.timezone)
- else:
- # Skip if we can't parse the timestamp
- continue
-
- # Check if timestamp is within the past 1 minute
- if timestamp_dt < cutoff_time:
- continue
-
- # Check if this item has images
- if "image_uris" in item and item["image_uris"]:
- for j, file_ref in enumerate(item["image_uris"]):
- if self.needs_upload and self.upload_manager is not None:
- # For GEMINI models: Resolve pending uploads for immediate use (non-blocking check)
- if isinstance(file_ref, dict) and file_ref.get("pending"):
- # Get upload status
- upload_status = self.upload_manager.get_upload_status(
- file_ref
- )
-
- if upload_status["status"] == "completed":
- file_ref = upload_status["result"]
- # Note: Don't clean up here, this is just for chat context
- elif upload_status["status"] == "failed":
- # Upload failed, skip this image
- # Note: Don't clean up here, this is just for chat context
- continue
- elif upload_status["status"] == "unknown":
- # Upload was cleaned up, treat as failed
- # Note: Don't clean up here, this is just for chat context
- continue
- else:
- continue # Still pending, skip
-
- # For non-GEMINI models: file_ref is already the image URI, use as-is
- # Include sources information if available
- sources = item.get("sources")
- most_recent_images.append(
- (timestamp, file_ref, sources[j] if sources else None)
- )
-
- return most_recent_images
-
- def absorb_content_into_memory(
- self, agent_states, ready_messages=None, user_id=None
- ):
- """Process accumulated content and send to memory agents."""
-
- if ready_messages is not None:
- # Use the pre-processed ready messages
- ready_to_process = ready_messages
-
- # Remove the processed messages from temporary_messages and clean up placeholders
- with self._temporary_messages_lock:
- # Remove processed messages from the beginning (they were processed in temporal order)
- num_to_remove = len(ready_messages)
-
- # Clean up placeholders from the messages being removed
- if self.needs_upload and self.upload_manager is not None:
- for i in range(min(num_to_remove, len(self.temporary_messages))):
- timestamp, item = self.temporary_messages[i]
- if "image_uris" in item and item["image_uris"]:
- for file_ref in item["image_uris"]:
- if isinstance(file_ref, dict) and file_ref.get(
- "pending"
- ):
- placeholder_id = id(file_ref)
- # Clean up upload manager status and local tracking
- self.upload_manager.cleanup_resolved_upload(
- file_ref
- )
- self.upload_start_times.pop(placeholder_id, None)
-
- self.temporary_messages = self.temporary_messages[num_to_remove:]
- else:
- # Use the existing logic to separate and process messages
- with self._temporary_messages_lock:
- # Separate uploaded images, pending images, and text content
- ready_to_process = [] # Items that are ready to be processed
- pending_items = [] # Items that need to stay for next cycle
-
- for timestamp, item in self.temporary_messages:
- item_copy = copy.deepcopy(item)
- has_pending_uploads = False
-
- # Process image URIs if they exist
- if "image_uris" in item and item["image_uris"]:
- processed_image_uris = []
- for file_ref in item["image_uris"]:
- if self.needs_upload and self.upload_manager is not None:
- # For GEMINI models: Check if this is a pending placeholder
- if isinstance(file_ref, dict) and file_ref.get(
- "pending"
- ):
- placeholder_id = id(file_ref)
- # Get upload status
- upload_status = (
- self.upload_manager.get_upload_status(file_ref)
- )
-
- if upload_status["status"] == "completed":
- # Upload completed, use the result
- processed_image_uris.append(
- upload_status["result"]
- )
- # Clean up both upload manager and local tracking
- self.upload_manager.cleanup_resolved_upload(
- file_ref
- )
- self.upload_start_times.pop(
- placeholder_id, None
- )
- elif upload_status["status"] == "failed":
- # Upload failed, skip this image but continue processing
- # Clean up both upload manager and local tracking
- self.upload_manager.cleanup_resolved_upload(
- file_ref
- )
- self.upload_start_times.pop(
- placeholder_id, None
- )
- continue
- elif upload_status["status"] == "unknown":
- # Upload was cleaned up, treat as failed
- logger.debug(
- "Skipping unknown/cleaned upload in absorb_content_into_memory"
- )
- # Only clean up local tracking since upload manager already cleaned up
- self.upload_start_times.pop(
- placeholder_id, None
- )
- continue
- else:
- # Still pending, keep original for next cycle
- has_pending_uploads = True
- break
- else:
- # Already uploaded file reference
- processed_image_uris.append(file_ref)
- else:
- # For non-GEMINI models: store the image URI directly for base64 conversion later
- processed_image_uris.append(file_ref)
-
- if has_pending_uploads:
- # Keep for next cycle if any uploads are still pending
- pending_items.append((timestamp, item))
- else:
- # All uploads completed, update the item
- item_copy["image_uris"] = processed_image_uris
- ready_to_process.append((timestamp, item_copy))
- else:
- # No images or already processed, add to ready list
- ready_to_process.append((timestamp, item_copy))
-
- # Keep only items that are still pending (for GEMINI models) or clear all (for non-GEMINI models)
- self.temporary_messages = pending_items
-
- # Extract voice content from ready_to_process messages
- voice_content = []
- for _, item in ready_to_process:
- if "audio_segments" in item and item["audio_segments"] is not None:
- # audio_segments can be a list of audio segments that can be directly combined
- voice_content.extend(item["audio_segments"])
-
- # Save voice content to folder if any exists
- if voice_content:
- current_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[
- :-3
- ] # Include milliseconds
- voice_folder = f"tmp_voice_content_{current_timestamp}"
-
- try:
- os.makedirs(voice_folder, exist_ok=True)
- logger.info("Created voice content folder: %s", voice_folder)
-
- for i, audio_segment in enumerate(voice_content):
- try:
- # Save audio segment to file
- if hasattr(audio_segment, "export"):
- # AudioSegment object
- filename = f"voice_segment_{i + 1:03d}.wav"
- filepath = os.path.join(voice_folder, filename)
- audio_segment.export(filepath, format="wav")
- self.logger.info(
- f"Saved voice segment {i + 1} to {filepath}"
- )
- else:
- # Handle other audio formats (e.g., raw bytes)
- filename = f"voice_segment_{i + 1:03d}.dat"
- filepath = os.path.join(voice_folder, filename)
- with open(filepath, "wb") as f:
- if isinstance(audio_segment, bytes):
- f.write(audio_segment)
- else:
- # Convert to bytes if needed
- f.write(str(audio_segment).encode())
- logger.info("Saved voice data %s to %s", i + 1, filepath)
- except Exception as e:
- logger.error("Failed to save voice segment %s: %s", i + 1, e)
-
- self.logger.info(
- f"Successfully saved {len(voice_content)} voice segments to {voice_folder}"
- )
- except Exception as e:
- self.logger.error(
- f"Failed to create voice content folder {voice_folder}: {e}"
- )
-
- # Process content and build message
- message = self._build_memory_message(ready_to_process, voice_content)
-
- # Handle user conversation if exists
- message, user_message_added = self._add_user_conversation_to_message(message)
-
- if SKIP_META_MEMORY_MANAGER:
- # Add system instruction
- if user_message_added:
- system_message = "[System Message] Interpret the provided content and the conversations between the user and the chat agent, according to what the user is doing, trigger the appropriate memory update."
- else:
- system_message = "[System Message] Interpret the provided content, according to what the user is doing, extract the important information matching your memory type and save it into the memory."
- else:
- # Add system instruction for meta memory manager
- if user_message_added:
- system_message = "[System Message] As the meta memory manager, analyze the provided content and the conversations between the user and the chat agent. Based on what the user is doing, determine which memory should be updated (episodic, procedural, knowledge, semantic, core, and resource)."
- else:
- system_message = "[System Message] As the meta memory manager, analyze the provided content. Based on the content, determine what memories need to be updated (episodic, procedural, knowledge, semantic, core, and resource)"
-
- message.append({"type": "text", "text": system_message})
-
- t1 = time.time()
- if SKIP_META_MEMORY_MANAGER:
- # Send to memory agents in parallel
- self._send_to_memory_agents_separately(
- message,
- set(list(self.uri_to_create_time.keys())),
- agent_states,
- user_id=user_id,
- )
- else:
- # Send to meta memory agent
- response, agent_type = self._send_to_meta_memory_agent(
- message,
- set(list(self.uri_to_create_time.keys())),
- agent_states,
- user_id=user_id,
- )
-
- t2 = time.time()
- logger.info("Time taken to send to memory agents: %s seconds", t2 - t1)
-
- # # write the logic to send the message to all the agents one by one
- # payloads = {
- # 'message': message,
- # 'chaining': CHAINING_FOR_MEMORY_UPDATE
- # }
-
- # for agent_type in ['episodic_memory', 'procedural_memory', 'knowledge',
- # 'semantic_memory', 'core_memory', 'resource_memory']:
- # self.message_queue.send_message_in_queue(
- # self.client,
- # agent_states,
- # payloads,
- # agent_type
- # )
-
- # Clean up processed content
- self._cleanup_processed_content(ready_to_process, user_message_added)
-
- def _build_memory_message(self, ready_to_process, voice_content):
- """Build the message content for memory agents."""
-
- # Collect content organized by source
- images_by_source = {} # source_name -> [(timestamp, file_refs)]
- text_content = []
- audio_content = []
-
- for timestamp, item in ready_to_process:
- # Handle images with sources
- if "image_uris" in item and item["image_uris"]:
- sources = item.get("sources", [])
- image_uris = item["image_uris"]
-
- # If we have sources, group images by source
- if sources and len(sources) == len(image_uris):
- for source, file_ref in zip(sources, image_uris):
- if source not in images_by_source:
- images_by_source[source] = []
- images_by_source[source].append((timestamp, file_ref))
- else:
- # Fallback: if no sources or mismatch, group under generic name
- generic_source = "Screenshots"
- if generic_source not in images_by_source:
- images_by_source[generic_source] = []
- for file_ref in image_uris:
- images_by_source[generic_source].append((timestamp, file_ref))
-
- # Handle text messages
- if "message" in item and item["message"]:
- text_content.append((timestamp, item["message"]))
-
- # Handle audio segments
- if "audio_segments" in item and item["audio_segments"]:
- audio_content.extend(item["audio_segments"])
-
- # Process voice files from both sources (voice_content and audio_segments)
- all_voice_content = voice_content.copy() if voice_content else []
- all_voice_content.extend(audio_content)
-
- voice_transcription = ""
- if all_voice_content:
- voice_transcription = process_voice_files(all_voice_content)
-
- # Build the structured message for memory agents
- message_parts = []
-
- # Add screenshots grouped by source
- if images_by_source:
- # Add general introductory text
- message_parts.append(
- {
- "type": "text",
- "text": "The following are the screenshots taken from the computer of the user:",
- }
- )
-
- # Group by source application
- for source_name, source_images in images_by_source.items():
- # Add source-specific header
- message_parts.append(
- {
- "type": "text",
- "text": f"These are the screenshots from {source_name}:",
- }
- )
-
- # Add each image with its timestamp
- for timestamp, file_ref in source_images:
- message_parts.append(
- {"type": "text", "text": f"Timestamp: {timestamp}"}
- )
-
- # Handle different types of file references
- if hasattr(file_ref, "uri"):
- # GEMINI models: use Google Cloud file URI
- message_parts.append(
- {
- "type": "google_cloud_file_uri",
- "google_cloud_file_uri": file_ref.uri,
- }
- )
- else:
- # OpenAI models: convert to base64
- try:
- mime_type = get_image_mime_type(file_ref)
- base64_data = encode_image(file_ref)
- message_parts.append(
- {
- "type": "image_data",
- "image_data": {
- "data": f"data:{mime_type};base64,{base64_data}",
- "detail": "auto",
- },
- }
- )
- except Exception as e:
- logger.error("Failed to encode image %s: %s", file_ref, e)
- # Add a text message indicating the image couldn't be processed
- message_parts.append(
- {
- "type": "text",
- "text": f"[Image at {file_ref} could not be processed]",
- }
- )
-
- # Add voice transcription if any
- if voice_transcription:
- message_parts.append(
- {
- "type": "text",
- "text": f"The following are the voice recordings and their transcriptions:\n{voice_transcription}",
- }
- )
-
- # Add text content if any
- if text_content:
- message_parts.append(
- {
- "type": "text",
- "text": "The following are text messages from the user:",
- }
- )
-
- for idx, (timestamp, text) in enumerate(text_content):
- message_parts.append(
- {"type": "text", "text": f"Timestamp: {timestamp} Text:\n{text}"}
- )
-
- return message_parts
-
- def _add_user_conversation_to_message(self, message):
- """Add user conversation to the message if it exists."""
- user_message_added = False
- if len(self.temporary_user_messages[-1]) > 0:
- user_conversation = "The following are the conversations between the user and the Chat Agent while capturing this content:\n"
- for idx, user_message in enumerate(self.temporary_user_messages[-1]):
- user_conversation += f"role: {user_message['role']}; content: {user_message['content']}\n"
- user_conversation = user_conversation.strip()
-
- message.append({"type": "text", "text": user_conversation})
-
- self.temporary_user_messages.append([])
- user_message_added = True
- return message, user_message_added
-
- def _send_to_meta_memory_agent(
- self, message, existing_file_uris, agent_states, user_id=None
- ):
- """Send the processed content to the meta memory agent."""
-
- payloads = {
- "message": message,
- "existing_file_uris": existing_file_uris,
- "chaining": CHAINING_FOR_MEMORY_UPDATE,
- "message_queue": self.message_queue,
- "user_id": user_id,
- }
-
- response, agent_type = self.message_queue.send_message_in_queue(
- self.client,
- agent_states.meta_memory_agent_state.id,
- payloads,
- "meta_memory",
- )
- return response, agent_type
-
- def _send_to_memory_agents_separately(
- self, message, existing_file_uris, agent_states, user_id=None
- ):
- """Send the processed content to all memory agents in parallel."""
-
- payloads = {
- "message": message,
- "existing_file_uris": existing_file_uris,
- "chaining": CHAINING_FOR_MEMORY_UPDATE,
- "user_id": user_id,
- }
-
- responses = []
- memory_agent_types = [
- "episodic_memory",
- "procedural_memory",
- "knowledge",
- "semantic_memory",
- "core_memory",
- "resource_memory",
- ]
-
- with ThreadPoolExecutor(max_workers=6) as pool:
- futures = [
- pool.submit(
- self.message_queue.send_message_in_queue,
- self.client,
- self.message_queue._get_agent_id_for_type(agent_states, agent_type),
- payloads,
- agent_type,
- )
- for agent_type in memory_agent_types
- ]
-
- for future in tqdm(as_completed(futures), total=len(futures)):
- response, agent_type = future.result()
- responses.append(response)
-
- def _cleanup_processed_content(self, ready_to_process, user_message_added):
- """Clean up processed content and mark files as processed."""
- # Mark processed files as processed in database and cleanup upload results (only for GEMINI models)
- if self.needs_upload and self.upload_manager is not None:
- for timestamp, item in ready_to_process:
- if "image_uris" in item and item["image_uris"]:
- for file_ref in item["image_uris"]:
- if hasattr(file_ref, "name"):
- try:
- self.client.server.cloud_file_mapping_manager.set_processed(
- cloud_file_id=file_ref.name
- )
- except Exception:
- pass
-
- # Clean up upload results from memory now that they've been processed
- # We need to track which placeholders were originally used to get these file_refs
- # Since we don't have direct access to the original placeholders, we'll rely on
- # the cleanup happening in the upload manager's periodic cleanup or
- # when the same placeholder is accessed again
- else:
- # For OpenAI models: Clean up image files if delete_after_upload is True
- for timestamp, item in ready_to_process:
- # Check if this item should have its files deleted
- should_delete = item.get(
- "delete_after_upload", True
- ) # Default to True for backward compatibility
-
- if should_delete and "image_uris" in item and item["image_uris"]:
- for image_uri in item["image_uris"]:
- # Only delete if it's a local file path (string)
- if isinstance(image_uri, str):
- self._delete_local_image_file(image_uri)
-
- # Clean up user messages if added
- if user_message_added:
- if len(self.temporary_user_messages) > 1:
- self.temporary_user_messages.pop(0)
-
- def _delete_local_image_file(self, image_path):
- """Delete a local image file with retry logic."""
- try:
- max_retries = 10
- retry_count = 0
- while retry_count < max_retries:
- try:
- if os.path.exists(image_path):
- os.remove(image_path)
- logger.debug("Deleted processed image file: %s", image_path)
- if not os.path.exists(image_path):
- break
- else:
- break # File doesn't exist, nothing to do
- except Exception as e:
- retry_count += 1
- if retry_count < max_retries:
- time.sleep(0.1)
- else:
- self.logger.warning(
- f"Failed to delete image file {image_path} after {max_retries} attempts: {e}"
- )
- except Exception as e:
- self.logger.error(
- f"Error while trying to delete image file {image_path}: {e}"
- )
-
- def _cleanup_file_after_upload(self, filenames, placeholders):
- """Clean up local file after upload completes."""
-
- if self.upload_manager is None:
- return # No upload manager for non-GEMINI models
-
- for filename, placeholder in zip(filenames, placeholders):
- placeholder_id = id(placeholder) if isinstance(placeholder, dict) else None
-
- try:
- # Wait for upload to complete with timeout
- upload_successful = self.upload_manager.wait_for_upload(
- placeholder, timeout=60
- )
-
- if upload_successful:
- # Clean up tracking
- if placeholder_id:
- self.upload_start_times.pop(placeholder_id, None)
- else:
- # Don't clean up tracking here, let the timeout detection handle it
- pass
-
- # Remove file after upload attempt (successful or not)
- max_retries = 10
- retry_count = 0
- while retry_count < max_retries:
- try:
- if os.path.exists(filename):
- os.remove(filename)
- logger.info("Removed file: %s", filename)
- if not os.path.exists(filename):
- break
- else:
- pass
- else:
- break
- except Exception:
- retry_count += 1
- if retry_count < max_retries:
- time.sleep(0.1)
- else:
- pass
-
- except Exception:
- # Still try to remove the local file
- try:
- if os.path.exists(filename):
- os.remove(filename)
- except Exception:
- pass
-
- def get_message_count(self):
- """Get the current count of temporary messages."""
- with self._temporary_messages_lock:
- return len(self.temporary_messages)
-
- def get_upload_status_summary(self):
- """Get a summary of current upload statuses for debugging."""
- summary = {
- "total_messages": len(self.temporary_messages),
- }
-
- # Get upload manager status if available
- if self.upload_manager and hasattr(
- self.upload_manager, "get_upload_status_summary"
- ):
- summary["upload_manager_status"] = (
- self.upload_manager.get_upload_status_summary()
- )
-
- return summary
-
- def update_model(self, new_model_name):
- """Update the model name and related settings."""
- self.model_name = new_model_name
- self.needs_upload = new_model_name in GEMINI_MODELS
- self.logger = logging.getLogger(
- f"Mirix.TemporaryMessageAccumulator.{new_model_name}"
- )
- self.logger.setLevel(logging.INFO)
diff --git a/mirix/agent/upload_manager.py b/mirix/agent/upload_manager.py
deleted file mode 100644
index 92bca16b..00000000
--- a/mirix/agent/upload_manager.py
+++ /dev/null
@@ -1,262 +0,0 @@
-import logging
-import os
-import threading
-import time
-import uuid
-from concurrent.futures import ThreadPoolExecutor
-
-from PIL import Image
-
-
-class UploadManager:
- """
- Simplified upload manager that handles each image upload independently.
- Each upload gets a 10-second timeout and either succeeds or fails immediately.
- """
-
- def __init__(self, google_client, client, existing_files, uri_to_create_time):
- self.google_client = google_client
- self.client = client
- self.existing_files = existing_files
- self.uri_to_create_time = uri_to_create_time
-
- # Initialize logger
- self.logger = logging.getLogger("Mirix.UploadManager")
- self.logger.setLevel(logging.INFO)
-
- # Simple tracking: upload_uuid -> {'status': 'pending'/'completed'/'failed', 'result': file_ref or None}
- self._upload_status = {}
- self._upload_lock = threading.Lock()
-
- # Thread pool for concurrent uploads (max 4 simultaneous uploads)
- self._executor = ThreadPoolExecutor(
- max_workers=4, thread_name_prefix="upload_worker"
- )
-
- def _compress_image(self, image_path, quality=85, max_size=(1920, 1080)):
- """Compress image to reduce upload time while maintaining reasonable quality"""
- try:
- with Image.open(image_path) as img:
- # Convert to RGB if necessary
- if img.mode in ("RGBA", "LA", "P"):
- img = img.convert("RGB")
-
- # Resize if too large
- img.thumbnail(max_size, Image.Resampling.LANCZOS)
-
- # Create compressed version
- base_path = os.path.splitext(image_path)[0]
- compressed_path = f"{base_path}_compressed.jpg"
- img.save(compressed_path, "JPEG", quality=quality, optimize=True)
-
- return compressed_path if os.path.exists(compressed_path) else None
-
- except Exception as e:
- logger.error("Image compression failed for %s: %s", image_path, e)
- return None
-
- def _upload_single_file(self, upload_uuid, filename, timestamp, compressed_file):
- """Upload a single file with 5-second timeout"""
- try:
- # Check if file already exists in cloud
- if self.client.server.cloud_file_mapping_manager.check_if_existing(
- local_file_id=filename
- ):
- cloud_file_name = (
- self.client.server.cloud_file_mapping_manager.get_cloud_file(
- local_file_id=filename
- )
- )
- file_ref = [
- x for x in self.existing_files if x.name == cloud_file_name
- ][0]
-
- with self._upload_lock:
- self._upload_status[upload_uuid] = {
- "status": "completed",
- "result": file_ref,
- }
- return
-
- # Choose file to upload (compressed if available, otherwise original)
- upload_file = (
- compressed_file
- if compressed_file and os.path.exists(compressed_file)
- else filename
- )
-
- # Upload with 5-second timeout
- upload_start_time = time.time()
- file_ref = self.google_client.files.upload(file=upload_file)
- upload_duration = time.time() - upload_start_time
-
- self.logger.info(
- f"Upload completed in {upload_duration:.2f} seconds for file {upload_file}"
- )
-
- # Update tracking and database
- self.uri_to_create_time[file_ref.uri] = {
- "create_time": file_ref.create_time,
- "filename": file_ref.name,
- }
- self.client.server.cloud_file_mapping_manager.add_mapping(
- local_file_id=filename,
- cloud_file_id=file_ref.uri,
- timestamp=timestamp,
- force_add=True,
- )
-
- # Clean up compressed file if it was created and used
- if (
- compressed_file
- and compressed_file != filename
- and upload_file == compressed_file
- ):
- try:
- os.remove(compressed_file)
- logger.info("Removed compressed file: %s", compressed_file)
- except Exception:
- pass # Ignore cleanup errors
-
- # Mark as completed
- with self._upload_lock:
- self._upload_status[upload_uuid] = {
- "status": "completed",
- "result": file_ref,
- }
-
- except Exception as e:
- logger.error("Upload failed for %s: %s", filename, e)
- # Mark as failed
- with self._upload_lock:
- self._upload_status[upload_uuid] = {"status": "failed", "result": None}
-
- # Clean up compressed file on failure too
- if (
- compressed_file
- and compressed_file != filename
- and os.path.exists(compressed_file)
- ):
- try:
- os.remove(compressed_file)
- except Exception:
- pass
-
- def upload_file_async(self, filename, timestamp, compress=True):
- """Start an async upload and return immediately with a placeholder"""
- upload_uuid = str(uuid.uuid4())
-
- # Compress image if requested
- compressed_file = None
- if compress and filename.lower().endswith((".png", ".jpg", ".jpeg")):
- compressed_file = self._compress_image(filename)
-
- # Initialize status
- with self._upload_lock:
- self._upload_status[upload_uuid] = {"status": "pending", "result": None}
-
- # Submit upload task with 5-second timeout
- future = self._executor.submit(
- self._upload_single_file, upload_uuid, filename, timestamp, compressed_file
- )
-
- # Set up automatic timeout handling
- def timeout_handler():
- time.sleep(10.0) # Wait 10 seconds
- with self._upload_lock:
- if self._upload_status.get(upload_uuid, {}).get("status") == "pending":
- self.logger.info(
- f"Upload timeout (5s) for {filename}, marking as failed"
- )
- self._upload_status[upload_uuid] = {
- "status": "failed",
- "result": None,
- }
- future.cancel() # Try to cancel the upload
-
- # Start timeout handler in separate thread
- timeout_thread = threading.Thread(target=timeout_handler, daemon=True)
- timeout_thread.start()
-
- # Return placeholder
- return {"upload_uuid": upload_uuid, "filename": filename, "pending": True}
-
- def get_upload_status(self, placeholder):
- """Get upload status and result in one call"""
- if not isinstance(placeholder, dict) or not placeholder.get("pending"):
- return {"status": "completed", "result": placeholder} # Already resolved
-
- upload_uuid = placeholder["upload_uuid"]
-
- with self._upload_lock:
- if upload_uuid not in self._upload_status:
- # Upload was either never started or already cleaned up
- # For cleaned up uploads, we can't tell if they succeeded or failed
- return {"status": "unknown", "result": None}
-
- status_info = self._upload_status.get(upload_uuid, {})
- status = status_info.get("status", "pending")
- result = status_info.get("result")
-
- # Don't clean up here - let cleanup_resolved_upload handle it
- return {"status": status, "result": result}
-
- def try_resolve_upload(self, placeholder):
- """Legacy method for backward compatibility"""
- status_info = self.get_upload_status(placeholder)
- if status_info["status"] == "completed":
- return status_info["result"]
- else:
- return None
-
- def wait_for_upload(self, placeholder, timeout=30):
- """Wait for upload to complete (legacy method, now just polls get_upload_status)"""
- if not isinstance(placeholder, dict) or not placeholder.get("pending"):
- return placeholder
-
- start_time = time.time()
- while time.time() - start_time < timeout:
- upload_status = self.get_upload_status(placeholder)
-
- if upload_status["status"] == "completed":
- return upload_status["result"]
- elif upload_status["status"] == "failed":
- raise Exception(f"Upload failed for {placeholder['filename']}")
-
- time.sleep(0.1)
-
- raise TimeoutError(
- f"Upload timeout after {timeout}s for {placeholder['filename']}"
- )
-
- def upload_file(self, filename, timestamp):
- """Legacy synchronous upload method"""
- placeholder = self.upload_file_async(filename, timestamp)
- return self.wait_for_upload(
- placeholder, timeout=10
- ) # Reduced timeout since individual uploads timeout at 5s
-
- def cleanup_resolved_upload(self, placeholder):
- """Clean up resolved upload from tracking"""
- if not isinstance(placeholder, dict) or not placeholder.get("pending"):
- return # Not a pending placeholder
-
- upload_uuid = placeholder["upload_uuid"]
- with self._upload_lock:
- self._upload_status.pop(upload_uuid, None)
-
- def cleanup_upload_workers(self):
- """Gracefully shut down the thread pool"""
- try:
- self._executor.shutdown(wait=True, timeout=10)
- except Exception:
- pass # Ignore shutdown errors
-
- def get_upload_status_summary(self):
- """Get a summary of current upload statuses (for debugging)"""
- with self._upload_lock:
- summary = {}
- for uuid, info in self._upload_status.items():
- status = info.get("status", "unknown")
- summary[status] = summary.get(status, 0) + 1
- return summary
diff --git a/mirix/client/remote_client.py b/mirix/client/remote_client.py
index bba09d90..1da25165 100644
--- a/mirix/client/remote_client.py
+++ b/mirix/client/remote_client.py
@@ -16,33 +16,165 @@
from mirix.client.client import AbstractClient
from mirix.constants import FUNCTION_RETURN_CHAR_LIMIT
from mirix.log import get_logger
-from mirix.schemas.agent import AgentState, AgentType, CreateAgent, CreateMetaAgent
-from mirix.schemas.block import Block, BlockUpdate, CreateBlock, Human, Persona
+from mirix.schemas.agent import AgentState, AgentType
+from mirix.schemas.block import Block, Human, Persona
from mirix.schemas.embedding_config import EmbeddingConfig
from mirix.schemas.environment_variables import (
SandboxEnvironmentVariable,
- SandboxEnvironmentVariableCreate,
- SandboxEnvironmentVariableUpdate,
)
-from mirix.schemas.file import FileMetadata
from mirix.schemas.llm_config import LLMConfig
from mirix.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
-from mirix.schemas.message import Message, MessageCreate
+from mirix.schemas.message import Message
from mirix.schemas.mirix_response import MirixResponse
from mirix.schemas.organization import Organization
from mirix.schemas.sandbox_config import (
E2BSandboxConfig,
LocalSandboxConfig,
SandboxConfig,
- SandboxConfigCreate,
- SandboxConfigUpdate,
)
-from mirix.schemas.tool import Tool, ToolCreate, ToolUpdate
+from mirix.schemas.tool import Tool
from mirix.schemas.tool_rule import BaseToolRule
logger = get_logger(__name__)
+# Default configurations for different LLM providers
+PROVIDER_DEFAULTS = {
+ "openai": {
+ "llm_config": {
+ "model": "gpt-4o-mini",
+ "model_endpoint_type": "openai",
+ "model_endpoint": "https://api.openai.com/v1",
+ "context_window": 128000,
+ },
+ "topic_extraction_llm_config": {
+ "model": "gpt-4.1-nano",
+ "model_endpoint_type": "openai",
+ "model_endpoint": "https://api.openai.com/v1",
+ "context_window": 128000,
+ },
+ "build_embeddings_for_memory": True,
+ "embedding_config": {
+ "embedding_model": "text-embedding-3-small",
+ "embedding_endpoint": "https://api.openai.com/v1",
+ "embedding_endpoint_type": "openai",
+ "embedding_dim": 1536,
+ },
+ },
+ "anthropic": {
+ "llm_config": {
+ "model": "claude-sonnet-4-20250514",
+ "model_endpoint_type": "anthropic",
+ "model_endpoint": "https://api.anthropic.com/v1",
+ "context_window": 200000,
+ },
+ "topic_extraction_llm_config": {
+ "model": "claude-sonnet-4-20250514",
+ "model_endpoint_type": "anthropic",
+ "model_endpoint": "https://api.anthropic.com/v1",
+ "context_window": 200000,
+ },
+ # Anthropic doesn't provide embeddings, use without embeddings
+ "build_embeddings_for_memory": False,
+ },
+ "google_ai": {
+ "llm_config": {
+ "model": "gemini-2.0-flash",
+ "model_endpoint_type": "google_ai",
+ "model_endpoint": "https://generativelanguage.googleapis.com",
+ "context_window": 1000000,
+ },
+ "topic_extraction_llm_config": {
+ "model": "gemini-2.0-flash-lite",
+ "model_endpoint_type": "google_ai",
+ "model_endpoint": "https://generativelanguage.googleapis.com",
+ "context_window": 1000000,
+ },
+ "build_embeddings_for_memory": True,
+ "embedding_config": {
+ "embedding_model": "text-embedding-004",
+ "embedding_endpoint_type": "google_ai",
+ "embedding_endpoint": "https://generativelanguage.googleapis.com",
+ "embedding_dim": 768,
+ },
+ },
+}
+
+# Default meta agent configuration shared across all providers
+DEFAULT_META_AGENT_CONFIG = {
+ "system_prompts_folder": "mirix/prompts/system/base",
+ "agents": [
+ "core_memory_agent",
+ "resource_memory_agent",
+ "semantic_memory_agent",
+ "episodic_memory_agent",
+ "procedural_memory_agent",
+ "knowledge_memory_agent",
+ "reflexion_agent",
+ "background_agent",
+ ],
+ "memory": {
+ "core": [
+ {"label": "human", "value": ""},
+ {"label": "persona", "value": "I am a helpful assistant."},
+ ],
+ "decay": {
+ "fade_after_days": 30,
+ "expire_after_days": 90,
+ },
+ },
+}
+
+
+def _get_provider_config(
+ provider: str,
+ api_key: Optional[str] = None,
+ model: Optional[str] = None,
+) -> Dict[str, Any]:
+ """
+ Generate a configuration dictionary for a specific LLM provider.
+
+ Args:
+ provider: The LLM provider name ("openai", "anthropic", "google_ai")
+ api_key: Optional API key for the provider. If not provided, the config
+ won't include api_key (user can set via environment variables).
+ model: Optional model name to override the default
+
+ Returns:
+ A complete configuration dictionary ready for initialize_meta_agent
+ """
+ import copy
+
+ provider = provider.lower()
+ if provider not in PROVIDER_DEFAULTS:
+ raise ValueError(
+ f"Unknown provider '{provider}'. "
+ f"Supported providers: {list(PROVIDER_DEFAULTS.keys())}"
+ )
+
+ # Deep copy to avoid modifying the defaults
+ config = copy.deepcopy(PROVIDER_DEFAULTS[provider])
+
+ # Set API key for LLM configs (only if provided)
+ if api_key:
+ config["llm_config"]["api_key"] = api_key
+ config["topic_extraction_llm_config"]["api_key"] = api_key
+
+ # Override model if specified
+ if model:
+ config["llm_config"]["model"] = model
+ config["topic_extraction_llm_config"]["model"] = model
+
+ # Set embedding API key (only for providers that use embeddings)
+ if api_key and config.get("build_embeddings_for_memory") and "embedding_config" in config:
+ config["embedding_config"]["api_key"] = api_key
+
+ # Add meta agent config
+ config["meta_agent_config"] = copy.deepcopy(DEFAULT_META_AGENT_CONFIG)
+
+ return config
+
+
def _validate_occurred_at(occurred_at: Optional[str]) -> Optional[datetime]:
"""
Validate occurred_at format and convert to datetime.
@@ -203,7 +335,7 @@ def _ensure_org_and_client_exist(self, headers: Optional[Dict[str, str]] = None)
"""
try:
# Create or get organization first
- org_response = self._request(
+ self._request(
"POST",
"/organizations/create_or_get",
json={"org_id": self.org_id, "name": self.org_name},
@@ -217,7 +349,7 @@ def _ensure_org_and_client_exist(self, headers: Optional[Dict[str, str]] = None)
)
# Create or get client
- client_response = self._request(
+ self._request(
"POST",
"/clients/create_or_get",
json={
@@ -390,7 +522,7 @@ def _request(
# Try to extract error message from response
try:
error_detail = response.json().get("detail", str(e))
- except:
+ except Exception:
error_detail = str(e)
raise requests.HTTPError(f"API request failed: {error_detail}") from e
@@ -1192,8 +1324,10 @@ def _load_system_prompts(self, config: Dict[str, Any]) -> Dict[str, str]:
def initialize_meta_agent(
self,
+ provider: Optional[str] = None,
+ api_key: Optional[str] = None,
+ model: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
- config_path: Optional[str] = None,
update_agents: Optional[bool] = False,
headers: Optional[Dict[str, str]] = None,
) -> AgentState:
@@ -1203,15 +1337,51 @@ def initialize_meta_agent(
This creates a meta memory agent that manages multiple specialized memory agents
(episodic, semantic, procedural, etc.) for the current project.
+ There are two ways to provide configuration (in order of precedence):
+ 1. provider (+ optional api_key): Use default config for the specified provider
+ 2. config: Provide a complete configuration dictionary
+
Args:
+ provider: LLM provider name ("openai", "anthropic", "google_ai").
+ Uses default configuration for that provider.
+ api_key: Optional API key for the LLM provider. If not provided,
+ you can set it via environment variables (e.g., OPENAI_API_KEY).
+ model: Optional model name to override the provider's default model
config: Configuration dictionary with llm_config, embedding_config, etc.
- config_path: Path to YAML config file (alternative to config dict)
+ update_agents: Whether to update existing agents
+ headers: Optional HTTP headers for the request
Returns:
AgentState: The initialized meta agent
- Example:
+ Examples:
+ # Simplest setup - just provider (api_key from environment variable)
>>> client = MirixClient(api_key="your-api-key")
+ >>> meta_agent = client.initialize_meta_agent(provider="openai")
+
+ # With explicit API key
+ >>> meta_agent = client.initialize_meta_agent(
+ ... provider="openai",
+ ... api_key="sk-proj-xxx"
+ ... )
+
+ # With custom model
+ >>> meta_agent = client.initialize_meta_agent(
+ ... provider="openai",
+ ... api_key="sk-proj-xxx",
+ ... model="gpt-4o"
+ ... )
+
+ # Using Anthropic (no embeddings)
+ >>> meta_agent = client.initialize_meta_agent(
+ ... provider="anthropic",
+ ... api_key="sk-ant-xxx"
+ ... )
+
+ # Using Google AI (Gemini models)
+ >>> meta_agent = client.initialize_meta_agent(provider="google_ai")
+
+ # Using a config dictionary
>>> config = {
... "llm_config": {"model": "gemini-2.0-flash"},
... "embedding_config": {"model": "text-embedding-004"}
@@ -1219,19 +1389,20 @@ def initialize_meta_agent(
>>> meta_agent = client.initialize_meta_agent(config=config)
"""
- # Load config from file if provided
- if config_path:
- from pathlib import Path
-
- import yaml
-
- config_file = Path(config_path)
- if config_file.exists():
- with open(config_file, "r") as f:
- config = yaml.safe_load(f)
+ # Option 1: Generate config from provider (api_key is optional)
+ if provider:
+ config = _get_provider_config(
+ provider=provider,
+ api_key=api_key,
+ model=model,
+ )
if not config:
- raise ValueError("Either config or config_path must be provided")
+ raise ValueError(
+ "Configuration required. Provide one of:\n"
+ " - provider for quick setup (e.g., provider='openai')\n"
+ " - config for a configuration dictionary"
+ )
# Load system prompts from folder if specified and not already provided
if (
@@ -1261,7 +1432,7 @@ def add(
self,
messages: List[Dict[str, Any]],
user_id: Optional[str] = None,
- chaining: bool = True,
+ chaining: bool = False,
verbose: bool = False,
filter_tags: Optional[Dict[str, Any]] = None,
use_cache: bool = True,
diff --git a/mirix/client/utils.py b/mirix/client/utils.py
index ca59ca33..3a2602a5 100644
--- a/mirix/client/utils.py
+++ b/mirix/client/utils.py
@@ -9,11 +9,6 @@
from datetime import datetime, timezone
-def get_utc_time() -> datetime:
- """Get the current UTC time"""
- return datetime.now(timezone.utc)
-
-
def json_dumps(data, indent=2):
"""
JSON serializer that handles datetime objects.
diff --git a/mirix/configs/examples/mirix_azure.yaml b/mirix/configs/examples/mirix_azure.yaml
index f900f5da..e354dd53 100644
--- a/mirix/configs/examples/mirix_azure.yaml
+++ b/mirix/configs/examples/mirix_azure.yaml
@@ -11,6 +11,27 @@ llm_config:
context_window: 128000
# api_key: "your-azure-key" # Or use AZURE_OPENAI_API_KEY env var
+# Optional: separate model for topic extraction
+topic_extraction_llm_config:
+ model: "gpt-4o-mini"
+ model_endpoint_type: "azure_openai"
+ model_endpoint: "https://your-resource.openai.azure.com/"
+ azure_endpoint: "https://your-resource.openai.azure.com/"
+ azure_deployment: "gpt-4o-mini" # Your deployment name
+ api_version: "2024-10-01-preview"
+ context_window: 128000
+ # Local OpenAI-compatible (LM Studio / vLLM):
+ # model: "your-local-model"
+ # model_endpoint_type: "openai"
+ # model_endpoint: "http://localhost:1234/v1"
+ # is_local_model: true
+ # context_window: 8192
+ # Ollama (local):
+ # model: "llama3.2"
+ # model_endpoint_type: "ollama"
+ # model_endpoint: "http://localhost:11434"
+ # context_window: 8192
+
build_embeddings_for_memory: true
embedding_config:
diff --git a/mirix/configs/examples/mirix_claude.yaml b/mirix/configs/examples/mirix_claude.yaml
index 0d1ce19a..4730b1b7 100644
--- a/mirix/configs/examples/mirix_claude.yaml
+++ b/mirix/configs/examples/mirix_claude.yaml
@@ -8,6 +8,24 @@ llm_config:
context_window: 200000
# api_key: "your-anthropic-key" # Or use ANTHROPIC_API_KEY env var
+# Optional: separate model for topic extraction
+topic_extraction_llm_config:
+ model: "claude-3-5-sonnet-20241022"
+ model_endpoint_type: "anthropic"
+ model_endpoint: "https://api.anthropic.com/v1"
+ context_window: 200000
+ # Local OpenAI-compatible (LM Studio / vLLM):
+ # model: "your-local-model"
+ # model_endpoint_type: "openai"
+ # model_endpoint: "http://localhost:1234/v1"
+ # is_local_model: true
+ # context_window: 8192
+ # Ollama (local):
+ # model: "llama3.2"
+ # model_endpoint_type: "ollama"
+ # model_endpoint: "http://localhost:11434"
+ # context_window: 8192
+
build_embeddings_for_memory: true
# Note: Anthropic doesn't provide embeddings, so we use OpenAI or Google
diff --git a/mirix/configs/examples/mirix_gemini.yaml b/mirix/configs/examples/mirix_gemini.yaml
index f2c4db7f..dbe2e25d 100644
--- a/mirix/configs/examples/mirix_gemini.yaml
+++ b/mirix/configs/examples/mirix_gemini.yaml
@@ -7,6 +7,24 @@ llm_config:
model_endpoint: "https://generativelanguage.googleapis.com"
context_window: 1000000
+# Optional: separate model for topic extraction
+topic_extraction_llm_config:
+ model: "gemini-2.0-flash"
+ model_endpoint_type: "google_ai"
+ model_endpoint: "https://generativelanguage.googleapis.com"
+ context_window: 1000000
+ # Local OpenAI-compatible (LM Studio / vLLM):
+ # model: "your-local-model"
+ # model_endpoint_type: "openai"
+ # model_endpoint: "http://localhost:1234/v1"
+ # is_local_model: true
+ # context_window: 8192
+ # Ollama (local):
+ # model: "llama3.2"
+ # model_endpoint_type: "ollama"
+ # model_endpoint: "http://localhost:11434"
+ # context_window: 8192
+
build_embeddings_for_memory: true
embedding_config:
diff --git a/mirix/configs/examples/mirix_gemini_single_agent.yaml b/mirix/configs/examples/mirix_gemini_single_agent.yaml
index 7ed0cd6a..d70121b3 100644
--- a/mirix/configs/examples/mirix_gemini_single_agent.yaml
+++ b/mirix/configs/examples/mirix_gemini_single_agent.yaml
@@ -7,6 +7,24 @@ llm_config:
model_endpoint: "https://generativelanguage.googleapis.com"
context_window: 1000000
+# Optional: separate model for topic extraction
+topic_extraction_llm_config:
+ model: "gemini-2.0-flash"
+ model_endpoint_type: "google_ai"
+ model_endpoint: "https://generativelanguage.googleapis.com"
+ context_window: 1000000
+ # Local OpenAI-compatible (LM Studio / vLLM):
+ # model: "your-local-model"
+ # model_endpoint_type: "openai"
+ # model_endpoint: "http://localhost:1234/v1"
+ # is_local_model: true
+ # context_window: 8192
+ # Ollama (local):
+ # model: "llama3.2"
+ # model_endpoint_type: "ollama"
+ # model_endpoint: "http://localhost:11434"
+ # context_window: 8192
+
build_embeddings_for_memory: true
embedding_config:
diff --git a/mirix/configs/examples/mirix_openai.yaml b/mirix/configs/examples/mirix_openai.yaml
index 1864e3cd..3ae3ceef 100644
--- a/mirix/configs/examples/mirix_openai.yaml
+++ b/mirix/configs/examples/mirix_openai.yaml
@@ -4,14 +4,38 @@
llm_config:
model: "gpt-4o-mini"
model_endpoint_type: "openai"
+ api_key: "sk-proj-xxx"
model_endpoint: "https://api.openai.com/v1"
context_window: 128000
+# Optional: separate model for topic extraction
+topic_extraction_llm_config:
+ model: "gpt-4.1-nano"
+ model_endpoint_type: "openai"
+ api_key: "sk-proj-xxx"
+ model_endpoint: "https://api.openai.com/v1"
+ context_window: 128000
+ # is_local_model: false
+
+ # Local OpenAI-compatible (LM Studio / vLLM):
+ # model: "your-local-model"
+ # model_endpoint_type: "openai"
+ # model_endpoint: "http://localhost:1234/v1"
+ # is_local_model: true
+ # context_window: 8192
+
+ # Ollama (local):
+ # model: "llama3.2"
+ # model_endpoint_type: "ollama"
+ # model_endpoint: "http://localhost:11434"
+ # context_window: 8192
+
build_embeddings_for_memory: true
embedding_config:
embedding_model: "text-embedding-3-small"
embedding_endpoint: "https://api.openai.com/v1"
+ api_key: "sk-proj-xxx"
embedding_endpoint_type: "openai"
embedding_dim: 1536
diff --git a/mirix/configs/examples/mirix_openai_single_agent.yaml b/mirix/configs/examples/mirix_openai_single_agent.yaml
index f601e0c2..364ddada 100644
--- a/mirix/configs/examples/mirix_openai_single_agent.yaml
+++ b/mirix/configs/examples/mirix_openai_single_agent.yaml
@@ -7,6 +7,25 @@ llm_config:
model_endpoint: "https://api.openai.com/v1"
context_window: 128000
+# Optional: separate model for topic extraction
+topic_extraction_llm_config:
+ model: "gpt-4o-mini"
+ model_endpoint_type: "openai"
+ model_endpoint: "https://api.openai.com/v1"
+ context_window: 128000
+ # is_local_model: false
+ # Local OpenAI-compatible (LM Studio / vLLM):
+ # model: "your-local-model"
+ # model_endpoint_type: "openai"
+ # model_endpoint: "http://localhost:1234/v1"
+ # is_local_model: true
+ # context_window: 8192
+ # Ollama (local):
+ # model: "llama3.2"
+ # model_endpoint_type: "ollama"
+ # model_endpoint: "http://localhost:11434"
+ # context_window: 8192
+
build_embeddings_for_memory: true
embedding_config:
diff --git a/mirix/database/redis_client.py b/mirix/database/redis_client.py
index 12849489..8645b40a 100644
--- a/mirix/database/redis_client.py
+++ b/mirix/database/redis_client.py
@@ -251,7 +251,7 @@ def _create_block_index(self) -> None:
self.client.ft(self.BLOCK_INDEX).info()
logger.debug("Index %s already exists", self.BLOCK_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -286,7 +286,7 @@ def _create_message_index(self) -> None:
self.client.ft(self.MESSAGE_INDEX).info()
logger.debug("Index %s already exists", self.MESSAGE_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -322,7 +322,7 @@ def _create_organization_index(self) -> None:
self.client.ft(self.ORGANIZATION_INDEX).info()
logger.debug("Index %s already exists", self.ORGANIZATION_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -353,7 +353,7 @@ def _create_user_index(self) -> None:
self.client.ft(self.USER_INDEX).info()
logger.debug("Index %s already exists", self.USER_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -389,7 +389,7 @@ def _create_agent_index(self) -> None:
self.client.ft(self.AGENT_INDEX).info()
logger.debug("Index %s already exists", self.AGENT_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -427,7 +427,7 @@ def _create_tool_index(self) -> None:
self.client.ft(self.TOOL_INDEX).info()
logger.debug("Index %s already exists", self.TOOL_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -594,7 +594,7 @@ def _create_episodic_index(self) -> None:
self.client.ft(self.EPISODIC_INDEX).info()
logger.debug("Index %s already exists", self.EPISODIC_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -655,7 +655,7 @@ def _create_semantic_index(self) -> None:
self.client.ft(self.SEMANTIC_INDEX).info()
logger.debug("Index %s already exists", self.SEMANTIC_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -726,7 +726,7 @@ def _create_procedural_index(self) -> None:
self.client.ft(self.PROCEDURAL_INDEX).info()
logger.debug("Index %s already exists", self.PROCEDURAL_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -785,7 +785,7 @@ def _create_resource_index(self) -> None:
self.client.ft(self.RESOURCE_INDEX).info()
logger.debug("Index %s already exists", self.RESOURCE_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -835,7 +835,7 @@ def _create_knowledge_index(self) -> None:
self.client.ft(self.KNOWLEDGE_INDEX).info()
logger.debug("Index %s already exists", self.KNOWLEDGE_INDEX)
return
- except:
+ except Exception:
pass
schema = (
@@ -1015,7 +1015,6 @@ def search_text(
"""
try:
from redis.commands.search.query import Query
- from datetime import datetime
import re
# Escape special characters in query for Redis Search
@@ -1141,7 +1140,6 @@ def search_vector(
"""
try:
from redis.commands.search.query import Query
- from datetime import datetime
import numpy as np
# Convert embedding to bytes
@@ -1271,7 +1269,6 @@ def search_recent(
"""
try:
from redis.commands.search.query import Query
- from datetime import datetime
# Build query parts
query_parts = []
@@ -1367,7 +1364,6 @@ def search_recent_by_org(
"""
try:
from redis.commands.search.query import Query
- from datetime import datetime
import re
def escape_text_value(value: str) -> str:
@@ -1454,7 +1450,6 @@ def search_vector_by_org(
"""
try:
from redis.commands.search.query import Query
- from datetime import datetime
import re
def escape_text_value(value: str) -> str:
@@ -1548,7 +1543,6 @@ def search_text_by_org(
"""
try:
from redis.commands.search.query import Query
- from datetime import datetime
import re
def escape_text_value(value: str) -> str:
diff --git a/mirix/errors.py b/mirix/errors.py
index 7c87ab9f..40ff87ab 100755
--- a/mirix/errors.py
+++ b/mirix/errors.py
@@ -229,14 +229,3 @@ class InvalidToolCallError(MirixMessageError):
"The message uses an invalid tool call or has improper usage of a tool call."
)
-
-class MissingInnerMonologueError(MirixMessageError):
- """Error raised when a message is missing an inner monologue."""
-
- default_error_message = "The message is missing an inner monologue."
-
-
-class InvalidInnerMonologueError(MirixMessageError):
- """Error raised when a message has a malformed inner monologue."""
-
- default_error_message = "The message has a malformed inner monologue."
diff --git a/mirix/functions/function_sets/memory_tools.py b/mirix/functions/function_sets/memory_tools.py
index ff30f6b0..e3cc6892 100644
--- a/mirix/functions/function_sets/memory_tools.py
+++ b/mirix/functions/function_sets/memory_tools.py
@@ -116,7 +116,6 @@ def episodic_memory_insert(self: "Agent", items: List[EpisodicEventForLLM]):
# Get filter_tags, use_cache, client_id, user_id, and occurred_at from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
user_id = getattr(self, 'user_id', None)
occurred_at_override = getattr(self, 'occurred_at', None) # Optional timestamp override from API
@@ -204,7 +203,6 @@ def episodic_memory_replace(
# Get filter_tags, use_cache, client_id, user_id, and occurred_at from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
occurred_at_override = getattr(self, 'occurred_at', None) # Optional timestamp override from API
user_id = self.user_id
@@ -301,7 +299,6 @@ def resource_memory_insert(self: "Agent", items: List[ResourceMemoryItemBase]):
# Get filter_tags, use_cache, client_id, and user_id from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
user_id = getattr(self, 'user_id', None)
inserted_count = 0
@@ -377,7 +374,6 @@ def resource_memory_update(
# Get filter_tags, use_cache, client_id, and user_id from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
user_id = getattr(self, 'user_id', None)
for old_id in old_ids:
@@ -420,7 +416,6 @@ def procedural_memory_insert(self: "Agent", items: List[ProceduralMemoryItemBase
# Get filter_tags, use_cache, client_id, and user_id from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
user_id = getattr(self, 'user_id', None)
inserted_count = 0
@@ -497,7 +492,6 @@ def procedural_memory_update(
# Get filter_tags, use_cache, client_id, and user_id from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
user_id = getattr(self, 'user_id', None)
for old_id in old_ids:
@@ -572,7 +566,6 @@ def semantic_memory_insert(self: "Agent", items: List[SemanticMemoryItemBase]):
# Get filter_tags, use_cache, client_id, and user_id from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
user_id = getattr(self, 'user_id', None)
inserted_count = 0
@@ -653,7 +646,6 @@ def semantic_memory_update(
# Get filter_tags, use_cache, client_id, and user_id from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
user_id = getattr(self, 'user_id', None)
for old_id in old_semantic_item_ids:
@@ -702,10 +694,9 @@ def knowledge_insert(self: "Agent", items: List[KnowledgeItemBase]):
else self.agent_state.id
)
- # Get filter_tags, use_cache, client_id, and user_id from agent instance
+ # Get filter_tags, use_cache, and user_id from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
user_id = getattr(self, 'user_id', None)
inserted_count = 0
@@ -782,10 +773,9 @@ def knowledge_update(
else self.agent_state.id
)
- # Get filter_tags, use_cache, client_id, and user_id from agent instance
+ # Get filter_tags, use_cache, and user_id from agent instance
filter_tags = getattr(self, 'filter_tags', None)
use_cache = getattr(self, 'use_cache', True)
- client_id = getattr(self, 'client_id', None)
user_id = getattr(self, 'user_id', None)
for old_id in old_ids:
diff --git a/mirix/functions/helpers.py b/mirix/functions/helpers.py
index 40281bdc..9851f90c 100755
--- a/mirix/functions/helpers.py
+++ b/mirix/functions/helpers.py
@@ -15,8 +15,6 @@
from mirix.log import get_logger
from mirix.schemas.enums import MessageRole
from mirix.schemas.message import MessageCreate
-
-logger = get_logger(__name__)
from mirix.schemas.mirix_message import (
AssistantMessage,
ReasoningMessage,
@@ -24,6 +22,8 @@
)
from mirix.schemas.mirix_response import MirixResponse
+logger = get_logger(__name__)
+
if TYPE_CHECKING:
try:
from langchain_core.tools import BaseTool as LangChainBaseTool
diff --git a/mirix/llm_api/helpers.py b/mirix/llm_api/helpers.py
index 79ee0e50..1b136024 100755
--- a/mirix/llm_api/helpers.py
+++ b/mirix/llm_api/helpers.py
@@ -1,8 +1,4 @@
-import copy
-import json
import logging
-import warnings
-from collections import OrderedDict
from typing import Any, List, Union
import requests
@@ -10,9 +6,8 @@
from mirix.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from mirix.schemas.enums import MessageRole
from mirix.schemas.message import Message
-from mirix.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
from mirix.settings import summarizer_settings
-from mirix.utils import count_tokens, json_dumps, printd
+from mirix.utils import count_tokens, printd
logger = logging.getLogger(__name__)
diff --git a/mirix/llm_api/llm_client_base.py b/mirix/llm_api/llm_client_base.py
index 978b73a0..ef74ff0b 100644
--- a/mirix/llm_api/llm_client_base.py
+++ b/mirix/llm_api/llm_client_base.py
@@ -26,7 +26,7 @@ def __init__(
self.use_tool_naming = use_tool_naming
self.file_manager = FileManager()
self.cloud_file_mapping_manager = CloudFileMappingManager()
- self.logger = logging.getLogger(f"Mirix.LLMClientBase")
+ self.logger = logging.getLogger("Mirix.LLMClientBase")
def send_llm_request(
self,
diff --git a/mirix/llm_api/openai_client.py b/mirix/llm_api/openai_client.py
index ae3c533b..01aa1b9f 100644
--- a/mirix/llm_api/openai_client.py
+++ b/mirix/llm_api/openai_client.py
@@ -215,19 +215,24 @@ def build_request_data(
# Determine the appropriate tool_choice based on model capabilities
tool_choice = self._get_tool_choice(tools, force_tool_call)
- data = ChatCompletionRequest(
- model=model,
- messages=self.fill_image_content_in_messages(openai_message_list),
- tools=(
+ request_kwargs = {
+ "model": model,
+ "messages": self.fill_image_content_in_messages(openai_message_list),
+ "tools": (
[OpenAITool(type="function", function=f) for f in tools]
if tools
else None
),
- tool_choice=tool_choice,
- user=str(),
- max_completion_tokens=llm_config.max_tokens,
- temperature=llm_config.temperature,
- )
+ "tool_choice": tool_choice,
+ "user": str(),
+ "max_completion_tokens": llm_config.max_tokens,
+ }
+
+ # gpt-5 does not support temperature
+ if not (model and model.startswith("gpt-5")):
+ request_kwargs["temperature"] = llm_config.temperature
+
+ data = ChatCompletionRequest(**request_kwargs)
if data.tools is not None and len(data.tools) > 0:
# Convert to structured output style (which has 'strict' and no optionals)
diff --git a/mirix/local_client/local_client.py b/mirix/local_client/local_client.py
index 88745d05..8303cf51 100644
--- a/mirix/local_client/local_client.py
+++ b/mirix/local_client/local_client.py
@@ -13,7 +13,7 @@
import os
import shutil
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import requests
@@ -1116,14 +1116,6 @@ def convert_message(m):
input_messages = [
MessageCreate(role=MessageRole(role), content=content, name=name)
]
- if extra_messages is not None:
- extra_messages = [
- MessageCreate(
- role=MessageRole(role),
- content=[convert_message(m) for m in extra_messages],
- name=name,
- )
- ]
else:
raise ValueError(f"Invalid message type: {type(message)}")
diff --git a/mirix/orm/agent.py b/mirix/orm/agent.py
index fd31b796..3a850699 100755
--- a/mirix/orm/agent.py
+++ b/mirix/orm/agent.py
@@ -70,6 +70,11 @@ class Agent(SqlalchemyBase, OrganizationMixin):
nullable=True,
doc="the LLM backend configuration object for this agent.",
)
+ topic_extraction_llm_config: Mapped[Optional[LLMConfig]] = mapped_column(
+ LLMConfigColumn,
+ nullable=True,
+ doc="optional LLM configuration used for topic extraction.",
+ )
embedding_config: Mapped[Optional[EmbeddingConfig]] = mapped_column(
EmbeddingConfigColumn, doc="the embedding configuration object for this agent."
)
@@ -125,6 +130,7 @@ def to_pydantic(self) -> PydanticAgentState:
"system": self.system,
"agent_type": self.agent_type,
"llm_config": self.llm_config,
+ "topic_extraction_llm_config": self.topic_extraction_llm_config,
"embedding_config": self.embedding_config,
"memory": Memory(blocks=[b.to_pydantic() for b in self.core_memory]),
"created_by_id": self.created_by_id,
diff --git a/mirix/orm/client.py b/mirix/orm/client.py
index 4080828a..13201994 100644
--- a/mirix/orm/client.py
+++ b/mirix/orm/client.py
@@ -51,8 +51,8 @@ class Client(SqlalchemyBase, OrganizationMixin):
# Credits for LLM usage (1 credit = 1 dollar)
credits: Mapped[float] = mapped_column(
nullable=False,
- default=100.0,
- doc="Available credits for LLM API calls. New clients start with $100 credits. 1 credit = 1 dollar.",
+ default=10.0,
+ doc="Available credits for LLM API calls. New clients start with $10 credits. 1 credit = 1 dollar.",
)
# Relationships
diff --git a/mirix/orm/memory_agent_tool_call.py b/mirix/orm/memory_agent_tool_call.py
index 7bc51196..7c50de80 100644
--- a/mirix/orm/memory_agent_tool_call.py
+++ b/mirix/orm/memory_agent_tool_call.py
@@ -2,7 +2,7 @@
from datetime import datetime
from typing import Optional
-from sqlalchemy import Boolean, DateTime, ForeignKey, JSON, String, Text
+from sqlalchemy import Boolean, DateTime, ForeignKey, JSON, String, Text, Float, Integer
from sqlalchemy.orm import Mapped, mapped_column
from mirix.orm.sqlalchemy_base import SqlalchemyBase
@@ -28,6 +28,13 @@ class MemoryAgentToolCall(SqlalchemyBase):
function_name: Mapped[str] = mapped_column(String)
function_args: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
+ llm_call_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
+ prompt_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ completion_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ cached_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ total_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ credit_cost: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+
status: Mapped[str] = mapped_column(
String, default="running", doc="running|completed|failed"
)
diff --git a/mirix/orm/memory_queue_trace.py b/mirix/orm/memory_queue_trace.py
index 88f709d4..edb6f46d 100644
--- a/mirix/orm/memory_queue_trace.py
+++ b/mirix/orm/memory_queue_trace.py
@@ -2,7 +2,7 @@
from datetime import datetime
from typing import Optional
-from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, JSON, String, Text
+from sqlalchemy import Boolean, DateTime, Integer, JSON, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from mirix.orm.mixins import OrganizationMixin
@@ -42,6 +42,10 @@ class MemoryQueueTrace(SqlalchemyBase, OrganizationMixin):
completed_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
+ interrupt_requested_at: Mapped[Optional[datetime]] = mapped_column(
+ DateTime(timezone=True), nullable=True
+ )
+ interrupt_reason: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
message_count: Mapped[int] = mapped_column(Integer, default=0)
success: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
diff --git a/mirix/orm/sqlalchemy_base.py b/mirix/orm/sqlalchemy_base.py
index 71fe7a12..f8011b2b 100755
--- a/mirix/orm/sqlalchemy_base.py
+++ b/mirix/orm/sqlalchemy_base.py
@@ -916,6 +916,10 @@ def _update_redis_cache(
data['message_ids'] = json.dumps(data['message_ids'])
if 'llm_config' in data and data['llm_config']:
data['llm_config'] = json.dumps(data['llm_config'])
+ if 'topic_extraction_llm_config' in data and data['topic_extraction_llm_config']:
+ data['topic_extraction_llm_config'] = json.dumps(
+ data['topic_extraction_llm_config']
+ )
if 'embedding_config' in data and data['embedding_config']:
data['embedding_config'] = json.dumps(data['embedding_config'])
if 'tool_rules' in data and data['tool_rules']:
@@ -1031,4 +1035,4 @@ def _update_redis_cache(
except Exception as e:
# Log but don't fail the operation if Redis fails
logger.error("Failed to update Redis cache for %s %s: %s", self.__class__.__name__, self.id, e)
- logger.info("Operation completed successfully in PostgreSQL despite Redis error")
\ No newline at end of file
+ logger.info("Operation completed successfully in PostgreSQL despite Redis error")
diff --git a/mirix/prompts/system/base/chat_agent.txt b/mirix/prompts/system/base/chat_agent.txt
index 6bd09fc7..97f49a33 100755
--- a/mirix/prompts/system/base/chat_agent.txt
+++ b/mirix/prompts/system/base/chat_agent.txt
@@ -27,7 +27,6 @@ Repository for structured, factual data including contact information, credentia
Contains conceptual knowledge about entities, concepts, and objects, including detailed understanding and contextual information.
Operational Requirements:
-Maintain concise internal monologue (maximum 50 words at all times).
Continuously update conversation topics based on user interactions without explicitly disclosing this process to users. This functions as an internal contextual mechanism to maintain natural conversation flow and demonstrate human-like conversational memory.
You have access to partial information from each memory component. Utilize the `search_in_memory` and `list_memory_within_timerange` functions to retrieve relevant information for response formulation.
diff --git a/mirix/prompts/system/screen_monitor/chat_agent.txt b/mirix/prompts/system/screen_monitor/chat_agent.txt
index 609c0f33..76dd0855 100755
--- a/mirix/prompts/system/screen_monitor/chat_agent.txt
+++ b/mirix/prompts/system/screen_monitor/chat_agent.txt
@@ -32,7 +32,6 @@ You are the Chat Agent, responsible for user communication and proactive memory
- `send_intermediate_message` does NOT end the conversation - you must continue processing
**Key Guidelines:**
-- Maintain concise internal monologue (max 50 words)
- Monitor user sentiment; update Persona Block if self-improvement needed
- Messages without function calls are internal reasoning (invisible to users)
- Use `send_intermediate_message` sparingly - only for genuine progress updates
diff --git a/mirix/prompts/system/screen_monitor/chat_agent_monitor_on.txt b/mirix/prompts/system/screen_monitor/chat_agent_monitor_on.txt
index 9ada2135..c009c918 100755
--- a/mirix/prompts/system/screen_monitor/chat_agent_monitor_on.txt
+++ b/mirix/prompts/system/screen_monitor/chat_agent_monitor_on.txt
@@ -11,7 +11,6 @@ Memory Components:
6. Semantic Memory: Conceptual knowledge about entities and objects
Requirements:
-- Maintain concise internal monologue (max 50 words)
- Continuously update conversation topics internally for natural flow
- Use `search_in_memory` and `list_memory_within_timerange` to retrieve relevant information
diff --git a/mirix/queue/manager.py b/mirix/queue/manager.py
index d2a422c8..5b381924 100644
--- a/mirix/queue/manager.py
+++ b/mirix/queue/manager.py
@@ -8,7 +8,6 @@
"""
import atexit
-import logging
import time
from typing import Any, List, Optional
diff --git a/mirix/queue/queue_util.py b/mirix/queue/queue_util.py
index ece02ad5..6e33f311 100644
--- a/mirix/queue/queue_util.py
+++ b/mirix/queue/queue_util.py
@@ -3,7 +3,6 @@
from mirix.schemas.client import Client
from mirix.schemas.message import MessageCreate
-from mirix.schemas.message import MessageCreate as PydanticMessageCreate
from mirix.schemas.enums import MessageRole
from mirix.schemas.mirix_message_content import TextContent
diff --git a/mirix/queue/worker.py b/mirix/queue/worker.py
index 3e5cff67..f382b3f3 100644
--- a/mirix/queue/worker.py
+++ b/mirix/queue/worker.py
@@ -14,7 +14,6 @@
if TYPE_CHECKING:
from .queue_interface import QueueInterface
- from mirix.schemas.user import User
from mirix.schemas.client import Client
from mirix.schemas.message import MessageCreate
diff --git a/mirix/schemas/agent.py b/mirix/schemas/agent.py
index 34a1ad26..6de8626c 100755
--- a/mirix/schemas/agent.py
+++ b/mirix/schemas/agent.py
@@ -1,5 +1,5 @@
from enum import Enum
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
@@ -74,13 +74,17 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
# agent configuration
agent_type: AgentType = Field(..., description="The type of agent.")
- # llm information
- llm_config: LLMConfig = Field(
- ..., description="The LLM configuration used by the agent."
- )
- embedding_config: Optional[EmbeddingConfig] = Field(
- None, description="The embedding configuration used by the agent."
- )
+ # llm information
+ llm_config: LLMConfig = Field(
+ ..., description="The LLM configuration used by the agent."
+ )
+ topic_extraction_llm_config: Optional[LLMConfig] = Field(
+ None,
+ description="Optional LLM configuration used for topic extraction.",
+ )
+ embedding_config: Optional[EmbeddingConfig] = Field(
+ None, description="The embedding configuration used by the agent."
+ )
# This is an object representing the in-process state of a running `Agent`
# Field in this object can be theoretically edited by tools, and will be persisted by the ORM
@@ -137,12 +141,16 @@ class CreateAgent(BaseModel, validate_assignment=True): #
agent_type: AgentType = Field(
default_factory=lambda: AgentType.chat_agent, description="The type of agent."
)
- llm_config: Optional[LLMConfig] = Field(
- None, description="The LLM configuration used by the agent."
- )
- embedding_config: Optional[EmbeddingConfig] = Field(
- None, description="The embedding configuration used by the agent."
- )
+ llm_config: Optional[LLMConfig] = Field(
+ None, description="The LLM configuration used by the agent."
+ )
+ topic_extraction_llm_config: Optional[LLMConfig] = Field(
+ None,
+ description="Optional LLM configuration used for topic extraction.",
+ )
+ embedding_config: Optional[EmbeddingConfig] = Field(
+ None, description="The embedding configuration used by the agent."
+ )
# Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence
# If the client wants to make this empty, then the client can set the arg to an empty list
initial_message_sequence: Optional[List[MessageCreate]] = Field(
@@ -266,12 +274,16 @@ class UpdateAgent(BaseModel):
tool_rules: Optional[List[ToolRule]] = Field(
None, description="The tool rules governing the agent."
)
- llm_config: Optional[LLMConfig] = Field(
- None, description="The LLM configuration used by the agent."
- )
- embedding_config: Optional[EmbeddingConfig] = Field(
- None, description="The embedding configuration used by the agent."
- )
+ llm_config: Optional[LLMConfig] = Field(
+ None, description="The LLM configuration used by the agent."
+ )
+ topic_extraction_llm_config: Optional[LLMConfig] = Field(
+ None,
+ description="Optional LLM configuration used for topic extraction.",
+ )
+ embedding_config: Optional[EmbeddingConfig] = Field(
+ None, description="The embedding configuration used by the agent."
+ )
clear_embedding_config: bool = Field(
False, description="If true, clear the embedding configuration."
)
@@ -362,14 +374,18 @@ class CreateMetaAgent(BaseModel):
None,
description="Dictionary mapping agent names to their system prompt text. Takes precedence over system_prompts_folder.",
)
- llm_config: Optional[LLMConfig] = Field(
- None,
- description="LLM configuration for memory agents. Required if no default is set.",
- )
- embedding_config: Optional[EmbeddingConfig] = Field(
- None,
- description="Embedding configuration for memory agents. Required if no default is set.",
- )
+ llm_config: Optional[LLMConfig] = Field(
+ None,
+ description="LLM configuration for memory agents. Required if no default is set.",
+ )
+ topic_extraction_llm_config: Optional[LLMConfig] = Field(
+ None,
+ description="Optional LLM configuration used for topic extraction across memory agents.",
+ )
+ embedding_config: Optional[EmbeddingConfig] = Field(
+ None,
+ description="Embedding configuration for memory agents. Required if no default is set.",
+ )
class UpdateMetaAgent(BaseModel):
"""Request schema for updating a MetaAgent."""
@@ -390,14 +406,18 @@ class UpdateMetaAgent(BaseModel):
None,
description="Dictionary mapping agent names to their system prompt text. Updates only the specified agents.",
)
- llm_config: Optional[LLMConfig] = Field(
- None,
- description="LLM configuration for meta agent and its sub-agents.",
- )
- embedding_config: Optional[EmbeddingConfig] = Field(
- None,
- description="Embedding configuration for meta agent and its sub-agents.",
- )
+ llm_config: Optional[LLMConfig] = Field(
+ None,
+ description="LLM configuration for meta agent and its sub-agents.",
+ )
+ topic_extraction_llm_config: Optional[LLMConfig] = Field(
+ None,
+ description="Optional LLM configuration used for topic extraction across the meta agent.",
+ )
+ embedding_config: Optional[EmbeddingConfig] = Field(
+ None,
+ description="Embedding configuration for meta agent and its sub-agents.",
+ )
clear_embedding_config: bool = Field(
False,
description="If true, clear embedding configuration for meta agent and its sub-agents.",
diff --git a/mirix/schemas/client.py b/mirix/schemas/client.py
index 6f0fdaee..7196dbb1 100644
--- a/mirix/schemas/client.py
+++ b/mirix/schemas/client.py
@@ -1,5 +1,5 @@
from datetime import datetime
-from typing import List, Optional
+from typing import Optional
import uuid
from pydantic import Field
@@ -51,8 +51,8 @@ class Client(ClientBase):
# Credits for LLM usage (1 credit = 1 dollar)
credits: float = Field(
- 100.0,
- description="Available credits for LLM API calls. New clients start with $100. 1 credit = 1 dollar.",
+ 10.0,
+ description="Available credits for LLM API calls. New clients start with $10. 1 credit = 1 dollar.",
)
created_at: Optional[datetime] = Field(
diff --git a/mirix/schemas/cloud_file_mapping.py b/mirix/schemas/cloud_file_mapping.py
index a469e9f6..5160c556 100644
--- a/mirix/schemas/cloud_file_mapping.py
+++ b/mirix/schemas/cloud_file_mapping.py
@@ -2,7 +2,7 @@
from pydantic import Field
-from mirix.client.utils import get_utc_time
+from mirix.utils import get_utc_time
from mirix.schemas.mirix_base import MirixBase
diff --git a/mirix/schemas/episodic_memory.py b/mirix/schemas/episodic_memory.py
index 2d3558c4..1c3290c7 100755
--- a/mirix/schemas/episodic_memory.py
+++ b/mirix/schemas/episodic_memory.py
@@ -3,7 +3,7 @@
from pydantic import Field, field_validator
-from mirix.client.utils import get_utc_time
+from mirix.utils import get_utc_time
from mirix.constants import MAX_EMBEDDING_DIM
from mirix.schemas.embedding_config import EmbeddingConfig
from mirix.schemas.mirix_base import MirixBase
diff --git a/mirix/schemas/knowledge.py b/mirix/schemas/knowledge.py
index fbf341e6..7d5b84b0 100644
--- a/mirix/schemas/knowledge.py
+++ b/mirix/schemas/knowledge.py
@@ -3,7 +3,7 @@
from pydantic import Field, field_validator
-from mirix.client.utils import get_utc_time
+from mirix.utils import get_utc_time
from mirix.constants import MAX_EMBEDDING_DIM
from mirix.schemas.embedding_config import EmbeddingConfig
from mirix.schemas.mirix_base import MirixBase
diff --git a/mirix/schemas/llm_config.py b/mirix/schemas/llm_config.py
index 4e11003f..802406a2 100755
--- a/mirix/schemas/llm_config.py
+++ b/mirix/schemas/llm_config.py
@@ -23,6 +23,7 @@ class LLMConfig(BaseModel):
api_version (str, optional): The API version for Azure OpenAI (e.g., '2024-10-01-preview').
azure_endpoint (str, optional): The Azure endpoint for the model (e.g., 'https://your-resource.openai.azure.com/').
azure_deployment (str, optional): The Azure deployment name for the model.
+ is_local_model (bool, optional): Whether the model is locally hosted (e.g., LM Studio, vLLM).
"""
# TODO: 🤮 don't default to a vendor! bug city!
@@ -94,6 +95,10 @@ class LLMConfig(BaseModel):
None,
description="Name of registered auth provider for dynamic header injection (e.g., for claims-based tickets)",
)
+ is_local_model: Optional[bool] = Field(
+ False,
+ description="Whether the model is locally hosted (e.g., LM Studio, vLLM).",
+ )
# Azure-specific fields (Azure OpenAI only)
api_version: Optional[str] = Field(
diff --git a/mirix/schemas/memory_agent_tool_call.py b/mirix/schemas/memory_agent_tool_call.py
index 56b3deae..3d677cde 100644
--- a/mirix/schemas/memory_agent_tool_call.py
+++ b/mirix/schemas/memory_agent_tool_call.py
@@ -22,6 +22,25 @@ class MemoryAgentToolCall(MemoryAgentToolCallBase):
None, description="Arguments passed to the tool"
)
+ llm_call_id: Optional[str] = Field(
+ None, description="LLM response ID that produced this tool call"
+ )
+ prompt_tokens: Optional[int] = Field(
+ None, description="Prompt tokens billed for the LLM call"
+ )
+ completion_tokens: Optional[int] = Field(
+ None, description="Completion tokens billed for the LLM call"
+ )
+ cached_tokens: Optional[int] = Field(
+ None, description="Cached prompt tokens for the LLM call"
+ )
+ total_tokens: Optional[int] = Field(
+ None, description="Total tokens reported for the LLM call"
+ )
+ credit_cost: Optional[float] = Field(
+ None, description="Credits charged for the LLM call"
+ )
+
status: str = Field(..., description="running|completed|failed")
started_at: datetime = Field(..., description="When tool execution started")
completed_at: Optional[datetime] = Field(
diff --git a/mirix/schemas/memory_queue_trace.py b/mirix/schemas/memory_queue_trace.py
index ae88eb75..94814c6f 100644
--- a/mirix/schemas/memory_queue_trace.py
+++ b/mirix/schemas/memory_queue_trace.py
@@ -34,6 +34,12 @@ class MemoryQueueTrace(MemoryQueueTraceBase):
completed_at: Optional[datetime] = Field(
None, description="When processing completed"
)
+ interrupt_requested_at: Optional[datetime] = Field(
+ None, description="When an interrupt was requested"
+ )
+ interrupt_reason: Optional[str] = Field(
+ None, description="Reason for interruption request"
+ )
message_count: int = Field(0, description="Number of input messages queued")
success: Optional[bool] = Field(
diff --git a/mirix/schemas/message.py b/mirix/schemas/message.py
index 1d641fc3..217d7e43 100644
--- a/mirix/schemas/message.py
+++ b/mirix/schemas/message.py
@@ -1,10 +1,8 @@
from __future__ import annotations
-import copy
import json
import uuid
import warnings
-from collections import OrderedDict
from datetime import datetime, timezone
from typing import Any, Dict, List, Literal, Optional, Union
diff --git a/mirix/schemas/mirix_message.py b/mirix/schemas/mirix_message.py
index 4117120d..4bcd0243 100755
--- a/mirix/schemas/mirix_message.py
+++ b/mirix/schemas/mirix_message.py
@@ -30,7 +30,7 @@ class MessageType(str, Enum):
class MirixMessage(BaseModel):
"""
Base class for simplified Mirix message response type. This is intended to be used for developers
- who want the internal monologue, tool calls, and tool returns in a simplified format that does not
+ who want the tool calls, and tool returns in a simplified format that does not
include additional information other than the content and timestamp.
Args:
diff --git a/mirix/schemas/mirix_response.py b/mirix/schemas/mirix_response.py
index 5c1d5e02..09c20783 100755
--- a/mirix/schemas/mirix_response.py
+++ b/mirix/schemas/mirix_response.py
@@ -147,7 +147,6 @@ def format_json(json_str):
.json-key, .function-name, .json-boolean { color: #9cdcfe; }
.json-string { color: #ce9178; }
.json-number { color: #b5cea8; }
- .internal-monologue { font-style: italic; }
"""
diff --git a/mirix/schemas/openai/chat_completion_request.py b/mirix/schemas/openai/chat_completion_request.py
index 1e5c26d7..6d5c612e 100755
--- a/mirix/schemas/openai/chat_completion_request.py
+++ b/mirix/schemas/openai/chat_completion_request.py
@@ -105,6 +105,7 @@ class ChatCompletionRequest(BaseModel):
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
+ max_completion_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0
response_format: Optional[ResponseFormat] = None
diff --git a/mirix/schemas/organization.py b/mirix/schemas/organization.py
index 2d96f180..f7282036 100755
--- a/mirix/schemas/organization.py
+++ b/mirix/schemas/organization.py
@@ -4,7 +4,7 @@
from pydantic import Field
-from mirix.client.utils import get_utc_time
+from mirix.utils import get_utc_time
from mirix.schemas.mirix_base import MirixBase
diff --git a/mirix/schemas/procedural_memory.py b/mirix/schemas/procedural_memory.py
index b9481af4..46e9b28f 100755
--- a/mirix/schemas/procedural_memory.py
+++ b/mirix/schemas/procedural_memory.py
@@ -3,7 +3,7 @@
from pydantic import Field, field_validator
-from mirix.client.utils import get_utc_time
+from mirix.utils import get_utc_time
from mirix.constants import MAX_EMBEDDING_DIM
from mirix.schemas.embedding_config import EmbeddingConfig
from mirix.schemas.mirix_base import MirixBase
diff --git a/mirix/schemas/resource_memory.py b/mirix/schemas/resource_memory.py
index 38966c97..5791d891 100755
--- a/mirix/schemas/resource_memory.py
+++ b/mirix/schemas/resource_memory.py
@@ -6,7 +6,7 @@
from mirix.constants import MAX_EMBEDDING_DIM
from mirix.schemas.embedding_config import EmbeddingConfig
from mirix.schemas.mirix_base import MirixBase
-from mirix.client.utils import get_utc_time
+from mirix.utils import get_utc_time
class ResourceMemoryItemBase(MirixBase):
diff --git a/mirix/schemas/semantic_memory.py b/mirix/schemas/semantic_memory.py
index f76a8097..047743ab 100755
--- a/mirix/schemas/semantic_memory.py
+++ b/mirix/schemas/semantic_memory.py
@@ -6,7 +6,7 @@
from mirix.constants import MAX_EMBEDDING_DIM
from mirix.schemas.embedding_config import EmbeddingConfig
from mirix.schemas.mirix_base import MirixBase
-from mirix.client.utils import get_utc_time
+from mirix.utils import get_utc_time
class SemanticMemoryItemBase(MirixBase):
diff --git a/mirix/sdk.py b/mirix/sdk.py
index 1d631e7f..d22eb774 100644
--- a/mirix/sdk.py
+++ b/mirix/sdk.py
@@ -685,9 +685,6 @@ def visualize_memories(self, user_id: Optional[str] = None) -> Dict[str, Any]:
if not target_user:
return {"success": False, "error": "No user found"}
- # Get the meta agent state to access memory agents
- meta_agent_state = self._client.get_agent(self._meta_agent.id)
-
memories = {}
# Get episodic memory
@@ -909,18 +906,18 @@ def visualize_memories(self, user_id: Optional[str] = None) -> Dict[str, Any]:
knowledge_memory_agent = agent
break
- if knowledge_memory_agent:
- knowledge_items = knowledge_memory_manager.list_knowledge(
- actor=target_user,
- agent_state=knowledge_memory_agent,
- limit=50,
- timezone_str=target_user.timezone,
- )
- else:
- knowledge_items = []
-
- memories["credentials"] = []
- for item in knowledge_items:
+ if knowledge_memory_agent:
+ knowledge_items = knowledge_memory_manager.list_knowledge(
+ actor=target_user,
+ agent_state=knowledge_memory_agent,
+ limit=50,
+ timezone_str=target_user.timezone,
+ )
+ else:
+ knowledge_items = []
+
+ memories["credentials"] = []
+ for item in knowledge_items:
memories["credentials"].append(
{
"caption": item.caption,
diff --git a/mirix/server/rest_api.py b/mirix/server/rest_api.py
index 695a9626..992a53fe 100644
--- a/mirix/server/rest_api.py
+++ b/mirix/server/rest_api.py
@@ -4,14 +4,12 @@
allowing MirixClient instances to communicate with a cloud-hosted server.
"""
-import copy
import json
import traceback
from contextlib import asynccontextmanager
-from datetime import datetime
-from typing import Any, Dict, List, Optional, Union
+from datetime import datetime, timezone
+from typing import Any, Dict, List, Optional
-import requests
from fastapi import APIRouter, Body, FastAPI, Header, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
@@ -23,18 +21,12 @@
from mirix.llm_api.llm_client import LLMClient
from mirix.log import get_logger
from mirix.schemas.agent import AgentState, AgentType, CreateAgent
-from mirix.schemas.block import Block, BlockUpdate, CreateBlock, Human, Persona
-from mirix.schemas.client import Client, ClientCreate, ClientUpdate
+from mirix.schemas.block import Block
+from mirix.schemas.client import Client, ClientUpdate
from mirix.schemas.embedding_config import EmbeddingConfig
from mirix.schemas.enums import MessageRole
-from mirix.schemas.environment_variables import (
- SandboxEnvironmentVariable,
- SandboxEnvironmentVariableCreate,
- SandboxEnvironmentVariableUpdate,
-)
-from mirix.schemas.file import FileMetadata
from mirix.schemas.llm_config import LLMConfig
-from mirix.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
+from mirix.schemas.memory import Memory
from mirix.schemas.message import Message, MessageCreate
from mirix.schemas.mirix_response import MirixResponse
from mirix.schemas.memory_agent_tool_call import MemoryAgentToolCall as MemoryAgentToolCallSchema
@@ -45,24 +37,24 @@
from mirix.schemas.resource_memory import ResourceMemoryItemUpdate
from mirix.schemas.agent import CreateMetaAgent, MemoryConfig, MemoryBlockConfig, MemoryDecayConfig, UpdateMetaAgent
from mirix.schemas.semantic_memory import SemanticMemoryItemUpdate
-from mirix.schemas.tool import Tool, ToolCreate, ToolUpdate
+from mirix.schemas.tool import Tool
from mirix.schemas.tool_rule import BaseToolRule
from mirix.schemas.user import User
from mirix.server.server import SyncServer
from mirix.server.server import db_context
+from mirix.services.memory_queue_trace_manager import MemoryQueueTraceManager
from mirix.settings import model_settings, settings
+from mirix.topic_extraction import extract_topics_with_ollama, flatten_messages_to_plain_text
from mirix.utils import convert_message_to_mirix_message
from mirix.orm.memory_agent_tool_call import MemoryAgentToolCall
from mirix.orm.memory_agent_trace import MemoryAgentTrace
from mirix.orm.memory_queue_trace import MemoryQueueTrace
-
-logger = get_logger(__name__)
-
-# Import queue components
from mirix.queue import initialize_queue
from mirix.queue.manager import get_manager as get_queue_manager
from mirix.queue.queue_util import put_messages
+logger = get_logger(__name__)
+
# Initialize server (single instance shared across all requests)
_server: Optional[SyncServer] = None
@@ -163,6 +155,7 @@ async def inject_client_org_headers(request: Request, call_next):
"/admin/auth/register",
"/admin/auth/login",
"/admin/auth/check-setup",
+ "/health",
}
if request.url.path in public_paths:
@@ -409,29 +402,7 @@ def _flatten_messages_to_plain_text(messages: List[Dict[str, Any]]) -> str:
"""
Flatten OpenAI-style message payloads into a simple conversation transcript.
"""
- transcript_parts: List[str] = []
-
- for msg in messages:
- role = msg.get("role", "user")
- content = msg.get("content", "")
-
- parts: List[str] = []
- if isinstance(content, list):
- for chunk in content:
- if isinstance(chunk, dict):
- text = chunk.get("text")
- if text:
- parts.append(text.strip())
- elif isinstance(chunk, str):
- parts.append(chunk.strip())
- elif isinstance(content, str):
- parts.append(content.strip())
-
- combined = " ".join(filter(None, parts)).strip()
- if combined:
- transcript_parts.append(f"{role.upper()}: {combined}")
-
- return "\n".join(transcript_parts)
+ return flatten_messages_to_plain_text(messages)
def extract_topics_with_local_model(messages: List[Dict[str, Any]], model_name: str) -> Optional[str]:
@@ -441,74 +412,11 @@ def extract_topics_with_local_model(messages: List[Dict[str, Any]], model_name:
Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#chat
"""
- base_url = model_settings.ollama_base_url
- if not base_url:
- logger.warning(
- "local_model_for_retrieval provided (%s) but MIRIX_OLLAMA_BASE_URL is not configured",
- model_name,
- )
- return None
-
- conversation = _flatten_messages_to_plain_text(messages)
- if not conversation:
- logger.debug("No text content found in messages for local topic extraction")
- return None
-
- payload = {
- "model": model_name,
- "stream": False,
- "messages": [
- {
- "role": "system",
- "content": (
- "You are a helpful assistant that extracts the topic from the user's input. "
- "Return a concise list of topics separated by ';' and nothing else."
- ),
- },
- {
- "role": "user",
- "content": (
- "Conversation transcript:\n"
- f"{conversation}\n\n"
- "Respond ONLY with the topic(s) separated by ';'."
- ),
- },
- ],
- "options": {
- "temperature": 0,
- },
- }
-
- try:
- response = requests.post(
- f"{base_url.rstrip('/')}/api/chat",
- json=payload,
- timeout=30,
- proxies={"http": None, "https": None},
- )
- response.raise_for_status()
- response_data = response.json()
- except requests.RequestException as exc:
- logger.error("Failed to extract topics with local model %s: %s", model_name, exc)
- return None
-
- message_payload = response_data.get("message") if isinstance(response_data, dict) else None
- text_response: Optional[str] = None
- if isinstance(message_payload, dict):
- text_response = message_payload.get("content")
- elif isinstance(response_data, dict):
- text_response = response_data.get("content")
-
- if isinstance(text_response, str):
- topics = text_response.strip()
- logger.debug("Extracted topics via local model %s: %s", model_name, topics)
- return topics or None
-
- logger.warning(
- "Unexpected response format from Ollama topic extraction: %s",
- response_data,
+ return extract_topics_with_ollama(
+ messages=messages,
+ model_name=model_name,
+ base_url=model_settings.ollama_base_url,
)
- return None
# ============================================================================
@@ -653,7 +561,7 @@ async def get_agent(
try:
return server.agent_manager.get_agent_by_id(agent_id, actor=client)
- except NoResultFound as e:
+ except NoResultFound:
raise HTTPException(
status_code=404,
detail=f"Agent {agent_id} not found or not accessible"
@@ -704,8 +612,8 @@ async def update_agent(
):
"""Update an agent."""
server = get_server()
- client_id, org_id = get_client_and_org(x_client_id, x_org_id)
- client = server.client_manager.get_client_by_id(client_id)
+ client_id, _org_id = get_client_and_org(x_client_id, x_org_id)
+ _client = server.client_manager.get_client_by_id(client_id)
# TODO: Implement update_agent in server
raise HTTPException(status_code=501, detail="Update agent not yet implemented")
@@ -1010,8 +918,8 @@ async def list_blocks(
):
"""List all blocks."""
server = get_server()
- client_id, org_id = get_client_and_org(x_client_id, x_org_id)
- client = server.client_manager.get_client_by_id(client_id)
+ client_id, _org_id = get_client_and_org(x_client_id, x_org_id)
+ _client = server.client_manager.get_client_by_id(client_id)
# Get default user for block queries (blocks are user-scoped, not client-scoped)
user = server.user_manager.get_admin_user()
return server.block_manager.get_blocks(user=user, label=label)
@@ -1025,8 +933,8 @@ async def get_block(
):
"""Get a block by ID."""
server = get_server()
- client_id, org_id = get_client_and_org(x_client_id, x_org_id)
- client = server.client_manager.get_client_by_id(client_id)
+ client_id, _org_id = get_client_and_org(x_client_id, x_org_id)
+ _client = server.client_manager.get_client_by_id(client_id)
# Get admin user for block queries (blocks are user-scoped, not client-scoped)
user = server.user_manager.get_admin_user()
return server.block_manager.get_block_by_id(block_id, user=user)
@@ -1724,7 +1632,7 @@ async def delete_client_api_key(
"message": f"API key {api_key_id} deleted successfully",
"id": api_key_id,
}
- except Exception as e:
+ except Exception:
raise HTTPException(status_code=404, detail=f"API key {api_key_id} not found")
@@ -1801,6 +1709,9 @@ async def initialize_meta_agent(
)
llm_config = LLMConfig(**config["llm_config"])
+ topic_extraction_llm_config = None
+ if config.get("topic_extraction_llm_config"):
+ topic_extraction_llm_config = LLMConfig(**config["topic_extraction_llm_config"])
if build_embeddings_for_memory:
if not config.get("embedding_config"):
@@ -1826,6 +1737,8 @@ async def initialize_meta_agent(
"llm_config": llm_config,
"embedding_config": embedding_config,
}
+ if topic_extraction_llm_config is not None:
+ create_params["topic_extraction_llm_config"] = topic_extraction_llm_config
# Flatten meta_agent_config fields into create_params
if "meta_agent_config" in config and config["meta_agent_config"]:
@@ -2030,6 +1943,11 @@ class MemoryQueueTraceDetailResponse(BaseModel):
trace: MemoryQueueTraceSchema
agent_traces: List[MemoryAgentTraceWithTools]
+class MemoryQueueTraceInterruptRequest(BaseModel):
+ reason: Optional[str] = Field(
+ None, description="Reason for interrupting the queue trace"
+ )
+
@router.get("/memory/queue-traces", response_model=List[MemoryQueueTraceSchema])
async def list_memory_queue_traces(
@@ -2078,6 +1996,27 @@ async def get_memory_queue_trace(
if not trace or trace.client_id != client_id:
raise HTTPException(status_code=404, detail="Trace not found")
+ if trace.status == "processing":
+ now = datetime.now(timezone.utc)
+ reference = trace.started_at or trace.queued_at
+ if reference:
+ if reference.tzinfo is None:
+ reference = reference.replace(tzinfo=timezone.utc)
+ wait_seconds = (now - reference).total_seconds()
+ logger.info(
+ "Queue trace %s processing for %.1fs (queued_at=%s, started_at=%s, now=%s)",
+ trace_id,
+ wait_seconds,
+ trace.queued_at.isoformat() if trace.queued_at else "N/A",
+ trace.started_at.isoformat() if trace.started_at else "N/A",
+ now.isoformat(),
+ )
+ else:
+ logger.info(
+ "Queue trace %s processing with no queued/started timestamp",
+ trace_id,
+ )
+
agent_traces = (
session.execute(
select(MemoryAgentTrace)
@@ -2121,6 +2060,32 @@ async def get_memory_queue_trace(
)
+@router.post("/memory/queue-traces/{trace_id}/interrupt")
+async def interrupt_memory_queue_trace(
+ trace_id: str,
+ payload: Optional[MemoryQueueTraceInterruptRequest] = Body(None),
+ x_client_id: Optional[str] = Header(None),
+ x_org_id: Optional[str] = Header(None),
+):
+ """
+ Request interruption of a running queue trace.
+ """
+ client_id, _ = get_client_and_org(x_client_id, x_org_id)
+
+ with db_context() as session:
+ trace = session.get(MemoryQueueTrace, trace_id)
+ if not trace or trace.client_id != client_id:
+ raise HTTPException(status_code=404, detail="Trace not found")
+
+ reason = payload.reason if payload else None
+ if not reason:
+ reason = "Interrupted by user"
+
+ MemoryQueueTraceManager().request_interrupt(trace_id, reason=reason)
+
+ return {"success": True, "trace_id": trace_id, "interrupt_requested": True}
+
+
class SelfReflectionRequest(BaseModel):
"""Request model for triggering self-reflection.
@@ -2142,8 +2107,8 @@ class SelfReflectionRequest(BaseModel):
procedural_ids: Optional[List[str]] = None # Procedural memory item IDs
-# Import self-reflection service functions
-from mirix.services.self_reflection_service import (
+# Import self-reflection service functions (intentionally here due to circular imports)
+from mirix.services.self_reflection_service import ( # noqa: E402
retrieve_memories_by_updated_at_range,
retrieve_specific_memories_by_ids,
build_self_reflection_prompt,
@@ -2442,7 +2407,7 @@ def retrieve_memories_by_keywords(
try:
user = server.user_manager.get_user_by_id(user_id)
timezone_str = user.timezone
- except:
+ except Exception:
timezone_str = "UTC"
memories = {}
@@ -2699,6 +2664,7 @@ async def retrieve_memory_with_conversation(
# Extract topics from the conversation
# TODO: Consider allowing custom model selection in the future
llm_config = all_agents[0].llm_config
+ topic_llm_config = getattr(all_agents[0], "topic_extraction_llm_config", None) or llm_config
# Check if messages have actual content before calling LLM
has_content = False
@@ -2729,8 +2695,17 @@ async def retrieve_memory_with_conversation(
)
if topics is None:
- # NEW: Extract both topics and temporal expression
- topics, temporal_expr = extract_topics_and_temporal_info(request.messages, llm_config)
+ if topic_llm_config.model_endpoint_type == "ollama":
+ topics = extract_topics_with_ollama(
+ messages=request.messages,
+ model_name=topic_llm_config.model,
+ base_url=topic_llm_config.model_endpoint or model_settings.ollama_base_url,
+ )
+ else:
+ # NEW: Extract both topics and temporal expression
+ topics, temporal_expr = extract_topics_and_temporal_info(
+ request.messages, topic_llm_config
+ )
logger.debug("Extracted topics: %s, temporal: %s", topics, temporal_expr)
key_words = topics if topics else ""
@@ -2978,7 +2953,7 @@ async def search_memory(
try:
user = server.user_manager.get_user_by_id(user_id)
timezone_str = user.timezone
- except:
+ except Exception:
timezone_str = "UTC"
# Parse filter_tags from JSON string to dict
@@ -4238,7 +4213,7 @@ class DashboardClientResponse(BaseModel):
admin_user_id: str # Admin user for memory operations
created_at: Optional[datetime]
last_login: Optional[datetime]
- credits: float = 100.0 # Available credits for LLM API calls (1 credit = 1 dollar)
+ credits: float = 10.0 # Available credits for LLM API calls (1 credit = 1 dollar)
class TokenResponse(BaseModel):
diff --git a/mirix/server/server.py b/mirix/server/server.py
index f659e752..affb8215 100644
--- a/mirix/server/server.py
+++ b/mirix/server/server.py
@@ -46,7 +46,7 @@
from mirix.log import get_logger
from mirix.orm import Base
from mirix.orm.errors import NoResultFound
-from mirix.schemas.agent import AgentState, AgentType, CreateAgent, CreateMetaAgent
+from mirix.schemas.agent import AgentState, AgentType, CreateAgent
from mirix.schemas.block import BlockUpdate
from mirix.schemas.embedding_config import EmbeddingConfig
@@ -156,7 +156,7 @@ def run_command(
# NOTE: hack to see if single session management works
-from mirix.settings import model_settings, settings # noqa: E402
+from mirix.settings import settings # noqa: E402
config = MirixConfig.load()
diff --git a/mirix/services/admin_user_manager.py b/mirix/services/admin_user_manager.py
index ef73bb46..bb73122e 100644
--- a/mirix/services/admin_user_manager.py
+++ b/mirix/services/admin_user_manager.py
@@ -217,7 +217,7 @@ def register_client_for_dashboard(
session.query(ClientModel)
.filter(
ClientModel.email == email.lower(),
- ClientModel.is_deleted == False
+ ClientModel.is_deleted.is_(False)
)
.first()
)
@@ -336,7 +336,7 @@ def authenticate(self, email: str, password: str) -> Tuple[Optional[PydanticClie
session.query(ClientModel)
.filter(
ClientModel.email == email.lower(),
- ClientModel.is_deleted == False
+ ClientModel.is_deleted.is_(False)
)
.first()
)
@@ -387,7 +387,7 @@ def get_client_by_email(self, email: str) -> Optional[PydanticClient]:
session.query(ClientModel)
.filter(
ClientModel.email == email.lower(),
- ClientModel.is_deleted == False
+ ClientModel.is_deleted.is_(False)
)
.first()
)
@@ -406,7 +406,7 @@ def list_dashboard_clients(
query = (
session.query(ClientModel)
.filter(
- ClientModel.is_deleted == False,
+ ClientModel.is_deleted.is_(False),
ClientModel.email.isnot(None)
)
.order_by(ClientModel.created_at.desc())
@@ -451,7 +451,7 @@ def set_client_password(
.filter(
ClientModel.email == email.lower(),
ClientModel.id != client_id,
- ClientModel.is_deleted == False
+ ClientModel.is_deleted.is_(False)
)
.first()
)
@@ -507,7 +507,7 @@ def count_dashboard_clients(self) -> int:
return (
session.query(ClientModel)
.filter(
- ClientModel.is_deleted == False,
+ ClientModel.is_deleted.is_(False),
ClientModel.email.isnot(None)
)
.count()
diff --git a/mirix/services/agent_manager.py b/mirix/services/agent_manager.py
index 52f7d138..c211c5fa 100644
--- a/mirix/services/agent_manager.py
+++ b/mirix/services/agent_manager.py
@@ -23,7 +23,6 @@
)
from mirix.log import get_logger
from mirix.orm import Agent as AgentModel
-from mirix.orm import Block as BlockModel
from mirix.orm import Tool as ToolModel
from mirix.orm.errors import NoResultFound
from mirix.schemas.agent import AgentState as PydanticAgentState
@@ -31,7 +30,6 @@
AgentType,
CreateAgent,
CreateMetaAgent,
- MemoryConfig,
UpdateAgent,
UpdateMetaAgent,
)
@@ -184,6 +182,7 @@ def create_agent(
system=system,
agent_type=agent_create.agent_type,
llm_config=agent_create.llm_config,
+ topic_extraction_llm_config=agent_create.topic_extraction_llm_config,
embedding_config=agent_create.embedding_config,
memory_config=agent_create.memory_config,
tool_ids=tool_ids,
@@ -298,20 +297,23 @@ def create_meta_agent(
# First, create the meta_memory_agent as the parent
meta_agent_name = meta_agent_create.name or "meta_memory_agent"
meta_system_prompt = None
- if (
- meta_agent_create.system_prompts
- and "meta_memory_agent" in meta_agent_create.system_prompts
- ):
- meta_system_prompt = meta_agent_create.system_prompts["meta_memory_agent"]
- elif has_child_agents:
- # Use the standard meta_memory_agent prompt with trigger_memory_update
- meta_system_prompt = default_system_prompts["meta_memory_agent"]
- else:
- # No child agents - use the direct prompt with episodic memory tools
- meta_system_prompt = default_system_prompts.get(
- "meta_memory_agent_direct",
- default_system_prompts["meta_memory_agent"] # Fallback
- )
+
+ if meta_agent_create.system_prompts:
+ if has_child_agents:
+ meta_system_prompt = meta_agent_create.system_prompts['meta_memory_agent_direct']
+ else:
+ meta_system_prompt = meta_agent_create.system_prompts['meta_memory_agent']
+
+ if meta_system_prompt is None:
+ if has_child_agents:
+ # Use the standard meta_memory_agent prompt with trigger_memory_update
+ meta_system_prompt = default_system_prompts["meta_memory_agent"]
+ else:
+ # No child agents - use the direct prompt with episodic memory tools
+ meta_system_prompt = default_system_prompts.get(
+ "meta_memory_agent_direct",
+ default_system_prompts["meta_memory_agent"] # Fallback
+ )
# Build memory_config dict from the MemoryConfig (if decay settings provided)
memory_config_dict = None
@@ -334,6 +336,7 @@ def create_meta_agent(
agent_type=AgentType.meta_memory_agent,
system=meta_system_prompt,
llm_config=meta_agent_create.llm_config,
+ topic_extraction_llm_config=meta_agent_create.topic_extraction_llm_config,
embedding_config=meta_agent_create.embedding_config,
include_base_tools=True,
)
@@ -380,10 +383,11 @@ def create_meta_agent(
# Create the agent using CreateAgent schema with parent_id
agent_create = CreateAgent(
- name=f"{meta_agent_name}_{agent_name}",
+ name=agent_name,
agent_type=agent_type,
system=custom_system, # Uses custom prompt or default from base folder
llm_config=meta_agent_create.llm_config,
+ topic_extraction_llm_config=meta_agent_create.topic_extraction_llm_config,
embedding_config=meta_agent_create.embedding_config,
include_base_tools=True,
parent_id=meta_agent_state.id, # Set the parent_id
@@ -551,6 +555,10 @@ def update_meta_agent(
meta_agent_update_fields["name"] = meta_agent_update.name
if meta_agent_update.llm_config is not None:
meta_agent_update_fields["llm_config"] = meta_agent_update.llm_config
+ if meta_agent_update.topic_extraction_llm_config is not None:
+ meta_agent_update_fields["topic_extraction_llm_config"] = (
+ meta_agent_update.topic_extraction_llm_config
+ )
if meta_agent_update.embedding_config is not None:
meta_agent_update_fields["embedding_config"] = (
meta_agent_update.embedding_config
@@ -568,18 +576,6 @@ def update_meta_agent(
)
meta_agent_state = self.get_agent_by_id(agent_id=meta_agent_id, actor=actor)
- # Update meta agent's system prompt if provided (separate call needed for rebuild_system_prompt)
- if (
- meta_agent_update.system_prompts
- and "meta_memory_agent" in meta_agent_update.system_prompts
- ):
- self.update_system_prompt(
- agent_id=meta_agent_id,
- system_prompt=meta_agent_update.system_prompts["meta_memory_agent"],
- actor=actor,
- )
- meta_agent_state = self.get_agent_by_id(agent_id=meta_agent_id, actor=actor)
-
# Get existing sub-agents
existing_children = self.list_agents(actor=actor, parent_id=meta_agent_id)
existing_agent_names = set()
@@ -642,6 +638,10 @@ def update_meta_agent(
# Use the updated configs or fall back to meta agent's configs
llm_config = meta_agent_update.llm_config or meta_agent_state.llm_config
+ topic_extraction_llm_config = (
+ meta_agent_update.topic_extraction_llm_config
+ or meta_agent_state.topic_extraction_llm_config
+ )
if meta_agent_update.clear_embedding_config:
embedding_config = None
else:
@@ -652,10 +652,11 @@ def update_meta_agent(
# Create the agent using CreateAgent schema with parent_id
agent_create = CreateAgent(
- name=f"{meta_agent_state.name}_{agent_name}",
+ name=agent_name,
agent_type=agent_type,
system=custom_system,
llm_config=llm_config,
+ topic_extraction_llm_config=topic_extraction_llm_config,
embedding_config=embedding_config,
include_base_tools=True,
parent_id=meta_agent_id,
@@ -670,12 +671,24 @@ def update_meta_agent(
logger.debug(
f"Created sub-agent: {agent_name} with id: {new_agent_state.id}, parent_id: {meta_agent_id}"
)
+ else:
+ # If agents is None, delete all existing sub-agents
+ logger.debug(
+ "agents field is None - deleting all sub-agents for meta agent %s",
+ meta_agent_id,
+ )
+ for agent_name, child_agent in existing_agents_by_name.items():
+ logger.debug(
+ "Deleting sub-agent: %s with id: %s", agent_name, child_agent.id
+ )
+ self.delete_agent(agent_id=child_agent.id, actor=actor)
+ existing_agents_by_name.clear()
# Update system prompts for existing sub-agents
if meta_agent_update.system_prompts:
for agent_name, system_prompt in meta_agent_update.system_prompts.items():
- # Skip meta_memory_agent as we already updated it
- if agent_name == "meta_memory_agent":
+ # Skip meta_memory_agent and meta_memory_agent_direct as we handle those separately
+ if agent_name in ["meta_memory_agent", "meta_memory_agent_direct"]:
continue
if agent_name in existing_agents_by_name:
@@ -693,6 +706,7 @@ def update_meta_agent(
# Update llm_config and embedding_config for all sub-agents if provided
if (
meta_agent_update.llm_config
+ or meta_agent_update.topic_extraction_llm_config
or meta_agent_update.embedding_config
or meta_agent_update.clear_embedding_config
):
@@ -700,6 +714,10 @@ def update_meta_agent(
update_fields = {}
if meta_agent_update.llm_config is not None:
update_fields["llm_config"] = meta_agent_update.llm_config
+ if meta_agent_update.topic_extraction_llm_config is not None:
+ update_fields["topic_extraction_llm_config"] = (
+ meta_agent_update.topic_extraction_llm_config
+ )
if meta_agent_update.embedding_config is not None:
update_fields["embedding_config"] = (
meta_agent_update.embedding_config
@@ -715,6 +733,84 @@ def update_meta_agent(
actor=actor,
)
+ # Update meta_memory_agent's system prompt and tools based on whether it has children
+ has_children = len(existing_agents_by_name) > 0
+
+ # Determine the appropriate system prompt
+ meta_system_prompt = None
+
+ # First check if custom system prompts are provided
+ if meta_agent_update.system_prompts:
+ if has_children:
+ meta_system_prompt = meta_agent_update.system_prompts.get('meta_memory_agent')
+ else:
+ meta_system_prompt = meta_agent_update.system_prompts.get('meta_memory_agent_direct')
+
+ # If no custom system prompt, use defaults based on children
+ if meta_system_prompt is None:
+ if has_children:
+ # Has child agents - use standard meta_memory_agent prompt
+ meta_system_prompt = default_system_prompts.get("meta_memory_agent")
+ logger.info(
+ "Meta agent has children - using meta_memory_agent.txt system prompt"
+ )
+ else:
+ # No child agents - use direct prompt with memory tools
+ meta_system_prompt = default_system_prompts.get(
+ "meta_memory_agent_direct",
+ default_system_prompts.get("meta_memory_agent") # Fallback
+ )
+ logger.info(
+ "Meta agent has no children - using meta_memory_agent_direct.txt system prompt"
+ )
+
+ # Update the system prompt
+ if meta_system_prompt:
+ self.update_system_prompt(
+ agent_id=meta_agent_id,
+ system_prompt=meta_system_prompt,
+ actor=actor,
+ )
+
+ # Update meta_memory_agent's tools based on whether it has children
+ meta_agent_state = self.get_agent_by_id(agent_id=meta_agent_id, actor=actor)
+
+ # Get the appropriate tool names based on children
+ if has_children:
+ # Has child agents - use trigger_memory_update
+ logger.info(
+ "Configuring meta_memory_agent with trigger_memory_update (has children)"
+ )
+ tool_names_to_use = META_MEMORY_TOOLS + UNIVERSAL_MEMORY_TOOLS
+ else:
+ # No child agents - use direct memory tools
+ logger.info(
+ "Configuring meta_memory_agent with direct memory tools (no children)"
+ )
+ tool_names_to_use = META_MEMORY_TOOLS_DIRECT + UNIVERSAL_MEMORY_TOOLS
+
+ # Get BASE_TOOLS + the appropriate memory tools
+ from mirix.constants import BASE_TOOLS
+ all_tool_names = BASE_TOOLS + tool_names_to_use
+
+ # Collect tool IDs for the complete tool set
+ new_tool_ids = []
+ for tool_name in all_tool_names:
+ tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
+ if tool:
+ new_tool_ids.append(tool.id)
+ else:
+ logger.debug("Tool %s not found", tool_name)
+
+ # Remove duplicates
+ new_tool_ids = list(set(new_tool_ids))
+
+ self.update_agent(
+ agent_id=meta_agent_id,
+ agent_update=UpdateAgent(tool_ids=new_tool_ids),
+ actor=actor,
+ )
+
# Refresh the meta agent state with updated children
meta_agent_state = self.get_agent_by_id(agent_id=meta_agent_id, actor=actor)
updated_children = self.list_agents(actor=actor, parent_id=meta_agent_id)
@@ -898,6 +994,7 @@ def _create_agent(
system: str,
agent_type: AgentType,
llm_config: LLMConfig,
+ topic_extraction_llm_config: Optional[LLMConfig],
embedding_config: Optional[EmbeddingConfig],
memory_config: Optional[Dict[str, Any]],
tool_ids: List[str],
@@ -916,6 +1013,7 @@ def _create_agent(
"system": system,
"agent_type": agent_type,
"llm_config": llm_config,
+ "topic_extraction_llm_config": topic_extraction_llm_config,
"embedding_config": embedding_config,
"memory_config": memory_config,
"organization_id": actor.organization_id,
@@ -1065,6 +1163,7 @@ def _update_agent(
"name",
"system",
"llm_config",
+ "topic_extraction_llm_config",
"embedding_config",
"message_ids",
"tool_rules",
@@ -1173,7 +1272,6 @@ def _reconstruct_children_from_cache(
import json
from mirix.database.redis_client import get_redis_client
- from mirix.schemas.block import Block as PydanticBlock
from mirix.schemas.memory import Memory as PydanticMemory
from mirix.schemas.tool import Tool as PydanticTool
@@ -1239,6 +1337,12 @@ def _reconstruct_children_from_cache(
if isinstance(child_data["llm_config"], (str, bytes))
else child_data["llm_config"]
)
+ if "topic_extraction_llm_config" in child_data:
+ child_data["topic_extraction_llm_config"] = (
+ json.loads(child_data["topic_extraction_llm_config"])
+ if isinstance(child_data["topic_extraction_llm_config"], (str, bytes))
+ else child_data["topic_extraction_llm_config"]
+ )
if "embedding_config" in child_data:
child_data["embedding_config"] = (
json.loads(child_data["embedding_config"])
@@ -1606,6 +1710,12 @@ def get_agent_by_id(
if isinstance(cached_data["llm_config"], str)
else cached_data["llm_config"]
)
+ if "topic_extraction_llm_config" in cached_data:
+ cached_data["topic_extraction_llm_config"] = (
+ json.loads(cached_data["topic_extraction_llm_config"])
+ if isinstance(cached_data["topic_extraction_llm_config"], str)
+ else cached_data["topic_extraction_llm_config"]
+ )
if "embedding_config" in cached_data:
cached_data["embedding_config"] = (
json.loads(cached_data["embedding_config"])
@@ -1753,6 +1863,10 @@ def get_agent_by_id(
data["message_ids"] = json.dumps(data["message_ids"])
if "llm_config" in data and data["llm_config"]:
data["llm_config"] = json.dumps(data["llm_config"])
+ if "topic_extraction_llm_config" in data and data["topic_extraction_llm_config"]:
+ data["topic_extraction_llm_config"] = json.dumps(
+ data["topic_extraction_llm_config"]
+ )
if "embedding_config" in data and data["embedding_config"]:
data["embedding_config"] = json.dumps(data["embedding_config"])
if "tool_rules" in data and data["tool_rules"]:
@@ -2119,7 +2233,7 @@ def reset_messages(
if add_default_initial_messages:
return self.append_initial_message_sequence_to_in_context_messages(
- user, agent_state
+ actor, agent_state
)
else:
# We still want to always have a system message
diff --git a/mirix/services/block_manager.py b/mirix/services/block_manager.py
index 7b713c0e..a732587a 100755
--- a/mirix/services/block_manager.py
+++ b/mirix/services/block_manager.py
@@ -1,16 +1,14 @@
-import os
from typing import List, Optional
from mirix.log import get_logger
from mirix.orm.block import Block as BlockModel
from mirix.orm.errors import NoResultFound
-from mirix.schemas.block import Block, BlockUpdate, Human, Persona
+from mirix.schemas.block import Block, BlockUpdate
from mirix.schemas.block import Block as PydanticBlock
from mirix.schemas.client import Client as PydanticClient
from mirix.schemas.user import User as PydanticUser
-from mirix.utils import enforce_types, list_human_files, list_persona_files
+from mirix.utils import enforce_types
-from mirix.orm import user
from mirix.orm.enums import AccessType
logger = get_logger(__name__)
@@ -236,7 +234,6 @@ def _copy_blocks_from_admin_user(
List of newly created BlockModel instances
"""
from mirix.services.user_manager import UserManager
- from mirix.utils import generate_unique_short_id
# ✅ NEW: Get the organization-specific default user
# This ensures we copy blocks from the correct template within the same organization
@@ -441,7 +438,7 @@ def soft_delete_by_user_id(self, user_id: str) -> int:
# Query all non-deleted records for this user
blocks = session.query(BlockModel).filter(
BlockModel.user_id == user_id,
- BlockModel.is_deleted == False
+ BlockModel.is_deleted.is_(False)
).all()
count = len(blocks)
@@ -520,31 +517,3 @@ def delete_by_user_id(self, user_id: str) -> int:
redis_client.client.delete(*batch)
return count
-
- @enforce_types
- def add_default_blocks(self, actor: PydanticClient, user: Optional[PydanticUser] = None):
- """
- Add default persona and human blocks.
-
- Args:
- actor: Client for audit trail
- user_id: Optional user_id for block ownership (uses default if not provided)
- """
- # Use admin user if not provided
- if user is None:
- from mirix.services.user_manager import UserManager
- user = UserManager().get_admin_user()
-
- for persona_file in list_persona_files():
- text = open(persona_file, "r", encoding="utf-8").read()
- name = os.path.basename(persona_file).replace(".txt", "")
- self.create_or_update_block(
- Persona(value=text), actor=actor, user=user
- )
-
- for human_file in list_human_files():
- text = open(human_file, "r", encoding="utf-8").read()
- name = os.path.basename(human_file).replace(".txt", "")
- self.create_or_update_block(
- Human(value=text), actor=actor, user=user
- )
diff --git a/mirix/services/client_manager.py b/mirix/services/client_manager.py
index 70f8f7ba..fe5a2c9e 100644
--- a/mirix/services/client_manager.py
+++ b/mirix/services/client_manager.py
@@ -7,7 +7,6 @@
from mirix.schemas.client import Client as PydanticClient
from mirix.schemas.client import ClientUpdate
from mirix.schemas.client_api_key import ClientApiKey as PydanticClientApiKey
-from mirix.schemas.client_api_key import ClientApiKeyCreate
from mirix.services.organization_manager import OrganizationManager
from mirix.utils import enforce_types
from mirix.security.api_keys import hash_api_key
@@ -134,7 +133,7 @@ def get_client_by_api_key(self, api_key: str) -> Optional[PydanticClient]:
.filter(
ClientApiKeyModel.api_key_hash == hashed,
ClientApiKeyModel.status == "active",
- ClientApiKeyModel.is_deleted == False
+ ClientApiKeyModel.is_deleted.is_(False)
)
.first()
)
@@ -156,7 +155,7 @@ def list_client_api_keys(self, client_id: str) -> List[PydanticClientApiKey]:
session.query(ClientApiKeyModel)
.filter(
ClientApiKeyModel.client_id == client_id,
- ClientApiKeyModel.is_deleted == False
+ ClientApiKeyModel.is_deleted.is_(False)
)
.all()
)
@@ -319,7 +318,7 @@ def delete_client_by_id(self, client_id: str):
agents_created_by_client = session.query(AgentModel).filter(
AgentModel._created_by_id == client_id,
- AgentModel.is_deleted == False
+ AgentModel.is_deleted.is_(False)
).all()
agent_ids = [agent.id for agent in agents_created_by_client]
logger.debug("Found %d agents created by client %s", len(agent_ids), client_id)
@@ -333,7 +332,7 @@ def delete_client_by_id(self, client_id: str):
# Soft delete tools created by this client
tools = session.query(ToolModel).filter(
ToolModel._created_by_id == client_id,
- ToolModel.is_deleted == False
+ ToolModel.is_deleted.is_(False)
).all()
for tool in tools:
tool.is_deleted = True
@@ -343,7 +342,7 @@ def delete_client_by_id(self, client_id: str):
# Soft delete blocks created by this client
blocks = session.query(BlockModel).filter(
BlockModel._created_by_id == client_id,
- BlockModel.is_deleted == False
+ BlockModel.is_deleted.is_(False)
).all()
for block in blocks:
block.is_deleted = True
diff --git a/mirix/services/episodic_memory_manager.py b/mirix/services/episodic_memory_manager.py
index fb154a59..7fb90fa4 100755
--- a/mirix/services/episodic_memory_manager.py
+++ b/mirix/services/episodic_memory_manager.py
@@ -222,13 +222,17 @@ def get_episodic_memory_by_id(
@update_timezone
@enforce_types
def get_most_recently_updated_event(
- self, user: PydanticUser, timezone_str: str = None
+ self,
+ actor: PydanticClient,
+ user_id: str,
+ timezone_str: str = None
) -> Optional[PydanticEpisodicEvent]:
"""
Fetch the most recently updated episodic event based on last_modify timestamp.
Args:
- user: User who owns the memories to query
+ actor: Client performing the operation
+ user_id: User who owns the memories to query
timezone_str: Optional timezone string
Returns:
@@ -240,7 +244,7 @@ def get_most_recently_updated_event(
query = (
select(EpisodicEvent)
- .where(EpisodicEvent.user_id == user.id)
+ .where(EpisodicEvent.user_id == user_id)
.order_by(
cast(
text("episodic_memory.last_modify ->> 'timestamp'"), DateTime
@@ -251,7 +255,7 @@ def get_most_recently_updated_event(
result = session.execute(query.limit(1))
episodic_memory = result.scalar_one_or_none()
- return [episodic_memory.to_pydantic()] if episodic_memory else None
+ return episodic_memory.to_pydantic() if episodic_memory else None
@enforce_types
def create_episodic_memory(
@@ -423,7 +427,7 @@ def soft_delete_by_client_id(self, actor: PydanticClient) -> int:
# Query all non-deleted records for this client (use actor.id)
items = session.query(EpisodicEvent).filter(
EpisodicEvent.client_id == actor.id,
- EpisodicEvent.is_deleted == False
+ EpisodicEvent.is_deleted.is_(False)
).all()
count = len(items)
@@ -471,7 +475,7 @@ def soft_delete_by_user_id(self, user_id: str) -> int:
# Extract IDs BEFORE bulk update (for Redis cleanup)
item_ids = [row[0] for row in session.query(EpisodicEvent.id).filter(
EpisodicEvent.user_id == user_id,
- EpisodicEvent.is_deleted == False
+ EpisodicEvent.is_deleted.is_(False)
).all()]
count = len(item_ids)
@@ -481,7 +485,7 @@ def soft_delete_by_user_id(self, user_id: str) -> int:
# Batch soft delete in database using single SQL UPDATE
session.query(EpisodicEvent).filter(
EpisodicEvent.user_id == user_id,
- EpisodicEvent.is_deleted == False
+ EpisodicEvent.is_deleted.is_(False)
).update(
{
"is_deleted": True,
@@ -1630,7 +1634,6 @@ def list_episodic_memory_by_org(
return [event.to_pydantic() for event in episodic_memory]
if search_method == "embedding":
- embed_query = True
embedding_config = agent_state.embedding_config
# Use provided embedding or generate it
@@ -1668,7 +1671,7 @@ def list_episodic_memory_by_org(
base_query = base_query.order_by(embedding_query_field)
elif search_method == "bm25":
# Use PostgreSQL native full-text search if available
- from sqlalchemy import text, func
+ from sqlalchemy import func
# Determine search field
if not search_field or search_field == "details":
diff --git a/mirix/services/knowledge_memory_manager.py b/mirix/services/knowledge_memory_manager.py
index 8831e85c..935bfb8d 100644
--- a/mirix/services/knowledge_memory_manager.py
+++ b/mirix/services/knowledge_memory_manager.py
@@ -477,7 +477,10 @@ def get_item_by_id(
@update_timezone
@enforce_types
def get_most_recently_updated_item(
- self, user: PydanticUser, timezone_str: str = None
+ self,
+ actor: PydanticClient,
+ user_id: str,
+ timezone_str: str = None
) -> Optional[PydanticKnowledgeItem]:
"""
Fetch the most recently updated knowledge item based on last_modify timestamp.
@@ -495,12 +498,12 @@ def get_most_recently_updated_item(
)
# Filter by user_id for multi-user support
- query = query.where(KnowledgeItem.user_id == user.id)
+ query = query.where(KnowledgeItem.user_id == user_id)
result = session.execute(query.limit(1))
item = result.scalar_one_or_none()
- return [item.to_pydantic()] if item else None
+ return item.to_pydantic() if item else None
@enforce_types
def create_item(
@@ -1087,7 +1090,7 @@ def soft_delete_by_client_id(self, actor: PydanticClient) -> int:
# Query all non-deleted records for this client (use actor.id)
items = session.query(KnowledgeItem).filter(
KnowledgeItem.client_id == actor.id,
- KnowledgeItem.is_deleted == False
+ KnowledgeItem.is_deleted.is_(False)
).all()
count = len(items)
@@ -1133,7 +1136,7 @@ def soft_delete_by_user_id(self, user_id: str) -> int:
# Query all non-deleted records for this user
items = session.query(KnowledgeItem).filter(
KnowledgeItem.user_id == user_id,
- KnowledgeItem.is_deleted == False
+ KnowledgeItem.is_deleted.is_(False)
).all()
count = len(items)
diff --git a/mirix/services/memory_agent_tool_call_trace_manager.py b/mirix/services/memory_agent_tool_call_trace_manager.py
index 9715a668..338547db 100644
--- a/mirix/services/memory_agent_tool_call_trace_manager.py
+++ b/mirix/services/memory_agent_tool_call_trace_manager.py
@@ -2,6 +2,8 @@
from datetime import datetime
from typing import Optional
+from sqlalchemy import text
+
from mirix.orm.memory_agent_tool_call import MemoryAgentToolCall
from mirix.schemas.client import Client as PydanticClient
from mirix.schemas.memory_agent_tool_call import (
@@ -15,6 +17,44 @@ def __init__(self):
from mirix.server.server import db_context
self.session_maker = db_context
+ self._ensure_usage_columns()
+
+ def _ensure_usage_columns(self) -> None:
+ with self.session_maker() as session:
+ bind = session.get_bind()
+ if not bind or bind.dialect.name != "sqlite":
+ return
+ results = session.execute(text("PRAGMA table_info(memory_agent_tool_calls)"))
+ existing = {row[1] for row in results.fetchall()}
+ statements = []
+ if "llm_call_id" not in existing:
+ statements.append(
+ "ALTER TABLE memory_agent_tool_calls ADD COLUMN llm_call_id VARCHAR"
+ )
+ if "prompt_tokens" not in existing:
+ statements.append(
+ "ALTER TABLE memory_agent_tool_calls ADD COLUMN prompt_tokens INTEGER"
+ )
+ if "completion_tokens" not in existing:
+ statements.append(
+ "ALTER TABLE memory_agent_tool_calls ADD COLUMN completion_tokens INTEGER"
+ )
+ if "cached_tokens" not in existing:
+ statements.append(
+ "ALTER TABLE memory_agent_tool_calls ADD COLUMN cached_tokens INTEGER"
+ )
+ if "total_tokens" not in existing:
+ statements.append(
+ "ALTER TABLE memory_agent_tool_calls ADD COLUMN total_tokens INTEGER"
+ )
+ if "credit_cost" not in existing:
+ statements.append(
+ "ALTER TABLE memory_agent_tool_calls ADD COLUMN credit_cost FLOAT"
+ )
+ if statements:
+ for statement in statements:
+ session.execute(text(statement))
+ session.commit()
def start_tool_call(
self,
@@ -22,6 +62,12 @@ def start_tool_call(
function_name: str,
function_args: Optional[dict],
tool_call_id: Optional[str] = None,
+ llm_call_id: Optional[str] = None,
+ prompt_tokens: Optional[int] = None,
+ completion_tokens: Optional[int] = None,
+ cached_tokens: Optional[int] = None,
+ total_tokens: Optional[int] = None,
+ credit_cost: Optional[float] = None,
actor: Optional[PydanticClient] = None,
) -> PydanticMemoryAgentToolCall:
trace_id = generate_unique_short_id(
@@ -33,6 +79,12 @@ def start_tool_call(
tool_call_id=tool_call_id,
function_name=function_name,
function_args=function_args,
+ llm_call_id=llm_call_id,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ cached_tokens=cached_tokens,
+ total_tokens=total_tokens,
+ credit_cost=credit_cost,
status="running",
started_at=datetime.now(dt.timezone.utc),
)
@@ -46,6 +98,12 @@ def finish_tool_call(
success: bool,
response_text: Optional[str] = None,
error_message: Optional[str] = None,
+ llm_call_id: Optional[str] = None,
+ prompt_tokens: Optional[int] = None,
+ completion_tokens: Optional[int] = None,
+ cached_tokens: Optional[int] = None,
+ total_tokens: Optional[int] = None,
+ credit_cost: Optional[float] = None,
actor: Optional[PydanticClient] = None,
) -> None:
with self.session_maker() as session:
@@ -56,5 +114,17 @@ def finish_tool_call(
trace.success = success
trace.response_text = response_text
trace.error_message = error_message
+ if llm_call_id is not None:
+ trace.llm_call_id = llm_call_id
+ if prompt_tokens is not None:
+ trace.prompt_tokens = prompt_tokens
+ if completion_tokens is not None:
+ trace.completion_tokens = completion_tokens
+ if cached_tokens is not None:
+ trace.cached_tokens = cached_tokens
+ if total_tokens is not None:
+ trace.total_tokens = total_tokens
+ if credit_cost is not None:
+ trace.credit_cost = credit_cost
trace.completed_at = datetime.now(dt.timezone.utc)
trace.update(session, actor=actor)
diff --git a/mirix/services/memory_queue_trace_manager.py b/mirix/services/memory_queue_trace_manager.py
index 6d0e507e..d1e32fd4 100644
--- a/mirix/services/memory_queue_trace_manager.py
+++ b/mirix/services/memory_queue_trace_manager.py
@@ -2,6 +2,8 @@
from datetime import datetime
from typing import Dict, Optional
+from sqlalchemy import text
+
from mirix.orm.memory_agent_trace import MemoryAgentTrace
from mirix.orm.memory_queue_trace import MemoryQueueTrace
from mirix.schemas.client import Client as PydanticClient
@@ -14,6 +16,28 @@ def __init__(self):
from mirix.server.server import db_context
self.session_maker = db_context
+ self._ensure_interrupt_columns()
+
+ def _ensure_interrupt_columns(self) -> None:
+ with self.session_maker() as session:
+ bind = session.get_bind()
+ if not bind or bind.dialect.name != "sqlite":
+ return
+ results = session.execute(text("PRAGMA table_info(memory_queue_traces)"))
+ existing = {row[1] for row in results.fetchall()}
+ statements = []
+ if "interrupt_requested_at" not in existing:
+ statements.append(
+ "ALTER TABLE memory_queue_traces ADD COLUMN interrupt_requested_at DATETIME"
+ )
+ if "interrupt_reason" not in existing:
+ statements.append(
+ "ALTER TABLE memory_queue_traces ADD COLUMN interrupt_reason TEXT"
+ )
+ if statements:
+ for statement in statements:
+ session.execute(text(statement))
+ session.commit()
def create_trace(
self,
@@ -60,6 +84,8 @@ def mark_completed(
trace = session.get(MemoryQueueTrace, trace_id)
if not trace:
return
+ if trace.completed_at is not None:
+ return
trace.status = "completed" if success else "failed"
trace.success = success
trace.error_message = error_message
@@ -68,6 +94,36 @@ def mark_completed(
trace.memory_update_counts = memory_update_counts
trace.update(session, actor=actor)
+ def request_interrupt(
+ self,
+ trace_id: str,
+ reason: Optional[str] = None,
+ actor: Optional[PydanticClient] = None,
+ ) -> None:
+ with self.session_maker() as session:
+ trace = session.get(MemoryQueueTrace, trace_id)
+ if not trace:
+ return
+ if trace.interrupt_requested_at is None:
+ trace.interrupt_requested_at = datetime.now(dt.timezone.utc)
+ if reason:
+ trace.interrupt_reason = reason
+ trace.update(session, actor=actor)
+
+ def is_interrupt_requested(self, trace_id: str) -> bool:
+ with self.session_maker() as session:
+ trace = session.get(MemoryQueueTrace, trace_id)
+ if not trace:
+ return False
+ return trace.interrupt_requested_at is not None
+
+ def get_interrupt_reason(self, trace_id: str) -> Optional[str]:
+ with self.session_maker() as session:
+ trace = session.get(MemoryQueueTrace, trace_id)
+ if not trace:
+ return None
+ return trace.interrupt_reason
+
def set_meta_agent_output(
self,
trace_id: str,
diff --git a/mirix/services/message_manager.py b/mirix/services/message_manager.py
index 3335c387..20db0c70 100755
--- a/mirix/services/message_manager.py
+++ b/mirix/services/message_manager.py
@@ -265,7 +265,7 @@ def soft_delete_by_client_id(self, actor: PydanticClient) -> int:
# Query all non-deleted records for this client (use actor.id)
messages = session.query(MessageModel).filter(
MessageModel.client_id == actor.id,
- MessageModel.is_deleted == False
+ MessageModel.is_deleted.is_(False)
).all()
count = len(messages)
@@ -311,7 +311,7 @@ def soft_delete_by_user_id(self, user_id: str) -> int:
# Query all non-deleted records for this user
messages = session.query(MessageModel).filter(
MessageModel.user_id == user_id,
- MessageModel.is_deleted == False
+ MessageModel.is_deleted.is_(False)
).all()
count = len(messages)
diff --git a/mirix/services/procedural_memory_manager.py b/mirix/services/procedural_memory_manager.py
index 43a6d3df..e42e1a8d 100755
--- a/mirix/services/procedural_memory_manager.py
+++ b/mirix/services/procedural_memory_manager.py
@@ -442,7 +442,10 @@ def get_item_by_id(
@update_timezone
@enforce_types
def get_most_recently_updated_item(
- self, user: PydanticUser, timezone_str: str = None
+ self,
+ actor: PydanticClient,
+ user_id: str,
+ timezone_str: str = None
) -> Optional[PydanticProceduralMemoryItem]:
"""
Fetch the most recently updated procedural memory item based on last_modify timestamp.
@@ -460,12 +463,12 @@ def get_most_recently_updated_item(
)
# Filter by user_id for multi-user support
- query = query.where(ProceduralMemoryItem.user_id == user.id)
+ query = query.where(ProceduralMemoryItem.user_id == user_id)
result = session.execute(query.limit(1))
item = result.scalar_one_or_none()
- return [item.to_pydantic()] if item else None
+ return item.to_pydantic() if item else None
@enforce_types
def create_item(
@@ -1061,7 +1064,7 @@ def soft_delete_by_client_id(self, actor: PydanticClient) -> int:
# Query all non-deleted records for this client (use actor.id)
items = session.query(ProceduralMemoryItem).filter(
ProceduralMemoryItem.client_id == actor.id,
- ProceduralMemoryItem.is_deleted == False
+ ProceduralMemoryItem.is_deleted.is_(False)
).all()
count = len(items)
@@ -1107,7 +1110,7 @@ def soft_delete_by_user_id(self, user_id: str) -> int:
# Query all non-deleted records for this user
items = session.query(ProceduralMemoryItem).filter(
ProceduralMemoryItem.user_id == user_id,
- ProceduralMemoryItem.is_deleted == False
+ ProceduralMemoryItem.is_deleted.is_(False)
).all()
count = len(items)
diff --git a/mirix/services/resource_memory_manager.py b/mirix/services/resource_memory_manager.py
index b167a611..d7cdeac2 100755
--- a/mirix/services/resource_memory_manager.py
+++ b/mirix/services/resource_memory_manager.py
@@ -397,7 +397,10 @@ def get_item_by_id(
@update_timezone
@enforce_types
def get_most_recently_updated_item(
- self, user: PydanticUser, timezone_str: str = None
+ self,
+ actor: PydanticClient,
+ user_id: str,
+ timezone_str: str = None
) -> Optional[PydanticResourceMemoryItem]:
"""
Fetch the most recently updated resource memory item based on last_modify timestamp.
@@ -410,7 +413,7 @@ def get_most_recently_updated_item(
query = (
select(ResourceMemoryItem)
- .where(ResourceMemoryItem.user_id == user.id)
+ .where(ResourceMemoryItem.user_id == user_id)
.order_by(
cast(
text("resource_memory.last_modify ->> 'timestamp'"), DateTime
@@ -421,7 +424,7 @@ def get_most_recently_updated_item(
result = session.execute(query.limit(1))
item = result.scalar_one_or_none()
- return [item.to_pydantic()] if item else None
+ return item.to_pydantic() if item else None
@enforce_types
def create_item(
@@ -966,7 +969,7 @@ def soft_delete_by_client_id(self, actor: PydanticClient) -> int:
# Query all non-deleted records for this client (use actor.id)
items = session.query(ResourceMemoryItem).filter(
ResourceMemoryItem.client_id == actor.id,
- ResourceMemoryItem.is_deleted == False
+ ResourceMemoryItem.is_deleted.is_(False)
).all()
count = len(items)
@@ -1012,7 +1015,7 @@ def soft_delete_by_user_id(self, user_id: str) -> int:
# Query all non-deleted records for this user
items = session.query(ResourceMemoryItem).filter(
ResourceMemoryItem.user_id == user_id,
- ResourceMemoryItem.is_deleted == False
+ ResourceMemoryItem.is_deleted.is_(False)
).all()
count = len(items)
diff --git a/mirix/services/self_reflection_service.py b/mirix/services/self_reflection_service.py
index ba1cc469..4d3741b8 100644
--- a/mirix/services/self_reflection_service.py
+++ b/mirix/services/self_reflection_service.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime, timezone
-from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, TypeVar
from sqlalchemy import and_, desc, select
diff --git a/mirix/services/semantic_memory_manager.py b/mirix/services/semantic_memory_manager.py
index 346429d3..d092d693 100755
--- a/mirix/services/semantic_memory_manager.py
+++ b/mirix/services/semantic_memory_manager.py
@@ -458,7 +458,10 @@ def get_semantic_item_by_id(
@update_timezone
@enforce_types
def get_most_recently_updated_item(
- self, user: PydanticUser, timezone_str: str = None
+ self,
+ actor: PydanticClient,
+ user_id: str,
+ timezone_str: str = None
) -> Optional[PydanticSemanticMemoryItem]:
"""
Fetch the most recently updated semantic memory item based on last_modify timestamp.
@@ -476,12 +479,12 @@ def get_most_recently_updated_item(
)
# Filter by user_id for multi-user support
- query = query.where(SemanticMemoryItem.user_id == user.id)
+ query = query.where(SemanticMemoryItem.user_id == user_id)
result = session.execute(query.limit(1))
item = result.scalar_one_or_none()
- return [item.to_pydantic()] if item else None
+ return item.to_pydantic() if item else None
@enforce_types
def create_item(
@@ -1078,7 +1081,7 @@ def soft_delete_by_client_id(self, actor: PydanticClient) -> int:
# Query all non-deleted records for this client (use actor.id)
items = session.query(SemanticMemoryItem).filter(
SemanticMemoryItem.client_id == actor.id,
- SemanticMemoryItem.is_deleted == False
+ SemanticMemoryItem.is_deleted.is_(False)
).all()
count = len(items)
@@ -1124,7 +1127,7 @@ def soft_delete_by_user_id(self, user_id: str) -> int:
# Query all non-deleted records for this user
items = session.query(SemanticMemoryItem).filter(
SemanticMemoryItem.user_id == user_id,
- SemanticMemoryItem.is_deleted == False
+ SemanticMemoryItem.is_deleted.is_(False)
).all()
count = len(items)
diff --git a/mirix/services/user_manager.py b/mirix/services/user_manager.py
index 77dc7490..ab9eb6b1 100755
--- a/mirix/services/user_manager.py
+++ b/mirix/services/user_manager.py
@@ -1,4 +1,5 @@
-from typing import List, Optional, Tuple
+from datetime import datetime
+from typing import List, Optional
from mirix.log import get_logger
from mirix.orm.errors import NoResultFound
@@ -135,7 +136,6 @@ def update_last_self_reflection_time(self, user_id: str, reflection_time: "datet
Returns:
Updated user object
"""
- from datetime import datetime
with self.session_maker() as session:
# Retrieve the existing user by ID
@@ -447,7 +447,7 @@ def get_or_create_org_admin_user(self, org_id: str, client_id: Optional[str] = N
user = session.query(UserModel).filter(
UserModel.name == self.ADMIN_USER_NAME,
UserModel.organization_id == org_id,
- UserModel.is_deleted == False
+ UserModel.is_deleted.is_(False)
).first()
if user:
@@ -493,8 +493,8 @@ def get_admin_user_for_client(self, client_id: str) -> Optional[PydanticUser]:
with self.session_maker() as session:
admin_user = session.query(UserModel).filter(
UserModel.client_id == client_id,
- UserModel.is_admin == True,
- UserModel.is_deleted == False,
+ UserModel.is_admin.is_(True),
+ UserModel.is_deleted.is_(False),
).first()
if admin_user:
@@ -529,7 +529,7 @@ def list_users(
organization_id: Filter by organization ID
"""
with self.session_maker() as session:
- query = session.query(UserModel).filter(UserModel.is_deleted == False)
+ query = session.query(UserModel).filter(UserModel.is_deleted.is_(False))
if client_id:
query = query.filter(UserModel.client_id == client_id)
diff --git a/mirix/services/utils.py b/mirix/services/utils.py
index 98020970..b3576c92 100644
--- a/mirix/services/utils.py
+++ b/mirix/services/utils.py
@@ -107,6 +107,8 @@ def build_query(
def update_timezone(func):
@wraps(func)
def wrapper(*args, **kwargs):
+ from datetime import datetime
+
# Access timezone_str from kwargs (it will be None if not provided)
timezone_str = kwargs.get("timezone_str")
@@ -122,43 +124,53 @@ def wrapper(*args, **kwargs):
if results is None:
return None
+ # Handle both single objects and lists
+ is_single_object = not isinstance(results, list)
+ results_list = [results] if is_single_object else results
+
+ # ALWAYS convert string timestamps to datetime objects, regardless of timezone_str
+ for result in results_list:
+ # Convert last_modify timestamp from string to datetime if needed
+ if (
+ hasattr(result, "last_modify")
+ and result.last_modify
+ and "timestamp" in result.last_modify
+ ):
+ timestamp = result.last_modify["timestamp"]
+ if isinstance(timestamp, str):
+ timestamp = datetime.fromisoformat(
+ timestamp.replace("Z", "+00:00")
+ )
+ result.last_modify["timestamp"] = timestamp
+
+ # Only do timezone conversion if timezone_str is provided
if timezone_str:
- for result in results:
+ target_tz = pytz.timezone(timezone_str.split(" (")[0])
+ for result in results_list:
if hasattr(result, "occurred_at"):
if result.occurred_at.tzinfo is None:
result.occurred_at = pytz.utc.localize(result.occurred_at)
- target_tz = pytz.timezone(timezone_str.split(" (")[0])
result.occurred_at = result.occurred_at.astimezone(target_tz)
if hasattr(result, "created_at"):
if result.created_at.tzinfo is None:
result.created_at = pytz.utc.localize(result.created_at)
- target_tz = pytz.timezone(timezone_str.split(" (")[0])
result.created_at = result.created_at.astimezone(target_tz)
if hasattr(result, "updated_at") and result.updated_at is not None:
if result.updated_at.tzinfo is None:
result.updated_at = pytz.utc.localize(result.updated_at)
- target_tz = pytz.timezone(timezone_str.split(" (")[0])
result.updated_at = result.updated_at.astimezone(target_tz)
if (
hasattr(result, "last_modify")
and result.last_modify
and "timestamp" in result.last_modify
):
- # Check if timestamp is a string (ISO format) and convert to datetime
+ # At this point, timestamp should already be a datetime object
timestamp = result.last_modify["timestamp"]
- if isinstance(timestamp, str):
- from datetime import datetime
-
- timestamp = datetime.fromisoformat(
- timestamp.replace("Z", "+00:00")
- )
-
- # Now handle timezone conversion
if timestamp.tzinfo is None:
timestamp = pytz.utc.localize(timestamp)
- target_tz = pytz.timezone(timezone_str.split(" (")[0])
result.last_modify["timestamp"] = timestamp.astimezone(target_tz)
- return results
+ # Return single object or list based on input
+ return results_list[0] if is_single_object else results_list
return wrapper
diff --git a/mirix/system.py b/mirix/system.py
index 3b834526..a4f36866 100755
--- a/mirix/system.py
+++ b/mirix/system.py
@@ -1,96 +1,13 @@
import json
-import uuid
import warnings
from typing import Optional
from .constants import (
- INITIAL_BOOT_MESSAGE,
- INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG,
- INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
MESSAGE_SUMMARY_WARNING_STR,
)
from .helpers.datetime_helpers import get_local_time
from .helpers.json_helpers import json_dumps
-
-def get_initial_boot_messages(version="startup"):
- if version == "startup":
- initial_boot_message = INITIAL_BOOT_MESSAGE
- messages = [
- {"role": "assistant", "content": initial_boot_message},
- ]
-
- elif version == "startup_with_send_message":
- tool_call_id = str(uuid.uuid4())
- messages = [
- # first message includes both inner monologue and function call to send_message
- {
- "role": "assistant",
- "content": INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
- # "function_call": {
- # "name": "send_message",
- # "arguments": '{\n "message": "' + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}" + '"\n}',
- # },
- "tool_calls": [
- {
- "id": tool_call_id,
- "type": "function",
- "function": {
- "name": "send_message",
- "arguments": '{\n "message": "'
- + f"{INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG}"
- + '"\n}',
- },
- }
- ],
- },
- # obligatory function return message
- {
- # "role": "function",
- "role": "tool",
- "name": "send_message", # NOTE: technically not up to spec, this is old functions style
- "content": package_function_response(True, None),
- "tool_call_id": tool_call_id,
- },
- ]
-
- elif version == "startup_with_send_message_gpt35":
- tool_call_id = str(uuid.uuid4())
- messages = [
- # first message includes both inner monologue and function call to send_message
- {
- "role": "assistant",
- "content": "*inner thoughts* Still waiting on the user. Sending a message with function.",
- # "function_call": {"name": "send_message", "arguments": '{\n "message": "' + f"Hi, is anyone there?" + '"\n}'},
- "tool_calls": [
- {
- "id": tool_call_id,
- "type": "function",
- "function": {
- "name": "send_message",
- "arguments": '{\n "message": "'
- + "Hi, is anyone there?"
- + '"\n}',
- },
- }
- ],
- },
- # obligatory function return message
- {
- # "role": "function",
- "role": "tool",
- "name": "send_message",
- "content": package_function_response(True, None),
- "tool_call_id": tool_call_id,
- },
- ]
-
- else:
- raise ValueError(version)
-
- return messages
-
-
def get_contine_chaining(
reason="Automated timer",
include_location=False,
diff --git a/mirix/topic_extraction.py b/mirix/topic_extraction.py
new file mode 100644
index 00000000..fc9673b0
--- /dev/null
+++ b/mirix/topic_extraction.py
@@ -0,0 +1,116 @@
+from typing import Any, Dict, List, Optional
+
+import requests
+
+from mirix.log import get_logger
+from mirix.settings import model_settings
+
+logger = get_logger(__name__)
+
+
+def flatten_messages_to_plain_text(messages: List[Dict[str, Any]]) -> str:
+ transcript_parts = []
+ for msg in messages:
+ if not isinstance(msg, dict):
+ continue
+
+ role = msg.get("role", "user")
+ content = msg.get("content")
+ parts = []
+
+ if isinstance(content, list):
+ for chunk in content:
+ if isinstance(chunk, dict):
+ text = chunk.get("text")
+ if text:
+ parts.append(text.strip())
+ elif isinstance(chunk, str):
+ parts.append(chunk.strip())
+ elif isinstance(content, str):
+ parts.append(content.strip())
+
+ combined = " ".join(filter(None, parts)).strip()
+ if combined:
+ transcript_parts.append(f"{role.upper()}: {combined}")
+
+ return "\n".join(transcript_parts)
+
+
+def extract_topics_with_ollama(
+ messages: List[Dict[str, Any]],
+ model_name: str,
+ base_url: Optional[str] = None,
+) -> Optional[str]:
+ """
+ Extract topics using a locally hosted Ollama model via the /api/chat endpoint.
+
+ Reference: https://github.com/ollama/ollama/blob/main/docs/api.md#chat
+ """
+ base_url = base_url or model_settings.ollama_base_url
+ if not base_url:
+ logger.warning(
+ "Ollama topic extraction requested (%s) but MIRIX_OLLAMA_BASE_URL is not configured",
+ model_name,
+ )
+ return None
+
+ conversation = flatten_messages_to_plain_text(messages)
+ if not conversation:
+ logger.debug("No text content found in messages for Ollama topic extraction")
+ return None
+
+ payload = {
+ "model": model_name,
+ "stream": False,
+ "messages": [
+ {
+ "role": "system",
+ "content": (
+ "You are a helpful assistant that extracts the topic from the user's input. "
+ "Return a concise list of topics separated by ';' and nothing else."
+ ),
+ },
+ {
+ "role": "user",
+ "content": (
+ "Conversation transcript:\n"
+ f"{conversation}\n\n"
+ "Respond ONLY with the topic(s) separated by ';'."
+ ),
+ },
+ ],
+ "options": {
+ "temperature": 0,
+ },
+ }
+
+ try:
+ response = requests.post(
+ f"{base_url.rstrip('/')}/api/chat",
+ json=payload,
+ timeout=30,
+ proxies={"http": None, "https": None},
+ )
+ response.raise_for_status()
+ response_data = response.json()
+ except requests.RequestException as exc:
+ logger.error("Failed to extract topics with Ollama model %s: %s", model_name, exc)
+ return None
+
+ message_payload = response_data.get("message") if isinstance(response_data, dict) else None
+ text_response: Optional[str] = None
+ if isinstance(message_payload, dict):
+ text_response = message_payload.get("content")
+ elif isinstance(response_data, dict):
+ text_response = response_data.get("content")
+
+ if isinstance(text_response, str):
+ topics = text_response.strip()
+ logger.debug("Extracted topics via Ollama model %s: %s", model_name, topics)
+ return topics or None
+
+ logger.warning(
+ "Unexpected response format from Ollama topic extraction: %s",
+ response_data,
+ )
+ return None
diff --git a/mirix/utils.py b/mirix/utils.py
index e192f896..cdc5d632 100755
--- a/mirix/utils.py
+++ b/mirix/utils.py
@@ -24,8 +24,6 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
- Any,
- Dict,
List,
Optional,
Union,
@@ -43,7 +41,7 @@
from pathvalidate import sanitize_filename as pathvalidate_sanitize_filename
import mirix
-from mirix.client.utils import get_utc_time, json_dumps # Re-export from client
+from mirix.client.utils import json_dumps # Re-export from client
from mirix.constants import (
CLI_WARNING_PREFIX,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
@@ -64,7 +62,6 @@
TextContent,
)
from mirix.schemas.openai.chat_completion_request import Tool, ToolCall
-from mirix.schemas.openai.chat_completion_response import ChatCompletionResponse
if TYPE_CHECKING:
from mirix.services.file_manager import FileManager
@@ -579,6 +576,9 @@ def is_optional_type(hint):
return hint.__origin__ is Union and type(None) in hint.__args__
return False
+def get_utc_time() -> datetime:
+ """Get the current UTC time"""
+ return datetime.now(timezone.utc)
def enforce_types(func):
@wraps(func)
@@ -776,84 +776,6 @@ def create_random_username() -> str:
noun = random.choice(NOUN_BANK).capitalize()
return adjective + noun
-
-def verify_first_message_correctness(
- response: ChatCompletionResponse,
- require_send_message: bool = True,
- require_monologue: bool = False,
-) -> bool:
- """Can be used to enforce that the first message always uses send_message"""
- response_message = response.choices[0].message
-
- # First message should be a call to send_message with a non-empty content
- if (
- hasattr(response_message, "function_call")
- and response_message.function_call is not None
- ) and (
- hasattr(response_message, "tool_calls")
- and response_message.tool_calls is not None
- ):
- printd(
- f"First message includes both function call AND tool call: {response_message}"
- )
- return False
- elif (
- hasattr(response_message, "function_call")
- and response_message.function_call is not None
- ):
- function_call = response_message.function_call
- elif (
- hasattr(response_message, "tool_calls")
- and response_message.tool_calls is not None
- ):
- function_call = response_message.tool_calls[0].function
- else:
- printd(f"First message didn't include function call: {response_message}")
- return False
-
- function_name = function_call.name if function_call is not None else ""
- if (
- require_send_message
- and function_name != "send_message"
- and function_name != "archival_memory_search"
- ):
- printd(
- f"First message function call wasn't send_message or archival_memory_search: {response_message}"
- )
- return False
-
- if require_monologue and (
- not response_message.content
- or response_message.content is None
- or response_message.content == ""
- ):
- printd(f"First message missing internal monologue: {response_message}")
- return False
-
- if response_message.content:
- ### Extras
- monologue = response_message.content
-
- def contains_special_characters(s):
- special_characters = '(){}[]"'
- return any(char in s for char in special_characters)
-
- if contains_special_characters(monologue):
- printd(
- f"First message internal monologue contained special characters: {response_message}"
- )
- return False
- # if 'functions' in monologue or 'send_message' in monologue or 'inner thought' in monologue.lower():
- if "functions" in monologue or "send_message" in monologue:
- # Sometimes the syntax won't be correct and internal syntax will leak into message.context
- printd(
- f"First message internal monologue contained reserved words: {response_message}"
- )
- return False
-
- return True
-
-
def is_valid_url(url):
try:
result = urlparse(url)
diff --git a/pyproject.toml b/pyproject.toml
index 44372e71..f0f7a89b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "mirix"
-version = "0.1.4"
+version = "0.1.6"
description = "Multi-Agent Personal Assistant with an Advanced Memory System"
readme = "README.md"
license = {text = "Apache License 2.0"}
diff --git a/samples/custom_auth_example.py b/samples/custom_auth_example.py
index 3bcc003a..8d0c2886 100644
--- a/samples/custom_auth_example.py
+++ b/samples/custom_auth_example.py
@@ -142,7 +142,7 @@ def test_auth_provider():
print("Registered auth providers:", list_auth_providers())
sample_config = create_llm_config_with_auth()
- print(f"\nCreated LLM config with auth provider:")
+ print("\nCreated LLM config with auth provider:")
print(f" - {sample_config.auth_provider}")
test_auth_provider()
diff --git a/samples/generate_demo_api_key.py b/samples/generate_demo_api_key.py
index 34c06150..2df1ab39 100644
--- a/samples/generate_demo_api_key.py
+++ b/samples/generate_demo_api_key.py
@@ -15,11 +15,11 @@
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
-from mirix.security.api_keys import generate_api_key
-from mirix.services.client_manager import ClientManager
-from mirix.services.organization_manager import OrganizationManager
-from mirix.schemas.client import Client as PydanticClient
-from mirix.schemas.organization import Organization as PydanticOrganization
+from mirix.security.api_keys import generate_api_key # noqa: E402
+from mirix.services.client_manager import ClientManager # noqa: E402
+from mirix.services.organization_manager import OrganizationManager # noqa: E402
+from mirix.schemas.client import Client as PydanticClient # noqa: E402
+from mirix.schemas.organization import Organization as PydanticOrganization # noqa: E402
ORG_ID = "demo-org"
diff --git a/samples/langgraph_integration.py b/samples/langgraph_integration.py
index 5d043df0..09dac1f4 100644
--- a/samples/langgraph_integration.py
+++ b/samples/langgraph_integration.py
@@ -12,10 +12,9 @@
load_dotenv(os.path.join(mirix_root, ".env"))
-import logging
+import logging # noqa: E402
-from mirix.schemas.agent import AgentType
-from mirix.client import MirixClient
+from mirix.client import MirixClient # noqa: E402
# Configure logging
logging.basicConfig(
diff --git a/samples/load_json_batch_conversations.py b/samples/load_json_batch_conversations.py
index 6549b20d..4e96edd9 100644
--- a/samples/load_json_batch_conversations.py
+++ b/samples/load_json_batch_conversations.py
@@ -26,7 +26,7 @@
import sys
import time
from pathlib import Path
-from typing import List, Dict, Any, Optional
+from typing import List, Dict, Any
from mirix.client import MirixClient
diff --git a/samples/poetry.lock b/samples/poetry.lock
index 6842edb7..19a51d50 100644
--- a/samples/poetry.lock
+++ b/samples/poetry.lock
@@ -2867,7 +2867,7 @@ files = [
[[package]]
name = "mirix"
-version = "0.1.4"
+version = "0.1.6"
description = "Multi-Agent Personal Assistant with an Advanced Memory System"
optional = false
python-versions = ">=3.10,<4.0"
diff --git a/samples/pyproject.toml b/samples/pyproject.toml
index 5dfb1e73..2a27285d 100644
--- a/samples/pyproject.toml
+++ b/samples/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "mirix-samples"
-version = "0.1.4"
+version = "0.1.6"
description = "A sample directory for MIRIX quick start"
readme = "README.md"
license = {text = "Apache License 2.0"}
diff --git a/samples/run_client.py b/samples/run_client.py
index c7791efb..efdda919 100644
--- a/samples/run_client.py
+++ b/samples/run_client.py
@@ -12,6 +12,7 @@
import logging
import os
import sys
+import yaml
from pathlib import Path
from dotenv import load_dotenv
@@ -25,7 +26,7 @@
# Load env vars from repo-root `./.env` (does not override already-set env vars).
load_dotenv(REPO_ROOT / ".env", override=False)
-from mirix import MirixClient
+from mirix import MirixClient # noqa: E402
# Configure logging
logging.basicConfig(
@@ -52,7 +53,7 @@ def print_memories(memories):
print(f" 🕐 Temporal expression detected: '{memories.get('temporal_expression')}'")
if memories.get("date_range"):
date_range = memories.get("date_range")
- print(f" 📅 Date range applied:")
+ print(" 📅 Date range applied:")
print(f" Start: {date_range.get('start')}")
print(f" End: {date_range.get('end')}")
@@ -382,30 +383,20 @@ def main():
# Navigate to project root (parent of samples/)
project_root = script_dir.parent
# Build path to config file
- config_path = project_root / "mirix" / "configs" / "examples" / "mirix_gemini.yaml"
-
- # Verify the config file exists
- if not config_path.exists():
- raise FileNotFoundError(f"Config file not found: {config_path}")
# Create MirixClient (connects to server via REST API)
- client_id = 'sales-loader-client' #'demo-client-app' # Identifies the client application
- # user_id = 'demo-user' #'demo-user' # Identifies the end-user within the client app
- org_id = 'demo-org'
api_key = os.environ.get("MIRIX_API_KEY")
if not api_key:
raise ValueError("MIRIX_API_KEY is required to run this sample.")
client = MirixClient(
api_key=api_key,
- # client_id="sales-loader-client",
- # client_scope="Sales",
- # org_id="demo-org",
- debug=True,
)
+ config = yaml.safe_load(open(args.config))
+
client.initialize_meta_agent(
- config_path=args.config,
+ config=config,
update_agents=True,
)
@@ -432,7 +423,7 @@ def main():
}]
},
],
- chaining=True
+ chaining=False
)
print(f"[OK] Memory added successfully: {result.get('success', False)}")
diff --git a/scripts/clean_cache.py b/scripts/clean_cache.py
index e495dcf9..41bf9dc4 100755
--- a/scripts/clean_cache.py
+++ b/scripts/clean_cache.py
@@ -11,7 +11,6 @@
./scripts/clean_cache.py
"""
-import os
import shutil
from pathlib import Path
from typing import List, Tuple
diff --git a/scripts/packaging/setup_client.py b/scripts/packaging/setup_client.py
index 641aeb1c..e1462f8c 100644
--- a/scripts/packaging/setup_client.py
+++ b/scripts/packaging/setup_client.py
@@ -14,7 +14,7 @@
import os
import sys
-from setuptools import find_packages, setup
+from setuptools import setup
# Parse command line arguments for package name and version
package_name = "mirix-client" # Default value
diff --git a/scripts/start_server.py b/scripts/start_server.py
index 67e70862..10ea140f 100755
--- a/scripts/start_server.py
+++ b/scripts/start_server.py
@@ -72,9 +72,8 @@ def main():
# Check if running in production mode
if args.production:
- try:
- import gunicorn.app.base
- except ImportError:
+ import importlib.util
+ if importlib.util.find_spec("gunicorn") is None:
print("Error: gunicorn is required for production mode")
print("Install it with: pip install gunicorn")
sys.exit(1)
diff --git a/tests/test_agent_prompt_update.py b/tests/test_agent_prompt_update.py
index f9c662a4..6a1a561d 100644
--- a/tests/test_agent_prompt_update.py
+++ b/tests/test_agent_prompt_update.py
@@ -30,7 +30,7 @@
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
-from mirix.client import MirixClient
+from mirix.client import MirixClient # noqa: E402
# Mark all tests as integration tests
pytestmark = [
@@ -86,7 +86,7 @@ def client(server_check, api_auth):
print("[SETUP] Client created via API key")
# Create or get user (ensures user exists in backend database)
- print(f"[SETUP] Creating/getting user: demo-user")
+ print("[SETUP] Creating/getting user: demo-user")
try:
user_id = client.create_or_get_user(
user_id="demo-user",
@@ -96,7 +96,7 @@ def client(server_check, api_auth):
except Exception as e:
import traceback
error_details = traceback.format_exc()
- print(f"\n[ERROR] Failed to create/get user:")
+ print("\n[ERROR] Failed to create/get user:")
print(f" Exception: {e}")
print(f" Details:\n{error_details}")
pytest.skip(f"Failed to create/get user: {e}")
@@ -194,7 +194,7 @@ def client(server_check, api_auth):
except Exception as e:
import traceback
error_details = traceback.format_exc()
- print(f"\n[ERROR] Failed to initialize meta agent:")
+ print("\n[ERROR] Failed to initialize meta agent:")
print(f" Exception: {e}")
print(f" Details:\n{error_details}")
pytest.skip(f"Failed to initialize meta agent: {e}")
@@ -339,18 +339,18 @@ def test_update_agent_system_prompt(client, agent_name):
agent_name=agent_name,
system_prompt=new_system_prompt
)
- print(f"[OK] Update request successful")
+ print("[OK] Update request successful")
print(f" Updated system prompt: {updated_agent.system[:80]}...")
except Exception as e:
pytest.fail(f"Failed to update system prompt: {e}")
# Step 3: Verify the update in the returned agent state
- print(f"\n[Step 3] Verifying update in returned agent state...")
+ print("\n[Step 3] Verifying update in returned agent state...")
assert updated_agent.system == new_system_prompt, \
"System prompt in returned agent should match the new prompt"
- print(f"[OK] System prompt matches in returned state")
+ print("[OK] System prompt matches in returned state")
# Verify system message ID changed
new_message_id = updated_agent.message_ids[0] if updated_agent.message_ids else None
@@ -360,18 +360,18 @@ def test_update_agent_system_prompt(client, agent_name):
print(f"[OK] System message ID changed: {original_message_id} → {new_message_id}")
# Step 4: Wait for cache and database to sync
- print(f"\n[Step 4] Waiting 2 seconds for cache/database sync...")
+ print("\n[Step 4] Waiting 2 seconds for cache/database sync...")
time.sleep(2)
# Step 5: Verify persistence by fetching agent again (tests Redis cache)
- print(f"\n[Step 5] Fetching agent again to verify persistence (Redis cache)...")
+ print("\n[Step 5] Fetching agent again to verify persistence (Redis cache)...")
refetched_agent = get_agent_direct_from_api(client, agent_name)
assert refetched_agent is not None, "Agent should still exist after update"
assert refetched_agent.system == new_system_prompt, \
"System prompt should persist in cache/database"
- print(f"[OK] System prompt persisted in cache")
+ print("[OK] System prompt persisted in cache")
print(f" Cached prompt: {refetched_agent.system[:80]}...")
# Verify message_ids[0] is still the new one
@@ -381,20 +381,20 @@ def test_update_agent_system_prompt(client, agent_name):
print(f"[OK] System message ID persisted: {cached_message_id}")
# Step 6: Verify system prompt in agent state
- print(f"\n[Step 6] Verifying system prompt is stored correctly...")
+ print("\n[Step 6] Verifying system prompt is stored correctly...")
# The agent.system field should contain the new prompt
assert refetched_agent.system == new_system_prompt, \
"Agent's system field should contain the new system prompt"
- print(f"[OK] System prompt verified in agent state")
+ print("[OK] System prompt verified in agent state")
print(f" Prompt: {refetched_agent.system[:80]}...")
# Step 7: Verify old and new are different
- print(f"\n[Step 7] Verifying changes were actually made...")
+ print("\n[Step 7] Verifying changes were actually made...")
assert updated_agent.system != original_agent.system, \
"New system prompt should be different from original"
- print(f"[OK] System prompt was successfully changed")
+ print("[OK] System prompt was successfully changed")
print(f"\n✓ TEST PASSED for '{agent_name}' agent")
print("="*70)
@@ -496,7 +496,7 @@ def test_update_same_agent_multiple_times(client):
# Verify prompt changed
assert updated.system == new_prompt, f"Update {i} should apply new prompt"
- print(f" ✓ Prompt updated")
+ print(" ✓ Prompt updated")
# Verify message_ids[0] changed
current_message_id = updated.message_ids[0] if updated.message_ids else None
@@ -509,7 +509,7 @@ def test_update_same_agent_multiple_times(client):
if previous_prompt:
assert updated.system != previous_prompt, \
f"Update {i} should change prompt from previous"
- print(f" ✓ Prompt changed from previous")
+ print(" ✓ Prompt changed from previous")
previous_message_id = current_message_id
previous_prompt = new_prompt
@@ -518,14 +518,14 @@ def test_update_same_agent_multiple_times(client):
time.sleep(1)
# Final verification
- print(f"\n[Final Verification] Fetching agent to verify last update...")
+ print("\n[Final Verification] Fetching agent to verify last update...")
final_agent = get_agent_direct_from_api(client, agent_name)
assert final_agent.system == previous_prompt, \
"Final prompt should match last update"
- print(f" ✓ Final prompt matches last update")
+ print(" ✓ Final prompt matches last update")
- print(f"\n✓ TEST PASSED: Multiple updates to same agent work correctly")
+ print("\n✓ TEST PASSED: Multiple updates to same agent work correctly")
print("="*70)
@@ -564,7 +564,7 @@ def test_error_handling_nonexistent_agent(client):
# Verify it suggests available agents (if any exist)
if "available agents:" in error_lower or "available" in error_lower:
- print(f" ✓ Error message suggests available agents")
+ print(" ✓ Error message suggests available agents")
# Test Case 2: Typo in short name (e.g., "episodick" instead of "episodic")
print("\n[Test Case 2] Attempting to update with typo in agent name...")
@@ -605,7 +605,7 @@ def test_error_handling_nonexistent_agent(client):
)
print(f" ✓ Exception raised for wrong case: {type(exc_info.value).__name__}")
- print(f" Note: Agent names are case-sensitive")
+ print(" Note: Agent names are case-sensitive")
# Test Case 5: Empty string
print("\n[Test Case 5] Attempting to update with empty agent name...")
@@ -618,7 +618,7 @@ def test_error_handling_nonexistent_agent(client):
print(f" ✓ Exception raised for empty name: {type(exc_info.value).__name__}")
- print(f"\n✓ TEST PASSED: All error handling scenarios work correctly")
+ print("\n✓ TEST PASSED: All error handling scenarios work correctly")
print(" - Invalid agent names are rejected")
print(" - Error messages are informative")
print(" - Server remains stable")
diff --git a/tests/test_client_agent_isolation.py b/tests/test_client_agent_isolation.py
index 355d828c..a7383b41 100644
--- a/tests/test_client_agent_isolation.py
+++ b/tests/test_client_agent_isolation.py
@@ -14,7 +14,6 @@
import requests
from mirix.client import MirixClient
-from mirix.schemas.agent import AgentState
# Configure logging
@@ -276,7 +275,7 @@ def test_get_agent_by_id_enforces_client_ownership(client_a, client_b, meta_agen
error_message = str(exc_info.value).lower()
assert "not found" in error_message or "404" in error_message, \
f"Expected 404/not found error, got: {exc_info.value}"
- print(f"✅ Client B correctly denied access to client A's agent (404)")
+ print("✅ Client B correctly denied access to client A's agent (404)")
# Client A tries to access client B's agent - should fail
with pytest.raises(Exception) as exc_info:
@@ -285,7 +284,7 @@ def test_get_agent_by_id_enforces_client_ownership(client_a, client_b, meta_agen
error_message = str(exc_info.value).lower()
assert "not found" in error_message or "404" in error_message, \
f"Expected 404/not found error, got: {exc_info.value}"
- print(f"✅ Client A correctly denied access to client B's agent (404)")
+ print("✅ Client A correctly denied access to client B's agent (404)")
print("✅ Client ownership enforcement verified")
@@ -328,7 +327,7 @@ def test_child_agents_filtered_by_client(client_a, client_b, meta_agent_a, meta_
if meta_a.children and len(meta_a.children) > 0:
print(f"✅ Client A's meta agent has {len(meta_a.children)} sub-agents")
else:
- print(f"ℹ️ Client A's meta agent children not yet populated (async timing)")
+ print("ℹ️ Client A's meta agent children not yet populated (async timing)")
# Verify client B's meta agent has children
# If children field is empty, it might be due to async timing - log but don't fail immediately
@@ -346,7 +345,7 @@ def test_child_agents_filtered_by_client(client_a, client_b, meta_agent_a, meta_
if meta_b.children and len(meta_b.children) > 0:
print(f"✅ Client B's meta agent has {len(meta_b.children)} sub-agents")
else:
- print(f"ℹ️ Client B's meta agent children not yet populated (async timing)")
+ print("ℹ️ Client B's meta agent children not yet populated (async timing)")
# Verify all of client A's children are created by client A (if children exist)
if meta_a.children:
@@ -390,12 +389,12 @@ def test_memory_apis_use_correct_client_agents(client_a, client_b, meta_agent_a,
]
try:
- result_a = client_a.add(
+ client_a.add(
user_id=TEST_USER_A_ID,
messages=messages_a
)
# The client automatically uses its own meta agent
- print(f"✅ Client A successfully added memory using its meta agent")
+ print("✅ Client A successfully added memory using its meta agent")
except Exception as e:
pytest.fail(f"Client A failed to add memory: {e}")
@@ -405,12 +404,12 @@ def test_memory_apis_use_correct_client_agents(client_a, client_b, meta_agent_a,
]
try:
- result_b = client_b.add(
+ client_b.add(
user_id=TEST_USER_B_ID,
messages=messages_b
)
# The client automatically uses its own meta agent
- print(f"✅ Client B successfully added memory using its meta agent")
+ print("✅ Client B successfully added memory using its meta agent")
except Exception as e:
pytest.fail(f"Client B failed to add memory: {e}")
@@ -419,24 +418,24 @@ def test_memory_apis_use_correct_client_agents(client_a, client_b, meta_agent_a,
# Try to retrieve memories for client A
try:
- memories_a = client_a.retrieve_memory_with_topic(
+ client_a.retrieve_memory_with_topic(
user_id=TEST_USER_A_ID,
topic="Python"
)
# If retrieval succeeds, it used the correct client's agents
- print(f"✅ Client A successfully retrieved memories using its agents")
+ print("✅ Client A successfully retrieved memories using its agents")
except Exception as e:
# Some errors are acceptable (e.g., no memories found yet)
print(f"ℹ️ Client A memory retrieval: {e}")
# Try to retrieve memories for client B
try:
- memories_b = client_b.retrieve_memory_with_topic(
+ client_b.retrieve_memory_with_topic(
user_id=TEST_USER_B_ID,
topic="Java"
)
# If retrieval succeeds, it used the correct client's agents
- print(f"✅ Client B successfully retrieved memories using its agents")
+ print("✅ Client B successfully retrieved memories using its agents")
except Exception as e:
# Some errors are acceptable (e.g., no memories found yet)
print(f"ℹ️ Client B memory retrieval: {e}")
@@ -449,8 +448,8 @@ def test_memory_apis_use_correct_client_agents(client_a, client_b, meta_agent_a,
# Actual behavior: Each client automatically uses its own meta agent (enforced by design)
# This is MORE SECURE than allowing agent_id as a parameter
- print(f"ℹ️ Agent isolation is enforced by design: each client automatically uses its own meta agent")
- print(f"✅ Client B can only use its own agents (no way to specify client A's agent)")
+ print("ℹ️ Agent isolation is enforced by design: each client automatically uses its own meta agent")
+ print("✅ Client B can only use its own agents (no way to specify client A's agent)")
print("✅ Memory API client isolation verified")
@@ -492,7 +491,7 @@ def test_redis_cache_respects_client_isolation(client_a, client_b, meta_agent_a,
error_message = str(exc_info.value).lower()
assert "not found" in error_message or "404" in error_message, \
f"Expected 404/not found error, got: {exc_info.value}"
- print(f"✅ Client B denied access to cached agent from client A")
+ print("✅ Client B denied access to cached agent from client A")
# Verify client A cannot access client B's cached agent
with pytest.raises(Exception) as exc_info:
@@ -501,7 +500,7 @@ def test_redis_cache_respects_client_isolation(client_a, client_b, meta_agent_a,
error_message = str(exc_info.value).lower()
assert "not found" in error_message or "404" in error_message, \
f"Expected 404/not found error, got: {exc_info.value}"
- print(f"✅ Client A denied access to cached agent from client B")
+ print("✅ Client A denied access to cached agent from client B")
print("✅ Redis cache client isolation verified")
@@ -529,7 +528,7 @@ def test_initialization_creates_separate_hierarchies(client_a, client_b):
assert len(top_level_b) == 1, \
f"Client B should have exactly 1 top-level agent, got {len(top_level_b)}"
- print(f"✅ Each client has exactly 1 top-level agent (meta agent)")
+ print("✅ Each client has exactly 1 top-level agent (meta agent)")
# Get child agents for each
meta_a = top_level_a[0]
diff --git a/tests/test_credit_system.py b/tests/test_credit_system.py
index 64770763..30d293e6 100644
--- a/tests/test_credit_system.py
+++ b/tests/test_credit_system.py
@@ -25,20 +25,20 @@ def _ensure_mirix_package():
_ensure_mirix_package()
-import mirix.services.client_manager as client_manager_module
-from mirix.client.utils import get_utc_time
-from mirix.llm_api.llm_client import LLMClient
-from mirix.pricing import calculate_cost
-from mirix.schemas.client import Client as PydanticClient
-from mirix.schemas.llm_config import LLMConfig
-from mirix.schemas.openai.chat_completion_response import (
+import mirix.services.client_manager as client_manager_module # noqa: E402
+from mirix.utils import get_utc_time # noqa: E402
+from mirix.llm_api.llm_client import LLMClient # noqa: E402
+from mirix.pricing import calculate_cost # noqa: E402
+from mirix.schemas.client import Client as PydanticClient # noqa: E402
+from mirix.schemas.llm_config import LLMConfig # noqa: E402
+from mirix.schemas.openai.chat_completion_response import ( # noqa: E402
ChatCompletionResponse,
Choice,
Message,
PromptTokensDetails,
UsageStatistics,
)
-from mirix.services.client_manager import ClientManager
+from mirix.services.client_manager import ClientManager # noqa: E402
class StubLLMClient:
diff --git a/tests/test_memory_decay.py b/tests/test_memory_decay.py
index 90b02a1c..94f62dc4 100644
--- a/tests/test_memory_decay.py
+++ b/tests/test_memory_decay.py
@@ -21,17 +21,17 @@
module="speech_recognition",
)
-from mirix.schemas.agent import (
+from mirix.schemas.agent import ( # noqa: E402
AgentType,
CreateMetaAgent,
MemoryConfig,
MemoryDecayConfig,
UpdateAgent,
)
-from mirix.schemas.client import Client
-from mirix.schemas.llm_config import LLMConfig
-from mirix.schemas.organization import Organization
-from mirix.schemas.user import User as PydanticUser
+from mirix.schemas.client import Client # noqa: E402
+from mirix.schemas.llm_config import LLMConfig # noqa: E402
+from mirix.schemas.organization import Organization # noqa: E402
+from mirix.schemas.user import User as PydanticUser # noqa: E402
TEST_RUN_ID = uuid.uuid4().hex[:8]
diff --git a/tests/test_memory_integration.py b/tests/test_memory_integration.py
index 48492e78..f66e74fd 100644
--- a/tests/test_memory_integration.py
+++ b/tests/test_memory_integration.py
@@ -21,6 +21,7 @@
import os
import sys
import time
+import yaml
import requests
from pathlib import Path
@@ -35,8 +36,7 @@
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
-from mirix import EmbeddingConfig, LLMConfig
-from mirix.client import MirixClient
+from mirix.client import MirixClient # noqa: E402
TEST_USER_ID = "demo-user"
TEST_CLIENT_ID = "demo-client"
@@ -92,8 +92,11 @@ def client(server_process, api_auth):
# Construct absolute path to config file
config_path = project_root / "mirix" / "configs" / "examples" / "mirix_gemini.yaml"
- result = client.initialize_meta_agent(
- config_path=str(config_path),
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+
+ client.initialize_meta_agent(
+ config=config,
update_agents=False # Don't update if already exists, just use existing
)
@@ -133,7 +136,7 @@ def test_add(client):
assert result is not None
assert result.get("success") is True
- print(f"[OK] Memory added successfully")
+ print("[OK] Memory added successfully")
def test_retrieve_with_conversation(client):
@@ -174,7 +177,7 @@ def test_retrieve_with_conversation(client):
assert result is not None
assert result.get("success") is True
assert "memories" in result
- print(f"[OK] Retrieved memories successfully")
+ print("[OK] Retrieved memories successfully")
# Display results
if result.get("memories"):
diff --git a/tests/test_memory_server.py b/tests/test_memory_server.py
index 01f7226a..ff3bbe7a 100644
--- a/tests/test_memory_server.py
+++ b/tests/test_memory_server.py
@@ -37,9 +37,9 @@
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
-from mirix.server.server import SyncServer
-from mirix.schemas.agent import AgentType
-from mirix.schemas.client import Client
+from mirix.server.server import SyncServer # noqa: E402
+from mirix.schemas.agent import AgentType # noqa: E402
+from mirix.schemas.client import Client # noqa: E402
# Skip all tests if no API key
pytestmark = pytest.mark.skipif(
diff --git a/tests/test_message_handling.py b/tests/test_message_handling.py
index 2dd03a7a..ddb65080 100644
--- a/tests/test_message_handling.py
+++ b/tests/test_message_handling.py
@@ -10,11 +10,9 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
-import pytest
from mirix.schemas.agent import AgentState
from mirix.schemas.client import Client
-from mirix.schemas.message import Message
from mirix.schemas.user import User
from mirix.services.agent_manager import AgentManager
from mirix.services.message_manager import MessageManager
diff --git a/tests/test_queue.py b/tests/test_queue.py
index bee1891e..32b88b5f 100644
--- a/tests/test_queue.py
+++ b/tests/test_queue.py
@@ -24,7 +24,6 @@
# Note: ProtoUser and ProtoMessageCreate are generated from message.proto
from mirix.queue.message_pb2 import MessageCreate as ProtoMessageCreate
from mirix.queue.message_pb2 import QueueMessage
-from mirix.queue.message_pb2 import User as ProtoUser
from mirix.queue.queue_util import put_messages
from mirix.queue.worker import QueueWorker
@@ -249,7 +248,7 @@ def test_same_user_routes_to_same_partition(self, sample_client):
# Put it back for counting
queue._partitions[partition_id].put(msg)
break
- except:
+ except Exception:
continue
assert partition_with_messages is not None
@@ -260,7 +259,7 @@ def test_same_user_routes_to_same_partition(self, sample_client):
try:
queue.get_from_partition(partition_with_messages, timeout=0.1)
count += 1
- except:
+ except Exception:
break
assert count == 10 # All 10 messages in same partition
@@ -287,7 +286,7 @@ def test_different_users_can_route_to_different_partitions(self, sample_client):
try:
queue.get_from_partition(partition_id, timeout=0.01)
partitions_with_messages.add(partition_id)
- except:
+ except Exception:
continue
# With 50 users and 100 partitions, we should have some spread
@@ -349,7 +348,7 @@ def test_fallback_to_actor_id_when_no_user_id(self, sample_client):
try:
queue.get_from_partition(partition_id, timeout=0.01)
count += 1
- except:
+ except Exception:
break
partition_counts.append(count)
diff --git a/tests/test_search_all_users.py b/tests/test_search_all_users.py
index 7a2b958f..6a816cdd 100644
--- a/tests/test_search_all_users.py
+++ b/tests/test_search_all_users.py
@@ -469,8 +469,8 @@ def test_search_excludes_user3_without_matching_scope(self, client1, user3_id):
user_ids_in_results = set(result['user_id'] for result in results['results'])
logger.info(f"User IDs in results: {user_ids_in_results}")
- logger.info(f"Searching with client1 scope='read_write'")
- logger.info(f"User 3 has scope='read_only' (via client3)")
+ logger.info("Searching with client1 scope='read_write'")
+ logger.info("User 3 has scope='read_only' (via client3)")
assert user3_id not in user_ids_in_results, "User 3 memories should be excluded (scope='read_only' doesn't match 'read_write')"
diff --git a/tests/test_temporal_queries.py b/tests/test_temporal_queries.py
index 583961b3..1e4c8caa 100644
--- a/tests/test_temporal_queries.py
+++ b/tests/test_temporal_queries.py
@@ -1,7 +1,7 @@
"""Tests for temporal query functionality."""
import pytest
-from datetime import datetime, timedelta
+from datetime import datetime
from mirix.temporal.temporal_parser import parse_temporal_expression, TemporalRange
diff --git a/tests/test_user.py b/tests/test_user.py
index 1c955bec..aeb1b2a1 100644
--- a/tests/test_user.py
+++ b/tests/test_user.py
@@ -30,7 +30,7 @@
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
-from mirix.client import MirixClient
+from mirix.client import MirixClient # noqa: E402
TEST_ORG_ID = "test-user-org"
TEST_CLIENT_ID = "test-user-client"
@@ -130,7 +130,7 @@ def test_explicit_user_creation_then_add_memory(client):
print(f"\n[Step 1] Generated user_id: {user_id}")
# Step 2: Create user explicitly
- print(f"[Step 2] Creating user with create_or_get_user()...")
+ print("[Step 2] Creating user with create_or_get_user()...")
created_user_id = client.create_or_get_user(
user_id=user_id,
user_name=f"Test User {user_id}",
@@ -140,7 +140,7 @@ def test_explicit_user_creation_then_add_memory(client):
assert created_user_id == user_id, "Returned user_id should match requested user_id"
# Step 3: Verify user exists in database
- print(f"[Step 3] Verifying user exists in database...")
+ print("[Step 3] Verifying user exists in database...")
time.sleep(1) # Small delay to ensure database write is complete
assert user_exists(client, user_id), f"User {user_id} should exist in database"
print(f"[OK] User {user_id} verified in database")
@@ -171,23 +171,23 @@ def test_explicit_user_creation_then_add_memory(client):
verbose=False
)
- print(f"[OK] Memory add request submitted")
+ print("[OK] Memory add request submitted")
print(f" Response: {response}")
# Step 5: Wait for processing
- print(f"[Step 5] Waiting 15 seconds for memory processing...")
+ print("[Step 5] Waiting 15 seconds for memory processing...")
time.sleep(15)
- print(f"[OK] Processing complete")
+ print("[OK] Processing complete")
# Step 6: Verify memory can be retrieved
- print(f"[Step 6] Retrieving memory to verify...")
+ print("[Step 6] Retrieving memory to verify...")
retrieve_response = client.retrieve_with_conversation(
user_id=user_id,
messages=[{"role": "user", "content": [{"type": "text", "text": "What is my favorite color?"}]}],
limit=5
)
- print(f"[OK] Memory retrieval successful")
+ print("[OK] Memory retrieval successful")
print(f" Retrieved {len(retrieve_response.get('memories', []))} memories")
# Verify we got some memories back
@@ -216,12 +216,12 @@ def test_auto_user_creation_on_add_memory(client):
print(f"\n[Step 1] Generated user_id: {user_id}")
# Step 2: Verify user does NOT exist yet
- print(f"[Step 2] Verifying user does NOT exist yet...")
+ print("[Step 2] Verifying user does NOT exist yet...")
assert not user_exists(client, user_id), f"User {user_id} should NOT exist yet"
print(f"[OK] Confirmed user {user_id} does not exist")
# Step 3: Add memory WITHOUT creating user first
- print(f"[Step 3] Adding memory WITHOUT calling create_or_get_user()...")
+ print("[Step 3] Adding memory WITHOUT calling create_or_get_user()...")
messages = [
{
"role": "user",
@@ -246,28 +246,28 @@ def test_auto_user_creation_on_add_memory(client):
verbose=False
)
- print(f"[OK] Memory add request submitted")
+ print("[OK] Memory add request submitted")
print(f" Response: {response}")
# Step 4: Wait for processing
- print(f"[Step 4] Waiting 15 seconds for memory processing and user auto-creation...")
+ print("[Step 4] Waiting 15 seconds for memory processing and user auto-creation...")
time.sleep(15)
- print(f"[OK] Processing complete")
+ print("[OK] Processing complete")
# Step 5: Verify user was auto-created
- print(f"[Step 5] Verifying user was auto-created in database...")
+ print("[Step 5] Verifying user was auto-created in database...")
assert user_exists(client, user_id), f"User {user_id} should have been auto-created"
print(f"[OK] User {user_id} was auto-created successfully")
# Step 6: Verify memory can be retrieved
- print(f"[Step 6] Retrieving memory to verify...")
+ print("[Step 6] Retrieving memory to verify...")
retrieve_response = client.retrieve_with_conversation(
user_id=user_id,
messages=[{"role": "user", "content": [{"type": "text", "text": "Where am I moving to?"}]}],
limit=5
)
- print(f"[OK] Memory retrieval successful")
+ print("[OK] Memory retrieval successful")
print(f" Retrieved {len(retrieve_response.get('memories', []))} memories")
# Verify we got some memories back
@@ -300,7 +300,7 @@ def test_idempotent_create_or_get_user(client):
print(f"[OK] User created (1st call): {created_user_id_1}")
# Step 2: Call again with same user_id
- print(f"[Step 2] Calling create_or_get_user() again with same user_id...")
+ print("[Step 2] Calling create_or_get_user() again with same user_id...")
time.sleep(1) # Small delay
created_user_id_2 = client.create_or_get_user(