Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src-tauri/src/bin/codex_monitor_daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,15 @@ impl DaemonState {
codex_core::archive_thread_core(&self.sessions, workspace_id, thread_id).await
}

async fn rollback_thread(
&self,
workspace_id: String,
thread_id: String,
turn_id: String,
) -> Result<Value, String> {
codex_core::rollback_thread_core(&self.sessions, workspace_id, thread_id, turn_id).await
}

async fn compact_thread(
&self,
workspace_id: String,
Expand Down
15 changes: 15 additions & 0 deletions src-tauri/src/bin/codex_monitor_daemon/rpc/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,21 @@ pub(super) async fn try_handle(
};
Some(state.archive_thread(workspace_id, thread_id).await)
}
"rollback_thread" => {
let workspace_id = match parse_string(params, "workspaceId") {
Ok(value) => value,
Err(err) => return Some(Err(err)),
};
let thread_id = match parse_string(params, "threadId") {
Ok(value) => value,
Err(err) => return Some(Err(err)),
};
let turn_id = match parse_string(params, "turnId") {
Ok(value) => value,
Err(err) => return Some(Err(err)),
};
Some(state.rollback_thread(workspace_id, thread_id, turn_id).await)
}
"compact_thread" => {
let workspace_id = match parse_string(params, "workspaceId") {
Ok(value) => value,
Expand Down
21 changes: 21 additions & 0 deletions src-tauri/src/codex/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,27 @@ pub(crate) async fn archive_thread(
codex_core::archive_thread_core(&state.sessions, workspace_id, thread_id).await
}

#[tauri::command]
pub(crate) async fn rollback_thread(
workspace_id: String,
thread_id: String,
turn_id: String,
state: State<'_, AppState>,
app: AppHandle,
) -> Result<Value, String> {
if remote_backend::is_remote_mode(&*state).await {
return remote_backend::call_remote(
&*state,
app,
"rollback_thread",
json!({ "workspaceId": workspace_id, "threadId": thread_id, "turnId": turn_id }),
)
.await;
}

codex_core::rollback_thread_core(&state.sessions, workspace_id, thread_id, turn_id).await
}

#[tauri::command]
pub(crate) async fn compact_thread(
workspace_id: String,
Expand Down
1 change: 1 addition & 0 deletions src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ pub fn run() {
codex::list_threads,
codex::list_mcp_server_status,
codex::archive_thread,
codex::rollback_thread,
codex::compact_thread,
codex::set_thread_name,
codex::collaboration_mode_list,
Expand Down
98 changes: 98 additions & 0 deletions src-tauri/src/shared/codex_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,57 @@ pub(crate) async fn archive_thread_core(
.await
}

pub(crate) async fn rollback_thread_core(
sessions: &Mutex<HashMap<String, Arc<WorkspaceSession>>>,
workspace_id: String,
thread_id: String,
turn_id: String,
) -> Result<Value, String> {
let session = get_session_clone(sessions, &workspace_id).await?;
let thread_response = read_thread_core(sessions, workspace_id.clone(), thread_id.clone()).await?;
let num_turns = rollback_num_turns_from_response(&thread_response, &turn_id)?;
let params = json!({ "threadId": thread_id, "numTurns": num_turns });
session
.send_request_for_workspace(&workspace_id, "thread/rollback", params)
.await
}

fn rollback_num_turns_from_response(response: &Value, turn_id: &str) -> Result<usize, String> {
let thread = extract_thread_from_response(response)
.ok_or_else(|| "Rollback failed: thread/read response missing thread payload.".to_string())?;
let turns = thread
.get("turns")
.and_then(Value::as_array)
.ok_or_else(|| "Rollback failed: thread/read response missing turns.".to_string())?;
let turn_index = turns
.iter()
.position(|turn| {
turn.as_object().is_some_and(|record| {
record
.get("id")
.or_else(|| record.get("turnId"))
.or_else(|| record.get("turn_id"))
.and_then(Value::as_str)
.is_some_and(|value| value == turn_id)
})
})
.ok_or_else(|| format!("Rollback failed: turn '{turn_id}' was not found in thread."))?;
Ok(turns.len().saturating_sub(turn_index))
}

fn extract_thread_from_response<'a>(response: &'a Value) -> Option<&'a Map<String, Value>> {
response
.as_object()
.and_then(|record| {
record
.get("result")
.and_then(Value::as_object)
.and_then(|result| result.get("thread"))
.or_else(|| record.get("thread"))
})
.and_then(Value::as_object)
}

pub(crate) async fn compact_thread_core(
sessions: &Mutex<HashMap<String, Arc<WorkspaceSession>>>,
workspace_id: String,
Expand Down Expand Up @@ -1030,4 +1081,51 @@ mod tests {
assert!(THREAD_LIST_SOURCE_KINDS.contains(&"subAgentCompact"));
assert!(THREAD_LIST_SOURCE_KINDS.contains(&"subAgentThreadSpawn"));
}

#[test]
fn rollback_num_turns_counts_from_target_turn_to_end() {
let response = json!({
"result": {
"thread": {
"turns": [
{ "id": "turn-1" },
{ "id": "turn-2" },
{ "id": "turn-3" }
]
}
}
});

let num_turns = rollback_num_turns_from_response(&response, "turn-2").unwrap();
assert_eq!(num_turns, 2);
}

#[test]
fn rollback_num_turns_supports_turn_id_aliases() {
let response = json!({
"thread": {
"turns": [
{ "turn_id": "turn-1" },
{ "turnId": "turn-2" }
]
}
});

let num_turns = rollback_num_turns_from_response(&response, "turn-2").unwrap();
assert_eq!(num_turns, 1);
}

#[test]
fn rollback_num_turns_errors_when_turn_is_missing() {
let response = json!({
"result": {
"thread": {
"turns": [{ "id": "turn-1" }]
}
}
});

let err = rollback_num_turns_from_response(&response, "turn-9").unwrap_err();
assert!(err.contains("turn-9"));
}
}
12 changes: 12 additions & 0 deletions src/features/app/components/MainApp.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import { useRemoteThreadLiveConnection } from "@app/hooks/useRemoteThreadLiveCon
import { useTrayRecentThreads } from "@app/hooks/useTrayRecentThreads";
import { useTraySessionUsage } from "@app/hooks/useTraySessionUsage";
import { useTauriEvent } from "@app/hooks/useTauriEvent";
import { useMessageEdit } from "@/features/messages/hooks/useMessageEdit";
import { useAppBootstrapOrchestration } from "@app/bootstrap/useAppBootstrapOrchestration";
import {
useThreadCodexBootstrapOrchestration,
Expand Down Expand Up @@ -494,6 +495,7 @@ export default function MainApp() {
handleUserInputSubmit,
refreshAccountInfo,
refreshAccountRateLimits,
editAndRegenerateMessage,
} = useThreads({
activeWorkspace,
onWorkspaceConnected: markWorkspaceConnected,
Expand All @@ -516,6 +518,15 @@ export default function MainApp() {
threadSortKey: threadListSortKey,
onThreadCodexMetadataDetected: handleThreadCodexMetadataDetected,
});

const messageEditState = useMessageEdit({
onRegenerate: async (itemId, newText, images) => {
if (!activeWorkspace || !activeThreadId) {
return;
}
await editAndRegenerateMessage(activeWorkspace, activeThreadId, itemId, newText, images);
},
});
const { connectionState: remoteThreadConnectionState, reconnectLive } =
useRemoteThreadLiveConnection({
backendMode: appSettings.backendMode,
Expand Down Expand Up @@ -1648,6 +1659,7 @@ export default function MainApp() {
promptActions,
worktreeState,
sidebarHandlers: sidebarMenuOrchestration,
messageEditState,
displayNodes,
threadPinning: {
pinThread,
Expand Down
22 changes: 22 additions & 0 deletions src/features/app/hooks/useAppServerEvents.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ describe("useAppServerEvents", () => {
onPlanDelta: vi.fn(),
onApprovalRequest: vi.fn(),
onRequestUserInput: vi.fn(),
onItemStarted: vi.fn(),
onItemCompleted: vi.fn(),
onAgentMessageCompleted: vi.fn(),
onAccountRateLimitsUpdated: vi.fn(),
Expand Down Expand Up @@ -324,6 +325,7 @@ describe("useAppServerEvents", () => {
method: "item/completed",
params: {
threadId: "thread-1",
turnId: "turn-2",
item: { type: "agentMessage", id: "item-2", text: "Done" },
},
},
Expand All @@ -332,6 +334,7 @@ describe("useAppServerEvents", () => {
expect(handlers.onItemCompleted).toHaveBeenCalledWith("ws-1", "thread-1", {
type: "agentMessage",
id: "item-2",
turnId: "turn-2",
text: "Done",
});
expect(handlers.onAgentMessageCompleted).toHaveBeenCalledWith({
Expand All @@ -341,6 +344,25 @@ describe("useAppServerEvents", () => {
text: "Done",
});

act(() => {
listener?.({
workspace_id: "ws-1",
message: {
method: "item/started",
params: {
threadId: "thread-1",
turnId: "turn-3",
item: { type: "userMessage", id: "item-3" },
},
},
});
});
expect(handlers.onItemStarted).toHaveBeenCalledWith("ws-1", "thread-1", {
type: "userMessage",
id: "item-3",
turnId: "turn-3",
});

act(() => {
listener?.({
workspace_id: "ws-1",
Expand Down
8 changes: 6 additions & 2 deletions src/features/app/hooks/useAppServerEvents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,10 @@ export function useAppServerEvents(handlers: AppServerEventHandlers) {
if (method === "item/completed") {
const threadId = String(params.threadId ?? params.thread_id ?? "");
const item = params.item as Record<string, unknown> | undefined;
const turnId = String(params.turnId ?? params.turn_id ?? "").trim();
if (threadId && item) {
currentHandlers.onItemCompleted?.(workspace_id, threadId, item);
const itemWithTurnId = turnId ? { ...item, turnId } : item;
currentHandlers.onItemCompleted?.(workspace_id, threadId, itemWithTurnId);
}
if (threadId && item?.type === "agentMessage") {
const itemId = String(item.id ?? "");
Expand All @@ -488,8 +490,10 @@ export function useAppServerEvents(handlers: AppServerEventHandlers) {
if (method === "item/started") {
const threadId = String(params.threadId ?? params.thread_id ?? "");
const item = params.item as Record<string, unknown> | undefined;
const turnId = String(params.turnId ?? params.turn_id ?? "").trim();
if (threadId && item) {
currentHandlers.onItemStarted?.(workspace_id, threadId, item);
const itemWithTurnId = turnId ? { ...item, turnId } : item;
currentHandlers.onItemStarted?.(workspace_id, threadId, itemWithTurnId);
}
return;
}
Expand Down
17 changes: 16 additions & 1 deletion src/features/app/hooks/useMainAppLayoutSurfaces.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { RefObject } from "react";
import type { AppSettings, ComposerEditorSettings, WorkspaceInfo } from "@/types";
import type { AppSettings, ComposerEditorSettings, ConversationItem, WorkspaceInfo } from "@/types";
import type { UseMessageEditResult } from "@/features/messages/hooks/useMessageEdit";
import type { ThreadState } from "@/features/threads/hooks/useThreadsReducer";
import type { WorkspaceLaunchScriptsState } from "@app/hooks/useWorkspaceLaunchScripts";
import { REMOTE_THREAD_POLL_INTERVAL_MS } from "@app/hooks/useRemoteThreadRefreshOnFocus";
Expand Down Expand Up @@ -225,6 +226,7 @@ type UseMainAppLayoutSurfacesArgs = {
dismissErrorToast: LayoutNodesOptions["primary"]["errorToastsProps"]["onDismiss"];
showDebugButton: boolean;
handleDebugClick: () => void;
messageEditState?: UseMessageEditResult;
};

type MainAppLayoutSurfacesContext = UseMainAppLayoutSurfacesArgs & {
Expand Down Expand Up @@ -374,6 +376,7 @@ function buildPrimarySurface({
dismissErrorToast,
showDebugButton,
handleDebugClick,
messageEditState,
}: MainAppLayoutSurfacesContext): LayoutNodesOptions["primary"] {
return {
sidebarProps: {
Expand Down Expand Up @@ -465,6 +468,16 @@ function buildPrimarySurface({
: null,
showPollingFetchStatus: showMobilePollingFetchStatus,
pollingIntervalMs: REMOTE_THREAD_POLL_INTERVAL_MS,
editingItemId: messageEditState?.editingItemId,
editText: messageEditState?.editText,
isConfirmingEdit: messageEditState?.isConfirming,
isRegeneratingEdit: messageEditState?.isRegenerating,
onStartEdit: messageEditState?.startEdit ? (item: Extract<ConversationItem, { kind: "message" }>) => messageEditState.startEdit(item.id, item.text, item.images) : undefined,
onCancelEdit: messageEditState?.cancelEdit,
onUpdateEditText: messageEditState?.updateEditText,
onRequestRegenerate: messageEditState?.requestRegenerate,
onCancelConfirm: messageEditState?.cancelConfirm,
onExecuteRegenerate: messageEditState?.executeRegenerate,
},
composerProps: composerWorkspaceState.showComposer
? {
Expand Down Expand Up @@ -1098,6 +1111,7 @@ export function useMainAppLayoutSurfaces({
dismissErrorToast,
showDebugButton,
handleDebugClick,
messageEditState,
}: UseMainAppLayoutSurfacesArgs): LayoutNodesOptions {
const sidebarRateLimits = activeWorkspace ? activeRateLimits : homeRateLimits;
const sidebarAccount = activeWorkspace ? activeAccount : homeAccount;
Expand Down Expand Up @@ -1260,6 +1274,7 @@ export function useMainAppLayoutSurfaces({
dismissErrorToast,
showDebugButton,
handleDebugClick,
messageEditState,
sidebarRateLimits,
sidebarAccount,
};
Expand Down
Loading