diff --git a/src-tauri/src/bin/codex_monitor_daemon.rs b/src-tauri/src/bin/codex_monitor_daemon.rs index 59bfc00bc..b4bc42b4d 100644 --- a/src-tauri/src/bin/codex_monitor_daemon.rs +++ b/src-tauri/src/bin/codex_monitor_daemon.rs @@ -786,6 +786,15 @@ impl DaemonState { codex_core::archive_thread_core(&self.sessions, workspace_id, thread_id).await } + async fn rollback_thread( + &self, + workspace_id: String, + thread_id: String, + turn_id: String, + ) -> Result { + codex_core::rollback_thread_core(&self.sessions, workspace_id, thread_id, turn_id).await + } + async fn compact_thread( &self, workspace_id: String, diff --git a/src-tauri/src/bin/codex_monitor_daemon/rpc/codex.rs b/src-tauri/src/bin/codex_monitor_daemon/rpc/codex.rs index cc278938e..9a8803d7e 100644 --- a/src-tauri/src/bin/codex_monitor_daemon/rpc/codex.rs +++ b/src-tauri/src/bin/codex_monitor_daemon/rpc/codex.rs @@ -130,6 +130,21 @@ pub(super) async fn try_handle( }; Some(state.archive_thread(workspace_id, thread_id).await) } + "rollback_thread" => { + let workspace_id = match parse_string(params, "workspaceId") { + Ok(value) => value, + Err(err) => return Some(Err(err)), + }; + let thread_id = match parse_string(params, "threadId") { + Ok(value) => value, + Err(err) => return Some(Err(err)), + }; + let turn_id = match parse_string(params, "turnId") { + Ok(value) => value, + Err(err) => return Some(Err(err)), + }; + Some(state.rollback_thread(workspace_id, thread_id, turn_id).await) + } "compact_thread" => { let workspace_id = match parse_string(params, "workspaceId") { Ok(value) => value, diff --git a/src-tauri/src/codex/mod.rs b/src-tauri/src/codex/mod.rs index e55d1e9ee..d9bc9a2de 100644 --- a/src-tauri/src/codex/mod.rs +++ b/src-tauri/src/codex/mod.rs @@ -294,6 +294,27 @@ pub(crate) async fn archive_thread( codex_core::archive_thread_core(&state.sessions, workspace_id, thread_id).await } +#[tauri::command] +pub(crate) async fn rollback_thread( + workspace_id: String, + thread_id: String, + turn_id: String, + state: State<'_, AppState>, + app: AppHandle, +) -> Result { + if remote_backend::is_remote_mode(&*state).await { + return remote_backend::call_remote( + &*state, + app, + "rollback_thread", + json!({ "workspaceId": workspace_id, "threadId": thread_id, "turnId": turn_id }), + ) + .await; + } + + codex_core::rollback_thread_core(&state.sessions, workspace_id, thread_id, turn_id).await +} + #[tauri::command] pub(crate) async fn compact_thread( workspace_id: String, diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 83e9dacae..49aabc462 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -226,6 +226,7 @@ pub fn run() { codex::list_threads, codex::list_mcp_server_status, codex::archive_thread, + codex::rollback_thread, codex::compact_thread, codex::set_thread_name, codex::collaboration_mode_list, diff --git a/src-tauri/src/shared/codex_core.rs b/src-tauri/src/shared/codex_core.rs index a0f2a4ea6..38a4bf6b1 100644 --- a/src-tauri/src/shared/codex_core.rs +++ b/src-tauri/src/shared/codex_core.rs @@ -372,6 +372,57 @@ pub(crate) async fn archive_thread_core( .await } +pub(crate) async fn rollback_thread_core( + sessions: &Mutex>>, + workspace_id: String, + thread_id: String, + turn_id: String, +) -> Result { + let session = get_session_clone(sessions, &workspace_id).await?; + let thread_response = read_thread_core(sessions, workspace_id.clone(), thread_id.clone()).await?; + let num_turns = rollback_num_turns_from_response(&thread_response, &turn_id)?; + let params = json!({ "threadId": thread_id, "numTurns": num_turns }); + session + .send_request_for_workspace(&workspace_id, "thread/rollback", params) + .await +} + +fn rollback_num_turns_from_response(response: &Value, turn_id: &str) -> Result { + let thread = extract_thread_from_response(response) + .ok_or_else(|| "Rollback failed: thread/read response missing thread payload.".to_string())?; + let turns = thread + .get("turns") + .and_then(Value::as_array) + .ok_or_else(|| "Rollback failed: thread/read response missing turns.".to_string())?; + let turn_index = turns + .iter() + .position(|turn| { + turn.as_object().is_some_and(|record| { + record + .get("id") + .or_else(|| record.get("turnId")) + .or_else(|| record.get("turn_id")) + .and_then(Value::as_str) + .is_some_and(|value| value == turn_id) + }) + }) + .ok_or_else(|| format!("Rollback failed: turn '{turn_id}' was not found in thread."))?; + Ok(turns.len().saturating_sub(turn_index)) +} + +fn extract_thread_from_response<'a>(response: &'a Value) -> Option<&'a Map> { + response + .as_object() + .and_then(|record| { + record + .get("result") + .and_then(Value::as_object) + .and_then(|result| result.get("thread")) + .or_else(|| record.get("thread")) + }) + .and_then(Value::as_object) +} + pub(crate) async fn compact_thread_core( sessions: &Mutex>>, workspace_id: String, @@ -1030,4 +1081,51 @@ mod tests { assert!(THREAD_LIST_SOURCE_KINDS.contains(&"subAgentCompact")); assert!(THREAD_LIST_SOURCE_KINDS.contains(&"subAgentThreadSpawn")); } + + #[test] + fn rollback_num_turns_counts_from_target_turn_to_end() { + let response = json!({ + "result": { + "thread": { + "turns": [ + { "id": "turn-1" }, + { "id": "turn-2" }, + { "id": "turn-3" } + ] + } + } + }); + + let num_turns = rollback_num_turns_from_response(&response, "turn-2").unwrap(); + assert_eq!(num_turns, 2); + } + + #[test] + fn rollback_num_turns_supports_turn_id_aliases() { + let response = json!({ + "thread": { + "turns": [ + { "turn_id": "turn-1" }, + { "turnId": "turn-2" } + ] + } + }); + + let num_turns = rollback_num_turns_from_response(&response, "turn-2").unwrap(); + assert_eq!(num_turns, 1); + } + + #[test] + fn rollback_num_turns_errors_when_turn_is_missing() { + let response = json!({ + "result": { + "thread": { + "turns": [{ "id": "turn-1" }] + } + } + }); + + let err = rollback_num_turns_from_response(&response, "turn-9").unwrap_err(); + assert!(err.contains("turn-9")); + } } diff --git a/src/features/app/components/MainApp.tsx b/src/features/app/components/MainApp.tsx index 1ce7cc587..255da6202 100644 --- a/src/features/app/components/MainApp.tsx +++ b/src/features/app/components/MainApp.tsx @@ -66,6 +66,7 @@ import { useRemoteThreadLiveConnection } from "@app/hooks/useRemoteThreadLiveCon import { useTrayRecentThreads } from "@app/hooks/useTrayRecentThreads"; import { useTraySessionUsage } from "@app/hooks/useTraySessionUsage"; import { useTauriEvent } from "@app/hooks/useTauriEvent"; +import { useMessageEdit } from "@/features/messages/hooks/useMessageEdit"; import { useAppBootstrapOrchestration } from "@app/bootstrap/useAppBootstrapOrchestration"; import { useThreadCodexBootstrapOrchestration, @@ -494,6 +495,7 @@ export default function MainApp() { handleUserInputSubmit, refreshAccountInfo, refreshAccountRateLimits, + editAndRegenerateMessage, } = useThreads({ activeWorkspace, onWorkspaceConnected: markWorkspaceConnected, @@ -516,6 +518,15 @@ export default function MainApp() { threadSortKey: threadListSortKey, onThreadCodexMetadataDetected: handleThreadCodexMetadataDetected, }); + + const messageEditState = useMessageEdit({ + onRegenerate: async (itemId, newText, images) => { + if (!activeWorkspace || !activeThreadId) { + return; + } + await editAndRegenerateMessage(activeWorkspace, activeThreadId, itemId, newText, images); + }, + }); const { connectionState: remoteThreadConnectionState, reconnectLive } = useRemoteThreadLiveConnection({ backendMode: appSettings.backendMode, @@ -1648,6 +1659,7 @@ export default function MainApp() { promptActions, worktreeState, sidebarHandlers: sidebarMenuOrchestration, + messageEditState, displayNodes, threadPinning: { pinThread, diff --git a/src/features/app/hooks/useAppServerEvents.test.tsx b/src/features/app/hooks/useAppServerEvents.test.tsx index 92bfb98e3..4e448ee75 100644 --- a/src/features/app/hooks/useAppServerEvents.test.tsx +++ b/src/features/app/hooks/useAppServerEvents.test.tsx @@ -61,6 +61,7 @@ describe("useAppServerEvents", () => { onPlanDelta: vi.fn(), onApprovalRequest: vi.fn(), onRequestUserInput: vi.fn(), + onItemStarted: vi.fn(), onItemCompleted: vi.fn(), onAgentMessageCompleted: vi.fn(), onAccountRateLimitsUpdated: vi.fn(), @@ -324,6 +325,7 @@ describe("useAppServerEvents", () => { method: "item/completed", params: { threadId: "thread-1", + turnId: "turn-2", item: { type: "agentMessage", id: "item-2", text: "Done" }, }, }, @@ -332,6 +334,7 @@ describe("useAppServerEvents", () => { expect(handlers.onItemCompleted).toHaveBeenCalledWith("ws-1", "thread-1", { type: "agentMessage", id: "item-2", + turnId: "turn-2", text: "Done", }); expect(handlers.onAgentMessageCompleted).toHaveBeenCalledWith({ @@ -341,6 +344,25 @@ describe("useAppServerEvents", () => { text: "Done", }); + act(() => { + listener?.({ + workspace_id: "ws-1", + message: { + method: "item/started", + params: { + threadId: "thread-1", + turnId: "turn-3", + item: { type: "userMessage", id: "item-3" }, + }, + }, + }); + }); + expect(handlers.onItemStarted).toHaveBeenCalledWith("ws-1", "thread-1", { + type: "userMessage", + id: "item-3", + turnId: "turn-3", + }); + act(() => { listener?.({ workspace_id: "ws-1", diff --git a/src/features/app/hooks/useAppServerEvents.ts b/src/features/app/hooks/useAppServerEvents.ts index 6f21dafb4..012bbdbf9 100644 --- a/src/features/app/hooks/useAppServerEvents.ts +++ b/src/features/app/hooks/useAppServerEvents.ts @@ -467,8 +467,10 @@ export function useAppServerEvents(handlers: AppServerEventHandlers) { if (method === "item/completed") { const threadId = String(params.threadId ?? params.thread_id ?? ""); const item = params.item as Record | undefined; + const turnId = String(params.turnId ?? params.turn_id ?? "").trim(); if (threadId && item) { - currentHandlers.onItemCompleted?.(workspace_id, threadId, item); + const itemWithTurnId = turnId ? { ...item, turnId } : item; + currentHandlers.onItemCompleted?.(workspace_id, threadId, itemWithTurnId); } if (threadId && item?.type === "agentMessage") { const itemId = String(item.id ?? ""); @@ -488,8 +490,10 @@ export function useAppServerEvents(handlers: AppServerEventHandlers) { if (method === "item/started") { const threadId = String(params.threadId ?? params.thread_id ?? ""); const item = params.item as Record | undefined; + const turnId = String(params.turnId ?? params.turn_id ?? "").trim(); if (threadId && item) { - currentHandlers.onItemStarted?.(workspace_id, threadId, item); + const itemWithTurnId = turnId ? { ...item, turnId } : item; + currentHandlers.onItemStarted?.(workspace_id, threadId, itemWithTurnId); } return; } diff --git a/src/features/app/hooks/useMainAppLayoutSurfaces.ts b/src/features/app/hooks/useMainAppLayoutSurfaces.ts index b6ac05279..883d2120e 100644 --- a/src/features/app/hooks/useMainAppLayoutSurfaces.ts +++ b/src/features/app/hooks/useMainAppLayoutSurfaces.ts @@ -1,5 +1,6 @@ import type { RefObject } from "react"; -import type { AppSettings, ComposerEditorSettings, WorkspaceInfo } from "@/types"; +import type { AppSettings, ComposerEditorSettings, ConversationItem, WorkspaceInfo } from "@/types"; +import type { UseMessageEditResult } from "@/features/messages/hooks/useMessageEdit"; import type { ThreadState } from "@/features/threads/hooks/useThreadsReducer"; import type { WorkspaceLaunchScriptsState } from "@app/hooks/useWorkspaceLaunchScripts"; import { REMOTE_THREAD_POLL_INTERVAL_MS } from "@app/hooks/useRemoteThreadRefreshOnFocus"; @@ -225,6 +226,7 @@ type UseMainAppLayoutSurfacesArgs = { dismissErrorToast: LayoutNodesOptions["primary"]["errorToastsProps"]["onDismiss"]; showDebugButton: boolean; handleDebugClick: () => void; + messageEditState?: UseMessageEditResult; }; type MainAppLayoutSurfacesContext = UseMainAppLayoutSurfacesArgs & { @@ -374,6 +376,7 @@ function buildPrimarySurface({ dismissErrorToast, showDebugButton, handleDebugClick, + messageEditState, }: MainAppLayoutSurfacesContext): LayoutNodesOptions["primary"] { return { sidebarProps: { @@ -465,6 +468,16 @@ function buildPrimarySurface({ : null, showPollingFetchStatus: showMobilePollingFetchStatus, pollingIntervalMs: REMOTE_THREAD_POLL_INTERVAL_MS, + editingItemId: messageEditState?.editingItemId, + editText: messageEditState?.editText, + isConfirmingEdit: messageEditState?.isConfirming, + isRegeneratingEdit: messageEditState?.isRegenerating, + onStartEdit: messageEditState?.startEdit ? (item: Extract) => messageEditState.startEdit(item.id, item.text, item.images) : undefined, + onCancelEdit: messageEditState?.cancelEdit, + onUpdateEditText: messageEditState?.updateEditText, + onRequestRegenerate: messageEditState?.requestRegenerate, + onCancelConfirm: messageEditState?.cancelConfirm, + onExecuteRegenerate: messageEditState?.executeRegenerate, }, composerProps: composerWorkspaceState.showComposer ? { @@ -1098,6 +1111,7 @@ export function useMainAppLayoutSurfaces({ dismissErrorToast, showDebugButton, handleDebugClick, + messageEditState, }: UseMainAppLayoutSurfacesArgs): LayoutNodesOptions { const sidebarRateLimits = activeWorkspace ? activeRateLimits : homeRateLimits; const sidebarAccount = activeWorkspace ? activeAccount : homeAccount; @@ -1260,6 +1274,7 @@ export function useMainAppLayoutSurfaces({ dismissErrorToast, showDebugButton, handleDebugClick, + messageEditState, sidebarRateLimits, sidebarAccount, }; diff --git a/src/features/messages/components/MessageRows.tsx b/src/features/messages/components/MessageRows.tsx index 9e2848606..3310e8663 100644 --- a/src/features/messages/components/MessageRows.tsx +++ b/src/features/messages/components/MessageRows.tsx @@ -8,9 +8,11 @@ import Diff from "lucide-react/dist/esm/icons/diff"; import FileDiffIcon from "lucide-react/dist/esm/icons/file-diff"; import FileText from "lucide-react/dist/esm/icons/file-text"; import Image from "lucide-react/dist/esm/icons/image"; +import Pencil from "lucide-react/dist/esm/icons/pencil"; import Quote from "lucide-react/dist/esm/icons/quote"; import Search from "lucide-react/dist/esm/icons/search"; import Terminal from "lucide-react/dist/esm/icons/terminal"; +import TriangleAlert from "lucide-react/dist/esm/icons/triangle-alert"; import Users from "lucide-react/dist/esm/icons/users"; import Wrench from "lucide-react/dist/esm/icons/wrench"; import X from "lucide-react/dist/esm/icons/x"; @@ -61,6 +63,16 @@ type MessageRowProps = MarkdownFileLinkProps & { onCopy: (item: Extract) => void; onQuote?: (item: Extract, selectedText?: string) => void; codeBlockCopyUseModifier?: boolean; + isEditing?: boolean; + editText?: string; + isConfirming?: boolean; + isRegenerating?: boolean; + onStartEdit?: (item: Extract) => void; + onCancelEdit?: () => void; + onUpdateEditText?: (text: string) => void; + onRequestRegenerate?: () => void; + onCancelConfirm?: () => void; + onExecuteRegenerate?: () => void; }; type ReasoningRowProps = MarkdownFileLinkProps & { @@ -377,11 +389,24 @@ export const MessageRow = memo(function MessageRow({ onOpenFileLink, onOpenFileLinkMenu, onOpenThreadLink, + isEditing = false, + editText = "", + isConfirming = false, + isRegenerating = false, + onStartEdit, + onCancelEdit, + onUpdateEditText, + onRequestRegenerate, + onCancelConfirm, + onExecuteRegenerate, }: MessageRowProps) { const [lightboxIndex, setLightboxIndex] = useState(null); const bubbleRef = useRef(null); + const editTextareaRef = useRef(null); const selectionSnapshotRef = useRef(null); const hasText = item.text.trim().length > 0; + const isUserMessage = item.role === "user"; + const canEdit = isUserMessage && Boolean(onStartEdit) && !isRegenerating; const imageItems = useMemo(() => { if (!item.images || item.images.length === 0) { return []; @@ -402,6 +427,28 @@ export const MessageRow = memo(function MessageRow({ imageItems.length === 0 && isStandaloneMarkdownTable(item.text); + useEffect(() => { + if (isEditing && editTextareaRef.current) { + const textarea = editTextareaRef.current; + textarea.focus(); + textarea.setSelectionRange(textarea.value.length, textarea.value.length); + } + }, [isEditing]); + + const handleEditKeyDown = useCallback( + (event: React.KeyboardEvent) => { + if (event.key === "Escape") { + event.preventDefault(); + if (isConfirming) { + onCancelConfirm?.(); + } else { + onCancelEdit?.(); + } + } + }, + [isConfirming, onCancelConfirm, onCancelEdit], + ); + const getSelectedMessageText = useCallback(() => { const bubble = bubbleRef.current; const selection = window.getSelection(); @@ -422,7 +469,7 @@ export const MessageRow = memo(function MessageRow({ return false; } const element = node instanceof Element ? node : node.parentElement; - return Boolean(element?.closest(".message-quote-button, .message-copy-button")); + return Boolean(element?.closest(".message-quote-button, .message-copy-button, .message-edit-button")); }; if (isWithinMessageControls(selection.anchorNode) || isWithinMessageControls(selection.focusNode)) { @@ -440,6 +487,84 @@ export const MessageRow = memo(function MessageRow({ onQuote(item, selectedText); }, [getSelectedMessageText, item, onQuote]); + if (isEditing) { + return ( +
+
+ {imageItems.length > 0 && ( + + )} +