diff --git a/client/webui/frontend/src/lib/components/chat/SessionList.tsx b/client/webui/frontend/src/lib/components/chat/SessionList.tsx index b1712fdc16..2182d7927c 100644 --- a/client/webui/frontend/src/lib/components/chat/SessionList.tsx +++ b/client/webui/frontend/src/lib/components/chat/SessionList.tsx @@ -37,10 +37,13 @@ const SessionName: React.FC = ({ session, respondingSessionId return false; // No pulse when auto title generation is disabled } const isNewChat = !session.name || session.name === "New Chat"; - // Only pulse if this session is the one that started the response + // Pulse if this session is the one that started the response const isThisSessionResponding = respondingSessionId === session.id; - return isThisSessionResponding && isNewChat; - }, [session.name, session.id, respondingSessionId, isGenerating, autoTitleGenerationEnabled]); + // Also pulse if this session has a running background task and no title yet + // This handles the case where user switched away while task is running + const hasBackgroundTaskWithNewTitle = session.hasRunningBackgroundTask && isNewChat; + return (isThisSessionResponding && isNewChat) || hasBackgroundTaskWithNewTitle; + }, [session.name, session.id, respondingSessionId, isGenerating, autoTitleGenerationEnabled, session.hasRunningBackgroundTask]); // Show slow pulse while waiting for title, faster pulse during transition animation const animationClass = useMemo(() => { diff --git a/client/webui/frontend/src/lib/hooks/useBackgroundTaskMonitor.ts b/client/webui/frontend/src/lib/hooks/useBackgroundTaskMonitor.ts index c323a710b6..4a3df9af86 100644 --- a/client/webui/frontend/src/lib/hooks/useBackgroundTaskMonitor.ts +++ b/client/webui/frontend/src/lib/hooks/useBackgroundTaskMonitor.ts @@ -1,3 +1,8 @@ +/** + * Hook for monitoring and reconnecting to background tasks. + * Stores active background tasks in localStorage and automatically reconnects on session load. + */ + import { useState, useEffect, useCallback, useRef } from "react"; import { api } from "@/lib/api"; import type { BackgroundTaskState, BackgroundTaskStatusResponse, ActiveBackgroundTasksResponse, BackgroundTaskNotification } from "@/lib/types/background-tasks"; @@ -17,8 +22,8 @@ interface UseBackgroundTaskMonitorProps { */ export function useBackgroundTaskMonitor({ userId, onTaskCompleted, onTaskFailed }: UseBackgroundTaskMonitorProps) { const [backgroundTasks, setBackgroundTasks] = useState([]); + const backgroundTasksRef = useRef(backgroundTasks); const [notifications, setNotifications] = useState([]); - const backgroundTasksRef = useRef([]); // Load background tasks from localStorage on mount useEffect(() => { @@ -34,6 +39,11 @@ export function useBackgroundTaskMonitor({ userId, onTaskCompleted, onTaskFailed } }, []); + // Keep the ref in sync with state + useEffect(() => { + backgroundTasksRef.current = backgroundTasks; + }, [backgroundTasks]); + // Save background tasks to localStorage whenever they change useEffect(() => { if (backgroundTasks.length > 0) { @@ -101,6 +111,7 @@ export function useBackgroundTaskMonitor({ userId, onTaskCompleted, onTaskFailed ); // Check all background tasks and update their status + // Uses backgroundTasksRef to read current tasks, making this callback stable const checkAllBackgroundTasks = useCallback(async () => { const tasks = backgroundTasksRef.current; if (tasks.length === 0) { @@ -202,6 +213,7 @@ export function useBackgroundTaskMonitor({ userId, onTaskCompleted, onTaskFailed // Periodic checking to detect background task completion when not connected to SSE // This handles the case where a task completes while the user is on a different session const hasBackgroundTasks = backgroundTasks.length > 0; + useEffect(() => { if (!hasBackgroundTasks) { return; diff --git a/client/webui/frontend/src/lib/hooks/useBeforeUnload.ts b/client/webui/frontend/src/lib/hooks/useBeforeUnload.ts index 1a79348ff2..a7ae630064 100644 --- a/client/webui/frontend/src/lib/hooks/useBeforeUnload.ts +++ b/client/webui/frontend/src/lib/hooks/useBeforeUnload.ts @@ -3,15 +3,25 @@ import { useChatContext } from "./useChatContext"; import { useConfigContext } from "./useConfigContext"; export function useBeforeUnload() { - const { messages } = useChatContext(); + const { messages, isResponding } = useChatContext(); const config = useConfigContext(); /** * Cross-browser beforeunload event handler - * Only warns when persistence is disabled and messages exist + * Warns when: + * 1. Persistence is disabled and messages exist (chat history would be lost) + * 2. A task is running and background tasks are disabled (task results would be lost) */ const handleBeforeUnload = useCallback( (event: BeforeUnloadEvent): string | void => { + // Case 1: Task is running and background tasks are disabled + // In this case, navigating away may cause the user to lose the task results + if (isResponding && config?.backgroundTasksEnabled === false) { + event.preventDefault(); + return "A task is currently running. If you leave now, you may lose the response. Are you sure you want to leave?"; + } + + // Case 2: Persistence disabled and messages exist (original behavior) if (config?.persistenceEnabled !== false) { return; } @@ -24,7 +34,7 @@ export function useBeforeUnload() { return "Are you sure you want to leave? Your chat history will be lost."; }, - [messages.length, config?.persistenceEnabled] + [messages.length, config?.persistenceEnabled, config?.backgroundTasksEnabled, isResponding] ); /** diff --git a/client/webui/frontend/src/lib/hooks/useTitleGeneration.ts b/client/webui/frontend/src/lib/hooks/useTitleGeneration.ts index ec8baf2552..6d5c02a279 100644 --- a/client/webui/frontend/src/lib/hooks/useTitleGeneration.ts +++ b/client/webui/frontend/src/lib/hooks/useTitleGeneration.ts @@ -35,7 +35,7 @@ export const useTitleGeneration = () => { } } - console.log(`[useTitleGeneration] Initial title: "${initialTitle}"`); + console.debug(`[useTitleGeneration] Initial title: "${initialTitle}"`); // Dispatch event to indicate title generation is starting if (typeof window !== "undefined") { @@ -66,7 +66,7 @@ export const useTitleGeneration = () => { return; } - console.log("[useTitleGeneration] Title generation triggered, polling for update..."); + console.debug("[useTitleGeneration] Title generation triggered, polling for update..."); // Poll for title update with exponential backoff const pollForTitle = async () => { diff --git a/client/webui/frontend/src/lib/providers/ChatProvider.tsx b/client/webui/frontend/src/lib/providers/ChatProvider.tsx index 6494540930..d5d244c1eb 100644 --- a/client/webui/frontend/src/lib/providers/ChatProvider.tsx +++ b/client/webui/frontend/src/lib/providers/ChatProvider.tsx @@ -12,6 +12,7 @@ import { ChatContext, type ChatContextValue, type PendingPromptData } from "@/li import { useConfigContext, useArtifacts, useAgentCards, useTaskContext, useErrorDialog, useTitleGeneration, useBackgroundTaskMonitor, useArtifactPreview, useArtifactOperations, useAuthContext } from "@/lib/hooks"; import { useProjectContext, registerProjectDeletedCallback } from "@/lib/providers"; import { getErrorMessage, fileToBase64, migrateTask, CURRENT_SCHEMA_VERSION, getApiBearerToken, internalToDisplayText } from "@/lib/utils"; +import { ConfirmationDialog } from "@/lib/components/common/ConfirmationDialog"; import type { CancelTaskRequest, @@ -96,6 +97,16 @@ export const ChatProvider: React.FC = ({ children }) => { const backgroundTasksRef = useRef([]); const messagesRef = useRef([]); + // Ref to hold the replay function - allows calling from loadSessionTasks which is defined earlier + const replayBufferedEventsRef = useRef<((taskId: string) => Promise) | null>(null); + + // Ref to hold handleSseMessage - allows calling from loadSessionTasks for buffer replay + const handleSseMessageRef = useRef<((event: MessageEvent) => void) | null>(null); + + // Track if we're currently replaying buffered events + // When true, handleSseMessage will skip save operations since the data is already persisted + const isReplayingEventsRef = useRef(false); + // Track query history for deep research progress timeline // This accumulates queries and their URLs as they come in via deep_research_progress events const deepResearchQueryHistoryRef = useRef< @@ -130,6 +141,11 @@ export const ChatProvider: React.FC = ({ children }) => { // Pending prompt state for starting new chat with a prompt template const [pendingPrompt, setPendingPrompt] = useState(null); + // Running task navigation warning state + // Used when background tasks are disabled and user tries to switch sessions while a task is running + const [runningTaskWarningOpen, setRunningTaskWarningOpen] = useState(false); + const [pendingNavigationAction, setPendingNavigationAction] = useState<(() => void) | null>(null); + // Notification Helper const addNotification = useCallback((message: string, type?: "success" | "info" | "warning") => { setNotifications(prev => { @@ -249,29 +265,23 @@ export const ChatProvider: React.FC = ({ children }) => { sessionsWithAutoGeneratedTitles.current.add(taskSessionId); try { - // Fetch messages from database - const taskData = await api.webui.get(`/api/v1/sessions/${taskSessionId}/chat-tasks`); - const tasks = taskData.tasks || []; - - // Find the task that matches - const matchingTask = tasks.find((t: StoredTaskData) => t.taskId === taskId); - if (matchingTask) { - const messageBubbles = JSON.parse(matchingTask.messageBubbles); - let userMessageText = ""; - let agentResponseText = ""; - - for (const bubble of messageBubbles) { - const text = bubble.text || ""; - if (bubble.type === "user" && !userMessageText) { - userMessageText = text; - } else if (bubble.type === "agent" && !agentResponseText) { - agentResponseText = text; - } - } - - if (userMessageText && agentResponseText) { - await generateTitle(taskSessionId, userMessageText, agentResponseText); - } + // Use the title-data endpoint which extracts messages from chat_tasks OR SSE buffer + // This works even when FE wasn't watching and chat_tasks weren't saved yet + console.debug("[ChatProvider] Fetching title-data for background task:", taskId); + const titleData = await api.webui.get(`/api/v1/tasks/${taskId}/title-data`); + console.debug("[ChatProvider] Title-data response:", titleData); + + const userMessageText = titleData?.user_message || ""; + const agentResponseText = titleData?.agent_response || ""; + + if (userMessageText && agentResponseText) { + console.debug("[ChatProvider] Calling generateTitle for background task"); + await generateTitle(taskSessionId, userMessageText, agentResponseText); + } else { + console.warn("[ChatProvider] Missing messages for title generation:", { + hasUser: !!userMessageText, + hasAgent: !!agentResponseText, + }); } } catch (error) { console.error("[ChatProvider] Error generating title for completed background task:", error); @@ -528,9 +538,52 @@ export const ChatProvider: React.FC = ({ children }) => { // Deserialize all tasks to messages const allMessages: MessageFE[] = []; + + // Track which tasks have buffered events (need replay) + const tasksWithBufferedEvents = new Map(); + + if (replayBufferedEventsRef.current && backgroundTasksEnabled) { + try { + // Use include_events=true to get all events in one batch + const response = await api.webui.get(`/api/v1/sessions/${sessionId}/events/unconsumed?include_events=true`); + if (response.has_events && response.events_by_task) { + // Populate the map from the batched response + for (const [taskId, taskEvents] of Object.entries(response.events_by_task)) { + const typedEvents = taskEvents as { events_buffered: boolean; events: any[] }; + if (typedEvents.events_buffered && typedEvents.events.length > 0) { + console.debug(`[loadSessionTasks] Task ${taskId} has ${typedEvents.events.length} buffered events (batched)`); + tasksWithBufferedEvents.set(taskId, typedEvents); + } + } + console.debug(`[loadSessionTasks] Loaded buffered events for ${tasksWithBufferedEvents.size} tasks in single request`); + } + } catch (error) { + // Fall back to per-task queries if batched endpoint fails + console.warn(`[loadSessionTasks] Batched event fetch failed, falling back to per-task queries:`, error); + for (const task of migratedTasks) { + try { + const response = await api.webui.get(`/api/v1/tasks/${task.taskId}/events/buffered?mark_consumed=false`); + if (response.events_buffered && response.events.length > 0) { + console.debug(`[loadSessionTasks] Task ${task.taskId} has ${response.events.length} buffered events (fallback)`); + tasksWithBufferedEvents.set(task.taskId, response); + } + } catch (taskError) { + console.debug(`[loadSessionTasks] Could not check buffered events for task ${task.taskId}:`, taskError); + } + } + } + } + for (const task of migratedTasks) { const taskMessages = deserializeTaskToMessages(task, sessionId); - allMessages.push(...taskMessages); + + if (tasksWithBufferedEvents.has(task.taskId)) { + const userMessages = taskMessages.filter(m => m.isUser); + allMessages.push(...userMessages); + } else { + // No buffered events - add all messages from chat_tasks + allMessages.push(...taskMessages); + } } // Extract feedback state from task metadata @@ -562,8 +615,7 @@ export const ChatProvider: React.FC = ({ children }) => { } } - // Update state - setMessages(allMessages); + // Update feedback state setSubmittedFeedback(feedbackMap); // Restore RAG data @@ -581,8 +633,161 @@ export const ChatProvider: React.FC = ({ children }) => { const mostRecentTask = migratedTasks[migratedTasks.length - 1]; setTaskIdInSidePanel(mostRecentTask.taskId); } + + // Process messages and buffered events to build the final message array + // We need to reconstruct agent responses from buffered events and insert them + // in the correct position (after their corresponding user messages) + if (replayBufferedEventsRef.current && tasksWithBufferedEvents.size > 0) { + console.debug(`[loadSessionTasks] Processing ${migratedTasks.length} tasks with ${tasksWithBufferedEvents.size} having buffered events`); + + // Build the complete message array with proper ordering + const finalMessages: MessageFE[] = []; + + // Track tasks where we used buffer reconstruction (need save + cleanup) + const tasksNeedingSaveAndCleanup: string[] = []; + // Track tasks where chat_tasks already exists but buffer needs cleanup (cleanup only, no save) + const tasksNeedingBufferCleanupOnly: string[] = []; + + // Collect tasks that need buffer replay (no saved agent response) + const tasksNeedingReplay: Array<{ taskId: string; events: any[] }> = []; + + for (const task of migratedTasks) { + const taskMessages = deserializeTaskToMessages(task, sessionId); + const bufferedData = tasksWithBufferedEvents.get(task.taskId); + const agentMessagesFromChatTasks = taskMessages.filter(m => !m.isUser); + const userMessages = taskMessages.filter(m => m.isUser); + + // IMPORTANT: If chat_tasks already has agent messages, PREFER them over buffer replay. + // The buffer should only be used when: + // 1. There are no agent messages saved yet (task was running when user switched away) + // 2. Buffer replay is needed for incomplete saves + // This prevents issues with complex parsing of SSE events. + if (bufferedData && agentMessagesFromChatTasks.length === 0) { + // Task has buffered events AND no saved agent response yet - need replay + console.debug(`[loadSessionTasks] Task ${task.taskId} has buffered events and no saved agent response, will replay through handleSseMessage`); + // Add user messages first + finalMessages.push(...userMessages); + // Store for replay after we set initial messages + tasksNeedingReplay.push({ taskId: task.taskId, events: bufferedData.events }); + // Mark this task for save and cleanup (will update session modified date) + tasksNeedingSaveAndCleanup.push(task.taskId); + } else { + // Either no buffered events OR chat_tasks already has the agent response + // Use the saved data from chat_tasks (it has correct ordering) + if (bufferedData && agentMessagesFromChatTasks.length > 0) { + console.debug(`[loadSessionTasks] Task ${task.taskId} has buffered events but also has ${agentMessagesFromChatTasks.length} saved agent messages - preferring chat_tasks data, scheduling buffer cleanup only`); + // Chat task already exists - only clean up the buffer (don't re-save to avoid updating modified date) + tasksNeedingBufferCleanupOnly.push(task.taskId); + } + finalMessages.push(...taskMessages); + } + } + + // Set all messages in ONE batch to avoid scroll issues + setMessages(finalMessages); + + // Replay buffered events through handleSseMessage (uses exact same code path as live streaming!) + // This is done AFTER setting initial messages so user messages are visible first + if (tasksNeedingReplay.length > 0) { + console.debug(`[loadSessionTasks] Replaying ${tasksNeedingReplay.length} tasks through handleSseMessage`); + + // Use setTimeout to not block the UI update + setTimeout(async () => { + // Get the current handleSseMessage function via ref + const handleSseMessageFn = handleSseMessageRef.current; + if (!handleSseMessageFn) { + console.error(`[loadSessionTasks] handleSseMessageRef not available for replay`); + return; + } + + for (const { taskId, events } of tasksNeedingReplay) { + console.debug(`[loadSessionTasks] Replaying ${events.length} events for task ${taskId}`); + + // Set replay flag to prevent saves during replay (state updates are async) + isReplayingEventsRef.current = true; + + try { + // Process each buffered event through handleSseMessage + // This is the EXACT same code path as live SSE streaming! + for (const bufferedEvent of events) { + const ssePayload = bufferedEvent.data; + if (ssePayload?.data) { + // Create a synthetic MessageEvent-like object + const syntheticEvent = { + data: ssePayload.data, + } as MessageEvent; + + // Process through SSE handler - exact same as live streaming + handleSseMessageFn(syntheticEvent); + } + } + } finally { + isReplayingEventsRef.current = false; + } + } + + // After replay, get messages and save using setMessages functional updater + const taskIdsToSave = tasksNeedingReplay.map(t => t.taskId); + setMessages(currentMessages => { + setTimeout(async () => { + for (const taskId of taskIdsToSave) { + const taskMessages = currentMessages.filter(m => m.taskId === taskId && !m.isStatusBubble); + if (taskMessages.length > 0) { + console.debug(`[loadSessionTasks] Saving task ${taskId} after replay (${taskMessages.length} messages)`); + const messageBubbles = taskMessages.map(serializeMessageBubble); + const userMessage = taskMessages.find(m => m.isUser); + const userMessageText = + userMessage?.parts + ?.filter(p => p.kind === "text") + .map(p => (p as TextPart).text) + .join("") || ""; + const hasError = taskMessages.some(m => m.isError); + const taskStatus = hasError ? "error" : "completed"; + + await saveTaskToBackend( + { + task_id: taskId, + user_message: userMessageText, + message_bubbles: messageBubbles, + task_metadata: { + schema_version: CURRENT_SCHEMA_VERSION, + status: taskStatus, + }, + }, + sessionId + ); + } else { + console.warn(`[loadSessionTasks] No messages found for task ${taskId} after replay`); + } + } + }, 0); + // Return unchanged - we're just reading state + return currentMessages; + }); + }, 50); + } + + // Trigger buffer cleanup only (no save) for tasks where chat_tasks already exists + // This avoids updating the session's modified date just from visiting it + if (tasksNeedingBufferCleanupOnly.length > 0) { + console.debug(`[loadSessionTasks] Scheduling buffer cleanup only for ${tasksNeedingBufferCleanupOnly.length} tasks`); + setTimeout(async () => { + for (const taskId of tasksNeedingBufferCleanupOnly) { + try { + await api.webui.delete(`/api/v1/tasks/${taskId}/events/buffered`); + console.debug(`[loadSessionTasks] Buffer cleanup successful for task ${taskId}`); + } catch (error) { + console.error(`[loadSessionTasks] Failed to clean up buffer for task ${taskId}:`, error); + } + } + }, 100); + } + } else { + // No tasks with buffered events - just set all messages at once + setMessages(allMessages); + } }, - [deserializeTaskToMessages, setRagData] + [deserializeTaskToMessages, setRagData, backgroundTasksEnabled, serializeMessageBubble, saveTaskToBackend] ); // Session State @@ -1304,186 +1509,99 @@ export const ChatProvider: React.FC = ({ children }) => { setIsCancelling(false); } - // Save complete task when agent response is done (Step 10.5-10.9) - // Note: For background tasks, the backend TaskLoggerService handles saving automatically - // For non-background tasks, we save here - if (currentTaskIdFromResult) { + // Save on final event - the save is idempotent (upsert) and will trigger buffer cleanup + // SKIP save during buffer replay because messagesRef won't have the updated messages yet + // (React state updates are async). The save will happen in loadSessionTasks after replay. + if (currentTaskIdFromResult && !isReplayingEventsRef.current) { const isBackgroundTask = isTaskRunningInBackground(currentTaskIdFromResult); - - // Only save non-background tasks from frontend - // Background tasks are saved by TaskLoggerService to avoid race conditions - if (!isBackgroundTask) { - // Use messagesRef to get the latest messages - const taskMessages = messagesRef.current.filter(msg => msg.taskId === currentTaskIdFromResult && !msg.isStatusBubble); - - if (taskMessages.length > 0) { - // Serialize all message bubbles - const messageBubbles = taskMessages.map(serializeMessageBubble); - - // Extract user message text - const userMessage = taskMessages.find(m => m.isUser); - const userMessageText = - userMessage?.parts - ?.filter(p => p.kind === "text") - .map(p => (p as TextPart).text) - .join("") || ""; - - // Determine task status - const hasError = taskMessages.some(m => m.isError); - const taskStatus = hasError ? "error" : "completed"; - - const taskRagData = ragDataRef.current.filter(r => r.taskId === currentTaskIdFromResult); - - // Get the session ID from the task's context - const taskSessionId = (result as TaskStatusUpdateEvent).contextId || sessionId; - - // Save complete task - saveTaskToBackend( - { - task_id: currentTaskIdFromResult, - user_message: userMessageText, - message_bubbles: messageBubbles, - task_metadata: { - schema_version: CURRENT_SCHEMA_VERSION, - status: taskStatus, - agent_name: selectedAgentName, - rag_data: taskRagData.length > 0 ? taskRagData : undefined, // Persist RAG data - }, + const taskSessionId = (result as TaskStatusUpdateEvent).contextId || sessionId; + + // Use messagesRef to get the latest messages + const taskMessages = messagesRef.current.filter(msg => msg.taskId === currentTaskIdFromResult && !msg.isStatusBubble); + + if (taskMessages.length > 0) { + // Serialize all message bubbles + const messageBubbles = taskMessages.map(serializeMessageBubble); + + // Extract user message text + const userMessage = taskMessages.find(m => m.isUser); + const userMessageText = + userMessage?.parts + ?.filter(p => p.kind === "text") + .map(p => (p as TextPart).text) + .join("") || ""; + + // Determine task status + const hasError = taskMessages.some(m => m.isError); + const taskStatus = hasError ? "error" : "completed"; + + const taskRagData = ragDataRef.current.filter(r => r.taskId === currentTaskIdFromResult); + + // Save complete task + saveTaskToBackend( + { + task_id: currentTaskIdFromResult, + user_message: userMessageText, + message_bubbles: messageBubbles, + task_metadata: { + schema_version: CURRENT_SCHEMA_VERSION, + status: taskStatus, + agent_name: selectedAgentName, + rag_data: taskRagData.length > 0 ? taskRagData : undefined, // Persist RAG data }, - taskSessionId - ) - .then(async saved => { - if (saved && typeof window !== "undefined") { + }, + taskSessionId + ) + .then(async saved => { + if (saved) { + if (typeof window !== "undefined") { window.dispatchEvent(new CustomEvent("new-chat-session")); } - // Handle session title based on feature flag - if (taskSessionId && !sessionsWithAutoGeneratedTitles.current.has(taskSessionId)) { - if (autoTitleGenerationEnabled) { - // Trigger automatic title generation for new sessions (if feature is enabled) - // Only trigger once per session (tracked by sessionsWithAutoGeneratedTitles ref) - // The generateTitle function will check the actual session name from the API - // and skip generation if the session already has a meaningful title - // Extract agent response text for title generation - const agentMessage = taskMessages.find(m => !m.isUser); - const agentResponseText = - agentMessage?.parts - ?.filter(p => p.kind === "text") - .map(p => (p as TextPart).text) - .join("") || ""; - - // Pass messages directly - no database dependency, no delays needed - if (userMessageText && agentResponseText) { - // Mark this session as having had AUTOMATIC title generation attempted - sessionsWithAutoGeneratedTitles.current.add(taskSessionId); - generateTitle(taskSessionId, userMessageText, agentResponseText).catch(error => { - console.error("[ChatProvider] Title generation failed:", error); - }); - } - } - } - }) - .catch(error => { - console.error(`[ChatProvider] Error saving task ${currentTaskIdFromResult}:`, error); - }); - } - } else { - // For background tasks, unregister after completion - console.log(`[ChatProvider] Background task ${currentTaskIdFromResult} completed. isBackgroundTask=true`); - unregisterBackgroundTask(currentTaskIdFromResult); - - // Trigger session list refresh - if (typeof window !== "undefined") { - window.dispatchEvent(new CustomEvent("new-chat-session")); - } - - // Also trigger title generation for background tasks (if feature is enabled) - const taskSessionId = (result as TaskStatusUpdateEvent).contextId || sessionId; - console.log(`[ChatProvider] Title generation check: autoTitleGenerationEnabled=${autoTitleGenerationEnabled}, taskSessionId=${taskSessionId}, alreadyGenerated=${sessionsWithAutoGeneratedTitles.current.has(taskSessionId)}`); - - if (autoTitleGenerationEnabled && taskSessionId && !sessionsWithAutoGeneratedTitles.current.has(taskSessionId)) { - // Mark this session as having had title generation attempted immediately - // to prevent duplicate attempts - sessionsWithAutoGeneratedTitles.current.add(taskSessionId); - console.log(`[ChatProvider] Starting title generation for background task ${currentTaskIdFromResult} in session ${taskSessionId}`); - - // Use an async IIFE to handle the async operations - (async () => { - // Use messagesRef to get the latest messages for background tasks too - const bgTaskMessages = messagesRef.current.filter(msg => msg.taskId === currentTaskIdFromResult && !msg.isStatusBubble); - console.log(`[ChatProvider] Found ${bgTaskMessages.length} messages in memory for task ${currentTaskIdFromResult}`); - console.log( - `[ChatProvider] All messages in memory:`, - messagesRef.current.map(m => ({ taskId: m.taskId, isUser: m.isUser, hasText: m.parts?.some(p => p.kind === "text") })) - ); - - let userMessageText = ""; - let agentResponseText = ""; - - const userMessage = bgTaskMessages.find(m => m.isUser); - userMessageText = - userMessage?.parts - ?.filter(p => p.kind === "text") - .map(p => (p as TextPart).text) - .join("") || ""; - const agentMessage = bgTaskMessages.find(m => !m.isUser); - agentResponseText = - agentMessage?.parts - ?.filter(p => p.kind === "text") - .map(p => (p as TextPart).text) - .join("") || ""; - - console.log(`[ChatProvider] From memory - userMessageText: "${userMessageText.substring(0, 50)}...", agentResponseText: "${agentResponseText.substring(0, 50)}..."`); - - // If messages not available in memory (e.g., after browser refresh), - // fetch from database - if (!userMessageText || !agentResponseText) { - console.log(`[ChatProvider] Messages not in memory for background task ${currentTaskIdFromResult}, fetching from database`); - try { - const taskData = await api.webui.get(`/api/v1/sessions/${taskSessionId}/chat-tasks`); - const tasks = taskData.tasks || []; - console.log(`[ChatProvider] Fetched ${tasks.length} tasks from database`); - - // Find the task that matches - const matchingTask = tasks.find((t: StoredTaskData) => t.taskId === currentTaskIdFromResult); - console.log(`[ChatProvider] Matching task found: ${!!matchingTask}`); + } - if (matchingTask) { - const messageBubbles = JSON.parse(matchingTask.messageBubbles); - console.log(`[ChatProvider] Task has ${messageBubbles.length} message bubbles`); + // Unregister background task after save completes + if (isBackgroundTask) { + unregisterBackgroundTask(currentTaskIdFromResult); + } - for (const bubble of messageBubbles) { - const text = bubble.text || ""; - console.log(`[ChatProvider] Bubble type=${bubble.type}, text length=${text.length}`); - if (bubble.type === "user" && !userMessageText) { - userMessageText = text; - } else if (bubble.type === "agent" && !agentResponseText) { - agentResponseText = text; - } - } + // Handle session title based on feature flag + if (taskSessionId && !sessionsWithAutoGeneratedTitles.current.has(taskSessionId)) { + if (autoTitleGenerationEnabled) { + // Trigger automatic title generation for new sessions (if feature is enabled) + // Only trigger once per session (tracked by sessionsWithAutoGeneratedTitles ref) + // The generateTitle function will check the actual session name from the API + // and skip generation if the session already has a meaningful title + // Extract agent response text for title generation + const agentMessage = taskMessages.find(m => !m.isUser); + const agentResponseText = + agentMessage?.parts + ?.filter(p => p.kind === "text") + .map(p => (p as TextPart).text) + .join("") || ""; + + // Pass messages directly - no database dependency, no delays needed + if (userMessageText && agentResponseText) { + // Mark this session as having had AUTOMATIC title generation attempted + sessionsWithAutoGeneratedTitles.current.add(taskSessionId); + generateTitle(taskSessionId, userMessageText, agentResponseText).catch(error => { + console.error("[ChatProvider] Title generation failed:", error); + }); } - } catch (error) { - console.error("[ChatProvider] Error fetching task data for title generation:", error); } } - - console.log(`[ChatProvider] Final - userMessageText: "${userMessageText.substring(0, 50)}...", agentResponseText: "${agentResponseText.substring(0, 50)}..."`); - - if (userMessageText && agentResponseText) { - console.log(`[ChatProvider] Calling generateTitle for session ${taskSessionId}`); - try { - await generateTitle(taskSessionId, userMessageText, agentResponseText); - console.log(`[ChatProvider] generateTitle completed for session ${taskSessionId}`); - } catch (error) { - console.error("[ChatProvider] Background task title generation failed:", error); - } - } else { - console.warn(`[ChatProvider] Cannot generate title - missing messages. User: ${!!userMessageText}, Agent: ${!!agentResponseText}`); + }) + .catch(error => { + console.error(`[ChatProvider] Error saving task ${currentTaskIdFromResult}:`, error); + // Still unregister background task even on save error + if (isBackgroundTask) { + unregisterBackgroundTask(currentTaskIdFromResult); } - })(); - } else { - console.log( - `[ChatProvider] Skipping title generation: autoTitleGenerationEnabled=${autoTitleGenerationEnabled}, taskSessionId=${taskSessionId}, alreadyGenerated=${sessionsWithAutoGeneratedTitles.current.has(taskSessionId)}` - ); + }); + } else if (isBackgroundTask) { + // No messages but task was background - still unregister + unregisterBackgroundTask(currentTaskIdFromResult); + if (typeof window !== "undefined") { + window.dispatchEvent(new CustomEvent("new-chat-session")); } } } @@ -1559,7 +1677,78 @@ export const ChatProvider: React.FC = ({ children }) => { ] ); - const handleNewSession = useCallback( + // Helper function to replay buffered SSE events for a background task + // This is used when a background task completed while the user was away + const replayBufferedEvents = useCallback( + async (taskId: string): Promise => { + try { + console.debug(`[ChatProvider] Fetching buffered events for task ${taskId}`); + const response = await api.webui.get(`/api/v1/tasks/${taskId}/events/buffered`); + + if (!response.events_buffered || response.events.length === 0) { + console.debug(`[ChatProvider] No buffered events for task ${taskId}`); + return false; + } + + console.debug(`[ChatProvider] Replaying ${response.events.length} buffered events for task ${taskId}`); + + // Reset the SSE event sequence counter for replay + // This ensures proper sequencing during replay + const originalSequence = sseEventSequenceRef.current; + sseEventSequenceRef.current = 0; + + // Clear any existing AGENT messages for this task before replaying + // Keep user messages since they're not in the SSE events + // This prevents duplicate content when replaying while preserving user input + setMessages(prev => prev.filter(msg => msg.taskId !== taskId || msg.isUser)); + + // Set the replay flag to prevent save operations during replay + // The data is already persisted, so we don't want to re-save and update timestamps + isReplayingEventsRef.current = true; + + try { + // Process each buffered event through the SSE handler + for (const bufferedEvent of response.events) { + // The buffered event data contains the full SSE payload + // which has {event: "message", data: "...serialized JSON..."} + const ssePayload = bufferedEvent.data; + + if (ssePayload?.data) { + // Create a synthetic MessageEvent-like object + const syntheticEvent = { + data: ssePayload.data, // This is the serialized JSON string + } as MessageEvent; + + // Process through the SSE handler + handleSseMessage(syntheticEvent); + } + } + } finally { + // Always reset the replay flag, even if an error occurred + isReplayingEventsRef.current = false; + } + + // Restore the sequence counter (or keep the new value if higher) + sseEventSequenceRef.current = Math.max(originalSequence, sseEventSequenceRef.current); + + console.debug(`[ChatProvider] Finished replaying buffered events for task ${taskId}`); + return true; + } catch (error) { + console.error(`[ChatProvider] Error replaying buffered events for task ${taskId}:`, error); + isReplayingEventsRef.current = false; // Reset flag on error + return false; + } + }, + [handleSseMessage] + ); + + // Keep the ref in sync with the latest replayBufferedEvents function + useEffect(() => { + replayBufferedEventsRef.current = replayBufferedEvents; + }, [replayBufferedEvents]); + + // Core implementation - called directly or after confirmation + const handleNewSessionCore = useCallback( async (preserveProjectContext: boolean = false) => { const log_prefix = "ChatProvider.handleNewSession:"; @@ -1624,6 +1813,22 @@ export const ChatProvider: React.FC = ({ children }) => { [isResponding, currentTaskId, selectedAgentName, isCancelling, closeCurrentEventSource, activeProject, setActiveProject, closePreview, isTaskRunningInBackground, setRagData] ); + // Wrapper that shows confirmation when task is running and background tasks are disabled + const handleNewSession = useCallback( + async (preserveProjectContext: boolean = false) => { + // Check if we need to warn the user about losing a running task + if (isResponding && backgroundTasksEnabled === false) { + // Store the action to execute after confirmation + setPendingNavigationAction(() => () => handleNewSessionCore(preserveProjectContext)); + setRunningTaskWarningOpen(true); + return; + } + // No warning needed - proceed directly + await handleNewSessionCore(preserveProjectContext); + }, + [isResponding, backgroundTasksEnabled, handleNewSessionCore] + ); + // Start a new chat session with a prompt template pre-filled const startNewChatWithPrompt = useCallback( (promptData: PendingPromptData) => { @@ -1640,7 +1845,8 @@ export const ChatProvider: React.FC = ({ children }) => { setPendingPrompt(null); }, []); - const handleSwitchSession = useCallback( + // Core implementation - called directly or after confirmation + const handleSwitchSessionCore = useCallback( async (newSessionId: string) => { const log_prefix = "ChatProvider.handleSwitchSession:"; console.log(`${log_prefix} Switching to session ${newSessionId}...`); @@ -1754,8 +1960,9 @@ export const ChatProvider: React.FC = ({ children }) => { break; } else { // Task is no longer running - it completed while we were away - // Unregister it and trigger title generation if needed - console.log(`[ChatProvider] Background task ${bgTask.taskId} completed while away, unregistering`); + console.log(`[ChatProvider] Background task ${bgTask.taskId} completed while away, already reconstructed in loadSessionTasks`); + + // Unregister the background task unregisterBackgroundTask(bgTask.taskId); // Trigger title generation for completed background task (if feature is enabled) @@ -1831,9 +2038,26 @@ export const ChatProvider: React.FC = ({ children }) => { autoTitleGenerationEnabled, generateTitle, isTaskRunningInBackground, + replayBufferedEvents, ] ); + // Wrapper that shows confirmation when task is running and background tasks are disabled + const handleSwitchSession = useCallback( + async (newSessionId: string) => { + // Check if we need to warn the user about losing a running task + if (isResponding && backgroundTasksEnabled === false) { + // Store the action to execute after confirmation + setPendingNavigationAction(() => () => handleSwitchSessionCore(newSessionId)); + setRunningTaskWarningOpen(true); + return; + } + // No warning needed - proceed directly + await handleSwitchSessionCore(newSessionId); + }, + [isResponding, backgroundTasksEnabled, handleSwitchSessionCore] + ); + const updateSessionName = useCallback( async (sessionId: string, newName: string) => { try { @@ -2394,8 +2618,9 @@ export const ChatProvider: React.FC = ({ children }) => { useEffect(() => { // Listen for background task completion events - // When a background task completes, reload ANY session it belongs to (not just current) - // This ensures we get the latest data even if the task completed while we were in a different session + // When a background task completes, reload the session if it's currently active + // If the user is on a different session, the buffer stays until they switch back + // (buffer cleanup happens after replay + save) const handleBackgroundTaskCompleted = async (event: Event) => { const customEvent = event as CustomEvent; const { taskId: completedTaskId } = customEvent.detail; @@ -2403,15 +2628,28 @@ export const ChatProvider: React.FC = ({ children }) => { // Find the completed task const completedTask = backgroundTasksRef.current.find(t => t.taskId === completedTaskId); if (completedTask) { - console.log(`[ChatProvider] Background task ${completedTaskId} completed, will reload session ${completedTask.sessionId} after delay`); - // Wait a bit to ensure any pending operations complete - setTimeout(async () => { - // Reload the session if it's currently active - if (currentSessionIdRef.current === completedTask.sessionId) { - console.log(`[ChatProvider] Reloading current session ${completedTask.sessionId} to get latest data`); - await loadSessionTasks(completedTask.sessionId); - } - }, 1500); // Increased delay to ensure save completes + console.log(`[ChatProvider] Background task ${completedTaskId} completed for session ${completedTask.sessionId}`); + + // Only replay if the user is currently viewing the session where the task completed + if (currentSessionIdRef.current === completedTask.sessionId) { + console.log(`[ChatProvider] User is on same session - will replay buffered events after delay`); + // Wait a bit to ensure any pending operations complete + setTimeout(async () => { + // Try to replay buffered events first (new single-path approach) + // This ensures embeds and templates are properly resolved through the frontend's SSE processing + const replayedSuccessfully = await replayBufferedEvents(completedTaskId); + + if (!replayedSuccessfully) { + // Fall back to loading from chat_tasks if no buffered events + console.log(`[ChatProvider] No buffered events, falling back to loadSessionTasks`); + await loadSessionTasks(completedTask.sessionId); + } + }, 1500); // Delay to ensure save completes + } else { + // User is on a different session - buffer stays until they switch back to that session + // When they switch back, handleSwitchSession will detect buffered events and replay them + console.log(`[ChatProvider] User is on different session (${currentSessionIdRef.current}) - buffer preserved for later replay when they switch back to ${completedTask.sessionId}`); + } } }; @@ -2419,7 +2657,7 @@ export const ChatProvider: React.FC = ({ children }) => { return () => { window.removeEventListener("background-task-completed", handleBackgroundTaskCompleted); }; - }, [loadSessionTasks]); + }, [loadSessionTasks, replayBufferedEvents]); useEffect(() => { // When the active project changes, reset the chat view to a clean slate @@ -2500,7 +2738,7 @@ export const ChatProvider: React.FC = ({ children }) => { }, [agents, configWelcomeMessage, messages.length, selectedAgentName, sessionId, isLoadingSession, activeProject]); // Store the latest handlers in refs so they can be accessed without triggering effect re-runs - const handleSseMessageRef = useRef(handleSseMessage); + // Note: handleSseMessageRef is declared earlier (line ~103) for use in loadSessionTasks const handleSseOpenRef = useRef(handleSseOpen); const handleSseErrorRef = useRef(handleSseError); @@ -2552,7 +2790,9 @@ export const ChatProvider: React.FC = ({ children }) => { }; const wrappedHandleSseMessage = (event: MessageEvent) => { - handleSseMessageRef.current(event); + if (handleSseMessageRef.current) { + handleSseMessageRef.current(event); + } }; eventSource.onopen = wrappedHandleSseOpen; @@ -2662,10 +2902,37 @@ export const ChatProvider: React.FC = ({ children }) => { isTaskRunningInBackground, }; + // Handlers for the running task warning dialog + const handleConfirmNavigation = useCallback(() => { + setRunningTaskWarningOpen(false); + if (pendingNavigationAction) { + pendingNavigationAction(); + setPendingNavigationAction(null); + } + }, [pendingNavigationAction]); + + const handleCancelNavigation = useCallback(() => { + setRunningTaskWarningOpen(false); + setPendingNavigationAction(null); + }, []); + return ( {children} + {/* Warning dialog when user tries to navigate away while a task is running and background tasks are disabled */} + ); }; diff --git a/examples/gateways/webui_gateway_example.yaml b/examples/gateways/webui_gateway_example.yaml index a424407ded..28d3942d6b 100644 --- a/examples/gateways/webui_gateway_example.yaml +++ b/examples/gateways/webui_gateway_example.yaml @@ -90,6 +90,9 @@ apps: task_logging: enabled: true + hybrid_buffer: + enabled: false # Enable so events are buffered in RAM first + flush_threshold: 10 # Flush to DB after N events (lower for easier testing) log_status_updates: true log_artifact_events: true log_file_parts: true diff --git a/src/solace_agent_mesh/common/utils/embeds/resolver.py b/src/solace_agent_mesh/common/utils/embeds/resolver.py index b71cb3e92e..0ec512c915 100644 --- a/src/solace_agent_mesh/common/utils/embeds/resolver.py +++ b/src/solace_agent_mesh/common/utils/embeds/resolver.py @@ -550,11 +550,12 @@ async def resolve_embeds_in_string( if embed_type in types_to_resolve: log.info( - "%s Found embed type '%s' to resolve: expr='%s', fmt='%s'", + "%s Found embed type '%s' to resolve: expr='%s', fmt='%s', types_to_resolve=%s", log_identifier, embed_type, expression, format_spec, + types_to_resolve, ) resolved_value = await resolver_func( embed_type, @@ -614,9 +615,10 @@ async def resolve_embeds_in_string( else: log.debug( - "%s Skipping embed type '%s' (not in types_to_resolve)", + "%s Skipping embed type '%s' (not in types_to_resolve=%s)", log_identifier, embed_type, + types_to_resolve, ) resolved_parts.append(match.group(0)) diff --git a/src/solace_agent_mesh/gateway/base/component.py b/src/solace_agent_mesh/gateway/base/component.py index edfc503675..9ac1b91a54 100644 --- a/src/solace_agent_mesh/gateway/base/component.py +++ b/src/solace_agent_mesh/gateway/base/component.py @@ -1363,6 +1363,14 @@ async def _resolve_embeds_and_handle_signals( text_to_resolve = current_buffer + part.text current_buffer = "" # Buffer is now being processed + # Debug: Log the text before embed resolution + log.debug( + "%s Input text for embed resolution (len=%d): %s", + log_id_prefix, + len(text_to_resolve), + text_to_resolve[:500] + "..." if len(text_to_resolve) > 500 else text_to_resolve, + ) + ( resolved_text, processed_idx, @@ -1377,6 +1385,16 @@ async def _resolve_embeds_and_handle_signals( config=embed_eval_config, ) + # Debug: Log the resolved text + log.debug( + "%s Resolved text (processed_idx=%d, signals=%d, len=%d): %s", + log_id_prefix, + processed_idx, + len(signals_with_placeholders), + len(resolved_text), + resolved_text[:500] + "..." if len(resolved_text) > 500 else resolved_text, + ) + if not signals_with_placeholders: new_parts.append(a2a.create_text_part(text=resolved_text)) else: diff --git a/src/solace_agent_mesh/gateway/http_sse/alembic/versions/20260207_add_sse_event_buffer.py b/src/solace_agent_mesh/gateway/http_sse/alembic/versions/20260207_add_sse_event_buffer.py new file mode 100644 index 0000000000..6021055e43 --- /dev/null +++ b/src/solace_agent_mesh/gateway/http_sse/alembic/versions/20260207_add_sse_event_buffer.py @@ -0,0 +1,80 @@ +"""Add SSE event buffer table for background task event replay + +Revision ID: 20260207_sse_event_buffer +Revises: 20260204_fix_command_constraint +Create Date: 2026-02-07 00:00:00.000000 + +This migration adds a table to persist SSE events for background tasks. +When a background task completes while the user is away, the events are +stored in this table and replayed when the user returns to the session. +This ensures the frontend processes events through the same code path +as live streaming, guaranteeing consistency. +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '20260207_sse_event_buffer' +down_revision = '20260204_fix_command_constraint' +branch_labels = None +depends_on = None + + +def upgrade(): + """Add SSE event buffer table and related columns.""" + + # Create sse_event_buffer table + op.create_table( + 'sse_event_buffer', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('task_id', sa.String(255), nullable=False), + sa.Column('session_id', sa.String(255), nullable=False), + sa.Column('user_id', sa.String(255), nullable=False), + sa.Column('event_sequence', sa.Integer(), nullable=False), + sa.Column('event_type', sa.String(50), nullable=False), + sa.Column('event_data', sa.JSON(), nullable=False), + sa.Column('created_at', sa.BigInteger(), nullable=False), # Epoch milliseconds + sa.Column('consumed', sa.Boolean(), nullable=False, server_default=sa.text('FALSE')), + sa.Column('consumed_at', sa.BigInteger(), nullable=True), # Epoch milliseconds + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('task_id', 'event_sequence', name='sse_event_buffer_task_seq_unique') + ) + + # Create indexes for efficient queries + op.create_index('idx_sse_event_buffer_task_id', 'sse_event_buffer', ['task_id']) + op.create_index('idx_sse_event_buffer_session_id', 'sse_event_buffer', ['session_id']) + op.create_index('idx_sse_event_buffer_consumed', 'sse_event_buffer', ['consumed']) + op.create_index('idx_sse_event_buffer_created_at', 'sse_event_buffer', ['created_at']) + # Composite index for has_unconsumed_events query optimization + op.create_index('idx_sse_event_buffer_task_consumed', 'sse_event_buffer', ['task_id', 'consumed']) + + # Add columns to tasks table for tracking event buffer state + op.add_column('tasks', sa.Column('session_id', sa.String(255), nullable=True)) + op.add_column('tasks', sa.Column('events_buffered', sa.Boolean(), nullable=True, server_default=sa.text('FALSE'))) + op.add_column('tasks', sa.Column('events_consumed', sa.Boolean(), nullable=True, server_default=sa.text('FALSE'))) + + # Create indexes for efficient queries on event buffer state + op.create_index('idx_tasks_session_id', 'tasks', ['session_id']) + op.create_index('idx_tasks_events_buffered', 'tasks', ['events_buffered']) + + +def downgrade(): + """Remove SSE event buffer table and related columns.""" + + # Remove columns from tasks table + op.drop_index('idx_tasks_events_buffered', table_name='tasks') + op.drop_index('idx_tasks_session_id', table_name='tasks') + op.drop_column('tasks', 'events_consumed') + op.drop_column('tasks', 'events_buffered') + op.drop_column('tasks', 'session_id') + + # Remove indexes from sse_event_buffer table + op.drop_index('idx_sse_event_buffer_task_consumed', table_name='sse_event_buffer') + op.drop_index('idx_sse_event_buffer_created_at', table_name='sse_event_buffer') + op.drop_index('idx_sse_event_buffer_consumed', table_name='sse_event_buffer') + op.drop_index('idx_sse_event_buffer_session_id', table_name='sse_event_buffer') + op.drop_index('idx_sse_event_buffer_task_id', table_name='sse_event_buffer') + + # Drop sse_event_buffer table + op.drop_table('sse_event_buffer') diff --git a/src/solace_agent_mesh/gateway/http_sse/component.py b/src/solace_agent_mesh/gateway/http_sse/component.py index ba1543465a..82675076ae 100644 --- a/src/solace_agent_mesh/gateway/http_sse/component.py +++ b/src/solace_agent_mesh/gateway/http_sse/component.py @@ -1332,17 +1332,29 @@ def _start_fastapi_server(self): # This must be done *after* setup_dependencies has run. session_factory = dependencies.SessionLocal if self.database_url else None + # Get task logging config for hybrid buffer settings + # Hybrid buffer is OFF by default for safety - it buffers events in RAM + # before flushing to DB to reduce database writes + task_logging_config = self.get_config("task_logging", {}) + hybrid_buffer_config = task_logging_config.get("hybrid_buffer", {}) + hybrid_buffer_enabled = hybrid_buffer_config.get("enabled", False) + hybrid_buffer_threshold = hybrid_buffer_config.get("flush_threshold", 10) + # Initialize SSE manager with session factory for background task detection self.sse_manager = SSEManager( max_queue_size=self.sse_max_queue_size, event_buffer=self.sse_event_buffer, - session_factory=session_factory + session_factory=session_factory, + hybrid_buffer_enabled=hybrid_buffer_enabled, + hybrid_buffer_threshold=hybrid_buffer_threshold, ) log.debug( - "%s SSE manager initialized with database session factory.", + "%s SSE manager initialized with database session factory (hybrid_buffer=%s, threshold=%d).", self.log_identifier, + hybrid_buffer_enabled, + hybrid_buffer_threshold, ) - task_logging_config = self.get_config("task_logging", {}) + # task_logging_config already obtained above for hybrid_buffer settings self.task_logger_service = TaskLoggerService( session_factory=session_factory, config=task_logging_config ) diff --git a/src/solace_agent_mesh/gateway/http_sse/persistent_sse_event_buffer.py b/src/solace_agent_mesh/gateway/http_sse/persistent_sse_event_buffer.py new file mode 100644 index 0000000000..b5604b99f3 --- /dev/null +++ b/src/solace_agent_mesh/gateway/http_sse/persistent_sse_event_buffer.py @@ -0,0 +1,733 @@ +""" +A database-backed buffer for holding SSE events for background tasks. + +This buffer persists events to the database so they survive server restarts +and can be replayed when the user returns to the session. It works alongside +the in-memory SSEEventBuffer - the in-memory buffer handles short-term buffering +for race conditions, while this persistent buffer handles long-term storage +for background tasks. + +HYBRID MODE: +When hybrid_mode_enabled=True, events are first buffered in RAM and only +flushed to the database when: +1. The RAM buffer reaches the flush threshold (default: 10 events) +2. The SSE connection is closed (client disconnects) +3. The task completes +4. Explicit flush is requested + +This reduces database write pressure for short-lived tasks while maintaining +durability for longer-running background tasks. +""" + +import logging +import threading +import time +from typing import Any, Callable, Dict, List, Optional, Tuple + +log = logging.getLogger(__name__) + + +class PersistentSSEEventBuffer: + """ + Database-backed buffer for SSE events. + + This buffer stores events in the database for background tasks that need + to be replayed when the user returns. It's designed to work with the + SSEManager to provide persistent event storage. + + HYBRID MODE : + When enabled, events are buffered in RAM first and batched to DB to reduce + database write pressure. This is off by default for safety but can be + enabled via config for performance optimization. + """ + + def __init__( + self, + session_factory: Optional[Callable] = None, + enabled: bool = True, + hybrid_mode_enabled: bool = False, + hybrid_flush_threshold: int = 10, + ): + """ + Initialize the persistent event buffer. + + Args: + session_factory: Factory function to create database sessions + enabled: Whether persistent buffering is enabled + hybrid_mode_enabled: Whether to use RAM-first buffering (default: False) + hybrid_flush_threshold: Number of events before flushing RAM to DB (default: 10) + """ + self._session_factory = session_factory + self._enabled = enabled + self._lock = threading.Lock() + self.log_identifier = "[PersistentSSEEventBuffer]" + + # Hybrid mode configuration + self._hybrid_mode_enabled = hybrid_mode_enabled + self._hybrid_flush_threshold = hybrid_flush_threshold + + # Cache for task metadata to avoid repeated DB queries + self._task_metadata_cache: Dict[str, Dict[str, str]] = {} + + # RAM buffer for hybrid mode: task_id -> list of (event_type, event_data, timestamp) + self._ram_buffer: Dict[str, List[Tuple[str, Dict[str, Any], int]]] = {} + + log.info( + "%s Initialized (enabled=%s, has_session_factory=%s, hybrid_mode=%s, flush_threshold=%d)", + self.log_identifier, + self._enabled, + self._session_factory is not None, + self._hybrid_mode_enabled, + self._hybrid_flush_threshold, + ) + + def is_enabled(self) -> bool: + """Check if persistent buffering is enabled and configured.""" + return self._enabled and self._session_factory is not None + + def is_hybrid_mode_enabled(self) -> bool: + """Check if hybrid RAM+DB buffering mode is enabled.""" + return self._hybrid_mode_enabled and self.is_enabled() + + def set_task_metadata( + self, + task_id: str, + session_id: str, + user_id: str, + ) -> None: + """ + Store task metadata for later use when buffering events. + + This stores metadata in memory (for fast access). Cross-process access + is handled via get_task_metadata which falls back to database lookup. + + Args: + task_id: The task ID + session_id: The session ID + user_id: The user ID + """ + # Store in memory cache only - don't write to DB to avoid conflicts + with self._lock: + self._task_metadata_cache[task_id] = { + "session_id": session_id, + "user_id": user_id, + } + + log.debug( + "%s Cached metadata for task %s: session=%s, user=%s", + self.log_identifier, + task_id, + session_id, + user_id, + ) + + def get_task_metadata(self, task_id: str) -> Optional[Dict[str, str]]: + """ + Get task metadata from cache or database. + + First checks the in-memory cache, then falls back to database lookup. + This ensures metadata is available even across process boundaries. + + Args: + task_id: The task ID + + Returns: + Dictionary with session_id and user_id, or None if not found + """ + # Check in-memory cache first + with self._lock: + cached = self._task_metadata_cache.get(task_id) + if cached: + return cached + + # Fall back to database lookup + if self._session_factory: + try: + from .repository.task_repository import TaskRepository + + db = self._session_factory() + try: + repo = TaskRepository() + task = repo.find_by_id(db, task_id) + if task and task.session_id and task.user_id: + metadata = { + "session_id": task.session_id, + "user_id": task.user_id, + } + # Cache it for future lookups + with self._lock: + self._task_metadata_cache[task_id] = metadata + log.debug( + "%s Retrieved task metadata from database for %s: session=%s, user=%s", + self.log_identifier, + task_id, + task.session_id, + task.user_id, + ) + return metadata + finally: + db.close() + except Exception as e: + log.debug( + "%s Failed to get task metadata from database: %s", + self.log_identifier, + e, + ) + + return None + + def clear_task_metadata(self, task_id: str) -> None: + """ + Clear cached task metadata. + + Args: + task_id: The task ID + """ + with self._lock: + self._task_metadata_cache.pop(task_id, None) + + def buffer_event( + self, + task_id: str, + event_type: str, + event_data: Dict[str, Any], + session_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> bool: + """ + Buffer an SSE event. + + In normal mode: writes directly to database. + In hybrid mode: buffers to RAM first, flushes to DB when threshold reached. + + Args: + task_id: The task ID this event belongs to + event_type: The SSE event type (e.g., 'message') + event_data: The event data payload (already serialized) + session_id: The session ID (optional, will use cached if not provided) + user_id: The user ID (optional, will use cached if not provided) + + Returns: + True if the event was buffered, False otherwise + """ + if not self.is_enabled(): + return False + + # Get metadata from cache if not provided + if not session_id or not user_id: + metadata = self.get_task_metadata(task_id) + if metadata: + session_id = session_id or metadata.get("session_id") + user_id = user_id or metadata.get("user_id") + + if not session_id or not user_id: + log.warning( + "%s Cannot buffer event for task %s: missing session_id or user_id", + self.log_identifier, + task_id, + ) + return False + + # HYBRID MODE: Buffer to RAM first, flush to DB on threshold + if self.is_hybrid_mode_enabled(): + return self._buffer_event_hybrid(task_id, event_type, event_data, session_id, user_id) + + # NORMAL MODE: Write directly to database + return self._buffer_event_to_db(task_id, event_type, event_data, session_id, user_id) + + def _buffer_event_hybrid( + self, + task_id: str, + event_type: str, + event_data: Dict[str, Any], + session_id: str, + user_id: str, + ) -> bool: + """ + Buffer event to RAM, flush to DB when threshold reached. + + Args: + task_id: The task ID + event_type: The SSE event type + event_data: The event payload + session_id: The session ID + user_id: The user ID + + Returns: + True if buffered successfully + """ + timestamp = int(time.time() * 1000) # milliseconds + should_flush = False + buffer_size_before_flush = 0 + + with self._lock: + # Add to RAM buffer + if task_id not in self._ram_buffer: + self._ram_buffer[task_id] = [] + + self._ram_buffer[task_id].append((event_type, event_data, timestamp, session_id, user_id)) + + buffer_size_before_flush = len(self._ram_buffer[task_id]) + + # Check if we should flush + if buffer_size_before_flush >= self._hybrid_flush_threshold: + should_flush = True + log.debug( + "%s [Hybrid] RAM buffer for task %s reached threshold (%d >= %d), will flush to DB", + self.log_identifier, + task_id, + buffer_size_before_flush, + self._hybrid_flush_threshold, + ) + + # Flush outside the lock to avoid holding it during DB operations + if should_flush: + self.flush_task_buffer(task_id) + + # Get actual current buffer size for accurate logging + current_buffer_size = self.get_ram_buffer_size(task_id) + + log.debug( + "%s [Hybrid] Buffered event to RAM for task %s (type=%s, ram_buffer_size=%d)", + self.log_identifier, + task_id, + event_type, + current_buffer_size, + ) + + return True + + def _buffer_event_to_db( + self, + task_id: str, + event_type: str, + event_data: Dict[str, Any], + session_id: str, + user_id: str, + ) -> bool: + """ + Buffer event directly to database (normal mode). + + Args: + task_id: The task ID + event_type: The SSE event type + event_data: The event payload + session_id: The session ID + user_id: The user ID + + Returns: + True if buffered successfully + """ + try: + from .repository.sse_event_buffer_repository import SSEEventBufferRepository + + db = self._session_factory() + try: + repo = SSEEventBufferRepository() + repo.buffer_event( + db=db, + task_id=task_id, + session_id=session_id, + user_id=user_id, + event_type=event_type, + event_data=event_data, + ) + + # Note: We don't update the tasks table here because: + # 1. The task record may not exist yet (it's created by task_logger_service) + # 2. We can determine if events exist by querying sse_event_buffer directly + + db.commit() + + log.debug( + "%s Buffered event for task %s (type=%s)", + self.log_identifier, + task_id, + event_type, + ) + return True + finally: + db.close() + except Exception as e: + log.error( + "%s Failed to buffer event for task %s: %s", + self.log_identifier, + task_id, + e, + ) + return False + + def flush_task_buffer(self, task_id: str) -> int: + """ + Flush RAM buffer for a specific task to database. + + This is called: + 1. When RAM buffer reaches threshold + 2. When SSE connection is closed + 3. When task completes + + Args: + task_id: The task ID to flush + + Returns: + Number of events flushed + """ + if not self.is_hybrid_mode_enabled(): + return 0 + + # Extract events from RAM buffer under lock + events_to_flush = [] + with self._lock: + events_to_flush = self._ram_buffer.pop(task_id, []) + + if not events_to_flush: + return 0 + + # Flush to database outside the lock + try: + from .repository.sse_event_buffer_repository import SSEEventBufferRepository + + db = self._session_factory() + try: + repo = SSEEventBufferRepository() + + for event_type, event_data, timestamp, session_id, user_id in events_to_flush: + repo.buffer_event( + db=db, + task_id=task_id, + session_id=session_id, + user_id=user_id, + event_type=event_type, + event_data=event_data, + created_time=timestamp, + ) + + db.commit() + + log.info( + "%s [Hybrid] Flushed %d events for task %s from RAM to DB", + self.log_identifier, + len(events_to_flush), + task_id, + ) + + return len(events_to_flush) + finally: + db.close() + except Exception as e: + log.error( + "%s [Hybrid] Failed to flush events for task %s: %s. Events lost: %d", + self.log_identifier, + task_id, + e, + len(events_to_flush), + ) + # Re-add events to buffer on failure for retry + with self._lock: + if task_id not in self._ram_buffer: + self._ram_buffer[task_id] = [] + # Prepend the failed events (they came first) + self._ram_buffer[task_id] = events_to_flush + self._ram_buffer[task_id] + return 0 + + def flush_all_buffers(self) -> int: + """ + Flush all RAM buffers to database. + + This is called during shutdown to ensure no events are lost. + + Returns: + Total number of events flushed + """ + if not self.is_hybrid_mode_enabled(): + return 0 + + # Get all task IDs with buffered events + with self._lock: + task_ids = list(self._ram_buffer.keys()) + + total_flushed = 0 + for task_id in task_ids: + total_flushed += self.flush_task_buffer(task_id) + + if total_flushed > 0: + log.info( + "%s [Hybrid] Flushed all buffers: %d total events", + self.log_identifier, + total_flushed, + ) + + return total_flushed + + def get_ram_buffer_size(self, task_id: str) -> int: + """ + Get the number of events in RAM buffer for a task. + + Args: + task_id: The task ID + + Returns: + Number of events in RAM buffer + """ + with self._lock: + return len(self._ram_buffer.get(task_id, [])) + + def get_buffered_events( + self, + task_id: str, + mark_consumed: bool = True, + ) -> List[Dict[str, Any]]: + """ + Get all buffered events for a task. + + In hybrid mode: first flushes RAM buffer to DB, then retrieves from DB. + This ensures all events are returned in correct order. + + Args: + task_id: The task ID + mark_consumed: Whether to mark events as consumed + + Returns: + List of event dictionaries + """ + if not self.is_enabled(): + return [] + + # In hybrid mode, flush RAM buffer first to ensure all events are in DB + if self.is_hybrid_mode_enabled(): + self.flush_task_buffer(task_id) + + try: + from .repository.sse_event_buffer_repository import SSEEventBufferRepository + + db = self._session_factory() + try: + repo = SSEEventBufferRepository() + events = repo.get_buffered_events( + db=db, + task_id=task_id, + mark_consumed=mark_consumed, + ) + if mark_consumed: + db.commit() + + log.info( + "%s Retrieved %d buffered events for task %s (mark_consumed=%s)", + self.log_identifier, + len(events), + task_id, + mark_consumed, + ) + return events + finally: + db.close() + except Exception as e: + log.error( + "%s Failed to get buffered events for task %s: %s", + self.log_identifier, + task_id, + e, + ) + return [] + + def has_unconsumed_events(self, task_id: str) -> bool: + """ + Check if a task has unconsumed buffered events. + + In hybrid mode, also checks RAM buffer. + + Args: + task_id: The task ID + + Returns: + True if there are unconsumed events + """ + if not self.is_enabled(): + return False + + # In hybrid mode, check RAM buffer first + if self.is_hybrid_mode_enabled(): + with self._lock: + if self._ram_buffer.get(task_id): + return True + + try: + from .repository.sse_event_buffer_repository import SSEEventBufferRepository + + db = self._session_factory() + try: + repo = SSEEventBufferRepository() + return repo.has_unconsumed_events(db, task_id) + finally: + db.close() + except Exception as e: + log.error( + "%s Failed to check unconsumed events for task %s: %s", + self.log_identifier, + task_id, + e, + ) + return False + + def get_unconsumed_events_for_session( + self, + session_id: str, + ) -> Dict[str, List[Dict[str, Any]]]: + """ + Get all unconsumed events for a session, grouped by task. + + Args: + session_id: The session ID + + Returns: + Dictionary mapping task_id to list of events + """ + if not self.is_enabled(): + return {} + + try: + from .repository.sse_event_buffer_repository import SSEEventBufferRepository + + db = self._session_factory() + try: + repo = SSEEventBufferRepository() + # Repository returns List[SSEEventBufferModel], we need to convert to Dict[task_id, List[events]] + events = repo.get_unconsumed_events_for_session(db, session_id) + + # Group events by task_id + result: Dict[str, List[Dict[str, Any]]] = {} + for event in events: + task_id = event.task_id + if task_id not in result: + result[task_id] = [] + result[task_id].append({ + "event_type": event.event_type, + "event_data": event.event_data, + "event_sequence": event.event_sequence, + "created_at": event.created_at, + }) + return result + finally: + db.close() + except Exception as e: + log.error( + "%s Failed to get unconsumed events for session %s: %s", + self.log_identifier, + session_id, + e, + ) + return {} + + def delete_events_for_task(self, task_id: str) -> int: + """ + Delete all buffered events for a task. + + In hybrid mode, also clears RAM buffer. + + Args: + task_id: The task ID + + Returns: + Number of events deleted + """ + log.debug( + "%s [BufferCleanup] delete_events_for_task called for task_id=%s, is_enabled=%s, hybrid_mode=%s", + self.log_identifier, + task_id, + self.is_enabled(), + self.is_hybrid_mode_enabled(), + ) + if not self.is_enabled(): + log.debug("%s [BufferCleanup] Buffer not enabled, returning 0", self.log_identifier) + return 0 + + # In hybrid mode, clear RAM buffer first (discard without flushing to DB) + ram_cleared = 0 + if self.is_hybrid_mode_enabled(): + with self._lock: + events = self._ram_buffer.pop(task_id, []) + ram_cleared = len(events) + if ram_cleared > 0: + log.debug( + "%s [BufferCleanup] Cleared %d events from RAM buffer for task %s (discarded, not flushed)", + self.log_identifier, + ram_cleared, + task_id, + ) + + try: + from .repository.sse_event_buffer_repository import SSEEventBufferRepository + + db = self._session_factory() + try: + repo = SSEEventBufferRepository() + deleted = repo.delete_events_for_task(db, task_id) + db.commit() + + log.debug( + "%s [BufferCleanup] Deleted %d events for task %s from database (+ %d from RAM)", + self.log_identifier, + deleted, + task_id, + ram_cleared, + ) + + # Clear cached metadata + self.clear_task_metadata(task_id) + + return deleted + ram_cleared + finally: + db.close() + except Exception as e: + log.error( + "%s Failed to delete events for task %s: %s", + self.log_identifier, + task_id, + e, + ) + return ram_cleared # Still return RAM cleared count + + def cleanup_old_events(self, days: int = 7) -> int: + """ + Clean up consumed events older than the specified number of days. + + Args: + days: Number of days to keep consumed events + + Returns: + Number of events deleted + """ + if not self.is_enabled(): + return 0 + + try: + from .repository.sse_event_buffer_repository import SSEEventBufferRepository + from solace_agent_mesh.shared.utils.timestamp_utils import now_epoch_ms + + # Calculate cutoff time + cutoff_ms = now_epoch_ms() - (days * 24 * 60 * 60 * 1000) + + db = self._session_factory() + try: + repo = SSEEventBufferRepository() + deleted = repo.cleanup_consumed_events(db, cutoff_ms) + db.commit() + + if deleted > 0: + log.info( + "%s Cleaned up %d consumed events older than %d days", + self.log_identifier, + deleted, + days, + ) + + return deleted + finally: + db.close() + except Exception as e: + log.error( + "%s Failed to cleanup old events: %s", + self.log_identifier, + e, + ) + return 0 diff --git a/src/solace_agent_mesh/gateway/http_sse/repository/__init__.py b/src/solace_agent_mesh/gateway/http_sse/repository/__init__.py index bb8f162003..3f850f58a8 100644 --- a/src/solace_agent_mesh/gateway/http_sse/repository/__init__.py +++ b/src/solace_agent_mesh/gateway/http_sse/repository/__init__.py @@ -16,6 +16,7 @@ from .feedback_repository import FeedbackRepository from .project_repository import ProjectRepository from .session_repository import SessionRepository +from .sse_event_buffer_repository import SSEEventBufferRepository from .task_repository import TaskRepository # Entities (re-exported for convenience) @@ -24,6 +25,7 @@ # Models (re-exported for convenience) from .models.base import Base from .models.session_model import SessionModel +from .models.sse_event_buffer_model import SSEEventBufferModel __all__ = [ # Interfaces @@ -37,10 +39,12 @@ "FeedbackRepository", "ProjectRepository", "SessionRepository", + "SSEEventBufferRepository", "TaskRepository", # Entities "Session", # Models "Base", "SessionModel", + "SSEEventBufferModel", ] diff --git a/src/solace_agent_mesh/gateway/http_sse/repository/entities/task.py b/src/solace_agent_mesh/gateway/http_sse/repository/entities/task.py index 6513a8bd7e..2a10a1ce4f 100644 --- a/src/solace_agent_mesh/gateway/http_sse/repository/entities/task.py +++ b/src/solace_agent_mesh/gateway/http_sse/repository/entities/task.py @@ -27,6 +27,11 @@ class Task(BaseModel): last_activity_time: int | None = None background_execution_enabled: bool | None = False max_execution_time_ms: int | None = None + + # SSE event buffer fields + session_id: str | None = None + events_buffered: bool | None = False + events_consumed: bool | None = False class Config: from_attributes = True diff --git a/src/solace_agent_mesh/gateway/http_sse/repository/models/__init__.py b/src/solace_agent_mesh/gateway/http_sse/repository/models/__init__.py index d519ca4af7..372f5f9533 100644 --- a/src/solace_agent_mesh/gateway/http_sse/repository/models/__init__.py +++ b/src/solace_agent_mesh/gateway/http_sse/repository/models/__init__.py @@ -8,6 +8,7 @@ from .project_model import ProjectModel, CreateProjectModel, UpdateProjectModel from .project_user_model import ProjectUserModel, CreateProjectUserModel, UpdateProjectUserModel from .session_model import SessionModel, CreateSessionModel, UpdateSessionModel +from .sse_event_buffer_model import SSEEventBufferModel from .task_event_model import TaskEventModel from .task_model import TaskModel from .prompt_model import PromptGroupModel, PromptModel, PromptGroupUserModel @@ -19,6 +20,7 @@ "ProjectModel", "ProjectUserModel", "SessionModel", + "SSEEventBufferModel", "CreateProjectModel", "UpdateProjectModel", "CreateProjectUserModel", diff --git a/src/solace_agent_mesh/gateway/http_sse/repository/models/sse_event_buffer_model.py b/src/solace_agent_mesh/gateway/http_sse/repository/models/sse_event_buffer_model.py new file mode 100644 index 0000000000..d32d3cb765 --- /dev/null +++ b/src/solace_agent_mesh/gateway/http_sse/repository/models/sse_event_buffer_model.py @@ -0,0 +1,35 @@ +""" +SSE Event Buffer SQLAlchemy model. + +This model stores SSE events for background tasks that need to be replayed +when the user returns to the session. Events are stored in sequence order +and can be fetched and replayed through the frontend's existing SSE processing. +""" + +from sqlalchemy import BigInteger, Boolean, Column, Integer, JSON, String + +from .base import Base + + +class SSEEventBufferModel(Base): + """SQLAlchemy model for SSE event buffer entries.""" + + __tablename__ = "sse_event_buffer" + + id = Column(Integer, primary_key=True, autoincrement=True) + task_id = Column(String(255), nullable=False, index=True) + session_id = Column(String(255), nullable=False, index=True) + user_id = Column(String(255), nullable=False) + event_sequence = Column(Integer, nullable=False) + event_type = Column(String(50), nullable=False) + event_data = Column(JSON, nullable=False) + created_at = Column(BigInteger, nullable=False) # Epoch milliseconds + consumed = Column(Boolean, nullable=False, default=False, index=True) + consumed_at = Column(BigInteger, nullable=True) # Epoch milliseconds + + def __repr__(self): + return ( + f"" + ) diff --git a/src/solace_agent_mesh/gateway/http_sse/repository/models/task_model.py b/src/solace_agent_mesh/gateway/http_sse/repository/models/task_model.py index 4f2f923f34..9fb0199860 100644 --- a/src/solace_agent_mesh/gateway/http_sse/repository/models/task_model.py +++ b/src/solace_agent_mesh/gateway/http_sse/repository/models/task_model.py @@ -32,6 +32,11 @@ class TaskModel(Base): last_activity_time = Column(BigInteger, nullable=True, index=True) background_execution_enabled = Column(Boolean, nullable=True, default=False) max_execution_time_ms = Column(BigInteger, nullable=True) + + # SSE event buffer state columns + session_id = Column(String, nullable=True, index=True) + events_buffered = Column(Boolean, nullable=True, default=False, index=True) + events_consumed = Column(Boolean, nullable=True, default=False) # Relationship to events events = relationship( diff --git a/src/solace_agent_mesh/gateway/http_sse/repository/sse_event_buffer_repository.py b/src/solace_agent_mesh/gateway/http_sse/repository/sse_event_buffer_repository.py new file mode 100644 index 0000000000..bbc59f1fdf --- /dev/null +++ b/src/solace_agent_mesh/gateway/http_sse/repository/sse_event_buffer_repository.py @@ -0,0 +1,337 @@ +""" +Repository for SSE Event Buffer operations. + +This repository handles persistence of SSE events for background tasks +that need to be replayed when the user returns to the session. +""" + +import logging +from typing import List, Optional + +from sqlalchemy import func +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session as DBSession + +from .models.sse_event_buffer_model import SSEEventBufferModel +from .models.task_model import TaskModel +from solace_agent_mesh.shared.utils.timestamp_utils import now_epoch_ms + +log = logging.getLogger(__name__) + +# Maximum retry attempts for sequence number race condition +MAX_SEQUENCE_RETRIES = 3 + + +class SSEEventBufferRepository: + """Repository for SSE event buffer database operations.""" + + def __init__(self): + self.log_identifier = "[SSEEventBufferRepository]" + + def buffer_event( + self, + db: DBSession, + task_id: str, + session_id: str, + user_id: str, + event_type: str, + event_data: dict, + created_time: int | None = None, + ) -> SSEEventBufferModel: + """ + Buffer an SSE event for later replay. + + Uses retry logic to handle rare race conditions where concurrent writes + could result in duplicate sequence numbers. The unique constraint on + (task_id, event_sequence) ensures data integrity. + + Note: In practice, this race condition is extremely unlikely because: + 1. Events for each task flow through a single-threaded message processor + 2. Each task has its own event stream arriving sequentially + + The retry logic is a defensive measure for edge cases like hybrid mode + RAM buffer flush coinciding with a new event, or multi-worker deployments. + + Args: + db: Database session + task_id: The task ID this event belongs to + session_id: The session ID + user_id: The user ID + event_type: The SSE event type (e.g., 'message') + event_data: The event data payload + created_time: Optional timestamp for when the event was originally created + (used by hybrid mode when flushing RAM buffer) + + Returns: + The created SSEEventBufferModel instance + + Raises: + IntegrityError: If all retry attempts fail (should be extremely rare) + """ + last_error = None + + for attempt in range(MAX_SEQUENCE_RETRIES): + try: + # Get next sequence number for this task + max_seq = db.query(func.max(SSEEventBufferModel.event_sequence))\ + .filter(SSEEventBufferModel.task_id == task_id)\ + .scalar() or 0 + + # Create buffer entry + buffer_entry = SSEEventBufferModel( + task_id=task_id, + session_id=session_id, + user_id=user_id, + event_sequence=max_seq + 1, + event_type=event_type, + event_data=event_data, + created_at=created_time if created_time is not None else now_epoch_ms(), + consumed=False, + ) + db.add(buffer_entry) + + # Mark task as having buffered events + task = db.query(TaskModel).filter(TaskModel.id == task_id).first() + if task and not task.events_buffered: + task.events_buffered = True + + db.flush() # Flush to get the ID and trigger constraint check + + log.debug( + "%s Buffered event for task %s: sequence=%d, type=%s", + self.log_identifier, + task_id, + buffer_entry.event_sequence, + event_type, + ) + + return buffer_entry + + except IntegrityError as e: + # Rollback the failed transaction + db.rollback() + last_error = e + + # Check if this is a sequence number collision + if "sse_event_buffer_task_seq_unique" in str(e) or "UNIQUE constraint" in str(e): + log.warning( + "%s Sequence number race condition detected for task %s (attempt %d/%d), retrying...", + self.log_identifier, + task_id, + attempt + 1, + MAX_SEQUENCE_RETRIES, + ) + continue + else: + # Some other integrity error, re-raise + raise + + # All retries failed - this should be extremely rare + log.error( + "%s Failed to buffer event after %d attempts for task %s: %s", + self.log_identifier, + MAX_SEQUENCE_RETRIES, + task_id, + last_error, + ) + raise last_error + + def get_buffered_events( + self, + db: DBSession, + task_id: str, + mark_consumed: bool = True, + ) -> List[dict]: + """ + Get all buffered events for a task in sequence order. + + Args: + db: Database session + task_id: The task ID to get events for + mark_consumed: Whether to mark events as consumed + + Returns: + List of event dictionaries with type, data, and sequence + """ + events = db.query(SSEEventBufferModel)\ + .filter(SSEEventBufferModel.task_id == task_id)\ + .order_by(SSEEventBufferModel.event_sequence)\ + .all() + + event_data = [ + { + "type": event.event_type, + "data": event.event_data, + "sequence": event.event_sequence, + } + for event in events + ] + + if mark_consumed and events: + # Mark events as consumed + consumed_at = now_epoch_ms() + db.query(SSEEventBufferModel)\ + .filter(SSEEventBufferModel.task_id == task_id)\ + .update({ + "consumed": True, + "consumed_at": consumed_at, + }) + + # Mark task as consumed + task = db.query(TaskModel).filter(TaskModel.id == task_id).first() + if task: + task.events_consumed = True + + log.info( + "%s Marked %d events as consumed for task %s", + self.log_identifier, + len(events), + task_id, + ) + + return event_data + + def get_unconsumed_events_for_session( + self, + db: DBSession, + session_id: str, + task_id: Optional[str] = None, + mark_consumed: bool = False, + ) -> List[SSEEventBufferModel]: + """ + Get unconsumed events for a session, optionally filtered by task. + + Args: + db: Database session + session_id: The session ID + task_id: Optional task ID to filter by + mark_consumed: Whether to mark events as consumed + + Returns: + List of SSEEventBufferModel instances + """ + query = db.query(SSEEventBufferModel)\ + .filter(SSEEventBufferModel.session_id == session_id)\ + .filter(SSEEventBufferModel.consumed.is_(False)) + + if task_id: + query = query.filter(SSEEventBufferModel.task_id == task_id) + + events = query.order_by(SSEEventBufferModel.event_sequence).all() + + if mark_consumed and events: + # Mark events as consumed + consumed_at = now_epoch_ms() + for event in events: + event.consumed = True + event.consumed_at = consumed_at + + log.info( + "%s Marked %d events as consumed for session %s (task=%s)", + self.log_identifier, + len(events), + session_id, + task_id, + ) + + return events + + def has_unconsumed_events( + self, + db: DBSession, + task_id: str, + ) -> bool: + """ + Check if a task has unconsumed buffered events. + + Args: + db: Database session + task_id: The task ID to check + + Returns: + True if there are unconsumed events, False otherwise + """ + count = db.query(func.count(SSEEventBufferModel.id))\ + .filter(SSEEventBufferModel.task_id == task_id)\ + .filter(SSEEventBufferModel.consumed.is_(False))\ + .scalar() + + return count > 0 + + def get_event_count( + self, + db: DBSession, + task_id: str, + ) -> int: + """ + Get the total number of buffered events for a task. + + Args: + db: Database session + task_id: The task ID + + Returns: + Number of buffered events + """ + return db.query(func.count(SSEEventBufferModel.id))\ + .filter(SSEEventBufferModel.task_id == task_id)\ + .scalar() or 0 + + def cleanup_consumed_events( + self, + db: DBSession, + older_than_ms: int, + ) -> int: + """ + Clean up consumed events older than the specified time. + + Args: + db: Database session + older_than_ms: Delete events consumed before this epoch time (ms) + + Returns: + Number of events deleted + """ + deleted = db.query(SSEEventBufferModel)\ + .filter(SSEEventBufferModel.consumed.is_(True))\ + .filter(SSEEventBufferModel.consumed_at < older_than_ms)\ + .delete() + + if deleted > 0: + log.info( + "%s Cleaned up %d consumed events older than %d", + self.log_identifier, + deleted, + older_than_ms, + ) + + return deleted + + def delete_events_for_task( + self, + db: DBSession, + task_id: str, + ) -> int: + """ + Delete all buffered events for a task. + + Args: + db: Database session + task_id: The task ID + + Returns: + Number of events deleted + """ + deleted = db.query(SSEEventBufferModel)\ + .filter(SSEEventBufferModel.task_id == task_id)\ + .delete() + + if deleted > 0: + log.debug( + "%s Deleted %d events for task %s", + self.log_identifier, + deleted, + task_id, + ) + + return deleted diff --git a/src/solace_agent_mesh/gateway/http_sse/repository/task_repository.py b/src/solace_agent_mesh/gateway/http_sse/repository/task_repository.py index 76d4211146..9bbe6122d0 100644 --- a/src/solace_agent_mesh/gateway/http_sse/repository/task_repository.py +++ b/src/solace_agent_mesh/gateway/http_sse/repository/task_repository.py @@ -33,6 +33,10 @@ def save_task(self, session: DBSession, task: Task) -> Task: model.last_activity_time = task.last_activity_time model.background_execution_enabled = task.background_execution_enabled model.max_execution_time_ms = task.max_execution_time_ms + # SSE event buffer state + model.session_id = task.session_id + model.events_buffered = task.events_buffered + model.events_consumed = task.events_consumed else: model = TaskModel( id=task.id, @@ -50,6 +54,10 @@ def save_task(self, session: DBSession, task: Task) -> Task: last_activity_time=task.last_activity_time, background_execution_enabled=task.background_execution_enabled, max_execution_time_ms=task.max_execution_time_ms, + # SSE event buffer state + session_id=task.session_id, + events_buffered=task.events_buffered, + events_consumed=task.events_consumed, ) session.add(model) @@ -70,8 +78,9 @@ def save_event(self, session: DBSession, event: TaskEvent) -> TaskEvent: ) session.add(model) session.flush() - session.refresh(model) - return self._event_model_to_entity(model) + # Note: We don't refresh here since we already have all the data, + # and refresh can fail in certain edge cases (e.g., foreign key constraints) + return event def find_by_id(self, session: DBSession, task_id: str) -> Task | None: """Find a task by its ID.""" diff --git a/src/solace_agent_mesh/gateway/http_sse/routers/sessions.py b/src/solace_agent_mesh/gateway/http_sse/routers/sessions.py index c2fc3cfedf..84deaa6e94 100644 --- a/src/solace_agent_mesh/gateway/http_sse/routers/sessions.py +++ b/src/solace_agent_mesh/gateway/http_sse/routers/sessions.py @@ -1,12 +1,20 @@ import asyncio import json import logging +import re from typing import Optional, TYPE_CHECKING from fastapi import APIRouter, Body, Depends, HTTPException, Query, status from pydantic import ValidationError from sqlalchemy.orm import Session -from ..dependencies import get_session_business_service, get_db, get_title_generation_service +from ....common.utils.embeds import ( + LATE_EMBED_TYPES, + evaluate_embed, + resolve_embeds_in_string, +) +from ....common.utils.embeds.types import ResolutionMode +from ....common.utils.templates import resolve_template_blocks_in_string +from ..dependencies import get_session_business_service, get_db, get_title_generation_service, get_shared_artifact_service, get_sac_component from ..services.session_service import SessionService from solace_agent_mesh.shared.api.auth_utils import get_current_user from solace_agent_mesh.shared.api.pagination import DataResponse, PaginatedResponse, PaginationParams @@ -199,6 +207,8 @@ async def save_task( db: Session = Depends(get_db), user: dict = Depends(get_current_user), session_service: SessionService = Depends(get_session_business_service), + artifact_service = Depends(get_shared_artifact_service), + component = Depends(get_sac_component), ): """ Save a complete task interaction (upsert). @@ -229,14 +239,119 @@ async def save_task( existing_task = task_repo.find_by_id(db, request.task_id, user_id) is_update = existing_task is not None + # Resolve embeds in message_bubbles before saving + message_bubbles = request.message_bubbles + if artifact_service and message_bubbles and '«' in message_bubbles: + try: + # Parse the message bubbles JSON + bubbles = json.loads(message_bubbles) + resolved_bubbles = [] + + gateway_id = component.gateway_id if component else "webui" + + for bubble in bubbles: + if isinstance(bubble, dict): + # Resolve embeds in the text field + text = bubble.get("text", "") + if text and '«' in text: + embed_eval_context = { + "artifact_service": artifact_service, + "session_context": { + "app_name": gateway_id, + "user_id": user_id, + "session_id": session_id, + }, + } + embed_eval_config = { + "gateway_max_artifact_resolve_size_bytes": 1024 * 1024, # 1MB limit + "gateway_recursive_embed_depth": 3, + } + + resolved_text, _, signals = await resolve_embeds_in_string( + text=text, + context=embed_eval_context, + resolver_func=evaluate_embed, + types_to_resolve=LATE_EMBED_TYPES, + resolution_mode=ResolutionMode.A2A_MESSAGE_TO_USER, + log_identifier=f"[SaveTask:{request.task_id}]", + config=embed_eval_config, + ) + + # Resolve template blocks (template_liquid) + if '«««template' in resolved_text: + resolved_text = await resolve_template_blocks_in_string( + text=resolved_text, + artifact_service=artifact_service, + session_context={ + "app_name": gateway_id, + "user_id": user_id, + "session_id": session_id, + }, + log_identifier=f"[SaveTask:{request.task_id}][TemplateResolve]", + ) + + # Strip status_update embeds (they're for real-time display only) + status_update_pattern = r'«status_update:[^»]+»\n?' + resolved_text = re.sub(status_update_pattern, '', resolved_text) + + # Strip any remaining template blocks that weren't resolved + template_block_pattern = r'«««template(?:_liquid)?:[^\n]+\n(?:(?!»»»).)*?»»»' + resolved_text = re.sub(template_block_pattern, '', resolved_text, flags=re.DOTALL) + + bubble["text"] = resolved_text + + # Also resolve embeds and templates in parts if they contain text + parts = bubble.get("parts", []) + resolved_parts = [] + for part in parts: + if isinstance(part, dict) and part.get("kind") == "text": + part_text = part.get("text", "") + if part_text and '«' in part_text: + resolved_part_text, _, _ = await resolve_embeds_in_string( + text=part_text, + context=embed_eval_context, + resolver_func=evaluate_embed, + types_to_resolve=LATE_EMBED_TYPES, + resolution_mode=ResolutionMode.A2A_MESSAGE_TO_USER, + log_identifier=f"[SaveTask:{request.task_id}]", + config=embed_eval_config, + ) + # Resolve template blocks in parts + if '«««template' in resolved_part_text: + resolved_part_text = await resolve_template_blocks_in_string( + text=resolved_part_text, + artifact_service=artifact_service, + session_context={ + "app_name": gateway_id, + "user_id": user_id, + "session_id": session_id, + }, + log_identifier=f"[SaveTask:{request.task_id}][TemplateResolve]", + ) + part["text"] = resolved_part_text + resolved_parts.append(part) + bubble["parts"] = resolved_parts + + resolved_bubbles.append(bubble) + + message_bubbles = json.dumps(resolved_bubbles) + log.debug("Resolved embeds in message_bubbles for task %s", request.task_id) + except Exception as e: + log.warning( + "Failed to resolve embeds in message_bubbles for task %s: %s. Saving as-is.", + request.task_id, + e, + ) + # Save the task - pass strings directly + # Use the resolved message_bubbles if embeds were resolved, otherwise use the original saved_task = session_service.save_task( db=db, task_id=request.task_id, session_id=session_id, user_id=user_id, user_message=request.user_message, - message_bubbles=request.message_bubbles, # Already a string + message_bubbles=message_bubbles, # Use resolved message_bubbles task_metadata=request.task_metadata, # Already a string ) @@ -247,6 +362,51 @@ async def save_task( session_id, ) + # Clear SSE event buffer for this task (implicit cleanup) + # This is atomic with the save operation + try: + from ..dependencies import get_sac_component + component = get_sac_component() + log.debug( + "[BufferCleanup] Task %s: Starting cleanup. component=%s, sse_manager=%s", + request.task_id, + component is not None, + component.sse_manager is not None if component else False, + ) + if component and component.sse_manager: + persistent_buffer = component.sse_manager.get_persistent_buffer() + log.debug( + "[BufferCleanup] Task %s: persistent_buffer=%s, is_enabled=%s", + request.task_id, + persistent_buffer is not None, + persistent_buffer.is_enabled() if persistent_buffer else False, + ) + if persistent_buffer and persistent_buffer.is_enabled(): + deleted_count = persistent_buffer.delete_events_for_task(request.task_id) + if deleted_count > 0: + log.info( + "[BufferCleanup] Task %s: Cleared %d buffered SSE events after chat_task save", + request.task_id, + deleted_count, + ) + else: + log.debug( + "[BufferCleanup] Task %s: Buffer disabled or not available, skipping cleanup", + request.task_id, + ) + else: + log.debug( + "[BufferCleanup] Task %s: No component or sse_manager available", + request.task_id, + ) + except Exception as buffer_error: + # Non-critical - buffer will be cleaned up by retention policy + log.warning( + "[BufferCleanup] Task %s: Failed to clear buffer: %s", + request.task_id, + buffer_error, + ) + # Convert to response DTO response = TaskResponse( task_id=saved_task.id, @@ -431,6 +591,132 @@ async def get_session_history( ) from e +@router.get("/sessions/{session_id}/events/unconsumed") +async def get_session_unconsumed_events( + session_id: str, + include_events: bool = False, + db: Session = Depends(get_db), + user: dict = Depends(get_current_user), + session_service: SessionService = Depends(get_session_business_service), +): + """ + Check for unconsumed buffered SSE events for a session. + + This endpoint is used by the frontend to determine if there are + any buffered events that need to be replayed when switching to a session. + + In the unified event buffer architecture: + 1. Frontend switches to a session + 2. Frontend calls this endpoint with include_events=true to get all buffered events in one request + 3. Frontend replays events through handleSseMessage + 4. Frontend saves to chat_tasks (which implicitly cleans up buffer) + + Query Parameters: + include_events: If true, include the actual event data in the response. + This enables batched retrieval to avoid N+1 queries. + + Returns: + JSON object with: + - has_events: boolean indicating if there are unconsumed events + - task_ids: list of task IDs with unconsumed events + - events_by_task: (only if include_events=true) dict mapping task_id to list of events + """ + user_id = user.get("id") + log.info( + "User %s checking for unconsumed events in session %s (include_events=%s)", + user_id, session_id, include_events + ) + + try: + if ( + not session_id + or session_id.strip() == "" + or session_id in ["null", "undefined"] + ): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=SESSION_NOT_FOUND_MSG + ) + + # Verify user owns this session + session_domain = session_service.get_session_details( + db=db, session_id=session_id, user_id=user_id + ) + if not session_domain: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=SESSION_NOT_FOUND_MSG + ) + + # Get the SSE manager to access the persistent buffer + from ..dependencies import get_sac_component + from ..component import WebUIBackendComponent + + component: WebUIBackendComponent = get_sac_component() + + if component is None: + log.warning("WebUI backend component not available") + return {"has_events": False, "task_ids": [], "events_by_task": {} if include_events else None} + + sse_manager = component.sse_manager + persistent_buffer = sse_manager.get_persistent_buffer() if sse_manager else None + if persistent_buffer is None: + log.debug("Persistent buffer not available") + return {"has_events": False, "task_ids": [], "events_by_task": {} if include_events else None} + + # Get unconsumed events for this session (already grouped by task_id) + unconsumed_by_task = persistent_buffer.get_unconsumed_events_for_session(session_id) + + task_ids = list(unconsumed_by_task.keys()) + has_events = len(task_ids) > 0 + + log.info( + "Session %s has %d tasks with unconsumed events: %s", + session_id, + len(task_ids), + task_ids + ) + + response = { + "has_events": has_events, + "task_ids": task_ids, + "session_id": session_id + } + + # Include actual events if requested (enables batched retrieval) + if include_events: + # Convert events to the format expected by frontend + # Format matches the per-task endpoint: {events_buffered: bool, events: [...]} + events_by_task = {} + for task_id, events in unconsumed_by_task.items(): + events_by_task[task_id] = { + "events_buffered": len(events) > 0, + "events": [ + { + "sequence": event.get("event_sequence", 0), + "event_type": event.get("event_type", "message"), + "data": event.get("event_data", {}), + } + for event in events + ] + } + response["events_by_task"] = events_by_task + + return response + + except HTTPException: + raise + except Exception as e: + log.exception( + "Error checking unconsumed events for session %s, user %s: %s", + session_id, + user_id, + e, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to check unconsumed events", + ) from e + + @router.patch("/sessions/{session_id}", response_model=SessionResponse) async def update_session_name( session_id: str, diff --git a/src/solace_agent_mesh/gateway/http_sse/routers/tasks.py b/src/solace_agent_mesh/gateway/http_sse/routers/tasks.py index f2284cf9c9..853c1767e5 100644 --- a/src/solace_agent_mesh/gateway/http_sse/routers/tasks.py +++ b/src/solace_agent_mesh/gateway/http_sse/routers/tasks.py @@ -539,6 +539,33 @@ async def _submit_task( log.info("%sTask submitted successfully. TaskID: %s", log_prefix, task_id) + # UNIFIED ARCHITECTURE: Register ALL tasks for persistent SSE event buffering + # when the feature is enabled (tied to background_tasks feature flag). + # This enables session switching, browser refresh recovery, and reconnection for ALL tasks. + # The FE will clear the buffer after successfully saving the chat_task. + try: + sse_manager = component.sse_manager + if sse_manager and sse_manager.get_persistent_buffer().is_enabled(): + sse_manager.register_task_for_persistent_buffer( + task_id=task_id, + session_id=session_id, + user_id=user_id, + ) + is_background = additional_metadata.get("backgroundExecutionEnabled", False) + log.info( + "%sRegistered task %s for persistent SSE buffering (session=%s, background=%s)", + log_prefix, + task_id, + session_id, + is_background, + ) + except Exception as e: + log.warning( + "%sFailed to register task for persistent buffering: %s", + log_prefix, + e, + ) + task_object = a2a.create_initial_task( task_id=task_id, context_id=session_id, @@ -875,6 +902,415 @@ async def get_task_events( ) +@router.get("/tasks/{task_id}/events/buffered", tags=["Tasks"]) +async def get_buffered_task_events( + task_id: str, + request: FastAPIRequest, + db: DBSession = Depends(get_db), + user_id: UserId = Depends(get_user_id), + user_config: dict = Depends(get_user_config), + repo: ITaskRepository = Depends(get_task_repository), + mark_consumed: bool = Query( + default=True, + description="Whether to mark events as consumed after fetching" + ), +): + """ + Retrieves buffered SSE events for a background task. + + This endpoint is used by the frontend to replay SSE events for background tasks + that completed while the user was disconnected. The events are returned in the + same format as the live SSE stream, allowing the frontend to process them + through its existing event handling logic. + + Args: + task_id: The ID of the task to fetch buffered events for + mark_consumed: If True, marks events as consumed after fetching (default: True) + + Returns: + A list of buffered SSE events in sequence order, ready for frontend replay + """ + log_prefix = f"[GET /api/v1/tasks/{task_id}/events/buffered] " + log.info("%sRequest from user %s, mark_consumed=%s", log_prefix, user_id, mark_consumed) + + try: + # First verify the task exists and user has permission + result = repo.find_by_id_with_events(db, task_id) + if not result: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Task with ID '{task_id}' not found.", + ) + + task, _ = result + + can_read_all = user_config.get("scopes", {}).get("tasks:read:all", False) + if task.user_id != user_id and not can_read_all: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to view this task.", + ) + + # Fetch buffered events from the persistent buffer + # Note: We query the sse_event_buffer table directly instead of relying on + # task.events_buffered flag, which may not be set if the task was created + # after events started being buffered (timing issue) + from ..repository.sse_event_buffer_repository import SSEEventBufferRepository + + buffer_repo = SSEEventBufferRepository() + + # Check if this task has buffered events by querying the buffer table directly + has_buffered = buffer_repo.has_unconsumed_events(db, task_id) + if not has_buffered: + # Also check for consumed events (already replayed but still stored) + event_count = buffer_repo.get_event_count(db, task_id) + if event_count == 0: + log.info("%sTask %s does not have buffered events", log_prefix, task_id) + return { + "task_id": task_id, + "events": [], + "has_more": False, + "events_buffered": False, + "events_consumed": task.events_consumed or False, + } + + if mark_consumed: + # Get unconsumed events and mark them as consumed + # Note: We use task_id directly, not session_id, since session_id might not be set + events = buffer_repo.get_buffered_events( + db=db, + task_id=task_id, + mark_consumed=True, + ) + + # The repository already marks events as consumed + else: + # Get all buffered events without marking as consumed + events = buffer_repo.get_buffered_events( + db=db, + task_id=task_id, + mark_consumed=False, + ) + + # events is already a list of dicts with keys: type, data, sequence + # Just pass them through, the format matches what frontend expects + log.info( + "%sReturning %d buffered events for task %s", + log_prefix, + len(events), + task_id, + ) + + # Commit the transaction to persist the consumed state + if mark_consumed and events: + db.commit() + + return { + "task_id": task_id, + "events": events, + "has_more": False, + "events_buffered": len(events) > 0, + "events_consumed": mark_consumed and len(events) > 0, + } + + except HTTPException: + raise + except Exception as e: + log.exception("%sError retrieving buffered events: %s", log_prefix, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An error occurred while retrieving buffered events.", + ) + + +@router.delete("/tasks/{task_id}/events/buffered", tags=["Tasks"]) +async def clear_buffered_task_events( + task_id: str, + request: FastAPIRequest, + db: DBSession = Depends(get_db), + user_id: UserId = Depends(get_user_id), +): + """ + Clear all buffered SSE events for a task. + + This endpoint is used to clean up orphan buffered events without + triggering a chat_task save. Use cases: + 1. Clean up leftover events when a chat_task already exists + 2. Explicitly clear buffer without updating session modified time + + NOTE: Buffer cleanup also happens implicitly in save_task endpoint + (POST /sessions/{session_id}/chat-tasks), so this endpoint is only + needed when you want cleanup without a save operation. + + Returns: + JSON object with the number of events deleted + """ + log_prefix = f"[DELETE /api/v1/tasks/{task_id}/events/buffered] " + log.debug("%sRequest from user %s to clear buffered events", log_prefix, user_id) + + try: + # Get the SSE manager to access the persistent buffer + component: "WebUIBackendComponent" = get_sac_component() + + if component is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="WebUI backend component not available", + ) + + sse_manager = component.sse_manager + persistent_buffer = sse_manager.get_persistent_buffer() if sse_manager else None + if persistent_buffer is None: + log.debug("%sPersistent buffer not available", log_prefix) + return {"deleted": 0, "message": "Persistent buffer not enabled"} + + # Verify user owns this task by checking the task's user_id in the buffer metadata + # or the task itself in the database + task_metadata = persistent_buffer.get_task_metadata(task_id) + if task_metadata: + task_user_id = task_metadata.get("user_id") + if task_user_id and task_user_id != user_id: + log.warning( + "%sUser %s attempted to clear buffer for task %s owned by %s", + log_prefix, + user_id, + task_id, + task_user_id, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to clear events for this task", + ) + else: + # No metadata found, try to verify via database task record + from ..repository.task_repository import TaskRepository + + repo = TaskRepository() + task = repo.find_by_id(db, task_id) + if task and hasattr(task, 'user_id') and task.user_id: + if task.user_id != user_id: + log.warning( + "%sUser %s attempted to clear buffer for task %s owned by %s", + log_prefix, + user_id, + task_id, + task.user_id, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to clear events for this task", + ) + + # Delete all events for this task + deleted_count = persistent_buffer.delete_events_for_task(task_id) + + if deleted_count > 0: + log.info("%sDeleted %d buffered events for task %s", log_prefix, deleted_count, task_id) + + return { + "deleted": deleted_count, + "task_id": task_id + } + + except HTTPException: + raise + except Exception as e: + log.exception("%sError clearing buffered events: %s", log_prefix, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An error occurred while clearing buffered events.", + ) + + +@router.get("/tasks/{task_id}/title-data", tags=["Tasks"]) +async def get_task_title_data( + task_id: str, + db: DBSession = Depends(get_db), + user_id: UserId = Depends(get_user_id), +): + """ + Extract user message and agent response from task for title generation. + + This endpoint extracts the first user message and final agent response from: + 1. The Task table (initial_request_text for user message) + 2. The SSE event buffer (final response for agent response) + + Used for background task title generation when the frontend was not watching. + """ + log_prefix = f"[GET /api/v1/tasks/{task_id}/title-data] " + log.info("%sRequest from user %s", log_prefix, user_id) + + try: + from ..repository.task_repository import TaskRepository + from ..repository.sse_event_buffer_repository import SSEEventBufferRepository + from ..repository.chat_task_repository import ChatTaskRepository + import json + + task_repo = TaskRepository() + buffer_repo = SSEEventBufferRepository() + chat_task_repo = ChatTaskRepository() + + # Get task for initial_request_text (user message) and session_id + task = task_repo.find_by_id(db, task_id) + if not task: + log.warning("%sTask %s not found", log_prefix, task_id) + return { + "user_message": None, + "agent_response": None, + "error": "Task not found" + } + + # Authorization: Verify user owns this task + if task.user_id and task.user_id != user_id: + log.warning( + "%sUser %s attempted to access title-data for task %s owned by %s", + log_prefix, + user_id, + task_id, + task.user_id, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to access this task's data", + ) + + user_message = None + agent_response = None + + try: + chat_task = chat_task_repo.find_by_id(db, task_id, user_id) + if chat_task: + log.info("%sPrimary: Found chat_task for task %s", log_prefix, task_id) + + # Use the clean user_message from chat_task + user_message = chat_task.user_message + + # Extract agent response from message_bubbles + if chat_task.message_bubbles: + bubbles = json.loads(chat_task.message_bubbles) + for bubble in reversed(bubbles): # Start from most recent + if bubble.get("direction") == "agent" or bubble.get("sender") == "agent": + # Look for text parts in the bubble + parts = bubble.get("parts", []) + for part in parts: + if part.get("type") == "text" or part.get("kind") == "text": + text = part.get("text", "") + if text and len(text) > 10: + agent_response = text + break + if agent_response: + break + + if user_message and agent_response: + log.info("%sUsing chat_task data: user=%d chars, agent=%d chars", + log_prefix, len(user_message), len(agent_response)) + except Exception as e: + log.warning("%sError reading from chat_tasks: %s", log_prefix, e) + + # Fallback to task.initial_request_text if no user_message from chat_task + if not user_message: + user_message = task.initial_request_text + + # FALLBACK: SSE event buffer (if chat_task didn't have agent response) + # This handles cases where task completed but FE hasn't saved chat_task yet + if not agent_response: + try: + events = buffer_repo.get_buffered_events(db, task_id, mark_consumed=False) + log.info("%sFallback SSE buffer: Found %d buffered events for task %s", log_prefix, len(events), task_id) + + # Collect streaming text fragments from status-update events (agent_progress_update) + # In streaming mode, text is sent incrementally, not in the final task response + streaming_text_parts = [] + + # Look for final "task" event with response text OR accumulate streaming text + for event in events: # Process in sequence order for streaming text + event_data = event.get("data", "") + if isinstance(event_data, str): + try: + parsed = json.loads(event_data) + except json.JSONDecodeError: + continue + else: + parsed = event_data + + # Check if this is an SSE wrapper with nested data + if "data" in parsed and isinstance(parsed.get("data"), str): + try: + inner_data = json.loads(parsed["data"]) + parsed = inner_data + except json.JSONDecodeError: + pass + + # Check for task response with text parts (non-streaming final response) + result = parsed.get("result", {}) + if result.get("kind") == "task": + task_data = result.get("task", {}) + artifacts = task_data.get("artifacts", []) + for artifact in artifacts: + parts = artifact.get("parts", []) + for part in parts: + if part.get("kind") == "text": + text = part.get("text", "") + if text and len(text) > 10: # Meaningful response + agent_response = text + break + if agent_response: + break + if agent_response: + break + + # Collect streaming text from status updates (agent_progress_update) + if result.get("kind") == "status-update": + status_data = result.get("status", {}) + message = status_data.get("message", {}) + if message: + parts = message.get("parts", []) + for part in parts: + if part.get("kind") == "text": + text = part.get("text", "") + if text: + streaming_text_parts.append(text) + + # Also check for agent_progress_update type (direct SSE event type) + if parsed.get("type") == "agent_progress_update": + text = parsed.get("text", "") + if text: + streaming_text_parts.append(text) + + # If no bundled response, use accumulated streaming text + if not agent_response and streaming_text_parts: + agent_response = "".join(streaming_text_parts) + log.info("%sReconstructed agent response from %d streaming fragments (%d chars)", + log_prefix, len(streaming_text_parts), len(agent_response)) + + except Exception as e: + log.warning("%sError extracting agent response from SSE buffer: %s", log_prefix, e) + + + log.info( + "%sExtracted title data: user_message=%s, agent_response=%s", + log_prefix, + "yes" if user_message else "no", + "yes" if agent_response else "no" + ) + + return { + "user_message": user_message, + "agent_response": agent_response, + "task_id": task_id, + "session_id": task.session_id, + } + + except HTTPException: + raise + except Exception as e: + log.exception("%sError extracting title data: %s", log_prefix, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An error occurred while extracting title data.", + ) + + @router.get("/tasks/{task_id}", tags=["Tasks"]) async def get_task_as_stim_file( task_id: str, diff --git a/src/solace_agent_mesh/gateway/http_sse/services/task_logger_service.py b/src/solace_agent_mesh/gateway/http_sse/services/task_logger_service.py index 6ac21ced0f..5ecbdf853d 100644 --- a/src/solace_agent_mesh/gateway/http_sse/services/task_logger_service.py +++ b/src/solace_agent_mesh/gateway/http_sse/services/task_logger_service.py @@ -75,7 +75,7 @@ def log_event(self, event_data: Dict[str, Any]): repo = TaskRepository() # Infer details from the parsed event - direction, task_id, user_id = self._infer_event_details( + direction, task_id, user_id, session_id = self._infer_event_details( topic, parsed_event, user_properties ) @@ -146,12 +146,14 @@ def log_event(self, event_data: Dict[str, Any]): last_activity_time=current_time, background_execution_enabled=background_execution_enabled, max_execution_time_ms=max_execution_time_ms, + session_id=session_id, # Store session_id for persistent event buffering ) repo.save_task(db, new_task) log.info( f"{self.log_identifier} Created new task record for ID: {task_id}" + (f" with parent: {parent_task_id}" if parent_task_id else "") + (f" (background execution enabled)" if background_execution_enabled else "") + + (f" (session: {session_id})" if session_id else "") ) else: # We received an event for a task we haven't seen the start of. @@ -166,6 +168,7 @@ def log_event(self, event_data: Dict[str, Any]): execution_mode="foreground", last_activity_time=current_time, background_execution_enabled=False, + session_id=session_id, # Store session_id for persistent event buffering ) repo.save_task(db, placeholder_task) log.info( @@ -173,8 +176,19 @@ def log_event(self, event_data: Dict[str, Any]): ) else: # Update last activity time for existing task - task.last_activity_time = now_epoch_ms() - repo.save_task(db, task) + # This is a non-critical update that can fail due to cross-process SQLite concurrency + try: + task.last_activity_time = now_epoch_ms() + repo.save_task(db, task) + except Exception as activity_update_error: + # StaleDataError or other concurrency issues - log and continue + # The task may have been modified/deleted by another process (FastAPI vs SAC) + log.debug( + f"{self.log_identifier} Non-critical: Failed to update last_activity_time for task {task_id}: {activity_update_error}" + ) + # Rollback and begin a new transaction so subsequent operations can continue + db.rollback() + db.begin() # Create and save the event using the sanitized raw payload task_event = TaskEvent( @@ -218,19 +232,6 @@ def log_event(self, event_data: Dict[str, Any]): f"{self.log_identifier} Finalized task record for ID: {task_id} with status: {final_status}" ) - # For background tasks, save chat messages when task completes - # Only save for top-level tasks (no parent_task_id) to avoid duplicates - # Sub-tasks (delegated by orchestrator) have parent_task_id set and contain system prompts - if task_to_update.background_execution_enabled and not task_to_update.parent_task_id: - self._save_chat_messages_for_background_task(db, task_id, task_to_update, repo) - - # Note: The frontend will detect task completion through: - # 1. SSE final_response event (if connected to that task) - # 2. Session list refresh triggered by the ChatProvider - # 3. Database status check when loading sessions - log.info( - f"{self.log_identifier} Background task {task_id} completed and chat messages saved" - ) db.commit() except Exception as e: @@ -288,10 +289,15 @@ def _parse_a2a_event(self, topic: str, payload: dict) -> Union[ def _infer_event_details( self, topic: str, parsed_event: Any, user_props: Dict | None - ) -> tuple[str, str | None, str | None]: - """Infers direction, task_id, and user_id from a parsed A2A event.""" + ) -> tuple[str, str | None, str | None, str | None]: + """Infers direction, task_id, user_id, and session_id from a parsed A2A event. + + Returns: + Tuple of (direction, task_id, user_id, session_id) + """ direction = "unknown" task_id = None + session_id = None # Will be extracted from context_id # Ensure user_props is a dict, not None user_props = user_props or {} user_id = user_props.get("userId") @@ -299,6 +305,10 @@ def _infer_event_details( if isinstance(parsed_event, A2ARequest): direction = "request" task_id = a2a.get_request_id(parsed_event) + # Extract session_id from context_id in the message + message = a2a.get_message_from_send_request(parsed_event) + if message: + session_id = a2a.get_context_id(message) elif isinstance( parsed_event, (A2ATask, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) ): @@ -306,10 +316,13 @@ def _infer_event_details( task_id = getattr(parsed_event, "task_id", None) or getattr( parsed_event, "id", None ) + # Extract session_id from context_id + session_id = getattr(parsed_event, "context_id", None) elif isinstance(parsed_event, JSONRPCError): direction = "error" if isinstance(parsed_event.data, dict): task_id = parsed_event.data.get("taskId") + session_id = parsed_event.data.get("contextId") if not user_id: user_config = user_props.get("a2aUserConfig") or user_props.get("a2a_user_config") @@ -318,7 +331,7 @@ def _infer_event_details( if isinstance(user_profile, dict): user_id = user_profile.get("id") - return direction, str(task_id) if task_id else None, user_id + return direction, str(task_id) if task_id else None, user_id, session_id def _extract_initial_text(self, parsed_event: Any) -> str | None: """Extracts the initial text from a send message request.""" @@ -395,6 +408,10 @@ def _save_chat_messages_for_background_task( Save chat messages for a completed background task by reconstructing them from task events. This ensures chat history is available when users return to a session after a background task completes. Uses upsert to avoid duplicates. + + NOTE: Even if SSE events are buffered (task.events_buffered=True), we still save to chat_tasks + from task_events. The chat_tasks data will have unresolved embeds, but this serves as a fallback. + The frontend should prefer replaying from sse_event_buffer when available. """ try: # Get all events for this task diff --git a/src/solace_agent_mesh/gateway/http_sse/sse_manager.py b/src/solace_agent_mesh/gateway/http_sse/sse_manager.py index 6879ef303c..d67af81d05 100644 --- a/src/solace_agent_mesh/gateway/http_sse/sse_manager.py +++ b/src/solace_agent_mesh/gateway/http_sse/sse_manager.py @@ -11,6 +11,7 @@ import math from .sse_event_buffer import SSEEventBuffer +from .persistent_sse_event_buffer import PersistentSSEEventBuffer log = logging.getLogger(__name__) trace_logger = logging.getLogger("sam_trace") @@ -25,7 +26,15 @@ class SSEManager: different event loops (e.g., FastAPI event loop and SAC component event loop). """ - def __init__(self, max_queue_size: int, event_buffer: SSEEventBuffer, session_factory: Optional[Callable] = None): + def __init__( + self, + max_queue_size: int, + event_buffer: SSEEventBuffer, + session_factory: Optional[Callable] = None, + persistent_buffer_enabled: bool = True, + hybrid_buffer_enabled: bool = False, + hybrid_buffer_threshold: int = 10, + ): self._connections: Dict[str, List[asyncio.Queue]] = {} self._event_buffer = event_buffer # Use a single threading lock for cross-event-loop synchronization @@ -35,6 +44,23 @@ def __init__(self, max_queue_size: int, event_buffer: SSEEventBuffer, session_fa self._session_factory = session_factory self._background_task_cache: Dict[str, bool] = {} # Cache to avoid repeated DB queries self._tasks_with_prior_connection: set = set() # Track tasks that have had at least one SSE connection + + # Initialize persistent buffer for background tasks + # Hybrid mode enables RAM-first buffering + # to reduce database writes for short-lived tasks + self._persistent_buffer = PersistentSSEEventBuffer( + session_factory=session_factory, + enabled=persistent_buffer_enabled, + hybrid_mode_enabled=hybrid_buffer_enabled, + hybrid_flush_threshold=hybrid_buffer_threshold, + ) + + if hybrid_buffer_enabled: + log.info( + "%s Hybrid buffer mode ENABLED (threshold=%d events)", + self.log_identifier, + hybrid_buffer_threshold, + ) def _sanitize_json(self, obj): if isinstance(obj, dict): @@ -102,11 +128,15 @@ async def remove_sse_connection( ): """ Removes a specific SSE connection queue for a task. + + In hybrid mode, flushes RAM buffer to DB when last connection is removed. Args: task_id: The ID of the task. connection_queue: The specific queue instance to remove. """ + should_flush_ram_buffer = False + with self._lock: if task_id in self._connections: try: @@ -124,6 +154,10 @@ async def remove_sse_connection( self.log_identifier, task_id, ) + # In hybrid mode, flush RAM buffer when last connection is removed + # This ensures events are persisted to DB before the client disconnects + if self._persistent_buffer.is_hybrid_mode_enabled(): + should_flush_ram_buffer = True except ValueError: log.debug( "%s Attempted to remove an already removed queue for Task ID: %s.", @@ -131,28 +165,67 @@ async def remove_sse_connection( task_id, ) else: - log.warning( + # This can happen due to timing - task may have been cleaned up already + log.debug( "%s Attempted to remove queue for non-existent Task ID: %s.", self.log_identifier, task_id, ) + + # Flush RAM buffer outside the lock to avoid holding it during DB operations + if should_flush_ram_buffer: + flushed = self._persistent_buffer.flush_task_buffer(task_id) + if flushed > 0: + log.info( + "%s [Hybrid] Flushed %d events from RAM buffer on connection close for task %s", + self.log_identifier, + flushed, + task_id, + ) - def _is_background_task(self, task_id: str) -> bool: + def _is_task_registered_for_buffering(self, task_id: str) -> bool: """ - Check if a task is a background task by querying the database. - Uses caching to avoid repeated queries. + Check if a task is registered for persistent event buffering. + + Note: All tasks get registered for buffering when they're submitted. + This enables session switching, browser refresh recovery, and reconnection + for all tasks. + + This method returns True if the task has metadata registered, which means events + will be persisted to the database for later replay. The in-memory buffer can + safely drop events for these tasks when the client disconnects since the events + are already persisted. + + The order of checks is: + 1. Cache (fastest) + 2. Persistent buffer metadata (registered when task is submitted) + 3. Database query (fallback for legacy tasks or worker restarts) Args: task_id: The ID of the task to check Returns: - True if the task is a background task, False otherwise + True if the task is registered for persistent buffering, False otherwise """ # Check cache first if task_id in self._background_task_cache: return self._background_task_cache[task_id] - # If no session factory, assume not a background task + # Check if task has metadata registered for persistent buffering + # This is set when the task is submitted (before the task record is created in DB) + metadata = self._persistent_buffer.get_task_metadata(task_id) + if metadata is not None: + # Metadata exists - this task is registered for buffering + self._background_task_cache[task_id] = True + log.debug( + "%s Task %s is registered for persistent buffering (has metadata)", + self.log_identifier, + task_id, + ) + return True + + # Fallback to database query + # (for legacy tasks or tasks that were submitted before metadata was registered) if not self._session_factory: return False @@ -163,23 +236,70 @@ def _is_background_task(self, task_id: str) -> bool: try: repo = TaskRepository() task = repo.find_by_id(db, task_id) - is_background = task and task.background_execution_enabled + # For legacy compatibility, also check background_execution_enabled + # These tasks should also have their events persisted + is_registered = task and (task.session_id or task.background_execution_enabled) # Cache the result - self._background_task_cache[task_id] = is_background + self._background_task_cache[task_id] = is_registered - return is_background + # If found in DB with session_id, reconstruct metadata for buffering + # This handles worker restarts where in-memory metadata was lost + if is_registered and task: + session_id = getattr(task, 'session_id', None) + user_id = getattr(task, 'user_id', None) + + if session_id and user_id: + self._persistent_buffer.set_task_metadata(task_id, session_id, user_id) + log.info( + "%s Reconstructed metadata for task %s from database: session=%s, user=%s", + self.log_identifier, + task_id, + session_id, + user_id, + ) + + return is_registered finally: db.close() except Exception as e: log.warning( - "%s Failed to check if task %s is a background task: %s", + "%s Failed to check if task %s is registered for buffering: %s", self.log_identifier, task_id, e, ) return False + def register_task_for_persistent_buffer( + self, + task_id: str, + session_id: str, + user_id: str, + ) -> None: + """ + Register task metadata for persistent buffering. + + This should be called when a background task is created so we have + the session_id and user_id available when buffering events. + + Args: + task_id: The task ID + session_id: The session ID + user_id: The user ID + """ + self._persistent_buffer.set_task_metadata(task_id, session_id, user_id) + log.debug( + "%s Registered task %s for persistent buffering (session=%s)", + self.log_identifier, + task_id, + session_id, + ) + + def get_persistent_buffer(self) -> PersistentSSEEventBuffer: + """Get the persistent buffer instance.""" + return self._persistent_buffer + async def send_event( self, task_id: str, event_data: Dict[str, Any], event_type: str = "message" ): @@ -208,6 +328,45 @@ async def send_event( sse_payload = {"event": event_type, "data": serialized_data} + # Check if this task is registered for persistent buffering + is_registered = self._is_task_registered_for_buffering(task_id) + + # UNIFIED ARCHITECTURE: Always buffer to persistent storage for replay + # This enables session switching, browser refresh recovery, and reconnection for ALL tasks + # The FE will clear the buffer after successfully saving the chat_task + # Note: We check is_enabled() which is tied to the background_tasks feature flag + if self._persistent_buffer.is_enabled(): + # Check if this task has metadata registered (ensures we have session_id/user_id) + # All tasks with registered metadata get buffered + task_metadata = self._persistent_buffer.get_task_metadata(task_id) + if task_metadata is not None: + buffered = self._persistent_buffer.buffer_event( + task_id=task_id, + event_type=event_type, + event_data=sse_payload, # Store the full SSE payload + ) + log.debug( + "%s Buffered event for task %s: type=%s, registered=%s, result=%s", + self.log_identifier, + task_id, + event_type, + is_registered, + buffered, + ) + elif is_registered: + # Fallback for registered tasks without metadata in cache + # This shouldn't happen in normal flow but provides safety + log.warning( + "%s Registered task %s has no metadata in cache, attempting to buffer anyway", + self.log_identifier, + task_id, + ) + self._persistent_buffer.buffer_event( + task_id=task_id, + event_type=event_type, + event_data=sse_payload, + ) + # Get queues and decide action under the lock queues_copy = None @@ -215,19 +374,16 @@ async def send_event( queues = self._connections.get(task_id) if not queues: - # Check if this is a background task (outside lock would be better, - # but we need the decision to be atomic with the buffering) - is_background_task = self._is_background_task(task_id) - # Check if this task has ever had a connection has_had_connection = task_id in self._tasks_with_prior_connection - # Only drop events for background tasks that have HAD a connection before + # For registered tasks that have HAD a connection before, drop in-memory events + # Events are already persisted to database, so we don't need in-memory buffer # If no connection has ever been made, we must buffer so the first client gets the events - if is_background_task and has_had_connection: - # For background tasks where client disconnected, drop events to prevent buffer overflow + if is_registered and has_had_connection: log.debug( - "%s No active SSE connections for background task %s (had prior connection). Dropping event to prevent buffer overflow.", + "%s No active SSE connections for registered task %s (had prior connection). " + "Events persisted to database for replay.", self.log_identifier, task_id, ) @@ -387,10 +543,25 @@ async def close_all_for_task(self, task_id: str): Closes all SSE connections associated with a specific task. If a connection existed, it also cleans up the event buffer. If no connection ever existed, the buffer is left for a late-connecting client. + + In hybrid mode: always flushes RAM buffer to DB to ensure events that came in + after SSE disconnect are persisted for later retrieval. """ queues_to_close = None should_remove_buffer = False + # In hybrid mode, flush RAM buffer to DB BEFORE closing connections + # This ensures any events that arrived after client disconnect are persisted + if self._persistent_buffer.is_hybrid_mode_enabled(): + flushed = self._persistent_buffer.flush_task_buffer(task_id) + if flushed > 0: + log.info( + "%s [Hybrid] Flushed %d events from RAM buffer on task close for task %s", + self.log_identifier, + flushed, + task_id, + ) + with self._lock: if task_id in self._connections: # This is the "normal" case: a client is or was connected. @@ -458,8 +629,21 @@ def cleanup_old_locks(self): pass async def close_all(self): - """Closes all active SSE connections managed by this instance.""" + """Closes all active SSE connections managed by this instance. + + In hybrid mode, flushes all RAM buffers to DB before closing to ensure no events are lost. + """ self.cleanup_old_locks() + + # In hybrid mode, flush all RAM buffers first to ensure no events are lost + if self._persistent_buffer.is_hybrid_mode_enabled(): + flushed = self._persistent_buffer.flush_all_buffers() + if flushed > 0: + log.info( + "%s [Hybrid] Flushed %d events from RAM buffers during shutdown", + self.log_identifier, + flushed, + ) # Collect all queues to close under the lock all_queues_to_close = [] diff --git a/templates/webui.yaml b/templates/webui.yaml index 9364923dd0..a3fcad9997 100644 --- a/templates/webui.yaml +++ b/templates/webui.yaml @@ -63,6 +63,9 @@ apps: task_logging: enabled: true + hybrid_buffer: + enabled: false # false=events go directly to DB (default); true=buffer in RAM first, then flush to DB + flush_threshold: 10 # Flush to DB after N events (lower for easier testing) log_status_updates: true log_artifact_events: true log_file_parts: true diff --git a/tests/unit/gateway/http_sse/routers/test_sessions_unconsumed_events_auth.py b/tests/unit/gateway/http_sse/routers/test_sessions_unconsumed_events_auth.py new file mode 100644 index 0000000000..1cb247fd49 --- /dev/null +++ b/tests/unit/gateway/http_sse/routers/test_sessions_unconsumed_events_auth.py @@ -0,0 +1,290 @@ +"""Unit tests for GET /sessions/{session_id}/events/unconsumed endpoint authorization.""" + +import pytest +from unittest.mock import MagicMock, patch + + +class TestGetSessionUnconsumedEventsAuthorization: + """Tests for GET /sessions/{session_id}/events/unconsumed endpoint authorization.""" + + @pytest.fixture + def mock_db(self): + """Create a mock database session.""" + return MagicMock() + + @pytest.fixture + def mock_session_service(self): + """Create a mock SessionService.""" + return MagicMock() + + @pytest.fixture + def mock_sse_manager(self): + """Create a mock SSEManager.""" + mock_manager = MagicMock() + mock_persistent_buffer = MagicMock() + mock_manager.get_persistent_buffer.return_value = mock_persistent_buffer + return mock_manager, mock_persistent_buffer + + @pytest.mark.asyncio + async def test_returns_404_when_session_belongs_to_different_user( + self, mock_db, mock_session_service + ): + """Test that endpoint returns 404 when session belongs to different user.""" + from fastapi import HTTPException + + # Session service returns None (session not found or not owned by user) + mock_session_service.get_session_details.return_value = None + + requesting_user = {"id": "different-user-id"} + session_id = "test-session-id" + + # Patch at the dependencies module level since it's imported inside the function + with patch( + "solace_agent_mesh.gateway.http_sse.dependencies.get_sac_component", + return_value=MagicMock(), + ): + from solace_agent_mesh.gateway.http_sse.routers.sessions import ( + get_session_unconsumed_events, + ) + + with pytest.raises(HTTPException) as exc_info: + await get_session_unconsumed_events( + session_id=session_id, + include_events=False, + db=mock_db, + user=requesting_user, + session_service=mock_session_service, + ) + + assert exc_info.value.status_code == 404 + # Session not found means either doesn't exist or user doesn't own it + assert "not found" in exc_info.value.detail.lower() + + # Verify the service was called with correct user_id + mock_session_service.get_session_details.assert_called_once_with( + db=mock_db, session_id=session_id, user_id="different-user-id" + ) + + @pytest.mark.asyncio + async def test_allows_access_when_session_belongs_to_same_user( + self, mock_db, mock_session_service, mock_sse_manager + ): + """Test that endpoint allows access when session belongs to requesting user.""" + mock_manager, mock_persistent_buffer = mock_sse_manager + + # Session service returns session (user owns it) + mock_session = MagicMock() + mock_session.id = "test-session-id" + mock_session_service.get_session_details.return_value = mock_session + + # Mock persistent buffer returns empty + mock_persistent_buffer.get_unconsumed_events_for_session.return_value = {} + + mock_component = MagicMock() + mock_component.sse_manager = mock_manager + + requesting_user = {"id": "owner-user-id"} + session_id = "test-session-id" + + # Patch at the dependencies module level since it's imported inside the function + with patch( + "solace_agent_mesh.gateway.http_sse.dependencies.get_sac_component", + return_value=mock_component, + ): + from solace_agent_mesh.gateway.http_sse.routers.sessions import ( + get_session_unconsumed_events, + ) + + result = await get_session_unconsumed_events( + session_id=session_id, + include_events=False, + db=mock_db, + user=requesting_user, + session_service=mock_session_service, + ) + + # Should return successfully + assert result["session_id"] == session_id + assert result["has_events"] is False + assert result["task_ids"] == [] + + @pytest.mark.asyncio + async def test_returns_unconsumed_events_for_valid_session( + self, mock_db, mock_session_service, mock_sse_manager + ): + """Test that endpoint returns unconsumed events when include_events=True.""" + mock_manager, mock_persistent_buffer = mock_sse_manager + + # Session service returns session (user owns it) + mock_session = MagicMock() + mock_session.id = "test-session-id" + mock_session_service.get_session_details.return_value = mock_session + + # Mock persistent buffer returns events for two tasks + mock_persistent_buffer.get_unconsumed_events_for_session.return_value = { + "task-1": [ + {"event_sequence": 1, "event_type": "message", "event_data": {"text": "hello"}} + ], + "task-2": [ + {"event_sequence": 1, "event_type": "artifact", "event_data": {"name": "file.txt"}} + ], + } + + mock_component = MagicMock() + mock_component.sse_manager = mock_manager + + requesting_user = {"id": "owner-user-id"} + session_id = "test-session-id" + + # Patch at the dependencies module level since it's imported inside the function + with patch( + "solace_agent_mesh.gateway.http_sse.dependencies.get_sac_component", + return_value=mock_component, + ): + from solace_agent_mesh.gateway.http_sse.routers.sessions import ( + get_session_unconsumed_events, + ) + + result = await get_session_unconsumed_events( + session_id=session_id, + include_events=True, + db=mock_db, + user=requesting_user, + session_service=mock_session_service, + ) + + # Should return events grouped by task_id + assert result["session_id"] == session_id + assert result["has_events"] is True + assert "task-1" in result["task_ids"] + assert "task-2" in result["task_ids"] + assert "events_by_task" in result + assert len(result["events_by_task"]) == 2 + + @pytest.mark.asyncio + async def test_returns_404_for_invalid_session_id( + self, mock_db, mock_session_service + ): + """Test that endpoint returns 404 for invalid session IDs.""" + from fastapi import HTTPException + + requesting_user = {"id": "some-user-id"} + + # Patch at the dependencies module level since it's imported inside the function + with patch( + "solace_agent_mesh.gateway.http_sse.dependencies.get_sac_component", + return_value=MagicMock(), + ): + from solace_agent_mesh.gateway.http_sse.routers.sessions import ( + get_session_unconsumed_events, + ) + + # Test with "null" string + with pytest.raises(HTTPException) as exc_info: + await get_session_unconsumed_events( + session_id="null", + include_events=False, + db=mock_db, + user=requesting_user, + session_service=mock_session_service, + ) + + assert exc_info.value.status_code == 404 + + # Test with "undefined" string + with pytest.raises(HTTPException) as exc_info: + await get_session_unconsumed_events( + session_id="undefined", + include_events=False, + db=mock_db, + user=requesting_user, + session_service=mock_session_service, + ) + + assert exc_info.value.status_code == 404 + + # Test with empty string + with pytest.raises(HTTPException) as exc_info: + await get_session_unconsumed_events( + session_id="", + include_events=False, + db=mock_db, + user=requesting_user, + session_service=mock_session_service, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_handles_missing_persistent_buffer( + self, mock_db, mock_session_service + ): + """Test that endpoint handles missing persistent buffer gracefully.""" + # Session service returns session (user owns it) + mock_session = MagicMock() + mock_session.id = "test-session-id" + mock_session_service.get_session_details.return_value = mock_session + + # Mock component with no persistent buffer + mock_manager = MagicMock() + mock_manager.get_persistent_buffer.return_value = None + mock_component = MagicMock() + mock_component.sse_manager = mock_manager + + requesting_user = {"id": "owner-user-id"} + session_id = "test-session-id" + + # Patch at the dependencies module level since it's imported inside the function + with patch( + "solace_agent_mesh.gateway.http_sse.dependencies.get_sac_component", + return_value=mock_component, + ): + from solace_agent_mesh.gateway.http_sse.routers.sessions import ( + get_session_unconsumed_events, + ) + + result = await get_session_unconsumed_events( + session_id=session_id, + include_events=False, + db=mock_db, + user=requesting_user, + session_service=mock_session_service, + ) + + # Should return empty result, not error + assert result["has_events"] is False + assert result["task_ids"] == [] + + @pytest.mark.asyncio + async def test_handles_missing_component( + self, mock_db, mock_session_service + ): + """Test that endpoint handles missing component gracefully.""" + # Session service returns session (user owns it) + mock_session = MagicMock() + mock_session.id = "test-session-id" + mock_session_service.get_session_details.return_value = mock_session + + requesting_user = {"id": "owner-user-id"} + session_id = "test-session-id" + + # Patch at the dependencies module level since it's imported inside the function + with patch( + "solace_agent_mesh.gateway.http_sse.dependencies.get_sac_component", + return_value=None, + ): + from solace_agent_mesh.gateway.http_sse.routers.sessions import ( + get_session_unconsumed_events, + ) + + result = await get_session_unconsumed_events( + session_id=session_id, + include_events=False, + db=mock_db, + user=requesting_user, + session_service=mock_session_service, + ) + + # Should return empty result, not error + assert result["has_events"] is False + assert result["task_ids"] == [] diff --git a/tests/unit/gateway/http_sse/routers/test_tasks_buffered_events_auth.py b/tests/unit/gateway/http_sse/routers/test_tasks_buffered_events_auth.py new file mode 100644 index 0000000000..0038a66b15 --- /dev/null +++ b/tests/unit/gateway/http_sse/routers/test_tasks_buffered_events_auth.py @@ -0,0 +1,195 @@ +"""Unit tests for /tasks/{task_id}/events/buffered endpoint authorization.""" + +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + + +class TestGetBufferedTaskEventsAuthorization: + """Tests for GET /tasks/{task_id}/events/buffered endpoint authorization.""" + + @pytest.fixture + def mock_db(self): + """Create a mock database session.""" + return MagicMock() + + @pytest.fixture + def mock_task_repo(self): + """Create a mock TaskRepository.""" + return MagicMock() + + @pytest.fixture + def mock_buffer_repo(self): + """Create a mock SSEEventBufferRepository.""" + return MagicMock() + + @pytest.fixture + def mock_request(self): + """Create a mock FastAPI request.""" + return MagicMock() + + @pytest.mark.asyncio + async def test_returns_403_when_task_belongs_to_different_user( + self, mock_db, mock_task_repo, mock_request + ): + """Test that endpoint returns 403 when task belongs to different user.""" + from fastapi import HTTPException + + # Create a mock task belonging to a different user + mock_task = MagicMock() + mock_task.user_id = "owner-user-id" + mock_task.events_consumed = False + mock_task_repo.find_by_id_with_events.return_value = (mock_task, []) + + requesting_user_id = "different-user-id" + task_id = "test-task-id" + + # User config without read:all scope + user_config = {"scopes": {"tasks:read:all": False}} + + with patch( + "solace_agent_mesh.gateway.http_sse.routers.tasks.TaskRepository", + return_value=mock_task_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + get_buffered_task_events, + ) + + with pytest.raises(HTTPException) as exc_info: + await get_buffered_task_events( + task_id=task_id, + request=mock_request, + db=mock_db, + user_id=requesting_user_id, + user_config=user_config, + repo=mock_task_repo, + mark_consumed=True, + ) + + assert exc_info.value.status_code == 403 + assert "permission" in exc_info.value.detail.lower() + + @pytest.mark.asyncio + async def test_allows_access_when_task_belongs_to_same_user( + self, mock_db, mock_task_repo, mock_buffer_repo, mock_request + ): + """Test that endpoint allows access when task belongs to the requesting user.""" + # Create a mock task belonging to the requesting user + mock_task = MagicMock() + mock_task.user_id = "owner-user-id" + mock_task.events_consumed = False + mock_task_repo.find_by_id_with_events.return_value = (mock_task, []) + + # Mock buffer repo + mock_buffer_repo.has_unconsumed_events.return_value = True + mock_buffer_repo.get_buffered_events.return_value = [ + {"type": "message", "data": {"text": "test"}, "sequence": 1} + ] + + requesting_user_id = "owner-user-id" # Same as owner + task_id = "test-task-id" + user_config = {"scopes": {"tasks:read:all": False}} + + with patch( + "solace_agent_mesh.gateway.http_sse.routers.tasks.TaskRepository", + return_value=mock_task_repo, + ), patch( + "solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository", + return_value=mock_buffer_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + get_buffered_task_events, + ) + + result = await get_buffered_task_events( + task_id=task_id, + request=mock_request, + db=mock_db, + user_id=requesting_user_id, + user_config=user_config, + repo=mock_task_repo, + mark_consumed=True, + ) + + # Should return successfully with events + assert result["task_id"] == task_id + assert "events" in result + + @pytest.mark.asyncio + async def test_allows_access_with_read_all_scope( + self, mock_db, mock_task_repo, mock_buffer_repo, mock_request + ): + """Test that endpoint allows access when user has tasks:read:all scope.""" + # Create a mock task belonging to a different user + mock_task = MagicMock() + mock_task.user_id = "owner-user-id" + mock_task.events_consumed = False + mock_task_repo.find_by_id_with_events.return_value = (mock_task, []) + + # Mock buffer repo + mock_buffer_repo.has_unconsumed_events.return_value = False + mock_buffer_repo.get_event_count.return_value = 0 + + requesting_user_id = "admin-user-id" # Different from owner + task_id = "test-task-id" + # User has read:all scope + user_config = {"scopes": {"tasks:read:all": True}} + + with patch( + "solace_agent_mesh.gateway.http_sse.routers.tasks.TaskRepository", + return_value=mock_task_repo, + ), patch( + "solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository", + return_value=mock_buffer_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + get_buffered_task_events, + ) + + result = await get_buffered_task_events( + task_id=task_id, + request=mock_request, + db=mock_db, + user_id=requesting_user_id, + user_config=user_config, + repo=mock_task_repo, + mark_consumed=True, + ) + + # Should return successfully (empty events) + assert result["task_id"] == task_id + assert result["events_buffered"] is False + + @pytest.mark.asyncio + async def test_returns_404_when_task_not_found( + self, mock_db, mock_task_repo, mock_request + ): + """Test that endpoint returns 404 when task is not found.""" + from fastapi import HTTPException + + mock_task_repo.find_by_id_with_events.return_value = None + + requesting_user_id = "some-user-id" + task_id = "nonexistent-task-id" + user_config = {"scopes": {"tasks:read:all": False}} + + with patch( + "solace_agent_mesh.gateway.http_sse.routers.tasks.TaskRepository", + return_value=mock_task_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + get_buffered_task_events, + ) + + with pytest.raises(HTTPException) as exc_info: + await get_buffered_task_events( + task_id=task_id, + request=mock_request, + db=mock_db, + user_id=requesting_user_id, + user_config=user_config, + repo=mock_task_repo, + mark_consumed=True, + ) + + assert exc_info.value.status_code == 404 + assert "not found" in exc_info.value.detail.lower() diff --git a/tests/unit/gateway/http_sse/routers/test_tasks_clear_buffered_events_auth.py b/tests/unit/gateway/http_sse/routers/test_tasks_clear_buffered_events_auth.py new file mode 100644 index 0000000000..f66d005f43 --- /dev/null +++ b/tests/unit/gateway/http_sse/routers/test_tasks_clear_buffered_events_auth.py @@ -0,0 +1,232 @@ +"""Unit tests for DELETE /tasks/{task_id}/events/buffered endpoint authorization.""" + +import pytest +from unittest.mock import MagicMock, patch + + +class TestClearBufferedTaskEventsAuthorization: + """Tests for DELETE /tasks/{task_id}/events/buffered endpoint authorization.""" + + @pytest.fixture + def mock_db(self): + """Create a mock database session.""" + return MagicMock() + + @pytest.fixture + def mock_request(self): + """Create a mock FastAPI request.""" + return MagicMock() + + @pytest.fixture + def mock_sse_manager(self): + """Create a mock SSEManager.""" + mock_manager = MagicMock() + mock_persistent_buffer = MagicMock() + mock_manager.get_persistent_buffer.return_value = mock_persistent_buffer + return mock_manager, mock_persistent_buffer + + @pytest.fixture + def mock_component(self, mock_sse_manager): + """Create a mock WebUIBackendComponent.""" + mock_comp = MagicMock() + mock_comp.sse_manager = mock_sse_manager[0] + return mock_comp + + @pytest.mark.asyncio + async def test_returns_403_when_task_metadata_belongs_to_different_user( + self, mock_db, mock_request, mock_sse_manager + ): + """Test that endpoint returns 403 when task metadata belongs to different user.""" + from fastapi import HTTPException + + mock_manager, mock_persistent_buffer = mock_sse_manager + + # Task metadata shows different owner + mock_persistent_buffer.get_task_metadata.return_value = { + "user_id": "owner-user-id" + } + + requesting_user_id = "different-user-id" + task_id = "test-task-id" + + mock_component = MagicMock() + mock_component.sse_manager = mock_manager + + with patch( + "solace_agent_mesh.gateway.http_sse.routers.tasks.get_sac_component", + return_value=mock_component, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + clear_buffered_task_events, + ) + + with pytest.raises(HTTPException) as exc_info: + await clear_buffered_task_events( + task_id=task_id, + request=mock_request, + db=mock_db, + user_id=requesting_user_id, + ) + + assert exc_info.value.status_code == 403 + assert "permission" in exc_info.value.detail.lower() + + @pytest.mark.asyncio + async def test_returns_403_when_db_task_belongs_to_different_user( + self, mock_db, mock_request, mock_sse_manager + ): + """Test that endpoint returns 403 when DB task belongs to different user.""" + from fastapi import HTTPException + + mock_manager, mock_persistent_buffer = mock_sse_manager + + # No task metadata (fall back to DB lookup) + mock_persistent_buffer.get_task_metadata.return_value = None + + # Create a mock task from DB belonging to different user + mock_task = MagicMock() + mock_task.user_id = "owner-user-id" + + mock_task_repo = MagicMock() + mock_task_repo.find_by_id.return_value = mock_task + + requesting_user_id = "different-user-id" + task_id = "test-task-id" + + mock_component = MagicMock() + mock_component.sse_manager = mock_manager + + # Patch at the repository module level to intercept the import inside the function + with patch( + "solace_agent_mesh.gateway.http_sse.routers.tasks.get_sac_component", + return_value=mock_component, + ), patch( + "solace_agent_mesh.gateway.http_sse.repository.task_repository.TaskRepository", + return_value=mock_task_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + clear_buffered_task_events, + ) + + with pytest.raises(HTTPException) as exc_info: + await clear_buffered_task_events( + task_id=task_id, + request=mock_request, + db=mock_db, + user_id=requesting_user_id, + ) + + assert exc_info.value.status_code == 403 + assert "permission" in exc_info.value.detail.lower() + + @pytest.mark.asyncio + async def test_allows_access_when_task_metadata_belongs_to_same_user( + self, mock_db, mock_request, mock_sse_manager + ): + """Test that endpoint allows access when task metadata belongs to requesting user.""" + mock_manager, mock_persistent_buffer = mock_sse_manager + + # Task metadata shows same owner + mock_persistent_buffer.get_task_metadata.return_value = { + "user_id": "owner-user-id" + } + mock_persistent_buffer.delete_events_for_task.return_value = 5 + + requesting_user_id = "owner-user-id" # Same as owner + task_id = "test-task-id" + + mock_component = MagicMock() + mock_component.sse_manager = mock_manager + + with patch( + "solace_agent_mesh.gateway.http_sse.routers.tasks.get_sac_component", + return_value=mock_component, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + clear_buffered_task_events, + ) + + result = await clear_buffered_task_events( + task_id=task_id, + request=mock_request, + db=mock_db, + user_id=requesting_user_id, + ) + + # Should return successfully with deleted count + assert result["task_id"] == task_id + assert result["deleted"] == 5 + mock_persistent_buffer.delete_events_for_task.assert_called_once_with(task_id) + + @pytest.mark.asyncio + async def test_allows_access_when_no_metadata_and_no_db_task( + self, mock_db, mock_request, mock_sse_manager + ): + """Test that endpoint allows access when no metadata and no DB task (orphan buffer).""" + mock_manager, mock_persistent_buffer = mock_sse_manager + + # No task metadata + mock_persistent_buffer.get_task_metadata.return_value = None + mock_persistent_buffer.delete_events_for_task.return_value = 0 + + # No DB task either + mock_task_repo = MagicMock() + mock_task_repo.find_by_id.return_value = None + + requesting_user_id = "some-user-id" + task_id = "orphan-task-id" + + mock_component = MagicMock() + mock_component.sse_manager = mock_manager + + # Patch at the repository module level to intercept the import inside the function + with patch( + "solace_agent_mesh.gateway.http_sse.routers.tasks.get_sac_component", + return_value=mock_component, + ), patch( + "solace_agent_mesh.gateway.http_sse.repository.task_repository.TaskRepository", + return_value=mock_task_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + clear_buffered_task_events, + ) + + result = await clear_buffered_task_events( + task_id=task_id, + request=mock_request, + db=mock_db, + user_id=requesting_user_id, + ) + + # Should return successfully (permissive for orphan data) + assert result["task_id"] == task_id + assert "deleted" in result + + @pytest.mark.asyncio + async def test_returns_503_when_component_not_available( + self, mock_db, mock_request + ): + """Test that endpoint returns 503 when component is not available.""" + from fastapi import HTTPException + + requesting_user_id = "some-user-id" + task_id = "test-task-id" + + with patch( + "solace_agent_mesh.gateway.http_sse.routers.tasks.get_sac_component", + return_value=None, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + clear_buffered_task_events, + ) + + with pytest.raises(HTTPException) as exc_info: + await clear_buffered_task_events( + task_id=task_id, + request=mock_request, + db=mock_db, + user_id=requesting_user_id, + ) + + assert exc_info.value.status_code == 503 + assert "not available" in exc_info.value.detail.lower() diff --git a/tests/unit/gateway/http_sse/routers/test_tasks_title_data.py b/tests/unit/gateway/http_sse/routers/test_tasks_title_data.py new file mode 100644 index 0000000000..500c7fe41e --- /dev/null +++ b/tests/unit/gateway/http_sse/routers/test_tasks_title_data.py @@ -0,0 +1,184 @@ +"""Unit tests for /tasks/{task_id}/title-data endpoint authorization.""" + +import pytest +from unittest.mock import MagicMock, patch + + +class TestGetTaskTitleDataAuthorization: + """Tests for GET /tasks/{task_id}/title-data endpoint authorization.""" + + @pytest.fixture + def mock_db(self): + """Create a mock database session.""" + return MagicMock() + + @pytest.fixture + def mock_task_repo(self): + """Create a mock TaskRepository.""" + return MagicMock() + + @pytest.fixture + def mock_buffer_repo(self): + """Create a mock SSEEventBufferRepository.""" + return MagicMock() + + @pytest.fixture + def mock_chat_task_repo(self): + """Create a mock ChatTaskRepository.""" + return MagicMock() + + @pytest.mark.asyncio + async def test_returns_403_when_task_belongs_to_different_user( + self, mock_db, mock_task_repo + ): + """Test that endpoint returns 403 when task belongs to different user.""" + from fastapi import HTTPException + + # Create a mock task belonging to a different user + mock_task = MagicMock() + mock_task.user_id = "owner-user-id" + mock_task.initial_request_text = "test message" + mock_task.session_id = "test-session-id" + mock_task_repo.find_by_id.return_value = mock_task + + requesting_user_id = "different-user-id" + task_id = "test-task-id" + + # Patch at the repository module level since imports are inside the function + with patch( + "solace_agent_mesh.gateway.http_sse.repository.task_repository.TaskRepository", + return_value=mock_task_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + get_task_title_data, + ) + + with pytest.raises(HTTPException) as exc_info: + await get_task_title_data( + task_id=task_id, + db=mock_db, + user_id=requesting_user_id, + ) + + assert exc_info.value.status_code == 403 + assert "permission" in exc_info.value.detail.lower() + + @pytest.mark.asyncio + async def test_allows_access_when_task_belongs_to_same_user( + self, mock_db, mock_task_repo, mock_chat_task_repo, mock_buffer_repo + ): + """Test that endpoint allows access when task belongs to the requesting user.""" + # Create a mock task belonging to the requesting user + mock_task = MagicMock() + mock_task.user_id = "owner-user-id" + mock_task.initial_request_text = "test user message" + mock_task.session_id = "test-session-id" + mock_task_repo.find_by_id.return_value = mock_task + + # Mock chat_task with message bubbles + mock_chat_task = MagicMock() + mock_chat_task.user_message = "test user message" + mock_chat_task.message_bubbles = None # No bubbles to simplify test + mock_chat_task_repo.find_by_id.return_value = mock_chat_task + + # Mock empty buffer + mock_buffer_repo.get_buffered_events.return_value = [] + + requesting_user_id = "owner-user-id" # Same as owner + task_id = "test-task-id" + + with patch( + "solace_agent_mesh.gateway.http_sse.repository.task_repository.TaskRepository", + return_value=mock_task_repo, + ), patch( + "solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository", + return_value=mock_buffer_repo, + ), patch( + "solace_agent_mesh.gateway.http_sse.repository.chat_task_repository.ChatTaskRepository", + return_value=mock_chat_task_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + get_task_title_data, + ) + + result = await get_task_title_data( + task_id=task_id, + db=mock_db, + user_id=requesting_user_id, + ) + + # Should return successfully with the user message + assert result["user_message"] == "test user message" + assert result["task_id"] == task_id + + @pytest.mark.asyncio + async def test_allows_access_when_task_has_no_user_id( + self, mock_db, mock_task_repo, mock_chat_task_repo, mock_buffer_repo + ): + """Test that endpoint allows access when task has no user_id (legacy data).""" + # Create a mock task with no user_id + mock_task = MagicMock() + mock_task.user_id = None # No user_id set + mock_task.initial_request_text = "test user message" + mock_task.session_id = "test-session-id" + mock_task_repo.find_by_id.return_value = mock_task + + # Mock no chat_task found + mock_chat_task_repo.find_by_id.return_value = None + + # Mock empty buffer + mock_buffer_repo.get_buffered_events.return_value = [] + + requesting_user_id = "some-user-id" + task_id = "test-task-id" + + with patch( + "solace_agent_mesh.gateway.http_sse.repository.task_repository.TaskRepository", + return_value=mock_task_repo, + ), patch( + "solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository", + return_value=mock_buffer_repo, + ), patch( + "solace_agent_mesh.gateway.http_sse.repository.chat_task_repository.ChatTaskRepository", + return_value=mock_chat_task_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + get_task_title_data, + ) + + result = await get_task_title_data( + task_id=task_id, + db=mock_db, + user_id=requesting_user_id, + ) + + # Should return successfully (permissive for legacy data) + assert result["user_message"] == "test user message" + assert result["task_id"] == task_id + + @pytest.mark.asyncio + async def test_returns_error_when_task_not_found(self, mock_db, mock_task_repo): + """Test that endpoint returns error info when task is not found.""" + mock_task_repo.find_by_id.return_value = None + + requesting_user_id = "some-user-id" + task_id = "nonexistent-task-id" + + with patch( + "solace_agent_mesh.gateway.http_sse.repository.task_repository.TaskRepository", + return_value=mock_task_repo, + ): + from solace_agent_mesh.gateway.http_sse.routers.tasks import ( + get_task_title_data, + ) + + result = await get_task_title_data( + task_id=task_id, + db=mock_db, + user_id=requesting_user_id, + ) + + # Should return error info, not raise exception + assert result["user_message"] is None + assert result["agent_response"] is None + assert result["error"] == "Task not found" diff --git a/tests/unit/gateway/http_sse/services/test_task_logger_service.py b/tests/unit/gateway/http_sse/services/test_task_logger_service.py index f1ceae5e3d..102980b0f8 100644 --- a/tests/unit/gateway/http_sse/services/test_task_logger_service.py +++ b/tests/unit/gateway/http_sse/services/test_task_logger_service.py @@ -306,14 +306,17 @@ def test_infer_details_from_request(self, service): with patch('solace_agent_mesh.gateway.http_sse.services.task_logger_service.a2a') as mock_a2a: mock_a2a.get_request_id.return_value = "task-123" + mock_a2a.get_context_id.return_value = None + mock_a2a.get_message_from_send_request.return_value = None - direction, task_id, user_id = service._infer_event_details( + direction, task_id, user_id, session_id = service._infer_event_details( "some/topic", mock_request, {"userId": "user-456"} ) assert direction == "request" assert task_id == "task-123" assert user_id == "user-456" + assert session_id is None # No context_id in simple test case def test_infer_details_from_task(self, service): """Test inferring details from A2ATask.""" @@ -322,13 +325,14 @@ def test_infer_details_from_task(self, service): mock_task = Mock(spec=A2ATask) mock_task.id = "task-789" - direction, task_id, user_id = service._infer_event_details( + direction, task_id, user_id, session_id = service._infer_event_details( "some/topic", mock_task, {"userId": "user-123"} ) assert direction == "response" assert task_id == "task-789" assert user_id == "user-123" + assert session_id is None # A2ATask doesn't have context_id def test_infer_details_from_status_update(self, service): """Test inferring details from TaskStatusUpdateEvent.""" @@ -337,12 +341,13 @@ def test_infer_details_from_status_update(self, service): mock_event = Mock(spec=TaskStatusUpdateEvent) mock_event.task_id = "task-status-123" - direction, task_id, user_id = service._infer_event_details( + direction, task_id, user_id, session_id = service._infer_event_details( "some/topic", mock_event, {} ) assert direction == "status" assert task_id == "task-status-123" + assert session_id is None # TaskStatusUpdateEvent doesn't have context_id def test_infer_details_from_error(self, service): """Test inferring details from JSONRPCError.""" @@ -351,12 +356,13 @@ def test_infer_details_from_error(self, service): mock_error = Mock(spec=JSONRPCError) mock_error.data = {"taskId": "task-error-123"} - direction, task_id, user_id = service._infer_event_details( + direction, task_id, user_id, session_id = service._infer_event_details( "some/topic", mock_error, {} ) assert direction == "error" assert task_id == "task-error-123" + assert session_id is None # JSONRPCError doesn't have context_id def test_infer_details_user_id_from_a2a_config(self, service): """Test inferring user_id from a2aUserConfig.""" @@ -366,6 +372,8 @@ def test_infer_details_user_id_from_a2a_config(self, service): with patch('solace_agent_mesh.gateway.http_sse.services.task_logger_service.a2a') as mock_a2a: mock_a2a.get_request_id.return_value = "task-123" + mock_a2a.get_context_id.return_value = None + mock_a2a.get_message_from_send_request.return_value = None user_props = { "a2aUserConfig": { @@ -375,11 +383,12 @@ def test_infer_details_user_id_from_a2a_config(self, service): } } - direction, task_id, user_id = service._infer_event_details( + direction, task_id, user_id, session_id = service._infer_event_details( "some/topic", mock_request, user_props ) assert user_id == "user-from-config" + assert session_id is None # No context_id in simple test case def test_infer_details_with_none_user_props(self, service): """Test inferring details with None user_props.""" @@ -389,14 +398,17 @@ def test_infer_details_with_none_user_props(self, service): with patch('solace_agent_mesh.gateway.http_sse.services.task_logger_service.a2a') as mock_a2a: mock_a2a.get_request_id.return_value = "task-123" + mock_a2a.get_context_id.return_value = None + mock_a2a.get_message_from_send_request.return_value = None - direction, task_id, user_id = service._infer_event_details( + direction, task_id, user_id, session_id = service._infer_event_details( "some/topic", mock_request, None ) assert direction == "request" assert task_id == "task-123" assert user_id is None + assert session_id is None # No user_props, no context_id class TestParseA2AEvent: diff --git a/tests/unit/gateway/http_sse/test_persistent_sse_event_buffer.py b/tests/unit/gateway/http_sse/test_persistent_sse_event_buffer.py new file mode 100644 index 0000000000..59d94a3c8a --- /dev/null +++ b/tests/unit/gateway/http_sse/test_persistent_sse_event_buffer.py @@ -0,0 +1,879 @@ +"""Unit tests for PersistentSSEEventBuffer. + +These tests focus on the core buffer behaviors: +1. Buffering events to RAM (hybrid mode) +2. Task metadata storage and retrieval +3. Buffer mode detection (hybrid vs direct DB) + +These tests use minimal mocking - only testing the in-memory behaviors +that don't require actual database connections. +""" + +import pytest +from unittest.mock import Mock, patch + + +class TestBufferEventRouting: + """Tests for buffer_event routing logic between hybrid and direct modes.""" + + def test_routes_to_db_when_hybrid_disabled(self): + """When hybrid buffer is disabled, events should route to _buffer_event_to_db.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + hybrid_mode_enabled=False, + ) + + # Set up metadata first + buffer.set_task_metadata("task-123", "session-abc", "user-xyz") + + # Mock the DB method + with patch.object(buffer, '_buffer_event_to_db', return_value=True) as mock_db: + with patch.object(buffer, '_buffer_event_hybrid', return_value=True) as mock_hybrid: + buffer.buffer_event( + task_id="task-123", + event_type="message", + event_data={"text": "Hello"}, + ) + + # Should call DB method, not hybrid + mock_db.assert_called_once() + mock_hybrid.assert_not_called() + + def test_routes_to_hybrid_when_hybrid_enabled(self): + """When hybrid buffer is enabled, events should route to _buffer_event_hybrid.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + hybrid_mode_enabled=True, + hybrid_flush_threshold=5, + ) + + # Set up metadata first + buffer.set_task_metadata("task-123", "session-abc", "user-xyz") + + with patch.object(buffer, '_buffer_event_to_db', return_value=True) as mock_db: + with patch.object(buffer, '_buffer_event_hybrid', return_value=True) as mock_hybrid: + buffer.buffer_event( + task_id="task-123", + event_type="message", + event_data={"text": "Hello"}, + ) + + # Should call hybrid method, not DB + mock_hybrid.assert_called_once() + mock_db.assert_not_called() + + +class TestRamBuffer: + """Tests for in-memory RAM buffer operations (hybrid mode).""" + + def test_events_stored_in_ram_buffer(self): + """Events should be stored in RAM buffer in hybrid mode.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + hybrid_mode_enabled=True, + hybrid_flush_threshold=100, # High threshold so no auto-flush + ) + + # Manually call the hybrid buffer method (bypassing metadata checks) + buffer._buffer_event_hybrid( + task_id="task-123", + event_type="message", + event_data={"text": "Event 1"}, + session_id="session-abc", + user_id="user-xyz", + ) + + buffer._buffer_event_hybrid( + task_id="task-123", + event_type="message", + event_data={"text": "Event 2"}, + session_id="session-abc", + user_id="user-xyz", + ) + + # Should have 2 events in RAM + assert buffer.get_ram_buffer_size("task-123") == 2 + + def test_ram_buffer_cleared_after_delete(self): + """RAM buffer should be cleared when delete_events_for_task is called.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + hybrid_mode_enabled=True, + hybrid_flush_threshold=100, + ) + + # Add events to RAM manually + buffer._ram_buffer["task-123"] = [ + ("message", {"text": "Event 1"}, 1000, "session-abc", "user-xyz"), + ("message", {"text": "Event 2"}, 2000, "session-abc", "user-xyz"), + ] + + assert buffer.get_ram_buffer_size("task-123") == 2 + + # Delete - mock the DB portion + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository' + ): + # This will fail to delete from DB but should still clear RAM + count = buffer.delete_events_for_task("task-123") + + # RAM buffer should be empty + assert buffer.get_ram_buffer_size("task-123") == 0 + + +class TestTaskMetadata: + """Tests for task metadata storage (session_id, user_id for authorization).""" + + def test_set_and_get_metadata_from_cache(self): + """Metadata should be stored in cache and retrievable.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=None, # No DB needed for this test + enabled=True, + ) + + # Set metadata + buffer.set_task_metadata("task-123", "session-abc", "user-xyz") + + # Get metadata + metadata = buffer.get_task_metadata("task-123") + + assert metadata is not None + assert metadata["session_id"] == "session-abc" + assert metadata["user_id"] == "user-xyz" + + def test_clear_metadata_removes_from_cache(self): + """clear_task_metadata should remove from in-memory cache.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=None, + enabled=True, + ) + + buffer.set_task_metadata("task-123", "session-abc", "user-xyz") + assert "task-123" in buffer._task_metadata_cache + + buffer.clear_task_metadata("task-123") + assert "task-123" not in buffer._task_metadata_cache + + def test_metadata_survives_across_calls(self): + """Metadata should persist across multiple get calls.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=None, + enabled=True, + ) + + buffer.set_task_metadata("task-123", "session-abc", "user-xyz") + + # Multiple gets should all succeed + for _ in range(3): + metadata = buffer.get_task_metadata("task-123") + assert metadata["session_id"] == "session-abc" + + +class TestBufferEnabled: + """Tests for buffer enabled/disabled states.""" + + def test_is_enabled_when_configured(self): + """Buffer should report enabled when both flag and factory are set.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + ) + + assert buffer.is_enabled() is True + + def test_is_disabled_when_flag_false(self): + """Buffer should report disabled when enabled=False.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=False, + ) + + assert buffer.is_enabled() is False + + def test_is_disabled_when_no_factory(self): + """Buffer should report disabled when no session factory.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=None, + enabled=True, + ) + + assert buffer.is_enabled() is False + + def test_buffer_event_returns_false_when_disabled(self): + """buffer_event should return False when disabled.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=False, + ) + + result = buffer.buffer_event( + task_id="task-123", + event_type="message", + event_data={"text": "Hello"}, + ) + + assert result is False + + +class TestHybridModeDetection: + """Tests for hybrid mode enabled/disabled detection.""" + + def test_hybrid_mode_enabled(self): + """is_hybrid_mode_enabled should return True when properly configured.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + hybrid_mode_enabled=True, + ) + + assert buffer.is_hybrid_mode_enabled() is True + + def test_hybrid_mode_disabled_when_buffer_disabled(self): + """Hybrid mode should be disabled when buffer itself is disabled.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=False, + hybrid_mode_enabled=True, # Flag is true but buffer disabled + ) + + assert buffer.is_hybrid_mode_enabled() is False + + def test_hybrid_mode_disabled_by_default(self): + """Hybrid mode should be disabled by default.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + ) + + assert buffer.is_hybrid_mode_enabled() is False + + +class TestMissingMetadata: + """Tests for handling missing metadata scenarios.""" + + def test_buffer_event_fails_without_metadata(self): + """buffer_event should return False when no metadata is available.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + ) + + # Don't set metadata + result = buffer.buffer_event( + task_id="task-unknown", + event_type="message", + event_data={"text": "Hello"}, + ) + + # Should fail because no metadata + assert result is False + + def test_buffer_event_succeeds_with_explicit_ids(self): + """buffer_event should work when session_id and user_id are provided explicitly.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + hybrid_mode_enabled=False, + ) + + # Mock the DB method to not actually hit database + with patch.object(buffer, '_buffer_event_to_db', return_value=True): + result = buffer.buffer_event( + task_id="task-123", + event_type="message", + event_data={"text": "Hello"}, + session_id="session-abc", # Explicit + user_id="user-xyz", # Explicit + ) + + assert result is True + + +class TestFlushTaskBuffer: + """Tests for flush_task_buffer RAM to DB flushing.""" + + def test_flush_when_hybrid_disabled_returns_zero(self): + """flush_task_buffer should return 0 when hybrid mode disabled.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + hybrid_mode_enabled=False, + ) + + result = buffer.flush_task_buffer("task-123") + assert result == 0 + + def test_flush_empty_buffer_returns_zero(self): + """flush_task_buffer should return 0 when buffer is empty.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + hybrid_mode_enabled=True, + ) + + result = buffer.flush_task_buffer("nonexistent-task") + assert result == 0 + + def test_auto_flush_at_threshold(self): + """Events should auto-flush to DB when threshold reached.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + hybrid_mode_enabled=True, + hybrid_flush_threshold=3, + ) + + # Mock repository + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository' + ) as MockRepo: + mock_repo_instance = Mock() + MockRepo.return_value = mock_repo_instance + + # Add 3 events (threshold) + for i in range(3): + buffer._buffer_event_hybrid( + task_id="task-123", + event_type="message", + event_data={"text": f"Event {i}"}, + session_id="session-abc", + user_id="user-xyz", + ) + + # Should have flushed (called buffer_event 3 times on repo) + assert mock_repo_instance.buffer_event.call_count == 3 + + def test_flush_failure_readds_to_buffer(self): + """Failed flush should re-add events to buffer for retry.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session_factory.side_effect = Exception("DB connection failed") + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + hybrid_mode_enabled=True, + hybrid_flush_threshold=100, + ) + + # Add events to RAM manually + buffer._ram_buffer["task-123"] = [ + ("message", {"text": "Event 1"}, 1000, "session-abc", "user-xyz"), + ("message", {"text": "Event 2"}, 2000, "session-abc", "user-xyz"), + ] + + # Flush should fail + result = buffer.flush_task_buffer("task-123") + assert result == 0 + + # Events should be back in buffer + assert buffer.get_ram_buffer_size("task-123") == 2 + + +class TestFlushAllBuffers: + """Tests for flush_all_buffers batch flushing.""" + + def test_flush_all_when_hybrid_disabled_returns_zero(self): + """flush_all_buffers should return 0 when hybrid mode disabled.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + hybrid_mode_enabled=False, + ) + + result = buffer.flush_all_buffers() + assert result == 0 + + def test_flush_all_multiple_tasks(self): + """flush_all_buffers should flush all tasks.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + hybrid_mode_enabled=True, + hybrid_flush_threshold=100, + ) + + # Add events to RAM for multiple tasks + buffer._ram_buffer["task-1"] = [ + ("message", {"text": "Event 1"}, 1000, "session-abc", "user-xyz"), + ] + buffer._ram_buffer["task-2"] = [ + ("message", {"text": "Event 2"}, 2000, "session-abc", "user-xyz"), + ("message", {"text": "Event 3"}, 3000, "session-abc", "user-xyz"), + ] + + # Mock repository + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository' + ) as MockRepo: + mock_repo_instance = Mock() + MockRepo.return_value = mock_repo_instance + + result = buffer.flush_all_buffers() + + # Should have flushed 3 total events (1 + 2) + assert result == 3 + assert mock_repo_instance.buffer_event.call_count == 3 + + +class TestGetBufferedEvents: + """Tests for get_buffered_events retrieval.""" + + def test_get_buffered_events_when_disabled_returns_empty(self): + """get_buffered_events should return empty list when disabled.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=False, + ) + + result = buffer.get_buffered_events("task-123") + assert result == [] + + def test_get_buffered_events_flushes_ram_first_in_hybrid(self): + """get_buffered_events should flush RAM to DB first in hybrid mode.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + hybrid_mode_enabled=True, + ) + + with patch.object(buffer, 'flush_task_buffer') as mock_flush: + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository' + ) as MockRepo: + mock_repo_instance = Mock() + mock_repo_instance.get_buffered_events.return_value = [] + MockRepo.return_value = mock_repo_instance + + buffer.get_buffered_events("task-123") + + # flush_task_buffer should be called first + mock_flush.assert_called_once_with("task-123") + + +class TestHasUnconsumedEvents: + """Tests for has_unconsumed_events checking.""" + + def test_has_unconsumed_when_disabled_returns_false(self): + """has_unconsumed_events should return False when disabled.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=False, + ) + + result = buffer.has_unconsumed_events("task-123") + assert result is False + + def test_has_unconsumed_checks_ram_in_hybrid_mode(self): + """has_unconsumed_events should check RAM buffer first in hybrid mode.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=True, + hybrid_mode_enabled=True, + ) + + # Add events to RAM + buffer._ram_buffer["task-123"] = [ + ("message", {"text": "Event 1"}, 1000, "session-abc", "user-xyz"), + ] + + # Should return True without hitting DB because RAM has events + result = buffer.has_unconsumed_events("task-123") + assert result is True + + def test_has_unconsumed_checks_db_when_ram_empty(self): + """has_unconsumed_events should check DB when RAM is empty in hybrid mode.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + hybrid_mode_enabled=True, + ) + + # Mock repository + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository' + ) as MockRepo: + mock_repo_instance = Mock() + mock_repo_instance.has_unconsumed_events.return_value = True + MockRepo.return_value = mock_repo_instance + + result = buffer.has_unconsumed_events("task-123") + + # Should check DB + mock_repo_instance.has_unconsumed_events.assert_called_once() + assert result is True + + +class TestGetUnconsumedEventsForSession: + """Tests for get_unconsumed_events_for_session session-level retrieval.""" + + def test_get_unconsumed_for_session_when_disabled_returns_empty(self): + """get_unconsumed_events_for_session should return empty dict when disabled.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=False, + ) + + result = buffer.get_unconsumed_events_for_session("session-abc") + assert result == {} + + def test_get_unconsumed_for_session_groups_by_task(self): + """get_unconsumed_events_for_session should group events by task_id.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + ) + + # Mock repository to return events from multiple tasks + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository' + ) as MockRepo: + # Create mock event objects + mock_event_1 = Mock() + mock_event_1.task_id = "task-1" + mock_event_1.event_type = "message" + mock_event_1.event_data = {"text": "Hello"} + mock_event_1.event_sequence = 1 + mock_event_1.created_at = 1000 + + mock_event_2 = Mock() + mock_event_2.task_id = "task-1" + mock_event_2.event_type = "message" + mock_event_2.event_data = {"text": "World"} + mock_event_2.event_sequence = 2 + mock_event_2.created_at = 2000 + + mock_event_3 = Mock() + mock_event_3.task_id = "task-2" + mock_event_3.event_type = "artifact" + mock_event_3.event_data = {"name": "file.txt"} + mock_event_3.event_sequence = 1 + mock_event_3.created_at = 3000 + + mock_repo_instance = Mock() + mock_repo_instance.get_unconsumed_events_for_session.return_value = [ + mock_event_1, mock_event_2, mock_event_3 + ] + MockRepo.return_value = mock_repo_instance + + result = buffer.get_unconsumed_events_for_session("session-abc") + + # Should have 2 tasks + assert len(result) == 2 + assert "task-1" in result + assert "task-2" in result + + # task-1 should have 2 events + assert len(result["task-1"]) == 2 + # task-2 should have 1 event + assert len(result["task-2"]) == 1 + + +class TestDeleteEventsForTask: + """Tests for delete_events_for_task cleanup.""" + + def test_delete_when_disabled_returns_zero(self): + """delete_events_for_task should return 0 when disabled.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=False, + ) + + result = buffer.delete_events_for_task("task-123") + assert result == 0 + + def test_delete_clears_metadata(self): + """delete_events_for_task should clear task metadata cache.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + ) + + # Set metadata + buffer.set_task_metadata("task-123", "session-abc", "user-xyz") + assert buffer.get_task_metadata("task-123") is not None + + # Mock repository + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository' + ) as MockRepo: + mock_repo_instance = Mock() + mock_repo_instance.delete_events_for_task.return_value = 5 + MockRepo.return_value = mock_repo_instance + + buffer.delete_events_for_task("task-123") + + # Metadata should be cleared + assert buffer.get_task_metadata("task-123") is None + + +class TestCleanupOldEvents: + """Tests for cleanup_old_events scheduled cleanup.""" + + def test_cleanup_when_disabled_returns_zero(self): + """cleanup_old_events should return 0 when disabled.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + buffer = PersistentSSEEventBuffer( + session_factory=Mock(), + enabled=False, + ) + + result = buffer.cleanup_old_events(days=7) + assert result == 0 + + def test_cleanup_calls_repository(self): + """cleanup_old_events should call repository cleanup method.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + ) + + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.sse_event_buffer_repository.SSEEventBufferRepository' + ) as MockRepo: + with patch( + 'solace_agent_mesh.shared.utils.timestamp_utils.now_epoch_ms', + return_value=1000000000 # Mock timestamp + ): + mock_repo_instance = Mock() + mock_repo_instance.cleanup_consumed_events.return_value = 42 + MockRepo.return_value = mock_repo_instance + + result = buffer.cleanup_old_events(days=7) + + # Should call repository cleanup + mock_repo_instance.cleanup_consumed_events.assert_called_once() + assert result == 42 + + +class TestGetTaskMetadataDbFallback: + """Tests for get_task_metadata database fallback.""" + + def test_metadata_fallback_to_db_when_not_in_cache(self): + """get_task_metadata should fall back to DB when not in cache.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + ) + + # Mock TaskRepository + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.task_repository.TaskRepository' + ) as MockRepo: + mock_task = Mock() + mock_task.session_id = "session-from-db" + mock_task.user_id = "user-from-db" + + mock_repo_instance = Mock() + mock_repo_instance.find_by_id.return_value = mock_task + MockRepo.return_value = mock_repo_instance + + # Should return data from DB + metadata = buffer.get_task_metadata("task-123") + + assert metadata is not None + assert metadata["session_id"] == "session-from-db" + assert metadata["user_id"] == "user-from-db" + + # Should now be in cache + assert "task-123" in buffer._task_metadata_cache + + def test_metadata_db_fallback_returns_none_if_not_found(self): + """get_task_metadata should return None if not in cache or DB.""" + from solace_agent_mesh.gateway.http_sse.persistent_sse_event_buffer import ( + PersistentSSEEventBuffer, + ) + + mock_session_factory = Mock() + mock_session = Mock() + mock_session_factory.return_value = mock_session + + buffer = PersistentSSEEventBuffer( + session_factory=mock_session_factory, + enabled=True, + ) + + # Mock TaskRepository to return None + with patch( + 'solace_agent_mesh.gateway.http_sse.repository.task_repository.TaskRepository' + ) as MockRepo: + mock_repo_instance = Mock() + mock_repo_instance.find_by_id.return_value = None + MockRepo.return_value = mock_repo_instance + + metadata = buffer.get_task_metadata("task-unknown") + assert metadata is None