e.key === "Enter" && handleCardClick()}>
@@ -229,7 +229,7 @@ export default function MembersPage() {
-
发起会话
+
打开线程
diff --git a/frontend/app/src/pages/NewChatPage.test.tsx b/frontend/app/src/pages/NewChatPage.test.tsx
new file mode 100644
index 000000000..cb07bdfd6
--- /dev/null
+++ b/frontend/app/src/pages/NewChatPage.test.tsx
@@ -0,0 +1,179 @@
+// @vitest-environment jsdom
+
+import { render, screen, waitFor } from "@testing-library/react";
+import { MemoryRouter, Outlet, Route, Routes } from "react-router-dom";
+import { beforeEach, describe, expect, it, vi } from "vitest";
+import NewChatPage from "./NewChatPage";
+import { useAuthStore } from "../store/auth-store";
+import { useAppStore } from "../store/app-store";
+
+const handleGetMainThread = vi.fn();
+
+vi.mock("zustand/middleware", async () => {
+ const actual = await vi.importActual("zustand/middleware");
+ return {
+ ...actual,
+ persist: ((initializer: unknown) => initializer) as typeof actual.persist,
+ };
+});
+
+vi.mock("../components/CenteredInputBox", () => ({
+ default: () => centered-input-box
,
+}));
+
+vi.mock("../components/WorkspaceSetupModal", () => ({
+ default: () => null,
+}));
+
+vi.mock("../components/FilesystemBrowser", () => ({
+ default: () => null,
+}));
+
+vi.mock("../components/MemberAvatar", () => ({
+ default: ({ name }: { name: string }) => {name}
,
+}));
+
+vi.mock("../hooks/use-workspace-settings", () => ({
+ useWorkspaceSettings: () => ({
+ settings: { default_workspace: null, recent_workspaces: [], default_model: "leon:large", enabled_models: ["leon:large"] },
+ loading: false,
+ hasWorkspace: false,
+ refreshSettings: vi.fn(),
+ setDefaultWorkspace: vi.fn(),
+ }),
+}));
+
+vi.mock("../api", () => ({
+ postRun: vi.fn(),
+}));
+
+vi.mock("../api/client", () => ({
+ getDefaultThreadConfig: vi.fn(() => new Promise(() => {})),
+ listMyLeases: vi.fn(async () => []),
+ saveDefaultThreadConfig: vi.fn(async () => undefined),
+}));
+
+function ContextOutlet() {
+ return (
+
+ );
+}
+
+describe("NewChatPage", () => {
+ beforeEach(() => {
+ handleGetMainThread.mockReset();
+ handleGetMainThread.mockResolvedValue(null);
+
+ useAuthStore.setState({
+ token: "token",
+ user: { id: "u-1", name: "tester", type: "human", avatar: null },
+ agent: null,
+ entityId: "u-1",
+ setupInfo: null,
+ login: vi.fn(),
+ sendOtp: vi.fn(),
+ verifyOtp: vi.fn(),
+ completeRegister: vi.fn(),
+ clearSetupInfo: vi.fn(),
+ logout: vi.fn(),
+ });
+
+ useAppStore.setState({
+ memberList: [{
+ id: "m_xVuNpKJNxblZ",
+ name: "Morel",
+ description: "",
+ status: "active",
+ version: "1.0.0",
+ avatar_url: "/avatars/morel.png",
+ config: {
+ prompt: "",
+ rules: [],
+ tools: [],
+ mcps: [],
+ skills: [],
+ subAgents: [],
+ },
+ created_at: 0,
+ updated_at: 0,
+ }],
+ taskList: [],
+ cronJobs: [],
+ librarySkills: [],
+ libraryMcps: [],
+ libraryAgents: [],
+ libraryRecipes: [],
+ userProfile: { name: "User", initials: "U", email: "" },
+ loaded: true,
+ error: null,
+ loadAll: vi.fn(),
+ retry: vi.fn(),
+ resetSessionData: vi.fn(),
+ fetchMembers: vi.fn(),
+ addMember: vi.fn(),
+ updateMember: vi.fn(),
+ updateMemberConfig: vi.fn(),
+ publishMember: vi.fn(),
+ deleteMember: vi.fn(),
+ getMemberById: vi.fn(),
+ fetchTasks: vi.fn(),
+ addTask: vi.fn(),
+ updateTask: vi.fn(),
+ deleteTask: vi.fn(),
+ bulkUpdateTaskStatus: vi.fn(),
+ bulkDeleteTasks: vi.fn(),
+ fetchCronJobs: vi.fn(),
+ addCronJob: vi.fn(),
+ updateCronJob: vi.fn(),
+ deleteCronJob: vi.fn(),
+ triggerCronJob: vi.fn(),
+ fetchLibrary: vi.fn(),
+ fetchLibraryNames: vi.fn(),
+ addResource: vi.fn(),
+ updateResource: vi.fn(),
+ deleteResource: vi.fn(),
+ fetchResourceContent: vi.fn(),
+ updateResourceContent: vi.fn(),
+ fetchProfile: vi.fn(),
+ updateProfile: vi.fn(),
+ getMemberNames: vi.fn(),
+ getResourceUsedBy: vi.fn(),
+ });
+ });
+
+ it("does not block the create-chat UI on a pending default-config fetch once main thread resolves null", async () => {
+ render(
+
+
+ }>
+ } />
+
+
+ ,
+ );
+
+ await waitFor(() => {
+ expect(screen.getByText("开始与 Morel 对话")).toBeTruthy();
+ });
+ expect(screen.queryByText("正在检查 Morel 的主对话")).toBeNull();
+ expect(screen.getByText("centered-input-box")).toBeTruthy();
+ });
+});
diff --git a/frontend/app/src/pages/NewChatPage.tsx b/frontend/app/src/pages/NewChatPage.tsx
index 235ca48f4..4e1c739be 100644
--- a/frontend/app/src/pages/NewChatPage.tsx
+++ b/frontend/app/src/pages/NewChatPage.tsx
@@ -22,6 +22,34 @@ interface OutletContext {
setSessionsOpen: (value: boolean) => void;
}
+function ResolveStateCard({
+ memberName,
+ memberAvatarUrl,
+ title,
+ description,
+ destructive = false,
+}: {
+ memberName: string;
+ memberAvatarUrl?: string;
+ title: string;
+ description: string;
+ destructive?: boolean;
+}) {
+ return (
+
+
+
+
+
+
{title}
+
+ {description}
+
+
+
+ );
+}
+
const PROVIDER_TYPE_LABELS: Record = {
local: "Local",
daytona: "Daytona",
@@ -472,39 +500,29 @@ export default function NewChatPage({ mode = "member" }: { mode?: "member" | "ne
? `复用 ${providerSummaryLabel} 的现有 sandbox`
: `新建 ${providerSummaryLabel} sandbox · ${recipeSummaryLabel}`;
- if (loading || resolveState === "resolving" || configDefaultsLoading) {
+ // @@@defer-default-config - default config should refine the create form, not block
+ // entry into the no-main-thread UI. If the config fetch stalls, users still need the
+ // create-chat surface with sane local defaults.
+ if (loading || resolveState === "resolving") {
return (
-
-
-
-
-
-
- 正在检查 {memberName} 的主对话
-
-
- 如果没有主对话,这里会进入创建界面。
-
-
-
+
);
}
if (resolveState === "error") {
return (
-
-
-
-
-
-
- 无法检查 {memberName} 的主对话
-
-
- {error ?? "未知错误"}
-
-
-
+
);
}
diff --git a/frontend/app/src/pages/RootLayout.test.tsx b/frontend/app/src/pages/RootLayout.test.tsx
new file mode 100644
index 000000000..cb1a1090a
--- /dev/null
+++ b/frontend/app/src/pages/RootLayout.test.tsx
@@ -0,0 +1,72 @@
+// @vitest-environment jsdom
+
+import { fireEvent, render, screen, waitFor } from "@testing-library/react";
+import { beforeEach, describe, expect, it, vi } from "vitest";
+import { MemoryRouter, Route, Routes } from "react-router-dom";
+import { LoginForm } from "./RootLayout";
+import { useAuthStore } from "../store/auth-store";
+
+vi.mock("zustand/middleware", async () => {
+ const actual = await vi.importActual("zustand/middleware");
+ return {
+ ...actual,
+ persist: ((initializer: unknown) => initializer) as typeof actual.persist,
+ };
+});
+
+describe("LoginForm", () => {
+ beforeEach(() => {
+ useAuthStore.setState({
+ token: null,
+ user: null,
+ agent: null,
+ entityId: null,
+ setupInfo: null,
+ login: vi.fn(async () => {
+ useAuthStore.setState({
+ token: "token",
+ user: { id: "u-1", name: "tester", type: "human", avatar: null },
+ agent: null,
+ entityId: null,
+ setupInfo: null,
+ });
+ }),
+ sendOtp: vi.fn(async () => undefined),
+ verifyOtp: vi.fn(async () => ({ tempToken: "temp" })),
+ completeRegister: vi.fn(async () => undefined),
+ clearSetupInfo: vi.fn(),
+ logout: vi.fn(),
+ });
+ });
+
+ it("redirects to /threads after a successful login", async () => {
+ render(
+
+
+
+
+ login-page
+ >
+ }
+ />
+ threads-page} />
+
+ ,
+ );
+
+ fireEvent.change(screen.getByPlaceholderText("邮箱或 Mycel ID"), {
+ target: { value: "otpfull_1775371370@example.com" },
+ });
+ fireEvent.change(screen.getByPlaceholderText("密码"), {
+ target: { value: "LeonFull123!" },
+ });
+ fireEvent.click(screen.getByRole("button", { name: "登录" }));
+
+ await waitFor(() => {
+ expect(screen.getByText("threads-page")).toBeTruthy();
+ });
+ });
+});
diff --git a/frontend/app/src/pages/RootLayout.tsx b/frontend/app/src/pages/RootLayout.tsx
index d285056f0..db8c4496b 100644
--- a/frontend/app/src/pages/RootLayout.tsx
+++ b/frontend/app/src/pages/RootLayout.tsx
@@ -1,5 +1,5 @@
import { NavLink, Outlet, useLocation, useNavigate } from "react-router-dom";
-import { MessageSquare, MessagesSquare, Users, ListTodo, Store, Layers, Plug, Settings, Plus, ChevronLeft, ChevronRight, LogOut, Camera, Eye, EyeOff } from "lucide-react";
+import { MessageSquare, MessagesSquare, Users, ListTodo, Store, Layers, Settings, Plus, ChevronLeft, ChevronRight, LogOut, Camera, Eye, EyeOff } from "lucide-react";
import { useState, useEffect, useCallback, useRef } from "react";
import { uploadMemberAvatar } from "@/api/client";
import MemberAvatar from "@/components/MemberAvatar";
@@ -18,7 +18,6 @@ const navItems = [
{ to: "/tasks", icon: ListTodo, label: "Tasks" },
{ to: "/resources", icon: Layers, label: "Resources" },
{ to: "/marketplace", icon: Store, label: "Marketplace" },
- { to: "/connections", icon: Plug, label: "Connections" },
];
const mobileNavItems = [
@@ -65,9 +64,21 @@ function AuthenticatedLayout() {
}, [authUser]);
const loadAll = useAppStore((s) => s.loadAll);
+ const resetSessionData = useAppStore((s) => s.resetSessionData);
const storeAddTask = useAppStore((s) => s.addTask);
+ const lastLoadedUserIdRef = useRef
(null);
- useEffect(() => { loadAll(); }, [loadAll]);
+ useEffect(() => {
+ const userId = authUser?.id ?? null;
+ if (!userId) return;
+ if (lastLoadedUserIdRef.current === userId) return;
+ // @@@auth-session-reset - switching users in the same SPA process must discard
+ // panel caches before reloading, otherwise the next account inherits old
+ // members/tasks and the sidebar mixes identities.
+ lastLoadedUserIdRef.current = userId;
+ resetSessionData();
+ void loadAll();
+ }, [authUser?.id, loadAll, resetSessionData]);
const [expanded, setExpanded] = useState(() => {
const saved = localStorage.getItem("sidebar-expanded");
@@ -187,7 +198,10 @@ function AuthenticatedLayout() {
{/* Main content - no top bar, pages have their own headers */}
-
+ {/* @@@outlet-no-route-key - thread switches should not remount the entire
+ outlet tree; RootLayout route keys were re-triggering AppLayout
+ bootstrap fetches on every /threads/:memberId/:threadId hop. */}
+
{/* Bottom tab bar */}
@@ -311,7 +325,7 @@ function AuthenticatedLayout() {
-
+
@@ -354,7 +368,7 @@ function CreateDropdown({
新建成员
{cache?.loading ? (
@@ -922,4 +924,3 @@ export default function Tasks() {
-
diff --git a/frontend/app/src/pages/ThreadsIndexRedirect.tsx b/frontend/app/src/pages/ThreadsIndexRedirect.tsx
index 2fb79079c..025511dfe 100644
--- a/frontend/app/src/pages/ThreadsIndexRedirect.tsx
+++ b/frontend/app/src/pages/ThreadsIndexRedirect.tsx
@@ -1,14 +1,58 @@
import { useEffect } from "react";
import { useNavigate } from "react-router-dom";
+import { getMainThread } from "../api/client";
import { useAuthStore } from "../store/auth-store";
+const mainThreadInflight = new Map>>>();
+
+function loadMainThread(memberId: string) {
+ const existing = mainThreadInflight.get(memberId);
+ if (existing) return existing;
+ const pending = getMainThread(memberId).finally(() => {
+ mainThreadInflight.delete(memberId);
+ });
+ mainThreadInflight.set(memberId, pending);
+ return pending;
+}
+
export default function ThreadsIndexRedirect() {
const agent = useAuthStore((s) => s.agent);
const navigate = useNavigate();
useEffect(() => {
if (!agent?.id) return;
- navigate(`/threads/${encodeURIComponent(agent.id)}`, { replace: true });
+ const agentId = agent.id;
+
+ let cancelled = false;
+
+ async function redirectToThread() {
+ const memberId = encodeURIComponent(agentId);
+ try {
+ // @@@threads-index-direct-main-route - /threads is a pure entrypoint; resolve the
+ // main thread here so login/setup flows do not bounce through NewChatPage first.
+ // @@@threads-index-inflight-dedup - React StrictMode remounts /threads in dev.
+ // Reuse the first main-thread request and ignore stale callbacks instead of
+ // aborting the first fetch and polluting network/devtools with ERR_ABORTED.
+ const thread = await loadMainThread(agentId);
+ if (cancelled) return;
+ navigate(
+ thread
+ ? `/threads/${memberId}/${encodeURIComponent(thread.thread_id)}`
+ : `/threads/${memberId}`,
+ { replace: true },
+ );
+ } catch (error) {
+ if (cancelled) return;
+ if (error instanceof DOMException && error.name === "AbortError") return;
+ console.error("[ThreadsIndexRedirect] resolve main thread failed:", error);
+ navigate(`/threads/${memberId}`, { replace: true });
+ }
+ }
+
+ void redirectToThread();
+ return () => {
+ cancelled = true;
+ };
}, [agent?.id, navigate]);
return null;
diff --git a/frontend/app/src/pages/ask-user-question.test.ts b/frontend/app/src/pages/ask-user-question.test.ts
new file mode 100644
index 000000000..899c58006
--- /dev/null
+++ b/frontend/app/src/pages/ask-user-question.test.ts
@@ -0,0 +1,38 @@
+import { describe, expect, it } from "vitest";
+import { askUserQuestionSelectionKey, buildAskUserAnswers } from "./ask-user-question";
+import type { AskUserQuestionPrompt } from "../api";
+
+describe("ask-user-question helpers", () => {
+ it("keeps duplicate prompts independently addressable by position", () => {
+ const questions: AskUserQuestionPrompt[] = [
+ {
+ header: "Style",
+ question: "Choose a style",
+ options: [{ label: "Minimal", description: "Keep it simple" }],
+ },
+ {
+ header: "Style",
+ question: "Choose a style",
+ options: [{ label: "Bold", description: "Make it loud" }],
+ },
+ ];
+
+ const answers = buildAskUserAnswers(questions, {
+ [askUserQuestionSelectionKey(0)]: ["Minimal"],
+ [askUserQuestionSelectionKey(1)]: ["Bold"],
+ });
+
+ expect(answers).toEqual([
+ {
+ header: "Style",
+ question: "Choose a style",
+ selected_options: ["Minimal"],
+ },
+ {
+ header: "Style",
+ question: "Choose a style",
+ selected_options: ["Bold"],
+ },
+ ]);
+ });
+});
diff --git a/frontend/app/src/pages/ask-user-question.ts b/frontend/app/src/pages/ask-user-question.ts
new file mode 100644
index 000000000..a1ce5faad
--- /dev/null
+++ b/frontend/app/src/pages/ask-user-question.ts
@@ -0,0 +1,16 @@
+import type { AskUserAnswer, AskUserQuestionPrompt } from "../api";
+
+export function askUserQuestionSelectionKey(questionIndex: number): string {
+ return String(questionIndex);
+}
+
+export function buildAskUserAnswers(
+ questions: AskUserQuestionPrompt[],
+ selections: Record,
+): AskUserAnswer[] {
+ return questions.map((question, index) => ({
+ header: question.header,
+ question: question.question,
+ selected_options: selections[askUserQuestionSelectionKey(index)] ?? [],
+ }));
+}
diff --git a/frontend/app/src/pages/resources/CapabilityIcons.tsx b/frontend/app/src/pages/resources/CapabilityIcons.tsx
index 886ef02aa..c3c32cbc0 100644
--- a/frontend/app/src/pages/resources/CapabilityIcons.tsx
+++ b/frontend/app/src/pages/resources/CapabilityIcons.tsx
@@ -52,38 +52,3 @@ export function CapabilityStrip({ capabilities }: { capabilities: ProviderCapabi
);
}
-
-/** Detailed capability tiles for ProviderDetail */
-export function CapabilityGrid({ capabilities }: { capabilities: ProviderCapabilities }) {
- return (
-
- {CAPABILITY_KEYS.map((key) => {
- const Icon = CAPABILITY_ICON_MAP[key];
- const has = capabilities[key];
- return (
-
-
-
-
-
- {CAPABILITY_LABELS[key]}
-
-
- );
- })}
-
- );
-}
diff --git a/frontend/app/src/router.tsx b/frontend/app/src/router.tsx
index 024478143..b45f6193f 100644
--- a/frontend/app/src/router.tsx
+++ b/frontend/app/src/router.tsx
@@ -15,7 +15,6 @@ import MarketplacePage from './pages/MarketplacePage';
import MarketplaceDetailPage from './pages/MarketplaceDetailPage';
import LibraryItemDetailPage from './pages/LibraryItemDetailPage';
import ResourcesPage from './pages/ResourcesPage';
-import ConnectionsPage from './pages/ConnectionsPage';
import InviteCodesPage from './pages/InviteCodesPage';
export const router = createBrowserRouter([
@@ -34,23 +33,27 @@ export const router = createBrowserRouter([
},
{
path: 'threads',
- element:
,
children: [
{
index: true,
element:
,
},
{
- path: ':memberId',
- element:
,
- },
- {
- path: ':memberId/new',
- element:
,
- },
- {
- path: ':memberId/:threadId',
- element:
,
+ element:
,
+ children: [
+ {
+ path: ':memberId',
+ element:
,
+ },
+ {
+ path: ':memberId/new',
+ element:
,
+ },
+ {
+ path: ':memberId/:threadId',
+ element:
,
+ },
+ ],
},
],
},
@@ -100,10 +103,6 @@ export const router = createBrowserRouter([
path: 'library',
element:
,
},
- {
- path: 'connections',
- element:
,
- },
{
path: 'invite-codes',
element:
,
diff --git a/frontend/app/src/store/app-store.test.ts b/frontend/app/src/store/app-store.test.ts
new file mode 100644
index 000000000..350c25ba7
--- /dev/null
+++ b/frontend/app/src/store/app-store.test.ts
@@ -0,0 +1,36 @@
+// @vitest-environment jsdom
+
+import { beforeEach, describe, expect, it } from "vitest";
+import { useAppStore } from "./app-store";
+
+describe("useAppStore", () => {
+ beforeEach(() => {
+ useAppStore.setState({
+ memberList: [],
+ taskList: [],
+ cronJobs: [],
+ librarySkills: [],
+ libraryMcps: [],
+ libraryAgents: [],
+ libraryRecipes: [],
+ userProfile: { name: "User", initials: "U", email: "" },
+ loaded: false,
+ error: null,
+ });
+ });
+
+ it("resets loaded member state when auth identity changes", () => {
+ useAppStore.setState({
+ memberList: [{ id: "m-old", name: "Old", status: "active" } as never],
+ loaded: true,
+ error: "stale",
+ });
+
+ useAppStore.getState().resetSessionData();
+
+ const state = useAppStore.getState();
+ expect(state.memberList).toEqual([]);
+ expect(state.loaded).toBe(false);
+ expect(state.error).toBeNull();
+ });
+});
diff --git a/frontend/app/src/store/app-store.ts b/frontend/app/src/store/app-store.ts
index e54bd1ef5..3cbab9423 100644
--- a/frontend/app/src/store/app-store.ts
+++ b/frontend/app/src/store/app-store.ts
@@ -6,6 +6,7 @@ import type {
import { useAuthStore } from "./auth-store";
const API = "/api/panel";
+let loadAllInflight: Promise
| null = null;
interface AppState {
// ── Data ──
@@ -23,6 +24,7 @@ interface AppState {
// ── Init ──
loadAll: () => Promise;
retry: () => Promise;
+ resetSessionData: () => void;
// ── Members ──
fetchMembers: () => Promise;
@@ -71,6 +73,38 @@ interface AppState {
getResourceUsedBy: (type: string, name: string) => string[];
}
+type LibraryType = "skill" | "mcp" | "agent" | "recipe";
+type LibraryStateKey = "librarySkills" | "libraryMcps" | "libraryAgents" | "libraryRecipes";
+
+const DEFAULT_PROFILE: UserProfile = { name: "User", initials: "U", email: "" };
+const LIBRARY_STATE_KEYS: Record = {
+ skill: "librarySkills",
+ mcp: "libraryMcps",
+ agent: "libraryAgents",
+ recipe: "libraryRecipes",
+};
+
+function getLibraryStateKey(type: string): LibraryStateKey {
+ const key = LIBRARY_STATE_KEYS[type as LibraryType];
+ if (!key) throw new Error(`Unsupported library type: ${type}`);
+ return key;
+}
+
+function emptySessionState() {
+ return {
+ memberList: [],
+ taskList: [],
+ cronJobs: [],
+ librarySkills: [],
+ libraryMcps: [],
+ libraryAgents: [],
+ libraryRecipes: [],
+ userProfile: DEFAULT_PROFILE,
+ loaded: false,
+ error: null,
+ };
+}
+
async function api(path: string, opts?: RequestInit): Promise {
const token = useAuthStore.getState().token;
const headers: Record = { "Content-Type": "application/json" };
@@ -81,35 +115,42 @@ async function api(path: string, opts?: RequestInit): Promise {
}
export const useAppStore = create()((set, get) => ({
- memberList: [],
- taskList: [],
- cronJobs: [],
- librarySkills: [],
- libraryMcps: [],
- libraryAgents: [],
- libraryRecipes: [],
- userProfile: { name: "User", initials: "U", email: "" },
- loaded: false,
- error: null,
+ ...emptySessionState(),
loadAll: async () => {
if (get().loaded) return;
- set({ error: null });
+ if (loadAllInflight) return loadAllInflight;
+
+ const pending = (async () => {
+ set({ error: null });
+ try {
+ // @@@load-all-singleflight - RootLayout can mount twice in dev StrictMode and /threads
+ // index redirect now avoids AppLayout, so keep the global panel bootstrap idempotent
+ // instead of firing duplicate members/tasks/library/profile bursts.
+ await Promise.all([
+ get().fetchMembers(),
+ get().fetchTasks(),
+ get().fetchCronJobs(),
+ get().fetchLibrary("skill"),
+ get().fetchLibrary("mcp"),
+ get().fetchLibrary("agent"),
+ get().fetchLibrary("recipe"),
+ get().fetchProfile(),
+ ]);
+ set({ loaded: true });
+ } catch (e) {
+ const msg = e instanceof Error ? e.message : String(e);
+ set({ error: `数据加载失败: ${msg}`, loaded: true });
+ }
+ })();
+
+ loadAllInflight = pending;
try {
- await Promise.all([
- get().fetchMembers(),
- get().fetchTasks(),
- get().fetchCronJobs(),
- get().fetchLibrary("skill"),
- get().fetchLibrary("mcp"),
- get().fetchLibrary("agent"),
- get().fetchLibrary("recipe"),
- get().fetchProfile(),
- ]);
- set({ loaded: true });
- } catch (e) {
- const msg = e instanceof Error ? e.message : String(e);
- set({ error: `数据加载失败: ${msg}`, loaded: true });
+ await pending;
+ } finally {
+ if (loadAllInflight === pending) {
+ loadAllInflight = null;
+ }
}
},
@@ -118,6 +159,11 @@ export const useAppStore = create()((set, get) => ({
await get().loadAll();
},
+ resetSessionData: () => {
+ loadAllInflight = null;
+ set(emptySessionState());
+ },
+
// ── Members ──
fetchMembers: async () => {
const data = await api<{ items: Member[] }>("/members");
@@ -254,10 +300,8 @@ export const useAppStore = create()((set, get) => ({
// ── Library ──
fetchLibrary: async (type) => {
const data = await api<{ items: ResourceItem[] }>(`/library/${type}`);
- if (type === "skill") set({ librarySkills: data.items });
- else if (type === "mcp") set({ libraryMcps: data.items });
- else if (type === "agent") set({ libraryAgents: data.items });
- else if (type === "recipe") set({ libraryRecipes: data.items });
+ const key = getLibraryStateKey(type);
+ set({ [key]: data.items } as Pick);
},
fetchLibraryNames: async (type) => {
@@ -270,10 +314,8 @@ export const useAppStore = create()((set, get) => ({
method: "POST",
body: JSON.stringify({ name, desc, ...extra }),
});
- if (type === "skill") set((s) => ({ librarySkills: [...s.librarySkills, item] }));
- else if (type === "mcp") set((s) => ({ libraryMcps: [...s.libraryMcps, item] }));
- else if (type === "agent") set((s) => ({ libraryAgents: [...s.libraryAgents, item] }));
- else set((s) => ({ libraryRecipes: [...s.libraryRecipes, item] }));
+ const key = getLibraryStateKey(type);
+ set((s) => ({ [key]: [...s[key], item] }) as Pick);
return item;
},
@@ -282,23 +324,23 @@ export const useAppStore = create()((set, get) => ({
method: "PUT",
body: JSON.stringify(fields),
});
- const updater = (list: ResourceItem[]) => list.map((x) => (x.id === id ? updated : x));
- if (type === "skill") set((s) => ({ librarySkills: updater(s.librarySkills) }));
- else if (type === "mcp") set((s) => ({ libraryMcps: updater(s.libraryMcps) }));
- else if (type === "agent") set((s) => ({ libraryAgents: updater(s.libraryAgents) }));
- else set((s) => ({ libraryRecipes: updater(s.libraryRecipes) }));
+ const key = getLibraryStateKey(type);
+ set((s) => ({
+ [key]: s[key].map((item) => (item.id === id ? updated : item)),
+ }) as Pick);
},
deleteResource: async (type, id) => {
await api(`/library/${type}/${id}`, { method: "DELETE" });
- const filter = (list: ResourceItem[]) => list.filter((x) => x.id !== id);
- if (type === "skill") set((s) => ({ librarySkills: filter(s.librarySkills) }));
- else if (type === "mcp") set((s) => ({ libraryMcps: filter(s.libraryMcps) }));
- else if (type === "agent") set((s) => ({ libraryAgents: filter(s.libraryAgents) }));
- else {
+ if (type === "recipe") {
const data = await api<{ items: ResourceItem[] }>(`/library/${type}`);
set({ libraryRecipes: data.items });
+ return;
}
+ const key = getLibraryStateKey(type);
+ set((s) => ({
+ [key]: s[key].filter((item) => item.id !== id),
+ }) as Pick);
},
fetchResourceContent: async (type, id) => {
diff --git a/frontend/app/src/store/auth-store.ts b/frontend/app/src/store/auth-store.ts
index fb0d7b1d8..f782dac72 100644
--- a/frontend/app/src/store/auth-store.ts
+++ b/frontend/app/src/store/auth-store.ts
@@ -1,15 +1,11 @@
/**
* Auth store — JWT token, user identity, login/register/logout.
* Persisted to localStorage via Zustand persist middleware.
- *
- * Set VITE_DEV_SKIP_AUTH=true in .env.development to bypass login during dev.
*/
import { create } from "zustand";
import { persist } from "zustand/middleware";
-const DEV_SKIP_AUTH = import.meta.env.VITE_DEV_SKIP_AUTH === "true";
-
// Allow overriding the API origin at runtime via window.__MYCEL_CONFIG__.apiBase
// (injected by docker-entrypoint.sh), falling back to the Vite build-time variable.
// Relative URLs are used when neither is set (same-origin / local dev).
@@ -18,7 +14,6 @@ const API_BASE = (
?? import.meta.env.VITE_API_BASE
?? ""
).replace(/\/$/, "");
-
export interface AuthIdentity {
id: string;
name: string;
@@ -62,15 +57,13 @@ async function apiPost(endpoint: string, body: Record) {
return res.json();
}
-const DEV_MOCK_USER: AuthIdentity = { id: "dev-user", name: "Dev", type: "human" };
-
export const useAuthStore = create()(
persist(
(set) => ({
- token: DEV_SKIP_AUTH ? "dev-skip-auth" : null,
- user: DEV_SKIP_AUTH ? DEV_MOCK_USER : null,
+ token: null,
+ user: null,
agent: null,
- entityId: DEV_SKIP_AUTH ? "dev-user" : null,
+ entityId: null,
setupInfo: null,
login: async (identifier, password) => {
@@ -81,7 +74,6 @@ export const useAuthStore = create()(
agent: data.agent,
entityId: data.user?.id ?? null,
});
- window.location.href = "/threads";
},
sendOtp: async (email, password, inviteCode) => {
@@ -117,10 +109,6 @@ export const useAuthStore = create()(
}),
{
name: "leon-auth",
- ...(DEV_SKIP_AUTH && {
- // In skip-auth mode, never let persisted null overwrite the mock identity
- merge: (_persisted: unknown, current: AuthState) => current,
- }),
},
),
);
@@ -140,7 +128,7 @@ export async function authFetch(url: string, init?: RequestInit): Promise=6.0.0",
"uvicorn>=0.30.0",
"sse-starlette>=1.6.0",
+ "multilspy>=0.0.15",
+ "pyright>=1.1.0",
"supabase>=2.28.3",
"fastapi>=0.118.0",
"langgraph-checkpoint-postgres>=3.0.5",
@@ -91,6 +93,7 @@ packages = [
"core.tools.filesystem",
"core.tools.filesystem.read",
"core.tools.filesystem.read.readers",
+ "core.tools.lsp",
"core.tools.search",
"core.tools.skills",
"core.tools.task",
diff --git a/sandbox/base.py b/sandbox/base.py
index 0a423f25a..2ae32a676 100644
--- a/sandbox/base.py
+++ b/sandbox/base.py
@@ -9,6 +9,7 @@
import asyncio
import logging
+import threading
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING
@@ -70,6 +71,49 @@ def __getattr__(self, name: str):
return getattr(self._remote._get_capability().command, name)
+def _cached_capability_is_stale(manager, thread_id: str, capability) -> bool:
+ session = getattr(capability, "_session", None)
+ if session is None:
+ return True
+ if getattr(session, "status", None) in {"closed", "failed", "paused"}:
+ return True
+ # @@@capability-cache-session-liveness - cached wrappers outlive session teardown;
+ # always confirm the cached session still exists as the current active session.
+ current = manager.session_manager.get(thread_id, session.terminal.terminal_id)
+ if current is None:
+ return True
+ return current.session_id != session.session_id
+
+
+def _run_coroutine_blocking(coro, *, timeout: float | None = None):
+ try:
+ asyncio.get_running_loop()
+ except RuntimeError:
+ return asyncio.run(coro)
+
+ result: dict[str, object] = {}
+ error: dict[str, BaseException] = {}
+ done = threading.Event()
+
+ # @@@same-loop-init-bridge - init commands can run while the web request event loop is already active;
+ # running run_coroutine_threadsafe(...).result() on that same loop deadlocks, so bridge through a helper thread.
+ def _runner() -> None:
+ try:
+ result["value"] = asyncio.run(coro)
+ except BaseException as exc: # pragma: no cover - defensive relay
+ error["value"] = exc
+ finally:
+ done.set()
+
+ thread = threading.Thread(target=_runner, daemon=True)
+ thread.start()
+ if not done.wait(timeout):
+ raise TimeoutError(f"Coroutine timed out after {timeout}s")
+ if "value" in error:
+ raise error["value"]
+ return result.get("value")
+
+
class RemoteSandbox(Sandbox):
"""Concrete sandbox for all provider-backed environments (AgentBay, Docker, E2B, Daytona)."""
@@ -103,26 +147,30 @@ def _get_capability(self) -> SandboxCapability:
thread_id = get_current_thread_id()
if not thread_id:
raise RuntimeError("No thread_id set. Call set_current_thread_id first.")
+ print(f"[RemoteSandbox._get_capability] provider={self._provider.name} thread_id={thread_id}", flush=True)
+ cached = self._capability_cache.get(thread_id)
+ if cached is not None and _cached_capability_is_stale(self._manager, thread_id, cached):
+ self._capability_cache.pop(thread_id, None)
if thread_id not in self._capability_cache:
+ print(
+ f"[RemoteSandbox._get_capability] provider={self._provider.name} thread_id={thread_id} cache=miss",
+ flush=True,
+ )
capability = self._manager.get_sandbox(thread_id)
if self._config.init_commands and thread_id not in self._init_commands_run:
self._run_init_commands(capability)
self._init_commands_run.add(thread_id)
self._capability_cache[thread_id] = capability
+ else:
+ print(
+ f"[RemoteSandbox._get_capability] provider={self._provider.name} thread_id={thread_id} cache=hit",
+ flush=True,
+ )
return self._capability_cache[thread_id]
def _run_init_commands(self, capability: SandboxCapability) -> None:
for i, cmd in enumerate(self._config.init_commands, 1):
- try:
- loop = asyncio.get_running_loop()
- except RuntimeError:
- loop = None
-
- if loop:
- future = asyncio.run_coroutine_threadsafe(capability.command.execute(cmd), loop)
- result = future.result(timeout=30)
- else:
- result = asyncio.run(capability.command.execute(cmd))
+ result = _run_coroutine_blocking(capability.command.execute(cmd), timeout=30)
if result.exit_code != 0:
raise RuntimeError(
@@ -229,6 +277,9 @@ def _get_capability(self) -> SandboxCapability:
thread_id = get_current_thread_id()
if not thread_id:
raise RuntimeError("No thread_id set. Call set_current_thread_id first.")
+ cached = self._capability_cache.get(thread_id)
+ if cached is not None and _cached_capability_is_stale(self._manager, thread_id, cached):
+ self._capability_cache.pop(thread_id, None)
if thread_id not in self._capability_cache:
self._capability_cache[thread_id] = self._manager.get_sandbox(thread_id)
return self._capability_cache[thread_id]
diff --git a/sandbox/capability.py b/sandbox/capability.py
index 4b278742a..b5269a30f 100644
--- a/sandbox/capability.py
+++ b/sandbox/capability.py
@@ -9,7 +9,7 @@
import shlex
import uuid
-from pathlib import Path
+from pathlib import Path, PurePosixPath
from typing import TYPE_CHECKING
from sandbox.interfaces.executor import BaseExecutor
@@ -36,7 +36,7 @@ class SandboxCapability:
def __init__(self, session: ChatSession, manager: SandboxManager | None = None):
self._session = session
self._command_wrapper = _CommandWrapper(session, manager=manager)
- self._fs_wrapper = _FileSystemWrapper(session)
+ self._fs_wrapper = _FileSystemWrapper(session, manager=manager)
@property
def command(self) -> BaseExecutor:
@@ -95,6 +95,14 @@ async def execute(self, command: str, cwd: str | None = None, timeout: float | N
self._session.touch()
# @@@command-context - CommandMiddleware passes Cwd/env; preserve that context for remote runtimes.
wrapped, _ = self._wrap_command(command, cwd, env)
+ print(
+ "[_CommandWrapper.execute] "
+ f"thread_id={self._session.thread_id} "
+ f"terminal_id={self._session.terminal.terminal_id} "
+ f"command={command[:200]!r} "
+ f"cwd={cwd!r} timeout={timeout}",
+ flush=True,
+ )
return await self._session.runtime.execute(wrapped, timeout)
async def execute_async(self, command: str, cwd: str | None = None, env: dict[str, str] | None = None):
@@ -178,8 +186,9 @@ class _FileSystemWrapper(FileSystemBackend):
is_remote = True
- def __init__(self, session: ChatSession):
+ def __init__(self, session: ChatSession, manager: SandboxManager | None = None):
self._session = session
+ self._manager = manager
def _get_provider(self):
"""Get provider from session's lease."""
@@ -193,7 +202,14 @@ def _get_instance_id(self) -> str:
# @@@lease-convergence - File operations can also wake paused instances; always converge through lease.
provider = getattr(self._session.runtime, "provider", None)
if provider is not None:
- instance = self._session.lease.ensure_active_instance(provider)
+ try:
+ instance = self._session.lease.ensure_active_instance(provider)
+ except RuntimeError:
+ if self._manager is None or getattr(self._session.lease, "observed_state", None) != "paused":
+ raise
+ if not self._manager.resume_session(self._session.thread_id, source="auto_resume"):
+ raise
+ instance = self._session.lease.ensure_active_instance(provider)
else:
instance = self._session.lease.get_instance()
if not instance:
@@ -242,7 +258,30 @@ def file_mtime(self, path: str) -> float | None:
return None
def file_size(self, path: str) -> int | None:
- """Not available for remote sandbox."""
+ """Best-effort size lookup via parent directory listing."""
+ self._session.touch()
+ provider = self._get_provider()
+ instance_id = self._get_instance_id()
+
+ target = PurePosixPath(path)
+ if not target.name:
+ return None
+
+ parent = str(target.parent) or "/"
+ try:
+ entries = provider.list_dir(instance_id, parent)
+ except Exception:
+ return None
+
+ for entry in entries or []:
+ if entry.get("name") != target.name:
+ continue
+ size = entry.get("size")
+ if isinstance(size, int):
+ return size
+ if isinstance(size, float):
+ return int(size)
+ return None
return None
def is_dir(self, path: str) -> bool:
diff --git a/sandbox/manager.py b/sandbox/manager.py
index 29f380b0a..54237f710 100644
--- a/sandbox/manager.py
+++ b/sandbox/manager.py
@@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any
+from config.user_paths import user_home_path
from sandbox.capability import SandboxCapability
from sandbox.chat_session import ChatSessionManager, ChatSessionPolicy
from sandbox.lease import lease_from_row
@@ -20,7 +21,7 @@
from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path
from storage.providers.sqlite.lease_repo import SQLiteLeaseRepo
from storage.providers.sqlite.terminal_repo import SQLiteTerminalRepo
-from storage.providers.sqlite.thread_repo import SQLiteThreadRepo
+from storage.runtime import build_storage_container, build_thread_repo
logger = logging.getLogger(__name__)
@@ -53,6 +54,76 @@ def lookup_sandbox_for_thread(thread_id: str, db_path: Path | None = None) -> st
lease_repo.close()
+def resolve_existing_lease_cwd(
+ lease_id: str,
+ fallback_cwd: str | None = None,
+ db_path: Path | None = None,
+) -> str:
+ if fallback_cwd:
+ return fallback_cwd
+
+ target_db = db_path or resolve_role_db_path(SQLiteDBRole.SANDBOX)
+ terminal_repo = SQLiteTerminalRepo(db_path=target_db)
+ try:
+ row = terminal_repo.get_latest_by_lease(lease_id)
+ finally:
+ terminal_repo.close()
+ if row and row.get("cwd"):
+ return str(row["cwd"])
+ return str(Path.home())
+
+
+def bind_thread_to_existing_lease(
+ thread_id: str,
+ lease_id: str,
+ *,
+ cwd: str | None = None,
+ db_path: Path | None = None,
+) -> str:
+ target_db = db_path or resolve_role_db_path(SQLiteDBRole.SANDBOX)
+ terminal_repo = SQLiteTerminalRepo(db_path=target_db)
+ try:
+ existing = terminal_repo.get_active(thread_id)
+ if existing is not None:
+ return str(existing["cwd"])
+ initial_cwd = resolve_existing_lease_cwd(lease_id, cwd, db_path=target_db)
+ terminal_repo.create(
+ terminal_id=f"term-{uuid.uuid4().hex[:12]}",
+ thread_id=thread_id,
+ lease_id=lease_id,
+ initial_cwd=initial_cwd,
+ )
+ return initial_cwd
+ finally:
+ terminal_repo.close()
+
+
+def bind_thread_to_existing_thread_lease(
+ thread_id: str,
+ source_thread_id: str,
+ *,
+ cwd: str | None = None,
+ db_path: Path | None = None,
+) -> str | None:
+ target_db = db_path or resolve_role_db_path(SQLiteDBRole.SANDBOX)
+ terminal_repo = SQLiteTerminalRepo(db_path=target_db)
+ try:
+ source_terminal = terminal_repo.get_active(source_thread_id)
+ finally:
+ terminal_repo.close()
+ if source_terminal is None:
+ return None
+ # @@@subagent-lease-reuse
+ # Child threads need their own terminal/session state, but must attach
+ # to the parent's existing lease instead of silently provisioning a new one.
+ return bind_thread_to_existing_lease(
+ thread_id,
+ str(source_terminal["lease_id"]),
+ cwd=cwd,
+ db_path=target_db,
+ )
+
+
class SandboxManager:
def __init__(
self,
@@ -105,27 +176,82 @@ def get_lease(self, lease_id: str):
def _default_terminal_cwd(self) -> str:
return resolve_provider_cwd(self.provider)
+ def _sandbox_volume_repo(self):
+ # @@@volume-repo-align - thread creation persists volume metadata through the
+ # active storage container; sandbox startup must read the same repo instead
+ # of hardcoding SQLite or Supabase-backed threads lose their volume row.
+ container = build_storage_container(main_db_path=resolve_role_db_path(SQLiteDBRole.MAIN))
+ return container.sandbox_volume_repo()
+
+ def _requires_volume_bootstrap(self) -> bool:
+ # @@@local-shell-no-volume-gate - local runtimes execute directly on the host
+ # and should not fail to start a shell just because file-channel volume
+ # metadata is absent or stored in a different backend.
+ return self.provider_capability.runtime_kind != "local"
+
+ def _ensure_thread_volume(self, thread_id: str, lease) -> None:
+ if not self._requires_volume_bootstrap() or lease.volume_id:
+ return
+
+ volume_id = str(uuid.uuid4())
+ self._create_volume_entry(thread_id, volume_id)
+
+ # @@@remote-volume-self-heal - legacy threads can lose their eager-created lease row
+ # and get rebound through manager recovery; persist a replacement volume_id before mount/sync.
+ self.lease_store.set_volume_id(lease.lease_id, volume_id)
+ lease.volume_id = volume_id
+
+ def _create_volume_entry(self, thread_id: str, volume_id: str) -> None:
+ import json
+ import os
+
+ from sandbox.volume_source import HostVolume
+
+ now_str = datetime.now().isoformat()
+ volume_root = Path(os.environ.get("LEON_SANDBOX_VOLUME_ROOT", str(user_home_path("volumes")))).expanduser().resolve()
+ volume_root.mkdir(parents=True, exist_ok=True)
+ source = HostVolume(volume_root / volume_id)
+
+ repo = self._sandbox_volume_repo()
+ try:
+ repo.create(volume_id, json.dumps(source.serialize()), f"vol-{thread_id}", now_str)
+ finally:
+ repo.close()
+
+ def _resolve_volume_entry(self, thread_id: str, lease) -> dict[str, Any]:
+ repo = self._sandbox_volume_repo()
+ try:
+ entry = repo.get(lease.volume_id)
+ finally:
+ repo.close()
+ if entry:
+ return entry
+ # @@@missing-volume-row-self-heal - old remote threads can retain a live lease.volume_id
+ # after the sandbox volume row was pruned; recreate the row in place before mount/sync.
+ self._create_volume_entry(thread_id, lease.volume_id)
+ repo = self._sandbox_volume_repo()
+ try:
+ entry = repo.get(lease.volume_id)
+ finally:
+ repo.close()
+ if not entry:
+ raise ValueError(f"Volume not found: {lease.volume_id}")
+ return entry
+
def _setup_mounts(self, thread_id: str) -> dict:
"""Mount the lease's volume into the sandbox. Pure sandbox-layer operation."""
import json
from sandbox.volume_source import DaytonaVolume, deserialize_volume_source
- from storage.providers.sqlite.sandbox_volume_repo import SQLiteSandboxVolumeRepo
terminal = self._get_active_terminal(thread_id)
if not terminal:
raise ValueError(f"No active terminal for thread {thread_id}")
lease = self._get_lease(terminal.lease_id)
- if not lease or not lease.volume_id:
+ if not lease:
raise ValueError(f"No volume for thread {thread_id}")
-
- repo = SQLiteSandboxVolumeRepo()
- try:
- entry = repo.get(lease.volume_id)
- finally:
- repo.close()
- if not entry:
- raise ValueError(f"Volume not found: {lease.volume_id}")
+ self._ensure_thread_volume(thread_id, lease)
+ entry = self._resolve_volume_entry(thread_id, lease)
source = deserialize_volume_source(json.loads(entry["source"]))
volume_id = lease.volume_id
@@ -152,11 +278,10 @@ def _upgrade_to_daytona_volume(self, thread_id: str, current_source, volume_id:
import json
from sandbox.volume_source import DaytonaVolume
- from storage.providers.sqlite.sandbox_volume_repo import SQLiteSandboxVolumeRepo
# @@@member-id-for-volume-naming - read from thread config in leon.db
member_id = "unknown"
- thread_repo = SQLiteThreadRepo(resolve_role_db_path(SQLiteDBRole.MAIN))
+ thread_repo = build_thread_repo(main_db_path=resolve_role_db_path(SQLiteDBRole.MAIN))
try:
row = thread_repo.get_by_id(thread_id)
if row:
@@ -172,6 +297,7 @@ def _upgrade_to_daytona_volume(self, thread_id: str, current_source, volume_id:
if "already exists" in str(e):
volume_name = f"leon-volume-{member_id}"
logger.info("Daytona volume already exists: %s, reusing", volume_name)
+ self.provider.wait_managed_volume_ready(volume_name)
else:
raise
@@ -180,7 +306,7 @@ def _upgrade_to_daytona_volume(self, thread_id: str, current_source, volume_id:
volume_name=volume_name,
)
- repo = SQLiteSandboxVolumeRepo()
+ repo = self._sandbox_volume_repo()
try:
repo.update_source(volume_id, json.dumps(new_source.serialize()))
finally:
@@ -251,30 +377,35 @@ def resolve_volume_source(self, thread_id: str):
import json
from sandbox.volume_source import deserialize_volume_source
- from storage.providers.sqlite.sandbox_volume_repo import SQLiteSandboxVolumeRepo
terminal = self._get_active_terminal(thread_id)
if not terminal:
raise ValueError(f"No active terminal for thread {thread_id}")
lease = self._get_lease(terminal.lease_id)
- if not lease or not lease.volume_id:
+ if not lease:
raise ValueError(f"No volume for thread {thread_id}")
- repo = SQLiteSandboxVolumeRepo()
- try:
- entry = repo.get(lease.volume_id)
- finally:
- repo.close()
- if not entry:
- raise ValueError(f"Volume not found: {lease.volume_id}")
+ self._ensure_thread_volume(thread_id, lease)
+ entry = self._resolve_volume_entry(thread_id, lease)
return deserialize_volume_source(json.loads(entry["source"]))
+ def _skip_volume_sync_for_local_lease(self, lease) -> bool:
+ # @@@local-no-volume-sync - local sessions may execute directly in host cwd with no sandbox volume row.
+ # In that shape there is nothing to upload/download, so sync paths must no-op instead of inventing one.
+ return lease is not None and not self._requires_volume_bootstrap() and not lease.volume_id
+
def _sync_to_sandbox(self, thread_id: str, instance_id: str, source=None, files: list[str] | None = None) -> None:
if source is None:
+ lease = self._get_thread_lease(thread_id)
+ if self._skip_volume_sync_for_local_lease(lease):
+ return
source = self.resolve_volume_source(thread_id)
self.volume.sync_upload(thread_id, instance_id, source, self.volume.resolve_mount_path(), files=files)
def _sync_from_sandbox(self, thread_id: str, instance_id: str, source=None) -> None:
if source is None:
+ lease = self._get_thread_lease(thread_id)
+ if self._skip_volume_sync_for_local_lease(lease):
+ return
source = self.resolve_volume_source(thread_id)
self.volume.sync_download(thread_id, instance_id, source, self.volume.resolve_mount_path())
@@ -307,7 +438,7 @@ def get_sandbox(self, thread_id: str, bind_mounts: list | None = None) -> Sandbo
if session:
self._assert_lease_provider(session.lease, thread_id)
# @@@activity-resume - Any new activity against a paused thread must resume before command execution.
- if session.status == "paused":
+ if session.status == "paused" or getattr(session.lease, "observed_state", None) == "paused":
if not self.resume_session(thread_id, source="auto_resume"):
raise RuntimeError(f"Failed to resume paused session for thread {thread_id}")
session = self.session_manager.get(thread_id, session.terminal.terminal_id)
@@ -339,13 +470,29 @@ def get_sandbox(self, thread_id: str, bind_mounts: list | None = None) -> Sandbo
if not lease:
lease = self._create_lease(terminal.lease_id, self.provider.name)
self._assert_lease_provider(lease, thread_id)
+ if lease.observed_state == "paused":
+ # @@@paused-lease-rehydrate - a persisted thread can lose its in-memory chat session
+ # while the lease stays paused in storage; resume before reconstructing capability.
+ if not self.resume_session(thread_id, source="auto_resume"):
+ raise RuntimeError(f"Failed to resume paused session for thread {thread_id}")
+ session = self.session_manager.get(thread_id, terminal.terminal_id)
+ if session:
+ self._assert_lease_provider(session.lease, thread_id)
+ self._ensure_bound_instance(session.lease)
+ return SandboxCapability(session, manager=self)
+ lease = self._get_lease(terminal.lease_id)
+ if not lease:
+ raise RuntimeError(f"Lease disappeared after resume for thread {thread_id}")
+ self._assert_lease_provider(lease, thread_id)
# Stamp bind_mounts on lease so lazy creation paths pick them up
if bind_mounts:
lease.bind_mounts = bind_mounts
- # @@@volume-strategy-gate - mount volume into sandbox
- storage = self._setup_mounts(thread_id)
+ storage = None
+ if self._requires_volume_bootstrap():
+ # @@@volume-strategy-gate - remote runtimes need volume mount/sync before first command.
+ storage = self._setup_mounts(thread_id)
self._ensure_bound_instance(lease)
@@ -375,7 +522,7 @@ def get_sandbox(self, thread_id: str, bind_mounts: list | None = None) -> Sandbo
lease=lease,
)
- if instance:
+ if instance and storage is not None:
# @@@workspace-upload - sync files to sandbox after creation
self._sync_to_sandbox(thread_id, instance.instance_id, source=storage["source"])
self._fire_session_ready(instance.instance_id, "create")
@@ -518,15 +665,26 @@ def enforce_idle_timeouts(self) -> int:
if self._lease_is_busy(lease.lease_id):
continue
status = lease.refresh_instance_status(self.provider)
- # Only pause remote providers (local sandbox doesn't need pause)
+ capability = self.provider.get_capability()
+ # @@@idle-reaper-reclaim-contract - idle timeout must reclaim remote resources; providers
+ # that cannot pause should destroy instead of repeatedly throwing unsupported-operation noise.
if status == "running" and self.provider.name != "local":
try:
- paused = lease.pause_instance(self.provider, source="idle_reaper")
+ if capability.can_pause:
+ reclaimed = lease.pause_instance(self.provider, source="idle_reaper")
+ elif capability.can_destroy:
+ reclaimed = lease.destroy_instance(self.provider, source="idle_reaper") is None
+ else:
+ print(
+ f"[idle-reaper] provider {self.provider.name} cannot reclaim expired lease "
+ f"{lease.lease_id} for thread {thread_id}"
+ )
+ continue
except Exception as exc:
- print(f"[idle-reaper] failed to pause expired lease {lease.lease_id} for thread {thread_id}: {exc}")
+ print(f"[idle-reaper] failed to reclaim expired lease {lease.lease_id} for thread {thread_id}: {exc}")
continue
- if not paused:
- print(f"[idle-reaper] failed to pause expired lease {lease.lease_id} for thread {thread_id}")
+ if not reclaimed:
+ print(f"[idle-reaper] failed to reclaim expired lease {lease.lease_id} for thread {thread_id}")
continue
self.session_manager.delete(session_id, reason="idle_timeout")
@@ -596,6 +754,10 @@ def resume_session(self, thread_id: str, source: str = "user_resume") -> bool:
for terminal in terminals:
session = self.session_manager.get(thread_id, terminal.terminal_id)
if session:
+ session.lease = lease
+ runtime = getattr(session, "runtime", None)
+ if runtime is not None:
+ runtime.lease = lease
self.session_manager.resume(session.session_id)
resumed_any = True
diff --git a/sandbox/provider.py b/sandbox/provider.py
index fc298afed..d96524206 100644
--- a/sandbox/provider.py
+++ b/sandbox/provider.py
@@ -260,6 +260,10 @@ def delete_managed_volume(self, backend_ref: str) -> None:
"""Delete provider-managed persistent volume."""
raise NotImplementedError(f"{self.name} does not support managed volumes")
+ def wait_managed_volume_ready(self, backend_ref: str) -> None:
+ """Block until a previously created managed volume is reusable."""
+ return None
+
def set_thread_bind_mounts(self, thread_id: str, mounts: list) -> None:
"""Set per-thread bind mounts for next create_session(). No-op for providers without mount support."""
pass
diff --git a/sandbox/providers/agentbay.py b/sandbox/providers/agentbay.py
index 4f3e7c996..95e16da05 100644
--- a/sandbox/providers/agentbay.py
+++ b/sandbox/providers/agentbay.py
@@ -6,9 +6,14 @@
from __future__ import annotations
+import json
+import time
from dataclasses import replace
+from types import SimpleNamespace
from typing import TYPE_CHECKING, Any
+import requests
+
from sandbox.provider import (
Metrics,
ProviderCapability,
@@ -100,7 +105,7 @@ def create_session(self, context_id: str | None = None, thread_id: str | None =
if not result.success:
raise RuntimeError(f"Failed to create session: {result.error_message}")
- session = result.session
+ session = self._hydrate_direct_call_session(result.session)
self._sessions[session.session_id] = session
return SessionInfo(
@@ -111,7 +116,10 @@ def create_session(self, context_id: str | None = None, thread_id: str | None =
def destroy_session(self, session_id: str, sync: bool = True) -> bool:
session = self._get_session(session_id)
- result = session.delete(sync_context=sync)
+ # @@@agentbay-destroy-without-pause - some AgentBay account tiers wire delete(sync_context=True)
+ # through pause/sync first; when pause is unsupported, destroy must skip sync_context entirely.
+ effective_sync = sync and self.get_capability().can_pause
+ result = session.delete(sync_context=effective_sync)
if result.success:
self._sessions.pop(session_id, None)
return result.success
@@ -161,17 +169,65 @@ def execute(
) -> ProviderExecResult:
session = self._get_session(session_id)
timeout_ms = min(timeout_ms, 50000)
-
- result = session.command.execute_command(
- command=command,
- timeout_ms=timeout_ms,
- cwd=cwd or self.default_context_path,
+ exec_args = {
+ "command": command,
+ "timeout_ms": timeout_ms,
+ "cwd": cwd or self.default_context_path,
+ }
+ shell_server = self._resolve_shell_server(session)
+ session_tools = getattr(session, "mcpTools", None) or getattr(session, "mcp_tools", None) or []
+ print(
+ "[AgentBay.execute] "
+ f"session_id={session_id} "
+ f"has_link_url={bool(getattr(session, 'link_url', ''))} "
+ f"has_token={bool(getattr(session, 'token', ''))} "
+ f"shell_server={shell_server!r} "
+ f"tool_count={len(session_tools)} "
+ f"timeout_ms={timeout_ms}",
+ flush=True,
)
- if not result.success:
- return ProviderExecResult(output="", error=result.error_message)
+ if getattr(session, "link_url", "") and getattr(session, "token", "") and shell_server:
+ # @@@agentbay-shell-link-route - shared staging proved shell can degrade into the API path
+ # despite hydrated direct-call metadata; take the explicit LinkUrl route when shell server is known.
+ result = self._call_link_url_tool(session, "shell", exec_args, shell_server)
+ print(
+ "[AgentBay.execute] "
+ f"session_id={session_id} path=link_url exit_code={result.exit_code} "
+ f"error={result.error!r} output_len={len(result.output or '')}",
+ flush=True,
+ )
+ return result
- return ProviderExecResult(output=result.output or "")
+ print(f"[AgentBay.execute] session_id={session_id} path=sdk_command_execute", flush=True)
+ try:
+ result = session.command.execute_command(**exec_args)
+ except Exception as exc:
+ print(
+ f"[AgentBay.execute] session_id={session_id} path=sdk_command_execute raised={exc.__class__.__name__}: {exc}",
+ flush=True,
+ )
+ raise
+
+ if not result.success:
+ print(
+ "[AgentBay.execute] "
+ f"session_id={session_id} path=sdk_command_execute success=False "
+ f"exit_code={getattr(result, 'exit_code', None)} "
+ f"error={getattr(result, 'error_message', None)!r} "
+ f"output_len={len(getattr(result, 'output', '') or '')}",
+ flush=True,
+ )
+ return ProviderExecResult(output=result.output or "", exit_code=result.exit_code or 1, error=result.error_message)
+
+ print(
+ "[AgentBay.execute] "
+ f"session_id={session_id} path=sdk_command_execute success=True "
+ f"exit_code={getattr(result, 'exit_code', None)} "
+ f"output_len={len(getattr(result, 'output', '') or '')}",
+ flush=True,
+ )
+ return ProviderExecResult(output=result.output or "", exit_code=result.exit_code or 0)
def read_file(self, session_id: str, path: str) -> str:
session = self._get_session(session_id)
@@ -246,7 +302,168 @@ def _get_session(self, session_id: str):
if not result.success:
raise RuntimeError(f"Session not found: {session_id}")
self._sessions[session_id] = result.session
- return self._sessions[session_id]
+ cached = self._sessions[session_id]
+ hydrated = self._hydrate_direct_call_session(cached)
+ self._sessions[session_id] = hydrated
+ return hydrated
+
+ def _hydrate_direct_call_session(self, session: Any):
+ """Ensure cached session carries LinkUrl/token/tool metadata for direct shell calls."""
+ if not self._session_needs_direct_call_refresh(session):
+ return session
+ session_id = str(getattr(session, "session_id", "") or "")
+ if not session_id:
+ raise RuntimeError("AgentBay session missing session_id")
+ refreshed = self.client.get(session_id)
+ if not refreshed.success:
+ raise RuntimeError(f"Failed to hydrate AgentBay session {session_id}: {refreshed.error_message}")
+ hydrated = refreshed.session
+ if self._session_needs_direct_call_refresh(hydrated):
+ metadata = self._fetch_direct_call_metadata(session_id)
+ self._apply_direct_call_metadata(hydrated, metadata)
+ return hydrated
+
+ @staticmethod
+ def _resolve_shell_server(session: Any) -> str | None:
+ for resolver_name in ("_get_mcp_server_for_tool", "_find_server_for_tool"):
+ resolver = getattr(session, resolver_name, None)
+ if callable(resolver):
+ try:
+ server_name = resolver("shell")
+ except Exception:
+ continue
+ if server_name:
+ return str(server_name)
+ for tools_attr in ("mcpTools", "mcp_tools"):
+ tools = getattr(session, tools_attr, None) or []
+ for tool in tools:
+ if getattr(tool, "name", None) == "shell":
+ server_name = getattr(tool, "server", "") or ""
+ if server_name:
+ return str(server_name)
+ return None
+
+ @staticmethod
+ def _provider_exec_result_from_tool_result(tool_result: Any) -> ProviderExecResult:
+ if not getattr(tool_result, "success", False):
+ error_message = getattr(tool_result, "error_message", "") or "Failed to execute command"
+ return ProviderExecResult(output="", exit_code=1, error=error_message)
+ data = getattr(tool_result, "data", "")
+ try:
+ payload = json.loads(data) if isinstance(data, str) else data
+ except json.JSONDecodeError:
+ payload = None
+ if isinstance(payload, dict):
+ stdout = str(payload.get("stdout", "") or "")
+ stderr = str(payload.get("stderr", "") or "")
+ exit_code = int(payload.get("exit_code", 0) or 0)
+ error = stderr or None
+ return ProviderExecResult(output=stdout + stderr, exit_code=exit_code, error=error)
+ return ProviderExecResult(output=str(data or ""), exit_code=0)
+
+ def _call_link_url_tool(
+ self,
+ session: Any,
+ tool_name: str,
+ args: dict[str, Any],
+ server_name: str,
+ ) -> ProviderExecResult:
+ link_url = str(getattr(session, "link_url", "") or "")
+ token = str(getattr(session, "token", "") or "")
+ if not link_url or not token:
+ return ProviderExecResult(output="", exit_code=1, error="LinkUrl/token not available")
+
+ try:
+ response = requests.post(
+ link_url.rstrip("/") + "/callTool",
+ json={
+ "args": args,
+ "server": server_name,
+ "requestId": f"link-{int(time.time() * 1000)}",
+ "tool": tool_name,
+ "token": token,
+ },
+ headers={
+ "Content-Type": "application/json",
+ "X-Access-Token": token,
+ },
+ timeout=max(int(args.get("timeout_ms", 30000) or 30000) / 1000.0, 30.0),
+ )
+ except requests.RequestException as exc:
+ return ProviderExecResult(output="", exit_code=1, error=f"HTTP request failed: {exc}")
+ if response.status_code < 200 or response.status_code >= 300:
+ return ProviderExecResult(output="", exit_code=1, error=f"HTTP request failed with code: {response.status_code}")
+
+ outer = response.json()
+ data_field = outer.get("data")
+ if data_field is None:
+ return ProviderExecResult(output="", exit_code=1, error="No data field in LinkUrl response")
+ parsed_data = json.loads(data_field) if isinstance(data_field, str) else data_field
+ if not isinstance(parsed_data, dict):
+ return ProviderExecResult(output="", exit_code=1, error="Invalid data field type in LinkUrl response")
+
+ result_field = parsed_data.get("result", {})
+ if not isinstance(result_field, dict):
+ return ProviderExecResult(output="", exit_code=1, error="No result field in LinkUrl response data")
+
+ content = result_field.get("content", [])
+ text_content = ""
+ if isinstance(content, list) and content:
+ first = content[0]
+ if isinstance(first, str):
+ text_content = first
+ elif isinstance(first, dict):
+ text_content = str(first.get("text") or first.get("blob") or first.get("data") or "")
+ elif isinstance(content, str):
+ text_content = content
+
+ if result_field.get("isError", False):
+ error_message = text_content or json.dumps(result_field, ensure_ascii=False)
+ return ProviderExecResult(output="", exit_code=1, error=error_message)
+
+ return self._provider_exec_result_from_tool_result(SimpleNamespace(success=True, data=text_content, error_message=""))
+
+ @staticmethod
+ def _session_needs_direct_call_refresh(session: Any) -> bool:
+ # @@@agentbay-direct-call-hydration - shared staging may return a create-session object
+ # without token/link_url/mcpTools; refresh once so shell execution stays on the richer LinkUrl path.
+ if not getattr(session, "token", ""):
+ return True
+ if not getattr(session, "link_url", ""):
+ return True
+ tools = getattr(session, "mcpTools", None) or getattr(session, "mcp_tools", None)
+ return not bool(tools)
+
+ def _fetch_direct_call_metadata(self, session_id: str) -> dict[str, Any]:
+ from agentbay.api.models import GetSessionRequest
+
+ # @@@agentbay-raw-get-session - the SDK Session object drops LinkUrl/ToolList for this account tier,
+ # but the raw GetSession response still carries them. Pull that response directly and patch the session.
+ request = GetSessionRequest(authorization=f"Bearer {self.client.api_key}", session_id=session_id)
+ response = self.client.client.get_session(request)
+ body = response.to_map().get("body", {})
+ data = body.get("Data", {}) or {}
+ return {
+ "link_url": data.get("LinkUrl", "") or "",
+ "token": data.get("Token", "") or "",
+ "mcp_tools": [
+ SimpleNamespace(name=str(tool.get("Name", "") or ""), server=str(tool.get("Server", "") or ""))
+ for tool in (data.get("ToolList", []) or [])
+ ],
+ }
+
+ @staticmethod
+ def _apply_direct_call_metadata(session: Any, metadata: dict[str, Any]) -> None:
+ link_url = str(metadata.get("link_url", "") or "")
+ if link_url:
+ setattr(session, "link_url", link_url)
+ token = str(metadata.get("token", "") or "")
+ if token:
+ setattr(session, "token", token)
+ tools = metadata.get("mcp_tools", []) or []
+ if tools:
+ setattr(session, "mcp_tools", tools)
+ setattr(session, "mcpTools", tools)
def create_runtime(self, terminal: AbstractTerminal, lease: SandboxLease) -> PhysicalTerminalRuntime:
from sandbox.runtime import RemoteWrappedRuntime
diff --git a/sandbox/providers/daytona.py b/sandbox/providers/daytona.py
index def0f865f..f314d5621 100644
--- a/sandbox/providers/daytona.py
+++ b/sandbox/providers/daytona.py
@@ -15,6 +15,7 @@
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Any
+from urllib.parse import urlparse, urlunparse
import httpx
@@ -107,6 +108,13 @@ def __init__(
os.environ["DAYTONA_API_KEY"] = api_key
os.environ["DAYTONA_API_URL"] = api_url
self.client = Daytona()
+ original_get_proxy_toolbox_url = self.client._get_proxy_toolbox_url
+
+ def _wrapped_get_proxy_toolbox_url(sandbox_id: str, region_id: str) -> str:
+ raw_url = original_get_proxy_toolbox_url(sandbox_id, region_id)
+ return self._normalize_toolbox_proxy_url(raw_url)
+
+ self.client._get_proxy_toolbox_url = _wrapped_get_proxy_toolbox_url
self._sandboxes: dict[str, Any] = {}
self._thread_bind_mounts: dict[str, list[MountSpec]] = {} # thread_id -> bind_mounts
self._volume_mounts: dict[str, tuple[str, str]] = {} # thread_id -> (volume_id, mount_path)
@@ -123,13 +131,17 @@ def create_managed_volume(self, member_id: str, mount_path: str) -> str:
logger.info("Creating managed volume: %s", volume_name)
# @@@volume-ready - volume transitions pending_create → ready (~6s)
self.client.volume.create(volume_name)
+ self.wait_managed_volume_ready(volume_name)
+ return volume_name
+
+ def wait_managed_volume_ready(self, backend_ref: str) -> None:
for _ in range(30):
- vol = self.client.volume.get(volume_name)
+ vol = self.client.volume.get(backend_ref)
if vol.state == "ready":
- logger.info("Managed volume ready: %s (id=%s)", volume_name, vol.id)
- return volume_name
+ logger.info("Managed volume ready: %s (id=%s)", backend_ref, vol.id)
+ return
time.sleep(1)
- raise RuntimeError(f"Volume {volume_name} did not become ready within 30s")
+ raise RuntimeError(f"Volume {backend_ref} did not become ready within 30s")
def set_managed_volume_mount(self, thread_id: str, backend_ref: str, mount_path: str) -> None:
self._volume_mounts[thread_id] = (backend_ref, mount_path)
@@ -390,6 +402,19 @@ def _get_sandbox(self, session_id: str):
self._sandboxes[session_id] = self.client.find_one(session_id)
return self._sandboxes[session_id]
+ def _normalize_toolbox_proxy_url(self, raw_url: str) -> str:
+ api_host = (urlparse(self.api_url).hostname or "").lower()
+ if api_host not in {"localhost", "127.0.0.1"}:
+ return raw_url
+
+ parsed = urlparse(raw_url)
+ if (parsed.hostname or "").lower() != "172.18.0.1":
+ return raw_url
+
+ # @@@local-toolbox-loopback - self-host Daytona local dev reaches toolbox through
+ # the SSH-forwarded loopback proxy on :4000, not the server-side docker bridge gateway.
+ return urlunparse(parsed._replace(netloc=f"127.0.0.1:{parsed.port or 4000}"))
+
def get_runtime_sandbox(self, session_id: str):
"""Expose native SDK sandbox for runtime-level persistent terminal handling."""
return self._get_sandbox(session_id)
diff --git a/sandbox/providers/e2b.py b/sandbox/providers/e2b.py
index 5827b124b..482f66cdf 100644
--- a/sandbox/providers/e2b.py
+++ b/sandbox/providers/e2b.py
@@ -68,6 +68,10 @@ def __init__(
timeout: int = 300,
provider_name: str | None = None,
):
+ # @@@e2b-sdk-presence - staging inventory must fail loudly when the SDK is absent,
+ # otherwise provider catalog/create-thread gates can overclaim e2b availability.
+ from e2b import Sandbox # noqa: F401
+
if provider_name:
self.name = provider_name
self.api_key = api_key
@@ -88,6 +92,16 @@ def create_session(self, context_id: str | None = None, thread_id: str | None =
api_key=self.api_key,
)
self._sandboxes[sandbox.sandbox_id] = sandbox
+ # @@@e2b-workspace-bootstrap - fresh E2B sandboxes do not guarantee our sync root exists.
+ # Create it eagerly so upload/download and file hints target a real path contract.
+ bootstrap = sandbox.commands.run(
+ f"mkdir -p {self.WORKSPACE_ROOT}/files",
+ cwd=self.default_cwd,
+ timeout=10,
+ )
+ if getattr(bootstrap, "exit_code", 0) != 0:
+ error = getattr(bootstrap, "stderr", "") or getattr(bootstrap, "stdout", "") or "unknown error"
+ raise RuntimeError(f"Failed to bootstrap E2B workspace root: {error}")
return SessionInfo(
session_id=sandbox.sandbox_id,
diff --git a/sandbox/providers/local.py b/sandbox/providers/local.py
index a8c6c6f02..b5766b9c9 100644
--- a/sandbox/providers/local.py
+++ b/sandbox/providers/local.py
@@ -7,6 +7,7 @@
import shlex
import subprocess
import threading
+import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
@@ -171,6 +172,12 @@ def list_dir(self, session_id: str, path: str) -> list[dict]:
return items
def get_metrics(self, session_id: str) -> Metrics | None:
+ if platform.system() == "Linux":
+ metrics = self._get_metrics_via_procfs()
+ if metrics is not None:
+ return metrics
+ return self.get_metrics_via_commands(session_id)
+
if platform.system() != "Darwin":
return self.get_metrics_via_commands(session_id)
@@ -222,6 +229,59 @@ def get_metrics(self, session_id: str) -> Metrics | None:
except Exception:
return None
+ def _get_metrics_via_procfs(self) -> Metrics | None:
+ try:
+ cpu_percent = self._sample_linux_cpu_percent()
+
+ meminfo: dict[str, int] = {}
+ with open("/proc/meminfo") as fh:
+ for line in fh:
+ key, _, raw = line.partition(":")
+ value = raw.strip().split()[0] if raw.strip() else ""
+ if value.isdigit():
+ meminfo[key] = int(value)
+
+ total_kb = meminfo.get("MemTotal")
+ available_kb = meminfo.get("MemAvailable")
+ memory_total_mb = (total_kb / 1024.0) if total_kb is not None else None
+ memory_used_mb = ((total_kb - available_kb) / 1024.0) if total_kb is not None and available_kb is not None else None
+
+ stat = os.statvfs("/")
+ total_bytes = stat.f_blocks * stat.f_frsize
+ free_bytes = stat.f_bavail * stat.f_frsize
+ disk_total_gb = total_bytes / (1024.0**3)
+ disk_used_gb = (total_bytes - free_bytes) / (1024.0**3)
+
+ return Metrics(
+ cpu_percent=cpu_percent,
+ memory_used_mb=memory_used_mb,
+ memory_total_mb=memory_total_mb,
+ disk_used_gb=disk_used_gb,
+ disk_total_gb=disk_total_gb,
+ )
+ except Exception:
+ return None
+
+ def _sample_linux_cpu_percent(self) -> float | None:
+ first_total, first_idle = self._read_linux_cpu_totals()
+ time.sleep(0.1)
+ second_total, second_idle = self._read_linux_cpu_totals()
+ total_delta = second_total - first_total
+ idle_delta = second_idle - first_idle
+ if total_delta <= 0:
+ return None
+ busy_delta = total_delta - idle_delta
+ return max(0.0, min(100.0, (busy_delta / total_delta) * 100.0))
+
+ def _read_linux_cpu_totals(self) -> tuple[int, int]:
+ with open("/proc/stat") as fh:
+ first = fh.readline().strip()
+ parts = first.split()
+ values = [int(value) for value in parts[1:9]]
+ total = sum(values)
+ idle = values[3] + values[4]
+ return total, idle
+
def create_runtime(self, terminal: AbstractTerminal, lease: SandboxLease) -> PhysicalTerminalRuntime:
from sandbox.providers.local import LocalPersistentShellRuntime
diff --git a/sandbox/runtime.py b/sandbox/runtime.py
index 87cecd024..d68a747ff 100644
--- a/sandbox/runtime.py
+++ b/sandbox/runtime.py
@@ -749,6 +749,8 @@ def _looks_like_infra_error(text: str) -> bool:
"no such session",
"session does not exist",
"failed to create pty session",
+ "failed to send input to pty",
+ "pty control error",
"no ip address found",
"is the sandbox started",
"is paused",
@@ -758,6 +760,9 @@ def _looks_like_infra_error(text: str) -> bool:
"websocket",
"close frame",
"no close frame",
+ "internal error",
+ "1011",
+ "broken pipe",
"transport",
"unreachable",
"timed out",
@@ -806,6 +811,17 @@ def _execute_once(self, command: str, timeout: float | None = None) -> ExecuteRe
instance = self.lease.ensure_active_instance(self.provider)
state = self.terminal.get_state()
timeout_ms = int(timeout * 1000) if timeout else 30000
+ print(
+ "[RemoteWrappedRuntime._execute_once] "
+ f"thread_id={self.terminal.thread_id} "
+ f"lease_id={self.lease.lease_id} "
+ f"instance_id={instance.instance_id} "
+ f"provider={getattr(self.provider, 'name', '?')} "
+ f"cwd={state.cwd!r} "
+ f"timeout_ms={timeout_ms} "
+ f"command={command[:200]!r}",
+ flush=True,
+ )
# @@@ _build_state_snapshot_cmd returns (start, end, cmd) but RemoteWrappedRuntime
# builds its own inline block to interleave cd/exports/command, so the pre-built cmd is unused.
start_marker, end_marker, _ = _build_state_snapshot_cmd()
@@ -832,14 +848,32 @@ def _execute_once(self, command: str, timeout: float | None = None) -> ExecuteRe
cwd=state.cwd,
)
raw_output = result.output or ""
-
- new_cwd, env_map, raw_output = _extract_state_from_output(
- raw_output,
- start_marker,
- end_marker,
- cwd_fallback=state.cwd,
- env_fallback=state.env_delta,
+ print(
+ "[RemoteWrappedRuntime._execute_once] "
+ f"thread_id={self.terminal.thread_id} "
+ f"provider_exit={result.exit_code} "
+ f"provider_error={result.error!r} "
+ f"output_len={len(raw_output)}",
+ flush=True,
)
+
+ try:
+ new_cwd, env_map, raw_output = _extract_state_from_output(
+ raw_output,
+ start_marker,
+ end_marker,
+ cwd_fallback=state.cwd,
+ env_fallback=state.env_delta,
+ )
+ except Exception as exc:
+ print(
+ "[RemoteWrappedRuntime._execute_once] "
+ f"thread_id={self.terminal.thread_id} "
+ f"state_parse_failed={exc.__class__.__name__}: {exc} "
+ f"raw_output_preview={raw_output[:400]!r}",
+ flush=True,
+ )
+ raise
from sandbox.terminal import TerminalState
self.update_terminal_state(TerminalState(cwd=new_cwd, env_delta=env_map))
diff --git a/sandbox/thread_context.py b/sandbox/thread_context.py
index d52ba7ef1..d98e9895c 100644
--- a/sandbox/thread_context.py
+++ b/sandbox/thread_context.py
@@ -3,10 +3,14 @@
from __future__ import annotations
from contextvars import ContextVar
+from typing import Any
_current_thread_id: ContextVar[str] = ContextVar("sandbox_thread_id", default="")
# @@@run-context - groups file ops per execution unit: checkpoint_id in TUI, run_id in web mode.
_current_run_id: ContextVar[str] = ContextVar("sandbox_run_id", default="")
+# Parent conversation messages — set by QueryLoop before tool execution; read by AgentService
+# for forkContext=True sub-agent spawning.
+_current_messages: ContextVar[list[Any]] = ContextVar("current_messages", default=[])
def set_current_thread_id(thread_id: str) -> None:
@@ -25,3 +29,11 @@ def set_current_run_id(run_id: str) -> None:
def get_current_run_id() -> str | None:
value = _current_run_id.get()
return value if value else None
+
+
+def set_current_messages(messages: list[Any]) -> None:
+ _current_messages.set(list(messages))
+
+
+def get_current_messages() -> list[Any]:
+ return _current_messages.get()
diff --git a/storage/providers/sqlite/agent_registry_repo.py b/storage/providers/sqlite/agent_registry_repo.py
index 02aa62aeb..a9a2c0e87 100644
--- a/storage/providers/sqlite/agent_registry_repo.py
+++ b/storage/providers/sqlite/agent_registry_repo.py
@@ -59,11 +59,27 @@ def get_by_id(self, agent_id: str) -> tuple | None:
(agent_id,),
).fetchone()
+ def list_running_by_name(self, name: str) -> list[tuple]:
+ with self._conn() as conn:
+ return conn.execute(
+ "SELECT agent_id, name, thread_id, status, parent_agent_id, subagent_type "
+ "FROM agents WHERE name=? AND status='running' ORDER BY created_at DESC, agent_id DESC",
+ (name,),
+ ).fetchall()
+
def update_status(self, agent_id: str, status: str) -> None:
with self._conn() as conn:
conn.execute("UPDATE agents SET status=? WHERE agent_id=?", (status, agent_id))
conn.commit()
+ def get_latest_by_name_and_parent(self, name: str, parent_agent_id: str | None) -> tuple | None:
+ with self._conn() as conn:
+ return conn.execute(
+ "SELECT agent_id, name, thread_id, status, parent_agent_id, subagent_type "
+ "FROM agents WHERE name=? AND parent_agent_id IS ? ORDER BY created_at DESC, agent_id DESC LIMIT 1",
+ (name, parent_agent_id),
+ ).fetchone()
+
def list_running(self) -> list[tuple]:
with self._conn() as conn:
return conn.execute(
diff --git a/storage/providers/sqlite/chat_repo.py b/storage/providers/sqlite/chat_repo.py
index f761c6e5a..37ca68ad7 100644
--- a/storage/providers/sqlite/chat_repo.py
+++ b/storage/providers/sqlite/chat_repo.py
@@ -172,6 +172,12 @@ def _ensure_table(self) -> None:
)
"""
)
+ # @@@entity-id-to-user-id-migration - old chat dbs still used entity_id.
+ # Rename first so later index creation does not explode on missing user_id.
+ try:
+ self._conn.execute("ALTER TABLE chat_entities RENAME COLUMN entity_id TO user_id")
+ except sqlite3.OperationalError:
+ pass # column already named user_id, or table is new
# @@@chat-entity-migration - add muted/mute_until if table already exists
try:
self._conn.execute("ALTER TABLE chat_entities ADD COLUMN muted INTEGER NOT NULL DEFAULT 0")
@@ -183,11 +189,6 @@ def _ensure_table(self) -> None:
pass
# @@@chat-entity-index — speeds up find_chat_between and list_chats_for_user
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_chat_entities_user ON chat_entities(user_id, chat_id)")
- # @@@entity-id-to-user-id-migration — rename column for existing databases
- try:
- self._conn.execute("ALTER TABLE chat_entities RENAME COLUMN entity_id TO user_id")
- except sqlite3.OperationalError:
- pass # column already named user_id, or table is new
self._conn.commit()
diff --git a/storage/providers/sqlite/lease_repo.py b/storage/providers/sqlite/lease_repo.py
index f0ab745c9..de9f7663e 100644
--- a/storage/providers/sqlite/lease_repo.py
+++ b/storage/providers/sqlite/lease_repo.py
@@ -250,6 +250,20 @@ def mark_needs_refresh(self, lease_id: str, hint_at: datetime | None = None) ->
self._conn.commit()
return cursor.rowcount > 0
+ def set_volume_id(self, lease_id: str, volume_id: str) -> bool:
+ with self._lock:
+ cursor = self._conn.execute(
+ """
+ UPDATE sandbox_leases
+ SET volume_id = ?,
+ updated_at = ?
+ WHERE lease_id = ?
+ """,
+ (volume_id, datetime.now().isoformat(), lease_id),
+ )
+ self._conn.commit()
+ return cursor.rowcount > 0
+
def delete(self, lease_id: str) -> None:
with self._lock:
self._conn.execute("DELETE FROM sandbox_instances WHERE lease_id = ?", (lease_id,))
diff --git a/storage/providers/supabase/agent_registry_repo.py b/storage/providers/supabase/agent_registry_repo.py
index 8aaccd1d0..31bca5506 100644
--- a/storage/providers/supabase/agent_registry_repo.py
+++ b/storage/providers/supabase/agent_registry_repo.py
@@ -55,6 +55,22 @@ def get_by_id(self, agent_id: str) -> tuple | None:
def update_status(self, agent_id: str, status: str) -> None:
self._table().update({"status": status}).eq("agent_id", agent_id).execute()
+ def get_latest_by_name_and_parent(self, name: str, parent_agent_id: str | None) -> tuple | None:
+ query = self._table().select("agent_id,name,thread_id,status,parent_agent_id,subagent_type").eq("name", name)
+ if parent_agent_id is None:
+ query = query.is_("parent_agent_id", "null")
+ else:
+ query = query.eq("parent_agent_id", parent_agent_id)
+ rows = q.rows(
+ query.order("created_at", desc=True).limit(1).execute(),
+ _REPO,
+ "get_latest_by_name_and_parent",
+ )
+ if not rows:
+ return None
+ r = rows[0]
+ return (r["agent_id"], r["name"], r["thread_id"], r["status"], r.get("parent_agent_id"), r.get("subagent_type"))
+
def list_running(self) -> list[tuple]:
rows = q.rows(
self._table().select("agent_id,name,thread_id,status,parent_agent_id,subagent_type").eq("status", "running").execute(),
diff --git a/storage/providers/supabase/chat_repo.py b/storage/providers/supabase/chat_repo.py
index d0cfaa0ab..401fb3726 100644
--- a/storage/providers/supabase/chat_repo.py
+++ b/storage/providers/supabase/chat_repo.py
@@ -212,6 +212,16 @@ def count_unread(self, chat_id: str, user_id: str) -> int:
raw = q.rows(response, _REPO_MSG, "count_unread")
return len(raw)
+ def has_unread_mention(self, chat_id: str, user_id: str) -> bool:
+ resp_ce = self._client.table(_TABLE_CHAT_ENTITIES).select("last_read_at").eq("chat_id", chat_id).eq("user_id", user_id).execute()
+ ce_rows = q.rows(resp_ce, _REPO_MSG, "has_unread_mention(last_read_at)")
+ if not ce_rows:
+ return False
+ for message in self.list_unread(chat_id, user_id):
+ if user_id in message.mentioned_ids:
+ return True
+ return False
+
def list_by_time_range(
self,
chat_id: str,
diff --git a/storage/providers/supabase/entity_repo.py b/storage/providers/supabase/entity_repo.py
index cb2e0dc84..b4ecc1dc7 100644
--- a/storage/providers/supabase/entity_repo.py
+++ b/storage/providers/supabase/entity_repo.py
@@ -43,6 +43,13 @@ def get_by_member_id(self, member_id: str) -> list[EntityRow]:
rows = q.rows(response, _REPO, "get_by_member_id")
return [EntityRow.model_validate(r) for r in rows]
+ def get_by_thread_id(self, thread_id: str) -> EntityRow | None:
+ response = self._t().select("*").eq("thread_id", thread_id).execute()
+ rows = q.rows(response, _REPO, "get_by_thread_id")
+ if not rows:
+ return None
+ return EntityRow.model_validate(rows[0])
+
def list_all(self) -> list[EntityRow]:
query = q.order(self._t().select("*"), "created_at", desc=False, repo=_REPO, operation="list_all")
rows = q.rows(query.execute(), _REPO, "list_all")
diff --git a/storage/runtime.py b/storage/runtime.py
index 0a2d1b394..a522fe3da 100644
--- a/storage/runtime.py
+++ b/storage/runtime.py
@@ -59,6 +59,68 @@ def build_storage_container(
)
+def build_thread_repo(
+ *,
+ main_db_path: str | Path | None = None,
+ strategy: str | None = None,
+ supabase_client: Any | None = None,
+ supabase_client_factory: str | None = None,
+ env: Mapping[str, str] | None = None,
+):
+ env_map = env if env is not None else os.environ
+ resolved_strategy = _resolve_strategy(strategy if strategy is not None else env_map.get("LEON_STORAGE_STRATEGY"))
+ if resolved_strategy == "supabase":
+ client = supabase_client
+ if client is None:
+ factory_ref = supabase_client_factory if supabase_client_factory is not None else env_map.get("LEON_SUPABASE_CLIENT_FACTORY")
+ if not factory_ref:
+ raise RuntimeError(
+ "Supabase thread repo requires runtime config. "
+ "Set LEON_SUPABASE_CLIENT_FACTORY=: "
+ "or inject supabase_client explicitly."
+ )
+ client = _load_factory(factory_ref)()
+ _ensure_supabase_client(client)
+ from storage.providers.supabase.thread_repo import SupabaseThreadRepo
+
+ return SupabaseThreadRepo(client)
+
+ from storage.providers.sqlite.thread_repo import SQLiteThreadRepo
+
+ return SQLiteThreadRepo(db_path=main_db_path)
+
+
+def build_member_repo(
+ *,
+ main_db_path: str | Path | None = None,
+ strategy: str | None = None,
+ supabase_client: Any | None = None,
+ supabase_client_factory: str | None = None,
+ env: Mapping[str, str] | None = None,
+):
+ env_map = env if env is not None else os.environ
+ resolved_strategy = _resolve_strategy(strategy if strategy is not None else env_map.get("LEON_STORAGE_STRATEGY"))
+ if resolved_strategy == "supabase":
+ client = supabase_client
+ if client is None:
+ factory_ref = supabase_client_factory if supabase_client_factory is not None else env_map.get("LEON_SUPABASE_CLIENT_FACTORY")
+ if not factory_ref:
+ raise RuntimeError(
+ "Supabase member repo requires runtime config. "
+ "Set LEON_SUPABASE_CLIENT_FACTORY=: "
+ "or inject supabase_client explicitly."
+ )
+ client = _load_factory(factory_ref)()
+ _ensure_supabase_client(client)
+ from storage.providers.supabase.member_repo import SupabaseMemberRepo
+
+ return SupabaseMemberRepo(client)
+
+ from storage.providers.sqlite.member_repo import SQLiteMemberRepo
+
+ return SQLiteMemberRepo(db_path=main_db_path)
+
+
def _resolve_strategy(raw: str | None) -> StorageStrategy:
value = (raw or "sqlite").strip().lower()
if value in {"", "sqlite"}:
diff --git a/tests/config/conftest.py b/tests/Config/conftest.py
similarity index 100%
rename from tests/config/conftest.py
rename to tests/Config/conftest.py
diff --git a/tests/config/test_loader.py b/tests/Config/test_loader.py
similarity index 74%
rename from tests/config/test_loader.py
rename to tests/Config/test_loader.py
index f3671fa09..c0874f38d 100644
--- a/tests/config/test_loader.py
+++ b/tests/Config/test_loader.py
@@ -1,11 +1,13 @@
"""Comprehensive tests for config.loader module."""
+import json
import os
import sys
+from pathlib import Path
import pytest
-from config.loader import ConfigLoader, load_config
+from config.loader import AgentLoader, ConfigLoader, load_config
from config.schema import LeonSettings
@@ -157,6 +159,27 @@ def test_expand_env_vars_nested(self):
assert result["paths"] == ["/base/path1", "/base/path2"]
assert result["config"]["root"] == "/base"
+ def test_discover_mcp_preserves_explicit_transport(self, tmp_path):
+ path = tmp_path / ".mcp.json"
+ path.write_text(
+ json.dumps(
+ {
+ "mcpServers": {
+ "wsdemo": {
+ "transport": "websocket",
+ "url": "ws://example.test/mcp",
+ }
+ }
+ }
+ ),
+ encoding="utf-8",
+ )
+
+ result = ConfigLoader._discover_mcp(tmp_path)
+
+ assert result["wsdemo"].transport == "websocket"
+ assert result["wsdemo"].url == "ws://example.test/mcp"
+
class TestLoadConfigFunction:
"""Tests for load_config convenience function."""
@@ -169,3 +192,32 @@ def test_load_config_with_workspace(self, tmp_path, monkeypatch):
settings = load_config(workspace_root=str(project_dir))
assert isinstance(settings, LeonSettings)
+
+
+def test_project_agent_file_does_not_claim_bundle_source_dir(tmp_path: Path):
+ agents_dir = tmp_path / ".leon" / "agents"
+ agents_dir.mkdir(parents=True)
+ (agents_dir / "explore.md").write_text(
+ "---\nname: explore\nmodel: project-model\n---\nproject prompt\n",
+ encoding="utf-8",
+ )
+
+ agent = AgentLoader(workspace_root=tmp_path).load_all_agents()["explore"]
+
+ assert agent.model == "project-model"
+ assert agent.source_dir is None
+
+
+def test_member_agent_retains_bundle_source_dir(tmp_path: Path, monkeypatch):
+ home_root = tmp_path
+ monkeypatch.setattr("config.loader.user_home_read_candidates", lambda *parts: (home_root.joinpath(*parts),))
+ member_dir = home_root / "members" / "alice"
+ member_dir.mkdir(parents=True)
+ (member_dir / "agent.md").write_text(
+ '---\nname: alice\ntools:\n - "*"\n---\nmember prompt\n',
+ encoding="utf-8",
+ )
+
+ agent = AgentLoader(workspace_root=tmp_path).load_all_agents()["alice"]
+
+ assert agent.source_dir == member_dir.resolve()
diff --git a/tests/config/test_loader_skill_dir_bootstrap.py b/tests/Config/test_loader_skill_dir_bootstrap.py
similarity index 100%
rename from tests/config/test_loader_skill_dir_bootstrap.py
rename to tests/Config/test_loader_skill_dir_bootstrap.py
diff --git a/tests/Fix/test_auth_entity_resolution.py b/tests/Fix/test_auth_entity_resolution.py
new file mode 100644
index 000000000..c445b566f
--- /dev/null
+++ b/tests/Fix/test_auth_entity_resolution.py
@@ -0,0 +1,48 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+from fastapi import HTTPException
+
+from backend.web.core import dependencies
+
+
+class _Request:
+ def __init__(self, *, token: str, payload: dict, member_exists: bool = True) -> None:
+ self.headers = {"Authorization": f"Bearer {token}"}
+ self.app = SimpleNamespace(
+ state=SimpleNamespace(
+ auth_service=SimpleNamespace(verify_token=lambda seen: payload if seen == token else None),
+ member_repo=SimpleNamespace(get_by_id=lambda _user_id: object() if member_exists else None),
+ )
+ )
+
+
+@pytest.mark.asyncio
+async def test_get_current_entity_id_derives_human_entity_when_jwt_has_no_entity_id():
+ request = _Request(token="tok-1", payload={"user_id": "user-123"})
+
+ entity_id = await dependencies.get_current_entity_id(request)
+
+ assert entity_id == "user-123-1"
+
+
+@pytest.mark.asyncio
+async def test_get_current_entity_id_keeps_explicit_entity_id_when_present():
+ request = _Request(token="tok-1", payload={"user_id": "user-123", "entity_id": "custom-entity"})
+
+ entity_id = await dependencies.get_current_entity_id(request)
+
+ assert entity_id == "custom-entity"
+
+
+@pytest.mark.asyncio
+async def test_get_current_user_id_still_rejects_deleted_user():
+ request = _Request(token="tok-1", payload={"user_id": "ghost-user"}, member_exists=False)
+
+ with pytest.raises(HTTPException) as exc_info:
+ await dependencies.get_current_user_id(request)
+
+ assert exc_info.value.status_code == 401
+ assert exc_info.value.detail == "User no longer exists — please re-login"
diff --git a/tests/Fix/test_auth_service_token_verification.py b/tests/Fix/test_auth_service_token_verification.py
new file mode 100644
index 000000000..f145b7bd6
--- /dev/null
+++ b/tests/Fix/test_auth_service_token_verification.py
@@ -0,0 +1,246 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+from backend.web.services.auth_service import AuthService
+
+
+class _FakeSupabaseAuth:
+ def __init__(self, user_id: str = "user-1") -> None:
+ self.user_id = user_id
+ self.tokens: list[str] = []
+
+ def get_user(self, token: str):
+ self.tokens.append(token)
+ return SimpleNamespace(user=SimpleNamespace(id=self.user_id))
+
+
+class _FakeSupabaseClient:
+ def __init__(self, user_id: str = "user-1") -> None:
+ self.auth = _FakeSupabaseAuth(user_id=user_id)
+
+
+class _FakeLoginAuth:
+ def __init__(self) -> None:
+ self.calls: list[dict[str, str]] = []
+
+ def sign_in_with_password(self, payload: dict[str, str]):
+ self.calls.append(payload)
+ return SimpleNamespace(
+ user=SimpleNamespace(id="user-1"),
+ session=SimpleNamespace(access_token="tok-1"),
+ )
+
+
+class _FakeAuthClient:
+ def __init__(self) -> None:
+ self.auth = _FakeLoginAuth()
+
+
+class _FactoryBackedLoginAuth:
+ def __init__(self, owner: _FactoryBackedAuthClient) -> None:
+ self._owner = owner
+
+ def sign_in_with_password(self, payload: dict[str, str]):
+ self._owner.calls.append(payload)
+ return SimpleNamespace(
+ user=SimpleNamespace(id="user-1"),
+ session=SimpleNamespace(access_token="tok-1"),
+ )
+
+ def get_user(self, token: str):
+ self._owner.tokens.append(token)
+ return SimpleNamespace(user=SimpleNamespace(id="user-1"))
+
+
+class _FactoryBackedAuthClient:
+ def __init__(self) -> None:
+ self.calls: list[dict[str, str]] = []
+ self.tokens: list[str] = []
+ self.auth = _FactoryBackedLoginAuth(self)
+
+
+class _DirectAuthClient:
+ def __init__(self) -> None:
+ self.calls: list[dict[str, str]] = []
+ self.tokens: list[str] = []
+
+ def sign_in_with_password(self, payload: dict[str, str]):
+ self.calls.append(payload)
+ return SimpleNamespace(
+ user=SimpleNamespace(id="user-1"),
+ session=SimpleNamespace(access_token="tok-1"),
+ )
+
+ def get_user(self, token: str):
+ self.tokens.append(token)
+ return SimpleNamespace(user=SimpleNamespace(id="user-1"))
+
+ def sign_up(self, payload: dict[str, str]):
+ self.calls.append(payload)
+ return SimpleNamespace(user=SimpleNamespace(id="user-1"), session=None)
+
+ def verify_otp(self, payload: dict[str, str]):
+ self.calls.append(payload)
+ return SimpleNamespace(
+ user=SimpleNamespace(id="user-1"),
+ session=SimpleNamespace(access_token="temp-token-1"),
+ )
+
+
+def _service(
+ *,
+ supabase_client=None,
+ supabase_auth_client=None,
+ supabase_auth_client_factory=None,
+ member_repo=None,
+ entity_repo=None,
+ invite_codes=None,
+) -> AuthService:
+ return AuthService(
+ members=member_repo or SimpleNamespace(),
+ accounts=SimpleNamespace(),
+ entities=entity_repo or SimpleNamespace(),
+ supabase_client=supabase_client,
+ supabase_auth_client=supabase_auth_client,
+ supabase_auth_client_factory=supabase_auth_client_factory,
+ invite_codes=invite_codes,
+ )
+
+
+def test_verify_token_prefers_supabase_get_user_over_local_jwt_secret(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.delenv("SUPABASE_JWT_SECRET", raising=False)
+ sb = _FakeSupabaseClient(user_id="user-supabase")
+
+ payload = _service(supabase_auth_client=sb).verify_token("tok-live")
+
+ assert sb.auth.tokens == ["tok-live"]
+ assert payload == {"user_id": "user-supabase", "entity_id": None}
+
+
+def test_verify_token_without_supabase_client_still_fails_loudly_when_secret_missing(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.delenv("SUPABASE_JWT_SECRET", raising=False)
+
+ with pytest.raises(RuntimeError, match="SUPABASE_JWT_SECRET env var required"):
+ _service().verify_token("tok-live")
+
+
+def test_login_uses_dedicated_auth_client_instead_of_storage_client():
+ auth_client = _FakeAuthClient()
+ member_repo = SimpleNamespace(
+ get_by_id=lambda _user_id: SimpleNamespace(name="codex", mycel_id=10001, email="codex@example.com", avatar=None),
+ list_by_owner_user_id=lambda _user_id: [],
+ )
+ entity_repo = SimpleNamespace(get_by_member_id=lambda _user_id: [SimpleNamespace(id="user-1-1", type="human")])
+
+ result = _service(
+ supabase_client=SimpleNamespace(auth=None),
+ supabase_auth_client=auth_client,
+ member_repo=member_repo,
+ entity_repo=entity_repo,
+ ).login("codex@example.com", "pw-1")
+
+ assert auth_client.auth.calls == [{"email": "codex@example.com", "password": "pw-1"}]
+ assert result["token"] == "tok-1"
+
+
+def test_login_uses_fresh_auth_client_from_factory_per_call():
+ created: list[_FactoryBackedAuthClient] = []
+
+ def factory() -> _FactoryBackedAuthClient:
+ client = _FactoryBackedAuthClient()
+ created.append(client)
+ return client
+
+ member_repo = SimpleNamespace(
+ get_by_id=lambda _user_id: SimpleNamespace(name="codex", mycel_id=10001, email="codex@example.com", avatar=None),
+ list_by_owner_user_id=lambda _user_id: [],
+ )
+ entity_repo = SimpleNamespace(get_by_member_id=lambda _user_id: [SimpleNamespace(id="user-1-1", type="human")])
+ service = _service(
+ supabase_client=SimpleNamespace(auth=None),
+ supabase_auth_client_factory=factory,
+ member_repo=member_repo,
+ entity_repo=entity_repo,
+ )
+
+ service.login("codex@example.com", "pw-1")
+ service.login("codex@example.com", "pw-2")
+
+ assert len(created) == 2
+ assert created[0].calls == [{"email": "codex@example.com", "password": "pw-1"}]
+ assert created[1].calls == [{"email": "codex@example.com", "password": "pw-2"}]
+
+
+def test_verify_token_uses_fresh_auth_client_from_factory_per_call(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.delenv("SUPABASE_JWT_SECRET", raising=False)
+ created: list[_FactoryBackedAuthClient] = []
+
+ def factory() -> _FactoryBackedAuthClient:
+ client = _FactoryBackedAuthClient()
+ created.append(client)
+ return client
+
+ service = _service(supabase_auth_client_factory=factory)
+
+ assert service.verify_token("tok-1") == {"user_id": "user-1", "entity_id": None}
+ assert service.verify_token("tok-2") == {"user_id": "user-1", "entity_id": None}
+ assert len(created) == 2
+ assert created[0].tokens == ["tok-1"]
+ assert created[1].tokens == ["tok-2"]
+
+
+def test_login_accepts_direct_gotrue_client_without_auth_wrapper():
+ auth_client = _DirectAuthClient()
+ member_repo = SimpleNamespace(
+ get_by_id=lambda _user_id: SimpleNamespace(name="codex", mycel_id=10001, email="codex@example.com", avatar=None),
+ list_by_owner_user_id=lambda _user_id: [],
+ )
+ entity_repo = SimpleNamespace(get_by_member_id=lambda _user_id: [SimpleNamespace(id="user-1-1", type="human")])
+
+ result = _service(
+ supabase_client=SimpleNamespace(auth=None),
+ supabase_auth_client=auth_client,
+ member_repo=member_repo,
+ entity_repo=entity_repo,
+ ).login("codex@example.com", "pw-1")
+
+ assert auth_client.calls == [{"email": "codex@example.com", "password": "pw-1"}]
+ assert result["token"] == "tok-1"
+
+
+def test_verify_token_accepts_direct_gotrue_client_without_auth_wrapper(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.delenv("SUPABASE_JWT_SECRET", raising=False)
+ auth_client = _DirectAuthClient()
+
+ payload = _service(supabase_auth_client=auth_client).verify_token("tok-direct")
+
+ assert auth_client.tokens == ["tok-direct"]
+ assert payload == {"user_id": "user-1", "entity_id": None}
+
+
+def test_send_otp_accepts_direct_gotrue_client_without_auth_wrapper():
+ auth_client = _DirectAuthClient()
+ invite_codes = SimpleNamespace(is_valid=lambda code: code == "invite-1")
+
+ _service(
+ supabase_client=SimpleNamespace(auth=None),
+ supabase_auth_client=auth_client,
+ invite_codes=invite_codes,
+ ).send_otp("fresh@example.com", "pw-1", "invite-1")
+
+ assert auth_client.calls == [{"email": "fresh@example.com", "password": "pw-1"}]
+
+
+def test_verify_register_otp_accepts_direct_gotrue_client_without_auth_wrapper():
+ auth_client = _DirectAuthClient()
+
+ result = _service(
+ supabase_client=SimpleNamespace(auth=None),
+ supabase_auth_client=auth_client,
+ ).verify_register_otp("fresh@example.com", "123456")
+
+ assert auth_client.calls == [{"email": "fresh@example.com", "token": "123456", "type": "signup"}]
+ assert result == {"temp_token": "temp-token-1"}
diff --git a/tests/Fix/test_background_task_cleanup.py b/tests/Fix/test_background_task_cleanup.py
new file mode 100644
index 000000000..3b088bd38
--- /dev/null
+++ b/tests/Fix/test_background_task_cleanup.py
@@ -0,0 +1,493 @@
+"""Integration tests for background task cleanup across command/agent surfaces."""
+
+import asyncio
+import json
+import shutil
+import sys
+from pathlib import Path
+
+import pytest
+from langchain_core.messages import AIMessage
+
+from core.agents.registry import AgentEntry, AgentRegistry
+from core.agents.service import AgentService
+from core.runtime.middleware.queue import MessageQueueManager
+from core.runtime.middleware.queue.middleware import SteeringMiddleware
+from core.runtime.registry import ToolRegistry
+from core.tools.command.bash.executor import BashExecutor
+from core.tools.command.service import CommandService
+from sandbox.thread_context import set_current_thread_id
+
+
+class _FakeAgentRegistry:
+ async def register(self, entry):
+ self.entry = entry
+
+ async def update_status(self, agent_id: str, status: str):
+ self.last_status = (agent_id, status)
+
+
+class _SlowChildAgent:
+ def __init__(self, first_text: str, release_event: asyncio.Event, started_event: asyncio.Event):
+ self._first_text = first_text
+ self._release_event = release_event
+ self._started_event = started_event
+ self._agent_service = type(
+ "_ChildService",
+ (),
+ {"cleanup_background_runs": self._cleanup_background_runs},
+ )()
+ self.agent = type("_InnerAgent", (), {"astream": self._astream})()
+ self.closed = False
+
+ async def ainit(self):
+ return None
+
+ async def _astream(self, *args, **kwargs):
+ self._started_event.set()
+ yield {"agent": {"messages": [AIMessage(content=self._first_text)]}}
+ await self._release_event.wait()
+
+ async def _cleanup_background_runs(self):
+ return None
+
+ def close(self):
+ self.closed = True
+ return None
+
+
+class _CompleteChildAgent:
+ def __init__(self, text: str):
+ self._text = text
+ self._agent_service = type(
+ "_ChildService",
+ (),
+ {"cleanup_background_runs": self._cleanup_background_runs},
+ )()
+ self.agent = type("_InnerAgent", (), {"astream": self._astream})()
+ self.closed = False
+
+ async def ainit(self):
+ return None
+
+ async def _astream(self, *args, **kwargs):
+ yield {"agent": {"messages": [AIMessage(content=self._text)]}}
+
+ async def _cleanup_background_runs(self):
+ return None
+
+ def close(self):
+ self.closed = True
+ return None
+
+
+class _FailingInitChildAgent:
+ def __init__(self, error: Exception):
+ self._error = error
+
+ async def ainit(self):
+ raise self._error
+
+
+def _agent_tool_json(result) -> dict:
+ content = getattr(result, "content", result)
+ return json.loads(content)
+
+
+@pytest.mark.skipif(
+ sys.platform == "win32" or shutil.which("bash") is None,
+ reason="bash background cleanup integration requires Unix-compatible bash",
+)
+def test_taskstop_terminates_real_background_bash_run(tmp_path):
+ async def run():
+ registry = ToolRegistry()
+ shared_runs: dict[str, object] = {}
+ executor = BashExecutor(default_cwd=str(tmp_path))
+ command_service = CommandService(
+ registry=registry,
+ workspace_root=tmp_path,
+ executor=executor,
+ background_runs=shared_runs,
+ )
+ agent_service = AgentService(
+ tool_registry=registry,
+ agent_registry=_FakeAgentRegistry(),
+ workspace_root=Path(tmp_path),
+ model_name="gpt-test",
+ shared_runs=shared_runs,
+ )
+
+ result = await command_service._execute_async(
+ "sleep 30",
+ str(tmp_path),
+ 30.0,
+ description="integration bash cleanup",
+ )
+ assert "task_id:" in result
+ assert len(shared_runs) == 1
+
+ task_id, running = next(iter(shared_runs.items()))
+ assert running.is_done is False
+
+ stop_result = await agent_service._handle_task_stop(task_id)
+
+ assert stop_result == f"Task {task_id} cancelled"
+ assert task_id not in shared_runs
+ assert running._cmd.process.returncode is not None
+
+ asyncio.run(run())
+
+
+def test_sendmessage_search_hint_uses_queue_naming(tmp_path):
+ registry = ToolRegistry()
+ AgentService(
+ tool_registry=registry,
+ agent_registry=_FakeAgentRegistry(),
+ workspace_root=Path(tmp_path),
+ model_name="gpt-test",
+ )
+
+ entry = registry.get("SendMessage")
+
+ assert entry is not None
+ assert "queue" in entry.search_hint
+ assert "mailbox" not in entry.search_hint
+
+
+@pytest.mark.asyncio
+async def test_sendmessage_enqueues_real_agent_notification_for_target_thread(tmp_path):
+ registry = ToolRegistry()
+ agent_registry = AgentRegistry(db_path=tmp_path / "agents.db")
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ service = AgentService(
+ tool_registry=registry,
+ agent_registry=agent_registry,
+ workspace_root=Path(tmp_path),
+ model_name="gpt-test",
+ queue_manager=queue_manager,
+ )
+ await agent_registry.register(
+ AgentEntry(
+ agent_id="agent-1",
+ name="worker-1",
+ thread_id="thread-worker-1",
+ status="running",
+ )
+ )
+
+ result = await service._handle_send_message(
+ target_name="worker-1",
+ message="hello from coordinator",
+ sender_name="coordinator",
+ )
+
+ assert result == "Message sent to worker-1."
+ items = queue_manager.drain_all("thread-worker-1")
+ assert len(items) == 1
+ assert items[0].notification_type == "agent"
+ assert items[0].sender_name == "coordinator"
+ assert "hello from coordinator" in items[0].content
+
+
+@pytest.mark.asyncio
+async def test_sendmessage_reaches_target_next_turn_via_steering_middleware(tmp_path):
+ registry = ToolRegistry()
+ agent_registry = AgentRegistry(db_path=tmp_path / "agents.db")
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ service = AgentService(
+ tool_registry=registry,
+ agent_registry=agent_registry,
+ workspace_root=Path(tmp_path),
+ model_name="gpt-test",
+ queue_manager=queue_manager,
+ )
+ await agent_registry.register(
+ AgentEntry(
+ agent_id="agent-1",
+ name="worker-1",
+ thread_id="thread-worker-1",
+ status="running",
+ )
+ )
+
+ await service._handle_send_message(
+ target_name="worker-1",
+ message="queue payload",
+ sender_name="coordinator",
+ )
+
+ injected = SteeringMiddleware(queue_manager=queue_manager).before_model(
+ state={},
+ runtime=None,
+ config={"configurable": {"thread_id": "thread-worker-1"}},
+ )
+
+ assert injected is not None
+ messages = injected["messages"]
+ assert len(messages) == 1
+ assert "queue payload" in str(messages[0].content)
+ assert messages[0].metadata["notification_type"] == "agent"
+ assert messages[0].metadata["sender_name"] == "coordinator"
+
+
+@pytest.mark.asyncio
+async def test_sendmessage_rejects_ambiguous_running_agent_names(tmp_path):
+ registry = ToolRegistry()
+ agent_registry = AgentRegistry(db_path=tmp_path / "agents.db")
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ service = AgentService(
+ tool_registry=registry,
+ agent_registry=agent_registry,
+ workspace_root=Path(tmp_path),
+ model_name="gpt-test",
+ queue_manager=queue_manager,
+ )
+ await agent_registry.register(
+ AgentEntry(
+ agent_id="agent-1",
+ name="worker",
+ thread_id="thread-worker-1",
+ status="running",
+ )
+ )
+ await agent_registry.register(
+ AgentEntry(
+ agent_id="agent-2",
+ name="worker",
+ thread_id="thread-worker-2",
+ status="running",
+ )
+ )
+
+ result = await service._handle_send_message(
+ target_name="worker",
+ message="hello dup",
+ sender_name="coordinator",
+ )
+
+ assert "ambiguous" in result
+ assert queue_manager.drain_all("thread-worker-1") == []
+ assert queue_manager.drain_all("thread-worker-2") == []
+
+
+@pytest.mark.asyncio
+async def test_background_agent_progress_notification_reaches_parent_next_turn(tmp_path, monkeypatch):
+ started = asyncio.Event()
+ release = asyncio.Event()
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ return _SlowChildAgent("Inspecting repository", release, started)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ service = AgentService(
+ tool_registry=registry,
+ agent_registry=_FakeAgentRegistry(),
+ workspace_root=Path(tmp_path),
+ model_name="gpt-test",
+ queue_manager=queue_manager,
+ background_progress_interval_s=0.02,
+ )
+
+ set_current_thread_id("parent-thread")
+ try:
+ raw = await service._handle_agent(
+ prompt="do work",
+ name="worker-1",
+ description="Investigating repository",
+ run_in_background=True,
+ )
+ task_id = _agent_tool_json(raw)["task_id"]
+ await asyncio.wait_for(started.wait(), timeout=1)
+ await asyncio.sleep(0.05)
+
+ injected = SteeringMiddleware(queue_manager=queue_manager).before_model(
+ state={},
+ runtime=None,
+ config={"configurable": {"thread_id": "parent-thread"}},
+ )
+
+ assert injected is not None
+ text = str(injected["messages"][0].content)
+ assert "" in text
+ assert f"{task_id}" in text
+ assert "Inspecting repository" in text
+ finally:
+ release.set()
+ await service.cleanup_background_runs()
+ set_current_thread_id("")
+
+
+@pytest.mark.asyncio
+async def test_background_agent_completion_notification_waits_for_followthrough_run(tmp_path, monkeypatch):
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ return _CompleteChildAgent("Finished indexing")
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ service = AgentService(
+ tool_registry=registry,
+ agent_registry=_FakeAgentRegistry(),
+ workspace_root=Path(tmp_path),
+ model_name="gpt-test",
+ queue_manager=queue_manager,
+ background_progress_interval_s=0.02,
+ )
+
+ set_current_thread_id("parent-thread")
+ try:
+ raw = await service._handle_agent(
+ prompt="do work",
+ name="worker-1",
+ description="Index repository",
+ run_in_background=True,
+ )
+ task_id = _agent_tool_json(raw)["task_id"]
+ running = service._tasks[task_id]
+ await asyncio.wait_for(running.task, timeout=1)
+
+ injected = SteeringMiddleware(queue_manager=queue_manager).before_model(
+ state={},
+ runtime=None,
+ config={"configurable": {"thread_id": "parent-thread"}},
+ )
+
+ assert injected is None
+ queued = queue_manager.list_queue("parent-thread")
+ assert len(queued) == 1
+ text = queued[0]["content"]
+ assert "" in text
+ assert f"{task_id}" in text
+ assert "completed" in text
+ assert "Finished indexing" in text
+ finally:
+ set_current_thread_id("")
+
+
+@pytest.mark.asyncio
+async def test_mixed_success_and_init_failure_background_agents_queue_both_terminal_notifications(tmp_path, monkeypatch):
+ created = 0
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ nonlocal created
+ created += 1
+ if created == 1:
+ return _CompleteChildAgent("GOOD:BASE:2")
+ return _FailingInitChildAgent(RuntimeError("bad child init"))
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ service = AgentService(
+ tool_registry=registry,
+ agent_registry=_FakeAgentRegistry(),
+ workspace_root=Path(tmp_path),
+ model_name="gpt-test",
+ queue_manager=queue_manager,
+ )
+
+ set_current_thread_id("parent-thread")
+ try:
+ raw_good = await service._handle_agent(
+ prompt="good child",
+ name="good-child",
+ description="good child",
+ run_in_background=True,
+ )
+ raw_bad = await service._handle_agent(
+ prompt="bad child",
+ name="bad-child",
+ description="bad child",
+ run_in_background=True,
+ )
+
+ await asyncio.wait_for(service._tasks[_agent_tool_json(raw_good)["task_id"]].task, timeout=1)
+ with pytest.raises(RuntimeError, match="bad child init"):
+ await asyncio.wait_for(service._tasks[_agent_tool_json(raw_bad)["task_id"]].task, timeout=1)
+
+ queued = queue_manager.list_queue("parent-thread")
+
+ assert len(queued) == 2
+ contents = [item["content"] for item in queued]
+ assert any("completed" in content and "GOOD:BASE:2" in content for content in contents)
+ assert any("error" in content and "Agent failed" in content for content in contents)
+ finally:
+ set_current_thread_id("")
+
+
+def test_terminal_background_notification_waits_for_followup_run_during_owner_turn(tmp_path):
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ queue_manager.enqueue(
+ "errorAgent failed",
+ "parent-thread",
+ notification_type="agent",
+ source="system",
+ )
+
+ runtime = type("_Runtime", (), {"current_run_source": "owner"})()
+ injected = SteeringMiddleware(queue_manager=queue_manager, agent_runtime=runtime).before_model(
+ state={},
+ runtime=None,
+ config={"configurable": {"thread_id": "parent-thread"}},
+ )
+
+ assert injected is None
+ queued = queue_manager.list_queue("parent-thread")
+ assert len(queued) == 1
+ assert "" in queued[0]["content"]
+
+
+def test_terminal_background_notification_waits_for_followup_run_during_system_turn(tmp_path):
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ queue_manager.enqueue(
+ "completedBG1:STEP1:2",
+ "parent-thread",
+ notification_type="agent",
+ source="system",
+ )
+
+ runtime = type("_Runtime", (), {"current_run_source": "system"})()
+ injected = SteeringMiddleware(queue_manager=queue_manager, agent_runtime=runtime).before_model(
+ state={},
+ runtime=None,
+ config={"configurable": {"thread_id": "parent-thread"}},
+ )
+
+ assert injected is None
+ queued = queue_manager.list_queue("parent-thread")
+ assert len(queued) == 1
+ assert "" in queued[0]["content"]
+
+
+def test_steer_injection_emits_phase_boundary_events(tmp_path):
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ queue_manager.enqueue(
+ "Stop the current plan and summarize status.",
+ "parent-thread",
+ notification_type="steer",
+ source="owner",
+ is_steer=True,
+ )
+
+ class _Runtime:
+ def __init__(self) -> None:
+ self.events: list[dict[str, str]] = []
+
+ def emit_activity_event(self, event: dict[str, str]) -> None:
+ self.events.append(event)
+
+ runtime = _Runtime()
+ injected = SteeringMiddleware(queue_manager=queue_manager, agent_runtime=runtime).before_model(
+ state={},
+ runtime=None,
+ config={"configurable": {"thread_id": "parent-thread"}},
+ )
+
+ assert injected is not None
+ assert str(injected["messages"][0].content) == "Stop the current plan and summarize status."
+ assert [event["event"] for event in runtime.events] == ["run_done", "run_start"]
diff --git a/tests/test_followup_requeue.py b/tests/Fix/test_followup_requeue.py
similarity index 97%
rename from tests/test_followup_requeue.py
rename to tests/Fix/test_followup_requeue.py
index 7a798aa7d..f19fa1b68 100644
--- a/tests/test_followup_requeue.py
+++ b/tests/Fix/test_followup_requeue.py
@@ -192,7 +192,7 @@ async def _run():
asyncio.run(_run())
def test_transition_failure_skips_start(self, mock_agent, mock_app, queue_manager):
- """When runtime.transition returns False, start_agent_run is not called."""
+ """When runtime.transition returns False, followup stays queued."""
queue_manager.enqueue("wont run", "thread-1")
mock_agent.runtime.transition.return_value = False
@@ -203,7 +203,8 @@ async def _run():
await _consume_followup_queue(mock_agent, "thread-1", mock_app)
mock_start.assert_not_called()
- # Message was consumed (dequeued) but not re-enqueued since no exception
- assert queue_manager.dequeue("thread-1") is None
+ item = queue_manager.dequeue("thread-1")
+ assert item is not None
+ assert item.content == "wont run"
asyncio.run(_run())
diff --git a/tests/Fix/test_monitor_resource_overview_uniqueness.py b/tests/Fix/test_monitor_resource_overview_uniqueness.py
new file mode 100644
index 000000000..dfcf08ba8
--- /dev/null
+++ b/tests/Fix/test_monitor_resource_overview_uniqueness.py
@@ -0,0 +1,323 @@
+from backend.web.services import resource_service
+
+
+class _FakeRepo:
+ def __init__(self, rows, lease_threads=None):
+ self._rows = rows
+ self._lease_threads = lease_threads or {}
+
+ def list_sessions_with_leases(self):
+ return list(self._rows)
+
+ def query_lease_threads(self, lease_id: str):
+ return [{"thread_id": tid} for tid in self._lease_threads.get(lease_id, [])]
+
+ def close(self):
+ pass
+
+
+class _FakeThreadRepo:
+ def __init__(self, rows):
+ self._rows = rows
+
+ def get_by_id(self, thread_id: str):
+ return self._rows.get(thread_id)
+
+ def close(self):
+ pass
+
+
+class _FakeMember:
+ def __init__(self, member_id: str, name: str, avatar: str | None = None):
+ self.id = member_id
+ self.name = name
+ self.avatar = avatar
+
+
+class _FakeMemberRepo:
+ def __init__(self, members):
+ self._members = members
+
+ def list_all(self):
+ return list(self._members)
+
+ def close(self):
+ pass
+
+
+def test_list_resource_providers_deduplicates_terminal_fallback_rows(monkeypatch):
+ rows = [
+ {
+ "provider": "local",
+ "session_id": None,
+ "thread_id": "thread-1",
+ "lease_id": "lease-1",
+ "observed_state": "running",
+ "desired_state": "running",
+ "created_at": "2026-04-04T00:00:00",
+ },
+ {
+ "provider": "local",
+ "session_id": None,
+ "thread_id": "thread-1",
+ "lease_id": "lease-1",
+ "observed_state": "running",
+ "desired_state": "running",
+ "created_at": "2026-04-04T00:00:00",
+ },
+ ]
+
+ monkeypatch.setattr(resource_service, "make_sandbox_monitor_repo", lambda: _FakeRepo(rows))
+ monkeypatch.setattr(
+ resource_service,
+ "available_sandbox_types",
+ lambda: [{"name": "local", "available": True}],
+ )
+ monkeypatch.setattr(
+ resource_service,
+ "_resolve_instance_capabilities",
+ lambda _config_name: (resource_service._empty_capabilities(), None),
+ )
+ monkeypatch.setattr(
+ resource_service,
+ "_thread_owners",
+ lambda thread_ids: {tid: {"member_id": "member-1", "member_name": "Toad", "avatar_url": None} for tid in thread_ids},
+ )
+ monkeypatch.setattr(resource_service, "list_resource_snapshots", lambda _lease_ids: {})
+
+ payload = resource_service.list_resource_providers()
+ local = payload["providers"][0]
+
+ assert local["telemetry"]["running"]["used"] == 1
+ assert local["sessions"] == [
+ {
+ "id": "lease-1:thread-1",
+ "leaseId": "lease-1",
+ "threadId": "thread-1",
+ "memberId": "member-1",
+ "memberName": "Toad",
+ "avatarUrl": None,
+ "status": "running",
+ "startedAt": "2026-04-04T00:00:00",
+ "metrics": None,
+ }
+ ]
+
+
+def test_list_resource_providers_resolves_owner_metadata_from_runtime_storage(monkeypatch):
+ rows = [
+ {
+ "provider": "daytona",
+ "session_id": "sess-1",
+ "thread_id": "thread-supabase",
+ "lease_id": "lease-1",
+ "observed_state": "running",
+ "desired_state": "running",
+ "created_at": "2026-04-04T00:00:00",
+ },
+ ]
+
+ monkeypatch.setattr(resource_service, "make_sandbox_monitor_repo", lambda: _FakeRepo(rows))
+ monkeypatch.setattr(
+ resource_service,
+ "available_sandbox_types",
+ lambda: [{"name": "daytona", "available": True}],
+ )
+ monkeypatch.setattr(resource_service, "resolve_provider_name", lambda *_args, **_kwargs: "daytona")
+ monkeypatch.setattr(resource_service, "_resolve_console_url", lambda *_args, **_kwargs: None)
+ monkeypatch.setattr(
+ resource_service,
+ "_resolve_instance_capabilities",
+ lambda _config_name: (resource_service._empty_capabilities(), None),
+ )
+ monkeypatch.setattr(
+ resource_service,
+ "build_thread_repo",
+ lambda **_kwargs: _FakeThreadRepo({"thread-supabase": {"member_id": "member-1"}}),
+ )
+ monkeypatch.setattr(
+ resource_service,
+ "build_member_repo",
+ lambda **_kwargs: _FakeMemberRepo([_FakeMember("member-1", "Toad")]),
+ )
+ monkeypatch.setattr(resource_service, "list_resource_snapshots", lambda _lease_ids: {})
+
+ payload = resource_service.list_resource_providers()
+
+ assert payload["providers"][0]["sessions"] == [
+ {
+ "id": "lease-1:thread-supabase",
+ "leaseId": "lease-1",
+ "threadId": "thread-supabase",
+ "memberId": "member-1",
+ "memberName": "Toad",
+ "avatarUrl": None,
+ "status": "running",
+ "startedAt": "2026-04-04T00:00:00",
+ "metrics": None,
+ }
+ ]
+
+
+def test_list_resource_providers_hides_subagent_threads(monkeypatch):
+ rows = [
+ {
+ "provider": "daytona",
+ "session_id": "sess-parent",
+ "thread_id": "thread-parent",
+ "lease_id": "lease-parent",
+ "observed_state": "running",
+ "desired_state": "running",
+ "created_at": "2026-04-04T00:00:00",
+ },
+ {
+ "provider": "daytona",
+ "session_id": "sess-child",
+ "thread_id": "subagent-deadbeef",
+ "lease_id": "lease-child",
+ "observed_state": "running",
+ "desired_state": "running",
+ "created_at": "2026-04-04T00:00:01",
+ },
+ ]
+
+ monkeypatch.setattr(resource_service, "make_sandbox_monitor_repo", lambda: _FakeRepo(rows))
+ monkeypatch.setattr(
+ resource_service,
+ "available_sandbox_types",
+ lambda: [{"name": "daytona", "available": True}],
+ )
+ monkeypatch.setattr(resource_service, "resolve_provider_name", lambda *_args, **_kwargs: "daytona")
+ monkeypatch.setattr(resource_service, "_resolve_console_url", lambda *_args, **_kwargs: None)
+ monkeypatch.setattr(
+ resource_service,
+ "_resolve_instance_capabilities",
+ lambda _config_name: (resource_service._empty_capabilities(), None),
+ )
+ monkeypatch.setattr(
+ resource_service,
+ "_thread_owners",
+ lambda thread_ids: {tid: {"member_id": tid, "member_name": tid, "avatar_url": None} for tid in thread_ids},
+ )
+ monkeypatch.setattr(resource_service, "list_resource_snapshots", lambda _lease_ids: {})
+
+ payload = resource_service.list_resource_providers()
+ sessions = payload["providers"][0]["sessions"]
+
+ assert [session["threadId"] for session in sessions] == ["thread-parent"]
+ assert payload["summary"]["running_sessions"] == 1
+
+
+def test_list_resource_providers_projects_visible_parent_when_raw_monitor_row_is_subagent(monkeypatch):
+ rows = [
+ {
+ "provider": "daytona_selfhost",
+ "session_id": None,
+ "thread_id": "subagent-deadbeef",
+ "lease_id": "lease-1",
+ "observed_state": "paused",
+ "desired_state": "paused",
+ "created_at": "2026-04-04T00:00:00",
+ },
+ ]
+
+ monkeypatch.setattr(
+ resource_service,
+ "make_sandbox_monitor_repo",
+ lambda: _FakeRepo(rows, lease_threads={"lease-1": ["subagent-deadbeef", "thread-parent"]}),
+ )
+ monkeypatch.setattr(
+ resource_service,
+ "available_sandbox_types",
+ lambda: [{"name": "daytona_selfhost", "available": True}],
+ )
+ monkeypatch.setattr(resource_service, "resolve_provider_name", lambda *_args, **_kwargs: "daytona")
+ monkeypatch.setattr(resource_service, "_resolve_console_url", lambda *_args, **_kwargs: None)
+ monkeypatch.setattr(
+ resource_service,
+ "_resolve_instance_capabilities",
+ lambda _config_name: (resource_service._empty_capabilities(), None),
+ )
+ monkeypatch.setattr(
+ resource_service,
+ "_thread_owners",
+ lambda thread_ids: {tid: {"member_id": "member-1", "member_name": "Morel", "avatar_url": None} for tid in thread_ids},
+ )
+ monkeypatch.setattr(resource_service, "list_resource_snapshots", lambda _lease_ids: {})
+
+ payload = resource_service.list_resource_providers()
+ sessions = payload["providers"][0]["sessions"]
+
+ assert sessions == [
+ {
+ "id": "lease-1:thread-parent",
+ "leaseId": "lease-1",
+ "threadId": "thread-parent",
+ "memberId": "member-1",
+ "memberName": "Morel",
+ "avatarUrl": None,
+ "status": "paused",
+ "startedAt": "2026-04-04T00:00:00",
+ "metrics": None,
+ }
+ ]
+
+
+def test_list_resource_providers_deduplicates_same_lease_thread_even_with_distinct_session_ids(monkeypatch):
+ rows = [
+ {
+ "provider": "daytona_selfhost",
+ "session_id": "sess-a",
+ "thread_id": "thread-parent",
+ "lease_id": "lease-1",
+ "observed_state": "running",
+ "desired_state": "running",
+ "created_at": "2026-04-04T00:00:00",
+ },
+ {
+ "provider": "daytona_selfhost",
+ "session_id": "sess-b",
+ "thread_id": "thread-parent",
+ "lease_id": "lease-1",
+ "observed_state": "running",
+ "desired_state": "running",
+ "created_at": "2026-04-04T00:00:01",
+ },
+ ]
+
+ monkeypatch.setattr(resource_service, "make_sandbox_monitor_repo", lambda: _FakeRepo(rows))
+ monkeypatch.setattr(
+ resource_service,
+ "available_sandbox_types",
+ lambda: [{"name": "daytona_selfhost", "available": True}],
+ )
+ monkeypatch.setattr(resource_service, "resolve_provider_name", lambda *_args, **_kwargs: "daytona")
+ monkeypatch.setattr(resource_service, "_resolve_console_url", lambda *_args, **_kwargs: None)
+ monkeypatch.setattr(
+ resource_service,
+ "_resolve_instance_capabilities",
+ lambda _config_name: (resource_service._empty_capabilities(), None),
+ )
+ monkeypatch.setattr(
+ resource_service,
+ "_thread_owners",
+ lambda thread_ids: {tid: {"member_id": "member-1", "member_name": "Toad", "avatar_url": None} for tid in thread_ids},
+ )
+ monkeypatch.setattr(resource_service, "list_resource_snapshots", lambda _lease_ids: {})
+
+ payload = resource_service.list_resource_providers()
+ sessions = payload["providers"][0]["sessions"]
+
+ assert sessions == [
+ {
+ "id": "lease-1:thread-parent",
+ "leaseId": "lease-1",
+ "threadId": "thread-parent",
+ "memberId": "member-1",
+ "memberName": "Toad",
+ "avatarUrl": None,
+ "status": "running",
+ "startedAt": "2026-04-04T00:00:00",
+ "metrics": None,
+ }
+ ]
diff --git a/tests/Fix/test_panel_auth_shell_coherence.py b/tests/Fix/test_panel_auth_shell_coherence.py
new file mode 100644
index 000000000..5a915b3c0
--- /dev/null
+++ b/tests/Fix/test_panel_auth_shell_coherence.py
@@ -0,0 +1,76 @@
+from __future__ import annotations
+
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+
+from backend.web.routers import panel as panel_router
+from backend.web.services import member_service, profile_service
+from storage.contracts import MemberRow, MemberType
+
+
+@pytest.mark.asyncio
+async def test_panel_members_uses_injected_member_repo_for_owner_scope(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
+ now = 1_775_278_000.0
+ agent = MemberRow(
+ id="agent-1",
+ name="Toad",
+ type=MemberType.MYCEL_AGENT,
+ owner_user_id="user-1",
+ created_at=now,
+ )
+ seen: list[str] = []
+ monkeypatch.setattr(
+ member_service,
+ "_member_to_dict",
+ lambda _member_dir: {
+ "id": "agent-1",
+ "name": "Toad",
+ "avatar_url": "avatars/agent-1.png",
+ "config": {},
+ },
+ )
+ member_dir = tmp_path / "agent-1"
+ member_dir.mkdir()
+ (member_dir / "agent.md").write_text("stub", encoding="utf-8")
+ monkeypatch.setattr(member_service, "MEMBERS_DIR", tmp_path)
+
+ fake_repo = SimpleNamespace(
+ list_by_owner_user_id=lambda owner_user_id: seen.append(owner_user_id) or [agent],
+ )
+
+ result = await panel_router.list_members(
+ user_id="user-1",
+ request=SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace(member_repo=fake_repo))),
+ )
+
+ assert seen == ["user-1"]
+ assert result["items"] == [{"id": "agent-1", "name": "Toad", "avatar_url": "avatars/agent-1.png", "config": {}}]
+
+
+def test_profile_service_prefers_authenticated_member_over_config_defaults():
+ member = MemberRow(
+ id="user-1",
+ name="codex",
+ type=MemberType.HUMAN,
+ email="codex@example.com",
+ created_at=1.0,
+ )
+
+ profile = profile_service.get_profile(member=member)
+
+ assert profile == {"name": "codex", "initials": "CO", "email": "codex@example.com"}
+
+
+def test_builtin_member_surface_exposes_chat_tools():
+ member = member_service._leon_builtin()
+ tools = {item["name"]: item for item in member["config"]["tools"]}
+
+ for tool_name in ("list_chats", "read_messages", "send_message", "search_messages"):
+ assert tool_name in tools
+ assert tools[tool_name]["enabled"] is True
+ assert tools[tool_name]["group"] == "chat"
+
+ for removed_name in ("chats", "read_message", "search_message", "directory", "wechat_send", "wechat_contacts"):
+ assert removed_name not in tools
diff --git a/tests/Fix/test_sandbox_provider_availability.py b/tests/Fix/test_sandbox_provider_availability.py
new file mode 100644
index 000000000..5b12fb2b6
--- /dev/null
+++ b/tests/Fix/test_sandbox_provider_availability.py
@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+from pathlib import Path
+from types import SimpleNamespace
+
+from backend.web.services import sandbox_service
+from sandbox.providers.local import LocalSessionProvider
+
+
+def test_available_sandbox_types_marks_configured_but_unavailable_provider(monkeypatch, tmp_path: Path) -> None:
+ local_provider = LocalSessionProvider(default_cwd=str(tmp_path))
+ (tmp_path / "daytona.json").write_text("{}")
+
+ monkeypatch.setattr(sandbox_service, "SANDBOXES_DIR", tmp_path)
+ monkeypatch.setattr(
+ sandbox_service,
+ "init_providers_and_managers",
+ lambda: ({"local": local_provider}, {}),
+ )
+ monkeypatch.setattr(
+ sandbox_service.SandboxConfig,
+ "load",
+ classmethod(lambda cls, name: SimpleNamespace(provider="daytona", name=name)),
+ )
+
+ types = sandbox_service.available_sandbox_types()
+ daytona = next(item for item in types if item["name"] == "daytona")
+
+ assert daytona["provider"] == "daytona"
+ assert daytona["available"] is False
+ assert "unavailable in the current process" in daytona["reason"]
+
+
+def test_available_sandbox_types_marks_e2b_unavailable_when_sdk_missing(monkeypatch, tmp_path: Path) -> None:
+ local_provider = LocalSessionProvider(default_cwd=str(tmp_path))
+ (tmp_path / "e2b.json").write_text("{}")
+
+ monkeypatch.setattr(sandbox_service, "SANDBOXES_DIR", tmp_path)
+ monkeypatch.setattr(
+ sandbox_service,
+ "init_providers_and_managers",
+ lambda: ({"local": local_provider}, {}),
+ )
+ monkeypatch.setattr(
+ sandbox_service.SandboxConfig,
+ "load",
+ classmethod(lambda cls, name: SimpleNamespace(provider="e2b", name=name)),
+ )
+
+ types = sandbox_service.available_sandbox_types()
+ e2b = next(item for item in types if item["name"] == "e2b")
+
+ assert e2b["provider"] == "e2b"
+ assert e2b["available"] is False
+ assert "unavailable in the current process" in e2b["reason"]
+
+
+def test_build_providers_and_managers_passes_agentbay_pause_capability_overrides(monkeypatch, tmp_path: Path) -> None:
+ (tmp_path / "agentbay.json").write_text("{}")
+ monkeypatch.setattr(sandbox_service, "SANDBOXES_DIR", tmp_path)
+
+ captured: dict[str, object] = {}
+
+ class _FakeAgentBayProvider:
+ def __init__(self, **kwargs) -> None:
+ captured.update(kwargs)
+ self.name = kwargs["provider_name"]
+
+ def get_capability(self):
+ return SimpleNamespace(can_pause=False, can_resume=False, can_destroy=True)
+
+ class _FakeSandboxManager:
+ def __init__(self, provider, db_path=None) -> None:
+ self.provider = provider
+ self.db_path = db_path
+
+ monkeypatch.setattr(sandbox_service, "SandboxManager", _FakeSandboxManager)
+ monkeypatch.setattr(
+ sandbox_service.SandboxConfig,
+ "load",
+ classmethod(
+ lambda cls, name: SimpleNamespace(
+ provider="agentbay",
+ agentbay=SimpleNamespace(
+ api_key="test-key",
+ region_id="ap-southeast-1",
+ context_path="/home/wuying",
+ image_id=None,
+ supports_pause=False,
+ supports_resume=False,
+ ),
+ )
+ ),
+ )
+
+ import sandbox.providers.agentbay as agentbay_module
+
+ monkeypatch.setattr(agentbay_module, "AgentBayProvider", _FakeAgentBayProvider)
+
+ providers, managers = sandbox_service._build_providers_and_managers()
+
+ assert "agentbay" in providers
+ assert "agentbay" in managers
+ assert captured["supports_pause"] is False
+ assert captured["supports_resume"] is False
diff --git a/tests/Fix/test_sandbox_user_leases.py b/tests/Fix/test_sandbox_user_leases.py
new file mode 100644
index 000000000..158fa423f
--- /dev/null
+++ b/tests/Fix/test_sandbox_user_leases.py
@@ -0,0 +1,117 @@
+from types import SimpleNamespace
+
+from backend.web.services import sandbox_service
+
+
+class _FakeMonitorRepo:
+ def __init__(self, rows):
+ self._rows = rows
+
+ def list_leases_with_threads(self):
+ return list(self._rows)
+
+ def close(self):
+ pass
+
+
+class _FakeThreadRepo:
+ def __init__(self, rows):
+ self._rows = rows
+
+ def get_by_id(self, thread_id: str):
+ return self._rows.get(thread_id)
+
+ def close(self):
+ pass
+
+
+class _FakeMemberRepo:
+ def __init__(self, rows):
+ self._rows = rows
+
+ def get_by_id(self, member_id: str):
+ return self._rows.get(member_id)
+
+ def close(self):
+ pass
+
+
+def test_list_user_leases_hides_subagent_threads_and_deduplicates_visible_agents(monkeypatch):
+ rows = [
+ {
+ "lease_id": "lease-1",
+ "provider_name": "daytona_selfhost",
+ "recipe_id": "daytona:default",
+ "recipe_json": None,
+ "observed_state": "running",
+ "desired_state": "running",
+ "cwd": "/home/daytona/files/app",
+ "thread_id": "thread-parent",
+ },
+ {
+ "lease_id": "lease-1",
+ "provider_name": "daytona_selfhost",
+ "recipe_id": "daytona:default",
+ "recipe_json": None,
+ "observed_state": "running",
+ "desired_state": "running",
+ "cwd": "/home/daytona/files/app",
+ "thread_id": "subagent-deadbeef",
+ },
+ ]
+ thread_repo = _FakeThreadRepo(
+ {
+ "thread-parent": {"member_id": "member-1"},
+ "subagent-deadbeef": {"member_id": "member-1"},
+ }
+ )
+ member_repo = _FakeMemberRepo(
+ {
+ "member-1": SimpleNamespace(id="member-1", name="Morel", avatar="x", owner_user_id="owner-1"),
+ }
+ )
+
+ monkeypatch.setattr(sandbox_service, "make_sandbox_monitor_repo", lambda: _FakeMonitorRepo(rows))
+
+ leases = sandbox_service.list_user_leases(
+ "owner-1",
+ thread_repo=thread_repo,
+ member_repo=member_repo,
+ )
+
+ assert leases == [
+ {
+ "lease_id": "lease-1",
+ "provider_name": "daytona_selfhost",
+ "recipe_id": "daytona:default",
+ "recipe": {
+ "id": "daytona:default",
+ "name": "Daytona Default",
+ "desc": "Default recipe for daytona",
+ "provider_type": "daytona",
+ "features": {"lark_cli": False},
+ "configurable_features": {"lark_cli": True},
+ "feature_options": [
+ {
+ "key": "lark_cli",
+ "name": "Lark CLI",
+ "description": "在 sandbox 初始化时懒安装并校验。",
+ "icon": "feishu",
+ }
+ ],
+ "builtin": True,
+ },
+ "observed_state": "running",
+ "desired_state": "running",
+ "cwd": "/home/daytona/files/app",
+ "thread_ids": ["thread-parent"],
+ "agents": [
+ {
+ "member_id": "member-1",
+ "member_name": "Morel",
+ "avatar_url": "/api/members/member-1/avatar",
+ }
+ ],
+ "recipe_name": "Daytona Default",
+ }
+ ]
diff --git a/tests/test_session_file_operations_cleanup.py b/tests/Fix/test_session_file_operations_cleanup.py
similarity index 100%
rename from tests/test_session_file_operations_cleanup.py
rename to tests/Fix/test_session_file_operations_cleanup.py
diff --git a/tests/test_storage_import_boundary.py b/tests/Fix/test_storage_import_boundary.py
similarity index 100%
rename from tests/test_storage_import_boundary.py
rename to tests/Fix/test_storage_import_boundary.py
diff --git a/tests/Fix/test_thread_request_model.py b/tests/Fix/test_thread_request_model.py
new file mode 100644
index 000000000..1bfe188be
--- /dev/null
+++ b/tests/Fix/test_thread_request_model.py
@@ -0,0 +1,25 @@
+from backend.web.models.requests import CreateThreadRequest
+
+
+def test_create_thread_request_accepts_legacy_sandbox_type_key() -> None:
+ payload = CreateThreadRequest.model_validate(
+ {
+ "member_id": "member-1",
+ "sandbox_type": "daytona_selfhost",
+ "model": "gpt-5.4-mini",
+ }
+ )
+
+ assert payload.sandbox == "daytona_selfhost"
+
+
+def test_create_thread_request_prefers_primary_sandbox_key() -> None:
+ payload = CreateThreadRequest.model_validate(
+ {
+ "member_id": "member-1",
+ "sandbox": "local",
+ "sandbox_type": "daytona_selfhost",
+ }
+ )
+
+ assert payload.sandbox == "local"
diff --git a/tests/Integration/test_auth_router.py b/tests/Integration/test_auth_router.py
new file mode 100644
index 000000000..51d2f9ee2
--- /dev/null
+++ b/tests/Integration/test_auth_router.py
@@ -0,0 +1,170 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+from fastapi import HTTPException
+
+from backend.web.routers import auth as auth_router
+from backend.web.routers import chats as chats_router
+
+
+class _FakeAuthService:
+ def __init__(self) -> None:
+ self.send_otp_calls: list[tuple[str, str, str]] = []
+ self.verify_otp_calls: list[tuple[str, str]] = []
+ self.complete_register_calls: list[tuple[str, str]] = []
+ self.login_calls: list[tuple[str, str]] = []
+ self.verify_otp_result = {"temp_token": "temp-otp"}
+ self.complete_register_result = {"token": "tok-register"}
+ self.login_result = {"token": "tok-login"}
+ self.send_otp_error: Exception | None = None
+ self.verify_otp_error: Exception | None = None
+ self.complete_register_error: Exception | None = None
+ self.login_error: Exception | None = None
+
+ def send_otp(self, email: str, password: str, invite_code: str) -> None:
+ self.send_otp_calls.append((email, password, invite_code))
+ if self.send_otp_error is not None:
+ raise self.send_otp_error
+
+ def verify_register_otp(self, email: str, token: str) -> dict:
+ self.verify_otp_calls.append((email, token))
+ if self.verify_otp_error is not None:
+ raise self.verify_otp_error
+ return self.verify_otp_result
+
+ def complete_register(self, temp_token: str, invite_code: str) -> dict:
+ self.complete_register_calls.append((temp_token, invite_code))
+ if self.complete_register_error is not None:
+ raise self.complete_register_error
+ return self.complete_register_result
+
+ def login(self, identifier: str, password: str) -> dict:
+ self.login_calls.append((identifier, password))
+ if self.login_error is not None:
+ raise self.login_error
+ return self.login_result
+
+
+@pytest.mark.asyncio
+async def test_send_otp_calls_auth_service_directly():
+ service = _FakeAuthService()
+ app = SimpleNamespace(state=SimpleNamespace(auth_service=service))
+
+ result = await auth_router.send_otp(
+ auth_router.SendOtpRequest(email="fresh@example.com", password="pass1234", invite_code="invite-1"),
+ app,
+ )
+
+ assert result == {"ok": True}
+ assert service.send_otp_calls == [("fresh@example.com", "pass1234", "invite-1")]
+
+
+@pytest.mark.asyncio
+async def test_send_otp_maps_value_error_to_bad_request():
+ service = _FakeAuthService()
+ service.send_otp_error = ValueError("邀请码无效或已过期")
+ app = SimpleNamespace(state=SimpleNamespace(auth_service=service))
+
+ with pytest.raises(HTTPException) as exc_info:
+ await auth_router.send_otp(
+ auth_router.SendOtpRequest(email="fresh@example.com", password="pass1234", invite_code="invite-1"),
+ app,
+ )
+
+ assert exc_info.value.status_code == 400
+ assert "邀请码无效" in str(exc_info.value.detail)
+
+
+@pytest.mark.asyncio
+async def test_verify_otp_calls_auth_service_directly():
+ service = _FakeAuthService()
+ app = SimpleNamespace(state=SimpleNamespace(auth_service=service))
+
+ result = await auth_router.verify_otp(
+ auth_router.VerifyOtpRequest(email="fresh@example.com", token="123456"),
+ app,
+ )
+
+ assert result == {"temp_token": "temp-otp"}
+ assert service.verify_otp_calls == [("fresh@example.com", "123456")]
+
+
+@pytest.mark.asyncio
+async def test_complete_register_calls_auth_service_directly():
+ service = _FakeAuthService()
+ app = SimpleNamespace(state=SimpleNamespace(auth_service=service))
+
+ result = await auth_router.complete_register(
+ auth_router.CompleteRegisterRequest(temp_token="temp-otp", invite_code="invite-1"),
+ app,
+ )
+
+ assert result == {"token": "tok-register"}
+ assert service.complete_register_calls == [("temp-otp", "invite-1")]
+
+
+@pytest.mark.asyncio
+async def test_login_calls_auth_service_directly():
+ service = _FakeAuthService()
+ app = SimpleNamespace(state=SimpleNamespace(auth_service=service))
+
+ result = await auth_router.login(auth_router.LoginRequest(identifier="fresh@example.com", password="pass1234"), app)
+
+ assert result == {"token": "tok-login"}
+ assert service.login_calls == [("fresh@example.com", "pass1234")]
+
+
+@pytest.mark.asyncio
+async def test_login_maps_value_error_to_unauthorized():
+ service = _FakeAuthService()
+ service.login_error = ValueError("Invalid username or password")
+ app = SimpleNamespace(state=SimpleNamespace(auth_service=service))
+
+ with pytest.raises(HTTPException) as exc_info:
+ await auth_router.login(auth_router.LoginRequest(identifier="fresh@example.com", password="pass1234"), app)
+
+ assert exc_info.value.status_code == 401
+ assert "Invalid username or password" in str(exc_info.value.detail)
+
+
+class _VerifyOnlyAuthService:
+ def __init__(self) -> None:
+ self.tokens: list[str] = []
+
+ def verify_token(self, token: str) -> dict:
+ self.tokens.append(token)
+ return {"user_id": "user-1"}
+
+
+@pytest.mark.asyncio
+async def test_chat_events_requires_token():
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ auth_service=_VerifyOnlyAuthService(),
+ chat_event_bus=SimpleNamespace(subscribe=lambda _chat_id: None),
+ )
+ )
+
+ with pytest.raises(HTTPException) as exc_info:
+ await chats_router.stream_chat_events("chat-1", token=None, app=app)
+
+ assert exc_info.value.status_code == 401
+ assert exc_info.value.detail == "Missing token"
+
+
+@pytest.mark.asyncio
+async def test_chat_events_verifies_provided_token():
+ auth_service = _VerifyOnlyAuthService()
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ auth_service=auth_service,
+ chat_event_bus=SimpleNamespace(subscribe=lambda _chat_id: None),
+ )
+ )
+
+ response = await chats_router.stream_chat_events("chat-1", token="tok-chat", app=app)
+
+ assert auth_service.tokens == ["tok-chat"]
+ assert response.media_type == "text/event-stream"
diff --git a/tests/Integration/test_child_thread_live_bridge.py b/tests/Integration/test_child_thread_live_bridge.py
new file mode 100644
index 000000000..84d1d26d7
--- /dev/null
+++ b/tests/Integration/test_child_thread_live_bridge.py
@@ -0,0 +1,635 @@
+from __future__ import annotations
+
+import asyncio
+import json
+from types import SimpleNamespace
+from unittest.mock import AsyncMock
+
+import pytest
+from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
+
+from backend.web.routers import threads as threads_router
+from backend.web.services.display_builder import DisplayBuilder
+from backend.web.services.event_buffer import ThreadEventBuffer
+from backend.web.services.streaming_service import run_child_thread_live
+from backend.web.utils.serializers import serialize_message
+from core.runtime.middleware.monitor import AgentState
+from core.runtime.middleware.queue.manager import MessageQueueManager
+
+
+class _FakeRuntime:
+ def __init__(self) -> None:
+ self.current_state = AgentState.IDLE
+ self._event_callback = None
+ self._activity_sink = None
+ self.state = SimpleNamespace(flags=SimpleNamespace(is_compacting=False))
+ self.calls = 0
+ self.tokens = 0
+ self.cost = 0.0
+ self.ctx_percent = 0.0
+
+ def transition(self, new_state: AgentState) -> bool:
+ self.current_state = new_state
+ return True
+
+ def set_event_callback(self, callback) -> None:
+ self._event_callback = callback
+
+ def bind_thread(self, activity_sink) -> None:
+ self._activity_sink = activity_sink
+
+ def unbind_thread(self) -> None:
+ self._activity_sink = None
+
+ def get_compact_dict(self) -> dict:
+ return {
+ "state": self.current_state.value,
+ "tokens": self.tokens,
+ "cost": self.cost,
+ "calls": self.calls,
+ "ctx_percent": self.ctx_percent,
+ }
+
+ def get_status_dict(self) -> dict:
+ return {
+ "state": {"state": self.current_state.value, "flags": {}},
+ "tokens": {"total": self.tokens},
+ "context": {"percent": self.ctx_percent},
+ "calls": self.calls,
+ "cost": self.cost,
+ }
+
+
+class _BlockingChildGraph:
+ def __init__(self) -> None:
+ self.messages: list = []
+ self.started = asyncio.Event()
+ self.release = asyncio.Event()
+ self.system_prompt = None
+
+ async def aget_state(self, _config):
+ return SimpleNamespace(values={"messages": list(self.messages)})
+
+ async def aupdate_state(self, _config, input_data, as_node=None):
+ self.messages.extend(input_data.get("messages", []))
+
+ async def astream(self, input_data, config=None, stream_mode=None):
+ if input_data is not None:
+ self.messages.extend(input_data.get("messages", []))
+ self.started.set()
+ await self.release.wait()
+ yield ("messages", (SimpleNamespace(__class__=SimpleNamespace(__name__="AIMessageChunk")), {}))
+ ai = AIMessage(content="CHILD_DONE")
+ ai.id = "ai-child-1"
+ self.messages.append(ai)
+ yield ("updates", {"agent": {"messages": [ai]}})
+
+
+class _BlockingChildAgent:
+ def __init__(self) -> None:
+ self.runtime = _FakeRuntime()
+ self.agent = _BlockingChildGraph()
+
+
+def _prime_agent_turn(
+ builder: DisplayBuilder,
+ thread_id: str,
+ *,
+ tool_call_id: str = "tc-agent-1",
+ args: dict | None = None,
+ run_id: str = "run-1",
+) -> None:
+ builder.apply_event(
+ thread_id,
+ "run_start",
+ {"run_id": run_id, "source": "owner", "showing": True},
+ )
+ builder.apply_event(
+ thread_id,
+ "tool_call",
+ {
+ "id": tool_call_id,
+ "name": "Agent",
+ "args": args or {"prompt": "do work"},
+ "showing": True,
+ },
+ )
+
+
+def _set_single_subagent_entry(
+ builder: DisplayBuilder,
+ thread_id: str,
+ *,
+ task_id: str,
+ thread_ref: str,
+ status: str,
+ result: str,
+ description: str = "inspect workspace",
+) -> None:
+ builder.set_entries(
+ thread_id,
+ [
+ {"id": "u1", "role": "user", "content": "do work", "timestamp": 1},
+ {
+ "id": "a1",
+ "role": "assistant",
+ "timestamp": 2,
+ "segments": [
+ {
+ "type": "tool",
+ "step": {
+ "id": "call-agent-1",
+ "name": "Agent",
+ "args": {"description": description},
+ "status": "done",
+ "result": result,
+ "subagent_stream": {
+ "task_id": task_id,
+ "thread_id": thread_ref,
+ "description": description,
+ "text": "",
+ "tool_calls": [],
+ "status": status,
+ },
+ },
+ }
+ ],
+ },
+ ],
+ )
+
+
+def _make_router_app(
+ builder: DisplayBuilder,
+ thread_id: str,
+ monkeypatch: pytest.MonkeyPatch,
+) -> SimpleNamespace:
+ fake_agent = SimpleNamespace(runtime=SimpleNamespace(current_state=AgentState.ACTIVE), agent=SimpleNamespace(aget_state=None))
+ monkeypatch.setattr(threads_router, "get_or_create_agent", AsyncMock(return_value=fake_agent))
+ return SimpleNamespace(
+ state=SimpleNamespace(
+ display_builder=builder,
+ agent_pool={},
+ thread_sandbox={thread_id: "local"},
+ )
+ )
+
+
+@pytest.mark.asyncio
+async def test_run_child_thread_live_rebinds_from_parent_sink_and_surfaces_runtime_and_detail_before_completion():
+ child_thread_id = "subagent-live-1"
+ agent = _BlockingChildAgent()
+ parent_events: list[dict] = []
+
+ async def _parent_sink(event: dict) -> None:
+ parent_events.append(event)
+
+ agent.runtime.bind_thread(_parent_sink)
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ display_builder=DisplayBuilder(),
+ queue_manager=MessageQueueManager(),
+ _event_loop=asyncio.get_running_loop(),
+ thread_event_buffers={},
+ thread_tasks={},
+ thread_last_active={},
+ agent_pool={},
+ thread_sandbox={child_thread_id: "local"},
+ thread_cwd={},
+ thread_repo=SimpleNamespace(get_by_id=lambda thread_id: {"model": "gpt-live"} if thread_id == child_thread_id else None),
+ )
+ )
+
+ task = asyncio.create_task(
+ run_child_thread_live(
+ agent,
+ child_thread_id,
+ "child prompt",
+ app,
+ input_messages=[HumanMessage(content="child prompt")],
+ )
+ )
+
+ await agent.agent.started.wait()
+
+ runtime = await threads_router.get_thread_runtime(child_thread_id, stream=False, user_id="owner-1", app=app)
+ detail = await threads_router.get_thread_messages(child_thread_id, user_id="owner-1", app=app)
+
+ assert runtime["state"]["state"] == "active"
+ assert detail["entries"]
+ assert detail["entries"][0]["role"] == "user"
+ assert detail["entries"][0]["content"] == "child prompt"
+ assert isinstance(app.state.thread_event_buffers[child_thread_id], ThreadEventBuffer)
+ assert app.state.agent_pool[f"{child_thread_id}:local"] is agent
+ assert agent.runtime._activity_sink is not _parent_sink
+ assert parent_events == []
+
+ agent.agent.release.set()
+ result = await task
+
+ assert result == "CHILD_DONE"
+
+
+@pytest.mark.asyncio
+async def test_run_child_thread_live_raises_when_child_run_emits_error_event(monkeypatch):
+ child_thread_id = "subagent-live-error"
+ agent = _BlockingChildAgent()
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ display_builder=DisplayBuilder(),
+ queue_manager=MessageQueueManager(),
+ _event_loop=asyncio.get_running_loop(),
+ thread_event_buffers={},
+ thread_tasks={},
+ thread_last_active={},
+ agent_pool={},
+ thread_sandbox={child_thread_id: "local"},
+ thread_cwd={},
+ thread_repo=SimpleNamespace(get_by_id=lambda thread_id: {"model": "gpt-live"} if thread_id == child_thread_id else None),
+ )
+ )
+
+ def fake_start_agent_run(agent, thread_id, message, app, enable_trajectory=False, message_metadata=None, input_messages=None):
+ async def _fake_run():
+ thread_buf = app.state.thread_event_buffers[thread_id]
+ await thread_buf.put({"event": "error", "data": json.dumps({"error": "child model init failed"})})
+ return ""
+
+ app.state.thread_tasks[thread_id] = asyncio.create_task(_fake_run())
+ return "run-error-1"
+
+ monkeypatch.setattr("backend.web.services.streaming_service.start_agent_run", fake_start_agent_run)
+
+ with pytest.raises(RuntimeError, match="child model init failed"):
+ await run_child_thread_live(
+ agent,
+ child_thread_id,
+ "child prompt",
+ app,
+ input_messages=[HumanMessage(content="child prompt")],
+ )
+
+
+@pytest.mark.asyncio
+async def test_run_child_thread_live_raises_when_child_never_makes_a_model_call(monkeypatch):
+ child_thread_id = "subagent-live-no-call"
+ agent = _BlockingChildAgent()
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ display_builder=DisplayBuilder(),
+ queue_manager=MessageQueueManager(),
+ _event_loop=asyncio.get_running_loop(),
+ thread_event_buffers={},
+ thread_tasks={},
+ thread_last_active={},
+ agent_pool={},
+ thread_sandbox={child_thread_id: "local"},
+ thread_cwd={},
+ thread_repo=SimpleNamespace(get_by_id=lambda thread_id: {"model": "gpt-live"} if thread_id == child_thread_id else None),
+ )
+ )
+
+ def fake_start_agent_run(agent, thread_id, message, app, enable_trajectory=False, message_metadata=None, input_messages=None):
+ async def _fake_run():
+ return ""
+
+ app.state.thread_tasks[thread_id] = asyncio.create_task(_fake_run())
+ return "run-no-call-1"
+
+ monkeypatch.setattr("backend.web.services.streaming_service.start_agent_run", fake_start_agent_run)
+
+ with pytest.raises(RuntimeError, match="before first model call"):
+ await run_child_thread_live(
+ agent,
+ child_thread_id,
+ "child prompt",
+ app,
+ input_messages=[HumanMessage(content="child prompt")],
+ )
+
+
+def test_live_tool_result_restores_subagent_stream_from_agent_background_json():
+ builder = DisplayBuilder()
+ thread_id = "parent-thread"
+ _prime_agent_turn(builder, thread_id, args={"prompt": "do work", "run_in_background": True})
+
+ delta = builder.apply_event(
+ thread_id,
+ "tool_result",
+ {
+ "tool_call_id": "tc-agent-1",
+ "name": "Agent",
+ "content": (
+ '{"task_id":"task-123","agent_name":"agent-task-123",'
+ '"thread_id":"subagent-task-123","status":"running",'
+ '"message":"Agent started in background. Use TaskOutput to get result."}'
+ ),
+ "metadata": {},
+ "showing": True,
+ },
+ )
+
+ seg = builder.get_entries(thread_id)[0]["segments"][0]
+ assert delta is not None
+ assert seg["step"]["subagent_stream"]["task_id"] == "task-123"
+ assert seg["step"]["subagent_stream"]["thread_id"] == "subagent-task-123"
+ assert seg["step"]["subagent_stream"]["status"] == "running"
+
+
+def test_live_tool_result_restores_subagent_stream_from_blocking_agent_metadata():
+ builder = DisplayBuilder()
+ thread_id = "parent-thread"
+ _prime_agent_turn(builder, thread_id)
+
+ delta = builder.apply_event(
+ thread_id,
+ "tool_result",
+ {
+ "tool_call_id": "tc-agent-1",
+ "name": "Agent",
+ "content": "CHILD_DONE",
+ "metadata": {
+ "task_id": "task-456",
+ "subagent_thread_id": "subagent-task-456",
+ "description": "blocking child",
+ },
+ "showing": True,
+ },
+ )
+
+ seg = builder.get_entries(thread_id)[0]["segments"][0]
+ assert delta is not None
+ assert seg["step"]["subagent_stream"]["task_id"] == "task-456"
+ assert seg["step"]["subagent_stream"]["thread_id"] == "subagent-task-456"
+ assert seg["step"]["subagent_stream"]["status"] == "completed"
+
+
+def test_task_start_can_patch_background_agent_after_tool_result_race():
+ builder = DisplayBuilder()
+ thread_id = "parent-thread"
+ _prime_agent_turn(
+ builder,
+ thread_id,
+ tool_call_id="tc-agent-race",
+ args={"prompt": "do work", "run_in_background": True},
+ )
+ builder.apply_event(
+ thread_id,
+ "tool_result",
+ {
+ "tool_call_id": "tc-agent-race",
+ "name": "Agent",
+ "content": "Agent started in background.",
+ "metadata": {},
+ "showing": True,
+ },
+ )
+
+ delta = builder.apply_event(
+ thread_id,
+ "task_start",
+ {
+ "task_id": "task-race",
+ "thread_id": "subagent-task-race",
+ "description": "late task start",
+ },
+ )
+
+ seg = builder.get_entries(thread_id)[0]["segments"][0]
+ assert delta is not None
+ assert seg["step"]["status"] == "done"
+ assert seg["step"]["subagent_stream"]["task_id"] == "task-race"
+ assert seg["step"]["subagent_stream"]["thread_id"] == "subagent-task-race"
+ assert seg["step"]["subagent_stream"]["status"] == "running"
+
+
+@pytest.mark.parametrize("task_status", ["completed", "error", "cancelled"])
+def test_live_notice_reconciles_subagent_stream_status_from_terminal_notification(task_status: str):
+ builder = DisplayBuilder()
+ thread_id = "parent-thread"
+ _prime_agent_turn(builder, thread_id, args={"prompt": "do work", "run_in_background": True})
+ builder.apply_event(
+ thread_id,
+ "tool_result",
+ {
+ "tool_call_id": "tc-agent-1",
+ "name": "Agent",
+ "content": (
+ '{"task_id":"task-123","agent_name":"agent-task-123",'
+ '"thread_id":"subagent-task-123","status":"running",'
+ '"message":"Agent started in background. Use TaskOutput to get result."}'
+ ),
+ "metadata": {},
+ "showing": True,
+ },
+ )
+
+ delta = builder.apply_event(
+ thread_id,
+ "notice",
+ {
+ "content": (
+ "\n"
+ "\n"
+ " task-123\n"
+ f" {task_status}\n"
+ " child task\n"
+ " child task\n"
+ " CHILD_DONE\n"
+ "\n"
+ ""
+ ),
+ "source": "system",
+ "notification_type": "agent",
+ },
+ )
+
+ seg = builder.get_entries(thread_id)[0]["segments"][0]
+ assert delta is not None
+ assert seg["step"]["subagent_stream"]["task_id"] == "task-123"
+ assert seg["step"]["subagent_stream"]["thread_id"] == "subagent-task-123"
+ assert seg["step"]["subagent_stream"]["status"] == task_status
+
+
+def test_checkpoint_rebuild_reconciles_subagent_stream_status_from_terminal_notification():
+ builder = DisplayBuilder()
+ thread_id = "parent-thread"
+
+ ai = AIMessage(
+ content="",
+ tool_calls=[{"name": "Agent", "args": {"prompt": "do work", "run_in_background": True}, "id": "tc-agent-1"}],
+ )
+ tool = ToolMessage(
+ content=(
+ '{"task_id":"task-123","agent_name":"agent-task-123",'
+ '"thread_id":"subagent-task-123","status":"running",'
+ '"message":"Agent started in background. Use TaskOutput to get result."}'
+ ),
+ name="Agent",
+ tool_call_id="tc-agent-1",
+ )
+ notice = HumanMessage(
+ content=(
+ "\n"
+ "\n"
+ " task-123\n"
+ " completed\n"
+ " child task\n"
+ " child task\n"
+ " CHILD_DONE\n"
+ "\n"
+ ""
+ )
+ )
+ notice.metadata = {"source": "system", "notification_type": "agent"}
+
+ entries = builder.build_from_checkpoint(
+ thread_id,
+ [serialize_message(ai), serialize_message(tool), serialize_message(notice)],
+ )
+
+ seg = entries[0]["segments"][0]
+ assert seg["step"]["subagent_stream"]["task_id"] == "task-123"
+ assert seg["step"]["subagent_stream"]["thread_id"] == "subagent-task-123"
+ assert seg["step"]["subagent_stream"]["status"] == "completed"
+
+
+def test_checkpoint_rebuild_restores_blocking_subagent_stream_from_tool_result_meta():
+ builder = DisplayBuilder()
+ thread_id = "parent-thread"
+
+ ai = AIMessage(
+ content="",
+ tool_calls=[{"name": "Agent", "args": {"prompt": "do work"}, "id": "tc-agent-1"}],
+ )
+ tool = ToolMessage(
+ content="CHILD_DONE",
+ name="Agent",
+ tool_call_id="tc-agent-1",
+ additional_kwargs={
+ "tool_result_meta": {
+ "task_id": "task-456",
+ "subagent_thread_id": "subagent-task-456",
+ "description": "blocking child",
+ "kind": "success",
+ "source": "local",
+ }
+ },
+ )
+
+ entries = builder.build_from_checkpoint(
+ thread_id,
+ [serialize_message(ai), serialize_message(tool)],
+ )
+
+ seg = entries[0]["segments"][0]
+ assert seg["step"]["subagent_stream"]["task_id"] == "task-456"
+ assert seg["step"]["subagent_stream"]["thread_id"] == "subagent-task-456"
+ assert seg["step"]["subagent_stream"]["status"] == "completed"
+
+
+@pytest.mark.asyncio
+async def test_list_tasks_includes_subagent_stream_from_display_entries():
+ thread_id = "parent-thread-tasks"
+ builder = DisplayBuilder()
+ _set_single_subagent_entry(
+ builder,
+ thread_id,
+ task_id="task-123",
+ thread_ref="subagent-task-123",
+ status="completed",
+ result="workspace looks empty",
+ )
+ monkeypatch = pytest.MonkeyPatch()
+ app = _make_router_app(builder, thread_id, monkeypatch)
+
+ tasks = await threads_router.list_tasks(thread_id, request=SimpleNamespace(app=app))
+
+ assert tasks == [
+ {
+ "task_id": "task-123",
+ "task_type": "agent",
+ "status": "completed",
+ "command_line": None,
+ "description": "inspect workspace",
+ "exit_code": None,
+ "error": None,
+ }
+ ]
+ monkeypatch.undo()
+
+
+@pytest.mark.asyncio
+async def test_get_task_returns_subagent_stream_result_from_display_entries():
+ thread_id = "parent-thread-task-detail"
+ builder = DisplayBuilder()
+ _set_single_subagent_entry(
+ builder,
+ thread_id,
+ task_id="task-123",
+ thread_ref="subagent-task-123",
+ status="completed",
+ result="workspace looks empty",
+ )
+ monkeypatch = pytest.MonkeyPatch()
+ app = _make_router_app(builder, thread_id, monkeypatch)
+
+ task = await threads_router.get_task(thread_id, "task-123", request=SimpleNamespace(app=app))
+
+ assert task == {
+ "task_id": "task-123",
+ "task_type": "agent",
+ "status": "completed",
+ "command_line": None,
+ "result": "workspace looks empty",
+ "text": "workspace looks empty",
+ }
+ monkeypatch.undo()
+
+
+@pytest.mark.asyncio
+async def test_blocking_subagent_done_state_overrides_stale_running_stream_on_detail_and_tasks(monkeypatch):
+ thread_id = "parent-thread-stale-running-completed"
+ builder = DisplayBuilder()
+ _set_single_subagent_entry(
+ builder,
+ thread_id,
+ task_id="task-stale-completed",
+ thread_ref="subagent-task-stale-completed",
+ status="running",
+ result="workspace looks empty",
+ )
+ app = _make_router_app(builder, thread_id, monkeypatch)
+
+ detail = await threads_router.get_thread_messages(thread_id, user_id="owner-1", app=app)
+ tasks = await threads_router.list_tasks(thread_id, request=SimpleNamespace(app=app))
+ task = await threads_router.get_task(thread_id, "task-stale-completed", request=SimpleNamespace(app=app))
+
+ stream = detail["entries"][1]["segments"][0]["step"]["subagent_stream"]
+ assert stream["status"] == "completed"
+ assert tasks[0]["status"] == "completed"
+ assert task["status"] == "completed"
+
+
+@pytest.mark.asyncio
+async def test_blocking_subagent_error_overrides_stale_running_stream_on_detail_and_tasks(monkeypatch):
+ thread_id = "parent-thread-stale-running-error"
+ builder = DisplayBuilder()
+ _set_single_subagent_entry(
+ builder,
+ thread_id,
+ task_id="task-stale-error",
+ thread_ref="subagent-task-stale-error",
+ status="running",
+ result="Agent failed: bad child model",
+ )
+ app = _make_router_app(builder, thread_id, monkeypatch)
+
+ detail = await threads_router.get_thread_messages(thread_id, user_id="owner-1", app=app)
+ tasks = await threads_router.list_tasks(thread_id, request=SimpleNamespace(app=app))
+ task = await threads_router.get_task(thread_id, "task-stale-error", request=SimpleNamespace(app=app))
+
+ stream = detail["entries"][1]["segments"][0]["step"]["subagent_stream"]
+ assert stream["status"] == "error"
+ assert tasks[0]["status"] == "error"
+ assert task["status"] == "error"
diff --git a/tests/test_daytona_e2e.py b/tests/Integration/test_daytona_e2e.py
similarity index 100%
rename from tests/test_daytona_e2e.py
rename to tests/Integration/test_daytona_e2e.py
diff --git a/tests/test_e2e_backend_api.py b/tests/Integration/test_e2e_backend_api.py
similarity index 100%
rename from tests/test_e2e_backend_api.py
rename to tests/Integration/test_e2e_backend_api.py
diff --git a/tests/test_e2e_providers.py b/tests/Integration/test_e2e_providers.py
similarity index 100%
rename from tests/test_e2e_providers.py
rename to tests/Integration/test_e2e_providers.py
diff --git a/tests/test_e2e_summary_persistence.py b/tests/Integration/test_e2e_summary_persistence.py
similarity index 100%
rename from tests/test_e2e_summary_persistence.py
rename to tests/Integration/test_e2e_summary_persistence.py
diff --git a/tests/Integration/test_entities_router.py b/tests/Integration/test_entities_router.py
new file mode 100644
index 000000000..5e7254497
--- /dev/null
+++ b/tests/Integration/test_entities_router.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+from backend.web.routers import entities as entities_router
+from storage.contracts import EntityRow, MemberRow
+
+
+@pytest.mark.asyncio
+async def test_list_entities_excludes_child_agent_branches_from_chat_discovery():
+ now = 1_775_223_756.0
+ user = MemberRow(id="u1", name="owner", type="human", created_at=now)
+ other_human = MemberRow(id="u2", name="other", type="human", created_at=now)
+ main_agent_member = MemberRow(
+ id="a-main",
+ name="Toad",
+ type="mycel_agent",
+ owner_user_id="u2",
+ created_at=now,
+ )
+ child_agent_member = MemberRow(
+ id="a-child",
+ name="Toad Branch",
+ type="mycel_agent",
+ owner_user_id="u2",
+ created_at=now,
+ )
+
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ entity_repo=SimpleNamespace(
+ list_by_type=lambda entity_type: (
+ [
+ EntityRow(id="a-main-1", type="agent", member_id="a-main", name="Toad", thread_id="thread-main", created_at=now),
+ EntityRow(
+ id="a-child-1",
+ type="agent",
+ member_id="a-child",
+ name="Toad · 分身1",
+ thread_id="thread-child",
+ created_at=now,
+ ),
+ ]
+ if entity_type == "agent"
+ else []
+ )
+ ),
+ member_repo=SimpleNamespace(list_all=lambda: [user, other_human, main_agent_member, child_agent_member]),
+ thread_repo=SimpleNamespace(
+ get_by_id=lambda thread_id: (
+ {"is_main": True, "branch_index": 0} if thread_id == "thread-main" else {"is_main": False, "branch_index": 1}
+ )
+ ),
+ )
+ )
+
+ result = await entities_router.list_entities(user_id="u1", app=app)
+
+ assert [item["id"] for item in result] == ["u2", "a-main-1"]
diff --git a/tests/Integration/test_leon_agent.py b/tests/Integration/test_leon_agent.py
new file mode 100644
index 000000000..e410f7df4
--- /dev/null
+++ b/tests/Integration/test_leon_agent.py
@@ -0,0 +1,1206 @@
+"""Integration tests for LeonAgent with QueryLoop.
+
+Uses mock model to verify the full astream pipeline without real API calls.
+"""
+
+import json
+import os
+from types import SimpleNamespace
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, SystemMessage, ToolMessage
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _mock_model(text="Integration test response"):
+ """Create a mock LangChain model that returns a plain AIMessage."""
+ ai_msg = AIMessage(content=text)
+ model = MagicMock()
+ model.bind_tools.return_value = model
+ model.ainvoke = AsyncMock(return_value=ai_msg)
+ # configurable_fields support
+ model.configurable_fields.return_value = model
+ model.with_config.return_value = model
+ return model
+
+
+def _empty_stream_model():
+ class _EmptyStreamModel:
+ def bind_tools(self, tools):
+ return self
+
+ def configurable_fields(self, **kwargs):
+ return self
+
+ def with_config(self, **kwargs):
+ return self
+
+ async def astream(self, messages):
+ if False:
+ yield AIMessageChunk(content="")
+
+ return _EmptyStreamModel()
+
+
+def _patch_env_api_key():
+ """Ensure ANTHROPIC_API_KEY is set for LeonAgent init (uses a fake value)."""
+ return patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-test-integration"})
+
+
+class _MemoryCheckpointer:
+ def __init__(self):
+ self.store = {}
+
+ async def aget(self, cfg):
+ return self.store.get(cfg["configurable"]["thread_id"])
+
+ async def aput(self, cfg, checkpoint, metadata, new_versions):
+ self.store[cfg["configurable"]["thread_id"]] = checkpoint
+
+
+class _DirectCompactionProbeModel:
+ def __init__(self):
+ self.summary_calls = 0
+ self.turn_calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ def configurable_fields(self, **kwargs):
+ return self
+
+ def with_config(self, **kwargs):
+ return self
+
+ def bind(self, **kwargs):
+ return self
+
+ async def ainvoke(self, messages):
+ first_content = getattr(messages[0], "content", "") if messages else ""
+ if isinstance(first_content, str) and "summarizing conversations" in first_content:
+ self.summary_calls += 1
+ return AIMessage(
+ content=(
+ "1. Request/Intent — summarize\n"
+ "2. Technical Concepts — compaction\n"
+ "3. Files/Code — none\n"
+ "4. Errors — none\n"
+ "5. Problem Solving — keep going\n"
+ "6. User Messages — large payloads\n"
+ "7. Pending Tasks — continue\n"
+ "8. Current Work — compacting\n"
+ "9. Next Step — answer user"
+ )
+ )
+
+ self.turn_calls += 1
+ return AIMessage(content=f"OK_{self.turn_calls}")
+
+
+class _MessageCaptureModel:
+ def __init__(self, text: str = "captured"):
+ self.calls: list[list[object]] = []
+ self.text = text
+
+ def bind_tools(self, tools):
+ return self
+
+ def configurable_fields(self, **kwargs):
+ return self
+
+ def with_config(self, **kwargs):
+ return self
+
+ def bind(self, **kwargs):
+ return self
+
+ async def ainvoke(self, messages):
+ self.calls.append(list(messages))
+ return AIMessage(content=self.text)
+
+
+def test_leon_agent_destructor_does_not_reenable_skipped_sandbox_cleanup():
+ """Explicit child close(cleanup_sandbox=False) must stay final under __del__."""
+ from core.runtime.agent import LeonAgent
+
+ agent = object.__new__(LeonAgent)
+ agent._session_started = False
+ agent._mark_terminated = MagicMock()
+ agent._cleanup_mcp_client = MagicMock()
+ agent._cleanup_sqlite_connection = MagicMock()
+ agent._cleanup_sandbox = MagicMock()
+
+ LeonAgent.close(agent, cleanup_sandbox=False)
+ LeonAgent.__del__(agent)
+
+ agent._cleanup_sandbox.assert_not_called()
+
+
+# ---------------------------------------------------------------------------
+# Integration Tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_simple_run(tmp_path):
+ """LeonAgent with mock model: astream completes and yields chunks."""
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Hello from integration test")
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ results = []
+ async for chunk in agent.agent.astream(
+ {"messages": [{"role": "user", "content": "hello"}]},
+ config={"configurable": {"thread_id": "test-integration-1"}},
+ stream_mode="updates",
+ ):
+ results.append(chunk)
+
+ assert len(results) > 0
+ # At least one agent chunk
+ agent_chunks = [c for c in results if "agent" in c]
+ assert len(agent_chunks) >= 1
+ # Agent message content matches mock
+ first_ai_msgs = agent_chunks[0]["agent"]["messages"]
+ assert any("integration test" in str(m.content) for m in first_ai_msgs)
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_astream_interface_compatible(tmp_path):
+ """astream yields dicts with 'agent' key — compatible with LangGraph stream_mode=updates."""
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Compatible response")
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ chunks = []
+ async for chunk in agent.agent.astream(
+ {"messages": [{"role": "user", "content": "test"}]},
+ config={"configurable": {"thread_id": "test-integration-2"}},
+ stream_mode="updates",
+ ):
+ chunks.append(chunk)
+
+ # All chunks are dicts
+ assert all(isinstance(c, dict) for c in chunks)
+ # All keys are one of "agent" or "tools"
+ for c in chunks:
+ assert set(c.keys()).issubset({"agent", "tools"})
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_astream_messages_updates_mode_yields_langgraph_tuples(tmp_path):
+ """messages+updates mode must yield LangGraph-style (mode, data) tuples for SSE consumers."""
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Tuple compatible response")
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ chunks = []
+ async for chunk in agent.agent.astream(
+ {"messages": [{"role": "user", "content": "tuple"}]},
+ config={"configurable": {"thread_id": "test-integration-tuples"}},
+ stream_mode=["messages", "updates"],
+ ):
+ chunks.append(chunk)
+
+ assert chunks
+ assert all(isinstance(chunk, tuple) and len(chunk) == 2 for chunk in chunks)
+ assert any(mode == "messages" for mode, _ in chunks)
+ assert any(mode == "updates" for mode, _ in chunks)
+
+ message_chunks = [data for mode, data in chunks if mode == "messages"]
+ first_msg_chunk, first_metadata = message_chunks[0]
+ assert isinstance(first_msg_chunk, AIMessageChunk)
+ assert "Tuple compatible response" in str(first_msg_chunk.content)
+ assert isinstance(first_metadata, dict)
+
+ update_chunks = [data for mode, data in chunks if mode == "updates"]
+ assert any("agent" in update for update in update_chunks)
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_astream_raises_loudly_on_empty_stream(tmp_path):
+ """Empty streaming responses should surface as errors, not silent empty iterators."""
+ from core.runtime.agent import LeonAgent
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=_empty_stream_model()),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ with pytest.raises(RuntimeError, match="streaming model returned no AIMessageChunk"):
+ async for _ in agent.astream(
+ "test",
+ thread_id="test-empty-stream",
+ stream_mode=["messages", "updates"],
+ ):
+ pass
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_bundle_dir_registers_mcp_resource_tools(tmp_path):
+ """Member bundle MCP config should surface MCP resource tools in the live registry."""
+ from core.runtime.agent import LeonAgent
+
+ member_dir = tmp_path / "members" / "toad"
+ member_dir.mkdir(parents=True)
+ (member_dir / "agent.md").write_text(
+ "---\nname: Toad\ndescription: Demo member\n---\nYou are Toad.\n",
+ encoding="utf-8",
+ )
+ (member_dir / ".mcp.json").write_text(
+ '{"mcpServers":{"nu50demo":{"transport":"stdio","command":"uv","args":["run","python","/tmp/nu50_mcp_server.py"]}}}',
+ encoding="utf-8",
+ )
+
+ mock_model = _mock_model("Bundle MCP response")
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(
+ workspace_root=str(tmp_path),
+ bundle_dir=str(member_dir),
+ api_key="sk-test-integration",
+ )
+ await agent.ainit()
+
+ assert agent._tool_registry.get("ListMcpResources") is not None
+ assert agent._tool_registry.get("ReadMcpResource") is not None
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_announces_mcp_instruction_delta_once_and_reannounces_on_change(tmp_path):
+ from core.runtime.agent import LeonAgent
+
+ member_dir = tmp_path / "members" / "toad"
+ member_dir.mkdir(parents=True)
+ (member_dir / "agent.md").write_text(
+ "---\nname: Toad\ndescription: Demo member\n---\nYou are Toad.\n",
+ encoding="utf-8",
+ )
+
+ def _write_mcp(instructions: str) -> None:
+ (member_dir / ".mcp.json").write_text(
+ json.dumps(
+ {
+ "mcpServers": {
+ "nu50demo": {
+ "transport": "stdio",
+ "command": "uv",
+ "args": ["run", "python", "/tmp/nu50_mcp_server.py"],
+ "instructions": instructions,
+ }
+ }
+ }
+ ),
+ encoding="utf-8",
+ )
+
+ def _message_text(message: object) -> str:
+ content = getattr(message, "content", "")
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ return "\n".join(str(block.get("text", "")) for block in content if isinstance(block, dict))
+ return str(content)
+
+ def _delta_messages(messages: list[object]) -> list[str]:
+ hits: list[str] = []
+ for message in messages:
+ content = _message_text(message)
+ if "" in content:
+ hits.append(content)
+ return hits
+
+ _write_mcp("Use nu50demo carefully.")
+ first_model = _MessageCaptureModel("First MCP delta response")
+ checkpointer = _MemoryCheckpointer()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=first_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(
+ workspace_root=str(tmp_path),
+ bundle_dir=str(member_dir),
+ api_key="sk-test-integration",
+ )
+ await agent.ainit()
+ agent.checkpointer = checkpointer
+ agent.agent.checkpointer = checkpointer
+
+ await agent.ainvoke("first turn", thread_id="mcp-delta-thread")
+ assert first_model.calls
+ first_messages = first_model.calls[0]
+ first_deltas = _delta_messages(first_messages)
+ assert len(first_deltas) == 1
+ assert "Use nu50demo carefully." in first_deltas[0]
+
+ second_call_index = len(first_model.calls)
+ await agent.ainvoke("second turn", thread_id="mcp-delta-thread")
+ assert len(first_model.calls) > second_call_index
+ second_messages = first_model.calls[second_call_index]
+ second_deltas = _delta_messages(second_messages)
+ assert len(second_deltas) == 1
+ assert second_deltas[0] == first_deltas[0]
+
+ agent.close()
+
+ _write_mcp("Use nu50demo only for trusted reads.")
+ second_model = _MessageCaptureModel("Second MCP delta response")
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=second_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(
+ workspace_root=str(tmp_path),
+ bundle_dir=str(member_dir),
+ api_key="sk-test-integration",
+ )
+ await agent.ainit()
+ agent.checkpointer = checkpointer
+ agent.agent.checkpointer = checkpointer
+
+ await agent.ainvoke("third turn", thread_id="mcp-delta-thread")
+ assert second_model.calls
+ third_messages = second_model.calls[0]
+ third_deltas = _delta_messages(third_messages)
+ assert len(third_deltas) == 2
+ assert "Use nu50demo carefully." in third_deltas[0]
+ assert "Use nu50demo only for trusted reads." in third_deltas[1]
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_memoizes_prompt_sections_between_builds(tmp_path):
+ """Pattern 6: prompt sections should be cached across repeated prompt assembly."""
+ from core.runtime import prompts as prompt_builders
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Prompt cache response")
+ original_context = prompt_builders.build_context_section
+ original_rules = prompt_builders.build_rules_section
+ counts = {"context": 0, "rules": 0}
+
+ def counted_context(*args, **kwargs):
+ counts["context"] += 1
+ return original_context(*args, **kwargs)
+
+ def counted_rules(*args, **kwargs):
+ counts["rules"] += 1
+ return original_rules(*args, **kwargs)
+
+ with (
+ patch("core.runtime.prompts.build_context_section", side_effect=counted_context),
+ patch("core.runtime.prompts.build_rules_section", side_effect=counted_rules),
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ first = agent._compose_system_prompt()
+ second = agent._compose_system_prompt()
+
+ assert first == second
+ assert counts == {"context": 1, "rules": 1}
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_clear_thread_invalidates_prompt_section_cache(tmp_path):
+ """Pattern 6: clear should invalidate cached prompt sections before rebuilding."""
+ from core.runtime import prompts as prompt_builders
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Prompt clear response")
+ original_context = prompt_builders.build_context_section
+ original_rules = prompt_builders.build_rules_section
+ counts = {"context": 0, "rules": 0}
+
+ def counted_context(*args, **kwargs):
+ counts["context"] += 1
+ return original_context(*args, **kwargs)
+
+ def counted_rules(*args, **kwargs):
+ counts["rules"] += 1
+ return original_rules(*args, **kwargs)
+
+ with (
+ patch("core.runtime.prompts.build_context_section", side_effect=counted_context),
+ patch("core.runtime.prompts.build_rules_section", side_effect=counted_rules),
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+ agent.agent.aclear = AsyncMock()
+
+ assert counts == {"context": 1, "rules": 1}
+
+ await agent.aclear_thread("prompt-clear-thread")
+
+ assert counts == {"context": 2, "rules": 2}
+
+ agent.close()
+
+
+def test_build_rules_section_unifies_core_risk_and_tool_preferences():
+ from core.runtime.prompts import build_rules_section
+
+ rules = build_rules_section(
+ is_sandbox=False,
+ working_dir="/repo",
+ workspace_root="/repo",
+ )
+
+ assert "**Workspace**" in rules
+ assert "**Absolute Paths**" in rules
+ assert "**Security**" in rules
+ assert "**Tool Priority**" in rules
+ assert "Do not guess URLs" in rules
+ assert "Do not add features, refactor code, or make speculative abstractions" in rules
+ assert "Don't create helpers, utilities, or abstractions for one-time operations" in rules
+ assert "Don't add error handling, fallbacks, or validation for scenarios that can't happen" in rules
+ assert "Prefer dedicated tools over `Bash`" in rules
+ assert "Use `Read` instead of `cat`, `head`, or `tail`." in rules
+ assert "Use `Glob`/`Grep` for file discovery and content search before falling back to `Bash`." in rules
+ assert "Ask before destructive, hard-to-reverse, or shared-state actions" in rules
+ assert (
+ "Examples: deleting files, force-pushing, dropping tables, killing unfamiliar processes, modifying shared infrastructure." in rules
+ )
+ assert "Background Task Description" not in rules
+
+
+def test_build_rules_section_includes_function_result_clearing_guidance_when_spill_buffer_enabled():
+ from core.runtime.prompts import build_rules_section
+
+ rules = build_rules_section(
+ is_sandbox=False,
+ working_dir="/repo",
+ workspace_root="/repo",
+ spill_buffer_enabled=True,
+ spill_keep_recent=3,
+ )
+
+ assert "**Function Result Clearing**" in rules
+ assert "Old tool results may be cleared from context to free up space." in rules
+ assert "The 3 most recent results are always kept." in rules
+ assert "write down any important information you might need later in your response" in rules
+
+
+def test_build_rules_section_omits_function_result_clearing_guidance_when_spill_buffer_disabled():
+ from core.runtime.prompts import build_rules_section
+
+ rules = build_rules_section(
+ is_sandbox=False,
+ working_dir="/repo",
+ workspace_root="/repo",
+ spill_buffer_enabled=False,
+ spill_keep_recent=3,
+ )
+
+ assert "**Function Result Clearing**" not in rules
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_session_start_hook_runs_on_ainit(tmp_path):
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Session start response")
+ seen = []
+
+ def on_start(payload):
+ seen.append(payload)
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ agent.app_state.add_session_hook("SessionStart", on_start)
+
+ await agent.ainit()
+
+ assert len(seen) == 1
+ assert seen[0]["event"] == "SessionStart"
+ assert seen[0]["sandbox"] == "local"
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_session_end_hook_runs_on_close(tmp_path):
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Session end response")
+ seen = []
+
+ def on_end(payload):
+ seen.append(payload)
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+ agent.app_state.add_session_hook("SessionEnd", on_end)
+
+ agent.close()
+
+ assert len(seen) == 1
+ assert seen[0]["event"] == "SessionEnd"
+ assert seen[0]["sandbox"] == "local"
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_session_hooks_support_async_callbacks_and_fire_once(tmp_path):
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Session once response")
+ seen = []
+
+ async def on_start(payload):
+ seen.append(("start", payload["event"]))
+
+ async def on_end(payload):
+ seen.append(("end", payload["event"]))
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ agent.app_state.add_session_hook("SessionStart", on_start)
+ agent.app_state.add_session_hook("SessionEnd", on_end)
+
+ await agent.ainit()
+ await agent.ainit()
+ agent.close()
+ agent.close()
+
+ assert seen == [("start", "SessionStart"), ("end", "SessionEnd")]
+
+
+class _DeferredDiscoveryProbeModel:
+ def __init__(self):
+ self.turn_tool_names: list[list[str]] = []
+ self._tools: list[dict] = []
+ self._turn = 0
+
+ def bind_tools(self, tools):
+ self._tools = list(tools or [])
+ self.turn_tool_names.append([tool.get("name") for tool in self._tools if isinstance(tool, dict)])
+ return self
+
+ def configurable_fields(self, **kwargs):
+ return self
+
+ def with_config(self, *args, **kwargs):
+ return self
+
+ async def ainvoke(self, messages):
+ if self._turn == 0:
+ self._turn += 1
+ return AIMessage(
+ content="",
+ tool_calls=[{"name": "tool_search", "args": {"query": "select:TaskCreate"}, "id": "tc-search"}],
+ )
+ self._turn += 1
+ return AIMessage(content="done")
+
+
+class _DeferredExecutionProbeModel:
+ def __init__(self):
+ self.turn_tool_names: list[list[str]] = []
+ self._tools: list[dict] = []
+ self._turn = 0
+
+ def bind_tools(self, tools):
+ self._tools = list(tools or [])
+ self.turn_tool_names.append([tool.get("name") for tool in self._tools if isinstance(tool, dict)])
+ return self
+
+ def configurable_fields(self, **kwargs):
+ return self
+
+ def with_config(self, *args, **kwargs):
+ return self
+
+ async def ainvoke(self, messages):
+ if self._turn == 0:
+ self._turn += 1
+ return AIMessage(
+ content="",
+ tool_calls=[{"name": "tool_search", "args": {"query": "select:TaskCreate"}, "id": "tc-search"}],
+ )
+ if self._turn == 1:
+ self._turn += 1
+ return AIMessage(
+ content="",
+ tool_calls=[
+ {
+ "name": "TaskCreate",
+ "args": {"subject": "PT02_EXEC", "description": "created after discovery"},
+ "id": "tc-task-create",
+ }
+ ],
+ )
+ self._turn += 1
+ return AIMessage(content="PT02_EXEC_DONE")
+
+
+class _DeferredCrossThreadProbeModel:
+ def __init__(self):
+ self.turn_tool_names: list[list[str]] = []
+ self._tools: list[dict] = []
+
+ def bind_tools(self, tools):
+ self._tools = list(tools or [])
+ self.turn_tool_names.append([tool.get("name") for tool in self._tools if isinstance(tool, dict)])
+ return self
+
+ def configurable_fields(self, **kwargs):
+ return self
+
+ def with_config(self, *args, **kwargs):
+ return self
+
+ async def ainvoke(self, messages):
+ joined = " ".join(str(getattr(msg, "content", "")) for msg in messages)
+ current_tool_names = {tool.get("name") for tool in self._tools if isinstance(tool, dict)}
+
+ if "discover task tools" in joined and "TaskCreate" not in current_tool_names:
+ return AIMessage(
+ content="",
+ tool_calls=[{"name": "tool_search", "args": {"query": "select:TaskCreate"}, "id": "tc-search"}],
+ )
+
+ if "discover task tools" in joined:
+ return AIMessage(content="discover-done")
+
+ return AIMessage(content="plain-done")
+
+
+class _DeferredInlineSelectProbeModel:
+ def __init__(self):
+ self.turn_tool_names: list[list[str]] = []
+ self._tools: list[dict] = []
+ self._turn = 0
+
+ def bind_tools(self, tools):
+ self._tools = list(tools or [])
+ self.turn_tool_names.append([tool.get("name") for tool in self._tools if isinstance(tool, dict)])
+ return self
+
+ def configurable_fields(self, **kwargs):
+ return self
+
+ def with_config(self, *args, **kwargs):
+ return self
+
+ async def ainvoke(self, messages):
+ if self._turn == 0:
+ self._turn += 1
+ return AIMessage(
+ content="",
+ tool_calls=[{"name": "tool_search", "args": {"query": "select:Read,TaskCreate"}, "id": "tc-search"}],
+ )
+ self._turn += 1
+ return AIMessage(content="after-inline-select")
+
+
+class _DeferredResumeProbeModel:
+ def __init__(self):
+ self.turn_tool_names: list[list[str]] = []
+ self._tools: list[dict] = []
+
+ def bind_tools(self, tools):
+ self._tools = list(tools or [])
+ self.turn_tool_names.append([tool.get("name") for tool in self._tools if isinstance(tool, dict)])
+ return self
+
+ def configurable_fields(self, **kwargs):
+ return self
+
+ def with_config(self, *args, **kwargs):
+ return self
+
+ async def ainvoke(self, messages):
+ return AIMessage(content="resume-done")
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_reinjects_discovered_deferred_tool_schemas_on_following_turn(tmp_path):
+ """Deferred tools discovered via tool_search must become real schemas on the next turn."""
+ from core.runtime.agent import LeonAgent
+
+ probe_model = _DeferredDiscoveryProbeModel()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=probe_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ result = await agent.ainvoke("discover task tools", thread_id="test-deferred-discovery")
+
+ assert result["reason"] == "completed"
+ assert len(probe_model.turn_tool_names) >= 2
+ first_turn, second_turn = probe_model.turn_tool_names[:2]
+ assert "TaskCreate" not in first_turn
+ assert "tool_search" in first_turn
+ assert "TaskCreate" in second_turn
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_can_execute_discovered_deferred_tool_on_following_turn(tmp_path):
+ """A deferred tool discovered via tool_search should become callable on the next turn."""
+ from core.runtime.agent import LeonAgent
+
+ probe_model = _DeferredExecutionProbeModel()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=probe_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ result = await agent.ainvoke("discover then run deferred task tool", thread_id="test-deferred-execution")
+
+ assert result["reason"] == "completed"
+ assert len(probe_model.turn_tool_names) >= 2
+ assert "TaskCreate" not in probe_model.turn_tool_names[0]
+ assert "TaskCreate" in probe_model.turn_tool_names[1]
+
+ task_tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage) and msg.tool_call_id == "tc-task-create"]
+ assert len(task_tool_messages) == 1
+ assert "PT02_EXEC" in str(task_tool_messages[0].content)
+ assert any(isinstance(msg, AIMessage) and msg.content == "PT02_EXEC_DONE" for msg in result["messages"])
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_deferred_discovery_does_not_leak_across_threads(tmp_path):
+ """Deferred tools discovered on one thread must not become inline on another thread."""
+ from core.runtime.agent import LeonAgent
+
+ probe_model = _DeferredCrossThreadProbeModel()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=probe_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ result_a = await agent.ainvoke("discover task tools", thread_id="thread-A")
+ result_b = await agent.ainvoke("plain request", thread_id="thread-B")
+
+ assert result_a["reason"] == "completed"
+ assert result_b["reason"] == "completed"
+ assert len(probe_model.turn_tool_names) >= 3
+
+ first_thread_a, second_thread_a, first_thread_b = probe_model.turn_tool_names[:3]
+ assert "TaskCreate" not in first_thread_a
+ assert "TaskCreate" in second_thread_a
+ assert "TaskCreate" not in first_thread_b
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_tool_search_exact_select_fails_loudly_for_inline_tools(tmp_path):
+ """Exact select should surface inline-tool misuse as a tool_use_error in the live loop."""
+ from core.runtime.agent import LeonAgent
+
+ probe_model = _DeferredInlineSelectProbeModel()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=probe_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ result = await agent.ainvoke("probe inline select", thread_id="test-inline-select")
+
+ assert result["reason"] == "completed"
+ tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage) and msg.tool_call_id == "tc-search"]
+ assert len(tool_messages) == 1
+ assert "" in str(tool_messages[0].content)
+ assert "inline/already-available tools: Read" in str(tool_messages[0].content)
+ assert any(isinstance(msg, AIMessage) and msg.content == "after-inline-select" for msg in result["messages"])
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_restores_discovered_deferred_tools_after_restart(tmp_path):
+ """Restarting the loop on the same thread should restore prior deferred discoveries from history."""
+ from core.runtime.agent import LeonAgent
+
+ checkpointer = _MemoryCheckpointer()
+ discovery_model = _DeferredDiscoveryProbeModel()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=discovery_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+ agent.checkpointer = checkpointer
+ agent.agent.checkpointer = checkpointer
+
+ result = await agent.ainvoke("discover task tools", thread_id="resume-thread")
+ assert result["reason"] == "completed"
+ agent.close()
+
+ resume_model = _DeferredResumeProbeModel()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=resume_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+ agent.checkpointer = checkpointer
+ agent.agent.checkpointer = checkpointer
+
+ result = await agent.ainvoke("after restart", thread_id="resume-thread")
+
+ assert result["reason"] == "completed"
+ assert resume_model.turn_tool_names
+ assert "TaskCreate" in resume_model.turn_tool_names[0]
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_multiple_thread_ids(tmp_path):
+ """Different thread_ids produce independent sessions (no cross-contamination)."""
+ from core.runtime.agent import LeonAgent
+
+ mock_model = MagicMock()
+ mock_model.bind_tools.return_value = mock_model
+ mock_model.with_config.return_value = mock_model
+ mock_model.configurable_fields.return_value = mock_model
+ mock_model.ainvoke = AsyncMock(
+ side_effect=[
+ AIMessage(content="Response for thread-A"),
+ AIMessage(content="Response for thread-B"),
+ ]
+ )
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ chunks_a = []
+ async for chunk in agent.agent.astream(
+ {"messages": [{"role": "user", "content": "hi A"}]},
+ config={"configurable": {"thread_id": "thread-A"}},
+ stream_mode="updates",
+ ):
+ chunks_a.append(chunk)
+
+ chunks_b = []
+ async for chunk in agent.agent.astream(
+ {"messages": [{"role": "user", "content": "hi B"}]},
+ config={"configurable": {"thread_id": "thread-B"}},
+ stream_mode="updates",
+ ):
+ chunks_b.append(chunk)
+
+ # Both sessions produced chunks
+ assert len(chunks_a) > 0
+ assert len(chunks_b) > 0
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_astream_wrapper_exposes_caller_surface(tmp_path):
+ """LeonAgent should expose a caller-owned astream surface instead of forcing callers onto agent.agent.astream."""
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Caller surface response")
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ chunks = []
+ async for chunk in agent.astream(
+ "caller stream",
+ thread_id="test-astream-wrapper",
+ stream_mode=["messages", "updates"],
+ ):
+ chunks.append(chunk)
+
+ assert chunks
+ assert all(isinstance(chunk, tuple) and len(chunk) == 2 for chunk in chunks)
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_astream_can_enforce_max_budget_per_event(tmp_path):
+ """Caller-owned astream surface should be able to stop once runtime cost exceeds a caller budget."""
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("Caller surface response")
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+
+ async def fake_stream(*args, **kwargs):
+ yield ("messages", ("first", {"langgraph_node": "agent"}))
+ yield ("updates", {"agent": {"messages": [AIMessage(content="done")]}})
+
+ agent.agent.astream = fake_stream
+ agent.runtime = SimpleNamespace(cost=0.75)
+
+ chunks = []
+ with pytest.raises(RuntimeError, match="max_budget_usd exceeded"):
+ async for chunk in agent.astream(
+ "caller stream",
+ thread_id="test-astream-budget",
+ stream_mode=["messages", "updates"],
+ max_budget_usd=0.5,
+ ):
+ chunks.append(chunk)
+
+ assert chunks == [("messages", ("first", {"langgraph_node": "agent"}))]
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_aclear_thread_resets_thread_history(tmp_path):
+ """aclear_thread should clear replayable thread history while preserving accumulators."""
+ from core.runtime.agent import LeonAgent
+
+ mock_model = _mock_model("clearable response")
+ checkpointer = _MemoryCheckpointer()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+ agent.checkpointer = checkpointer
+ agent.agent.checkpointer = checkpointer
+ agent.app_state.total_cost = 1.25
+
+ await agent.ainvoke("hello", thread_id="clear-agent-thread")
+ assert checkpointer.store["clear-agent-thread"]["channel_values"]["messages"]
+
+ agent.agent._tool_read_file_state["/tmp/file.py"] = {"partial": False}
+ agent.agent._tool_loaded_nested_memory_paths.add("/tmp/memory.md")
+ agent.agent._tool_discovered_skill_names.add("skill-a")
+ old_session_id = agent._bootstrap.session_id
+
+ await agent.aclear_thread("clear-agent-thread")
+
+ assert checkpointer.store["clear-agent-thread"]["channel_values"]["messages"] == []
+ assert agent.app_state.messages == []
+ assert agent.app_state.turn_count == 0
+ assert agent.app_state.compact_boundary_index == 0
+ assert agent.app_state.total_cost == 1.25
+ assert agent._bootstrap.session_id != old_session_id
+ assert agent._bootstrap.parent_session_id == old_session_id
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_aclear_thread_does_not_restore_stale_summary(tmp_path):
+ from core.runtime.agent import LeonAgent
+ from core.runtime.middleware import ModelRequest, ModelResponse
+ from core.runtime.middleware.memory.summary_store import SummaryStore
+ from sandbox.thread_context import set_current_thread_id
+
+ async def _handler(req: ModelRequest) -> ModelResponse:
+ return ModelResponse(result=[AIMessage(content="final")], request_messages=req.messages)
+
+ mock_model = _mock_model("clearable response")
+ checkpointer = _MemoryCheckpointer()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=mock_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+ agent.checkpointer = checkpointer
+ agent.agent.checkpointer = checkpointer
+
+ store = SummaryStore(tmp_path / "summary.db")
+ agent._memory_middleware.summary_store = store
+ store.save_summary(
+ thread_id="clear-summary-thread",
+ summary_text="STALE SUMMARY",
+ compact_up_to_index=2,
+ compacted_at=2,
+ )
+
+ await agent.aclear_thread("clear-summary-thread")
+
+ assert store.get_latest_summary("clear-summary-thread") is None
+
+ set_current_thread_id("clear-summary-thread")
+ request = ModelRequest(
+ model=mock_model,
+ messages=[HumanMessage(content="fresh-1"), HumanMessage(content="fresh-2")],
+ system_message=SystemMessage(content="sys"),
+ )
+ result = await agent._memory_middleware.awrap_model_call(request, _handler)
+
+ assert [msg.content for msg in result.request_messages] == ["fresh-1", "fresh-2"]
+
+ agent.close()
+
+
+@pytest.mark.asyncio
+@_patch_env_api_key()
+async def test_leon_agent_persists_summary_store_after_second_turn_compaction(tmp_path):
+ from core.runtime.agent import LeonAgent
+ from core.runtime.middleware.memory.summary_store import SummaryStore
+
+ checkpointer = _MemoryCheckpointer()
+ probe_model = _DirectCompactionProbeModel()
+
+ with (
+ patch("core.runtime.agent.LeonAgent._create_model", return_value=probe_model),
+ patch("core.runtime.agent.LeonAgent._init_async_components", return_value=(None, [])),
+ patch("core.runtime.agent.LeonAgent._init_checkpointer", new_callable=AsyncMock, return_value=None),
+ ):
+ agent = LeonAgent(workspace_root=str(tmp_path), api_key="sk-test-integration")
+ await agent.ainit()
+ agent.checkpointer = checkpointer
+ agent.agent.checkpointer = checkpointer
+
+ store = SummaryStore(tmp_path / "summary.db")
+ agent._memory_middleware.summary_store = store
+ agent._memory_middleware._compaction_threshold = 0.01
+ agent._memory_middleware.compactor.keep_recent_tokens = 10
+
+ turn1 = await agent.ainvoke("A" * 12000, thread_id="agent-compaction-thread")
+ assert turn1["reason"] == "completed"
+ assert store.get_latest_summary("agent-compaction-thread") is None
+
+ turn2 = await agent.ainvoke("B" * 12000, thread_id="agent-compaction-thread")
+ assert turn2["reason"] == "completed"
+ assert probe_model.summary_calls == 1
+ assert agent._memory_middleware._cached_summary is not None
+ assert agent._memory_middleware._compact_up_to_index > 0
+
+ summary = store.get_latest_summary("agent-compaction-thread")
+ assert summary is not None
+ assert summary.compact_up_to_index == agent._memory_middleware._compact_up_to_index
+ assert "Request/Intent" in summary.summary_text
+
+ agent.close()
diff --git a/tests/middleware/memory/test_memory_middleware_integration.py b/tests/Integration/test_memory_middleware_integration.py
similarity index 80%
rename from tests/middleware/memory/test_memory_middleware_integration.py
rename to tests/Integration/test_memory_middleware_integration.py
index 2892d1081..a33a60098 100644
--- a/tests/middleware/memory/test_memory_middleware_integration.py
+++ b/tests/Integration/test_memory_middleware_integration.py
@@ -3,13 +3,16 @@
Tests the complete flow: MemoryMiddleware → SummaryStore → SQLite → Checkpointer
"""
-from unittest.mock import AsyncMock, MagicMock
+from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage
+from langchain_core.runnables import RunnableLambda
+from core.runtime.middleware import ModelRequest, ModelResponse
from core.runtime.middleware.memory.middleware import MemoryMiddleware
from core.runtime.middleware.memory.summary_store import SummaryStore
+from sandbox.thread_context import set_current_thread_id
@pytest.fixture
@@ -41,7 +44,7 @@ def mock_get(config):
@pytest.fixture
def mock_model():
"""Create mock LLM model for testing."""
- model = AsyncMock()
+ model = MagicMock()
async def mock_ainvoke(messages):
# Return a mock summary response
@@ -50,6 +53,7 @@ async def mock_ainvoke(messages):
return response
model.ainvoke = mock_ainvoke
+ model.bind.return_value = model
return model
@@ -165,6 +169,59 @@ async def mock_handler(req):
assert middleware2._compact_up_to_index == original_index
assert middleware2._summary_restored is True
+ @pytest.mark.asyncio
+ async def test_summary_restore_is_isolated_per_thread_on_shared_middleware(self, temp_db, mock_model):
+ middleware = MemoryMiddleware(
+ context_limit=10000,
+ compaction_threshold=0.5,
+ db_path=temp_db,
+ verbose=True,
+ )
+ middleware.set_model(mock_model)
+
+ store = SummaryStore(temp_db)
+ store.save_summary(
+ thread_id="t1",
+ summary_text="SUMMARY ONE",
+ compact_up_to_index=1,
+ compacted_at=2,
+ )
+ store.save_summary(
+ thread_id="t2",
+ summary_text="SUMMARY TWO",
+ compact_up_to_index=1,
+ compacted_at=2,
+ )
+
+ async def handler(req: ModelRequest) -> ModelResponse:
+ return ModelResponse(result=[], request_messages=req.messages)
+
+ request_t1 = ModelRequest(
+ model=RunnableLambda(lambda x: x),
+ messages=[HumanMessage(content="a1"), HumanMessage(content="a2")],
+ system_message=None,
+ )
+
+ request_t2 = ModelRequest(
+ model=RunnableLambda(lambda x: x),
+ messages=[HumanMessage(content="b1"), HumanMessage(content="b2")],
+ system_message=None,
+ )
+
+ set_current_thread_id("t1")
+ result_t1 = await middleware.awrap_model_call(request_t1, handler)
+ set_current_thread_id("t2")
+ result_t2 = await middleware.awrap_model_call(request_t2, handler)
+
+ assert [getattr(msg, "content", "") for msg in result_t1.request_messages] == [
+ "[Conversation Summary]\nSUMMARY ONE",
+ "a2",
+ ]
+ assert [getattr(msg, "content", "") for msg in result_t2.request_messages] == [
+ "[Conversation Summary]\nSUMMARY TWO",
+ "b2",
+ ]
+
class TestSplitTurnSaveAndRestore:
"""Test 3: Verify split turn summaries are saved and restored correctly."""
@@ -325,6 +382,59 @@ async def mock_handler(req):
assert summary1.summary_id != summary2.summary_id
+class TestCompactionBreakerScope:
+ """Breaker should gate proactive compaction without poisoning reactive recovery."""
+
+ @pytest.mark.asyncio
+ async def test_reactive_recovery_can_bypass_and_clear_thread_breaker(self, temp_db, mock_request):
+ class _EventuallyRecoveringModel:
+ def __init__(self):
+ self.compact_calls = 0
+
+ async def ainvoke(self, messages):
+ self.compact_calls += 1
+ if self.compact_calls <= 3:
+ raise RuntimeError("compaction failed")
+ response = MagicMock()
+ response.content = "Recovered summary"
+ return response
+
+ model = _EventuallyRecoveringModel()
+ middleware = MemoryMiddleware(
+ context_limit=10000,
+ compaction_threshold=0.5,
+ db_path=temp_db,
+ verbose=True,
+ )
+ middleware.set_model(model)
+
+ messages = create_large_message_list(30)
+ mock_request.messages = messages
+
+ async def mock_handler(req):
+ return ModelResponse(result=[], request_messages=req.messages)
+
+ for _ in range(3):
+ await middleware.awrap_model_call(mock_request, mock_handler)
+
+ snapshot = middleware.snapshot_thread_state("test-thread-1")
+ assert snapshot == {"failure_count": 3, "breaker_open": True}
+
+ recovered = await middleware.compact_messages_for_recovery(
+ messages,
+ thread_id="test-thread-1",
+ )
+ assert recovered is not None
+ assert getattr(recovered[0], "content", "").startswith("[Conversation Summary]\nRecovered summary")
+
+ snapshot = middleware.snapshot_thread_state("test-thread-1")
+ assert snapshot == {"failure_count": 0, "breaker_open": False}
+
+ result = await middleware.awrap_model_call(mock_request, mock_handler)
+ assert getattr(result.request_messages[0], "content", "").startswith("[Conversation Summary]\nRecovered summary")
+ assert model.compact_calls >= 5
+
+
class TestMissingThreadIdRaisesError:
"""Test 6: Verify missing thread_id is handled gracefully."""
diff --git a/tests/test_monitor_resources_route.py b/tests/Integration/test_monitor_resources_route.py
similarity index 100%
rename from tests/test_monitor_resources_route.py
rename to tests/Integration/test_monitor_resources_route.py
diff --git a/tests/test_p3_api_only.py b/tests/Integration/test_p3_api_only.py
similarity index 100%
rename from tests/test_p3_api_only.py
rename to tests/Integration/test_p3_api_only.py
diff --git a/tests/test_p3_e2e.py b/tests/Integration/test_p3_e2e.py
similarity index 100%
rename from tests/test_p3_e2e.py
rename to tests/Integration/test_p3_e2e.py
diff --git a/tests/Integration/test_query_loop_backend_bridge.py b/tests/Integration/test_query_loop_backend_bridge.py
new file mode 100644
index 000000000..c7fa25cd5
--- /dev/null
+++ b/tests/Integration/test_query_loop_backend_bridge.py
@@ -0,0 +1,1971 @@
+"""Backend-facing regression tests for QueryLoop caller-contract bridge."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+from pathlib import Path
+from types import SimpleNamespace
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
+
+from backend.web.models.requests import SendMessageRequest
+from backend.web.routers import threads as threads_router
+from backend.web.routers.threads import get_thread_history, get_thread_messages
+from backend.web.services.display_builder import DisplayBuilder
+from backend.web.services.event_buffer import ThreadEventBuffer
+from backend.web.services.streaming_service import (
+ _ensure_thread_handlers,
+ _repair_incomplete_tool_calls,
+ _run_agent_to_buffer,
+ start_agent_run,
+)
+from core.runtime.loop import QueryLoop
+from core.runtime.middleware.memory.middleware import MemoryMiddleware
+from core.runtime.middleware.monitor.state_monitor import AgentState
+from core.runtime.middleware.queue.manager import MessageQueueManager
+from core.runtime.middleware.queue.middleware import SteeringMiddleware
+from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry
+from core.runtime.state import AppState, BootstrapConfig
+from core.tools.tool_search.service import ToolSearchService
+
+
+class _MemoryCheckpointer:
+ def __init__(self) -> None:
+ self.store: dict[str, dict] = {}
+
+ async def aget(self, cfg):
+ return self.store.get(cfg["configurable"]["thread_id"])
+
+ async def aget_tuple(self, cfg):
+ return None
+
+ async def aput(self, cfg, checkpoint, metadata, new_versions):
+ self.store[cfg["configurable"]["thread_id"]] = checkpoint
+
+
+class _NoToolModel:
+ def __init__(self, text: str = "done") -> None:
+ self._text = text
+
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ return AIMessage(content=self._text)
+
+
+class _TurnTextModel:
+ def __init__(self, *texts: str) -> None:
+ self._texts = list(texts)
+ self._index = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ if self._index < len(self._texts):
+ text = self._texts[self._index]
+ self._index += 1
+ return AIMessage(content=text)
+ return AIMessage(content=self._texts[-1] if self._texts else "done")
+
+
+class _TerminalFollowthroughPromptAwareModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ system_text = ""
+ if messages and messages[0].__class__.__name__ == "SystemMessage":
+ system_text = getattr(messages[0], "content", "") or ""
+ last_human = next(
+ (msg.content for msg in reversed(messages) if msg.__class__.__name__ == "HumanMessage"),
+ "",
+ )
+ if "CommandNotification" not in last_human and "task-notification" not in last_human:
+ return AIMessage(content="UNRELATED")
+ if "Terminal background completion notifications require an explicit assistant followthrough." in system_text:
+ return AIMessage(content="FOLLOWTHROUGH_ACK")
+ return AIMessage(content="")
+
+
+class _TerminalFollowthroughSilentModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ last_human = next(
+ (msg.content for msg in reversed(messages) if msg.__class__.__name__ == "HumanMessage"),
+ "",
+ )
+ if "CommandNotification" in last_human or "task-notification" in last_human:
+ return AIMessage(content="")
+ return AIMessage(content="UNRELATED")
+
+
+class _ChatNotificationSilentModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ last_human = next(
+ (msg.content for msg in reversed(messages) if msg.__class__.__name__ == "HumanMessage"),
+ "",
+ )
+ if "New message from" in last_human and "read_messages(chat_id=" in last_human:
+ return AIMessage(content="")
+ return AIMessage(content="UNRELATED")
+
+
+class _PromptTooLongTwiceModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ raise RuntimeError("prompt is too long")
+
+
+class _QueryOkWithFailingCompactorModel:
+ def bind_tools(self, tools):
+ return self
+
+ def bind(self, **kwargs):
+ return self
+
+ async def ainvoke(self, messages):
+ system_text = ""
+ if messages and messages[0].__class__.__name__ == "SystemMessage":
+ system_text = getattr(messages[0], "content", "") or ""
+ if "tasked with summarizing conversations" in system_text or "split turn" in system_text.lower():
+ raise RuntimeError("compaction failed")
+ return AIMessage(content="OK")
+
+
+class _BridgeReactiveCompactMiddleware:
+ compact_boundary_index = 1
+
+ async def compact_messages_for_recovery(self, messages):
+ return [SystemMessage(content="[Conversation Summary]\nSUMMARY")] + list(messages[-1:])
+
+
+class _ToolSearchInlineSelectModel:
+ def __init__(self) -> None:
+ self._turn = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ if self._turn == 0:
+ self._turn += 1
+ return AIMessage(
+ content="",
+ tool_calls=[{"name": "tool_search", "args": {"query": "select:Read,TaskCreate"}, "id": "tc-search"}],
+ )
+ return AIMessage(content="after-inline-select")
+
+
+class _ToolThenConcurrencyLimitModel:
+ def __init__(self) -> None:
+ self._turn = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ if self._turn == 0:
+ self._turn += 1
+ return AIMessage(
+ content="",
+ tool_calls=[{"name": "Write", "args": {"file_path": "/tmp/demo.txt", "content": "hi"}, "id": "tc-write"}],
+ )
+ raise RuntimeError("Concurrency limit exceeded for user, please retry later")
+
+
+class _SteerAwareTerminalModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ last_human = next(
+ (msg.content for msg in reversed(messages) if msg.__class__.__name__ == "HumanMessage"),
+ "",
+ )
+ return AIMessage(content="STEER_DONE" if last_human == "Stop and just say STEER_DONE." else "UNKNOWN")
+
+
+class _StopHonestyAwareModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ system_text = ""
+ if messages and messages[0].__class__.__name__ == "SystemMessage":
+ system_text = getattr(messages[0], "content", "") or ""
+ last_human = next(
+ (msg.content for msg in reversed(messages) if msg.__class__.__name__ == "HumanMessage"),
+ "",
+ )
+ if last_human != "Stop immediately. Do not continue the old task. Reply exactly STOPPED_NOW and do not write any file.":
+ return AIMessage(content="UNKNOWN")
+ if "Steer requests accepted during an active run are non-preemptive." in system_text:
+ return AIMessage(content="STOP_ACK_AFTER_COMPLETED_WORK")
+ return AIMessage(content="STOPPED_NOW")
+
+
+class _SteerCancelPoisonModel:
+ def __init__(self) -> None:
+ self._turn = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ if self._turn == 0:
+ self._turn += 1
+ return AIMessage(
+ content="",
+ tool_calls=[{"name": "SleepTool", "args": {}, "id": "tc-sleep"}],
+ )
+ last_human = next(
+ (msg.content for msg in reversed(messages) if msg.__class__.__name__ == "HumanMessage"),
+ "",
+ )
+ return AIMessage(content=f"LAST_HUMAN:{last_human}")
+
+
+class _FakeDisplayBuilder:
+ def __init__(self, cached_entries):
+ self._cached_entries = cached_entries
+ self.rebuilt_with: tuple[str, list[dict]] | None = None
+
+ def get_entries(self, thread_id: str):
+ return self._cached_entries
+
+ def build_from_checkpoint(self, thread_id: str, messages: list[dict]):
+ self.rebuilt_with = (thread_id, messages)
+ return [{"id": "rebuilt-notice", "role": "notice", "content": "rebuilt"}]
+
+ def get_display_seq(self, thread_id: str) -> int:
+ return 7
+
+
+class _StreamingGraphAgent:
+ checkpointer = None
+
+ async def aget_state(self, _config):
+ return SimpleNamespace(values={"messages": []})
+
+ async def astream(self, *_args, **_kwargs):
+ if False:
+ yield None
+
+
+class _NoResumeGraphAgent(_StreamingGraphAgent):
+ def __init__(self) -> None:
+ self.astream_calls = 0
+ self.aupdate_calls = 0
+
+ async def aupdate_state(self, *_args, **_kwargs):
+ self.aupdate_calls += 1
+
+ async def astream(self, *_args, **_kwargs):
+ self.astream_calls += 1
+ if False:
+ yield None
+ return
+
+
+class _StreamingRuntime:
+ current_state = AgentState.IDLE
+
+ def __init__(self) -> None:
+ self.current_run_source = None
+ self._event_callback = None
+ self.state = SimpleNamespace(flags=SimpleNamespace(is_compacting=False))
+
+ def set_event_callback(self, cb) -> None:
+ self._event_callback = cb
+
+ def bind_thread(self, *, activity_sink=None) -> None:
+ self._activity_sink = activity_sink
+
+ def get_status_dict(self) -> dict[str, object]:
+ return {"state": {"state": "idle", "flags": {}}}
+
+ def transition(self, new_state) -> bool:
+ valid = {
+ AgentState.IDLE: {AgentState.ACTIVE},
+ AgentState.ACTIVE: {AgentState.IDLE},
+ }
+ if new_state not in valid.get(self.current_state, set()):
+ return False
+ self.current_state = new_state
+ return True
+
+
+async def _wait_for_followthrough_text(loop: QueryLoop, thread_id: str, expected: str) -> None:
+ for _ in range(100):
+ state = await loop.aget_state({"configurable": {"thread_id": thread_id}})
+ messages = state.values.get("messages", []) if state and state.values else []
+ if any(msg.__class__.__name__ == "AIMessage" and getattr(msg, "content", None) == expected for msg in messages):
+ return
+ await asyncio.sleep(0.01)
+ raise AssertionError(f"followthrough text not observed: {expected}")
+
+
+def _make_loop(
+ *,
+ text: str = "done",
+ model=None,
+ registry: ToolRegistry | None = None,
+ checkpointer: _MemoryCheckpointer | None = None,
+ middleware: list | None = None,
+) -> QueryLoop:
+ return QueryLoop(
+ model=model or _NoToolModel(text=text),
+ system_prompt=SystemMessage(content="sys"),
+ middleware=middleware or [],
+ checkpointer=checkpointer,
+ registry=registry or ToolRegistry(),
+ app_state=AppState(),
+ runtime=None,
+ bootstrap=BootstrapConfig(workspace_root=Path("/tmp"), model_name="test-model"),
+ max_turns=5,
+ )
+
+
+def _patch_streaming_event_store(monkeypatch: pytest.MonkeyPatch) -> None:
+ seq = 0
+
+ async def fake_append_event(thread_id, run_id, event, message_id=None, run_event_repo=None):
+ nonlocal seq
+ seq += 1
+ return seq
+
+ async def fake_cleanup_old_runs(thread_id, keep_latest=1, run_event_repo=None):
+ return 0
+
+ monkeypatch.setattr("backend.web.services.event_store.append_event", fake_append_event)
+ monkeypatch.setattr("backend.web.services.streaming_service.cleanup_old_runs", fake_cleanup_old_runs)
+
+
+def _patch_direct_streaming(monkeypatch: pytest.MonkeyPatch) -> None:
+ _patch_streaming_event_store(monkeypatch)
+ monkeypatch.setattr("backend.web.services.streaming_service._ensure_thread_handlers", lambda *args, **kwargs: None)
+
+
+def _make_streaming_agent(loop: QueryLoop, *, queue_manager: MessageQueueManager | None = None) -> SimpleNamespace:
+ agent = SimpleNamespace(
+ agent=loop,
+ runtime=_StreamingRuntime(),
+ storage_container=None,
+ )
+ if queue_manager is not None:
+ agent.queue_manager = queue_manager
+ return agent
+
+
+def _make_streaming_app(
+ tmp_path: Path,
+ *,
+ thread_id: str | None = None,
+ agent: SimpleNamespace | None = None,
+ queue_manager: MessageQueueManager | None = None,
+ include_route_locks: bool = False,
+) -> tuple[SimpleNamespace, MessageQueueManager]:
+ queue_manager = queue_manager or MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ state = SimpleNamespace(
+ display_builder=DisplayBuilder(),
+ thread_tasks={},
+ thread_event_buffers={},
+ subagent_buffers={},
+ queue_manager=queue_manager,
+ thread_last_active={},
+ typing_tracker=None,
+ )
+ if thread_id is not None and agent is not None:
+ state.agent_pool = {f"{thread_id}:local": agent}
+ state.thread_sandbox = {thread_id: "local"}
+ state._event_loop = asyncio.get_running_loop()
+ if include_route_locks:
+ state.thread_locks = {}
+ state.thread_locks_guard = asyncio.Lock()
+ return SimpleNamespace(state=state), queue_manager
+
+
+def _make_direct_streaming_context(
+ tmp_path: Path,
+ loop: QueryLoop,
+ *,
+ queue_manager: MessageQueueManager | None = None,
+) -> tuple[SimpleNamespace, SimpleNamespace, ThreadEventBuffer]:
+ agent = _make_streaming_agent(loop, queue_manager=queue_manager)
+ app, _ = _make_streaming_app(tmp_path, queue_manager=queue_manager)
+ return agent, app, ThreadEventBuffer()
+
+
+def _make_route_followthrough_context(
+ tmp_path: Path,
+ *,
+ thread_id: str,
+ loop: QueryLoop,
+) -> tuple[MessageQueueManager, SimpleNamespace, SimpleNamespace]:
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ agent = _make_streaming_agent(loop, queue_manager=queue_manager)
+ app, _ = _make_streaming_app(tmp_path, thread_id=thread_id, agent=agent, queue_manager=queue_manager)
+ _ensure_thread_handlers(agent, thread_id, app)
+ return queue_manager, agent, app
+
+
+async def _run_direct_notification_followthrough(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+ *,
+ loop: QueryLoop,
+ thread_id: str,
+ message: str,
+ run_id: str,
+ message_metadata: dict[str, str] | None = None,
+) -> list[dict]:
+ _patch_direct_streaming(monkeypatch)
+ agent, app, thread_buf = _make_direct_streaming_context(tmp_path, loop)
+
+ await _run_agent_to_buffer(
+ agent,
+ thread_id,
+ message,
+ app,
+ False,
+ thread_buf,
+ run_id,
+ message_metadata=message_metadata,
+ )
+
+ entries = app.state.display_builder.get_entries(thread_id)
+ assert entries is not None
+ return entries
+
+
+def _assert_notice_then_text(entries: list[dict], notice_contains: str, expected_text: str) -> None:
+ assert entries[0]["segments"][0]["type"] == "notice"
+ assert notice_contains in entries[0]["segments"][0]["content"]
+ assert entries[0]["segments"][1] == {"type": "text", "content": expected_text}
+
+
+async def _get_local_thread_history(thread_id: str, *, agent: SimpleNamespace, app: SimpleNamespace) -> dict:
+ with (
+ patch.object(threads_router, "get_or_create_agent", return_value=agent),
+ patch.object(threads_router, "resolve_thread_sandbox", return_value="local"),
+ ):
+ return await get_thread_history(thread_id, limit=20, truncate=400, user_id="u", app=app)
+
+
+def _patch_fake_event_bus(monkeypatch: pytest.MonkeyPatch) -> None:
+ class _FakeEventBus:
+ def subscribe(self, *_args, **_kwargs):
+ return None
+
+ def make_emitter(self, **_kwargs):
+ async def _emit(_event):
+ return None
+
+ return _emit
+
+ monkeypatch.setattr("backend.web.event_bus.get_event_bus", lambda: _FakeEventBus())
+
+
+@pytest.mark.asyncio
+async def test_repair_incomplete_tool_calls_uses_query_loop_state_bridge():
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(checkpointer=checkpointer)
+ broken_ai = AIMessage(
+ content="",
+ tool_calls=[{"name": "Read", "args": {"file_path": "/tmp/a.txt"}, "id": "tc-1"}],
+ )
+ trailing = HumanMessage(content="after tool")
+ trailing.id = "human-after"
+ checkpointer.store["repair-live-thread"] = {"channel_values": {"messages": [broken_ai, trailing]}}
+
+ await _repair_incomplete_tool_calls(
+ SimpleNamespace(agent=loop),
+ {"configurable": {"thread_id": "repair-live-thread"}},
+ )
+
+ state = await loop.aget_state({"configurable": {"thread_id": "repair-live-thread"}})
+
+ assert [msg.__class__.__name__ for msg in state.values["messages"]] == [
+ "AIMessage",
+ "ToolMessage",
+ "HumanMessage",
+ ]
+ assert [getattr(msg, "content", None) for msg in state.values["messages"]] == [
+ "",
+ "Error: task was interrupted (server restart or timeout). Results unavailable.",
+ "after tool",
+ ]
+
+
+@pytest.mark.asyncio
+async def test_get_thread_history_reads_messages_via_query_loop_state_bridge():
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(text="history reply", checkpointer=checkpointer)
+ config = {"configurable": {"thread_id": "history-thread"}}
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "hello"}]},
+ config=config,
+ ):
+ pass
+
+ fake_agent = SimpleNamespace(agent=loop)
+ fake_app = SimpleNamespace(state=SimpleNamespace())
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ ):
+ history = await get_thread_history(
+ "history-thread",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert history["total"] == 2
+ assert history["thread_id"] == "history-thread"
+ assert [item["role"] for item in history["messages"]] == ["human", "assistant"]
+ assert history["messages"][1]["text"] == "history reply"
+
+
+@pytest.mark.asyncio
+async def test_get_thread_history_skips_empty_ai_messages_after_notifications():
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(checkpointer=checkpointer)
+ system_notice = HumanMessage(
+ content="errorAgent failed"
+ )
+ system_notice.metadata = {"source": "system"}
+ checkpointer.store["history-empty-ai-thread"] = {
+ "channel_values": {
+ "messages": [
+ HumanMessage(content="launch background task"),
+ system_notice,
+ AIMessage(content=""),
+ ]
+ }
+ }
+
+ fake_agent = SimpleNamespace(agent=loop)
+ fake_app = SimpleNamespace(state=SimpleNamespace())
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ ):
+ history = await get_thread_history(
+ "history-empty-ai-thread",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert [item["role"] for item in history["messages"]] == ["human", "notification"]
+ assert history["messages"][-1]["text"].startswith("")
+
+
+@pytest.mark.asyncio
+async def test_get_thread_history_retains_tool_search_inline_select_error():
+ checkpointer = _MemoryCheckpointer()
+ registry = ToolRegistry()
+ registry.register(
+ ToolEntry(
+ name="Read",
+ mode=ToolMode.INLINE,
+ schema={"name": "Read", "description": "read file"},
+ handler=lambda **_: "read",
+ source="test",
+ )
+ )
+ registry.register(
+ ToolEntry(
+ name="TaskCreate",
+ mode=ToolMode.DEFERRED,
+ schema={"name": "TaskCreate", "description": "create task"},
+ handler=lambda **_: "task",
+ source="test",
+ )
+ )
+ ToolSearchService(registry)
+ loop = _make_loop(
+ model=_ToolSearchInlineSelectModel(),
+ registry=registry,
+ checkpointer=checkpointer,
+ )
+ config = {"configurable": {"thread_id": "history-tool-search-inline-select"}}
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "probe inline select"}]},
+ config=config,
+ ):
+ pass
+
+ fake_agent = SimpleNamespace(agent=loop)
+ fake_app = SimpleNamespace(state=SimpleNamespace())
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ ):
+ history = await get_thread_history(
+ "history-tool-search-inline-select",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert [item["role"] for item in history["messages"]] == ["human", "tool_call", "tool_result", "assistant"]
+ assert history["messages"][1]["tool"] == "tool_search"
+ assert "" in history["messages"][2]["text"]
+ assert "inline/already-available tools: Read" in history["messages"][2]["text"]
+ assert history["messages"][3]["text"] == "after-inline-select"
+
+
+@pytest.mark.asyncio
+async def test_get_thread_history_persists_visible_assistant_error_after_model_failure():
+ checkpointer = _MemoryCheckpointer()
+ registry = ToolRegistry()
+ registry.register(
+ ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "description": "write file"},
+ handler=lambda **_: "FILE_WRITTEN",
+ source="test",
+ )
+ )
+ loop = _make_loop(
+ model=_ToolThenConcurrencyLimitModel(),
+ registry=registry,
+ checkpointer=checkpointer,
+ )
+ config = {"configurable": {"thread_id": "history-visible-model-error"}}
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "write once, then continue"}]},
+ config=config,
+ ):
+ pass
+
+ fake_agent = SimpleNamespace(agent=loop)
+ fake_app = SimpleNamespace(state=SimpleNamespace())
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ ):
+ history = await get_thread_history(
+ "history-visible-model-error",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert [item["role"] for item in history["messages"]] == ["human", "tool_call", "tool_result", "assistant"]
+ assert history["messages"][-1]["text"] == "Error: Concurrency limit exceeded for user, please retry later"
+
+
+@pytest.mark.asyncio
+async def test_query_loop_persists_visible_terminal_followthrough_when_system_notification_resume_is_silent():
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(text="", checkpointer=checkpointer)
+ system_notice = HumanMessage(
+ content="errorAgent failed"
+ )
+ system_notice.metadata = {"source": "system", "notification_type": "agent"}
+ checkpointer.store["resume-empty-ai-thread"] = {
+ "channel_values": {
+ "messages": [
+ HumanMessage(content="launch background task"),
+ system_notice,
+ ]
+ }
+ }
+
+ async for _ in loop.query(
+ None,
+ config={"configurable": {"thread_id": "resume-empty-ai-thread"}},
+ ):
+ pass
+
+ state = await loop.aget_state({"configurable": {"thread_id": "resume-empty-ai-thread"}})
+
+ assert [msg.__class__.__name__ for msg in state.values["messages"]] == [
+ "HumanMessage",
+ "HumanMessage",
+ "AIMessage",
+ ]
+ assert state.values["messages"][-2].content.startswith("")
+ assert state.values["messages"][-1].content == "Background agent failed, but the followthrough assistant reply was empty."
+
+
+@pytest.mark.asyncio
+async def test_query_loop_persists_midrun_steer_message_into_checkpoint_state(tmp_path):
+ checkpointer = _MemoryCheckpointer()
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ queue_manager.enqueue(
+ "Stop and just say STEER_DONE.",
+ "steer-persist-thread",
+ notification_type="steer",
+ source="owner",
+ is_steer=True,
+ )
+ runtime = SimpleNamespace(events=[], emit_activity_event=lambda event: runtime.events.append(event))
+ loop = _make_loop(
+ model=_SteerAwareTerminalModel(),
+ checkpointer=checkpointer,
+ middleware=[SteeringMiddleware(queue_manager=queue_manager, agent_runtime=runtime)],
+ )
+ checkpointer.store["steer-persist-thread"] = {
+ "channel_values": {
+ "messages": [
+ HumanMessage(content="Use Bash to run `sleep 20; echo LONG_PHASE_DONE`, then reply exactly ORIGINAL_DONE."),
+ AIMessage(
+ content="",
+ tool_calls=[{"name": "Bash", "args": {"command": "sleep 20; echo LONG_PHASE_DONE"}, "id": "tc-bash"}],
+ ),
+ ToolMessage(content="LONG_PHASE_DONE", name="Bash", tool_call_id="tc-bash"),
+ ]
+ }
+ }
+
+ async for _ in loop.query(None, config={"configurable": {"thread_id": "steer-persist-thread"}}):
+ pass
+
+ state = await loop.aget_state({"configurable": {"thread_id": "steer-persist-thread"}})
+ persisted = state.values["messages"]
+
+ assert [msg.__class__.__name__ for msg in persisted] == [
+ "HumanMessage",
+ "AIMessage",
+ "ToolMessage",
+ "HumanMessage",
+ "AIMessage",
+ ]
+ assert persisted[3].content == "Stop and just say STEER_DONE."
+ assert persisted[3].metadata["source"] == "owner"
+ assert persisted[3].metadata["is_steer"] is True
+ assert persisted[4].content == "STEER_DONE"
+
+
+@pytest.mark.asyncio
+async def test_get_thread_history_rebuilds_persisted_midrun_steer_message(tmp_path):
+ checkpointer = _MemoryCheckpointer()
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ queue_manager.enqueue(
+ "Stop and just say STEER_DONE.",
+ "steer-history-thread",
+ notification_type="steer",
+ source="owner",
+ is_steer=True,
+ )
+ runtime = SimpleNamespace(events=[], emit_activity_event=lambda event: runtime.events.append(event))
+ loop = _make_loop(
+ model=_SteerAwareTerminalModel(),
+ checkpointer=checkpointer,
+ middleware=[SteeringMiddleware(queue_manager=queue_manager, agent_runtime=runtime)],
+ )
+ checkpointer.store["steer-history-thread"] = {
+ "channel_values": {
+ "messages": [
+ HumanMessage(content="Use Bash to run `sleep 20; echo LONG_PHASE_DONE`, then reply exactly ORIGINAL_DONE."),
+ AIMessage(
+ content="",
+ tool_calls=[{"name": "Bash", "args": {"command": "sleep 20; echo LONG_PHASE_DONE"}, "id": "tc-bash"}],
+ ),
+ ToolMessage(content="LONG_PHASE_DONE", name="Bash", tool_call_id="tc-bash"),
+ ]
+ }
+ }
+
+ async for _ in loop.query(None, config={"configurable": {"thread_id": "steer-history-thread"}}):
+ pass
+
+ fake_agent = SimpleNamespace(agent=loop)
+ fake_app = SimpleNamespace(state=SimpleNamespace())
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ ):
+ history = await get_thread_history(
+ "steer-history-thread",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert [item["role"] for item in history["messages"]] == [
+ "human",
+ "tool_call",
+ "tool_result",
+ "human",
+ "assistant",
+ ]
+ assert history["messages"][3]["text"] == "Stop and just say STEER_DONE."
+ assert history["messages"][4]["text"] == "STEER_DONE"
+
+
+@pytest.mark.asyncio
+async def test_query_loop_adds_non_preemptive_steer_contract_before_terminal_reply(tmp_path):
+ checkpointer = _MemoryCheckpointer()
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ queue_manager.enqueue(
+ "Stop immediately. Do not continue the old task. Reply exactly STOPPED_NOW and do not write any file.",
+ "steer-stop-honesty-thread",
+ notification_type="steer",
+ source="owner",
+ is_steer=True,
+ )
+ runtime = SimpleNamespace(events=[], emit_activity_event=lambda event: runtime.events.append(event))
+ loop = _make_loop(
+ model=_StopHonestyAwareModel(),
+ checkpointer=checkpointer,
+ middleware=[SteeringMiddleware(queue_manager=queue_manager, agent_runtime=runtime)],
+ )
+ checkpointer.store["steer-stop-honesty-thread"] = {
+ "channel_values": {
+ "messages": [
+ HumanMessage(content="Run the long bash."),
+ AIMessage(
+ content="",
+ tool_calls=[{"name": "Bash", "args": {"command": "sleep 15; echo LONG_PHASE_DONE"}, "id": "tc-bash"}],
+ ),
+ ToolMessage(content="LONG_PHASE_DONE", name="Bash", tool_call_id="tc-bash"),
+ ]
+ }
+ }
+
+ async for _ in loop.query(None, config={"configurable": {"thread_id": "steer-stop-honesty-thread"}}):
+ pass
+
+ state = await loop.aget_state({"configurable": {"thread_id": "steer-stop-honesty-thread"}})
+ persisted = state.values["messages"]
+
+ assert [msg.__class__.__name__ for msg in persisted] == [
+ "HumanMessage",
+ "AIMessage",
+ "ToolMessage",
+ "HumanMessage",
+ "AIMessage",
+ ]
+ assert persisted[3].content == "Stop immediately. Do not continue the old task. Reply exactly STOPPED_NOW and do not write any file."
+ assert persisted[4].content == "STOP_ACK_AFTER_COMPLETED_WORK"
+
+
+@pytest.mark.asyncio
+async def test_cancelled_midrun_steer_persists_and_does_not_poison_next_turn(monkeypatch, tmp_path):
+ monkeypatch.setattr("backend.web.services.streaming_service._ensure_thread_handlers", lambda *args, **kwargs: None)
+ checkpointer = _MemoryCheckpointer()
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ runtime = _StreamingRuntime()
+ tool_started = asyncio.Event()
+
+ async def sleep_tool() -> str:
+ tool_started.set()
+ try:
+ await asyncio.sleep(60)
+ except asyncio.CancelledError:
+ raise
+ return "SLEPT"
+
+ registry = ToolRegistry()
+ registry.register(
+ ToolEntry(
+ name="SleepTool",
+ mode=ToolMode.INLINE,
+ schema={"name": "SleepTool", "description": "sleep", "parameters": {}},
+ handler=sleep_tool,
+ source="test",
+ )
+ )
+ loop = _make_loop(
+ model=_SteerCancelPoisonModel(),
+ registry=registry,
+ checkpointer=checkpointer,
+ middleware=[SteeringMiddleware(queue_manager=queue_manager, agent_runtime=runtime)],
+ )
+ agent = SimpleNamespace(
+ agent=loop,
+ runtime=runtime,
+ storage_container=None,
+ )
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ display_builder=DisplayBuilder(),
+ thread_tasks={},
+ thread_event_buffers={},
+ subagent_buffers={},
+ queue_manager=queue_manager,
+ thread_last_active={},
+ typing_tracker=None,
+ )
+ )
+ thread_id = "steer-cancel-poison-thread"
+ config = {"configurable": {"thread_id": thread_id}}
+
+ start_agent_run(agent, thread_id, "start", app)
+ task = app.state.thread_tasks[thread_id]
+
+ await asyncio.wait_for(tool_started.wait(), timeout=2)
+ queue_manager.enqueue(
+ "Stop and just say STEER_DONE.",
+ thread_id,
+ notification_type="steer",
+ source="owner",
+ is_steer=True,
+ )
+
+ task.cancel()
+ await asyncio.gather(task, return_exceptions=True)
+
+ assert queue_manager.list_queue(thread_id) == []
+ assert app.state.thread_tasks.get(thread_id) is None
+ assert runtime.current_state == AgentState.IDLE
+
+ state_after_cancel = await loop.aget_state(config)
+ cancelled_contents = [getattr(msg, "content", "") for msg in state_after_cancel.values["messages"]]
+ assert cancelled_contents[:2] == ["start", "Stop and just say STEER_DONE."]
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "fresh user message"}]},
+ config=config,
+ ):
+ pass
+
+ final_state = await loop.aget_state(config)
+ final_contents = [getattr(msg, "content", "") for msg in final_state.values["messages"]]
+ assert final_contents == [
+ "start",
+ "Stop and just say STEER_DONE.",
+ "fresh user message",
+ "LAST_HUMAN:fresh user message",
+ ]
+
+
+@pytest.mark.asyncio
+async def test_get_thread_messages_rebuilds_idle_thread_when_cached_entries_are_stale():
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(text="history reply", checkpointer=checkpointer)
+ config = {"configurable": {"thread_id": "detail-thread"}}
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "hello"}]},
+ config=config,
+ ):
+ pass
+
+ display_builder = _FakeDisplayBuilder(cached_entries=[{"id": "stale-turn", "role": "assistant", "segments": []}])
+ fake_agent = SimpleNamespace(
+ agent=loop,
+ runtime=SimpleNamespace(current_state=AgentState.IDLE),
+ )
+ fake_app = SimpleNamespace(state=SimpleNamespace(display_builder=display_builder))
+
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.routers.threads.get_sandbox_info", return_value={"type": "local"}),
+ ):
+ detail = await get_thread_messages(
+ "detail-thread",
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert detail["entries"] == [{"id": "rebuilt-notice", "role": "notice", "content": "rebuilt"}]
+ assert display_builder.rebuilt_with is not None
+ rebuilt_thread_id, rebuilt_messages = display_builder.rebuilt_with
+ assert rebuilt_thread_id == "detail-thread"
+ assert [msg["type"] for msg in rebuilt_messages] == ["HumanMessage", "AIMessage"]
+
+
+@pytest.mark.asyncio
+async def test_get_thread_messages_idle_rebuild_replays_latest_run_error_from_event_log():
+ human = HumanMessage(content="hello")
+ fake_agent = SimpleNamespace(
+ agent=SimpleNamespace(aget_state=AsyncMock(return_value=SimpleNamespace(values={"messages": [human]}))),
+ runtime=SimpleNamespace(current_state=AgentState.IDLE),
+ )
+ fake_app = SimpleNamespace(state=SimpleNamespace(display_builder=DisplayBuilder()))
+ run_events = [
+ {
+ "seq": 1,
+ "event": "run_start",
+ "data": json.dumps(
+ {
+ "thread_id": "detail-thread",
+ "run_id": "run-error-1",
+ "source": "owner",
+ "showing": True,
+ }
+ ),
+ "message_id": None,
+ },
+ {
+ "seq": 2,
+ "event": "error",
+ "data": json.dumps({"error": "quota exploded"}),
+ "message_id": None,
+ },
+ {
+ "seq": 3,
+ "event": "run_done",
+ "data": json.dumps({"thread_id": "detail-thread", "run_id": "run-error-1"}),
+ "message_id": None,
+ },
+ ]
+
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.routers.threads.get_sandbox_info", return_value={"type": "local"}),
+ patch("backend.web.services.event_store.get_latest_run_id", AsyncMock(return_value="run-error-1")),
+ patch("backend.web.services.event_store.read_events_after", AsyncMock(return_value=run_events)),
+ ):
+ detail = await get_thread_messages(
+ "detail-thread",
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert detail["entries"][0]["role"] == "user"
+ assert any(
+ entry.get("role") == "assistant"
+ and any(segment.get("type") == "text" and "quota exploded" in segment.get("content", "") for segment in entry.get("segments", []))
+ for entry in detail["entries"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_cold_rebuild_surfaces_persisted_compaction_notice_in_detail_and_history():
+ checkpointer = _MemoryCheckpointer()
+ summary_model = MagicMock()
+ summary_model.bind.return_value = summary_model
+ summary_model.ainvoke = AsyncMock(return_value=AIMessage(content="SUMMARY"))
+ memory = MemoryMiddleware(
+ context_limit=40,
+ compaction_config=SimpleNamespace(reserve_tokens=0, keep_recent_tokens=10),
+ compaction_threshold=0.1,
+ )
+ memory.set_model(summary_model)
+ loop = _make_loop(
+ text="after compact",
+ checkpointer=checkpointer,
+ middleware=[memory],
+ )
+ config = {"configurable": {"thread_id": "compact-thread"}}
+
+ history = [
+ HumanMessage(content="A" * 80),
+ AIMessage(content="B" * 80),
+ HumanMessage(content="C" * 80),
+ HumanMessage(content="hello after compact"),
+ ]
+
+ async for _ in loop.query({"messages": history}, config=config):
+ pass
+
+ fake_agent = SimpleNamespace(
+ agent=loop,
+ runtime=SimpleNamespace(current_state=AgentState.IDLE),
+ )
+ fake_app = SimpleNamespace(state=SimpleNamespace(display_builder=DisplayBuilder()))
+
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.routers.threads.get_sandbox_info", return_value={"type": "local"}),
+ ):
+ detail = await get_thread_messages(
+ "compact-thread",
+ user_id="u",
+ app=fake_app,
+ )
+ rebuilt_history = await get_thread_history(
+ "compact-thread",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert any(
+ any(segment.get("type") == "notice" and segment.get("notification_type") == "compact" for segment in entry.get("segments", []))
+ for entry in detail["entries"]
+ if entry.get("role") == "assistant"
+ )
+ assert any(
+ item.get("role") == "notification" and "Conversation compacted" in item.get("text", "") for item in rebuilt_history["messages"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_cold_rebuild_surfaces_persisted_prompt_too_long_notice_after_recovery_exhausts():
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(
+ model=_PromptTooLongTwiceModel(),
+ checkpointer=checkpointer,
+ middleware=[_BridgeReactiveCompactMiddleware()],
+ )
+ config = {"configurable": {"thread_id": "prompt-too-long-thread"}}
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "start"}]},
+ config=config,
+ ):
+ pass
+
+ fake_agent = SimpleNamespace(
+ agent=loop,
+ runtime=SimpleNamespace(current_state=AgentState.IDLE),
+ )
+ fake_app = SimpleNamespace(state=SimpleNamespace(display_builder=DisplayBuilder()))
+
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.routers.threads.get_sandbox_info", return_value={"type": "local"}),
+ ):
+ detail = await get_thread_messages(
+ "prompt-too-long-thread",
+ user_id="u",
+ app=fake_app,
+ )
+ rebuilt_history = await get_thread_history(
+ "prompt-too-long-thread",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert any(
+ any(
+ segment.get("type") == "notice" and "Prompt is too long. Automatic recovery exhausted." in segment.get("content", "")
+ for segment in entry.get("segments", [])
+ )
+ for entry in detail["entries"]
+ if entry.get("role") == "assistant"
+ )
+ assert any(
+ item.get("role") == "notification" and "Prompt is too long. Automatic recovery exhausted." in item.get("text", "")
+ for item in rebuilt_history["messages"]
+ )
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("task_status", "result_text"),
+ [
+ ("completed", "CHILD_DONE"),
+ ("error", "Agent failed"),
+ ("cancelled", "Agent cancelled"),
+ ],
+)
+async def test_get_thread_messages_idle_rebuild_keeps_terminal_subagent_stream_status(
+ task_status: str,
+ result_text: str,
+):
+ ai = AIMessage(
+ content="",
+ tool_calls=[{"name": "Agent", "args": {"prompt": "do work", "run_in_background": True}, "id": "tc-agent-1"}],
+ )
+ tool = ToolMessage(
+ content=(
+ '{"task_id":"task-123","agent_name":"agent-task-123",'
+ '"thread_id":"subagent-task-123","status":"running",'
+ '"message":"Agent started in background. Use TaskOutput to get result."}'
+ ),
+ name="Agent",
+ tool_call_id="tc-agent-1",
+ )
+ notice = HumanMessage(
+ content=(
+ "\n"
+ "\n"
+ " task-123\n"
+ f" {task_status}\n"
+ " child task\n"
+ " child task\n"
+ f" {result_text}\n"
+ "\n"
+ ""
+ )
+ )
+ notice.metadata = {"source": "system", "notification_type": "agent"}
+
+ fake_agent = SimpleNamespace(
+ agent=SimpleNamespace(aget_state=AsyncMock(return_value=SimpleNamespace(values={"messages": [ai, tool, notice]}))),
+ runtime=SimpleNamespace(current_state=AgentState.IDLE),
+ )
+ fake_app = SimpleNamespace(state=SimpleNamespace(display_builder=DisplayBuilder()))
+
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.routers.threads.get_sandbox_info", return_value={"type": "local"}),
+ ):
+ detail = await get_thread_messages(
+ "parent-thread",
+ user_id="u",
+ app=fake_app,
+ )
+
+ seg = detail["entries"][0]["segments"][0]
+ assert seg["step"]["subagent_stream"]["task_id"] == "task-123"
+ assert seg["step"]["subagent_stream"]["thread_id"] == "subagent-task-123"
+ assert seg["step"]["subagent_stream"]["status"] == task_status
+
+
+@pytest.mark.asyncio
+async def test_compaction_clear_then_recovery_notice_rebuilds_honestly(tmp_path):
+ checkpointer = _MemoryCheckpointer()
+ summary_model = MagicMock()
+ summary_model.bind.return_value = summary_model
+ summary_model.ainvoke = AsyncMock(return_value=AIMessage(content="SUMMARY"))
+
+ memory = MemoryMiddleware(
+ context_limit=40,
+ compaction_config=SimpleNamespace(reserve_tokens=0, keep_recent_tokens=10),
+ compaction_threshold=0.1,
+ db_path=tmp_path / "compaction-lifecycle.db",
+ )
+ memory.set_model(summary_model)
+ config = {"configurable": {"thread_id": "compaction-lifecycle-thread"}}
+ compact_loop = _make_loop(
+ text="after compact",
+ checkpointer=checkpointer,
+ middleware=[memory],
+ )
+
+ history = [
+ HumanMessage(content="A" * 80),
+ AIMessage(content="B" * 80),
+ HumanMessage(content="C" * 80),
+ HumanMessage(content="hello after compact"),
+ ]
+
+ async for _ in compact_loop.query({"messages": history}, config=config):
+ pass
+
+ assert memory.summary_store is not None
+ assert memory.summary_store.get_latest_summary("compaction-lifecycle-thread") is not None
+
+ fake_app = SimpleNamespace(state=SimpleNamespace(display_builder=DisplayBuilder()))
+ fake_agent = SimpleNamespace(
+ agent=compact_loop,
+ runtime=SimpleNamespace(current_state=AgentState.IDLE),
+ )
+
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.routers.threads.get_sandbox_info", return_value={"type": "local"}),
+ ):
+ compact_detail = await get_thread_messages(
+ "compaction-lifecycle-thread",
+ user_id="u",
+ app=fake_app,
+ )
+ compact_history = await get_thread_history(
+ "compaction-lifecycle-thread",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert any(
+ item.get("role") == "notification" and "Conversation compacted" in item.get("text", "") for item in compact_history["messages"]
+ )
+ assert any(
+ any(
+ segment.get("type") == "notice" and "Conversation compacted" in segment.get("content", "")
+ for segment in entry.get("segments", [])
+ )
+ for entry in compact_detail["entries"]
+ if entry.get("role") == "assistant"
+ )
+
+ await compact_loop.aclear("compaction-lifecycle-thread")
+
+ assert memory.summary_store.get_latest_summary("compaction-lifecycle-thread") is None
+
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.routers.threads.get_sandbox_info", return_value={"type": "local"}),
+ ):
+ cleared_detail = await get_thread_messages(
+ "compaction-lifecycle-thread",
+ user_id="u",
+ app=fake_app,
+ )
+ cleared_history = await get_thread_history(
+ "compaction-lifecycle-thread",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert cleared_detail["entries"] == []
+ assert cleared_history["messages"] == []
+
+ recovery_loop = _make_loop(
+ model=_PromptTooLongTwiceModel(),
+ checkpointer=checkpointer,
+ middleware=[_BridgeReactiveCompactMiddleware()],
+ )
+ recovery_agent = SimpleNamespace(
+ agent=recovery_loop,
+ runtime=SimpleNamespace(current_state=AgentState.IDLE),
+ )
+
+ async for _ in recovery_loop.query(
+ {"messages": [{"role": "user", "content": "start"}]},
+ config=config,
+ ):
+ pass
+
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=recovery_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.routers.threads.get_sandbox_info", return_value={"type": "local"}),
+ ):
+ recovery_detail = await get_thread_messages(
+ "compaction-lifecycle-thread",
+ user_id="u",
+ app=fake_app,
+ )
+ recovery_history = await get_thread_history(
+ "compaction-lifecycle-thread",
+ limit=20,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ notices = [item for item in recovery_history["messages"] if item.get("role") == "notification"]
+ assert notices == [
+ {
+ "role": "notification",
+ "text": "Prompt is too long. Automatic recovery exhausted. Clear the thread or start a new one.",
+ }
+ ]
+ assert not any("Conversation compacted" in item.get("text", "") for item in recovery_history["messages"])
+ assert any(
+ any(
+ segment.get("type") == "notice" and "Prompt is too long. Automatic recovery exhausted." in segment.get("content", "")
+ for segment in entry.get("segments", [])
+ )
+ for entry in recovery_detail["entries"]
+ if entry.get("role") == "assistant"
+ )
+
+
+@pytest.mark.asyncio
+async def test_cold_rebuild_surfaces_compaction_breaker_notice_after_repeated_failures(tmp_path):
+ checkpointer = _MemoryCheckpointer()
+ model = _QueryOkWithFailingCompactorModel()
+ memory = MemoryMiddleware(
+ context_limit=10000,
+ compaction_threshold=0.5,
+ db_path=tmp_path / "compaction-breaker.db",
+ compaction_config=SimpleNamespace(reserve_tokens=0, keep_recent_tokens=10),
+ )
+ memory.set_model(model)
+ loop = _make_loop(
+ model=model,
+ checkpointer=checkpointer,
+ middleware=[memory],
+ )
+ config = {"configurable": {"thread_id": "compaction-breaker-thread"}}
+
+ for attempt in range(3):
+ async for _ in loop.query(
+ {
+ "messages": [
+ {"role": "user", "content": "A" * 8000},
+ {"role": "assistant", "content": "B" * 8000},
+ {"role": "user", "content": f"start {attempt} " + ("C" * 8000)},
+ ]
+ },
+ config=config,
+ ):
+ pass
+
+ fake_agent = SimpleNamespace(
+ agent=loop,
+ runtime=SimpleNamespace(current_state=AgentState.IDLE),
+ )
+ fake_app = SimpleNamespace(state=SimpleNamespace(display_builder=DisplayBuilder()))
+
+ with (
+ patch("backend.web.routers.threads.get_or_create_agent", return_value=fake_agent),
+ patch("backend.web.routers.threads.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.routers.threads.get_sandbox_info", return_value={"type": "local"}),
+ ):
+ detail = await get_thread_messages(
+ "compaction-breaker-thread",
+ user_id="u",
+ app=fake_app,
+ )
+ rebuilt_history = await get_thread_history(
+ "compaction-breaker-thread",
+ limit=50,
+ truncate=300,
+ user_id="u",
+ app=fake_app,
+ )
+
+ assert any(
+ entry.get("role") == "assistant"
+ and any(
+ seg.get("type") == "notice"
+ and "Automatic compaction disabled for this thread after repeated failures." in seg.get("content", "")
+ for seg in entry.get("segments", [])
+ )
+ for entry in detail["entries"]
+ )
+ assert any(
+ item.get("role") == "notification"
+ and "Automatic compaction disabled for this thread after repeated failures." in item.get("text", "")
+ for item in rebuilt_history["messages"]
+ )
+
+
+@pytest.mark.asyncio
+async def test_run_agent_to_buffer_emits_notice_for_system_agent_notifications(monkeypatch, tmp_path):
+ seq = 0
+
+ async def fake_append_event(thread_id, run_id, event, message_id=None, run_event_repo=None):
+ nonlocal seq
+ seq += 1
+ return seq
+
+ async def fake_cleanup_old_runs(thread_id, keep_latest=1, run_event_repo=None):
+ return 0
+
+ monkeypatch.setattr("backend.web.services.event_store.append_event", fake_append_event)
+ monkeypatch.setattr("backend.web.services.streaming_service.cleanup_old_runs", fake_cleanup_old_runs)
+ monkeypatch.setattr("backend.web.services.streaming_service._ensure_thread_handlers", lambda *args, **kwargs: None)
+
+ agent = SimpleNamespace(
+ agent=_StreamingGraphAgent(),
+ runtime=_StreamingRuntime(),
+ storage_container=None,
+ )
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ display_builder=DisplayBuilder(),
+ thread_tasks={},
+ thread_event_buffers={},
+ subagent_buffers={},
+ queue_manager=MessageQueueManager(db_path=str(tmp_path / "queue.db")),
+ thread_last_active={},
+ typing_tracker=None,
+ )
+ )
+ thread_buf = ThreadEventBuffer()
+
+ await _run_agent_to_buffer(
+ agent,
+ "thread-notice",
+ "completed",
+ app,
+ False,
+ thread_buf,
+ "run-notice",
+ message_metadata={"source": "system", "notification_type": "agent"},
+ )
+
+ entries = app.state.display_builder.get_entries("thread-notice")
+ assert entries is not None
+ assert entries[0]["segments"] == [
+ {
+ "type": "notice",
+ "content": "completed",
+ "notification_type": "agent",
+ }
+ ]
+
+
+@pytest.mark.asyncio
+async def test_run_agent_to_buffer_persists_terminal_notifications_before_assistant_followthrough(monkeypatch, tmp_path):
+ seq = 0
+
+ async def fake_append_event(thread_id, run_id, event, message_id=None, run_event_repo=None):
+ nonlocal seq
+ seq += 1
+ return seq
+
+ async def fake_cleanup_old_runs(thread_id, keep_latest=1, run_event_repo=None):
+ return 0
+
+ monkeypatch.setattr("backend.web.services.event_store.append_event", fake_append_event)
+ monkeypatch.setattr("backend.web.services.streaming_service.cleanup_old_runs", fake_cleanup_old_runs)
+ monkeypatch.setattr("backend.web.services.streaming_service._ensure_thread_handlers", lambda *args, **kwargs: None)
+
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(checkpointer=checkpointer)
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ queue_manager.enqueue(
+ "errorAgent failed",
+ "thread-terminal-history",
+ notification_type="agent",
+ source="system",
+ )
+
+ agent = SimpleNamespace(
+ agent=loop,
+ runtime=_StreamingRuntime(),
+ storage_container=None,
+ )
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ display_builder=DisplayBuilder(),
+ thread_tasks={},
+ thread_event_buffers={},
+ subagent_buffers={},
+ queue_manager=queue_manager,
+ thread_last_active={},
+ typing_tracker=None,
+ )
+ )
+ thread_buf = ThreadEventBuffer()
+
+ await _run_agent_to_buffer(
+ agent,
+ "thread-terminal-history",
+ "completedBG_OK",
+ app,
+ False,
+ thread_buf,
+ "run-terminal-history",
+ message_metadata={"source": "system", "notification_type": "agent"},
+ )
+
+ state = await loop.aget_state({"configurable": {"thread_id": "thread-terminal-history"}})
+
+ assert [msg.__class__.__name__ for msg in state.values["messages"]] == [
+ "HumanMessage",
+ "HumanMessage",
+ "AIMessage",
+ ]
+ assert "BG_OK" in state.values["messages"][0].content
+ assert "Agent failed" in state.values["messages"][1].content
+ assert state.values["messages"][2].content == "done"
+
+
+@pytest.mark.asyncio
+async def test_run_agent_to_buffer_resumes_graph_for_terminal_background_notifications(monkeypatch, tmp_path):
+ seq = 0
+
+ async def fake_append_event(thread_id, run_id, event, message_id=None, run_event_repo=None):
+ nonlocal seq
+ seq += 1
+ return seq
+
+ async def fake_cleanup_old_runs(thread_id, keep_latest=1, run_event_repo=None):
+ return 0
+
+ monkeypatch.setattr("backend.web.services.event_store.append_event", fake_append_event)
+ monkeypatch.setattr("backend.web.services.streaming_service.cleanup_old_runs", fake_cleanup_old_runs)
+ monkeypatch.setattr("backend.web.services.streaming_service._ensure_thread_handlers", lambda *args, **kwargs: None)
+
+ graph = _NoResumeGraphAgent()
+ agent = SimpleNamespace(
+ agent=graph,
+ runtime=_StreamingRuntime(),
+ storage_container=None,
+ )
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ display_builder=DisplayBuilder(),
+ thread_tasks={},
+ thread_event_buffers={},
+ subagent_buffers={},
+ queue_manager=MessageQueueManager(db_path=str(tmp_path / "queue.db")),
+ thread_last_active={},
+ typing_tracker=None,
+ )
+ )
+ thread_buf = ThreadEventBuffer()
+
+ await _run_agent_to_buffer(
+ agent,
+ "thread-terminal-notice",
+ "completedBG_SEEN:RESULT:3",
+ app,
+ False,
+ thread_buf,
+ "run-terminal-notice",
+ message_metadata={"source": "system", "notification_type": "agent"},
+ )
+
+ assert graph.astream_calls == 1
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ (
+ "thread_id",
+ "run_id",
+ "message",
+ "message_metadata",
+ "notice_contains",
+ "expected_text",
+ ),
+ [
+ (
+ "thread-terminal-followthrough",
+ "run-terminal-followthrough",
+ "completedBG_OK",
+ {"source": "system", "notification_type": "agent"},
+ "BG_OK",
+ "AFTER_BG_DONE",
+ ),
+ (
+ "thread-command-followthrough",
+ "run-command-followthrough",
+ "completed",
+ {"source": "system", "notification_type": "command"},
+ "CommandNotification",
+ "AFTER_COMMAND_DONE",
+ ),
+ (
+ "thread-command-cancel-followthrough",
+ "run-command-cancel-followthrough",
+ 'cancelledcancelled task',
+ {"source": "system", "notification_type": "command"},
+ "cancelled",
+ "AFTER_COMMAND_CANCELLED",
+ ),
+ ],
+)
+async def test_run_agent_to_buffer_surfaces_notice_then_assistant_followthrough(
+ monkeypatch,
+ tmp_path,
+ thread_id: str,
+ run_id: str,
+ message: str,
+ message_metadata: dict[str, str],
+ notice_contains: str,
+ expected_text: str,
+):
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(text=expected_text, checkpointer=checkpointer)
+
+ entries = await _run_direct_notification_followthrough(
+ monkeypatch,
+ tmp_path,
+ loop=loop,
+ thread_id=thread_id,
+ message=message,
+ run_id=run_id,
+ message_metadata=message_metadata,
+ )
+
+ _assert_notice_then_text(entries, notice_contains, expected_text)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("thread_id", "message", "notification_type", "expected_notice", "expected_text"),
+ [
+ (
+ "thread-route-followthrough",
+ "completed",
+ "command",
+ "CommandNotification",
+ "AFTER_QUEUE_WAKE",
+ ),
+ (
+ "thread-route-agent-followthrough",
+ "completedSimple background tool testSimple Background Tool Test Done",
+ "agent",
+ "Simple Background Tool Test Done",
+ "AFTER_AGENT_WAKE",
+ ),
+ (
+ "thread-route-agent-error-followthrough",
+ "errorSimple background tool testAgent failed",
+ "agent",
+ "Agent failed",
+ "AFTER_AGENT_ERROR_WAKE",
+ ),
+ ],
+)
+async def test_queue_wake_handler_starts_terminal_followthrough_run(
+ monkeypatch,
+ tmp_path,
+ thread_id: str,
+ message: str,
+ notification_type: str,
+ expected_notice: str,
+ expected_text: str,
+):
+ _patch_streaming_event_store(monkeypatch)
+
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(text=expected_text, checkpointer=checkpointer)
+ queue_manager, agent, app = _make_route_followthrough_context(tmp_path, thread_id=thread_id, loop=loop)
+
+ queue_manager.enqueue(
+ message,
+ thread_id,
+ notification_type=notification_type,
+ source="system",
+ )
+
+ await _wait_for_followthrough_text(loop, thread_id, expected_text)
+ history = await _get_local_thread_history(thread_id, agent=agent, app=app)
+
+ assert [item["role"] for item in history["messages"]] == ["notification", "assistant"]
+ assert expected_notice in history["messages"][0]["text"]
+ assert history["messages"][1]["text"] == expected_text
+
+
+@pytest.mark.asyncio
+async def test_cancelled_task_notification_wakes_followthrough_run(monkeypatch, tmp_path):
+ _patch_streaming_event_store(monkeypatch)
+ _patch_fake_event_bus(monkeypatch)
+
+ thread_id = "thread-route-cancel-followthrough"
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(text="AFTER_CANCEL_WAKE", checkpointer=checkpointer)
+ queue_manager, agent, app = _make_route_followthrough_context(tmp_path, thread_id=thread_id, loop=loop)
+ run = SimpleNamespace(is_done=True, description="cancelled task", command="echo hi")
+ await threads_router._notify_task_cancelled(app, thread_id, "cmd-cancel", run)
+
+ await _wait_for_followthrough_text(loop, thread_id, "AFTER_CANCEL_WAKE")
+ history = await _get_local_thread_history(thread_id, agent=agent, app=app)
+ assert [item["role"] for item in history["messages"]] == ["notification", "assistant"]
+ assert "cancelled" in history["messages"][0]["text"]
+ assert history["messages"][1]["text"] == "AFTER_CANCEL_WAKE"
+
+
+@pytest.mark.asyncio
+async def test_send_message_route_then_agent_terminal_notification_reenters_followthrough(monkeypatch, tmp_path):
+ _patch_streaming_event_store(monkeypatch)
+
+ thread_id = "thread-route-send-message-followthrough"
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(model=_TurnTextModel("OWNER_OK", "AFTER_AGENT_ROUTE_WAKE"), checkpointer=checkpointer)
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ agent = _make_streaming_agent(loop, queue_manager=queue_manager)
+ app, _ = _make_streaming_app(
+ tmp_path,
+ thread_id=thread_id,
+ agent=agent,
+ queue_manager=queue_manager,
+ include_route_locks=True,
+ )
+
+ with (
+ patch("backend.web.services.agent_pool.get_or_create_agent", AsyncMock(return_value=agent)),
+ patch("backend.web.services.agent_pool.resolve_thread_sandbox", return_value="local"),
+ ):
+ result = await threads_router.send_message(
+ thread_id,
+ SendMessageRequest(message="start owner turn"),
+ user_id="u",
+ app=app,
+ )
+
+ assert result["status"] == "started"
+ await _wait_for_followthrough_text(loop, thread_id, "OWNER_OK")
+
+ queue_manager.enqueue(
+ "completedSimple background tool testSimple Background Tool Test Done",
+ thread_id,
+ notification_type="agent",
+ source="system",
+ )
+
+ await _wait_for_followthrough_text(loop, thread_id, "AFTER_AGENT_ROUTE_WAKE")
+
+ with (
+ patch.object(threads_router, "get_or_create_agent", return_value=agent),
+ patch.object(threads_router, "resolve_thread_sandbox", return_value="local"),
+ ):
+ history = await get_thread_history(thread_id, limit=20, truncate=400, user_id="u", app=app)
+
+ assert [item["role"] for item in history["messages"]] == ["human", "assistant", "notification", "assistant"]
+ assert history["messages"][0]["text"] == "start owner turn"
+ assert history["messages"][1]["text"] == "OWNER_OK"
+ assert "Simple Background Tool Test Done" in history["messages"][2]["text"]
+ assert history["messages"][3]["text"] == "AFTER_AGENT_ROUTE_WAKE"
+
+
+@pytest.mark.asyncio
+async def test_run_agent_to_buffer_adds_terminal_followthrough_system_note_to_prevent_silent_completion(monkeypatch, tmp_path):
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(model=_TerminalFollowthroughPromptAwareModel(), checkpointer=checkpointer)
+ entries = await _run_direct_notification_followthrough(
+ monkeypatch,
+ tmp_path,
+ loop=loop,
+ thread_id="thread-terminal-followthrough-note",
+ message="completed",
+ run_id="run-terminal-followthrough-note",
+ message_metadata={"source": "system", "notification_type": "command"},
+ )
+ _assert_notice_then_text(entries, "CommandNotification", "FOLLOWTHROUGH_ACK")
+
+
+@pytest.mark.asyncio
+async def test_run_agent_to_buffer_turns_silent_terminal_reentry_into_visible_followthrough(monkeypatch, tmp_path):
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(model=_TerminalFollowthroughSilentModel(), checkpointer=checkpointer)
+ entries = await _run_direct_notification_followthrough(
+ monkeypatch,
+ tmp_path,
+ loop=loop,
+ thread_id="thread-terminal-followthrough-silent",
+ message="completed",
+ run_id="run-terminal-followthrough-silent",
+ message_metadata={"source": "system", "notification_type": "command"},
+ )
+ _assert_notice_then_text(
+ entries,
+ "CommandNotification",
+ "Background command completed, but the followthrough assistant reply was empty.",
+ )
+
+
+@pytest.mark.asyncio
+async def test_run_agent_to_buffer_turns_silent_chat_notification_into_visible_followthrough(monkeypatch, tmp_path):
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(model=_ChatNotificationSilentModel(), checkpointer=checkpointer)
+ entries = await _run_direct_notification_followthrough(
+ monkeypatch,
+ tmp_path,
+ loop=loop,
+ thread_id="thread-chat-followthrough-silent",
+ message='\nNew message from alice in chat chat-123 (1 unread).\nRead it with read_messages(chat_id="chat-123").\nReply with send_message(chat_id="chat-123", content="...").\nDo not treat your normal assistant text as a chat reply.\n',
+ run_id="run-chat-followthrough-silent",
+ message_metadata={"source": "external", "notification_type": "chat"},
+ )
+ _assert_notice_then_text(
+ entries,
+ 'read_messages(chat_id="chat-123")',
+ 'I received a chat notification, but the followthrough assistant reply was empty. Read it with read_messages(chat_id="chat-123") before deciding whether to reply.',
+ )
+
+
+@pytest.mark.asyncio
+async def test_run_agent_to_buffer_tags_display_delta_with_source_seq(monkeypatch, tmp_path):
+ _patch_streaming_event_store(monkeypatch)
+ monkeypatch.setattr("backend.web.services.streaming_service._ensure_thread_handlers", lambda *args, **kwargs: None)
+
+ checkpointer = _MemoryCheckpointer()
+ loop = _make_loop(model=_NoToolModel("SEQ_OK"), checkpointer=checkpointer)
+ agent, app, thread_buf = _make_direct_streaming_context(tmp_path, loop)
+
+ await _run_agent_to_buffer(
+ agent,
+ "thread-display-delta-seq",
+ "hello",
+ app,
+ False,
+ thread_buf,
+ "run-display-delta-seq",
+ )
+
+ events, _ = await thread_buf.read_with_timeout(0, timeout=0.01)
+ assert events is not None
+ display_deltas = [json.loads(event["data"]) for event in events if event.get("event") == "display_delta"]
+ assert display_deltas
+ assert all(isinstance(delta.get("_seq"), int) for delta in display_deltas)
+
+
+@pytest.mark.asyncio
+async def test_run_agent_to_buffer_batches_additional_terminal_notifications(monkeypatch, tmp_path):
+ seq = 0
+
+ async def fake_append_event(thread_id, run_id, event, message_id=None, run_event_repo=None):
+ nonlocal seq
+ seq += 1
+ return seq
+
+ async def fake_cleanup_old_runs(thread_id, keep_latest=1, run_event_repo=None):
+ return 0
+
+ monkeypatch.setattr("backend.web.services.event_store.append_event", fake_append_event)
+ monkeypatch.setattr("backend.web.services.streaming_service.cleanup_old_runs", fake_cleanup_old_runs)
+ monkeypatch.setattr("backend.web.services.streaming_service._ensure_thread_handlers", lambda *args, **kwargs: None)
+
+ start_calls: list[tuple[str, str, dict | None]] = []
+
+ def fake_start_agent_run(agent, thread_id, message, app, enable_trajectory=False, message_metadata=None):
+ start_calls.append((thread_id, message, message_metadata))
+ return "run-next"
+
+ monkeypatch.setattr("backend.web.services.streaming_service.start_agent_run", fake_start_agent_run)
+
+ queue_manager = MessageQueueManager(db_path=str(tmp_path / "queue.db"))
+ queue_manager.enqueue(
+ "errorAgent failed",
+ "thread-batch-notice",
+ notification_type="agent",
+ )
+ queue_manager.enqueue(
+ "completed",
+ "thread-batch-notice",
+ notification_type="command",
+ )
+
+ agent = SimpleNamespace(
+ agent=_StreamingGraphAgent(),
+ runtime=_StreamingRuntime(),
+ storage_container=None,
+ )
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ display_builder=DisplayBuilder(),
+ thread_tasks={},
+ thread_event_buffers={},
+ subagent_buffers={},
+ queue_manager=queue_manager,
+ thread_last_active={},
+ typing_tracker=None,
+ )
+ )
+ thread_buf = ThreadEventBuffer()
+
+ await _run_agent_to_buffer(
+ agent,
+ "thread-batch-notice",
+ "completedBG_OK",
+ app,
+ False,
+ thread_buf,
+ "run-batch-notice",
+ message_metadata={"source": "system", "notification_type": "agent"},
+ )
+
+ entries = app.state.display_builder.get_entries("thread-batch-notice")
+ assert entries is not None
+ notice_segments = [segment for segment in entries[0]["segments"] if segment.get("type") == "notice"]
+ assert len(notice_segments) == 3
+ assert "BG_OK" in notice_segments[0]["content"]
+ assert "Agent failed" in notice_segments[1]["content"]
+ assert "CommandNotification" in notice_segments[2]["content"]
+ assert start_calls == []
+ assert queue_manager.list_queue("thread-batch-notice") == []
diff --git a/tests/test_queue_mode_integration.py b/tests/Integration/test_queue_mode_integration.py
similarity index 100%
rename from tests/test_queue_mode_integration.py
rename to tests/Integration/test_queue_mode_integration.py
diff --git a/tests/test_real_multiround.py b/tests/Integration/test_real_multiround.py
similarity index 100%
rename from tests/test_real_multiround.py
rename to tests/Integration/test_real_multiround.py
diff --git a/tests/test_sse_reconnect_integration.py b/tests/Integration/test_sse_reconnect_integration.py
similarity index 100%
rename from tests/test_sse_reconnect_integration.py
rename to tests/Integration/test_sse_reconnect_integration.py
diff --git a/tests/Integration/test_storage_runtime_wiring.py b/tests/Integration/test_storage_runtime_wiring.py
new file mode 100644
index 000000000..f4303b764
--- /dev/null
+++ b/tests/Integration/test_storage_runtime_wiring.py
@@ -0,0 +1,169 @@
+"""Runtime storage wiring tests for backend agent creation path."""
+
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+from backend.web.services import agent_pool
+from storage.providers.sqlite.checkpoint_repo import SQLiteCheckpointRepo
+from storage.providers.sqlite.eval_repo import SQLiteEvalRepo
+from storage.providers.supabase.checkpoint_repo import SupabaseCheckpointRepo
+
+
+class _FakeSupabaseClient:
+ def table(self, table_name: str):
+ raise AssertionError(f"table() should not be called in this wiring test: {table_name}")
+
+
+def _build_fake_supabase_client() -> _FakeSupabaseClient:
+ return _FakeSupabaseClient()
+
+
+def _build_invalid_supabase_client() -> object:
+ return object()
+
+
+def _capture_create_leon_agent(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]:
+ captured: dict[str, Any] = {}
+
+ def _fake_create_leon_agent(**kwargs):
+ captured.update(kwargs)
+ return object()
+
+ monkeypatch.setattr(agent_pool, "create_leon_agent", _fake_create_leon_agent)
+ return captured
+
+
+def test_create_agent_sync_wires_supabase_storage_container(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
+ monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
+ monkeypatch.setenv(
+ "LEON_SUPABASE_CLIENT_FACTORY",
+ "tests.Integration.test_storage_runtime_wiring:_build_fake_supabase_client",
+ )
+ monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
+ monkeypatch.setenv("LEON_EVAL_DB_PATH", str(tmp_path / "eval.db"))
+
+ captured = _capture_create_leon_agent(monkeypatch)
+ agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
+
+ container = captured["storage_container"]
+ assert isinstance(container.checkpoint_repo(), SupabaseCheckpointRepo)
+
+
+def test_create_agent_sync_supabase_missing_runtime_config_fails_loud(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
+ monkeypatch.delenv("LEON_SUPABASE_CLIENT_FACTORY", raising=False)
+
+ with pytest.raises(
+ RuntimeError,
+ match="LEON_SUPABASE_CLIENT_FACTORY",
+ ):
+ agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
+
+
+def test_create_agent_sync_supabase_invalid_runtime_config_fails_loud(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
+ monkeypatch.setenv(
+ "LEON_SUPABASE_CLIENT_FACTORY",
+ "tests.Integration.test_storage_runtime_wiring:_build_invalid_supabase_client",
+ )
+
+ with pytest.raises(RuntimeError, match="callable table\\(name\\) API"):
+ agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
+
+
+def test_create_agent_sync_defaults_to_sqlite_storage_container(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ monkeypatch.delenv("LEON_STORAGE_STRATEGY", raising=False)
+ monkeypatch.delenv("LEON_SUPABASE_CLIENT_FACTORY", raising=False)
+ monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
+
+ captured = _capture_create_leon_agent(monkeypatch)
+ agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
+
+ container = captured["storage_container"]
+ assert isinstance(container.checkpoint_repo(), SQLiteCheckpointRepo)
+
+
+def test_create_agent_sync_enables_thread_permission_resolver_scope(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ monkeypatch.delenv("LEON_STORAGE_STRATEGY", raising=False)
+ monkeypatch.delenv("LEON_SUPABASE_CLIENT_FACTORY", raising=False)
+ monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
+
+ captured = _capture_create_leon_agent(monkeypatch)
+ agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
+
+ assert captured["permission_resolver_scope"] == "thread"
+
+
+def test_create_agent_sync_repo_override_supabase_with_sqlite_default(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ monkeypatch.setenv("LEON_STORAGE_STRATEGY", "sqlite")
+ monkeypatch.setenv("LEON_STORAGE_REPO_PROVIDERS", '{"checkpoint_repo":"supabase"}')
+ monkeypatch.setenv(
+ "LEON_SUPABASE_CLIENT_FACTORY",
+ "tests.Integration.test_storage_runtime_wiring:_build_fake_supabase_client",
+ )
+ monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
+
+ captured = _capture_create_leon_agent(monkeypatch)
+ agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
+ container = captured["storage_container"]
+ assert isinstance(container.checkpoint_repo(), SupabaseCheckpointRepo)
+
+
+def test_create_agent_sync_repo_override_sqlite_with_supabase_default(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
+ monkeypatch.setenv("LEON_STORAGE_REPO_PROVIDERS", '{"eval_repo":"sqlite"}')
+ monkeypatch.setenv(
+ "LEON_SUPABASE_CLIENT_FACTORY",
+ "tests.Integration.test_storage_runtime_wiring:_build_fake_supabase_client",
+ )
+ monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
+ monkeypatch.setenv("LEON_EVAL_DB_PATH", str(tmp_path / "eval.db"))
+
+ captured = _capture_create_leon_agent(monkeypatch)
+ agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
+ container = captured["storage_container"]
+ assert isinstance(container.eval_repo(), SQLiteEvalRepo)
+
+
+def test_create_agent_sync_repo_override_supabase_without_runtime_config_fails_loud(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ monkeypatch.setenv("LEON_STORAGE_STRATEGY", "sqlite")
+ monkeypatch.setenv("LEON_STORAGE_REPO_PROVIDERS", '{"checkpoint_repo":"supabase"}')
+ monkeypatch.delenv("LEON_SUPABASE_CLIENT_FACTORY", raising=False)
+
+ with pytest.raises(RuntimeError, match="LEON_SUPABASE_CLIENT_FACTORY"):
+ agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
+
+
+def test_create_agent_sync_invalid_repo_override_json_fails_loud(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ monkeypatch.setenv("LEON_STORAGE_REPO_PROVIDERS", "not-json")
+
+ with pytest.raises(RuntimeError, match="Invalid LEON_STORAGE_REPO_PROVIDERS"):
+ agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
diff --git a/tests/Integration/test_threads_router.py b/tests/Integration/test_threads_router.py
new file mode 100644
index 000000000..21daac42b
--- /dev/null
+++ b/tests/Integration/test_threads_router.py
@@ -0,0 +1,893 @@
+from __future__ import annotations
+
+import json
+from contextlib import contextmanager
+from pathlib import Path
+from types import SimpleNamespace
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
+
+from backend.web.models.requests import CreateThreadRequest
+from backend.web.routers import threads as threads_router
+from core.runtime.loop import QueryLoop
+from core.runtime.middleware.monitor import AgentState
+from core.runtime.registry import ToolRegistry
+from core.runtime.state import AppState, BootstrapConfig, ToolPermissionState
+from storage.contracts import MemberRow, MemberType
+
+
+class _FakeMemberRepo:
+ def __init__(self) -> None:
+ self._members = {
+ "member-1": MemberRow(
+ id="member-1",
+ name="Toad",
+ type=MemberType.MYCEL_AGENT,
+ owner_user_id="owner-1",
+ created_at=1.0,
+ )
+ }
+ self._seq = {"member-1": 0}
+
+ def get_by_id(self, member_id: str):
+ return self._members.get(member_id)
+
+ def increment_entity_seq(self, member_id: str) -> int:
+ self._seq[member_id] += 1
+ return self._seq[member_id]
+
+
+class _FakeThreadRepo:
+ def __init__(self) -> None:
+ self.rows: dict[str, dict] = {}
+
+ def get_by_id(self, thread_id: str):
+ row = self.rows.get(thread_id)
+ if row is None:
+ return None
+ return {"id": thread_id, **row}
+
+ def get_main_thread(self, member_id: str):
+ for row in self.rows.values():
+ if row["member_id"] == member_id and row["is_main"]:
+ return {"id": row["thread_id"], **row}
+ return None
+
+ def get_next_branch_index(self, member_id: str) -> int:
+ indices = [row["branch_index"] for row in self.rows.values() if row["member_id"] == member_id]
+ return max(indices, default=0) + 1
+
+ def create(self, **kwargs):
+ self.rows[kwargs["thread_id"]] = dict(kwargs)
+
+
+class _FakeEntityRepo:
+ def __init__(self) -> None:
+ self.rows = []
+
+ def create(self, row):
+ self.rows.append(row)
+
+ def get_by_id(self, entity_id: str):
+ for row in self.rows:
+ if row.id == entity_id:
+ return row
+ return None
+
+ def update_thread_id(self, entity_id: str, thread_id: str):
+ row = self.get_by_id(entity_id)
+ if row is not None:
+ row.thread_id = thread_id
+
+
+class _FakeAuthService:
+ def __init__(self) -> None:
+ self.tokens: list[str] = []
+
+ def verify_token(self, token: str) -> dict:
+ self.tokens.append(token)
+ return {"user_id": "owner-1"}
+
+
+class _FakeRequest:
+ def __init__(self, headers: dict[str, str] | None = None) -> None:
+ self.headers = headers or {}
+
+
+class _FakePermissionAgent:
+ def __init__(self) -> None:
+ self.pending = [
+ {
+ "request_id": "perm-1",
+ "thread_id": "thread-1",
+ "tool_name": "Write",
+ "args": {"path": "/tmp/demo.txt"},
+ "message": "needs approval",
+ }
+ ]
+ self.session_rules = {
+ "allow": ["Read"],
+ "deny": ["Bash"],
+ "ask": ["Edit"],
+ }
+ self.managed_only = False
+ self.resolve_calls: list[tuple[str, str, str | None, list[dict] | None, dict | None]] = []
+ self.rule_add_calls: list[tuple[str, str]] = []
+ self.rule_remove_calls: list[tuple[str, str]] = []
+ self.agent = SimpleNamespace(
+ aget_state=AsyncMock(return_value=SimpleNamespace(values={})),
+ apersist_state=AsyncMock(),
+ )
+
+ def get_pending_permission_requests(self, thread_id: str | None = None):
+ if thread_id is None:
+ return list(self.pending)
+ return [item for item in self.pending if item["thread_id"] == thread_id]
+
+ def resolve_permission_request(
+ self,
+ request_id: str,
+ *,
+ decision: str,
+ message: str | None = None,
+ answers: list[dict] | None = None,
+ annotations: dict | None = None,
+ ) -> bool:
+ self.resolve_calls.append((request_id, decision, message, answers, annotations))
+ if request_id != "perm-1":
+ return False
+ self.pending = []
+ return True
+
+ def drop_permission_request(self, request_id: str) -> bool:
+ before = len(self.pending)
+ self.pending = [item for item in self.pending if item["request_id"] != request_id]
+ return len(self.pending) != before
+
+ def get_thread_permission_rules(self, thread_id: str) -> dict[str, object]:
+ return {
+ "thread_id": thread_id,
+ "scope": "session",
+ "managed_only": self.managed_only,
+ "rules": dict(self.session_rules),
+ }
+
+ def add_thread_permission_rule(self, thread_id: str, *, behavior: str, tool_name: str) -> bool:
+ self.rule_add_calls.append((behavior, tool_name))
+ if self.managed_only:
+ return False
+ for bucket in self.session_rules.values():
+ if tool_name in bucket:
+ bucket.remove(tool_name)
+ bucket = self.session_rules.setdefault(behavior, [])
+ if tool_name not in bucket:
+ bucket.append(tool_name)
+ return True
+
+ def remove_thread_permission_rule(self, thread_id: str, *, behavior: str, tool_name: str) -> bool:
+ self.rule_remove_calls.append((behavior, tool_name))
+ bucket = self.session_rules.get(behavior, [])
+ if tool_name not in bucket:
+ return False
+ bucket.remove(tool_name)
+ return True
+
+
+class _MemoryCheckpointer:
+ def __init__(self, channel_values: dict | None = None) -> None:
+ self._checkpoint = {"channel_values": dict(channel_values or {})}
+
+ async def aget(self, _cfg):
+ return self._checkpoint
+
+
+class _LivePendingPermissionAgent:
+ def __init__(self) -> None:
+ app_state = AppState(
+ tool_permission_context=ToolPermissionState(alwaysAskRules={"session": ["Bash"]}),
+ pending_permission_requests={
+ "perm-live": {
+ "request_id": "perm-live",
+ "thread_id": "thread-1",
+ "tool_name": "Bash",
+ "args": {"command": "echo hi"},
+ "message": "Permission required by rule: Bash",
+ }
+ },
+ )
+ self.agent = QueryLoop(
+ model=MagicMock(),
+ system_prompt=SystemMessage(content="sys"),
+ middleware=[],
+ checkpointer=_MemoryCheckpointer(channel_values={"messages": []}),
+ registry=ToolRegistry(),
+ app_state=app_state,
+ runtime=SimpleNamespace(current_state=AgentState.ACTIVE),
+ bootstrap=BootstrapConfig(
+ workspace_root=Path("/tmp"),
+ model_name="test-model",
+ permission_resolver_scope="thread",
+ ),
+ max_turns=1,
+ )
+
+ def get_pending_permission_requests(self, thread_id: str | None = None):
+ requests = list(self.agent._app_state.pending_permission_requests.values())
+ if thread_id is None:
+ return requests
+ return [item for item in requests if item["thread_id"] == thread_id]
+
+ def get_thread_permission_rules(self, thread_id: str) -> dict[str, object]:
+ state = self.agent._app_state.tool_permission_context
+ return {
+ "thread_id": thread_id,
+ "scope": "session",
+ "managed_only": state.allowManagedPermissionRulesOnly,
+ "rules": {
+ "allow": list(state.alwaysAllowRules.get("session", [])),
+ "deny": list(state.alwaysDenyRules.get("session", [])),
+ "ask": list(state.alwaysAskRules.get("session", [])),
+ },
+ }
+
+
+class _FakeAskUserQuestionAgent(_FakePermissionAgent):
+ def __init__(self) -> None:
+ super().__init__()
+ self.pending = [
+ {
+ "request_id": "perm-ask",
+ "thread_id": "thread-1",
+ "tool_name": "AskUserQuestion",
+ "args": {
+ "questions": [
+ {
+ "header": "Style",
+ "question": "Choose a style",
+ "options": [
+ {"label": "Minimal", "description": "Keep it simple"},
+ {"label": "Bold", "description": "Make it loud"},
+ ],
+ }
+ ]
+ },
+ "message": "Please answer the following questions so Leon can continue.",
+ }
+ ]
+
+ def resolve_permission_request(
+ self,
+ request_id: str,
+ *,
+ decision: str,
+ message: str | None = None,
+ answers: list[dict] | None = None,
+ annotations: dict | None = None,
+ ) -> bool:
+ self.resolve_calls.append((request_id, decision, message, answers, annotations))
+ if request_id != "perm-ask":
+ return False
+ self.pending = []
+ return True
+
+
+class _NullLock:
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+
+class _FakeClearAgent:
+ def __init__(self, state: AgentState = AgentState.IDLE) -> None:
+ self.runtime = SimpleNamespace(current_state=state)
+ self.aclear_thread = AsyncMock()
+
+
+def _make_threads_app(
+ *,
+ member_repo=None,
+ thread_repo=None,
+ entity_repo=None,
+ **state_overrides,
+):
+ return SimpleNamespace(
+ state=SimpleNamespace(
+ member_repo=member_repo or _FakeMemberRepo(),
+ thread_repo=thread_repo or _FakeThreadRepo(),
+ entity_repo=entity_repo or _FakeEntityRepo(),
+ **state_overrides,
+ )
+ )
+
+
+def _make_clear_thread_app():
+ display_builder = SimpleNamespace(clear=MagicMock())
+ queue_manager = SimpleNamespace(clear_all=MagicMock())
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ agent_pool={},
+ display_builder=display_builder,
+ queue_manager=queue_manager,
+ thread_event_buffers={"thread-1": object()},
+ )
+ )
+ return app, display_builder, queue_manager
+
+
+@contextmanager
+def _patch_create_thread_noop_guards():
+ with (
+ patch.object(threads_router, "_validate_sandbox_provider_gate", return_value=None),
+ patch.object(threads_router, "_validate_mount_capability_gate", return_value=None),
+ patch.object(threads_router, "_create_thread_sandbox_resources", return_value=None) as create_resources,
+ patch.object(threads_router, "_invalidate_resource_overview_cache", return_value=None),
+ patch.object(threads_router, "save_last_successful_config", return_value=None),
+ ):
+ yield create_resources
+
+
+@contextmanager
+def _patch_local_clear_thread_agent(agent):
+ with (
+ patch.object(threads_router, "resolve_thread_sandbox", return_value="local"),
+ patch.object(threads_router, "get_or_create_agent", AsyncMock(return_value=agent)),
+ patch.object(threads_router, "get_thread_lock", AsyncMock(return_value=_NullLock())),
+ ):
+ yield
+
+
+@pytest.mark.asyncio
+async def test_create_thread_route_preserves_legacy_sandbox_type_alias():
+ app = _make_threads_app(thread_sandbox={}, thread_cwd={})
+ payload = CreateThreadRequest.model_validate(
+ {
+ "member_id": "member-1",
+ "sandbox_type": "daytona_selfhost",
+ "model": "gpt-5.4-mini",
+ }
+ )
+
+ with _patch_create_thread_noop_guards():
+ result = await threads_router.create_thread(payload, "owner-1", app)
+
+ assert result["sandbox"] == "daytona_selfhost"
+ assert app.state.thread_sandbox[result["thread_id"]] == "daytona_selfhost"
+ assert app.state.thread_repo.rows[result["thread_id"]]["sandbox_type"] == "daytona_selfhost"
+
+
+@pytest.mark.asyncio
+async def test_resolve_main_thread_returns_null_for_orphaned_main_thread_metadata():
+ thread_repo = _FakeThreadRepo()
+ thread_repo.create(
+ thread_id="thread-1",
+ member_id="member-1",
+ owner_user_id="owner-1",
+ sandbox_type="local",
+ is_main=True,
+ branch_index=0,
+ )
+ app = _make_threads_app(thread_repo=thread_repo)
+
+ payload = threads_router.ResolveMainThreadRequest(member_id="member-1")
+
+ result = await threads_router.resolve_main_thread(payload, "owner-1", app)
+
+ assert result == {"thread": None}
+
+
+@pytest.mark.asyncio
+async def test_create_thread_route_uses_canonical_existing_lease_binding_helper():
+ app = _make_threads_app(thread_sandbox={}, thread_cwd={})
+ payload = CreateThreadRequest.model_validate(
+ {
+ "member_id": "member-1",
+ "lease_id": "lease-1",
+ "cwd": "/workspace/reused",
+ }
+ )
+
+ with (
+ patch.object(
+ threads_router.sandbox_service,
+ "list_user_leases",
+ return_value=[{"lease_id": "lease-1", "provider_name": "local", "recipe": None}],
+ ),
+ patch.object(threads_router, "bind_thread_to_existing_lease", return_value="/workspace/reused") as bind_helper,
+ patch.object(threads_router, "_invalidate_resource_overview_cache", return_value=None),
+ patch.object(threads_router, "save_last_successful_config", return_value=None),
+ ):
+ result = await threads_router.create_thread(payload, "owner-1", app)
+
+ bind_helper.assert_called_once_with(
+ result["thread_id"],
+ "lease-1",
+ cwd="/workspace/reused",
+ )
+ assert app.state.thread_cwd[result["thread_id"]] == "/workspace/reused"
+
+
+@pytest.mark.asyncio
+async def test_create_thread_route_passes_local_cwd_into_sandbox_bootstrap():
+ app = _make_threads_app(thread_sandbox={}, thread_cwd={})
+ payload = CreateThreadRequest.model_validate(
+ {
+ "member_id": "member-1",
+ "cwd": "/tmp/fresh-local-thread",
+ }
+ )
+
+ with _patch_create_thread_noop_guards() as create_resources:
+ result = await threads_router.create_thread(payload, "owner-1", app)
+
+ create_resources.assert_called_once_with(
+ result["thread_id"],
+ "local",
+ None,
+ "/tmp/fresh-local-thread",
+ )
+
+
+@pytest.mark.asyncio
+async def test_list_threads_hides_internal_subagent_threads():
+ app = _make_threads_app(
+ thread_repo=SimpleNamespace(
+ list_by_owner_user_id=lambda user_id: [
+ {
+ "id": "main-thread",
+ "sandbox_type": "local",
+ "member_name": "Toad",
+ "member_id": "member-1",
+ "entity_name": "Toad",
+ "branch_index": 0,
+ "is_main": True,
+ "member_avatar": None,
+ },
+ {
+ "id": "subagent-deadbeef",
+ "sandbox_type": "local",
+ "member_name": "Toad",
+ "member_id": "member-1",
+ "entity_name": "worker-1",
+ "branch_index": 1,
+ "is_main": False,
+ "member_avatar": None,
+ },
+ ]
+ ),
+ agent_pool={},
+ thread_last_active={},
+ )
+
+ payload = await threads_router.list_threads("owner-1", app)
+
+ assert [item["thread_id"] for item in payload["threads"]] == ["main-thread"]
+
+
+@pytest.mark.asyncio
+async def test_create_thread_route_rejects_unavailable_provider():
+ app = _make_threads_app(thread_sandbox={}, thread_cwd={})
+ payload = CreateThreadRequest.model_validate(
+ {
+ "member_id": "member-1",
+ "sandbox": "daytona",
+ }
+ )
+
+ with patch.object(threads_router.sandbox_service, "build_provider_from_config_name", return_value=None):
+ result = await threads_router.create_thread(payload, "owner-1", app)
+
+ assert isinstance(result, threads_router.JSONResponse)
+ assert result.status_code == 400
+ assert json.loads(result.body.decode()) == {
+ "error": "sandbox_provider_unavailable",
+ "provider": "daytona",
+ }
+ assert app.state.thread_repo.rows == {}
+
+
+@pytest.mark.asyncio
+async def test_create_thread_route_rejects_unavailable_provider_for_existing_lease():
+ app = _make_threads_app(thread_sandbox={}, thread_cwd={})
+ payload = CreateThreadRequest.model_validate(
+ {
+ "member_id": "member-1",
+ "lease_id": "lease-1",
+ }
+ )
+
+ with (
+ patch.object(
+ threads_router.sandbox_service,
+ "list_user_leases",
+ return_value=[{"lease_id": "lease-1", "provider_name": "daytona", "recipe": None}],
+ ),
+ patch.object(threads_router.sandbox_service, "build_provider_from_config_name", return_value=None),
+ ):
+ result = await threads_router.create_thread(payload, "owner-1", app)
+
+ assert isinstance(result, threads_router.JSONResponse)
+ assert result.status_code == 400
+ assert json.loads(result.body.decode()) == {
+ "error": "sandbox_provider_unavailable",
+ "provider": "daytona",
+ }
+ assert app.state.thread_repo.rows == {}
+
+
+@pytest.mark.asyncio
+async def test_stream_thread_events_requires_token():
+ app = _make_threads_app(
+ auth_service=_FakeAuthService(),
+ thread_repo=SimpleNamespace(get_by_id=lambda _thread_id: None),
+ thread_event_buffers={},
+ )
+
+ with pytest.raises(threads_router.HTTPException) as exc_info:
+ await threads_router.stream_thread_events(
+ "thread-1",
+ request=_FakeRequest(),
+ token=None,
+ app=app,
+ )
+
+ assert exc_info.value.status_code == 401
+ assert exc_info.value.detail == "Missing token"
+
+
+@pytest.mark.asyncio
+async def test_stream_thread_events_verifies_token_before_owner_check():
+ auth_service = _FakeAuthService()
+ thread_repo = SimpleNamespace(get_by_id=lambda _thread_id: {"member_id": "member-1"})
+ app = _make_threads_app(
+ auth_service=auth_service,
+ thread_repo=thread_repo,
+ thread_event_buffers={},
+ )
+
+ response = await threads_router.stream_thread_events(
+ "thread-1",
+ request=_FakeRequest(),
+ token="tok-thread",
+ app=app,
+ )
+
+ assert auth_service.tokens == ["tok-thread"]
+ assert response is not None
+
+
+@pytest.mark.asyncio
+async def test_get_thread_permissions_returns_thread_scoped_pending_requests():
+ agent = _FakePermissionAgent()
+
+ result = await threads_router.get_thread_permissions(
+ "thread-1",
+ user_id="owner-1",
+ agent=agent,
+ )
+
+ assert result == {
+ "thread_id": "thread-1",
+ "requests": [
+ {
+ "request_id": "perm-1",
+ "thread_id": "thread-1",
+ "tool_name": "Write",
+ "args": {"path": "/tmp/demo.txt"},
+ "message": "needs approval",
+ }
+ ],
+ "session_rules": {
+ "allow": ["Read"],
+ "deny": ["Bash"],
+ "ask": ["Edit"],
+ },
+ "managed_only": False,
+ }
+
+
+@pytest.mark.asyncio
+async def test_get_thread_permissions_does_not_clear_live_pending_requests_during_active_run():
+ agent = _LivePendingPermissionAgent()
+
+ result = await threads_router.get_thread_permissions(
+ "thread-1",
+ user_id="owner-1",
+ agent=agent,
+ )
+
+ assert result == {
+ "thread_id": "thread-1",
+ "requests": [
+ {
+ "request_id": "perm-live",
+ "thread_id": "thread-1",
+ "tool_name": "Bash",
+ "args": {"command": "echo hi"},
+ "message": "Permission required by rule: Bash",
+ }
+ ],
+ "session_rules": {
+ "allow": [],
+ "deny": [],
+ "ask": ["Bash"],
+ },
+ "managed_only": False,
+ }
+ assert agent.agent._app_state.pending_permission_requests == {
+ "perm-live": {
+ "request_id": "perm-live",
+ "thread_id": "thread-1",
+ "tool_name": "Bash",
+ "args": {"command": "echo hi"},
+ "message": "Permission required by rule: Bash",
+ }
+ }
+
+
+@pytest.mark.asyncio
+async def test_get_thread_history_does_not_clear_live_pending_requests_during_active_run():
+ agent = _LivePendingPermissionAgent()
+ agent.agent._app_state.messages = [
+ HumanMessage(content="please run bash"),
+ ToolMessage(content="Permission required by rule: Bash", tool_call_id="call-1", name="Bash"),
+ ]
+
+ with (
+ patch.object(threads_router, "resolve_thread_sandbox", return_value="local"),
+ patch.object(
+ threads_router,
+ "get_or_create_agent",
+ AsyncMock(return_value=agent),
+ ),
+ ):
+ result = await threads_router.get_thread_history(
+ "thread-1",
+ limit=20,
+ truncate=0,
+ user_id="owner-1",
+ app=SimpleNamespace(state=SimpleNamespace()),
+ )
+
+ assert result["messages"] == [
+ {"role": "human", "text": "please run bash"},
+ {"role": "tool_result", "tool": "Bash", "text": "Permission required by rule: Bash"},
+ ]
+ assert agent.agent._app_state.pending_permission_requests == {
+ "perm-live": {
+ "request_id": "perm-live",
+ "thread_id": "thread-1",
+ "tool_name": "Bash",
+ "args": {"command": "echo hi"},
+ "message": "Permission required by rule: Bash",
+ }
+ }
+
+
+@pytest.mark.asyncio
+async def test_resolve_thread_permission_request_persists_resolution():
+ agent = _FakePermissionAgent()
+
+ result = await threads_router.resolve_thread_permission_request(
+ "thread-1",
+ "perm-1",
+ SimpleNamespace(decision="allow", message="go ahead"),
+ user_id="owner-1",
+ agent=agent,
+ )
+
+ assert result == {"ok": True, "thread_id": "thread-1", "request_id": "perm-1"}
+ assert agent.resolve_calls == [("perm-1", "allow", "go ahead", None, None)]
+ agent.agent.apersist_state.assert_awaited_once_with("thread-1")
+
+
+@pytest.mark.asyncio
+async def test_resolve_ask_user_question_request_starts_followup_run_with_answers():
+ agent = _FakeAskUserQuestionAgent()
+ app = SimpleNamespace()
+ payload = SimpleNamespace(
+ decision="allow",
+ message=None,
+ answers=[
+ {
+ "header": "Style",
+ "question": "Choose a style",
+ "selected_options": ["Minimal"],
+ }
+ ],
+ annotations={"source": "ask-user-ui"},
+ )
+
+ with patch(
+ "backend.web.services.message_routing.route_message_to_brain",
+ AsyncMock(return_value={"status": "started", "routing": "direct", "thread_id": "thread-1"}),
+ ) as route_message:
+ result = await threads_router.resolve_thread_permission_request(
+ "thread-1",
+ "perm-ask",
+ payload,
+ user_id="owner-1",
+ agent=agent,
+ app=app,
+ )
+
+ assert result == {
+ "ok": True,
+ "thread_id": "thread-1",
+ "request_id": "perm-ask",
+ "followup": {"status": "started", "routing": "direct", "thread_id": "thread-1"},
+ }
+ assert agent.resolve_calls == [
+ (
+ "perm-ask",
+ "allow",
+ None,
+ [
+ {
+ "header": "Style",
+ "question": "Choose a style",
+ "selected_options": ["Minimal"],
+ }
+ ],
+ {"source": "ask-user-ui"},
+ )
+ ]
+ route_message.assert_awaited_once()
+ followup_message = route_message.await_args.args[2]
+ assert "AskUserQuestion" in followup_message
+ assert "Minimal" in followup_message
+ assert "Choose a style" in followup_message
+ assert agent.pending == []
+ assert agent.agent.apersist_state.await_count == 2
+ assert [call.args for call in agent.agent.apersist_state.await_args_list] == [("thread-1",), ("thread-1",)]
+
+
+@pytest.mark.asyncio
+async def test_resolve_ask_user_question_request_requires_answers_for_allow():
+ agent = _FakeAskUserQuestionAgent()
+
+ with pytest.raises(threads_router.HTTPException) as exc_info:
+ await threads_router.resolve_thread_permission_request(
+ "thread-1",
+ "perm-ask",
+ SimpleNamespace(decision="allow", message=None, answers=None, annotations=None),
+ user_id="owner-1",
+ agent=agent,
+ app=SimpleNamespace(),
+ )
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "AskUserQuestion answers are required when approving the request"
+ agent.agent.apersist_state.assert_not_awaited()
+
+
+@pytest.mark.asyncio
+async def test_resolve_thread_permission_request_404s_missing_request():
+ agent = _FakePermissionAgent()
+
+ with pytest.raises(threads_router.HTTPException) as exc_info:
+ await threads_router.resolve_thread_permission_request(
+ "thread-1",
+ "missing",
+ SimpleNamespace(decision="deny", message="no"),
+ user_id="owner-1",
+ agent=agent,
+ )
+
+ assert exc_info.value.status_code == 404
+ assert exc_info.value.detail == "Permission request not found"
+ agent.agent.apersist_state.assert_not_awaited()
+
+
+@pytest.mark.asyncio
+async def test_add_thread_permission_rule_persists_session_rule():
+ agent = _FakePermissionAgent()
+
+ result = await threads_router.add_thread_permission_rule(
+ "thread-1",
+ SimpleNamespace(behavior="allow", tool_name="Write"),
+ user_id="owner-1",
+ agent=agent,
+ )
+
+ assert result == {
+ "ok": True,
+ "thread_id": "thread-1",
+ "scope": "session",
+ "rules": {
+ "allow": ["Read", "Write"],
+ "deny": ["Bash"],
+ "ask": ["Edit"],
+ },
+ "managed_only": False,
+ }
+ assert agent.rule_add_calls == [("allow", "Write")]
+ agent.agent.apersist_state.assert_awaited_once_with("thread-1")
+
+
+@pytest.mark.asyncio
+async def test_add_thread_permission_rule_fails_loud_when_managed_only():
+ agent = _FakePermissionAgent()
+ agent.managed_only = True
+
+ with pytest.raises(threads_router.HTTPException) as exc_info:
+ await threads_router.add_thread_permission_rule(
+ "thread-1",
+ SimpleNamespace(behavior="allow", tool_name="Write"),
+ user_id="owner-1",
+ agent=agent,
+ )
+
+ assert exc_info.value.status_code == 409
+ assert exc_info.value.detail == "Managed permission rules only; session overrides are disabled"
+ agent.agent.apersist_state.assert_not_awaited()
+
+
+@pytest.mark.asyncio
+async def test_remove_thread_permission_rule_persists_session_rule_change():
+ agent = _FakePermissionAgent()
+
+ result = await threads_router.delete_thread_permission_rule(
+ "thread-1",
+ "deny",
+ "Bash",
+ user_id="owner-1",
+ agent=agent,
+ )
+
+ assert result == {
+ "ok": True,
+ "thread_id": "thread-1",
+ "scope": "session",
+ "rules": {
+ "allow": ["Read"],
+ "deny": [],
+ "ask": ["Edit"],
+ },
+ "managed_only": False,
+ }
+ assert agent.rule_remove_calls == [("deny", "Bash")]
+ agent.agent.apersist_state.assert_awaited_once_with("thread-1")
+
+
+@pytest.mark.asyncio
+async def test_clear_thread_route_clears_agent_state_and_thread_buffers():
+ agent = _FakeClearAgent()
+ app, display_builder, queue_manager = _make_clear_thread_app()
+
+ with _patch_local_clear_thread_agent(agent):
+ result = await threads_router.clear_thread_history(
+ "thread-1",
+ user_id="owner-1",
+ app=app,
+ )
+
+ assert result == {"ok": True, "thread_id": "thread-1"}
+ agent.aclear_thread.assert_awaited_once_with("thread-1")
+ display_builder.clear.assert_called_once_with("thread-1")
+ queue_manager.clear_all.assert_called_once_with("thread-1")
+ assert app.state.thread_event_buffers == {}
+
+
+@pytest.mark.asyncio
+async def test_clear_thread_route_rejects_active_run():
+ agent = _FakeClearAgent(state=AgentState.ACTIVE)
+ app, display_builder, queue_manager = _make_clear_thread_app()
+
+ with _patch_local_clear_thread_agent(agent):
+ with pytest.raises(threads_router.HTTPException) as exc_info:
+ await threads_router.clear_thread_history(
+ "thread-1",
+ user_id="owner-1",
+ app=app,
+ )
+
+ assert exc_info.value.status_code == 409
+ assert exc_info.value.detail == "Cannot clear thread while run is in progress"
+ agent.aclear_thread.assert_not_awaited()
+ display_builder.clear.assert_not_called()
+ queue_manager.clear_all.assert_not_called()
+ assert "thread-1" in app.state.thread_event_buffers
diff --git a/tests/Unit/backend/test_message_routing.py b/tests/Unit/backend/test_message_routing.py
new file mode 100644
index 000000000..9c5cf47d4
--- /dev/null
+++ b/tests/Unit/backend/test_message_routing.py
@@ -0,0 +1,52 @@
+from __future__ import annotations
+
+import asyncio
+from types import SimpleNamespace
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+from backend.web.services.message_routing import route_message_to_brain
+from core.runtime.middleware.monitor import AgentState
+
+
+class _FakeQueueManager:
+ def enqueue(self, *args, **kwargs) -> None:
+ raise AssertionError("enqueue should not be used for idle -> active routing")
+
+
+class _FakeRuntime:
+ def __init__(self) -> None:
+ self.current_state = AgentState.IDLE
+
+ def transition(self, next_state: AgentState) -> bool:
+ self.current_state = next_state
+ return True
+
+
+class _FakeAgent:
+ def __init__(self) -> None:
+ self.runtime = _FakeRuntime()
+
+
+@pytest.mark.asyncio
+async def test_route_message_to_brain_clears_resource_overview_cache_when_starting_run() -> None:
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ queue_manager=_FakeQueueManager(),
+ thread_locks={},
+ thread_locks_guard=asyncio.Lock(),
+ )
+ )
+ agent = _FakeAgent()
+
+ with (
+ patch("backend.web.services.agent_pool.resolve_thread_sandbox", return_value="local"),
+ patch("backend.web.services.agent_pool.get_or_create_agent", AsyncMock(return_value=agent)),
+ patch("backend.web.services.streaming_service.start_agent_run", return_value="run-123"),
+ patch("backend.web.services.resource_cache.clear_resource_overview_cache") as clear_cache,
+ ):
+ result = await route_message_to_brain(app, "thread-1", "hello")
+
+ assert result == {"status": "started", "routing": "direct", "run_id": "run-123", "thread_id": "thread-1"}
+ clear_cache.assert_called_once_with()
diff --git a/tests/Unit/core/test_agent_pool.py b/tests/Unit/core/test_agent_pool.py
new file mode 100644
index 000000000..1f537dfc2
--- /dev/null
+++ b/tests/Unit/core/test_agent_pool.py
@@ -0,0 +1,206 @@
+import asyncio
+import time
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Any, cast
+
+import pytest
+
+from backend.web.services import agent_pool
+
+
+class _FakeThreadRepo:
+ def get_by_id(self, thread_id: str):
+ return {"id": thread_id, "cwd": "/tmp", "model": "leon:large"}
+
+
+@pytest.mark.asyncio
+async def test_get_or_create_agent_creates_once_per_thread(monkeypatch: pytest.MonkeyPatch):
+ created: list[object] = []
+
+ def _fake_create_agent_sync(
+ sandbox_name: str,
+ workspace_root=None,
+ model_name: str | None = None,
+ agent: str | None = None,
+ bundle_dir=None,
+ thread_repo=None,
+ entity_repo=None,
+ member_repo=None,
+ queue_manager=None,
+ chat_repos=None,
+ extra_allowed_paths=None,
+ web_app=None,
+ ) -> object:
+ time.sleep(0.05)
+ obj = SimpleNamespace()
+ created.append(obj)
+ return obj
+
+ monkeypatch.setattr(agent_pool, "create_agent_sync", _fake_create_agent_sync)
+ monkeypatch.setattr(agent_pool, "get_or_create_agent_id", lambda **_: "agent-1")
+
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ agent_pool={},
+ thread_repo=_FakeThreadRepo(),
+ thread_cwd={},
+ thread_sandbox={},
+ )
+ )
+
+ first, second = await asyncio.gather(
+ agent_pool.get_or_create_agent(cast(Any, app), "local", thread_id="thread-1"),
+ agent_pool.get_or_create_agent(cast(Any, app), "local", thread_id="thread-1"),
+ )
+
+ assert len(created) == 1
+ assert first is second
+ assert app.state.agent_pool["thread-1:local"] is first
+
+
+@pytest.mark.asyncio
+async def test_get_or_create_agent_ignores_unavailable_local_cwd(monkeypatch: pytest.MonkeyPatch):
+ captured: dict[str, object] = {}
+
+ def _fake_create_agent_sync(
+ sandbox_name: str,
+ workspace_root=None,
+ model_name: str | None = None,
+ agent: str | None = None,
+ bundle_dir=None,
+ thread_repo=None,
+ entity_repo=None,
+ member_repo=None,
+ queue_manager=None,
+ chat_repos=None,
+ extra_allowed_paths=None,
+ web_app=None,
+ ) -> object:
+ captured["workspace_root"] = workspace_root
+ return SimpleNamespace()
+
+ class _ThreadRepo:
+ def get_by_id(self, thread_id: str):
+ return {
+ "id": thread_id,
+ "cwd": "/Users/lexicalmathical/Codebase/homeworks/aiagent",
+ "model": "leon:large",
+ }
+
+ monkeypatch.setattr(agent_pool, "create_agent_sync", _fake_create_agent_sync)
+ monkeypatch.setattr(agent_pool, "get_or_create_agent_id", lambda **_: "agent-2")
+ monkeypatch.setattr(Path, "exists", lambda self: False)
+
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ agent_pool={},
+ thread_repo=_ThreadRepo(),
+ thread_cwd={},
+ thread_sandbox={},
+ )
+ )
+
+ await agent_pool.get_or_create_agent(cast(Any, app), "local", thread_id="thread-2")
+
+ assert captured["workspace_root"] is None
+
+
+@pytest.mark.asyncio
+async def test_get_or_create_agent_honors_fresh_local_thread_cwd_even_when_missing(monkeypatch: pytest.MonkeyPatch, tmp_path):
+ captured: dict[str, object] = {}
+ requested = tmp_path / "fresh-workspace"
+
+ def _fake_create_agent_sync(
+ sandbox_name: str,
+ workspace_root=None,
+ model_name: str | None = None,
+ agent: str | None = None,
+ bundle_dir=None,
+ thread_repo=None,
+ entity_repo=None,
+ member_repo=None,
+ queue_manager=None,
+ chat_repos=None,
+ extra_allowed_paths=None,
+ web_app=None,
+ ) -> object:
+ captured["workspace_root"] = workspace_root
+ return SimpleNamespace()
+
+ class _ThreadRepo:
+ def get_by_id(self, thread_id: str):
+ return {
+ "id": thread_id,
+ "cwd": None,
+ "model": "leon:large",
+ }
+
+ monkeypatch.setattr(agent_pool, "create_agent_sync", _fake_create_agent_sync)
+ monkeypatch.setattr(agent_pool, "get_or_create_agent_id", lambda **_: "agent-3")
+
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ agent_pool={},
+ thread_repo=_ThreadRepo(),
+ thread_cwd={"thread-3": str(requested)},
+ thread_sandbox={},
+ )
+ )
+
+ await agent_pool.get_or_create_agent(cast(Any, app), "local", thread_id="thread-3")
+
+ assert captured["workspace_root"] == requested.resolve()
+ assert requested.is_dir()
+ assert app.state.thread_cwd["thread-3"] == str(requested.resolve())
+
+
+@pytest.mark.asyncio
+async def test_get_or_create_agent_passes_member_bundle_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
+ captured: dict[str, object] = {}
+ member_dir = tmp_path / "members" / "member-1"
+ member_dir.mkdir(parents=True)
+
+ def _fake_create_agent_sync(
+ sandbox_name: str,
+ workspace_root=None,
+ model_name: str | None = None,
+ agent: str | None = None,
+ bundle_dir=None,
+ thread_repo=None,
+ entity_repo=None,
+ member_repo=None,
+ queue_manager=None,
+ chat_repos=None,
+ extra_allowed_paths=None,
+ web_app=None,
+ ) -> object:
+ captured["bundle_dir"] = bundle_dir
+ return SimpleNamespace()
+
+ class _ThreadRepo:
+ def get_by_id(self, thread_id: str):
+ return {
+ "id": thread_id,
+ "cwd": None,
+ "model": "leon:large",
+ "member_id": "member-1",
+ "member_name": "Toad",
+ }
+
+ monkeypatch.setattr(agent_pool, "create_agent_sync", _fake_create_agent_sync)
+ monkeypatch.setattr(agent_pool, "get_or_create_agent_id", lambda **_: "agent-4")
+ monkeypatch.setattr(agent_pool, "preferred_existing_user_home_path", lambda *parts: member_dir)
+
+ app = SimpleNamespace(
+ state=SimpleNamespace(
+ agent_pool={},
+ thread_repo=_ThreadRepo(),
+ thread_cwd={},
+ thread_sandbox={},
+ )
+ )
+
+ await agent_pool.get_or_create_agent(cast(Any, app), "local", thread_id="thread-4")
+
+ assert captured["bundle_dir"] == member_dir.resolve()
diff --git a/tests/Unit/core/test_agent_service.py b/tests/Unit/core/test_agent_service.py
new file mode 100644
index 000000000..a5a8e530c
--- /dev/null
+++ b/tests/Unit/core/test_agent_service.py
@@ -0,0 +1,1525 @@
+"""Unit tests for AgentService sub-agent boundaries and policy."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+from pathlib import Path
+from types import SimpleNamespace
+from unittest.mock import AsyncMock
+
+import pytest
+
+from core.agents.service import (
+ AGENT_DISALLOWED,
+ AGENT_SCHEMA,
+ ASK_USER_QUESTION_SCHEMA,
+ EXPLORE_ALLOWED,
+ TASK_OUTPUT_SCHEMA,
+ AgentService,
+ _BashBackgroundRun,
+ _RunningTask,
+)
+from core.runtime.registry import ToolRegistry
+from core.runtime.runner import ToolRunner
+from core.runtime.state import AppState, BootstrapConfig, ToolUseContext
+from sandbox.manager import SandboxManager
+from sandbox.providers.local import LocalSessionProvider
+from sandbox.thread_context import get_current_thread_id, set_current_messages, set_current_thread_id
+from storage.contracts import EntityRow
+
+
+class _FakeRegistry:
+ def register(self, entry):
+ self.last_entry = entry
+
+
+class _FakeAgentRegistry:
+ def __init__(self) -> None:
+ self._latest_by_name_parent: dict[tuple[str, str | None], object] = {}
+
+ async def register(self, entry):
+ self.entry = entry
+
+ async def update_status(self, agent_id: str, status: str):
+ self.last_status = (agent_id, status)
+
+ async def get_latest_by_name_and_parent(self, name: str, parent_agent_id: str | None):
+ return self._latest_by_name_parent.get((name, parent_agent_id))
+
+
+class _FakeThreadRepo:
+ def __init__(self, rows: dict[str, dict] | None = None):
+ self.rows = rows or {}
+ self.created: list[dict] = []
+
+ def get_by_id(self, thread_id: str):
+ return self.rows.get(thread_id)
+
+ def get_next_branch_index(self, member_id: str) -> int:
+ branch_indexes = [int(row["branch_index"]) for row in self.rows.values() if row["member_id"] == member_id]
+ return (max(branch_indexes) if branch_indexes else 0) + 1
+
+ def create(self, thread_id: str, member_id: str, sandbox_type: str, cwd: str | None, created_at: float, **extra):
+ row = {
+ "id": thread_id,
+ "member_id": member_id,
+ "sandbox_type": sandbox_type,
+ "cwd": cwd,
+ "model": extra.get("model"),
+ "is_main": bool(extra.get("is_main", False)),
+ "branch_index": int(extra["branch_index"]),
+ "created_at": created_at,
+ }
+ self.rows[thread_id] = row
+ self.created.append(row)
+
+
+class _FakeEntityRepo:
+ def __init__(self):
+ self.rows_by_thread: dict[str, EntityRow] = {}
+
+ def create(self, row: EntityRow):
+ self.rows_by_thread[row.thread_id] = row
+
+ def get_by_thread_id(self, thread_id: str):
+ return self.rows_by_thread.get(thread_id)
+
+
+class _FakeMemberRepo:
+ def __init__(self, names: dict[str, str]):
+ self._names = names
+
+ def get_by_id(self, member_id: str):
+ name = self._names.get(member_id)
+ if name is None:
+ return None
+ return SimpleNamespace(id=member_id, name=name, avatar=None)
+
+
+class _FakeChildAgent:
+ def __init__(self, workspace_root: Path, model_name: str):
+ self.workspace_root = workspace_root
+ self.model_name = model_name
+ self._bootstrap = BootstrapConfig(workspace_root=workspace_root, model_name=model_name)
+ self.apply_fork_calls: list[tuple[BootstrapConfig, ToolUseContext | None]] = []
+ self.cleanup_calls = 0
+ self.closed = False
+ self.close_kwargs: dict[str, object] = {}
+ self._agent_service = SimpleNamespace(
+ _parent_bootstrap=None,
+ _parent_tool_context=None,
+ cleanup_background_runs=self._cleanup_background_runs,
+ )
+ self.agent = SimpleNamespace(astream=self._astream)
+
+ async def ainit(self):
+ return None
+
+ async def _astream(self, *args, **kwargs):
+ if False:
+ yield None
+ return
+
+ async def _cleanup_background_runs(self):
+ self.cleanup_calls += 1
+
+ def close(self, **kwargs):
+ self.closed = True
+ self.close_kwargs = kwargs
+ return None
+
+ def apply_forked_child_context(
+ self,
+ bootstrap: BootstrapConfig,
+ *,
+ tool_context: ToolUseContext | None = None,
+ ) -> None:
+ self.apply_fork_calls.append((bootstrap, tool_context))
+ self._bootstrap = bootstrap
+ self.agent._bootstrap = bootstrap
+ self._agent_service._parent_bootstrap = bootstrap
+ if tool_context is not None:
+ self._agent_service._parent_tool_context = tool_context
+ self.agent._tool_abort_controller = tool_context.abort_controller
+
+
+class _FakeAsyncCommand:
+ def __init__(self):
+ self.done = False
+ self.stdout_buffer = []
+ self.stderr_buffer = []
+ self.exit_code = None
+ self.process = SimpleNamespace(terminate=self._terminate, kill=self._kill, wait=self._wait)
+ self.terminated = False
+ self.killed = False
+ self.wait_calls = 0
+
+ def _terminate(self):
+ self.terminated = True
+ self.done = True
+
+ def _kill(self):
+ self.killed = True
+ self.done = True
+
+ async def _wait(self):
+ self.wait_calls += 1
+ return 0
+
+
+def _make_parent_context(tmp_path: Path, model_name: str = "gpt-parent") -> ToolUseContext:
+ parent_state = AppState(turn_count=1)
+ return ToolUseContext(
+ bootstrap=BootstrapConfig(workspace_root=tmp_path, model_name=model_name),
+ get_app_state=parent_state.get_state,
+ set_app_state=parent_state.set_state,
+ set_app_state_for_tasks=parent_state.set_state,
+ read_file_state={"/tmp/readme.md": {"partial": False}},
+ loaded_nested_memory_paths={"/tmp/memory.md"},
+ discovered_skill_names={"skill-a"},
+ nested_memory_attachment_triggers={"turn-a"},
+ messages=["hello"],
+ )
+
+
+def _make_service(tmp_path: Path, **kwargs) -> AgentService:
+ tool_registry = kwargs.pop("tool_registry", None) or _FakeRegistry()
+ agent_registry = kwargs.pop("agent_registry", None) or _FakeAgentRegistry()
+ model_name = kwargs.pop("model_name", "gpt-test")
+ return AgentService(
+ tool_registry=tool_registry,
+ agent_registry=agent_registry,
+ workspace_root=tmp_path,
+ model_name=model_name,
+ **kwargs,
+ )
+
+
+def _agent_tool_json(result) -> dict:
+ content = getattr(result, "content", result)
+ return json.loads(content)
+
+
+async def _sleep_forever():
+ while True:
+ await asyncio.sleep(3600)
+
+
+@pytest.mark.asyncio
+async def test_task_output_reports_running_command_honestly(tmp_path):
+ service = _make_service(tmp_path)
+ async_cmd = _FakeAsyncCommand()
+ service._tasks["cmd_test123"] = _BashBackgroundRun(async_cmd, "echo hello")
+
+ payload = json.loads(await service._handle_task_output("cmd_test123", block=False))
+
+ assert payload == {
+ "task_id": "cmd_test123",
+ "status": "running",
+ "message": "Command is still running.",
+ }
+
+
+@pytest.mark.asyncio
+async def test_task_output_keeps_agent_running_message_for_agent_tasks(tmp_path):
+ service = _make_service(tmp_path)
+ task = asyncio.create_task(_sleep_forever())
+ service._tasks["task_agent123"] = _RunningTask(
+ task=task,
+ agent_id="agent-1",
+ thread_id="thread-1",
+ )
+
+ try:
+ payload = json.loads(await service._handle_task_output("task_agent123", block=False))
+ finally:
+ task.cancel()
+ with pytest.raises(asyncio.CancelledError):
+ await task
+
+ assert payload == {
+ "task_id": "task_agent123",
+ "status": "running",
+ "message": "Agent is still running.",
+ }
+
+
+@pytest.mark.asyncio
+async def test_task_output_times_out_when_blocking_wait_expires(tmp_path):
+ service = _make_service(tmp_path)
+ task = asyncio.create_task(_sleep_forever())
+ service._tasks["task_agent123"] = _RunningTask(
+ task=task,
+ agent_id="agent-1",
+ thread_id="thread-1",
+ )
+
+ try:
+ payload = json.loads(await service._handle_task_output("task_agent123", timeout=1))
+ finally:
+ task.cancel()
+ with pytest.raises(asyncio.CancelledError):
+ await task
+
+ assert payload == {
+ "task_id": "task_agent123",
+ "status": "timeout",
+ "message": "Agent is still running.",
+ }
+
+
+@pytest.mark.asyncio
+async def test_run_agent_applies_forked_bootstrap_to_child_agent(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _FakeChildAgent(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+ service._parent_bootstrap = BootstrapConfig(
+ workspace_root=Path("/workspace"),
+ original_cwd=Path("/launcher"),
+ project_root=Path("/workspace/project"),
+ cwd=Path("/workspace/project/src"),
+ model_name="gpt-parent",
+ api_key="sk-parent",
+ extra_allowed_paths=["/shared"],
+ total_cost_usd=1.5,
+ total_tool_duration_ms=77,
+ model_provider="openai",
+ base_url="https://api.example.com/v1",
+ context_limit=12345,
+ )
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ child = created[0]
+ assert child._bootstrap.original_cwd == Path("/launcher")
+ assert child._bootstrap.project_root == Path("/workspace/project")
+ assert child._bootstrap.cwd == Path("/workspace/project/src")
+ assert child._bootstrap.extra_allowed_paths == ["/shared"]
+ assert child._bootstrap.parent_session_id == service._parent_bootstrap.session_id
+ assert child._bootstrap.session_id != service._parent_bootstrap.session_id
+ assert child._bootstrap.total_cost_usd == 1.5
+ assert child._bootstrap.total_tool_duration_ms == 77
+ assert child._bootstrap.model_provider == "openai"
+ assert child._bootstrap.base_url == "https://api.example.com/v1"
+ assert child._bootstrap.context_limit == 12345
+
+
+@pytest.mark.asyncio
+async def test_run_agent_applies_isolated_tool_context_to_child_agent_service(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _FakeChildAgent(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+ parent_context = _make_parent_context(tmp_path)
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ parent_tool_context=parent_context,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ child_context = created[0]._agent_service._parent_tool_context
+ assert child_context is not None
+ assert child_context is not parent_context
+ assert child_context.bootstrap.parent_session_id == parent_context.bootstrap.session_id
+ child_context.set_app_state(lambda prev: prev.model_copy(update={"turn_count": 9}))
+ assert parent_context.get_app_state().turn_count == 1
+ child_context.set_app_state_for_tasks(lambda prev: prev.model_copy(update={"turn_count": 9}))
+ assert parent_context.get_app_state().turn_count == 9
+
+
+@pytest.mark.asyncio
+async def test_run_agent_uses_explicit_child_fork_wiring_api(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _FakeChildAgent(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+ parent_context = _make_parent_context(tmp_path)
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ parent_tool_context=parent_context,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ assert len(created[0].apply_fork_calls) == 1
+ applied_bootstrap, applied_context = created[0].apply_fork_calls[0]
+ assert applied_bootstrap is created[0]._bootstrap
+ assert applied_context is created[0]._agent_service._parent_tool_context
+
+
+@pytest.mark.asyncio
+async def test_run_agent_uses_injected_child_agent_factory(tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ def fake_child_agent_factory(*, model_name, workspace_root, **kwargs):
+ child = _FakeChildAgent(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ service = _make_service(tmp_path, child_agent_factory=fake_child_agent_factory)
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ assert len(created) == 1
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_fork_context_uses_parent_tool_context_messages(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ class _CapturingChild(_FakeChildAgent):
+ async def _astream(self, payload, *args, **kwargs):
+ captured["messages"] = payload["messages"]
+ if False:
+ yield None
+ return
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ return _CapturingChild(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry)
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {"prompt": "inspect", "description": "inspect workspace", "fork_context": True},
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path),
+ )
+
+ result = await runner.awrap_tool_call(request, AsyncMock())
+
+ assert result.content == "(Agent completed with no text output)"
+ assert captured["messages"] == [
+ "hello",
+ {
+ "role": "user",
+ "content": (
+ "\n\n### ENTERING SUB-AGENT ROUTINE ###\n"
+ "Messages above are from the parent thread (read-only context).\n"
+ "Only complete the specific task assigned below.\n\n"
+ "inspect"
+ ),
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_fork_context_treats_empty_parent_messages_as_authoritative(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ class _CapturingChild(_FakeChildAgent):
+ async def _astream(self, payload, *args, **kwargs):
+ captured["messages"] = payload["messages"]
+ if False:
+ yield None
+ return
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ return _CapturingChild(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+ set_current_messages([{"role": "user", "content": "AMBIENT_LEAK"}])
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry)
+ runner = ToolRunner(registry=registry)
+ parent_context = _make_parent_context(tmp_path)
+ parent_context.messages = []
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {"prompt": "inspect", "description": "inspect workspace", "fork_context": True},
+ "id": "tc-1",
+ },
+ state=parent_context,
+ )
+
+ result = await runner.awrap_tool_call(request, AsyncMock())
+
+ assert result.content == "(Agent completed with no text output)"
+ assert captured["messages"] == [
+ {
+ "role": "user",
+ "content": (
+ "\n\n### ENTERING SUB-AGENT ROUTINE ###\n"
+ "Messages above are from the parent thread (read-only context).\n"
+ "Only complete the specific task assigned below.\n\n"
+ "inspect"
+ ),
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_run_agent_rolls_child_bootstrap_costs_back_into_parent_bootstrap(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ class _CostReportingChild(_FakeChildAgent):
+ async def _astream(self, *args, **kwargs):
+ self._bootstrap.total_cost_usd = 9.75
+ self._bootstrap.total_tool_duration_ms = 222
+ if False:
+ yield None
+ return
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _CostReportingChild(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+ service._parent_bootstrap = BootstrapConfig(
+ workspace_root=Path("/workspace"),
+ model_name="gpt-parent",
+ total_cost_usd=1.5,
+ total_tool_duration_ms=77,
+ )
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ assert created[0]._bootstrap.total_cost_usd == 9.75
+ assert created[0]._bootstrap.total_tool_duration_ms == 222
+ assert service._parent_bootstrap.total_cost_usd == 9.75
+ assert service._parent_bootstrap.total_tool_duration_ms == 222
+
+
+@pytest.mark.asyncio
+async def test_run_agent_preserves_concurrent_parent_and_child_bootstrap_growth(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ class _ConcurrentCostChild(_FakeChildAgent):
+ async def _astream(self, *args, **kwargs):
+ service._parent_bootstrap.total_cost_usd = 2.0
+ service._parent_bootstrap.total_tool_duration_ms = 20
+ self._bootstrap.total_cost_usd = 1.5
+ self._bootstrap.total_tool_duration_ms = 15
+ if False:
+ yield None
+ return
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _ConcurrentCostChild(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+ service._parent_bootstrap = BootstrapConfig(
+ workspace_root=Path("/workspace"),
+ model_name="gpt-parent",
+ total_cost_usd=1.0,
+ total_tool_duration_ms=10,
+ )
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ assert created[0]._bootstrap.total_cost_usd == 1.5
+ assert created[0]._bootstrap.total_tool_duration_ms == 15
+ assert service._parent_bootstrap.total_cost_usd == 2.5
+ assert service._parent_bootstrap.total_tool_duration_ms == 25
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_live_runner_path_passes_isolated_tool_context_to_child(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _FakeChildAgent(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry)
+ runner = ToolRunner(registry=registry)
+ parent_context = _make_parent_context(tmp_path)
+ request = SimpleNamespace(
+ tool_call={"name": "Agent", "args": {"prompt": "do work", "description": "do work"}, "id": "tc-1"},
+ state=parent_context,
+ )
+
+ result = await runner.awrap_tool_call(request, AsyncMock())
+
+ assert result.content == "(Agent completed with no text output)"
+ child_context = created[0]._agent_service._parent_tool_context
+ assert child_context is not None
+ assert child_context.bootstrap.parent_session_id == parent_context.bootstrap.session_id
+ child_context.set_app_state(lambda prev: prev.model_copy(update={"turn_count": 9}))
+ assert parent_context.get_app_state().turn_count == 1
+
+
+@pytest.mark.asyncio
+async def test_run_agent_without_fork_context_does_not_inject_parent_messages(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ class _CapturingChild(_FakeChildAgent):
+ async def _astream(self, payload, *args, **kwargs):
+ captured["messages"] = payload["messages"]
+ if False:
+ yield None
+ return
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ return _CapturingChild(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+ parent_context = _make_parent_context(tmp_path)
+ parent_context.messages = [
+ {
+ "role": "user",
+ "content": "PARENT_CONTROL_PROMPT",
+ }
+ ]
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="child task only",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ parent_tool_context=parent_context,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ assert captured["messages"] == [{"role": "user", "content": "child task only"}]
+
+
+@pytest.mark.asyncio
+async def test_run_agent_child_tool_context_deep_clones_read_file_state(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _FakeChildAgent(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+ parent_context = _make_parent_context(tmp_path)
+ parent_context.read_file_state = {"/tmp/readme.md": {"partial": False, "meta": {"seen": 1}}}
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ parent_tool_context=parent_context,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ child_context = created[0]._agent_service._parent_tool_context
+ child_context.read_file_state["/tmp/readme.md"]["partial"] = True
+ child_context.read_file_state["/tmp/readme.md"]["meta"]["seen"] = 9
+ assert parent_context.read_file_state["/tmp/readme.md"] == {
+ "partial": False,
+ "meta": {"seen": 1},
+ }
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_live_runner_path_applies_role_specific_tool_filters(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["model_name"] = model_name
+ captured["workspace_root"] = Path(workspace_root)
+ captured["kwargs"] = kwargs
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry, model_name="gpt-parent")
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {"prompt": "inspect", "description": "inspect workspace", "subagent_type": "explore"},
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path, model_name="gpt-parent"),
+ )
+
+ result = await runner.awrap_tool_call(request, AsyncMock())
+
+ assert result.content == "(Agent completed with no text output)"
+ assert captured["model_name"] == "gpt-parent"
+ assert captured["kwargs"]["agent"] == "explore"
+ assert captured["kwargs"]["allowed_tools"] == EXPLORE_ALLOWED
+ assert captured["kwargs"]["extra_blocked_tools"] == AGENT_DISALLOWED
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_model_priority_prefers_env_over_tool_frontmatter_and_parent(monkeypatch, tmp_path):
+ agent_dir = tmp_path / ".leon" / "agents"
+ agent_dir.mkdir(parents=True)
+ (agent_dir / "explore.md").write_text(
+ "---\nname: explore\nmodel: frontmatter-model\ntools:\n - Read\n---\nfrontmatter prompt\n",
+ encoding="utf-8",
+ )
+ captured: dict[str, object] = {}
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["model_name"] = model_name
+ captured["kwargs"] = kwargs
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+ monkeypatch.setenv("CLAUDE_CODE_SUBAGENT_MODEL", "env-model")
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry, model_name="parent-model")
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {
+ "prompt": "inspect",
+ "description": "inspect workspace",
+ "subagent_type": "explore",
+ "model": "tool-model",
+ },
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path, model_name="parent-model"),
+ )
+
+ await runner.awrap_tool_call(request, AsyncMock())
+
+ assert captured["model_name"] == "env-model"
+ assert captured["kwargs"]["agent"] == "explore"
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_model_priority_prefers_tool_over_frontmatter_and_parent(monkeypatch, tmp_path):
+ agent_dir = tmp_path / ".leon" / "agents"
+ agent_dir.mkdir(parents=True)
+ (agent_dir / "explore.md").write_text(
+ "---\nname: explore\nmodel: frontmatter-model\ntools:\n - Read\n---\nfrontmatter prompt\n",
+ encoding="utf-8",
+ )
+ captured: dict[str, object] = {}
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["model_name"] = model_name
+ captured["kwargs"] = kwargs
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry, model_name="parent-model")
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {
+ "prompt": "inspect",
+ "description": "inspect workspace",
+ "subagent_type": "explore",
+ "model": "tool-model",
+ },
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path, model_name="parent-model"),
+ )
+
+ await runner.awrap_tool_call(request, AsyncMock())
+
+ assert captured["model_name"] == "tool-model"
+ assert captured["kwargs"]["agent"] == "explore"
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_model_default_literal_inherits_parent_model(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["model_name"] = model_name
+ captured["kwargs"] = kwargs
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry, model_name="parent-model")
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {
+ "prompt": "inspect",
+ "description": "inspect workspace",
+ "subagent_type": "explore",
+ "model": "default",
+ },
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path, model_name="parent-model"),
+ )
+
+ await runner.awrap_tool_call(request, AsyncMock())
+
+ assert captured["model_name"] == "parent-model"
+ assert captured["kwargs"]["agent"] == "explore"
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_model_inherit_literal_inherits_parent_model(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["model_name"] = model_name
+ captured["kwargs"] = kwargs
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry, model_name="parent-model")
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {
+ "prompt": "inspect",
+ "description": "inspect workspace",
+ "subagent_type": "explore",
+ "model": "inherit",
+ },
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path, model_name="parent-model"),
+ )
+
+ await runner.awrap_tool_call(request, AsyncMock())
+
+ assert captured["model_name"] == "parent-model"
+ assert captured["kwargs"]["agent"] == "explore"
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_inherited_default_bootstrap_model_uses_parent_service_model(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["model_name"] = model_name
+ captured["kwargs"] = kwargs
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry, model_name="parent-service-model")
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {"prompt": "inspect", "description": "inspect workspace", "subagent_type": "explore"},
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path, model_name="default"),
+ )
+
+ await runner.awrap_tool_call(request, AsyncMock())
+
+ assert captured["model_name"] == "parent-service-model"
+ assert captured["kwargs"]["agent"] == "explore"
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_model_priority_prefers_frontmatter_over_parent(monkeypatch, tmp_path):
+ agent_dir = tmp_path / ".leon" / "agents"
+ agent_dir.mkdir(parents=True)
+ (agent_dir / "explore.md").write_text(
+ "---\nname: explore\nmodel: frontmatter-model\ntools:\n - Read\n---\nfrontmatter prompt\n",
+ encoding="utf-8",
+ )
+ captured: dict[str, object] = {}
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["model_name"] = model_name
+ captured["kwargs"] = kwargs
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry, model_name="parent-model")
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {"prompt": "inspect", "description": "inspect workspace", "subagent_type": "explore"},
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path, model_name="parent-model"),
+ )
+
+ await runner.awrap_tool_call(request, AsyncMock())
+
+ assert captured["model_name"] == "frontmatter-model"
+ assert captured["kwargs"]["agent"] == "explore"
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_model_priority_inherits_parent_when_no_env_tool_or_frontmatter(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["model_name"] = model_name
+ captured["kwargs"] = kwargs
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry, model_name="service-model")
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {"prompt": "inspect", "description": "inspect workspace", "subagent_type": "explore"},
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path, model_name="parent-model"),
+ )
+
+ await runner.awrap_tool_call(request, AsyncMock())
+
+ assert captured["model_name"] == "parent-model"
+ assert captured["kwargs"]["agent"] == "explore"
+
+
+@pytest.mark.asyncio
+async def test_cleanup_background_runs_cancels_pending_agent_and_shell_runs(tmp_path):
+ service = _make_service(tmp_path)
+ agent_task = asyncio.create_task(_sleep_forever())
+ shell_cmd = _FakeAsyncCommand()
+ service._tasks["agent-task"] = _RunningTask(
+ task=agent_task,
+ agent_id="agent-task",
+ thread_id="subagent-agent-task",
+ description="agent task",
+ )
+ service._tasks["bash-task"] = _BashBackgroundRun(
+ async_cmd=shell_cmd,
+ command="sleep 999",
+ description="bash task",
+ )
+
+ await service.cleanup_background_runs()
+
+ assert agent_task.cancelled() is True
+ assert shell_cmd.terminated is True
+ assert shell_cmd.wait_calls == 1
+ assert service._tasks == {}
+
+
+@pytest.mark.asyncio
+async def test_cleanup_background_runs_does_not_relabel_completed_agent_run(tmp_path):
+ registry = _FakeAgentRegistry()
+ service = _make_service(tmp_path, agent_registry=registry)
+ completed_task = asyncio.create_task(asyncio.sleep(0, result="done"))
+ await completed_task
+ service._tasks["agent-task"] = _RunningTask(
+ task=completed_task,
+ agent_id="agent-task",
+ thread_id="subagent-agent-task",
+ description="agent task",
+ )
+
+ await service.cleanup_background_runs()
+
+ assert getattr(registry, "last_status", None) is None
+ assert service._tasks == {}
+
+
+@pytest.mark.asyncio
+async def test_run_agent_cleans_up_child_background_runs_before_close(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _FakeChildAgent(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-task-1",
+ prompt="hello",
+ subagent_type="explore",
+ max_turns=None,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ assert created[0].cleanup_calls == 1
+ assert created[0].closed is True
+
+
+@pytest.mark.asyncio
+async def test_run_agent_links_child_abort_controller_to_parent_tool_context(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _FakeChildAgent(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+ parent_context = _make_parent_context(tmp_path)
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-task-1",
+ prompt="hello",
+ subagent_type="explore",
+ max_turns=None,
+ parent_tool_context=parent_context,
+ )
+
+ assert result == "(Agent completed with no text output)"
+
+ child_context = created[0]._agent_service._parent_tool_context
+ assert child_context is not None
+ assert getattr(created[0].agent, "_tool_abort_controller", None) is child_context.abort_controller
+
+ parent_context.abort_controller.abort()
+
+ assert child_context.abort_controller.is_aborted() is True
+
+
+@pytest.mark.asyncio
+async def test_run_agent_reuses_parent_lease_for_child_thread_terminal(monkeypatch, tmp_path, temp_db):
+ created: list[_FakeChildAgent] = []
+ observed: dict[str, str] = {}
+ parent_thread_id = "parent-thread"
+ child_thread_id = "subagent-child"
+
+ manager = SandboxManager(
+ provider=LocalSessionProvider(default_cwd=str(tmp_path)),
+ db_path=temp_db,
+ )
+ monkeypatch.setenv("LEON_SANDBOX_DB_PATH", str(temp_db))
+ monkeypatch.setattr(manager, "_setup_mounts", lambda thread_id: {"source": object(), "remote_path": str(tmp_path)})
+ monkeypatch.setattr(manager, "_sync_to_sandbox", lambda *args, **kwargs: None)
+
+ parent_capability = manager.get_sandbox(parent_thread_id)
+ parent_terminal_id = parent_capability._session.terminal.terminal_id
+ parent_lease_id = parent_capability._session.lease.lease_id
+
+ class _LeaseCapturingChild(_FakeChildAgent):
+ async def _astream(self, *args, **kwargs):
+ child_capability = manager.get_sandbox(get_current_thread_id())
+ observed["child_terminal_id"] = child_capability._session.terminal.terminal_id
+ observed["child_lease_id"] = child_capability._session.lease.lease_id
+ if False:
+ yield None
+ return
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _LeaseCapturingChild(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+ set_current_thread_id(parent_thread_id)
+
+ service = _make_service(tmp_path)
+
+ try:
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id=child_thread_id,
+ prompt="hello",
+ subagent_type="explore",
+ max_turns=None,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ assert created
+ assert observed["child_terminal_id"] != parent_terminal_id
+ assert observed["child_lease_id"] == parent_lease_id
+ finally:
+ manager.close()
+
+
+@pytest.mark.asyncio
+async def test_run_agent_inherits_parent_sandbox_when_forking_child(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["model_name"] = model_name
+ captured["workspace_root"] = Path(workspace_root)
+ captured["sandbox"] = kwargs.get("sandbox")
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+ service._parent_bootstrap = BootstrapConfig(
+ workspace_root=Path("/home/daytona"),
+ original_cwd=Path("/home/daytona"),
+ project_root=Path("/home/daytona"),
+ cwd=Path("/home/daytona"),
+ model_name="gpt-parent",
+ sandbox_type="daytona_selfhost",
+ )
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ assert captured["workspace_root"] == Path("/home/daytona")
+ assert captured["sandbox"] == "daytona_selfhost"
+
+
+@pytest.mark.asyncio
+async def test_run_agent_child_cleanup_skips_sandbox_close(monkeypatch, tmp_path):
+ created: list[_FakeChildAgent] = []
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ child = _FakeChildAgent(Path(workspace_root), model_name)
+ created.append(child)
+ return child
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ service = _make_service(tmp_path)
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ )
+
+ assert result == "(Agent completed with no text output)"
+ assert created[0].closed is True
+ assert created[0].close_kwargs == {"cleanup_sandbox": False}
+
+
+@pytest.mark.asyncio
+async def test_handle_agent_registers_subagent_thread_metadata_before_return(monkeypatch, tmp_path):
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ thread_repo = _FakeThreadRepo(
+ rows={
+ "parent-thread": {
+ "id": "parent-thread",
+ "member_id": "member-1",
+ "sandbox_type": "daytona_selfhost",
+ "cwd": "/home/daytona",
+ "model": "gpt-parent",
+ "is_main": True,
+ "branch_index": 0,
+ "created_at": 1.0,
+ }
+ }
+ )
+ entity_repo = _FakeEntityRepo()
+ member_repo = _FakeMemberRepo({"member-1": "Toad"})
+ service = _make_service(
+ tmp_path,
+ thread_repo=thread_repo,
+ entity_repo=entity_repo,
+ member_repo=member_repo,
+ )
+
+ set_current_thread_id("parent-thread")
+ try:
+ raw = await service._handle_agent(
+ prompt="do work",
+ name="worker-1",
+ run_in_background=True,
+ )
+ payload = _agent_tool_json(raw)
+ child_thread_id = payload["thread_id"]
+
+ child_thread = thread_repo.get_by_id(child_thread_id)
+ child_entity = entity_repo.get_by_thread_id(child_thread_id)
+
+ assert child_thread is not None
+ assert child_thread["member_id"] == "member-1"
+ assert child_thread["sandbox_type"] == "daytona_selfhost"
+ assert child_thread["cwd"] == "/home/daytona"
+ assert child_thread["is_main"] is False
+ assert child_thread["branch_index"] == 1
+ assert child_entity is not None
+ assert child_entity.id == child_thread_id
+ assert child_entity.member_id == "member-1"
+ assert child_entity.name == "worker-1"
+ finally:
+ await service.cleanup_background_runs()
+ set_current_thread_id("")
+
+
+@pytest.mark.asyncio
+async def test_handle_agent_reuses_existing_completed_child_thread_for_same_parent_and_name(monkeypatch, tmp_path):
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ thread_repo = _FakeThreadRepo(
+ rows={
+ "parent-thread": {
+ "id": "parent-thread",
+ "member_id": "member-1",
+ "sandbox_type": "daytona_selfhost",
+ "cwd": "/home/daytona",
+ "model": "gpt-parent",
+ "is_main": True,
+ "branch_index": 0,
+ "created_at": 1.0,
+ },
+ "subagent-existing": {
+ "id": "subagent-existing",
+ "member_id": "member-1",
+ "sandbox_type": "daytona_selfhost",
+ "cwd": "/home/daytona",
+ "model": "gpt-test",
+ "is_main": False,
+ "branch_index": 1,
+ "created_at": 2.0,
+ },
+ }
+ )
+ entity_repo = _FakeEntityRepo()
+ entity_repo.create(
+ EntityRow(
+ id="subagent-existing",
+ member_id="member-1",
+ thread_id="subagent-existing",
+ name="worker-1",
+ type="agent",
+ created_at=2.0,
+ )
+ )
+ registry = _FakeAgentRegistry()
+ registry._latest_by_name_parent[("worker-1", "parent-thread")] = SimpleNamespace(
+ agent_id="old-agent",
+ name="worker-1",
+ thread_id="subagent-existing",
+ status="completed",
+ parent_agent_id="parent-thread",
+ subagent_type="general",
+ )
+ service = _make_service(
+ tmp_path,
+ agent_registry=registry,
+ thread_repo=thread_repo,
+ entity_repo=entity_repo,
+ member_repo=_FakeMemberRepo({"member-1": "Toad"}),
+ )
+
+ set_current_thread_id("parent-thread")
+ try:
+ raw = await service._handle_agent(
+ prompt="continue work",
+ name="worker-1",
+ run_in_background=True,
+ )
+
+ payload = _agent_tool_json(raw)
+ assert payload["thread_id"] == "subagent-existing"
+ assert len(thread_repo.created) == 0
+ finally:
+ await service.cleanup_background_runs()
+ set_current_thread_id("")
+
+
+@pytest.mark.asyncio
+async def test_agent_tool_blocking_result_preserves_child_identity_metadata(monkeypatch, tmp_path):
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry)
+ runner = ToolRunner(registry=registry)
+ request = SimpleNamespace(
+ tool_call={
+ "name": "Agent",
+ "args": {"prompt": "inspect", "description": "inspect workspace"},
+ "id": "tc-1",
+ },
+ state=_make_parent_context(tmp_path),
+ )
+
+ result = await runner.awrap_tool_call(request, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert meta["task_id"]
+ assert meta["subagent_thread_id"].startswith("subagent-")
+
+
+@pytest.mark.asyncio
+async def test_run_agent_uses_live_child_thread_bridge_when_web_app_present(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ async def fake_run_child_thread_live(agent, thread_id, prompt, app, *, input_messages):
+ captured["agent"] = agent
+ captured["thread_id"] = thread_id
+ captured["prompt"] = prompt
+ captured["app"] = app
+ captured["input_messages"] = input_messages
+ return "LIVE_CHILD_DONE"
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ captured["child_web_app"] = kwargs.get("web_app")
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+ monkeypatch.setattr("backend.web.services.streaming_service.run_child_thread_live", fake_run_child_thread_live)
+
+ web_app = SimpleNamespace()
+ service = _make_service(tmp_path, web_app=web_app)
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt="do work",
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ )
+
+ assert result == "LIVE_CHILD_DONE"
+ assert captured["thread_id"] == "subagent-1"
+ assert captured["prompt"] == "do work"
+ assert captured["app"] is web_app
+ assert captured["child_web_app"] is web_app
+ assert len(captured["input_messages"]) == 1
+ assert captured["input_messages"][0]["role"] == "user"
+ assert captured["input_messages"][0]["content"] == "do work"
+ assert captured["agent"].cleanup_calls == 1
+ assert captured["agent"].closed is False
+
+
+@pytest.mark.asyncio
+async def test_run_agent_normalizes_workspace_suffix_in_child_prompt(monkeypatch, tmp_path):
+ captured: dict[str, object] = {}
+
+ async def fake_run_child_thread_live(agent, thread_id, prompt, app, *, input_messages):
+ captured["prompt"] = prompt
+ captured["input_messages"] = input_messages
+ return "LIVE_CHILD_DONE"
+
+ def fake_create_leon_agent(*, model_name, workspace_root, **kwargs):
+ return _FakeChildAgent(Path(workspace_root), model_name)
+
+ monkeypatch.setattr("core.runtime.agent.create_leon_agent", fake_create_leon_agent)
+ monkeypatch.setattr("backend.web.services.streaming_service.run_child_thread_live", fake_run_child_thread_live)
+
+ service = _make_service(tmp_path, web_app=SimpleNamespace())
+ raw_prompt = f"Inspect the workspace at {tmp_path}/current working directory. Read-only only. Report existing files."
+
+ result = await service._run_agent(
+ task_id="task-1",
+ agent_name="child",
+ thread_id="subagent-1",
+ prompt=raw_prompt,
+ subagent_type="general",
+ max_turns=None,
+ fork_context=False,
+ )
+
+ assert result == "LIVE_CHILD_DONE"
+ expected_prompt = f"Inspect the workspace at {tmp_path}. Read-only only. Report existing files."
+ assert captured["prompt"] == expected_prompt
+ assert captured["input_messages"][0]["content"] == expected_prompt
+
+
+def test_agent_schema_does_not_claim_general_has_full_tool_access():
+ description = AGENT_SCHEMA["description"]
+
+ assert "general (full tool access)" not in description
+ assert "general (broad tool access except Agent, TaskOutput, and TaskStop)" in description
+
+
+def test_agent_schema_requires_description():
+ assert AGENT_SCHEMA["parameters"]["required"] == ["prompt", "description"]
+
+
+def test_task_output_schema_exposes_block_and_timeout():
+ properties = TASK_OUTPUT_SCHEMA["parameters"]["properties"]
+
+ assert properties["block"]["default"] is True
+ assert properties["timeout"]["default"] == 30000
+ assert properties["timeout"]["maximum"] == 600000
+
+
+@pytest.mark.asyncio
+async def test_ask_user_question_requests_structured_question_payload(tmp_path):
+ registry = ToolRegistry()
+ _make_service(tmp_path, tool_registry=registry)
+ runner = ToolRunner(registry=registry)
+ app_state = AppState()
+ captured: dict[str, object] = {}
+
+ def request_permission(name, args, context, request, message):
+ captured["name"] = name
+ captured["args"] = dict(args)
+ captured["message"] = message
+ return {"request_id": "ask-1"}
+
+ request = SimpleNamespace(
+ tool_call={
+ "name": "AskUserQuestion",
+ "args": {
+ "questions": [
+ {
+ "header": "Color",
+ "question": "Which color should I use?",
+ "options": [
+ {"label": "Blue", "description": "Use blue"},
+ {"label": "Green", "description": "Use green"},
+ ],
+ }
+ ]
+ },
+ "id": "tc-1",
+ },
+ state=ToolUseContext(
+ bootstrap=BootstrapConfig(workspace_root=tmp_path, model_name="gpt-test"),
+ get_app_state=app_state.get_state,
+ set_app_state=app_state.set_state,
+ request_permission=request_permission,
+ ),
+ )
+
+ result = await runner.awrap_tool_call(request, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert meta["kind"] == "permission_request"
+ assert meta["request_id"] == "ask-1"
+ assert result.content == "User input required to continue."
+ assert captured["name"] == "AskUserQuestion"
+ assert captured["message"] == "Please answer the following questions so Leon can continue."
+ assert captured["args"] == {
+ "questions": [
+ {
+ "header": "Color",
+ "question": "Which color should I use?",
+ "options": [
+ {"label": "Blue", "description": "Use blue"},
+ {"label": "Green", "description": "Use green"},
+ ],
+ }
+ ]
+ }
+
+
+def test_ask_user_question_schema_requires_questions():
+ assert ASK_USER_QUESTION_SCHEMA["parameters"]["required"] == ["questions"]
diff --git a/tests/Unit/core/test_capability_async.py b/tests/Unit/core/test_capability_async.py
new file mode 100644
index 000000000..d07334c3d
--- /dev/null
+++ b/tests/Unit/core/test_capability_async.py
@@ -0,0 +1,195 @@
+import asyncio
+import uuid
+from pathlib import Path
+from types import SimpleNamespace
+
+from sandbox.base import LocalSandbox
+from sandbox.capability import SandboxCapability
+from sandbox.interfaces.executor import AsyncCommand, ExecuteResult
+from sandbox.thread_context import set_current_thread_id
+
+
+class _DummyState:
+ cwd = "/tmp"
+
+
+class _DummyTerminal:
+ terminal_id = "dummy-term"
+
+ def get_state(self):
+ return _DummyState()
+
+
+class _DummyRuntime:
+ def __init__(self):
+ self.commands: list[str] = []
+ self._async_commands: dict[str, AsyncCommand] = {}
+
+ async def execute(self, command: str, timeout=None):
+ self.commands.append(command)
+ await asyncio.sleep(0.01)
+ return ExecuteResult(exit_code=0, stdout=f"ok:{command}", stderr="")
+
+ async def start_command(self, command: str, cwd: str) -> AsyncCommand:
+ command_id = f"cmd_{uuid.uuid4().hex[:12]}"
+ result = await self.execute(command)
+ async_cmd = AsyncCommand(
+ command_id=command_id,
+ command_line=command,
+ cwd=cwd,
+ exit_code=result.exit_code,
+ done=True,
+ stdout_buffer=[result.stdout],
+ )
+ self._async_commands[command_id] = async_cmd
+ return async_cmd
+
+ async def get_command(self, command_id: str) -> AsyncCommand | None:
+ return self._async_commands.get(command_id)
+
+ async def wait_for_command(self, command_id: str, timeout: float | None = None) -> ExecuteResult | None:
+ cmd = self._async_commands.get(command_id)
+ if cmd is None:
+ return None
+ return ExecuteResult(
+ exit_code=cmd.exit_code or 0,
+ stdout="".join(cmd.stdout_buffer),
+ stderr="".join(cmd.stderr_buffer),
+ )
+
+
+class _DummySession:
+ def __init__(self):
+ self.terminal = _DummyTerminal()
+ self.runtime = _DummyRuntime()
+ self.touches = 0
+
+ def touch(self):
+ self.touches += 1
+
+
+async def _run_async_command_flow():
+ session = _DummySession()
+ capability = SandboxCapability(session)
+
+ async_cmd = await capability.command.execute_async("echo hi", cwd="/tmp/demo", env={"A": "1"})
+ assert async_cmd.command_id.startswith("cmd_")
+
+ status = await capability.command.get_status(async_cmd.command_id)
+ assert status is not None
+
+ result = await capability.command.wait_for(async_cmd.command_id, timeout=1.0)
+ assert result is not None
+ assert result.exit_code == 0
+ assert "echo hi" in result.stdout
+ assert session.touches > 0
+
+
+def test_command_wrapper_supports_execute_async():
+ asyncio.run(_run_async_command_flow())
+
+
+def test_local_sandbox_rebuilds_stale_closed_capability_before_execute_async(tmp_path):
+ root = Path(tmp_path)
+ thread_id = "thread-stale-session"
+ sandbox = LocalSandbox(str(root), db_path=root / "sandbox.db")
+ set_current_thread_id(thread_id)
+ capability = sandbox._get_capability()
+ stale_session_id = capability._session.session_id
+ sandbox.manager.session_manager.delete(stale_session_id, reason="test_close")
+
+ async def run():
+ async_cmd = await sandbox.shell().execute_async("sleep 0.01; echo hi")
+ result = await sandbox.shell().wait_for(async_cmd.command_id, timeout=1.0)
+ return async_cmd, result
+
+ async_cmd, result = asyncio.run(run())
+
+ assert capability._session.status == "closed"
+ refreshed = sandbox._get_capability()
+ assert refreshed._session.session_id != stale_session_id
+ assert async_cmd.command_id.startswith("cmd_")
+ assert result is not None
+ assert result.exit_code == 0
+ assert "hi" in result.stdout
+
+
+def test_filesystem_wrapper_auto_resumes_paused_lease_before_listing():
+ class _PausedLease:
+ def __init__(self):
+ self.observed_state = "paused"
+
+ def ensure_active_instance(self, _provider):
+ if self.observed_state == "paused":
+ raise RuntimeError("Sandbox lease lease-1 is paused. Resume before executing commands.")
+ return SimpleNamespace(instance_id="inst-1")
+
+ class _RemoteProvider:
+ def list_dir(self, instance_id: str, path: str):
+ assert instance_id == "inst-1"
+ assert path == "/home/daytona"
+ return [{"name": "demo.txt", "type": "file", "size": 7}]
+
+ lease = _PausedLease()
+ provider = _RemoteProvider()
+ resume_calls: list[tuple[str, str]] = []
+
+ class _RemoteSession:
+ def __init__(self):
+ self.thread_id = "thread-paused"
+ self.terminal = _DummyTerminal()
+ self.lease = lease
+ self.runtime = SimpleNamespace(provider=provider)
+ self.touches = 0
+
+ def touch(self):
+ self.touches += 1
+
+ session = _RemoteSession()
+ manager = SimpleNamespace(
+ resume_session=lambda thread_id, source="user_resume": (
+ resume_calls.append((thread_id, source)) or setattr(lease, "observed_state", "running") or True
+ )
+ )
+
+ capability = SandboxCapability(session, manager=manager)
+
+ result = capability.fs.list_dir("/home/daytona")
+
+ assert resume_calls == [("thread-paused", "auto_resume")]
+ assert [entry.name for entry in result.entries] == ["demo.txt"]
+ assert result.error is None
+
+
+def test_filesystem_wrapper_derives_remote_file_size_from_parent_listing():
+ class _Lease:
+ observed_state = "running"
+
+ def ensure_active_instance(self, _provider):
+ return SimpleNamespace(instance_id="inst-1")
+
+ class _RemoteProvider:
+ def list_dir(self, instance_id: str, path: str):
+ assert instance_id == "inst-1"
+ assert path == "/home/daytona"
+ return [
+ {"name": "demo.txt", "type": "file", "size": 42},
+ {"name": "nested", "type": "directory", "size": 0},
+ ]
+
+ class _RemoteSession:
+ def __init__(self):
+ self.thread_id = "thread-size"
+ self.terminal = _DummyTerminal()
+ self.lease = _Lease()
+ self.runtime = SimpleNamespace(provider=_RemoteProvider())
+ self.touches = 0
+
+ def touch(self):
+ self.touches += 1
+
+ capability = SandboxCapability(_RemoteSession())
+
+ assert capability.fs.file_size("/home/daytona/demo.txt") == 42
+ assert capability.fs.file_size("/home/daytona/missing.txt") is None
+ assert capability.fs.file_size("/") is None
diff --git a/tests/Unit/core/test_chat_tool_service.py b/tests/Unit/core/test_chat_tool_service.py
new file mode 100644
index 000000000..facf94e15
--- /dev/null
+++ b/tests/Unit/core/test_chat_tool_service.py
@@ -0,0 +1,161 @@
+from types import SimpleNamespace
+
+from langchain_core.messages import HumanMessage
+
+from core.agents.communication.chat_tool_service import ChatToolService
+from core.runtime.agent import LeonAgent
+from core.runtime.registry import ToolRegistry
+from storage.contracts import EntityRow, MemberRow, MemberType
+
+
+class _EntityRepo:
+ def __init__(self, entities: list[EntityRow]) -> None:
+ self._entities = {entity.id: entity for entity in entities}
+
+ def list_all(self) -> list[EntityRow]:
+ return list(self._entities.values())
+
+ def get_by_id(self, entity_id: str) -> EntityRow | None:
+ return self._entities.get(entity_id)
+
+
+class _MemberRepo:
+ def __init__(self, members: list[MemberRow]) -> None:
+ self._members = {member.id: member for member in members}
+
+ def get_by_id(self, member_id: str) -> MemberRow | None:
+ return self._members.get(member_id)
+
+ def list_all(self) -> list[MemberRow]:
+ return list(self._members.values())
+
+
+def test_chat_tool_registry_exposes_only_canonical_chat_surface() -> None:
+ registry = ToolRegistry()
+ ChatToolService(
+ registry,
+ user_id="m_agent",
+ owner_user_id="u_owner",
+ entity_repo=_EntityRepo([]),
+ chat_service=SimpleNamespace(),
+ chat_entity_repo=SimpleNamespace(),
+ chat_message_repo=SimpleNamespace(),
+ member_repo=_MemberRepo([]),
+ chat_event_bus=SimpleNamespace(),
+ runtime_fn=lambda: None,
+ )
+
+ for tool_name in ("list_chats", "read_messages", "send_message", "search_messages"):
+ assert registry.get(tool_name) is not None
+
+ assert registry.get("chats") is None
+ assert registry.get("read_message") is None
+ assert registry.get("search_message") is None
+ assert registry.get("directory") is None
+
+
+def test_compose_system_prompt_hardens_chat_reply_contract() -> None:
+ owner_entity = EntityRow(id="e_owner", type="human", member_id="u_owner", name="Owner", created_at=1.0)
+ agent_entity = EntityRow(id="e_agent", type="agent", member_id="m_agent", name="Helper", created_at=2.0)
+
+ agent = LeonAgent.__new__(LeonAgent)
+ agent._chat_repos = {
+ "user_id": "m_agent",
+ "owner_user_id": "u_owner",
+ "entity_repo": _EntityRepo([owner_entity, agent_entity]),
+ "member_repo": _MemberRepo(
+ [
+ MemberRow(id="u_owner", name="Owner", type=MemberType.HUMAN, created_at=1.0),
+ MemberRow(id="m_agent", name="Helper Member", type=MemberType.MYCEL_AGENT, owner_user_id="u_owner", created_at=2.0),
+ ]
+ ),
+ }
+ agent._build_system_prompt = lambda: "BASE"
+ agent.config = SimpleNamespace(system_prompt=None)
+
+ prompt = agent._compose_system_prompt()
+
+ assert "you MUST read it with read_messages()" in prompt
+ assert "prefer using that exact chat_id directly" in prompt
+ assert "you MUST call send_message()" in prompt
+ assert "Never claim you replied unless send_message() succeeded." in prompt
+ assert "directory" not in prompt
+
+
+def test_read_messages_validate_input_fills_missing_chat_id_from_latest_notification() -> None:
+ registry = ToolRegistry()
+ ChatToolService(
+ registry,
+ user_id="m_agent",
+ owner_user_id="u_owner",
+ entity_repo=_EntityRepo([]),
+ chat_service=SimpleNamespace(),
+ chat_entity_repo=SimpleNamespace(),
+ chat_message_repo=SimpleNamespace(),
+ member_repo=_MemberRepo([]),
+ chat_event_bus=SimpleNamespace(),
+ runtime_fn=lambda: None,
+ )
+ entry = registry.get("read_messages")
+ assert entry is not None
+ assert entry.validate_input is not None
+
+ request = SimpleNamespace(
+ state=SimpleNamespace(
+ messages=[
+ HumanMessage(
+ content=(
+ "\n"
+ "New message from alice in chat chat-123 (1 unread).\n"
+ 'Read it with read_messages(chat_id="chat-123").\n'
+ ""
+ ),
+ metadata={"source": "external", "notification_type": "chat"},
+ )
+ ]
+ )
+ )
+
+ args = entry.validate_input({"chat_id": "", "range": "-10:"}, request)
+
+ assert args == {"chat_id": "chat-123", "range": "-10:"}
+
+
+def test_send_message_validate_input_fills_missing_chat_id_from_latest_notification() -> None:
+ registry = ToolRegistry()
+ ChatToolService(
+ registry,
+ user_id="m_agent",
+ owner_user_id="u_owner",
+ entity_repo=_EntityRepo([]),
+ chat_service=SimpleNamespace(),
+ chat_entity_repo=SimpleNamespace(),
+ chat_message_repo=SimpleNamespace(),
+ member_repo=_MemberRepo([]),
+ chat_event_bus=SimpleNamespace(),
+ runtime_fn=lambda: None,
+ )
+ entry = registry.get("send_message")
+ assert entry is not None
+ assert entry.validate_input is not None
+
+ request = SimpleNamespace(
+ state=SimpleNamespace(
+ messages=[
+ HumanMessage(
+ content=(
+ "\n"
+ "New message from alice in chat chat-456 (1 unread).\n"
+ 'Read it with read_messages(chat_id="chat-456").\n'
+ 'Reply with send_message(chat_id="chat-456", content="...").\n'
+ ""
+ ),
+ metadata={"source": "external", "notification_type": "chat"},
+ )
+ ]
+ )
+ )
+
+ args = entry.validate_input({"content": "hi", "chat_id": ""}, request)
+
+ assert args == {"content": "hi", "chat_id": "chat-456"}
diff --git a/tests/test_command_middleware.py b/tests/Unit/core/test_command_middleware.py
similarity index 73%
rename from tests/test_command_middleware.py
rename to tests/Unit/core/test_command_middleware.py
index 05d64edf1..c48e0b681 100644
--- a/tests/test_command_middleware.py
+++ b/tests/Unit/core/test_command_middleware.py
@@ -5,10 +5,12 @@
import pytest
+from core.runtime.registry import ToolRegistry
from core.tools.command.base import AsyncCommand, BaseExecutor, ExecuteResult
from core.tools.command.dispatcher import get_executor, get_shell_info
from core.tools.command.hooks.dangerous_commands import DangerousCommandsHook
from core.tools.command.middleware import CommandMiddleware
+from core.tools.command.service import CommandService
class TestExecuteResult:
@@ -107,6 +109,36 @@ def test_block_rm_rf(self):
assert not result.allow
assert "SECURITY" in result.error_message
+ def test_allow_dangerous_text_inside_quotes(self):
+ hook = DangerousCommandsHook(verbose=False)
+ result = hook.check_command('echo "rm -rf /"', {})
+ assert result.allow
+
+ def test_allow_dangerous_text_inside_comment(self):
+ hook = DangerousCommandsHook(verbose=False)
+ result = hook.check_command("echo hi # rm -rf /", {})
+ assert result.allow
+
+ def test_block_obfuscated_dangerous_command_name_with_inline_quotes(self):
+ hook = DangerousCommandsHook(verbose=False)
+ result = hook.check_command('s"u"do echo hi', {})
+ assert not result.allow
+
+ def test_block_obfuscated_file_mutation_command_name_with_inline_quotes(self):
+ hook = DangerousCommandsHook(verbose=False)
+ result = hook.check_command('ch"mo"d 777 /tmp/x', {})
+ assert not result.allow
+
+ def test_block_ansi_c_quoted_obfuscation(self):
+ hook = DangerousCommandsHook(verbose=False)
+ result = hook.check_command("s$'udo' echo hi", {})
+ assert not result.allow
+
+ def test_block_locale_quoted_obfuscation(self):
+ hook = DangerousCommandsHook(verbose=False)
+ result = hook.check_command('$"chmod" 777 /tmp/x', {})
+ assert not result.allow
+
def test_block_sudo(self):
hook = DangerousCommandsHook()
result = hook.check_command("sudo apt install", {})
@@ -185,6 +217,29 @@ def store_completed_result(self, command_id: str, command_line: str, cwd: str, r
return None
+class _BlankErrorExecutor(BaseExecutor):
+ runtime_owns_cwd = True
+ shell_name = "bash"
+
+ class BlankCommandError(Exception):
+ pass
+
+ async def execute(self, command: str, cwd: str | None = None, timeout: float | None = None, env=None):
+ raise self.BlankCommandError()
+
+ async def execute_async(self, command: str, cwd: str | None = None, env=None):
+ raise self.BlankCommandError()
+
+ async def get_status(self, command_id: str):
+ return None
+
+ async def wait_for(self, command_id: str, timeout: float | None = None):
+ return None
+
+ def store_completed_result(self, command_id: str, command_line: str, cwd: str, result: ExecuteResult) -> None:
+ return None
+
+
class TestCommandStatusFormatting:
@pytest.mark.asyncio
async def test_running_status_strips_pty_prompt_echo_noise(self, tmp_path):
@@ -224,3 +279,25 @@ async def test_running_status_includes_stderr_chunks(self, tmp_path):
output_block = out.split("Output so far:\n", 1)[1]
assert "out" in output_block
assert "err" in output_block
+
+
+class TestFailLoudBlankExceptions:
+ @pytest.mark.asyncio
+ async def test_command_middleware_surfaces_exception_type_when_message_is_blank(self, tmp_path):
+ middleware = CommandMiddleware(workspace_root=tmp_path, executor=_BlankErrorExecutor(), verbose=False)
+
+ out = await middleware._execute_blocking("pwd", str(tmp_path), timeout=1)
+
+ assert out == "Error executing command: BlankCommandError"
+
+ @pytest.mark.asyncio
+ async def test_command_service_surfaces_exception_type_when_message_is_blank(self, tmp_path):
+ service = CommandService(
+ registry=ToolRegistry(),
+ workspace_root=tmp_path,
+ executor=_BlankErrorExecutor(),
+ )
+
+ out = await service._bash("pwd")
+
+ assert out == "Error executing command: BlankCommandError"
diff --git a/tests/test_event_bus.py b/tests/Unit/core/test_event_bus.py
similarity index 100%
rename from tests/test_event_bus.py
rename to tests/Unit/core/test_event_bus.py
diff --git a/tests/Unit/core/test_loop.py b/tests/Unit/core/test_loop.py
new file mode 100644
index 000000000..15135c05e
--- /dev/null
+++ b/tests/Unit/core/test_loop.py
@@ -0,0 +1,2886 @@
+"""Unit tests for core.runtime.loop QueryLoop."""
+
+import asyncio
+import json
+import tempfile
+from pathlib import Path
+from types import SimpleNamespace
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, RemoveMessage, SystemMessage, ToolMessage
+from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
+
+from core.runtime.loop import QueryLoop, _StreamingToolExecutor
+from core.runtime.middleware import AgentMiddleware
+from core.runtime.middleware.memory import MemoryMiddleware
+from core.runtime.middleware.monitor import AgentState
+from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry
+from core.runtime.state import AppState, BootstrapConfig, ToolPermissionState
+from storage.providers.sqlite.kernel import connect_sqlite_async
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def make_registry(*entries):
+ reg = ToolRegistry()
+ for e in entries:
+ reg.register(e)
+ return reg
+
+
+def make_loop(model, registry=None, middleware=None, max_turns=10, app_state=None, runtime=None, bootstrap=None, checkpointer=None):
+ return QueryLoop(
+ model=model,
+ system_prompt=SystemMessage(content="You are a test assistant."),
+ middleware=middleware or [],
+ checkpointer=checkpointer,
+ registry=registry or make_registry(),
+ app_state=app_state,
+ runtime=runtime,
+ bootstrap=bootstrap or BootstrapConfig(workspace_root=Path("/tmp"), model_name="test-model"),
+ max_turns=max_turns,
+ )
+
+
+class _MemoryCheckpointer:
+ def __init__(self):
+ self.store = {}
+
+ async def aget(self, cfg):
+ return self.store.get(cfg["configurable"]["thread_id"])
+
+ async def aput(self, cfg, checkpoint, metadata, new_versions):
+ self.store[cfg["configurable"]["thread_id"]] = checkpoint
+
+
+def mock_model_no_tools(text="Hello!"):
+ """Model that returns a plain AIMessage (no tool calls)."""
+ ai_msg = AIMessage(content=text)
+ model = MagicMock()
+ model.bind_tools.return_value = model
+ model.ainvoke = AsyncMock(return_value=ai_msg)
+ return model
+
+
+def mock_model_with_tool_call(tool_name="echo", args=None, call_id="tc-1", then_text="Done"):
+ """Model that first responds with a tool call, then responds with plain text."""
+ args = args or {"message": "hi"}
+ tool_call_msg = AIMessage(
+ content="",
+ tool_calls=[{"name": tool_name, "args": args, "id": call_id}],
+ )
+ final_msg = AIMessage(content=then_text)
+ model = MagicMock()
+ model.bind_tools.return_value = model
+ model.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
+ return model
+
+
+def mock_model_with_two_tool_turns():
+ first = AIMessage(content="", tool_calls=[{"name": "echo", "args": {"message": "one"}, "id": "tc-1"}])
+ second = AIMessage(content="", tool_calls=[{"name": "echo", "args": {"message": "two"}, "id": "tc-2"}])
+ final = AIMessage(content="done")
+ model = MagicMock()
+ model.bind_tools.return_value = model
+ model.ainvoke = AsyncMock(side_effect=[first, second, final])
+ return model
+
+
+def _make_summary_memory_middleware(*, context_limit=40, keep_recent_tokens=10, compaction_threshold=0.1):
+ summary_model = MagicMock()
+ summary_model.bind.return_value = summary_model
+ summary_model.ainvoke = AsyncMock(return_value=AIMessage(content="SUMMARY"))
+
+ memory = MemoryMiddleware(
+ context_limit=context_limit,
+ compaction_config=SimpleNamespace(reserve_tokens=0, keep_recent_tokens=keep_recent_tokens),
+ compaction_threshold=compaction_threshold,
+ )
+ memory.set_model(summary_model)
+ return memory, summary_model
+
+
+def _make_prompt_too_long_model(*responses):
+ model = MagicMock()
+ model.bind_tools.return_value = model
+ model.ainvoke = AsyncMock(side_effect=list(responses))
+ return model
+
+
+def make_inline_tool(name, handler, *, schema=None, is_concurrency_safe=True):
+ return ToolEntry(
+ name=name,
+ mode=ToolMode.INLINE,
+ schema=schema or {"name": name, "description": name, "parameters": {}},
+ handler=handler,
+ source="test",
+ is_concurrency_safe=is_concurrency_safe,
+ )
+
+
+def test_tool_use_context_get_app_state_is_live_closure():
+ app_state = AppState(turn_count=1)
+ loop = make_loop(mock_model_no_tools(), app_state=app_state)
+
+ ctx = loop._build_tool_use_context([])
+ assert ctx is not None
+ assert ctx.get_app_state().turn_count == 1
+
+ app_state.set_state(lambda prev: prev.model_copy(update={"turn_count": 7}))
+
+ assert ctx.get_app_state().turn_count == 7
+
+
+def test_tool_use_context_session_refs_persist_across_turns():
+ app_state = AppState()
+ loop = make_loop(mock_model_no_tools(), app_state=app_state)
+
+ ctx1 = loop._build_tool_use_context([HumanMessage(content="one")])
+ ctx2 = loop._build_tool_use_context([HumanMessage(content="two")])
+
+ assert ctx1 is not None
+ assert ctx2 is not None
+
+ ctx1.discovered_skill_names.add("skill-a")
+ ctx1.loaded_nested_memory_paths.add("/tmp/memory.md")
+ ctx1.read_file_state["/tmp/file.py"] = {"partial": False}
+
+ assert ctx2.discovered_skill_names is ctx1.discovered_skill_names
+ assert ctx2.loaded_nested_memory_paths is ctx1.loaded_nested_memory_paths
+ assert ctx2.read_file_state is ctx1.read_file_state
+ assert "skill-a" in ctx2.discovered_skill_names
+ assert "/tmp/memory.md" in ctx2.loaded_nested_memory_paths
+ assert "/tmp/file.py" in ctx2.read_file_state
+
+
+def test_tool_use_context_turn_refs_are_fresh_per_turn():
+ app_state = AppState()
+ loop = make_loop(mock_model_no_tools(), app_state=app_state)
+
+ ctx1 = loop._build_tool_use_context([HumanMessage(content="one")])
+ ctx2 = loop._build_tool_use_context([HumanMessage(content="two")])
+
+ assert ctx1 is not None
+ assert ctx2 is not None
+
+ ctx1.nested_memory_attachment_triggers.add("memo-a")
+
+ assert ctx2.nested_memory_attachment_triggers == set()
+ assert ctx2.nested_memory_attachment_triggers is not ctx1.nested_memory_attachment_triggers
+
+
+def test_tool_use_context_permission_request_surface_tracks_thread_pending_state():
+ app_state = AppState()
+ loop = make_loop(
+ mock_model_no_tools(),
+ app_state=app_state,
+ bootstrap=BootstrapConfig(
+ workspace_root=Path("/tmp"),
+ model_name="test-model",
+ permission_resolver_scope="thread",
+ ),
+ )
+
+ ctx = loop._build_tool_use_context([], thread_id="thread-a")
+ assert ctx is not None
+
+ request_id = ctx.request_permission("Write", {"path": "x"}, None, None, "needs approval")
+
+ assert isinstance(request_id, str)
+ assert app_state.pending_permission_requests[request_id]["thread_id"] == "thread-a"
+ assert app_state.pending_permission_requests[request_id]["tool_name"] == "Write"
+
+
+def test_tool_use_context_consumes_resolved_permission_once():
+ app_state = AppState(
+ resolved_permission_requests={
+ "perm-1": {
+ "thread_id": "thread-a",
+ "tool_name": "Write",
+ "args": {"path": "x"},
+ "decision": "allow",
+ "message": "approved",
+ }
+ }
+ )
+ loop = make_loop(mock_model_no_tools(), app_state=app_state)
+
+ ctx = loop._build_tool_use_context([], thread_id="thread-a")
+ assert ctx is not None
+
+ first = ctx.consume_permission_resolution("Write", {"path": "x"}, None, None)
+ second = ctx.consume_permission_resolution("Write", {"path": "x"}, None, None)
+
+ assert first == {"decision": "allow", "message": "approved"}
+ assert second is None
+ assert app_state.resolved_permission_requests == {}
+
+
+def test_tool_use_context_can_use_tool_reads_app_state_permission_rules():
+ app_state = AppState()
+ app_state.tool_permission_context.alwaysAskRules["session"] = ["Write"]
+ loop = make_loop(
+ mock_model_no_tools(),
+ app_state=app_state,
+ bootstrap=BootstrapConfig(
+ workspace_root=Path("/tmp"),
+ model_name="test-model",
+ permission_resolver_scope="thread",
+ ),
+ )
+
+ ctx = loop._build_tool_use_context([], thread_id="thread-a")
+ assert ctx is not None
+
+ decision = ctx.can_use_tool(
+ "Write",
+ {},
+ SimpleNamespace(is_read_only=False, is_destructive=False),
+ None,
+ )
+
+ assert decision == {
+ "decision": "ask",
+ "message": "Permission required by rule: Write",
+ }
+
+
+def test_tool_use_context_omits_permission_request_surface_without_interactive_resolver():
+ app_state = AppState()
+ loop = make_loop(mock_model_no_tools(), app_state=app_state)
+
+ ctx = loop._build_tool_use_context([], thread_id="thread-a")
+ assert ctx is not None
+
+ assert ctx.request_permission is None
+
+
+def test_tool_use_context_fails_loud_when_ask_has_no_interactive_resolver():
+ app_state = AppState()
+ app_state.tool_permission_context.alwaysAskRules["session"] = ["Write"]
+ loop = make_loop(mock_model_no_tools(), app_state=app_state)
+
+ ctx = loop._build_tool_use_context([], thread_id="thread-a")
+ assert ctx is not None
+
+ decision = ctx.can_use_tool(
+ "Write",
+ {},
+ SimpleNamespace(is_read_only=False, is_destructive=False),
+ None,
+ )
+
+ assert decision == {
+ "decision": "deny",
+ "message": "Permission required by rule: Write. No interactive permission resolver is available for this run.",
+ }
+
+
+class _CaptureTurnLocalStateMiddleware(AgentMiddleware):
+ def __init__(self):
+ self.turn_ids = []
+ self.trigger_snapshots = []
+
+ async def awrap_tool_call(self, request, handler):
+ self.turn_ids.append(request.state.turn_id)
+ self.trigger_snapshots.append(set(request.state.nested_memory_attachment_triggers))
+ if len(self.turn_ids) == 1:
+ request.state.nested_memory_attachment_triggers.add("first-turn-mark")
+ return await handler(request)
+
+
+@pytest.mark.asyncio
+async def test_query_loop_rebuilds_turn_local_tool_context_each_tool_turn():
+ model = mock_model_with_two_tool_turns()
+
+ def echo_handler(message: str) -> str:
+ return f"echo: {message}"
+
+ entry = ToolEntry(
+ name="echo",
+ mode=ToolMode.INLINE,
+ schema={"name": "echo", "description": "echo", "parameters": {}},
+ handler=echo_handler,
+ source="test",
+ is_concurrency_safe=False,
+ )
+ capture = _CaptureTurnLocalStateMiddleware()
+ loop = make_loop(model, registry=make_registry(entry), middleware=[capture], app_state=AppState())
+
+ async for _ in loop.astream({"messages": [{"role": "user", "content": "two turns"}]}):
+ pass
+
+ assert len(capture.turn_ids) == 2
+ assert capture.turn_ids[0] != capture.turn_ids[1]
+ assert capture.trigger_snapshots == [set(), set()]
+
+
+# ---------------------------------------------------------------------------
+# Tests: no tool calls → single agent chunk
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_no_tool_calls_yields_one_agent_chunk():
+ model = mock_model_no_tools("Hello world")
+ loop = make_loop(model)
+
+ chunks = []
+ async for chunk in loop.astream({"messages": [{"role": "user", "content": "hi"}]}):
+ chunks.append(chunk)
+
+ assert len(chunks) == 1
+ assert "agent" in chunks[0]
+ msgs = chunks[0]["agent"]["messages"]
+ assert len(msgs) == 1
+ assert msgs[0].content == "Hello world"
+
+
+@pytest.mark.asyncio
+async def test_no_tool_calls_model_called_once():
+ model = mock_model_no_tools()
+ loop = make_loop(model)
+
+ async for _ in loop.astream({"messages": [{"role": "user", "content": "hi"}]}):
+ pass
+
+ assert model.ainvoke.call_count == 1
+
+
+@pytest.mark.asyncio
+async def test_query_loop_clear_resets_turn_state_but_preserves_accumulators():
+ model = mock_model_no_tools("after clear")
+ checkpointer = _MemoryCheckpointer()
+ app_state = AppState(total_cost=1.25, tool_overrides={"Bash": False})
+ bootstrap = BootstrapConfig(workspace_root=Path("/tmp"), model_name="test-model")
+ loop = make_loop(
+ model=model,
+ checkpointer=checkpointer,
+ app_state=app_state,
+ bootstrap=bootstrap,
+ )
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "hi"}]},
+ config={"configurable": {"thread_id": "clear-thread"}},
+ ):
+ pass
+
+ loop._tool_read_file_state["/tmp/file.py"] = {"partial": False}
+ loop._tool_loaded_nested_memory_paths.add("/tmp/memory.md")
+ loop._tool_discovered_skill_names.add("skill-a")
+ old_session_id = bootstrap.session_id
+
+ await loop.aclear("clear-thread")
+
+ assert checkpointer.store["clear-thread"]["channel_values"]["messages"] == []
+ assert app_state.messages == []
+ assert app_state.turn_count == 0
+ assert app_state.compact_boundary_index == 0
+ assert app_state.total_cost == 1.25
+ assert app_state.tool_overrides == {"Bash": False}
+ assert loop._tool_read_file_state == {}
+ assert loop._tool_loaded_nested_memory_paths == set()
+ assert loop._tool_discovered_skill_names == set()
+ assert bootstrap.session_id != old_session_id
+ assert bootstrap.parent_session_id == old_session_id
+
+
+@pytest.mark.asyncio
+async def test_query_loop_replays_messages_with_real_async_sqlite_saver():
+ db_path = Path(tempfile.mkdtemp()) / "checkpoints.db"
+ conn = await connect_sqlite_async(db_path)
+ saver = AsyncSqliteSaver(conn)
+ await saver.setup()
+
+ try:
+ model = mock_model_no_tools("persist me")
+ loop = make_loop(
+ model=model,
+ checkpointer=saver,
+ app_state=AppState(),
+ )
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "first"}]},
+ config={"configurable": {"thread_id": "persist-thread"}},
+ ):
+ pass
+
+ reloaded = await loop._load_messages("persist-thread")
+ assert [msg.content for msg in reloaded] == ["first", "persist me"]
+ finally:
+ await conn.close()
+
+
+@pytest.mark.asyncio
+async def test_query_loop_aclear_wipes_real_async_sqlite_saver_history():
+ db_path = Path(tempfile.mkdtemp()) / "checkpoints.db"
+ conn = await connect_sqlite_async(db_path)
+ saver = AsyncSqliteSaver(conn)
+ await saver.setup()
+
+ try:
+ model = mock_model_no_tools("persist me")
+ loop = make_loop(
+ model=model,
+ checkpointer=saver,
+ app_state=AppState(total_cost=1.25),
+ bootstrap=BootstrapConfig(workspace_root=Path("/tmp"), model_name="test-model", total_cost_usd=1.25),
+ )
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "first"}]},
+ config={"configurable": {"thread_id": "clear-real-thread"}},
+ ):
+ pass
+
+ assert [msg.content for msg in await loop._load_messages("clear-real-thread")] == ["first", "persist me"]
+
+ await loop.aclear("clear-real-thread")
+
+ assert await loop._load_messages("clear-real-thread") == []
+ assert loop._app_state is not None
+ assert loop._app_state.total_cost == 1.25
+ finally:
+ await conn.close()
+
+
+@pytest.mark.asyncio
+async def test_query_loop_aget_state_exposes_messages_for_backend_callers():
+ model = mock_model_no_tools("state me")
+ checkpointer = _MemoryCheckpointer()
+ loop = make_loop(
+ model=model,
+ checkpointer=checkpointer,
+ app_state=AppState(),
+ )
+ config = {"configurable": {"thread_id": "state-thread"}}
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "hello"}]},
+ config=config,
+ ):
+ pass
+
+ state = await loop.aget_state(config)
+
+ assert state.values is not None
+ assert [msg.content for msg in state.values["messages"]] == ["hello", "state me"]
+
+
+@pytest.mark.asyncio
+async def test_query_loop_aget_state_exposes_persisted_permission_state_for_backend_callers():
+ checkpointer = _MemoryCheckpointer()
+ pending = {
+ "perm-1": {
+ "request_id": "perm-1",
+ "thread_id": "perm-thread",
+ "tool_name": "Write",
+ "args": {"path": "/tmp/a.txt"},
+ "message": "needs approval",
+ }
+ }
+ resolved = {
+ "perm-2": {
+ "request_id": "perm-2",
+ "thread_id": "perm-thread",
+ "tool_name": "Edit",
+ "args": {"path": "/tmp/b.txt"},
+ "decision": "allow",
+ "message": "approved",
+ }
+ }
+ loop = make_loop(
+ model=mock_model_no_tools("persist permissions"),
+ checkpointer=checkpointer,
+ app_state=AppState(
+ tool_permission_context=ToolPermissionState(
+ alwaysAllowRules={"session": ["Write"]},
+ alwaysDenyRules={"session": ["Bash"]},
+ alwaysAskRules={"session": ["Edit"]},
+ ),
+ pending_permission_requests=pending,
+ resolved_permission_requests=resolved,
+ ),
+ )
+ config = {"configurable": {"thread_id": "perm-thread"}}
+
+ await loop._save_messages("perm-thread", [HumanMessage(content="hello")])
+
+ reloaded = make_loop(
+ model=mock_model_no_tools("unused"),
+ checkpointer=checkpointer,
+ app_state=AppState(),
+ )
+
+ state = await reloaded.aget_state(config)
+
+ assert state.values["pending_permission_requests"] == pending
+ assert state.values["resolved_permission_requests"] == resolved
+ assert state.values["tool_permission_context"] == {
+ "alwaysAllowRules": {"session": ["Write"]},
+ "alwaysDenyRules": {"session": ["Bash"]},
+ "alwaysAskRules": {"session": ["Edit"]},
+ "allowManagedPermissionRulesOnly": False,
+ }
+
+
+@pytest.mark.asyncio
+async def test_query_loop_aget_state_uses_live_permission_state_while_active():
+ checkpointer = _MemoryCheckpointer()
+ app_state = AppState(
+ messages=[HumanMessage(content="live human")],
+ tool_permission_context=ToolPermissionState(alwaysAskRules={"session": ["Bash"]}),
+ pending_permission_requests={
+ "perm-live": {
+ "request_id": "perm-live",
+ "thread_id": "perm-thread",
+ "tool_name": "Bash",
+ "args": {"command": "echo hi"},
+ "message": "Permission required by rule: Bash",
+ }
+ },
+ )
+ loop = make_loop(
+ model=mock_model_no_tools("unused"),
+ checkpointer=checkpointer,
+ app_state=app_state,
+ runtime=SimpleNamespace(current_state=AgentState.ACTIVE),
+ )
+ config = {"configurable": {"thread_id": "perm-thread"}}
+
+ state = await loop.aget_state(config)
+
+ assert [msg.content for msg in state.values["messages"]] == ["live human"]
+ assert state.values["pending_permission_requests"] == {
+ "perm-live": {
+ "request_id": "perm-live",
+ "thread_id": "perm-thread",
+ "tool_name": "Bash",
+ "args": {"command": "echo hi"},
+ "message": "Permission required by rule: Bash",
+ }
+ }
+ assert state.values["tool_permission_context"] == {
+ "alwaysAllowRules": {},
+ "alwaysDenyRules": {},
+ "alwaysAskRules": {"session": ["Bash"]},
+ "allowManagedPermissionRulesOnly": False,
+ }
+
+
+@pytest.mark.asyncio
+async def test_query_loop_restores_persisted_permission_state_into_live_app_state():
+ checkpointer = _MemoryCheckpointer()
+ pending = {
+ "perm-1": {
+ "request_id": "perm-1",
+ "thread_id": "perm-thread",
+ "tool_name": "Write",
+ "args": {"path": "/tmp/a.txt"},
+ "message": "needs approval",
+ }
+ }
+ resolved = {
+ "perm-2": {
+ "request_id": "perm-2",
+ "thread_id": "perm-thread",
+ "tool_name": "Edit",
+ "args": {"path": "/tmp/b.txt"},
+ "decision": "allow",
+ "message": "approved",
+ }
+ }
+ seed_loop = make_loop(
+ model=mock_model_no_tools("seed"),
+ checkpointer=checkpointer,
+ app_state=AppState(
+ tool_permission_context=ToolPermissionState(
+ alwaysAllowRules={"session": ["Write"]},
+ alwaysDenyRules={"session": ["Bash"]},
+ alwaysAskRules={"session": ["Edit"]},
+ ),
+ pending_permission_requests=pending,
+ resolved_permission_requests=resolved,
+ ),
+ )
+ await seed_loop._save_messages("perm-thread", [HumanMessage(content="existing")])
+
+ app_state = AppState()
+ reloaded = make_loop(
+ model=mock_model_no_tools("after restore"),
+ checkpointer=checkpointer,
+ app_state=app_state,
+ )
+
+ async for _ in reloaded.query(
+ {"messages": [{"role": "user", "content": "continue"}]},
+ config={"configurable": {"thread_id": "perm-thread"}},
+ ):
+ pass
+
+ assert app_state.pending_permission_requests == pending
+ assert app_state.resolved_permission_requests == resolved
+ assert app_state.tool_permission_context.alwaysAllowRules == {"session": ["Write"]}
+ assert app_state.tool_permission_context.alwaysDenyRules == {"session": ["Bash"]}
+ assert app_state.tool_permission_context.alwaysAskRules == {"session": ["Edit"]}
+
+
+@pytest.mark.asyncio
+async def test_query_loop_persists_cleared_permission_state_after_resolution_consumed():
+ checkpointer = _MemoryCheckpointer()
+ request_id = "perm-ask"
+ thread_id = "perm-thread"
+ args = {
+ "questions": [
+ {
+ "header": "Choice",
+ "question": "Pick one.",
+ "multiSelect": False,
+ "options": [{"label": "Alpha", "description": "Alpha"}],
+ }
+ ]
+ }
+ app_state = AppState(
+ messages=[HumanMessage(content="existing")],
+ pending_permission_requests={
+ request_id: {
+ "request_id": request_id,
+ "thread_id": thread_id,
+ "tool_name": "AskUserQuestion",
+ "args": args,
+ "message": "Answer questions?",
+ }
+ },
+ )
+ loop = make_loop(
+ model=mock_model_no_tools("seed"),
+ checkpointer=checkpointer,
+ app_state=app_state,
+ )
+
+ resolved_payload = {
+ "request_id": request_id,
+ "thread_id": thread_id,
+ "tool_name": "AskUserQuestion",
+ "args": args,
+ "decision": "allow",
+ "message": "Answer questions?",
+ "answers": [
+ {
+ "header": "Choice",
+ "question": "Pick one.",
+ "selected_options": ["Alpha"],
+ }
+ ],
+ }
+ app_state.set_state(
+ lambda prev: prev.model_copy(
+ update={
+ "pending_permission_requests": {},
+ "resolved_permission_requests": {request_id: resolved_payload},
+ }
+ )
+ )
+
+ await loop.apersist_state(thread_id)
+ persisted = await loop._load_checkpoint_channel_values(thread_id)
+ assert persisted["pending_permission_requests"] == {}
+ assert persisted["resolved_permission_requests"] == {request_id: resolved_payload}
+
+ ctx = loop._build_tool_use_context([], thread_id=thread_id)
+ assert ctx is not None
+ assert ctx.consume_permission_resolution("AskUserQuestion", args, None, None) == {
+ "decision": "allow",
+ "message": "Answer questions?",
+ }
+ assert app_state.pending_permission_requests == {}
+ assert app_state.resolved_permission_requests == {}
+
+ await loop.apersist_state(thread_id)
+ persisted = await loop._load_checkpoint_channel_values(thread_id)
+ assert persisted["pending_permission_requests"] == {}
+ assert persisted["resolved_permission_requests"] == {}
+
+
+@pytest.mark.asyncio
+async def test_query_loop_aupdate_state_appends_start_messages_for_resume():
+ model = mock_model_no_tools("after resume")
+ checkpointer = _MemoryCheckpointer()
+ loop = make_loop(
+ model=model,
+ checkpointer=checkpointer,
+ app_state=AppState(),
+ )
+ config = {"configurable": {"thread_id": "resume-thread"}}
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "first"}]},
+ config=config,
+ ):
+ pass
+
+ await loop.aupdate_state(
+ config,
+ {"messages": [HumanMessage(content="second")]},
+ as_node="__start__",
+ )
+
+ state = await loop.aget_state(config)
+ assert [msg.content for msg in state.values["messages"]] == ["first", "after resume", "second"]
+
+
+@pytest.mark.asyncio
+async def test_query_loop_aupdate_state_applies_remove_and_insert_message_repairs():
+ checkpointer = _MemoryCheckpointer()
+ broken_ai = AIMessage(
+ content="",
+ tool_calls=[{"name": "Read", "args": {"file_path": "/tmp/a.txt"}, "id": "tc-1"}],
+ )
+ tool_reply = ToolMessage(content="old", tool_call_id="tc-1", name="Read")
+ trailing = HumanMessage(content="after tool")
+ tool_reply.id = "tool-old"
+ trailing.id = "human-after"
+ checkpointer.store["repair-thread"] = {"channel_values": {"messages": [broken_ai, tool_reply, trailing]}}
+
+ loop = make_loop(
+ model=mock_model_no_tools("unused"),
+ checkpointer=checkpointer,
+ app_state=AppState(),
+ )
+ config = {"configurable": {"thread_id": "repair-thread"}}
+
+ await loop.aupdate_state(
+ config,
+ {
+ "messages": [
+ RemoveMessage(id="tool-old"),
+ RemoveMessage(id="human-after"),
+ ToolMessage(content="repaired", tool_call_id="tc-1", name="Read"),
+ HumanMessage(content="after tool"),
+ ]
+ },
+ )
+
+ state = await loop.aget_state(config)
+ contents = [getattr(msg, "content", None) for msg in state.values["messages"]]
+ assert contents == ["", "repaired", "after tool"]
+
+
+@pytest.mark.asyncio
+async def test_query_loop_astream_none_resumes_after_state_injection():
+ model = MagicMock()
+ model.bind_tools.return_value = model
+ model.ainvoke = AsyncMock(
+ side_effect=[
+ AIMessage(content="first answer"),
+ AIMessage(content="resumed answer"),
+ ]
+ )
+ checkpointer = _MemoryCheckpointer()
+ loop = QueryLoop(
+ model=model,
+ system_prompt=SystemMessage(content="You are a test assistant."),
+ middleware=[],
+ checkpointer=checkpointer,
+ registry=make_registry(),
+ app_state=AppState(),
+ runtime=None,
+ bootstrap=BootstrapConfig(workspace_root=Path("/tmp"), model_name="test-model"),
+ max_turns=10,
+ )
+ config = {"configurable": {"thread_id": "resume-stream-thread"}}
+
+ async for _ in loop.query(
+ {"messages": [{"role": "user", "content": "first"}]},
+ config=config,
+ ):
+ pass
+
+ await loop.aupdate_state(
+ config,
+ {"messages": [HumanMessage(content="followup")]},
+ as_node="__start__",
+ )
+
+ events = []
+ async for event in loop.astream(None, config=config):
+ events.append(event)
+
+ assert any(msg.content == "resumed answer" for event in events for msg in event.get("agent", {}).get("messages", []))
+
+
+@pytest.mark.asyncio
+async def test_query_loop_aclear_deletes_persisted_summary_for_thread():
+ db_path = Path(tempfile.mkdtemp()) / "memory.db"
+ mm = MemoryMiddleware(db_path=db_path)
+ mm.summary_store.save_summary(
+ thread_id="clear-summary-thread",
+ summary_text="STALE SUMMARY",
+ compact_up_to_index=2,
+ compacted_at=2,
+ )
+
+ loop = QueryLoop(
+ model=mock_model_no_tools("done"),
+ system_prompt=SystemMessage(content="You are a test assistant."),
+ middleware=[mm],
+ checkpointer=None,
+ registry=make_registry(),
+ app_state=AppState(total_cost=1.25),
+ runtime=None,
+ bootstrap=BootstrapConfig(workspace_root=Path("/tmp"), model_name="test-model", total_cost_usd=1.25),
+ max_turns=10,
+ )
+
+ await loop.aclear("clear-summary-thread")
+
+ assert mm.summary_store.get_latest_summary("clear-summary-thread") is None
+
+
+# ---------------------------------------------------------------------------
+# Tests: with tool calls → agent chunk + tools chunk
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_tool_call_yields_agent_then_tools():
+ model = mock_model_with_tool_call()
+
+ # Register a simple echo tool
+ def echo_handler(message: str) -> str:
+ return f"echo: {message}"
+
+ entry = ToolEntry(
+ name="echo",
+ mode=ToolMode.INLINE,
+ schema={"name": "echo", "description": "echo", "parameters": {"type": "object", "properties": {}}},
+ handler=echo_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ registry = make_registry(entry)
+ loop = make_loop(model, registry=registry)
+
+ chunks = []
+ async for chunk in loop.astream({"messages": [{"role": "user", "content": "call echo"}]}):
+ chunks.append(chunk)
+
+ # First chunk: agent (with tool_calls)
+ # Second chunk: tools (ToolMessage results)
+ # Third chunk: agent (final text response)
+ agent_chunks = [c for c in chunks if "agent" in c]
+ tools_chunks = [c for c in chunks if "tools" in c]
+
+ assert len(agent_chunks) >= 1
+ assert len(tools_chunks) >= 1
+
+ # Tool result should be a ToolMessage
+ tool_msgs = tools_chunks[0]["tools"]["messages"]
+ assert len(tool_msgs) == 1
+ assert isinstance(tool_msgs[0], ToolMessage)
+
+
+@pytest.mark.asyncio
+async def test_tool_call_result_content():
+ model = mock_model_with_tool_call(tool_name="echo", args={"message": "test-val"})
+
+ def echo_handler(message: str) -> str:
+ return f"echo: {message}"
+
+ entry = ToolEntry(
+ name="echo",
+ mode=ToolMode.INLINE,
+ schema={"name": "echo", "description": "d", "parameters": {}},
+ handler=echo_handler,
+ source="test",
+ is_concurrency_safe=False,
+ )
+ loop = make_loop(model, registry=make_registry(entry))
+
+ tool_results = []
+ async for chunk in loop.astream({"messages": [{"role": "user", "content": "x"}]}):
+ if "tools" in chunk:
+ tool_results.extend(chunk["tools"]["messages"])
+
+ assert len(tool_results) == 1
+ assert "echo: test-val" in tool_results[0].content
+
+
+def test_tool_concurrency_safety_does_not_infer_from_read_only():
+ entry = ToolEntry(
+ name="readonly_serial",
+ mode=ToolMode.INLINE,
+ schema={"name": "readonly_serial", "description": "d", "parameters": {}},
+ handler=lambda: "ok",
+ source="test",
+ is_read_only=True,
+ is_concurrency_safe=False,
+ )
+ loop = make_loop(mock_model_no_tools(), registry=make_registry(entry))
+
+ assert loop._tool_is_concurrency_safe({"name": "readonly_serial", "args": {}}) is False
+
+
+# ---------------------------------------------------------------------------
+# Tests: max_turns guard
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_max_turns_stops_loop():
+ """Agent that hits max_turns should fail loudly on the caller-facing astream surface."""
+
+ def noop_handler() -> str:
+ return "ok"
+
+ entry = ToolEntry(
+ name="noop",
+ mode=ToolMode.INLINE,
+ schema={"name": "noop", "description": "d", "parameters": {}},
+ handler=noop_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+
+ # Build a model that always returns a tool call
+ tool_call_msg = AIMessage(
+ content="",
+ tool_calls=[{"name": "noop", "args": {}, "id": "tc-1"}],
+ )
+ model = MagicMock()
+ model.bind_tools.return_value = model
+ model.ainvoke = AsyncMock(return_value=tool_call_msg)
+
+ loop = make_loop(model, registry=make_registry(entry), max_turns=3)
+
+ with pytest.raises(RuntimeError, match="max_turns"):
+ async for _ in loop.astream({"messages": [{"role": "user", "content": "go"}]}):
+ pass
+
+ assert model.ainvoke.call_count == 3
+
+
+# ---------------------------------------------------------------------------
+# Tests: input parsing
+# ---------------------------------------------------------------------------
+
+
+def test_parse_input_dict_messages():
+ msgs = QueryLoop._parse_input({"messages": [{"role": "user", "content": "hello"}]})
+ assert len(msgs) == 1
+ assert isinstance(msgs[0], HumanMessage)
+ assert msgs[0].content == "hello"
+
+
+def test_parse_input_langchain_messages():
+ human = HumanMessage(content="hi")
+ msgs = QueryLoop._parse_input({"messages": [human]})
+ assert msgs[0] is human
+
+
+def test_parse_input_empty():
+ assert QueryLoop._parse_input({}) == []
+ assert QueryLoop._parse_input({"messages": []}) == []
+
+
+@pytest.mark.asyncio
+async def test_query_loop_syncs_app_state_on_completion():
+ model = mock_model_no_tools("AppState wired")
+ app_state = AppState(compact_boundary_index=99)
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=1.25))
+
+ async for _ in loop.query({"messages": [{"role": "user", "content": "sync"}]}):
+ pass
+
+ assert app_state.turn_count == 1
+ assert app_state.total_cost == 1.25
+ assert app_state.compact_boundary_index == 0
+ assert len(app_state.messages) == 2
+ assert app_state.messages[0].content == "sync"
+ assert app_state.messages[1].content == "AppState wired"
+
+
+@pytest.mark.asyncio
+async def test_query_loop_does_not_decrease_total_cost_when_runtime_reports_less():
+ model = mock_model_no_tools("cost stays monotonic")
+ app_state = AppState(total_cost=1.25)
+ bootstrap = BootstrapConfig(workspace_root=Path("/tmp"), model_name="test-model", total_cost_usd=1.25)
+ loop = QueryLoop(
+ model=model,
+ system_prompt=SystemMessage(content="You are a test assistant."),
+ middleware=[],
+ checkpointer=None,
+ registry=make_registry(),
+ app_state=app_state,
+ runtime=SimpleNamespace(cost=0.0),
+ bootstrap=bootstrap,
+ max_turns=10,
+ )
+
+ async for _ in loop.query({"messages": [{"role": "user", "content": "sync"}]}):
+ pass
+
+ assert app_state.total_cost == 1.25
+ assert bootstrap.total_cost_usd == 1.25
+
+
+@pytest.mark.asyncio
+async def test_query_loop_resets_dirty_app_state_turn_count_between_runs():
+ model = mock_model_no_tools("fresh")
+ app_state = AppState(turn_count=99, compact_boundary_index=7)
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=0.0))
+
+ first = await loop.ainvoke({"messages": [{"role": "user", "content": "hi"}]})
+ second = await loop.ainvoke({"messages": [{"role": "user", "content": "again"}]})
+
+ assert first["reason"] == "completed"
+ assert second["reason"] == "completed"
+ assert app_state.turn_count == 1
+ assert app_state.compact_boundary_index == 0
+ assert len(app_state.messages) == 2
+
+
+@pytest.mark.asyncio
+async def test_query_loop_refreshes_tools_between_tool_turns():
+ events: list[str] = []
+
+ async def refresh_tools() -> None:
+ events.append("refresh")
+
+ def echo_handler(message: str) -> str:
+ events.append("tool")
+ return f"echo: {message}"
+
+ tool_call_msg = AIMessage(
+ content="",
+ tool_calls=[{"name": "echo", "args": {"message": "hi"}, "id": "tc-1"}],
+ )
+ final_msg = AIMessage(content="done")
+ model = MagicMock()
+ model.bind_tools.return_value = model
+
+ async def ainvoke_side_effect(*args, **kwargs):
+ if not events:
+ events.append("model-1")
+ return tool_call_msg
+ assert events == ["model-1", "tool", "refresh"]
+ events.append("model-2")
+ return final_msg
+
+ model.ainvoke = AsyncMock(side_effect=ainvoke_side_effect)
+
+ entry = ToolEntry(
+ name="echo",
+ mode=ToolMode.INLINE,
+ schema={"name": "echo", "description": "echo", "parameters": {"type": "object", "properties": {}}},
+ handler=echo_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(model, registry=make_registry(entry))
+ loop._refresh_tools = refresh_tools
+
+ async for _ in loop.query({"messages": [{"role": "user", "content": "call echo"}]}):
+ pass
+
+ assert events == ["model-1", "tool", "refresh", "model-2"]
+
+
+@pytest.mark.asyncio
+async def test_streaming_overlap_snapshots_reused_live_chunks_before_final_aggregation():
+ class ReusedChunkModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ chunk = AIMessageChunk(
+ content="",
+ response_metadata={"model_provider": "openai"},
+ id="shared-chunk",
+ tool_calls=[],
+ invalid_tool_calls=[],
+ tool_call_chunks=[],
+ )
+ yield chunk
+ chunk.content = "HEL"
+ yield chunk
+ chunk.content = "LO"
+ yield chunk
+ chunk.content = ""
+ chunk.usage_metadata = {"input_tokens": 10, "output_tokens": 2, "total_tokens": 12}
+ yield chunk
+ chunk.chunk_position = "last"
+ yield chunk
+
+ loop = make_loop(ReusedChunkModel())
+
+ agent_messages = []
+ async for event in loop.query({"messages": [{"role": "user", "content": "hi"}]}):
+ if "agent" in event:
+ agent_messages.extend(event["agent"]["messages"])
+
+ assert len(agent_messages) == 1
+ assert agent_messages[0].content == "HELLO"
+ assert agent_messages[0].usage_metadata == {
+ "input_tokens": 10,
+ "output_tokens": 2,
+ "total_tokens": 12,
+ }
+
+
+class _CaptureToolContextMiddleware:
+ def __init__(self):
+ self.messages = None
+ self.boundary = None
+
+ async def awrap_tool_call(self, request, handler):
+ self.messages = list(request.state.messages)
+ self.boundary = request.state.get_app_state().compact_boundary_index
+ return await handler(request)
+
+
+@pytest.mark.asyncio
+async def test_query_loop_syncs_tool_context_messages_to_query_time_array():
+ capture = _CaptureToolContextMiddleware()
+ model = mock_model_with_tool_call(tool_name="echo", args={"message": "ctx"}, then_text="done")
+
+ def echo_handler(message: str) -> str:
+ return f"echo: {message}"
+
+ entry = make_inline_tool("echo", echo_handler)
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ middleware=[capture],
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ async for _ in loop.query({"messages": [{"role": "user", "content": "call echo"}]}):
+ pass
+
+ assert capture.messages is not None
+ assert len(capture.messages) == 1
+ assert capture.messages[0].content == "call echo"
+
+
+class _SummaryBoundaryMiddleware:
+ def __init__(self, boundary_index: int):
+ self.boundary_index = boundary_index
+ self.compact_boundary_index = boundary_index
+
+ async def awrap_model_call(self, request, handler):
+ rewritten = [SystemMessage(content="summary")] + list(request.messages[self.boundary_index :])
+ return await handler(request.override(messages=rewritten))
+
+
+class _ReactiveCompactMiddleware:
+ compact_boundary_index = 2
+
+ async def compact_messages_for_recovery(self, messages):
+ return [SystemMessage(content="[Conversation Summary]\nSUMMARY")] + list(messages[-1:])
+
+
+class _CollapseDrainMiddleware:
+ def __init__(self):
+ self.calls = 0
+
+ async def recover_from_overflow(self, messages):
+ self.calls += 1
+ return {
+ "committed": 1,
+ "messages": [SystemMessage(content="[Collapsed Context]\nDRAINED")] + list(messages[-1:]),
+ }
+
+
+class _EscalationModel:
+ def __init__(self):
+ self.max_tokens_values = []
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ def bind(self, **kwargs):
+ self.max_tokens_values.append(kwargs.get("max_tokens"))
+ return self
+
+ async def ainvoke(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ raise RuntimeError("max_output_tokens")
+ return AIMessage(content="after escalate")
+
+
+class _EscalationThenRecoveryModel:
+ def __init__(self):
+ self.max_tokens_values = []
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ def bind(self, **kwargs):
+ self.max_tokens_values.append(kwargs.get("max_tokens"))
+ return self
+
+ async def ainvoke(self, messages):
+ self.calls += 1
+ if self.calls in (1, 2):
+ raise RuntimeError("max_output_tokens")
+ return AIMessage(content="after recovery")
+
+
+class _ContextOverflowModel:
+ def __init__(self):
+ self.calls = 0
+ self.max_tokens_values = []
+
+ def bind_tools(self, tools):
+ return self
+
+ def bind(self, **kwargs):
+ self.max_tokens_values.append(kwargs.get("max_tokens"))
+ return self
+
+ async def ainvoke(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ raise RuntimeError("input length and `max_tokens` exceed context limit: 188059 + 20000 > 200000")
+ return AIMessage(content="after parsed overflow")
+
+
+class _TransientAPIError(Exception):
+ def __init__(self, status: int, message: str, headers: dict[str, str] | None = None):
+ super().__init__(message)
+ self.status = status
+ self.headers = headers or {}
+
+
+class _RetryOnceModel:
+ def __init__(self, status: int, headers: dict[str, str] | None = None):
+ self.calls = 0
+ self.status = status
+ self.headers = headers or {}
+
+ def bind_tools(self, tools):
+ return self
+
+ async def ainvoke(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ raise _TransientAPIError(self.status, f"transient {self.status}", self.headers)
+ return AIMessage(content=f"after retry {self.status}")
+
+
+class _EmptyStreamModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ if False:
+ yield AIMessageChunk(content="")
+
+
+class _TruncatedResponseModel:
+ def __init__(self, responses):
+ self.responses = list(responses)
+ self.calls = 0
+ self.max_tokens_values = []
+
+ def bind_tools(self, tools):
+ return self
+
+ def bind(self, **kwargs):
+ self.max_tokens_values.append(kwargs.get("max_tokens"))
+ return self
+
+ async def ainvoke(self, messages):
+ response = self.responses[self.calls]
+ self.calls += 1
+ return response
+
+
+class _QueryOkWithFailingCompactorModel:
+ def __init__(self):
+ self.query_calls = 0
+ self.compact_calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ def bind(self, **kwargs):
+ return self
+
+ async def ainvoke(self, messages):
+ system_text = ""
+ if messages and messages[0].__class__.__name__ == "SystemMessage":
+ system_text = getattr(messages[0], "content", "") or ""
+ if "tasked with summarizing conversations" in system_text or "split turn" in system_text.lower():
+ self.compact_calls += 1
+ raise RuntimeError("compaction failed")
+ self.query_calls += 1
+ return AIMessage(content="OK")
+
+
+class _StreamingToolModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(content="thinking")
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "echo", "args": '{"message":"hi"}', "id": "tc-1", "index": 0}],
+ )
+ await asyncio.sleep(0.05)
+ yield AIMessageChunk(content="done")
+ return
+ yield AIMessageChunk(content="final answer")
+
+
+class _SplitArgsStreamingToolModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "Read", "args": "", "id": "tc-read", "index": 0}],
+ )
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": None, "args": '{"file_path":"/tmp/a.txt"}', "id": "tc-read", "index": 0}],
+ )
+ await asyncio.sleep(0.01)
+ yield AIMessageChunk(content="done")
+ return
+ yield AIMessageChunk(content="final answer")
+
+
+class _SplitStringValueStreamingToolModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "Read", "args": '{"file_path":"/', "id": "tc-read", "index": 0}],
+ )
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": None, "args": 'tmp/a.txt"}', "id": "tc-read", "index": 0}],
+ )
+ await asyncio.sleep(0.01)
+ yield AIMessageChunk(content="done")
+ return
+ yield AIMessageChunk(content="final answer")
+
+
+class _SplitAnyOfStreamingToolModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "read_messages", "args": "", "id": "tc-chat-read", "index": 0}],
+ )
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": None, "args": '{"chat_id":"chat-1"}', "id": "tc-chat-read", "index": 0}],
+ )
+ await asyncio.sleep(0.01)
+ yield AIMessageChunk(content="done")
+ return
+ yield AIMessageChunk(content="final answer")
+
+
+class _TwoToolStreamingModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "unsafe", "args": '{"message":"u"}', "id": "tc-unsafe", "index": 0}],
+ )
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "safe", "args": '{"message":"s"}', "id": "tc-safe", "index": 1}],
+ )
+ await asyncio.sleep(0.05)
+ yield AIMessageChunk(content="done")
+ return
+ yield AIMessageChunk(content="final answer")
+
+
+class _FailingStreamingToolModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "echo", "args": '{"message":"boom"}', "id": "tc-1", "index": 0}],
+ )
+ await asyncio.sleep(0.005)
+ raise RuntimeError("stream exploded")
+
+
+class _FailingQueuedStreamingToolModel:
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "unsafe", "args": '{"message":"u"}', "id": "tc-unsafe", "index": 0}],
+ )
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "safe", "args": '{"message":"s"}', "id": "tc-safe", "index": 1}],
+ )
+ await asyncio.sleep(0.005)
+ raise RuntimeError("stream exploded")
+
+
+class _ToolThenFinalStreamingModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "echo", "args": '{"message":"boom"}', "id": "tc-1", "index": 0}],
+ )
+ await asyncio.sleep(0.01)
+ yield AIMessageChunk(content="tool turn")
+ return
+ yield AIMessageChunk(content="final answer")
+
+
+class _UnsafeThenSafeGapStreamingModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "unsafe", "args": '{"message":"u"}', "id": "tc-unsafe", "index": 0}],
+ )
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "safe", "args": '{"message":"s"}', "id": "tc-safe", "index": 1}],
+ )
+ await asyncio.sleep(0.08)
+ yield AIMessageChunk(content="done")
+ return
+ yield AIMessageChunk(content="final answer")
+
+
+class _BashAndSafeStreamingModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "bash", "args": '{"command":"boom"}', "id": "tc-bash", "index": 0}],
+ )
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "safe", "args": '{"message":"s"}', "id": "tc-safe", "index": 1}],
+ )
+ await asyncio.sleep(0.05)
+ yield AIMessageChunk(content="done")
+ return
+ yield AIMessageChunk(content="final answer")
+
+
+class _ExplodingToolMiddleware:
+ async def awrap_tool_call(self, request, handler):
+ raise RuntimeError("middleware boom")
+
+
+@pytest.mark.asyncio
+async def test_query_loop_does_not_double_apply_compact_boundary_before_memory_middleware():
+ capture = _CaptureToolContextMiddleware()
+ memory = _SummaryBoundaryMiddleware(boundary_index=3)
+ model = mock_model_with_tool_call(tool_name="echo", args={"message": "ctx"}, then_text="done")
+
+ def echo_handler(message: str) -> str:
+ return f"echo: {message}"
+
+ entry = make_inline_tool("echo", echo_handler)
+ history = [
+ HumanMessage(content="h0"),
+ AIMessage(content="a1"),
+ HumanMessage(content="h2"),
+ HumanMessage(content="call echo"),
+ ]
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ middleware=[memory, capture],
+ app_state=AppState(compact_boundary_index=3),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ async for _ in loop.query({"messages": history}):
+ pass
+
+ assert capture.messages is not None
+ assert len(capture.messages) == 2
+ assert capture.messages[0].content == "summary"
+ assert capture.messages[1].content == "call echo"
+
+
+@pytest.mark.asyncio
+async def test_query_loop_syncs_compact_boundary_index_from_memory_middleware():
+ memory = _SummaryBoundaryMiddleware(boundary_index=3)
+ model = mock_model_no_tools("done")
+ app_state = AppState()
+ loop = make_loop(
+ model,
+ middleware=[memory],
+ app_state=app_state,
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ async for _ in loop.query({"messages": [{"role": "user", "content": "hello"}]}):
+ pass
+
+ assert app_state.compact_boundary_index == 3
+
+
+@pytest.mark.asyncio
+async def test_query_loop_syncs_tool_context_after_real_memory_compaction():
+ capture = _CaptureToolContextMiddleware()
+ memory, _summary_model = _make_summary_memory_middleware()
+
+ model = mock_model_with_tool_call(tool_name="echo", args={"message": "ctx"}, then_text="done")
+
+ def echo_handler(message: str) -> str:
+ return f"echo: {message}"
+
+ entry = make_inline_tool("echo", echo_handler)
+
+ history = [
+ HumanMessage(content="A" * 80),
+ AIMessage(content="B" * 80),
+ HumanMessage(content="C" * 80),
+ HumanMessage(content="call echo"),
+ ]
+ app_state = AppState()
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ middleware=[memory, capture],
+ app_state=app_state,
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ async for _ in loop.query({"messages": history}):
+ pass
+
+ assert capture.messages is not None
+ assert isinstance(capture.messages[0], SystemMessage)
+ assert "Conversation Summary" in capture.messages[0].content
+ assert capture.messages[-1].content == "call echo"
+ assert app_state.compact_boundary_index > 0
+
+
+@pytest.mark.asyncio
+async def test_query_loop_syncs_compact_boundary_before_tool_execution():
+ capture = _CaptureToolContextMiddleware()
+ memory, _summary_model = _make_summary_memory_middleware()
+
+ model = mock_model_with_tool_call(tool_name="echo", args={"message": "ctx"}, then_text="done")
+
+ def echo_handler(message: str) -> str:
+ return f"echo: {message}"
+
+ entry = ToolEntry(
+ name="echo",
+ mode=ToolMode.INLINE,
+ schema={"name": "echo", "description": "echo", "parameters": {}},
+ handler=echo_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+
+ history = [
+ HumanMessage(content="A" * 80),
+ AIMessage(content="B" * 80),
+ HumanMessage(content="C" * 80),
+ HumanMessage(content="call echo"),
+ ]
+ app_state = AppState()
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ middleware=[memory, capture],
+ app_state=app_state,
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ async for _ in loop.query({"messages": history}):
+ pass
+
+ assert capture.messages is not None
+ assert capture.boundary == app_state.compact_boundary_index
+ assert capture.boundary > 0
+
+
+@pytest.mark.asyncio
+async def test_query_loop_persists_compaction_notice_when_boundary_advances():
+ memory, _summary_model = _make_summary_memory_middleware()
+
+ app_state = AppState()
+ loop = make_loop(
+ mock_model_no_tools("after compact"),
+ middleware=[memory],
+ app_state=app_state,
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ history = [
+ HumanMessage(content="A" * 80),
+ AIMessage(content="B" * 80),
+ HumanMessage(content="C" * 80),
+ HumanMessage(content="hello after compact"),
+ ]
+
+ async for _ in loop.query({"messages": history}):
+ pass
+
+ compact_notices = [
+ msg
+ for msg in app_state.messages
+ if msg.__class__.__name__ == "HumanMessage" and ((getattr(msg, "metadata", None) or {}).get("notification_type") == "compact")
+ ]
+
+ assert len(compact_notices) == 1
+ assert "Conversation compacted" in compact_notices[0].content
+ assert compact_notices[0].metadata["source"] == "system"
+ assert compact_notices[0].metadata["compact_boundary_index"] == app_state.compact_boundary_index
+ assert app_state.compact_boundary_index > 0
+
+
+@pytest.mark.asyncio
+async def test_memory_middleware_emits_runtime_compaction_notice():
+ memory, _summary_model = _make_summary_memory_middleware()
+ runtime = SimpleNamespace(cost=0.0, events=[], set_flag=lambda *_args, **_kwargs: None)
+ runtime.emit_activity_event = lambda event: runtime.events.append(event)
+ memory.set_runtime(runtime)
+
+ loop = make_loop(
+ mock_model_no_tools("after compact"),
+ middleware=[memory],
+ app_state=AppState(),
+ runtime=runtime,
+ )
+
+ history = [
+ HumanMessage(content="A" * 80),
+ AIMessage(content="B" * 80),
+ HumanMessage(content="C" * 80),
+ HumanMessage(content="hello after compact"),
+ ]
+
+ async for _ in loop.query({"messages": history}):
+ pass
+
+ compact_events = [event for event in runtime.events if event.get("event") == "notice"]
+
+ assert len(compact_events) == 1
+ payload = json.loads(compact_events[0]["data"])
+ assert payload["notification_type"] == "compact"
+ assert "Conversation compacted" in payload["content"]
+
+
+@pytest.mark.asyncio
+async def test_query_loop_recovers_from_max_output_tokens_with_explicit_continuation():
+ model = _EscalationThenRecoveryModel()
+ app_state = AppState()
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=0.0))
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["transition"].reason.value == "max_output_tokens_recovery"
+ assert model.calls == 3
+ assert model.max_tokens_values == [64000, 64000]
+ assert any(
+ getattr(msg, "content", "") == "Output token limit hit. Resume directly with no apology or recap." for msg in app_state.messages
+ )
+
+
+@pytest.mark.asyncio
+async def test_query_loop_escalates_max_output_tokens_before_continuation_recovery():
+ model = _EscalationModel()
+ app_state = AppState()
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=0.0))
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["transition"].reason.value == "max_output_tokens_escalate"
+ assert model.max_tokens_values == [64000]
+
+
+@pytest.mark.asyncio
+async def test_query_loop_parses_context_overflow_error_into_targeted_max_tokens_override():
+ model = _ContextOverflowModel()
+ app_state = AppState()
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=0.0))
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["messages"][-1].content == "after parsed overflow"
+ assert model.max_tokens_values == [10941]
+
+
+@pytest.mark.asyncio
+async def test_query_loop_retries_once_after_529_capacity_error():
+ model = _RetryOnceModel(529)
+ app_state = AppState()
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=0.0))
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["messages"][-1].content == "after retry 529"
+ assert model.calls == 2
+
+
+@pytest.mark.asyncio
+async def test_query_loop_retries_once_after_429_rate_limit_error():
+ model = _RetryOnceModel(429, headers={"retry-after": "0"})
+ app_state = AppState()
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=0.0))
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["messages"][-1].content == "after retry 429"
+ assert model.calls == 2
+
+
+@pytest.mark.asyncio
+async def test_query_loop_astream_raises_loudly_on_empty_stream():
+ loop = make_loop(_EmptyStreamModel(), app_state=AppState(), runtime=SimpleNamespace(cost=0.0))
+
+ with pytest.raises(RuntimeError, match="streaming model returned no AIMessageChunk"):
+ async for _ in loop.astream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode=["messages", "updates"]):
+ pass
+
+
+@pytest.mark.asyncio
+async def test_query_loop_detects_truncated_response_and_escalates_without_yielding_partial():
+ model = _TruncatedResponseModel(
+ [
+ AIMessage(content="partial", response_metadata={"finish_reason": "length"}),
+ AIMessage(content="after escalate"),
+ ]
+ )
+ app_state = AppState()
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=0.0))
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["transition"].reason.value == "max_output_tokens_escalate"
+ assert [msg.content for msg in result["messages"]] == ["after escalate"]
+ assert model.max_tokens_values == [64000]
+
+
+@pytest.mark.asyncio
+async def test_query_loop_recovers_from_truncated_response_with_withheld_message_pattern():
+ model = _TruncatedResponseModel(
+ [
+ AIMessage(content="partial-1", response_metadata={"finish_reason": "length"}),
+ AIMessage(content="partial-2", response_metadata={"stop_reason": "max_tokens"}),
+ AIMessage(content="after recovery"),
+ ]
+ )
+ app_state = AppState()
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=0.0))
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["transition"].reason.value == "max_output_tokens_recovery"
+ assert any(getattr(msg, "content", "") == "partial-2" for msg in app_state.messages)
+ assert any(
+ getattr(msg, "content", "") == "Output token limit hit. Resume directly with no apology or recap." for msg in app_state.messages
+ )
+
+
+@pytest.mark.asyncio
+async def test_query_loop_surfaces_withheld_truncated_message_after_recovery_exhausts():
+ model = _TruncatedResponseModel(
+ [
+ AIMessage(content="partial-1", response_metadata={"finish_reason": "length"}),
+ AIMessage(content="partial-2", response_metadata={"finish_reason": "length"}),
+ AIMessage(content="partial-3", response_metadata={"finish_reason": "length"}),
+ AIMessage(content="partial-4", response_metadata={"finish_reason": "length"}),
+ AIMessage(content="partial-5", response_metadata={"finish_reason": "length"}),
+ ]
+ )
+ app_state = AppState()
+ loop = make_loop(model, app_state=app_state, runtime=SimpleNamespace(cost=0.0))
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "model_error"
+ assert result["messages"][-1].content == "partial-5"
+
+
+@pytest.mark.asyncio
+async def test_query_loop_retries_prompt_too_long_via_reactive_compact():
+ model = _make_prompt_too_long_model(
+ RuntimeError("prompt is too long"),
+ AIMessage(content="after compact"),
+ )
+ app_state = AppState()
+ loop = make_loop(
+ model,
+ middleware=[_ReactiveCompactMiddleware()],
+ app_state=app_state,
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["transition"].reason.value == "reactive_compact_retry"
+ assert model.ainvoke.call_count == 2
+ assert isinstance(app_state.messages[0], SystemMessage)
+ assert "Conversation Summary" in app_state.messages[0].content
+
+
+@pytest.mark.asyncio
+async def test_handle_model_error_recovery_returns_typed_result_object():
+ loop = make_loop(mock_model_no_tools(), app_state=AppState(), runtime=SimpleNamespace(cost=0.0))
+
+ result = await loop._handle_model_error_recovery(
+ exc=RuntimeError("max_output_tokens exceeded"),
+ thread_id="thread-a",
+ messages=[HumanMessage(content="start")],
+ turn=1,
+ transition=None,
+ max_output_tokens_recovery_count=0,
+ has_attempted_reactive_compact=False,
+ max_output_tokens_override=None,
+ transient_api_retry_count=0,
+ )
+
+ assert result is not None
+ assert not isinstance(result, dict)
+ assert result.transition.reason.value == "max_output_tokens_escalate"
+ assert result.max_output_tokens_override == 64000
+
+
+@pytest.mark.asyncio
+async def test_query_loop_retries_prompt_too_long_via_collapse_drain_before_compact():
+ collapse = _CollapseDrainMiddleware()
+ model = _make_prompt_too_long_model(
+ RuntimeError("prompt is too long"),
+ AIMessage(content="after drain"),
+ )
+ app_state = AppState()
+ loop = make_loop(
+ model,
+ middleware=[collapse],
+ app_state=app_state,
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["transition"].reason.value == "collapse_drain_retry"
+ assert collapse.calls == 1
+ assert model.ainvoke.call_count == 2
+ assert isinstance(app_state.messages[0], SystemMessage)
+ assert "Collapsed Context" in app_state.messages[0].content
+
+
+@pytest.mark.asyncio
+async def test_query_loop_collapse_drain_is_single_shot_before_reactive_compact():
+ collapse = _CollapseDrainMiddleware()
+ model = _make_prompt_too_long_model(
+ RuntimeError("prompt is too long"),
+ RuntimeError("prompt is too long"),
+ AIMessage(content="after compact"),
+ )
+ app_state = AppState()
+ loop = make_loop(
+ model,
+ middleware=[collapse, _ReactiveCompactMiddleware()],
+ app_state=app_state,
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "completed"
+ assert result["transition"].reason.value == "reactive_compact_retry"
+ assert collapse.calls == 1
+ assert model.ainvoke.call_count == 3
+ assert isinstance(app_state.messages[0], SystemMessage)
+ assert "Conversation Summary" in app_state.messages[0].content
+
+
+@pytest.mark.asyncio
+async def test_query_loop_persists_prompt_too_long_notice_after_recovery_exhausts():
+ model = _make_prompt_too_long_model(
+ RuntimeError("prompt is too long"),
+ RuntimeError("prompt is too long"),
+ )
+ app_state = AppState()
+ loop = make_loop(
+ model,
+ middleware=[_ReactiveCompactMiddleware()],
+ app_state=app_state,
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "start"}]})
+
+ assert result["reason"] == "prompt_too_long"
+ notices = [
+ msg
+ for msg in app_state.messages
+ if msg.__class__.__name__ == "HumanMessage" and ((getattr(msg, "metadata", None) or {}).get("source") == "system")
+ ]
+ assert notices
+ assert notices[-1].content == "Prompt is too long. Automatic recovery exhausted. Clear the thread or start a new one."
+
+
+@pytest.mark.asyncio
+async def test_query_loop_astream_raises_prompt_too_long_notice_text_after_recovery_exhausts():
+ model = _make_prompt_too_long_model(
+ RuntimeError("prompt is too long"),
+ RuntimeError("prompt is too long"),
+ )
+ loop = make_loop(
+ model,
+ middleware=[_ReactiveCompactMiddleware()],
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ with pytest.raises(
+ RuntimeError,
+ match="Prompt is too long. Automatic recovery exhausted. Clear the thread or start a new one.",
+ ):
+ async for _ in loop.astream({"messages": [{"role": "user", "content": "start"}]}, stream_mode=["updates"]):
+ pass
+
+
+@pytest.mark.asyncio
+async def test_query_loop_opens_and_clears_thread_scoped_compaction_breaker(tmp_path):
+ thread_id = "compact-breaker-thread"
+ checkpointer = _MemoryCheckpointer()
+ model = _QueryOkWithFailingCompactorModel()
+
+ def make_breaker_loop():
+ memory = MemoryMiddleware(
+ context_limit=10000,
+ compaction_threshold=0.5,
+ db_path=tmp_path / "compact-breaker.db",
+ compaction_config=SimpleNamespace(reserve_tokens=0, keep_recent_tokens=10),
+ )
+ memory.set_model(model)
+ return QueryLoop(
+ model=model,
+ system_prompt=SystemMessage(content="You are a test assistant."),
+ middleware=[memory],
+ checkpointer=checkpointer,
+ registry=make_registry(),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ bootstrap=BootstrapConfig(workspace_root=Path("/tmp"), model_name="test-model"),
+ max_turns=10,
+ )
+
+ loop = make_breaker_loop()
+ config = {"configurable": {"thread_id": thread_id}}
+
+ for attempt in range(1, 4):
+ result = await loop.ainvoke(
+ {
+ "messages": [
+ {"role": "user", "content": "A" * 8000},
+ {"role": "assistant", "content": "B" * 8000},
+ {"role": "user", "content": f"start {attempt} " + ("C" * 8000)},
+ ]
+ },
+ config=config,
+ )
+ assert result["reason"] == "completed"
+ assert model.compact_calls == attempt
+
+ state = await loop.aget_state(config)
+ breaker_notices = [
+ msg
+ for msg in state.values["messages"]
+ if msg.__class__.__name__ == "HumanMessage"
+ and ((getattr(msg, "metadata", None) or {}).get("notification_type") == "compact_breaker")
+ ]
+ assert len(breaker_notices) == 1
+ assert "Automatic compaction disabled for this thread after repeated failures." in breaker_notices[0].content
+
+ reloaded = make_breaker_loop()
+ result = await reloaded.ainvoke(
+ {
+ "messages": [
+ {"role": "user", "content": "A" * 8000},
+ {"role": "assistant", "content": "B" * 8000},
+ {"role": "user", "content": "after breaker " + ("C" * 8000)},
+ ]
+ },
+ config=config,
+ )
+ assert result["reason"] == "completed"
+ assert model.compact_calls == 3
+
+ await reloaded.aclear(thread_id)
+
+ post_clear = make_breaker_loop()
+ result = await post_clear.ainvoke(
+ {
+ "messages": [
+ {"role": "user", "content": "A" * 8000},
+ {"role": "assistant", "content": "B" * 8000},
+ {"role": "user", "content": "after clear " + ("C" * 8000)},
+ ]
+ },
+ config=config,
+ )
+ assert result["reason"] == "completed"
+ assert model.compact_calls == 4
+
+
+@pytest.mark.asyncio
+async def test_query_loop_can_emit_tool_results_before_final_agent_message():
+ model = _StreamingToolModel()
+
+ async def echo_handler(message: str) -> str:
+ await asyncio.sleep(0.01)
+ return f"echo: {message}"
+
+ entry = ToolEntry(
+ name="echo",
+ mode=ToolMode.INLINE,
+ schema={"name": "echo", "description": "echo", "parameters": {}},
+ handler=echo_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ event_order: list[str] = []
+ async for chunk in loop.astream({"messages": [{"role": "user", "content": "call echo"}]}):
+ if "tools" in chunk:
+ event_order.append("tools")
+ if "agent" in chunk:
+ event_order.append("agent")
+
+ assert "tools" in event_order
+ assert "agent" in event_order
+ assert event_order.index("tools") < event_order.index("agent")
+
+
+@pytest.mark.asyncio
+async def test_streaming_executor_blocks_safe_tool_behind_running_unsafe_tool():
+ model = _TwoToolStreamingModel()
+ starts: list[str] = []
+
+ async def unsafe_handler(message: str) -> str:
+ starts.append(f"start-unsafe-{message}")
+ await asyncio.sleep(0.03)
+ starts.append(f"end-unsafe-{message}")
+ return f"unsafe: {message}"
+
+ async def safe_handler(message: str) -> str:
+ starts.append(f"start-safe-{message}")
+ await asyncio.sleep(0.001)
+ starts.append(f"end-safe-{message}")
+ return f"safe: {message}"
+
+ unsafe_entry = ToolEntry(
+ name="unsafe",
+ mode=ToolMode.INLINE,
+ schema={"name": "unsafe", "description": "unsafe", "parameters": {}},
+ handler=unsafe_handler,
+ source="test",
+ is_concurrency_safe=False,
+ )
+ safe_entry = ToolEntry(
+ name="safe",
+ mode=ToolMode.INLINE,
+ schema={"name": "safe", "description": "safe", "parameters": {}},
+ handler=safe_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(unsafe_entry, safe_entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ async for _ in loop.astream({"messages": [{"role": "user", "content": "call both"}]}):
+ pass
+
+ assert starts == [
+ "start-unsafe-u",
+ "end-unsafe-u",
+ "start-safe-s",
+ "end-safe-s",
+ ]
+
+
+@pytest.mark.asyncio
+async def test_streaming_executor_discards_running_tasks_on_stream_failure():
+ model = _FailingStreamingToolModel()
+ events: list[str] = []
+
+ async def echo_handler(message: str) -> str:
+ events.append(f"start-{message}")
+ try:
+ await asyncio.sleep(0.05)
+ except asyncio.CancelledError:
+ events.append(f"cancel-{message}")
+ raise
+ events.append(f"finish-{message}")
+ return f"echo: {message}"
+
+ entry = ToolEntry(
+ name="echo",
+ mode=ToolMode.INLINE,
+ schema={"name": "echo", "description": "echo", "parameters": {}},
+ handler=echo_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "call echo"}]})
+ await asyncio.sleep(0.06)
+
+ assert result["reason"] == "model_error"
+ assert "start-boom" in events
+ assert "cancel-boom" in events
+ assert "finish-boom" not in events
+ assert any("streaming discarded: streaming_error" in msg.content for msg in result["messages"])
+
+
+@pytest.mark.asyncio
+async def test_streaming_executor_discards_queued_tools_without_starting_them():
+ model = _FailingQueuedStreamingToolModel()
+ events: list[str] = []
+
+ async def unsafe_handler(message: str) -> str:
+ events.append(f"start-unsafe-{message}")
+ try:
+ await asyncio.sleep(0.05)
+ except asyncio.CancelledError:
+ events.append(f"cancel-unsafe-{message}")
+ raise
+ events.append(f"finish-unsafe-{message}")
+ return f"unsafe: {message}"
+
+ async def safe_handler(message: str) -> str:
+ events.append(f"start-safe-{message}")
+ await asyncio.sleep(0.001)
+ events.append(f"finish-safe-{message}")
+ return f"safe: {message}"
+
+ unsafe_entry = ToolEntry(
+ name="unsafe",
+ mode=ToolMode.INLINE,
+ schema={"name": "unsafe", "description": "unsafe", "parameters": {}},
+ handler=unsafe_handler,
+ source="test",
+ is_concurrency_safe=False,
+ )
+ safe_entry = ToolEntry(
+ name="safe",
+ mode=ToolMode.INLINE,
+ schema={"name": "safe", "description": "safe", "parameters": {}},
+ handler=safe_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(unsafe_entry, safe_entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "call both"}]})
+ await asyncio.sleep(0.06)
+
+ assert result["reason"] == "model_error"
+ assert "start-unsafe-u" in events
+ assert "cancel-unsafe-u" in events
+ assert "finish-unsafe-u" not in events
+ assert "start-safe-s" not in events
+ tool_errors = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)]
+ assert {msg.tool_call_id for msg in tool_errors} == {"tc-unsafe", "tc-safe"}
+ assert all("streaming discarded: streaming_error" in msg.content for msg in tool_errors)
+
+
+@pytest.mark.asyncio
+async def test_streaming_executor_uses_per_call_concurrency_safety():
+ class _DynamicConcurrencyStreamingModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "maybe_parallel", "args": '{"message":"u","parallel":false}', "id": "tc-maybe", "index": 0}],
+ )
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "safe", "args": '{"message":"s"}', "id": "tc-safe", "index": 1}],
+ )
+ await asyncio.sleep(0.05)
+ yield AIMessageChunk(content="done")
+ return
+ yield AIMessageChunk(content="final answer")
+
+ model = _DynamicConcurrencyStreamingModel()
+ starts: list[str] = []
+
+ async def maybe_parallel_handler(message: str, parallel: bool) -> str:
+ starts.append(f"start-maybe-{message}")
+ await asyncio.sleep(0.02)
+ starts.append(f"end-maybe-{message}")
+ return f"maybe: {message}"
+
+ async def safe_handler(message: str) -> str:
+ starts.append(f"start-safe-{message}")
+ await asyncio.sleep(0.001)
+ starts.append(f"end-safe-{message}")
+ return f"safe: {message}"
+
+ maybe_entry = ToolEntry(
+ name="maybe_parallel",
+ mode=ToolMode.INLINE,
+ schema={"name": "maybe_parallel", "description": "maybe", "parameters": {}},
+ handler=maybe_parallel_handler,
+ source="test",
+ is_concurrency_safe=lambda parsed: bool(parsed.get("parallel")),
+ )
+ safe_entry = ToolEntry(
+ name="safe",
+ mode=ToolMode.INLINE,
+ schema={"name": "safe", "description": "safe", "parameters": {}},
+ handler=safe_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(maybe_entry, safe_entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ async for _ in loop.astream({"messages": [{"role": "user", "content": "call both"}]}):
+ pass
+
+ assert starts == [
+ "start-maybe-u",
+ "end-maybe-u",
+ "start-safe-s",
+ "end-safe-s",
+ ]
+
+
+@pytest.mark.asyncio
+async def test_streaming_executor_missing_tool_completes_without_blocking_next_safe_tool():
+ class _MissingThenSafeStreamingModel:
+ def __init__(self):
+ self.calls = 0
+
+ def bind_tools(self, tools):
+ return self
+
+ async def astream(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "missing_tool", "args": "{}", "id": "tc-missing", "index": 0}],
+ )
+ yield AIMessageChunk(
+ content="",
+ tool_call_chunks=[{"name": "safe", "args": '{"message":"s"}', "id": "tc-safe", "index": 1}],
+ )
+ await asyncio.sleep(0.02)
+ yield AIMessageChunk(content="done")
+ return
+ yield AIMessageChunk(content="final answer")
+
+ model = _MissingThenSafeStreamingModel()
+ starts: list[str] = []
+
+ async def safe_handler(message: str) -> str:
+ starts.append(f"start-safe-{message}")
+ await asyncio.sleep(0.001)
+ starts.append(f"end-safe-{message}")
+ return f"safe: {message}"
+
+ safe_entry = ToolEntry(
+ name="safe",
+ mode=ToolMode.INLINE,
+ schema={"name": "safe", "description": "safe", "parameters": {}},
+ handler=safe_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(safe_entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ pre_agent_tool_ids = []
+ async for chunk in loop.astream({"messages": [{"role": "user", "content": "call missing then safe"}]}):
+ if "tools" in chunk:
+ pre_agent_tool_ids.extend(msg.tool_call_id for msg in chunk["tools"]["messages"])
+ if "agent" in chunk:
+ break
+
+ assert pre_agent_tool_ids == ["tc-missing", "tc-safe"]
+ assert starts == ["start-safe-s", "end-safe-s"]
+
+
+@pytest.mark.asyncio
+async def test_streaming_executor_missing_tool_is_immediately_completed():
+ async def safe_handler(message: str) -> str:
+ return f"safe:{message}"
+
+ safe_entry = ToolEntry(
+ name="safe",
+ mode=ToolMode.INLINE,
+ schema={"name": "safe", "description": "safe", "parameters": {}},
+ handler=safe_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ mock_model_no_tools(),
+ registry=make_registry(safe_entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+ executor = _StreamingToolExecutor(loop=loop, tool_context=None)
+
+ await executor.add_tool({"name": "missing_tool", "args": {}, "id": "tc-missing"})
+ await executor.add_tool({"name": "safe", "args": {"message": "s"}, "id": "tc-safe"})
+
+ assert [(tracked.tool_call.get("id"), tracked.status) for tracked in executor._tracked] == [
+ ("tc-missing", "completed"),
+ ("tc-safe", "executing"),
+ ]
+ assert executor._tracked[0].result is not None
+ assert "Tool 'missing_tool' not found" in executor._tracked[0].result.content
+
+
+@pytest.mark.asyncio
+async def test_execute_tools_preserves_order_blocking_for_safe_after_unsafe():
+ model = MagicMock()
+ model.bind_tools.return_value = model
+ model.ainvoke = AsyncMock(
+ side_effect=[
+ AIMessage(
+ content="",
+ tool_calls=[
+ {"name": "safe_a", "args": {"message": "a"}, "id": "tc-safe-a"},
+ {"name": "unsafe_b", "args": {"message": "b"}, "id": "tc-unsafe-b"},
+ {"name": "safe_c", "args": {"message": "c"}, "id": "tc-safe-c"},
+ ],
+ ),
+ AIMessage(content="done"),
+ ]
+ )
+ starts: list[str] = []
+
+ async def safe_a_handler(message: str) -> str:
+ starts.append(f"start-safe-a-{message}")
+ await asyncio.sleep(0.001)
+ starts.append(f"end-safe-a-{message}")
+ return f"safe-a: {message}"
+
+ async def unsafe_b_handler(message: str) -> str:
+ starts.append(f"start-unsafe-b-{message}")
+ await asyncio.sleep(0.02)
+ starts.append(f"end-unsafe-b-{message}")
+ return f"unsafe-b: {message}"
+
+ async def safe_c_handler(message: str) -> str:
+ starts.append(f"start-safe-c-{message}")
+ await asyncio.sleep(0.001)
+ starts.append(f"end-safe-c-{message}")
+ return f"safe-c: {message}"
+
+ loop = make_loop(
+ model,
+ registry=make_registry(
+ ToolEntry(
+ name="safe_a",
+ mode=ToolMode.INLINE,
+ schema={"name": "safe_a", "description": "safe_a", "parameters": {}},
+ handler=safe_a_handler,
+ source="test",
+ is_concurrency_safe=True,
+ ),
+ ToolEntry(
+ name="unsafe_b",
+ mode=ToolMode.INLINE,
+ schema={"name": "unsafe_b", "description": "unsafe_b", "parameters": {}},
+ handler=unsafe_b_handler,
+ source="test",
+ is_concurrency_safe=False,
+ ),
+ ToolEntry(
+ name="safe_c",
+ mode=ToolMode.INLINE,
+ schema={"name": "safe_c", "description": "safe_c", "parameters": {}},
+ handler=safe_c_handler,
+ source="test",
+ is_concurrency_safe=True,
+ ),
+ ),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ async for _ in loop.astream({"messages": [{"role": "user", "content": "call ordered tools"}]}):
+ pass
+
+ assert starts == [
+ "start-safe-a-a",
+ "end-safe-a-a",
+ "start-unsafe-b-b",
+ "end-unsafe-b-b",
+ "start-safe-c-c",
+ "end-safe-c-c",
+ ]
+
+
+@pytest.mark.asyncio
+async def test_streaming_executor_surfaces_middleware_exception_as_tool_error():
+ model = _ToolThenFinalStreamingModel()
+
+ async def echo_handler(message: str) -> str:
+ return f"echo: {message}"
+
+ entry = ToolEntry(
+ name="echo",
+ mode=ToolMode.INLINE,
+ schema={"name": "echo", "description": "echo", "parameters": {}},
+ handler=echo_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ middleware=[_ExplodingToolMiddleware()],
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "call echo"}]})
+
+ assert result["reason"] == "completed"
+ assert any(
+ isinstance(msg, ToolMessage) and msg.tool_call_id == "tc-1" and "middleware boom" in msg.content for msg in result["messages"]
+ )
+ assert any(isinstance(msg, AIMessage) and msg.content == "final answer" for msg in result["messages"])
+
+
+@pytest.mark.asyncio
+async def test_streaming_executor_restarts_queue_after_unsafe_completion_before_final_chunk():
+ model = _UnsafeThenSafeGapStreamingModel()
+ starts: list[str] = []
+
+ async def unsafe_handler(message: str) -> str:
+ starts.append(f"start-unsafe-{message}")
+ await asyncio.sleep(0.01)
+ starts.append(f"end-unsafe-{message}")
+ return f"unsafe: {message}"
+
+ async def safe_handler(message: str) -> str:
+ starts.append(f"start-safe-{message}")
+ await asyncio.sleep(0.001)
+ starts.append(f"end-safe-{message}")
+ return f"safe: {message}"
+
+ unsafe_entry = ToolEntry(
+ name="unsafe",
+ mode=ToolMode.INLINE,
+ schema={"name": "unsafe", "description": "unsafe", "parameters": {}},
+ handler=unsafe_handler,
+ source="test",
+ is_concurrency_safe=False,
+ )
+ safe_entry = ToolEntry(
+ name="safe",
+ mode=ToolMode.INLINE,
+ schema={"name": "safe", "description": "safe", "parameters": {}},
+ handler=safe_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(unsafe_entry, safe_entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ chunks = []
+ async for chunk in loop.astream({"messages": [{"role": "user", "content": "call both"}]}):
+ chunks.append(chunk)
+
+ first_agent_index = next(i for i, chunk in enumerate(chunks) if "agent" in chunk)
+ pre_agent_tool_ids = [msg.tool_call_id for chunk in chunks[:first_agent_index] for msg in chunk.get("tools", {}).get("messages", [])]
+
+ assert starts == [
+ "start-unsafe-u",
+ "end-unsafe-u",
+ "start-safe-s",
+ "end-safe-s",
+ ]
+ assert pre_agent_tool_ids == ["tc-unsafe", "tc-safe"]
+
+
+@pytest.mark.asyncio
+async def test_streaming_executor_bash_error_cancels_siblings_without_killing_parent():
+ model = _BashAndSafeStreamingModel()
+ events: list[str] = []
+
+ async def bash_handler(command: str) -> str:
+ events.append(f"start-bash-{command}")
+ await asyncio.sleep(0.005)
+ raise RuntimeError("bash exploded")
+
+ async def safe_handler(message: str) -> str:
+ events.append(f"start-safe-{message}")
+ try:
+ await asyncio.sleep(0.05)
+ except asyncio.CancelledError:
+ events.append(f"cancel-safe-{message}")
+ raise
+ events.append(f"finish-safe-{message}")
+ return f"safe: {message}"
+
+ bash_entry = make_inline_tool("bash", bash_handler)
+ safe_entry = make_inline_tool("safe", safe_handler)
+ loop = make_loop(
+ model,
+ registry=make_registry(bash_entry, safe_entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "call bash and safe"}]})
+
+ assert result["reason"] == "completed"
+ assert "start-bash-boom" in events
+ assert "start-safe-s" in events
+ assert "cancel-safe-s" in events
+ assert "finish-safe-s" not in events
+ tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)]
+ assert {msg.tool_call_id for msg in tool_messages} == {"tc-bash", "tc-safe"}
+ assert any(msg.tool_call_id == "tc-bash" and "bash exploded" in msg.content for msg in tool_messages)
+ assert any(msg.tool_call_id == "tc-safe" and "sibling" in msg.content for msg in tool_messages)
+
+
+@pytest.mark.asyncio
+async def test_query_loop_messages_updates_mode_forwards_live_stream_chunks():
+ model = _StreamingToolModel()
+
+ async def echo_handler(message: str) -> str:
+ await asyncio.sleep(0.01)
+ return f"echo: {message}"
+
+ entry = make_inline_tool("echo", echo_handler)
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ events = []
+ async for chunk in loop.astream(
+ {"messages": [{"role": "user", "content": "call echo"}]},
+ stream_mode=["messages", "updates"],
+ ):
+ events.append(chunk)
+
+ message_events = [data for mode, data in events if mode == "messages"]
+ texts = [msg.content for msg, _ in message_events if getattr(msg, "content", "")]
+ tool_update_index = next(i for i, item in enumerate(events) if item[0] == "updates" and "tools" in item[1])
+ thinking_index = next(i for i, item in enumerate(events) if item[0] == "messages" and item[1][0].content == "thinking")
+ tool_chunk_index = next(
+ i
+ for i, item in enumerate(events)
+ if item[0] == "messages" and getattr(item[1][0], "tool_call_chunks", None) and item[1][0].tool_call_chunks[0]["id"] == "tc-1"
+ )
+
+ assert thinking_index < tool_update_index
+ assert tool_chunk_index < tool_update_index
+ assert any(msg.content == "thinking" for msg, _ in message_events)
+ assert any(getattr(msg, "tool_call_chunks", None) and msg.tool_call_chunks[0]["id"] == "tc-1" for msg, _ in message_events)
+ assert texts == ["thinking", "done", "final answer"]
+
+
+@pytest.mark.asyncio
+async def test_streaming_overlap_waits_for_split_tool_call_args_before_execution():
+ model = _SplitArgsStreamingToolModel()
+ seen_args = []
+
+ def read_handler(file_path: str) -> str:
+ seen_args.append(file_path)
+ return f"read:{file_path}"
+
+ entry = ToolEntry(
+ name="Read",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Read",
+ "description": "read",
+ "parameters": {
+ "type": "object",
+ "required": ["file_path"],
+ "properties": {"file_path": {"type": "string"}},
+ },
+ },
+ handler=read_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "call read"}]})
+
+ tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)]
+ assert seen_args == ["/tmp/a.txt"]
+ assert any(msg.tool_call_id == "tc-read" and msg.content == "read:/tmp/a.txt" for msg in tool_messages)
+ assert not any("InputValidationError" in msg.content for msg in tool_messages)
+
+
+@pytest.mark.asyncio
+async def test_streaming_overlap_waits_for_split_string_value_before_execution():
+ model = _SplitStringValueStreamingToolModel()
+ seen_args = []
+
+ def read_handler(file_path: str) -> str:
+ seen_args.append(file_path)
+ return f"read:{file_path}"
+
+ entry = ToolEntry(
+ name="Read",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Read",
+ "description": "read",
+ "parameters": {
+ "type": "object",
+ "required": ["file_path"],
+ "properties": {"file_path": {"type": "string"}},
+ },
+ },
+ handler=read_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "call read"}]})
+
+ tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)]
+ assert seen_args == ["/tmp/a.txt"]
+ assert any(msg.tool_call_id == "tc-read" and msg.content == "read:/tmp/a.txt" for msg in tool_messages)
+ assert not any("InputValidationError" in msg.content for msg in tool_messages)
+
+
+@pytest.mark.asyncio
+async def test_streaming_overlap_waits_for_anyof_tool_args_before_execution():
+ model = _SplitAnyOfStreamingToolModel()
+ seen_calls = []
+
+ def read_messages_handler(entity_id: str | None = None, chat_id: str | None = None) -> str:
+ seen_calls.append({"entity_id": entity_id, "chat_id": chat_id})
+ if chat_id:
+ return f"chat:{chat_id}"
+ if entity_id:
+ return f"entity:{entity_id}"
+ return "Provide entity_id or chat_id."
+
+ entry = ToolEntry(
+ name="read_messages",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "read_messages",
+ "description": "read chat",
+ "parameters": {
+ "type": "object",
+ "required": [],
+ "properties": {
+ "entity_id": {"type": "string"},
+ "chat_id": {"type": "string"},
+ },
+ "x-leon-required-any-of": [
+ ["entity_id"],
+ ["chat_id"],
+ ],
+ },
+ },
+ handler=read_messages_handler,
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ model,
+ registry=make_registry(entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ result = await loop.ainvoke({"messages": [{"role": "user", "content": "read chat"}]})
+
+ tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)]
+ assert seen_calls == [{"entity_id": None, "chat_id": "chat-1"}]
+ assert any(msg.tool_call_id == "tc-chat-read" and msg.content == "chat:chat-1" for msg in tool_messages)
+ assert not any(msg.content == "Provide entity_id or chat_id." for msg in tool_messages)
+
+
+def test_normalize_stream_tool_call_keeps_aggregate_args_when_chunk_args_are_empty():
+ entry = ToolEntry(
+ name="read_messages",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "read_messages",
+ "description": "read chat",
+ "parameters": {
+ "type": "object",
+ "required": [],
+ "properties": {
+ "entity_id": {"type": "string"},
+ "chat_id": {"type": "string"},
+ },
+ "x-leon-required-any-of": [
+ ["entity_id"],
+ ["chat_id"],
+ ],
+ },
+ },
+ handler=lambda **_kwargs: "ok",
+ source="test",
+ is_concurrency_safe=True,
+ )
+ loop = make_loop(
+ mock_model_no_tools(),
+ registry=make_registry(entry),
+ app_state=AppState(),
+ runtime=SimpleNamespace(cost=0.0),
+ )
+
+ normalized = loop._normalize_stream_tool_call(
+ {"name": "read_messages", "args": {"chat_id": "chat-1"}, "id": "tc-chat-read"},
+ [{"name": "read_messages", "args": "", "id": "tc-chat-read", "index": 0}],
+ )
+
+ assert normalized == {
+ "name": "read_messages",
+ "args": {"chat_id": "chat-1"},
+ "id": "tc-chat-read",
+ }
diff --git a/tests/test_queue_formatters.py b/tests/Unit/core/test_queue_formatters.py
similarity index 85%
rename from tests/test_queue_formatters.py
rename to tests/Unit/core/test_queue_formatters.py
index 9d2e0982a..8ec57d72c 100644
--- a/tests/test_queue_formatters.py
+++ b/tests/Unit/core/test_queue_formatters.py
@@ -2,7 +2,21 @@
import xml.etree.ElementTree as ET
-from core.runtime.middleware.queue.formatters import format_command_notification
+from core.runtime.middleware.queue.formatters import format_chat_notification, format_command_notification
+
+
+class TestFormatChatNotification:
+ def test_includes_explicit_read_messages_and_send_message_instructions(self):
+ result = format_chat_notification(
+ sender_name="alice",
+ chat_id="chat-123",
+ unread_count=2,
+ )
+
+ assert 'read_messages(chat_id="chat-123")' in result
+ assert 'send_message(chat_id="chat-123", content="...")' in result
+ assert "Prefer using this exact chat_id directly" in result
+ assert "Do not treat your normal assistant text as a chat reply." in result
class TestFormatCommandNotification:
diff --git a/tests/test_runtime.py b/tests/Unit/core/test_runtime.py
similarity index 94%
rename from tests/test_runtime.py
rename to tests/Unit/core/test_runtime.py
index ef168ebbe..74ce15441 100644
--- a/tests/test_runtime.py
+++ b/tests/Unit/core/test_runtime.py
@@ -18,6 +18,7 @@
RemoteWrappedRuntime,
_extract_state_from_output,
_normalize_pty_result,
+ _RemoteRuntimeBase,
)
from sandbox.terminal import TerminalState, terminal_from_row
from storage.providers.sqlite.lease_repo import SQLiteLeaseRepo
@@ -89,6 +90,15 @@ def _wrap_remote_state_output(
return "\n".join(lines) + "\n"
+def test_remote_runtime_treats_daytona_pty_1011_as_infra_error():
+ text = 'Failed to send input to PTY: received 1011 (internal error) {"exitCode":1}'
+ assert _RemoteRuntimeBase._looks_like_infra_error(text) is True
+
+
+def test_remote_runtime_treats_broken_pipe_as_infra_error():
+ assert _RemoteRuntimeBase._looks_like_infra_error("[Errno 32] Broken pipe") is True
+
+
# TODO(windows-compat): LocalPersistentShellRuntime uses Unix PTY + /tmp paths.
# Tracked in: https://github.com/OpenDCAI/Mycel/issues — Windows shell support needed.
@pytest.mark.skipif(sys.platform == "win32", reason="LocalPersistentShellRuntime requires a Unix shell")
@@ -639,6 +649,36 @@ def _fake_run(handle, command: str, timeout: float | None, on_stdout_chunk=None)
await runtime.close()
+@pytest.mark.asyncio
+async def test_daytona_runtime_retries_once_after_broken_pipe(terminal_store, lease_store):
+ terminal = terminal_from_row(terminal_store.create("term-3b", "thread-3b", "lease-3b", "/tmp"), terminal_store.db_path)
+ lease = lease_store.create("lease-3b", "daytona")
+ provider = MagicMock()
+ from sandbox.providers.daytona import DaytonaSessionRuntime
+
+ runtime = DaytonaSessionRuntime(terminal, lease, provider)
+ calls: list[str] = []
+ recover_events: list[str] = []
+
+ def _fake_execute_once_sync(command: str, timeout: float | None = None, on_stdout_chunk=None):
+ calls.append(command)
+ if len(calls) == 1:
+ raise RuntimeError("[Errno 32] Broken pipe")
+ return ExecuteResult(exit_code=0, stdout="ok\n", stderr="")
+
+ runtime._execute_once_sync = _fake_execute_once_sync # type: ignore[attr-defined]
+ runtime._recover_infra = lambda: recover_events.append("recover") # type: ignore[attr-defined]
+ runtime._close_shell_sync = lambda: recover_events.append("close") # type: ignore[attr-defined]
+ runtime._schedule_snapshot = lambda generation, timeout: None # type: ignore[attr-defined]
+
+ result = await runtime.execute("echo ok")
+
+ assert result.exit_code == 0
+ assert result.stdout == "ok\n"
+ assert calls == ["echo ok", "echo ok"]
+ assert recover_events == ["recover", "close"]
+
+
def test_extract_state_from_output_ignores_prompt_noise():
start = "__LEON_STATE_START_deadbeef__"
end = "__LEON_STATE_END_deadbeef__"
diff --git a/tests/Unit/core/test_runtime_agent.py b/tests/Unit/core/test_runtime_agent.py
new file mode 100644
index 000000000..4999719e5
--- /dev/null
+++ b/tests/Unit/core/test_runtime_agent.py
@@ -0,0 +1,44 @@
+from pathlib import Path
+from types import SimpleNamespace
+from unittest.mock import MagicMock
+
+from core.runtime.abort import AbortController
+from core.runtime.agent import LeonAgent
+from core.runtime.state import BootstrapConfig
+
+
+def test_apply_forked_child_context_updates_agent_and_service_seams():
+ agent = object.__new__(LeonAgent)
+ agent.agent = SimpleNamespace(_bootstrap=None, _tool_abort_controller=None)
+ agent._agent_service = SimpleNamespace(_parent_bootstrap=None, _parent_tool_context=None)
+
+ bootstrap = BootstrapConfig(workspace_root=Path("/tmp"), model_name="test-model")
+ tool_context = SimpleNamespace(abort_controller=AbortController())
+
+ LeonAgent.apply_forked_child_context(agent, bootstrap, tool_context=tool_context)
+
+ assert agent._bootstrap is bootstrap
+ assert agent.agent._bootstrap is bootstrap
+ assert agent._agent_service._parent_bootstrap is bootstrap
+ assert agent._agent_service._parent_tool_context is tool_context
+ assert agent.agent._tool_abort_controller is tool_context.abort_controller
+
+
+def test_close_skips_sandbox_cleanup_and_stays_idempotent():
+ agent = object.__new__(LeonAgent)
+ agent._session_started = False
+ agent._session_ended = False
+ agent._closing = False
+ agent._closed = False
+ agent._cleanup_sandbox = MagicMock()
+ agent._mark_terminated = MagicMock()
+ agent._cleanup_mcp_client = MagicMock()
+ agent._cleanup_sqlite_connection = MagicMock()
+
+ LeonAgent.close(agent, cleanup_sandbox=False)
+ LeonAgent.close(agent, cleanup_sandbox=True)
+
+ agent._cleanup_sandbox.assert_not_called()
+ agent._mark_terminated.assert_called_once()
+ agent._cleanup_mcp_client.assert_called_once()
+ agent._cleanup_sqlite_connection.assert_called_once()
diff --git a/tests/Unit/core/test_runtime_support.py b/tests/Unit/core/test_runtime_support.py
new file mode 100644
index 000000000..1fb809a10
--- /dev/null
+++ b/tests/Unit/core/test_runtime_support.py
@@ -0,0 +1,251 @@
+"""Focused runtime support tests for cleanup, fork, and state helpers."""
+
+import asyncio
+import signal
+from pathlib import Path
+from typing import Any, get_type_hints
+
+import pytest
+
+import core.runtime.state as runtime_state
+from core.runtime.abort import AbortController
+from core.runtime.cleanup import CleanupRegistry
+from core.runtime.fork import create_subagent_context, fork_context
+from core.runtime.state import AppState, BootstrapConfig, ToolUseContext
+
+
+@pytest.fixture
+def runtime_parent_bootstrap():
+ return BootstrapConfig(
+ workspace_root=Path("/workspace"),
+ original_cwd=Path("/launcher"),
+ project_root=Path("/workspace/project"),
+ cwd=Path("/workspace/project/src"),
+ model_name="claude-opus-4-5",
+ api_key="sk-parent",
+ block_dangerous_commands=True,
+ block_network_commands=True,
+ enable_audit_log=False,
+ enable_web_tools=True,
+ allowed_file_extensions=[".py"],
+ extra_allowed_paths=["/shared"],
+ max_turns=20,
+ model_provider="anthropic",
+ base_url="https://api.anthropic.com",
+ context_limit=200000,
+ total_cost_usd=1.25,
+ total_tool_duration_ms=42,
+ )
+
+
+@pytest.fixture
+def runtime_parent_tool_context(runtime_parent_bootstrap):
+ app_state = AppState(turn_count=1, tool_overrides={"Bash": True})
+
+ def set_app_state_for_tasks(updater):
+ app_state.set_state(updater)
+
+ return ToolUseContext(
+ bootstrap=runtime_parent_bootstrap,
+ get_app_state=app_state.get_state,
+ set_app_state=app_state.set_state,
+ set_app_state_for_tasks=set_app_state_for_tasks,
+ refresh_tools=None,
+ read_file_state={"/tmp/file.py": {"partial": False}},
+ loaded_nested_memory_paths={"/tmp/memory.md"},
+ discovered_skill_names={"skill-a"},
+ nested_memory_attachment_triggers={"turn-a"},
+ messages=["msg-1"],
+ )
+
+
+def test_bootstrap_config_minimal_creation():
+ bc = BootstrapConfig(workspace_root=Path("/tmp"), model_name="claude-3-5-sonnet-20241022")
+ assert bc.workspace_root == Path("/tmp")
+ assert bc.project_root == Path("/tmp")
+ assert bc.cwd == Path("/tmp")
+ assert bc.model_name == "claude-3-5-sonnet-20241022"
+ assert bc.api_key is None
+
+
+def test_bootstrap_config_directory_lifetimes_can_be_distinct():
+ bc = BootstrapConfig(
+ workspace_root=Path("/workspace"),
+ original_cwd=Path("/launcher"),
+ project_root=Path("/workspace/project"),
+ cwd=Path("/workspace/project/src"),
+ model_name="test",
+ )
+ assert bc.original_cwd == Path("/launcher")
+ assert bc.project_root == Path("/workspace/project")
+ assert bc.cwd == Path("/workspace/project/src")
+ assert bc.workspace_root == Path("/workspace")
+
+
+def test_app_state_defaults_cover_permission_tracks():
+ s = AppState()
+ assert s.messages == []
+ assert s.turn_count == 0
+ assert s.total_cost == 0.0
+ assert s.compact_boundary_index == 0
+ assert s.tool_permission_context.alwaysAllowRules == {}
+ assert s.tool_permission_context.alwaysDenyRules == {}
+ assert s.tool_permission_context.alwaysAskRules == {}
+ assert s.pending_permission_requests == {}
+ assert s.resolved_permission_requests == {}
+
+
+def test_app_state_session_hooks_can_be_added_and_removed_per_event():
+ seen = []
+
+ def start_hook(payload):
+ seen.append(payload["event"])
+
+ s = AppState()
+ s.add_session_hook("SessionStart", start_hook)
+
+ hooks = s.get_session_hooks("SessionStart")
+ assert hooks == [start_hook]
+
+ hooks[0]({"event": "SessionStart"})
+ assert seen == ["SessionStart"]
+
+ s.remove_session_hook("SessionStart", start_hook)
+ assert s.get_session_hooks("SessionStart") == []
+
+
+def test_tool_use_context_subagent_noop_set_state():
+ bc = BootstrapConfig(workspace_root=Path("/tmp"), model_name="test")
+ app_state = AppState(turn_count=5)
+ calls = []
+
+ def noop(_value):
+ calls.append("called")
+
+ ctx = ToolUseContext(bootstrap=bc, get_app_state=lambda: app_state, set_app_state=noop)
+ ctx.set_app_state(AppState(turn_count=99))
+ assert len(calls) == 1
+ assert app_state.turn_count == 5
+
+
+def test_tool_use_context_core_callable_fields_are_not_typed_as_any():
+ hints = get_type_hints(ToolUseContext, globalns=vars(runtime_state))
+
+ assert hints["get_app_state"] is not Any
+ assert hints["set_app_state"] is not Any
+ assert hints["set_app_state_for_tasks"] is not Any
+ assert hints["refresh_tools"] is not Any
+ assert hints["can_use_tool"] is not Any
+ assert hints["request_permission"] is not Any
+ assert hints["consume_permission_resolution"] is not Any
+ assert hints["abort_controller"] is not Any
+
+
+def test_fork_context_copies_bootstrap_and_generates_new_session_id(runtime_parent_bootstrap):
+ child = fork_context(runtime_parent_bootstrap)
+ assert child.workspace_root == runtime_parent_bootstrap.workspace_root
+ assert child.original_cwd == runtime_parent_bootstrap.original_cwd
+ assert child.project_root == runtime_parent_bootstrap.project_root
+ assert child.cwd == runtime_parent_bootstrap.cwd
+ assert child.model_name == runtime_parent_bootstrap.model_name
+ assert child.api_key == runtime_parent_bootstrap.api_key
+ assert child.session_id != runtime_parent_bootstrap.session_id
+ assert child.parent_session_id == runtime_parent_bootstrap.session_id
+
+
+def test_create_subagent_context_keeps_parent_state_isolation(runtime_parent_tool_context):
+ child = create_subagent_context(runtime_parent_tool_context)
+
+ child.set_app_state(lambda prev: prev.model_copy(update={"turn_count": 9}))
+ assert runtime_parent_tool_context.get_app_state().turn_count == 1
+
+ child.set_app_state_for_tasks(lambda prev: prev.model_copy(update={"turn_count": 9}))
+ assert runtime_parent_tool_context.get_app_state().turn_count == 9
+
+
+def test_create_subagent_context_copies_read_state_and_abort_link(runtime_parent_tool_context):
+ runtime_parent_tool_context.read_file_state = {"/tmp/readme.md": {"partial": False, "meta": {"seen": 1}}}
+ runtime_parent_tool_context.abort_controller = AbortController()
+
+ child = create_subagent_context(runtime_parent_tool_context)
+ child.read_file_state["/tmp/readme.md"]["partial"] = True
+ child.read_file_state["/tmp/readme.md"]["meta"]["seen"] = 9
+ child.abort_controller.abort()
+
+ assert runtime_parent_tool_context.read_file_state["/tmp/readme.md"] == {
+ "partial": False,
+ "meta": {"seen": 1},
+ }
+ assert runtime_parent_tool_context.abort_controller.is_aborted() is False
+
+
+@pytest.mark.asyncio
+async def test_cleanup_registry_runs_in_priority_order_and_survives_failures():
+ order = []
+ reg = CleanupRegistry()
+
+ def failing():
+ raise RuntimeError("boom")
+
+ reg.register(lambda: order.append(3), priority=3)
+ reg.register(failing, priority=1)
+ reg.register(lambda: order.append(2), priority=2)
+ await reg.run_cleanup()
+ assert order == [2, 3]
+
+
+@pytest.mark.asyncio
+async def test_cleanup_registry_reuses_first_inflight_run():
+ order = []
+ release = asyncio.Event()
+ reg = CleanupRegistry()
+
+ async def slow():
+ order.append("start")
+ await release.wait()
+ order.append("done")
+
+ reg.register(slow, priority=1)
+
+ first = asyncio.create_task(reg.run_cleanup())
+ for _ in range(10):
+ if order == ["start"]:
+ break
+ await asyncio.sleep(0)
+
+ second = asyncio.create_task(reg.run_cleanup())
+ await asyncio.sleep(0)
+ release.set()
+ await asyncio.gather(first, second)
+
+ assert order == ["start", "done"]
+
+
+def test_cleanup_registry_register_returns_deregister_handle():
+ order = []
+ reg = CleanupRegistry()
+
+ unregister = reg.register(lambda: order.append("gone"), priority=1)
+ reg.register(lambda: order.append("kept"), priority=2)
+ unregister()
+
+ asyncio.run(reg.run_cleanup())
+ assert order == ["kept"]
+
+
+def test_cleanup_registry_installs_signal_handlers(monkeypatch):
+ registered = []
+
+ class _FakeLoop:
+ def add_signal_handler(self, sig, handler):
+ registered.append(sig)
+
+ monkeypatch.setattr(asyncio, "get_event_loop", lambda: _FakeLoop())
+
+ CleanupRegistry()
+
+ expected = {signal.SIGINT, signal.SIGTERM}
+ if hasattr(signal, "SIGHUP"):
+ expected.add(signal.SIGHUP)
+
+ assert set(registered) == expected
diff --git a/tests/test_spill_buffer.py b/tests/Unit/core/test_spill_buffer.py
similarity index 75%
rename from tests/test_spill_buffer.py
rename to tests/Unit/core/test_spill_buffer.py
index 553011a24..caf07bc5f 100644
--- a/tests/test_spill_buffer.py
+++ b/tests/Unit/core/test_spill_buffer.py
@@ -1,6 +1,6 @@
"""Tests for core.spill_buffer: spill_if_needed() and SpillBufferMiddleware."""
-import os
+import posixpath
from types import SimpleNamespace
from unittest.mock import MagicMock
@@ -61,12 +61,12 @@ def test_large_output_triggers_spill_and_preview(self):
)
# Verify write_file was called with the correct spill path.
- expected_path = os.path.join("/workspace", ".leon", "tool-results", "call_big.txt")
+ expected_path = posixpath.join("/workspace", ".leon", "tool-results", "call_big.txt")
fs.write_file.assert_called_once_with(expected_path, large)
# Result must mention the file path and include a preview.
assert expected_path in result
- assert "Output too large" in result
+ assert result.startswith("" in result
+ assert 'path="/workspace/.leon/tool-results/call_wrapped.txt"' in result
+ assert f'bytes="{len(large.encode("utf-8"))}"' in result
+
+ def test_image_block_content_bypasses_spill(self):
+ """Image-containing blocks should bypass persistence logic."""
+ fs = _make_fs_backend()
+ content = [
+ {"type": "text", "text": "caption"},
+ {"type": "image_url", "image_url": {"url": "https://example.com/a.png"}},
+ ]
+
+ result = spill_if_needed(
+ content=content,
+ threshold_bytes=1,
+ tool_call_id="call_image",
+ fs_backend=fs,
+ workspace_root="/workspace",
+ )
+
+ assert result is content
+ fs.write_file.assert_not_called()
+
+ def test_mcp_binary_blocks_are_saved_and_rewritten(self):
+ fs = _make_fs_backend()
+ mw = SpillBufferMiddleware(
+ fs_backend=fs,
+ workspace_root="/workspace",
+ default_threshold=50_000,
+ )
+ request = _make_request("mcp__server__image_tool", "call_mcp")
+ original_msg = ToolMessage(
+ content=[
+ {"type": "text", "text": "caption"},
+ {"type": "image", "base64": "QUJD", "mime_type": "image/png"},
+ ],
+ tool_call_id="call_mcp",
+ additional_kwargs={"tool_result_meta": {"source": "mcp"}},
+ )
+
+ result = mw._maybe_spill(request, original_msg)
+
+ expected_path = posixpath.join(
+ "/workspace",
+ ".leon",
+ "tool-results",
+ "call_mcp-1.png.base64",
+ )
+ fs.write_file.assert_called_once_with(expected_path, "QUJD")
+ assert isinstance(result.content, str)
+ assert "caption" in result.content
+ assert expected_path in result.content
+ assert "QUJD" not in result.content
+
# ===========================================================================
# SpillBufferMiddleware
@@ -236,7 +304,7 @@ def test_large_output_gets_spilled(self):
handler.assert_called_once_with(request)
assert result.content != large_content
- assert "Output too large" in result.content
+ assert result.content.startswith("done",
+ source="system",
+ notification_type="agent",
+ )
+ is True
+ )
+ assert (
+ is_terminal_background_notification(
+ "done",
+ source="system",
+ notification_type="command",
+ )
+ is True
+ )
+
+
+def test_is_terminal_background_notification_rejects_non_system_or_non_terminal_messages():
+ assert (
+ is_terminal_background_notification(
+ "done",
+ source="owner",
+ notification_type="agent",
+ )
+ is False
+ )
+ assert (
+ is_terminal_background_notification(
+ "plain reminder",
+ source="system",
+ notification_type="agent",
+ )
+ is False
+ )
diff --git a/tests/Unit/core/test_tool_registry_runner.py b/tests/Unit/core/test_tool_registry_runner.py
new file mode 100644
index 000000000..017f750a0
--- /dev/null
+++ b/tests/Unit/core/test_tool_registry_runner.py
@@ -0,0 +1,2371 @@
+"""Tests for ToolRegistry, ToolRunner, and ToolValidator (P0/P1 verification).
+
+Covers:
+- P0: Three-tier error normalization (Layer 1: validation, Layer 2: execution, Layer 3: soft)
+- P1: ToolRegistry inline/deferred mode
+- P1: ToolRunner dispatches registered tools and normalizes errors
+"""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import time
+from types import SimpleNamespace
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+from langchain_core.tools import tool
+
+from core.runtime.agent import _make_mcp_tool_entry
+from core.runtime.errors import InputValidationError
+from core.runtime.middleware import AgentMiddleware, ToolCallRequest
+from core.runtime.permissions import ToolPermissionContext, can_auto_approve
+from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry
+from core.runtime.runner import ToolRunner
+from core.runtime.state import AppState, BootstrapConfig, ToolUseContext
+from core.runtime.tool_result import ToolResultEnvelope, tool_permission_denied
+from core.runtime.validator import ToolValidator
+from core.tools.command.hooks.dangerous_commands import DangerousCommandsHook
+from core.tools.command.service import CommandService
+from core.tools.filesystem.read import ReadLimits
+from core.tools.filesystem.read import read_file as read_file_dispatch
+from core.tools.filesystem.read.readers.pdf import read_pdf
+from core.tools.filesystem.service import FileSystemService
+from core.tools.tool_search.service import ToolSearchService
+from core.tools.web.service import WebService
+from sandbox.interfaces.filesystem import DirListResult, FileReadResult, FileSystemBackend, FileWriteResult
+
+# ---------------------------------------------------------------------------
+# ToolRegistry
+# ---------------------------------------------------------------------------
+
+
+class TestToolRegistry:
+ def _make_entry(self, name: str, mode: ToolMode = ToolMode.INLINE) -> ToolEntry:
+ return ToolEntry(
+ name=name,
+ mode=mode,
+ schema={"name": name, "description": f"{name} tool"},
+ handler=lambda: f"result:{name}",
+ source="test",
+ )
+
+ def test_register_and_get(self):
+ reg = ToolRegistry()
+ entry = self._make_entry("Read")
+ reg.register(entry)
+ assert reg.get("Read") is entry
+
+ def test_get_unknown_returns_none(self):
+ reg = ToolRegistry()
+ assert reg.get("NonExistent") is None
+
+ def test_inline_tools_appear_in_get_inline_schemas(self):
+ reg = ToolRegistry()
+ reg.register(self._make_entry("Read", ToolMode.INLINE))
+ reg.register(self._make_entry("TaskCreate", ToolMode.DEFERRED))
+ schemas = reg.get_inline_schemas()
+ names = [s["name"] for s in schemas]
+ assert "Read" in names
+ assert "TaskCreate" not in names # P1: deferred not in inline
+
+ def test_deferred_tools_not_in_inline_schemas(self):
+ reg = ToolRegistry()
+ reg.register(self._make_entry("TaskCreate", ToolMode.DEFERRED))
+ reg.register(self._make_entry("TaskUpdate", ToolMode.DEFERRED))
+ assert reg.get_inline_schemas() == []
+
+ def test_search_finds_by_name(self):
+ reg = ToolRegistry()
+ reg.register(self._make_entry("TaskCreate", ToolMode.DEFERRED))
+ reg.register(self._make_entry("Read", ToolMode.INLINE))
+ results = reg.search("task")
+ names = [e.name for e in results]
+ assert "TaskCreate" in names
+
+ def test_search_includes_deferred_tools(self):
+ """tool_search must discover deferred tools too."""
+ reg = ToolRegistry()
+ reg.register(self._make_entry("TaskCreate", ToolMode.DEFERRED))
+ results = reg.search("TaskCreate")
+ assert any(e.name == "TaskCreate" for e in results)
+
+ def test_search_no_match_returns_empty_results(self):
+ reg = ToolRegistry()
+ reg.register(self._make_entry("Read", ToolMode.INLINE))
+ reg.register(self._make_entry("TaskCreate", ToolMode.DEFERRED))
+ assert reg.search("nonesuch") == []
+
+ def test_allowed_tools_filter(self):
+ reg = ToolRegistry(allowed_tools={"Read", "Grep"})
+ reg.register(self._make_entry("Read"))
+ reg.register(self._make_entry("Grep"))
+ reg.register(self._make_entry("Bash"))
+ assert reg.get("Read") is not None
+ assert reg.get("Grep") is not None
+ assert reg.get("Bash") is None # filtered out
+
+ def test_dynamic_schema_callable(self):
+ call_count = 0
+
+ def schema_fn() -> dict:
+ nonlocal call_count
+ call_count += 1
+ return {"name": "DynTool", "description": "dynamic"}
+
+ reg = ToolRegistry()
+ entry = ToolEntry(
+ name="DynTool",
+ mode=ToolMode.INLINE,
+ schema=schema_fn,
+ handler=lambda: "ok",
+ source="test",
+ )
+ reg.register(entry)
+ schemas = reg.get_inline_schemas()
+ assert call_count >= 1
+ assert any(s["name"] == "DynTool" for s in schemas)
+
+
+def test_agent_middleware_tools_are_not_shared_mutable_state():
+ first = AgentMiddleware()
+ second = AgentMiddleware()
+
+ first.tools = ["x"]
+
+ assert second.tools == ()
+
+ def test_inline_schemas_strip_runtime_only_schema_metadata(self):
+ reg = ToolRegistry()
+ reg.register(
+ ToolEntry(
+ name="ChatRead",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "ChatRead",
+ "description": "chat read",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "chat_id": {"type": "string"},
+ },
+ "x-leon-required-any-of": [["chat_id"]],
+ },
+ },
+ handler=lambda **_kwargs: "ok",
+ source="test",
+ )
+ )
+
+ [schema] = reg.get_inline_schemas()
+
+ assert "x-leon-required-any-of" not in schema["parameters"]
+
+
+# ---------------------------------------------------------------------------
+# ToolValidator
+# ---------------------------------------------------------------------------
+
+
+class TestToolValidator:
+ def _schema(self, required: list[str], props: dict) -> dict:
+ return {
+ "name": "TestTool",
+ "parameters": {
+ "type": "object",
+ "required": required,
+ "properties": {k: {"type": v} for k, v in props.items()},
+ },
+ }
+
+ def test_valid_args_pass(self):
+ v = ToolValidator()
+ schema = self._schema(["file_path"], {"file_path": "string"})
+ result = v.validate(schema, {"file_path": "/tmp/x"})
+ assert result.ok
+
+ def test_missing_required_raises_layer1(self):
+ v = ToolValidator()
+ schema = self._schema(["file_path"], {"file_path": "string"})
+ with pytest.raises(InputValidationError) as exc_info:
+ v.validate(schema, {})
+ assert "file_path" in str(exc_info.value)
+ assert "missing" in str(exc_info.value)
+ assert exc_info.value.error_code == "REQUIRED_FIELD_MISSING"
+ assert exc_info.value.details[0]["field"] == "file_path"
+
+ def test_wrong_type_raises_layer1(self):
+ v = ToolValidator()
+ schema = self._schema(["count"], {"count": "integer"})
+ with pytest.raises(InputValidationError) as exc_info:
+ v.validate(schema, {"count": "not-an-int"})
+ assert exc_info.value.error_code == "INVALID_TYPE"
+ assert exc_info.value.details[0]["field"] == "count"
+
+ def test_extra_params_allowed(self):
+ v = ToolValidator()
+ schema = self._schema(["a"], {"a": "string"})
+ result = v.validate(schema, {"a": "hello", "extra": "ok"})
+ assert result.ok
+
+ def test_required_any_of_requires_one_alternative(self):
+ v = ToolValidator()
+ schema = {
+ "name": "ChatRead",
+ "parameters": {
+ "type": "object",
+ "required": [],
+ "properties": {
+ "entity_id": {"type": "string"},
+ "chat_id": {"type": "string"},
+ },
+ "x-leon-required-any-of": [
+ ["entity_id"],
+ ["chat_id"],
+ ],
+ },
+ }
+
+ with pytest.raises(InputValidationError) as exc_info:
+ v.validate(schema, {})
+
+ assert "entity_id" in str(exc_info.value)
+ assert "chat_id" in str(exc_info.value)
+
+ def test_required_any_of_accepts_present_alternative(self):
+ v = ToolValidator()
+ schema = {
+ "name": "ChatRead",
+ "parameters": {
+ "type": "object",
+ "required": [],
+ "properties": {
+ "entity_id": {"type": "string"},
+ "chat_id": {"type": "string"},
+ },
+ "x-leon-required-any-of": [
+ ["entity_id"],
+ ["chat_id"],
+ ],
+ },
+ }
+
+ result = v.validate(schema, {"chat_id": "chat-1"})
+ assert result.ok
+
+ def test_string_constraints_raise_layer1(self):
+ v = ToolValidator()
+ schema = {
+ "name": "Read",
+ "parameters": {
+ "type": "object",
+ "required": ["file_path"],
+ "properties": {
+ "file_path": {
+ "type": "string",
+ "minLength": 1,
+ "pattern": "^/",
+ }
+ },
+ },
+ }
+
+ with pytest.raises(InputValidationError) as exc_info:
+ v.validate(schema, {"file_path": "relative/path.txt"})
+
+ assert "file_path" in str(exc_info.value)
+ assert "match pattern" in str(exc_info.value)
+ assert exc_info.value.error_code == "PATTERN_MISMATCH"
+ assert exc_info.value.details[0]["error_code"] == "PATTERN_MISMATCH"
+
+ def test_absolute_path_pattern_accepts_windows_drive_paths(self):
+ v = ToolValidator()
+ schema = {
+ "name": "Read",
+ "parameters": {
+ "type": "object",
+ "required": ["file_path"],
+ "properties": {
+ "file_path": {
+ "type": "string",
+ "minLength": 1,
+ "pattern": r"^(?:/|[A-Za-z]:[\\/])",
+ }
+ },
+ },
+ }
+
+ result = v.validate(schema, {"file_path": r"C:\tmp\file.txt"})
+
+ assert result.ok
+
+ def test_numeric_maximum_raises_layer1(self):
+ v = ToolValidator()
+ schema = {
+ "name": "TaskOutput",
+ "parameters": {
+ "type": "object",
+ "required": ["timeout"],
+ "properties": {
+ "timeout": {
+ "type": "integer",
+ "maximum": 600000,
+ }
+ },
+ },
+ }
+
+ with pytest.raises(InputValidationError) as exc_info:
+ v.validate(schema, {"timeout": 600001})
+
+ assert "timeout" in str(exc_info.value)
+ assert "at most" in str(exc_info.value)
+ assert exc_info.value.error_code == "NUMBER_TOO_LARGE"
+ assert exc_info.value.details[0]["field"] == "timeout"
+
+
+# ---------------------------------------------------------------------------
+# ToolRunner — P0 error normalization
+# ---------------------------------------------------------------------------
+
+
+def _make_runner(entries: list[ToolEntry]) -> ToolRunner:
+ reg = ToolRegistry()
+ for e in entries:
+ reg.register(e)
+ return ToolRunner(registry=reg)
+
+
+def _make_tool_call_request(name: str, args: dict, call_id: str = "tc-1"):
+ req = MagicMock()
+ req.tool_call = {"name": name, "args": args, "id": call_id}
+ return req
+
+
+class TestToolRunnerErrorNormalization:
+ """P0: three-tier error normalization."""
+
+ def test_layer1_missing_param_returns_input_validation_error(self):
+ entry = ToolEntry(
+ name="Read",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Read",
+ "parameters": {
+ "type": "object",
+ "required": ["file_path"],
+ "properties": {"file_path": {"type": "string"}},
+ },
+ },
+ handler=lambda file_path: "content",
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Read", {}) # missing file_path
+
+ called_upstream = []
+
+ def upstream(r):
+ called_upstream.append(r)
+ return MagicMock()
+
+ result = runner.wrap_tool_call(req, upstream)
+ # Layer 1 error format: InputValidationError: {name} failed due to...
+ assert "InputValidationError" in result.content
+ assert "Read" in result.content
+ assert result.additional_kwargs["tool_result_meta"]["error_code"] == "REQUIRED_FIELD_MISSING"
+ assert not called_upstream # must not fall through to upstream
+
+ def test_layer1_schema_failure_returns_structured_error_details(self):
+ entry = ToolEntry(
+ name="Bash",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Bash",
+ "parameters": {
+ "type": "object",
+ "required": ["timeout"],
+ "properties": {
+ "timeout": {"type": "integer", "maximum": 600000},
+ },
+ },
+ },
+ handler=lambda timeout: timeout,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Bash", {"timeout": 600001})
+
+ result = runner.wrap_tool_call(req, lambda r: MagicMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert meta["error_type"] == "input_validation"
+ assert meta["error_code"] == "NUMBER_TOO_LARGE"
+ assert meta["error_details"][0]["field"] == "timeout"
+
+ def test_layer2_handler_exception_returns_tool_use_error(self):
+ def bad_handler(**kwargs):
+ raise ValueError("disk full")
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Write",
+ "parameters": {
+ "type": "object",
+ "required": [],
+ "properties": {},
+ },
+ },
+ handler=bad_handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ result = runner.wrap_tool_call(req, lambda r: MagicMock())
+ # Layer 2 error format: ...
+ assert "" in result.content
+ assert "disk full" in result.content
+
+ @pytest.mark.asyncio
+ async def test_filesystem_service_read_preserves_image_blocks_on_local_path(self, tmp_path):
+ registry = ToolRegistry()
+ FileSystemService(
+ registry=registry,
+ workspace_root=tmp_path,
+ )
+ image = tmp_path / "tiny.png"
+ image.write_bytes(b"fake-png-payload")
+
+ runner = _make_runner(registry.list_all())
+ req = _make_tool_call_request("Read", {"file_path": str(image)})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert isinstance(result.content, list)
+ assert any(block.get("type") == "image" for block in result.content)
+ assert result.additional_kwargs["tool_result_meta"]["source"] == "local"
+
+ @pytest.mark.asyncio
+ async def test_filesystem_service_read_preserves_image_blocks_on_remote_path(self, tmp_path):
+ class RemoteImageBackend(FileSystemBackend):
+ is_remote = True
+
+ def __init__(self):
+ self._raw = b"remote-png-payload"
+
+ def read_file(self, path: str) -> FileReadResult:
+ return FileReadResult(content="opaque-binary-placeholder", size=len(self._raw))
+
+ def write_file(self, path: str, content: str) -> FileWriteResult:
+ return FileWriteResult(success=True)
+
+ def file_exists(self, path: str) -> bool:
+ return True
+
+ def file_mtime(self, path: str) -> float | None:
+ return None
+
+ def file_size(self, path: str) -> int | None:
+ return len(self._raw)
+
+ def is_dir(self, path: str) -> bool:
+ return False
+
+ def list_dir(self, path: str) -> DirListResult:
+ return DirListResult(entries=[])
+
+ def download_bytes(self, path: str) -> bytes:
+ return self._raw
+
+ registry = ToolRegistry()
+ FileSystemService(
+ registry=registry,
+ workspace_root="/workspace",
+ backend=RemoteImageBackend(),
+ )
+
+ runner = _make_runner(registry.list_all())
+ req = _make_tool_call_request("Read", {"file_path": "/workspace/tiny.png"})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert isinstance(result.content, list)
+ assert any(block.get("type") == "image" for block in result.content)
+ assert result.additional_kwargs["tool_result_meta"]["source"] == "local"
+
+ @pytest.mark.asyncio
+ async def test_filesystem_service_read_remote_pdf_uses_special_reader_path(self, tmp_path):
+ pdf_bytes = b"%PDF-1.4\nnot-a-real-pdf\n"
+ local_pdf = tmp_path / "sample.pdf"
+ local_pdf.write_bytes(pdf_bytes)
+ expected = read_file_dispatch(path=local_pdf, limits=ReadLimits()).format_output()
+ expected = expected.replace(str(local_pdf), "/workspace/sample.pdf")
+
+ class RemotePdfBackend(FileSystemBackend):
+ is_remote = True
+
+ def read_file(self, path: str) -> FileReadResult:
+ return FileReadResult(content="opaque-pdf-placeholder", size=len(pdf_bytes))
+
+ def write_file(self, path: str, content: str) -> FileWriteResult:
+ return FileWriteResult(success=True)
+
+ def file_exists(self, path: str) -> bool:
+ return True
+
+ def file_mtime(self, path: str) -> float | None:
+ return None
+
+ def file_size(self, path: str) -> int | None:
+ return len(pdf_bytes)
+
+ def is_dir(self, path: str) -> bool:
+ return False
+
+ def list_dir(self, path: str) -> DirListResult:
+ return DirListResult(entries=[])
+
+ def download_bytes(self, path: str) -> bytes:
+ return pdf_bytes
+
+ registry = ToolRegistry()
+ FileSystemService(
+ registry=registry,
+ workspace_root="/workspace",
+ backend=RemotePdfBackend(),
+ )
+
+ runner = _make_runner(registry.list_all())
+ req = _make_tool_call_request("Read", {"file_path": "/workspace/sample.pdf"})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == expected
+
+ @pytest.mark.asyncio
+ async def test_filesystem_service_remote_special_file_fails_before_download_when_size_known(self):
+ class RemoteLargePdfBackend(FileSystemBackend):
+ is_remote = True
+
+ def read_file(self, path: str) -> FileReadResult:
+ raise AssertionError("read_file should not run for oversize remote preflight")
+
+ def write_file(self, path: str, content: str) -> FileWriteResult:
+ return FileWriteResult(success=True)
+
+ def file_exists(self, path: str) -> bool:
+ return True
+
+ def file_mtime(self, path: str) -> float | None:
+ return None
+
+ def file_size(self, path: str) -> int | None:
+ return 11 * 1024 * 1024
+
+ def is_dir(self, path: str) -> bool:
+ return False
+
+ def list_dir(self, path: str) -> DirListResult:
+ return DirListResult(entries=[])
+
+ def download_bytes(self, path: str) -> bytes:
+ raise AssertionError("download_bytes should not run for oversize remote preflight")
+
+ registry = ToolRegistry()
+ FileSystemService(
+ registry=registry,
+ workspace_root="/workspace",
+ backend=RemoteLargePdfBackend(),
+ )
+
+ runner = _make_runner(registry.list_all())
+ req = _make_tool_call_request("Read", {"file_path": "/workspace/huge.pdf"})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "ToolValidationError" in result.content
+ assert "too large" in result.content.lower()
+ assert result.additional_kwargs["tool_result_meta"]["error_code"] == "FILE_TOO_LARGE"
+
+ @pytest.mark.asyncio
+ async def test_filesystem_service_read_accepts_pdf_pages_argument(self, tmp_path):
+ pdf_bytes = b"%PDF-1.4\nnot-a-real-pdf\n"
+ local_pdf = tmp_path / "paged.pdf"
+ local_pdf.write_bytes(pdf_bytes)
+ expected = read_pdf(local_pdf, ReadLimits(), start_page=1, limit_pages=1).format_output()
+
+ registry = ToolRegistry()
+ FileSystemService(
+ registry=registry,
+ workspace_root=tmp_path,
+ )
+ runner = _make_runner(registry.list_all())
+ req = _make_tool_call_request("Read", {"file_path": str(local_pdf), "pages": "1"})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == expected
+
+ def test_layer3_handler_returns_soft_failure_text(self):
+ def soft_fail(**kwargs):
+ return "No files found"
+
+ entry = ToolEntry(
+ name="Glob",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Glob",
+ "parameters": {
+ "type": "object",
+ "required": ["pattern"],
+ "properties": {"pattern": {"type": "string"}},
+ },
+ },
+ handler=soft_fail,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Glob", {"pattern": "**/*.xyz"})
+ result = runner.wrap_tool_call(req, lambda r: MagicMock())
+ # Layer 3: plain text, no tags
+ assert result.content == "No files found"
+ assert "" not in result.content
+ assert "InputValidationError" not in result.content
+
+ def test_unknown_tool_falls_through_to_upstream(self):
+ runner = _make_runner([]) # empty registry
+ req = _make_tool_call_request("UnknownMCPTool", {})
+ upstream_called = []
+
+ def upstream(r):
+ upstream_called.append(r)
+ msg = MagicMock()
+ msg.content = "mcp result"
+ return msg
+
+ result = runner.wrap_tool_call(req, upstream)
+ assert upstream_called
+ assert result.content == "mcp result"
+
+ @pytest.mark.asyncio
+ async def test_non_mcp_post_tool_use_hook_sees_materialized_tool_message(self):
+ events = []
+
+ def local_handler(**kwargs):
+ return "plain success"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=local_handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def post_tool_use(message, request):
+ events.append((type(message).__name__, message.content, message.additional_kwargs["tool_result_meta"]["source"]))
+ return message
+
+ req.state.post_tool_use = post_tool_use
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == "plain success"
+ assert events == [("ToolMessage", "plain success", "local")]
+
+ @pytest.mark.asyncio
+ async def test_async_post_tool_use_hooks_run_in_parallel(self):
+ def local_handler(**kwargs):
+ return "plain success"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=local_handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ async def post_hook_one(message, request):
+ await asyncio.sleep(0.05)
+ return None
+
+ async def post_hook_two(message, request):
+ await asyncio.sleep(0.05)
+ return None
+
+ req.state.post_tool_use = [post_hook_one, post_hook_two]
+
+ started = time.perf_counter()
+ result = await runner.awrap_tool_call(req, AsyncMock())
+ elapsed = time.perf_counter() - started
+
+ assert result.content == "plain success"
+ assert elapsed < 0.09
+
+ @pytest.mark.asyncio
+ async def test_async_post_tool_use_hook_timeout_cancels_hook_and_preserves_result(self):
+ events = []
+
+ def local_handler(**kwargs):
+ return "plain success"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=local_handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+ req.state.hook_timeout_ms = 50
+
+ async def stuck_hook(message, request):
+ try:
+ await asyncio.Future()
+ except asyncio.CancelledError:
+ events.append("post-cancelled")
+ raise
+
+ req.state.post_tool_use = stuck_hook
+
+ started = time.perf_counter()
+ result = await runner.awrap_tool_call(req, AsyncMock())
+ elapsed = time.perf_counter() - started
+
+ assert result.content == "plain success"
+ assert elapsed < 0.2
+ assert events == ["post-cancelled"]
+
+ @pytest.mark.asyncio
+ async def test_async_pre_tool_use_hook_timeout_cancels_hook_and_preserves_execution(self):
+ events = []
+
+ def local_handler(**kwargs):
+ events.append("handler")
+ return "plain success"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=local_handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+ req.state.hook_timeout_ms = 50
+
+ async def stuck_hook(payload, request):
+ try:
+ await asyncio.Future()
+ except asyncio.CancelledError:
+ events.append("pre-cancelled")
+ raise
+
+ req.state.pre_tool_use = stuck_hook
+
+ started = time.perf_counter()
+ result = await runner.awrap_tool_call(req, AsyncMock())
+ elapsed = time.perf_counter() - started
+
+ assert result.content == "plain success"
+ assert elapsed < 0.2
+ assert events == ["pre-cancelled", "handler"]
+
+ @pytest.mark.asyncio
+ async def test_post_tool_use_failure_hook_runs_on_materialized_error_message(self):
+ seen = []
+
+ def bad_handler(**kwargs):
+ raise ValueError("disk full")
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=bad_handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def post_tool_use_failure(message, request):
+ seen.append((type(message).__name__, message.additional_kwargs["tool_result_meta"]["kind"]))
+ return message
+
+ req.state.post_tool_use_failure = post_tool_use_failure
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "" in result.content
+ assert seen == [("ToolMessage", "error")]
+
+ @pytest.mark.asyncio
+ async def test_permission_denied_result_keeps_distinct_metadata(self):
+ def denied_handler(**kwargs):
+ return tool_permission_denied(
+ "permission denied",
+ top_level_blocks=[{"type": "text", "text": "extra-block"}],
+ metadata={"policy": "workspace"},
+ )
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=denied_handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "permission denied"
+ assert meta["kind"] == "permission_denied"
+ assert meta["source"] == "local"
+ assert meta["top_level_blocks"] == [{"type": "text", "text": "extra-block"}]
+ assert meta["policy"] == "workspace"
+
+ @pytest.mark.asyncio
+ async def test_mcp_post_tool_use_hook_can_modify_result_before_materialization(self):
+ runner = _make_runner([]) # unknown tool => upstream/MCP path
+ req = _make_tool_call_request("mcp__server__tool", {})
+ req.state = MagicMock()
+ seen = []
+
+ def post_tool_use(payload, request):
+ seen.append(type(payload).__name__)
+ assert isinstance(payload, ToolResultEnvelope)
+ return ToolResultEnvelope(
+ kind=payload.kind,
+ content="hooked mcp result",
+ is_error=payload.is_error,
+ top_level_blocks=payload.top_level_blocks,
+ metadata={**payload.metadata, "hooked": True},
+ )
+
+ req.state.post_tool_use = post_tool_use
+
+ async def upstream(_request):
+ return ToolResultEnvelope(kind="success", content="raw mcp result")
+
+ result = await runner.awrap_tool_call(req, upstream)
+
+ assert seen == ["ToolResultEnvelope"]
+ assert result.content == "hooked mcp result"
+ assert result.additional_kwargs["tool_result_meta"]["source"] == "mcp"
+ assert result.additional_kwargs["tool_result_meta"]["hooked"] is True
+
+ @pytest.mark.asyncio
+ async def test_command_hook_denial_uses_permission_denied_result_path(self, tmp_path):
+ registry = ToolRegistry()
+ CommandService(
+ registry=registry,
+ workspace_root=tmp_path,
+ hooks=[DangerousCommandsHook()],
+ )
+ runner = ToolRunner(registry=registry)
+ req = _make_tool_call_request("Bash", {"command": "rm -rf /"})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert "SECURITY" in result.content
+ assert meta["kind"] == "permission_denied"
+ assert meta["source"] == "local"
+ assert meta["policy"] == "command_hook"
+
+ @pytest.mark.asyncio
+ async def test_command_hook_does_not_block_quoted_dangerous_text(self, tmp_path):
+ registry = ToolRegistry()
+ CommandService(
+ registry=registry,
+ workspace_root=tmp_path,
+ hooks=[DangerousCommandsHook(verbose=False)],
+ )
+ runner = ToolRunner(registry=registry)
+ req = _make_tool_call_request("Bash", {"command": 'echo "rm -rf /"'})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "SECURITY ERROR" not in result.content
+ assert "rm -rf /" in result.content
+
+ @pytest.mark.asyncio
+ async def test_command_hook_does_not_block_commented_dangerous_text(self, tmp_path):
+ registry = ToolRegistry()
+ CommandService(
+ registry=registry,
+ workspace_root=tmp_path,
+ hooks=[DangerousCommandsHook(verbose=False)],
+ )
+ runner = ToolRunner(registry=registry)
+ req = _make_tool_call_request("Bash", {"command": "echo hi # rm -rf /"})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "SECURITY ERROR" not in result.content
+ assert "hi" in result.content
+
+ @pytest.mark.asyncio
+ async def test_command_hook_blocks_obfuscated_dangerous_command_name_with_inline_quotes(self, tmp_path):
+ registry = ToolRegistry()
+ CommandService(
+ registry=registry,
+ workspace_root=tmp_path,
+ hooks=[DangerousCommandsHook(verbose=False)],
+ )
+ runner = ToolRunner(registry=registry)
+ req = _make_tool_call_request("Bash", {"command": 's"u"do echo hi'})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "SECURITY ERROR" in result.content
+ assert result.additional_kwargs["tool_result_meta"]["kind"] == "permission_denied"
+
+ @pytest.mark.asyncio
+ async def test_command_hook_blocks_ansi_c_quoted_obfuscation(self, tmp_path):
+ registry = ToolRegistry()
+ CommandService(
+ registry=registry,
+ workspace_root=tmp_path,
+ hooks=[DangerousCommandsHook(verbose=False)],
+ )
+ runner = ToolRunner(registry=registry)
+ req = _make_tool_call_request("Bash", {"command": "s$'udo' echo hi"})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "SECURITY ERROR" in result.content
+ assert result.additional_kwargs["tool_result_meta"]["kind"] == "permission_denied"
+
+ @pytest.mark.asyncio
+ async def test_registered_mcp_tool_executes_through_runner_with_mcp_source(self):
+ @tool
+ async def sample_mcp_tool(x: int) -> str:
+ """sample mcp"""
+ return f"mcp:{x}"
+
+ registry = ToolRegistry()
+ registry.register(_make_mcp_tool_entry(sample_mcp_tool))
+ runner = ToolRunner(registry=registry)
+ req = _make_tool_call_request("sample_mcp_tool", {"x": 3})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "mcp:3"
+ assert meta["source"] == "mcp"
+ assert meta["kind"] == "success"
+
+ @pytest.mark.asyncio
+ async def test_registered_mcp_tool_post_hook_sees_envelope_before_materialization(self):
+ @tool
+ async def sample_mcp_tool(x: int) -> str:
+ """sample mcp"""
+ return f"mcp:{x}"
+
+ registry = ToolRegistry()
+ registry.register(_make_mcp_tool_entry(sample_mcp_tool))
+ runner = ToolRunner(registry=registry)
+ req = _make_tool_call_request("sample_mcp_tool", {"x": 3})
+ req.state = MagicMock()
+ seen = []
+
+ def post_tool_use(payload, request):
+ seen.append(type(payload).__name__)
+ assert isinstance(payload, ToolResultEnvelope)
+ return payload
+
+ req.state.post_tool_use = post_tool_use
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert seen == ["ToolResultEnvelope"]
+ assert result.content == "mcp:3"
+ assert result.additional_kwargs["tool_result_meta"]["source"] == "mcp"
+
+ @pytest.mark.asyncio
+ async def test_registered_mcp_tool_preserves_content_blocks_before_spill(self):
+ @tool
+ async def sample_mcp_tool(x: int) -> list[dict[str, str]]:
+ """sample mcp"""
+ return [
+ {"type": "text", "text": f"mcp:{x}"},
+ {"type": "image", "base64": "QUJD", "mime_type": "image/png"},
+ ]
+
+ registry = ToolRegistry()
+ registry.register(_make_mcp_tool_entry(sample_mcp_tool))
+ runner = ToolRunner(registry=registry)
+ req = _make_tool_call_request("sample_mcp_tool", {"x": 3})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == [
+ {"type": "text", "text": "mcp:3"},
+ {"type": "image", "base64": "QUJD", "mime_type": "image/png"},
+ ]
+ assert result.additional_kwargs["tool_result_meta"]["source"] == "mcp"
+
+ @pytest.mark.asyncio
+ async def test_registered_mcp_hook_rematerialization_keeps_mcp_source(self):
+ @tool
+ async def sample_mcp_tool(x: int) -> str:
+ """sample mcp"""
+ return f"mcp:{x}"
+
+ registry = ToolRegistry()
+ registry.register(_make_mcp_tool_entry(sample_mcp_tool))
+ runner = ToolRunner(registry=registry)
+ req = _make_tool_call_request("sample_mcp_tool", {"x": 3})
+ req.state = MagicMock()
+
+ def post_tool_use(payload, request):
+ return ToolResultEnvelope(
+ kind="success",
+ content="hooked-remat",
+ metadata={"hooked": True},
+ )
+
+ req.state.post_tool_use = post_tool_use
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "hooked-remat"
+ assert meta["source"] == "mcp"
+ assert meta["hooked"] is True
+
+ @pytest.mark.asyncio
+ async def test_pre_tool_use_does_not_run_before_schema_validation(self):
+ events = []
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Write",
+ "parameters": {
+ "type": "object",
+ "required": ["path"],
+ "properties": {"path": {"type": "string"}},
+ },
+ },
+ handler=lambda path: f"ok:{path}",
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def pre_tool_use(payload, request):
+ events.append("pre")
+ return payload
+
+ req.state.pre_tool_use = pre_tool_use
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "InputValidationError" in result.content
+ assert events == []
+
+ @pytest.mark.asyncio
+ async def test_tool_specific_validation_runs_before_pre_tool_use_and_handler(self):
+ events = []
+
+ def validate_input(args, request):
+ events.append("tool-validate")
+ return {"path": args["path"], "normalized": True}
+
+ def handler(path, normalized=False):
+ events.append(("handler", path, normalized))
+ return "ok"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Write",
+ "parameters": {
+ "type": "object",
+ "required": ["path"],
+ "properties": {"path": {"type": "string"}},
+ },
+ },
+ handler=handler,
+ source="test",
+ validate_input=validate_input,
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {"path": "/tmp/a"})
+ req.state = MagicMock()
+
+ def pre_tool_use(payload, request):
+ events.append(("pre", dict(payload["args"])))
+ return payload
+
+ req.state.pre_tool_use = pre_tool_use
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == "ok"
+ assert events == [
+ "tool-validate",
+ ("pre", {"path": "/tmp/a", "normalized": True}),
+ ("handler", "/tmp/a", True),
+ ]
+
+ @pytest.mark.asyncio
+ async def test_tool_specific_validation_failure_object_stops_before_handler(self):
+ events = []
+
+ def validate_input(args, request):
+ events.append("tool-validate")
+ return {"result": False, "message": "tool says no", "errorCode": "E_NO"}
+
+ def handler(**kwargs):
+ events.append(("handler", kwargs))
+ return "should-not-run"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Write",
+ "parameters": {
+ "type": "object",
+ "required": [],
+ "properties": {},
+ },
+ },
+ handler=handler,
+ source="test",
+ validate_input=validate_input,
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "ToolValidationError" in result.content
+ assert "tool says no" in result.content
+ assert result.additional_kwargs["tool_result_meta"]["error_type"] == "tool_input_validation"
+ assert result.additional_kwargs["tool_result_meta"]["error_code"] == "E_NO"
+ assert events == ["tool-validate"]
+
+ @pytest.mark.asyncio
+ async def test_filesystem_list_dir_outside_workspace_fails_with_structured_error_code(self, tmp_path):
+ registry = ToolRegistry()
+ FileSystemService(
+ registry=registry,
+ workspace_root=tmp_path,
+ )
+ runner = _make_runner(registry.list_all())
+ outside = (tmp_path.parent / "outside").resolve()
+ req = _make_tool_call_request("list_dir", {"path": str(outside)})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "ToolValidationError" in result.content
+ assert "outside workspace" in result.content.lower()
+ assert result.additional_kwargs["tool_result_meta"]["error_type"] == "tool_input_validation"
+ assert result.additional_kwargs["tool_result_meta"]["error_code"] == "PATH_OUTSIDE_WORKSPACE"
+
+ @pytest.mark.asyncio
+ async def test_filesystem_read_large_file_fails_before_handler_as_tool_validation(self, tmp_path):
+ class LargeFileBackend(FileSystemBackend):
+ is_remote = False
+
+ def __init__(self):
+ self.read_calls = 0
+
+ def read_file(self, path: str) -> FileReadResult:
+ self.read_calls += 1
+ raise AssertionError("read_file should not run for oversize preflight")
+
+ def write_file(self, path: str, content: str) -> FileWriteResult:
+ return FileWriteResult(success=True)
+
+ def file_exists(self, path: str) -> bool:
+ return True
+
+ def file_mtime(self, path: str) -> float | None:
+ return None
+
+ def file_size(self, path: str) -> int | None:
+ return 11 * 1024 * 1024
+
+ def is_dir(self, path: str) -> bool:
+ return False
+
+ def list_dir(self, path: str) -> DirListResult:
+ return DirListResult(entries=[])
+
+ backend = LargeFileBackend()
+ registry = ToolRegistry()
+ FileSystemService(
+ registry=registry,
+ workspace_root=tmp_path,
+ backend=backend,
+ )
+ runner = _make_runner(registry.list_all())
+ target = (tmp_path / "too-large.txt").resolve()
+ req = _make_tool_call_request("Read", {"file_path": str(target)})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "ToolValidationError" in result.content
+ assert "too large" in result.content.lower()
+ assert result.additional_kwargs["tool_result_meta"]["error_type"] == "tool_input_validation"
+ assert result.additional_kwargs["tool_result_meta"]["error_code"] == "FILE_TOO_LARGE"
+ assert backend.read_calls == 0
+
+ @pytest.mark.asyncio
+ async def test_hook_allow_cannot_bypass_permission_deny_rule(self):
+ def handler(**kwargs):
+ raise AssertionError("handler should not run when permission denies")
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def pre_tool_use(payload, request):
+ return {"permission": "allow"}
+
+ def can_use_tool(name, args, context, request):
+ return {"decision": "deny", "message": "settings deny"}
+
+ req.state.pre_tool_use = pre_tool_use
+ req.state.can_use_tool = can_use_tool
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "settings deny"
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+
+ @pytest.mark.asyncio
+ async def test_parallel_pre_tool_use_hooks_keep_deny_precedence(self):
+ def handler(**kwargs):
+ raise AssertionError("handler should not run when a hook denies")
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ async def allow_hook(payload, request):
+ await asyncio.sleep(0.01)
+ return {"permission": "allow", "message": "hook allow"}
+
+ async def deny_hook(payload, request):
+ await asyncio.sleep(0.01)
+ return {"decision": "deny", "message": "hook deny"}
+
+ req.state.pre_tool_use = [allow_hook, deny_hook]
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "hook deny"
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+
+ @pytest.mark.asyncio
+ async def test_pre_tool_use_can_update_args_before_permission_and_handler(self):
+ seen = []
+
+ def handler(path):
+ seen.append(("handler", path))
+ return f"ok:{path}"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={
+ "name": "Write",
+ "parameters": {
+ "type": "object",
+ "required": ["path"],
+ "properties": {"path": {"type": "string"}},
+ },
+ },
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {"path": "raw"})
+ req.state = MagicMock()
+
+ def pre_tool_use(payload, request):
+ return {"args": {"path": "mutated"}}
+
+ def can_use_tool(name, args, context, request):
+ seen.append(("permission", args["path"]))
+ return {"decision": "allow"}
+
+ req.state.pre_tool_use = pre_tool_use
+ req.state.can_use_tool = can_use_tool
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == "ok:mutated"
+ assert seen == [("permission", "mutated"), ("handler", "mutated")]
+
+ @pytest.mark.asyncio
+ async def test_async_pre_tool_use_hooks_run_in_parallel(self):
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=lambda: "ok",
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ async def hook_one(payload, request):
+ await asyncio.sleep(0.05)
+ return None
+
+ async def hook_two(payload, request):
+ await asyncio.sleep(0.05)
+ return None
+
+ req.state.pre_tool_use = [hook_one, hook_two]
+
+ started = time.perf_counter()
+ result = await runner.awrap_tool_call(req, AsyncMock())
+ elapsed = time.perf_counter() - started
+
+ assert result.content == "ok"
+ assert elapsed < 0.09
+
+ @pytest.mark.asyncio
+ async def test_permission_checker_receives_permission_context_not_scheduler_flag(self):
+ seen = []
+
+ entry = ToolEntry(
+ name="Read",
+ mode=ToolMode.INLINE,
+ schema={"name": "Read", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=lambda: "ok",
+ source="test",
+ is_read_only=True,
+ is_concurrency_safe=True,
+ is_destructive=True,
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Read", {})
+ req.state = MagicMock()
+
+ def can_use_tool(name, args, context, request):
+ seen.append((context.is_read_only, context.is_destructive, hasattr(context, "is_concurrency_safe")))
+ return {"decision": "allow"}
+
+ req.state.can_use_tool = can_use_tool
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == "ok"
+ assert seen == [(True, True, False)]
+
+ @pytest.mark.asyncio
+ async def test_async_permission_checker_is_awaited_before_handler(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ raise AssertionError("handler should not run when async permission denies")
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ async def can_use_tool(name, args, context, request):
+ seen.append("checker")
+ return {"decision": "deny", "message": "async deny"}
+
+ req.state.can_use_tool = can_use_tool
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "async deny"
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+ assert seen == ["checker"]
+
+ def test_sync_wrap_tool_call_awaits_async_permission_checker(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ raise AssertionError("handler should not run when async permission denies on sync path")
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ async def can_use_tool(name, args, context, request):
+ seen.append("checker")
+ return {"decision": "deny", "message": "async deny sync-path"}
+
+ req.state.can_use_tool = can_use_tool
+
+ result = runner.wrap_tool_call(req, lambda _req: None)
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "async deny sync-path"
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+ assert seen == ["checker"]
+
+ @pytest.mark.asyncio
+ async def test_sync_wrap_tool_call_awaits_async_permission_checker_inside_running_loop(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ raise AssertionError("handler should not run when async permission denies on nested-loop sync path")
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ async def can_use_tool(name, args, context, request):
+ seen.append("checker")
+ return {"decision": "deny", "message": "async deny nested-loop"}
+
+ req.state.can_use_tool = can_use_tool
+
+ result = runner.wrap_tool_call(req, lambda _req: None)
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "async deny nested-loop"
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+ assert seen == ["checker"]
+
+ def test_sync_wrap_tool_call_awaits_async_post_tool_use_hook(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ return "plain success"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ async def post_hook(result, request):
+ seen.append("post-start")
+ await asyncio.sleep(0)
+ seen.append("post-end")
+ return result
+
+ req.state.post_tool_use = post_hook
+
+ result = runner.wrap_tool_call(req, lambda _req: None)
+
+ assert result.content == "plain success"
+ assert seen == ["handler", "post-start", "post-end"]
+
+ def test_sync_wrap_tool_call_awaits_async_pre_tool_use_hook(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ return "plain success"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ async def pre_hook(payload, request):
+ seen.append("pre-start")
+ await asyncio.sleep(0)
+ seen.append("pre-end")
+ return payload
+
+ req.state.pre_tool_use = pre_hook
+
+ result = runner.wrap_tool_call(req, lambda _req: None)
+
+ assert result.content == "plain success"
+ assert seen == ["pre-start", "pre-end", "handler"]
+
+ def test_sync_wrap_tool_call_times_out_async_post_tool_use_hook(self):
+ events = []
+
+ def handler():
+ return "plain success"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+ req.state.hook_timeout_ms = 50
+
+ async def stuck_hook(result, request):
+ try:
+ await asyncio.Future()
+ except asyncio.CancelledError:
+ events.append("post-cancelled")
+ raise
+
+ req.state.post_tool_use = stuck_hook
+
+ started = time.perf_counter()
+ result = runner.wrap_tool_call(req, lambda _req: MagicMock())
+ elapsed = time.perf_counter() - started
+
+ assert result.content == "plain success"
+ assert elapsed < 0.2
+ assert events == ["post-cancelled"]
+
+ @pytest.mark.asyncio
+ async def test_sync_wrap_tool_call_awaits_async_post_tool_use_hook_inside_running_loop(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ return "plain success"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ async def post_hook(result, request):
+ seen.append("post-start")
+ await asyncio.sleep(0)
+ seen.append("post-end")
+ return result
+
+ req.state.post_tool_use = post_hook
+
+ result = runner.wrap_tool_call(req, lambda _req: None)
+
+ assert result.content == "plain success"
+ assert seen == ["handler", "post-start", "post-end"]
+
+ @pytest.mark.asyncio
+ async def test_permission_request_hook_can_allow_without_creating_request(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ return "ok"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def can_use_tool(name, args, context, request):
+ seen.append("checker")
+ return {"decision": "ask", "message": "needs approval"}
+
+ def request_permission(*args, **kwargs):
+ raise AssertionError("request surface should not run when permission_request hook allows")
+
+ async def permission_request_hook(payload, request):
+ seen.append("permission-request-hook")
+ return {"decision": "allow"}
+
+ req.state.can_use_tool = can_use_tool
+ req.state.request_permission = request_permission
+ req.state.consume_permission_resolution = lambda *args, **kwargs: None
+ req.state.permission_request_hooks = permission_request_hook
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == "ok"
+ assert seen == ["checker", "permission-request-hook", "handler"]
+
+ def test_sync_wrap_tool_call_runs_permission_request_hook_before_prompt(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ return "ok"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def can_use_tool(name, args, context, request):
+ seen.append("checker")
+ return {"decision": "ask", "message": "needs approval"}
+
+ def request_permission(*args, **kwargs):
+ raise AssertionError("request surface should not run when permission_request hook denies")
+
+ async def permission_request_hook(payload, request):
+ seen.append("permission-request-hook")
+ return {"decision": "deny", "message": "hook blocked"}
+
+ req.state.can_use_tool = can_use_tool
+ req.state.request_permission = request_permission
+ req.state.consume_permission_resolution = lambda *args, **kwargs: None
+ req.state.permission_request_hooks = permission_request_hook
+
+ result = runner.wrap_tool_call(req, lambda _req: None)
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "hook blocked"
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+ assert seen == ["checker", "permission-request-hook"]
+
+ @pytest.mark.asyncio
+ async def test_sync_wrap_tool_call_runs_permission_request_hook_inside_running_loop(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ return "ok"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def can_use_tool(name, args, context, request):
+ seen.append("checker")
+ return {"decision": "ask", "message": "needs approval"}
+
+ def request_permission(*args, **kwargs):
+ raise AssertionError("request surface should not run when permission_request hook allows")
+
+ async def permission_request_hook(payload, request):
+ seen.append("permission-request-hook")
+ await asyncio.sleep(0)
+ return {"decision": "allow"}
+
+ req.state.can_use_tool = can_use_tool
+ req.state.request_permission = request_permission
+ req.state.consume_permission_resolution = lambda *args, **kwargs: None
+ req.state.permission_request_hooks = permission_request_hook
+
+ result = runner.wrap_tool_call(req, lambda _req: None)
+
+ assert result.content == "ok"
+ assert seen == ["checker", "permission-request-hook", "handler"]
+
+ @pytest.mark.asyncio
+ async def test_ask_permission_returns_permission_request_when_request_surface_exists(self):
+ requests = {}
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=lambda: "ok",
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def can_use_tool(name, args, context, request):
+ return {"decision": "ask", "message": "needs approval"}
+
+ def request_permission(name, args, context, request, message):
+ requests["perm-1"] = {
+ "thread_id": "thread-a",
+ "tool_name": name,
+ "args": dict(args),
+ "message": message,
+ }
+ return {"request_id": "perm-1"}
+
+ req.state.can_use_tool = can_use_tool
+ req.state.request_permission = request_permission
+ req.state.consume_permission_resolution = lambda *args, **kwargs: None
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "needs approval"
+ assert meta["kind"] == "permission_request"
+ assert meta["decision"] == "ask"
+ assert meta["request_id"] == "perm-1"
+ assert requests["perm-1"]["message"] == "needs approval"
+
+ @pytest.mark.asyncio
+ async def test_ask_permission_fails_loud_when_request_surface_is_missing(self):
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=lambda: "ok",
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def can_use_tool(name, args, context, request):
+ return {
+ "decision": "ask",
+ "message": "Permission required by rule: Write. No interactive permission resolver is available for this run.",
+ }
+
+ req.state.can_use_tool = can_use_tool
+ req.state.request_permission = None
+ req.state.consume_permission_resolution = lambda *args, **kwargs: None
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "Permission required by rule: Write. No interactive permission resolver is available for this run."
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+
+ def test_sync_ask_permission_fails_loud_when_request_surface_is_missing(self):
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=lambda: "ok",
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def can_use_tool(name, args, context, request):
+ return {
+ "decision": "ask",
+ "message": "Permission required by rule: Write. No interactive permission resolver is available for this run.",
+ }
+
+ req.state.can_use_tool = can_use_tool
+ req.state.request_permission = None
+ req.state.consume_permission_resolution = lambda *args, **kwargs: None
+
+ result = runner.wrap_tool_call(req, lambda _req: None)
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "Permission required by rule: Write. No interactive permission resolver is available for this run."
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+
+ @pytest.mark.asyncio
+ async def test_consumed_permission_resolution_allows_single_retry_without_reprompt(self):
+ seen = []
+ resolution = {"decision": "allow", "message": "approved"}
+
+ def handler():
+ seen.append("handler")
+ return "ok"
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def consume_permission_resolution(name, args, context, request):
+ nonlocal resolution
+ current = resolution
+ resolution = None
+ return current
+
+ def can_use_tool(name, args, context, request):
+ seen.append("checker")
+ return {"decision": "ask", "message": "needs approval"}
+
+ req.state.consume_permission_resolution = consume_permission_resolution
+ req.state.can_use_tool = can_use_tool
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == "ok"
+ assert seen == ["checker", "handler"]
+
+ @pytest.mark.asyncio
+ async def test_stale_resolved_allow_does_not_override_current_async_deny(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ raise AssertionError("handler should not run when current deny overrides stale approval")
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def consume_permission_resolution(name, args, context, request):
+ seen.append("resolution")
+ return {"decision": "allow", "message": "approved earlier"}
+
+ def can_use_tool(name, args, context, request):
+ seen.append("checker")
+ return {"decision": "deny", "message": "deny now"}
+
+ req.state.consume_permission_resolution = consume_permission_resolution
+ req.state.can_use_tool = can_use_tool
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "deny now"
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+ assert seen == ["checker"]
+
+ def test_stale_resolved_allow_does_not_override_current_sync_deny(self):
+ seen = []
+
+ def handler():
+ seen.append("handler")
+ raise AssertionError("handler should not run when current deny overrides stale approval")
+
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=handler,
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ def consume_permission_resolution(name, args, context, request):
+ seen.append("resolution")
+ return {"decision": "allow", "message": "approved earlier"}
+
+ def can_use_tool(name, args, context, request):
+ seen.append("checker")
+ return {"decision": "deny", "message": "deny now"}
+
+ req.state.consume_permission_resolution = consume_permission_resolution
+ req.state.can_use_tool = can_use_tool
+
+ result = runner.wrap_tool_call(req, lambda _req: None)
+
+ meta = result.additional_kwargs["tool_result_meta"]
+ assert result.content == "deny now"
+ assert meta["kind"] == "permission_denied"
+ assert meta["decision"] == "deny"
+ assert seen == ["checker"]
+
+ @pytest.mark.asyncio
+ async def test_destructive_metadata_is_advisory_not_runtime_deny(self):
+ entry = ToolEntry(
+ name="Write",
+ mode=ToolMode.INLINE,
+ schema={"name": "Write", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=lambda: "ok",
+ source="test",
+ is_destructive=True,
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Write", {})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == "ok"
+
+ @pytest.mark.asyncio
+ async def test_runner_injects_tool_context_into_handler_when_requested(self):
+ entry = ToolEntry(
+ name="Agent",
+ mode=ToolMode.INLINE,
+ schema={"name": "Agent", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=lambda tool_context: f"context:{tool_context.turn_id}",
+ source="test",
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("Agent", {})
+ app_state = AppState()
+ req.state = ToolUseContext(
+ bootstrap=BootstrapConfig(workspace_root="/tmp/workspace", model_name="gpt-test"),
+ get_app_state=app_state.get_state,
+ set_app_state=app_state.set_state,
+ )
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert result.content == f"context:{req.state.turn_id}"
+
+ @pytest.mark.asyncio
+ async def test_runner_maps_context_schema_fields_into_handler_kwargs(self):
+ seen = {}
+
+ def needs_ctx(*, boot):
+ seen["boot"] = boot
+ return f"boot:{boot}"
+
+ entry = ToolEntry(
+ name="NeedsCtx",
+ mode=ToolMode.INLINE,
+ schema={"name": "NeedsCtx", "parameters": {"type": "object", "required": [], "properties": {}}},
+ handler=needs_ctx,
+ source="test",
+ context_schema={"boot": "bootstrap.model_name"},
+ )
+ runner = _make_runner([entry])
+ req = _make_tool_call_request("NeedsCtx", {})
+ app_state = AppState()
+ req.state = ToolUseContext(
+ bootstrap=BootstrapConfig(workspace_root="/tmp/workspace", model_name="MODEL_X"),
+ get_app_state=app_state.get_state,
+ set_app_state=app_state.set_state,
+ )
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert seen == {"boot": "MODEL_X"}
+ assert result.content == "boot:MODEL_X"
+
+
+class TestToolRunnerInlineInjection:
+ """P1: ToolRunner injects inline schemas into model call."""
+
+ def test_inline_schemas_injected(self):
+ entry = ToolEntry(
+ name="Read",
+ mode=ToolMode.INLINE,
+ schema={"name": "Read", "description": "read file"},
+ handler=lambda: "ok",
+ source="test",
+ )
+ runner = _make_runner([entry])
+
+ # Build a mock ModelRequest
+ request = MagicMock()
+ request.tools = []
+
+ captured = []
+
+ def handler(req):
+ captured.append(req)
+ return MagicMock()
+
+ request.override.return_value = request
+ runner.wrap_model_call(request, handler)
+
+ # Should have called override with tools containing Read
+ assert request.override.called
+ call_kwargs = request.override.call_args
+ _tools_arg = call_kwargs[1].get("tools") or (call_kwargs[0][0] if call_kwargs[0] else None)
+ # override was called — inline tools were injected
+
+ def test_deferred_schemas_not_injected(self):
+ deferred = ToolEntry(
+ name="TaskCreate",
+ mode=ToolMode.DEFERRED,
+ schema={"name": "TaskCreate", "description": "create task"},
+ handler=lambda: "ok",
+ source="test",
+ )
+ runner = _make_runner([deferred])
+ schemas = runner._registry.get_inline_schemas()
+ assert all(s["name"] != "TaskCreate" for s in schemas)
+
+
+# ---------------------------------------------------------------------------
+# P1: tool_modes from config honored
+# ---------------------------------------------------------------------------
+
+
+class TestToolModeFromConfig:
+ """Verify tool_modes config is applied during service init."""
+
+ def test_task_service_registers_deferred(self, tmp_path):
+ reg = ToolRegistry()
+ from core.tools.task.service import TaskService
+
+ _svc = TaskService(registry=reg, db_path=tmp_path / "test.db")
+ # TaskCreate/TaskUpdate/TaskList/TaskGet should be DEFERRED
+ for tool_name in ["TaskCreate", "TaskGet", "TaskList", "TaskUpdate"]:
+ entry = reg.get(tool_name)
+ assert entry is not None, f"{tool_name} not registered"
+ assert entry.mode == ToolMode.DEFERRED, f"{tool_name} should be DEFERRED, got {entry.mode}"
+
+ def test_search_service_registers_inline(self, tmp_path):
+ reg = ToolRegistry()
+ from core.tools.search.service import SearchService
+
+ _svc = SearchService(registry=reg, workspace_root=tmp_path)
+ for tool_name in ["Grep", "Glob"]:
+ entry = reg.get(tool_name)
+ assert entry is not None, f"{tool_name} not registered"
+ assert entry.mode == ToolMode.INLINE, f"{tool_name} should be INLINE, got {entry.mode}"
+
+ def test_task_service_read_only_queries_are_concurrency_safe(self, tmp_path):
+ reg = ToolRegistry()
+ from core.tools.task.service import TaskService
+
+ _svc = TaskService(registry=reg, db_path=tmp_path / "test.db")
+
+ for tool_name in ["TaskGet", "TaskList"]:
+ entry = reg.get(tool_name)
+ assert entry is not None, f"{tool_name} not registered"
+ assert entry.is_read_only is True
+ assert entry.is_concurrency_safe is True
+
+
+class TestToolSearchService:
+ def test_tool_search_schema_says_exact_lookup_is_for_deferred_tools(self):
+ reg = ToolRegistry()
+ ToolSearchService(reg)
+
+ schema = reg.get("tool_search").get_schema()
+
+ assert "deferred" in schema["description"].lower()
+ assert "deferred" in schema["parameters"]["properties"]["query"]["description"].lower()
+
+ def _make_ctx(self) -> ToolUseContext:
+ app = AppState()
+ return ToolUseContext(
+ bootstrap=BootstrapConfig(workspace_root="/tmp", model_name="test-model"),
+ get_app_state=lambda: app,
+ set_app_state=lambda fn: None,
+ )
+
+ def test_tool_search_keyword_results_are_capped_to_five(self):
+ reg = ToolRegistry()
+ for index in range(7):
+ reg.register(
+ ToolEntry(
+ name=f"Deferred{index}",
+ mode=ToolMode.DEFERRED,
+ schema={"name": f"Deferred{index}", "description": "alpha helper"},
+ handler=lambda: "ok",
+ source="test",
+ )
+ )
+ ToolSearchService(reg)
+ runner = _make_runner(reg.list_all())
+ req = ToolCallRequest(
+ tool_call={"name": "tool_search", "args": {"query": "alpha"}, "id": "tc-search"},
+ state=self._make_ctx(),
+ )
+
+ result = runner.wrap_tool_call(req, lambda r: MagicMock())
+
+ payload = json.loads(result.content)
+ assert len(payload) == 5
+
+ def test_tool_search_excludes_inline_tools(self):
+ reg = ToolRegistry()
+ reg.register(
+ ToolEntry(
+ name="Read",
+ mode=ToolMode.INLINE,
+ schema={"name": "Read", "description": "read file content"},
+ handler=lambda: "read",
+ source="test",
+ )
+ )
+ reg.register(
+ ToolEntry(
+ name="TaskCreate",
+ mode=ToolMode.DEFERRED,
+ schema={"name": "TaskCreate", "description": "create task"},
+ handler=lambda: "task",
+ source="test",
+ )
+ )
+ ToolSearchService(reg)
+ ctx = self._make_ctx()
+ runner = _make_runner(reg.list_all())
+ req = ToolCallRequest(
+ tool_call={"name": "tool_search", "args": {"query": "read"}, "id": "tc-search"},
+ state=ctx,
+ )
+
+ result = runner.wrap_tool_call(req, lambda r: MagicMock())
+
+ assert json.loads(result.content) == []
+ assert ctx.discovered_tool_names == set()
+
+ def test_tool_search_exact_select_fails_loudly_for_inline_tools(self):
+ reg = ToolRegistry()
+ reg.register(
+ ToolEntry(
+ name="Read",
+ mode=ToolMode.INLINE,
+ schema={"name": "Read", "description": "read file content"},
+ handler=lambda: "read",
+ source="test",
+ )
+ )
+ reg.register(
+ ToolEntry(
+ name="TaskCreate",
+ mode=ToolMode.DEFERRED,
+ schema={"name": "TaskCreate", "description": "create task"},
+ handler=lambda: "task",
+ source="test",
+ )
+ )
+ ToolSearchService(reg)
+ runner = _make_runner(reg.list_all())
+ req = ToolCallRequest(
+ tool_call={"name": "tool_search", "args": {"query": "select:Read,TaskCreate"}, "id": "tc-search"},
+ state=self._make_ctx(),
+ )
+
+ result = runner.wrap_tool_call(req, lambda r: MagicMock())
+
+ assert "" in result.content
+ assert "Read" in result.content
+ assert "inline" in result.content.lower()
+ assert "TaskCreate" not in result.content
+
+
+class TestWebToolRegistration:
+ def test_web_tools_are_deferred_not_inline(self):
+ reg = ToolRegistry()
+ WebService(registry=reg)
+
+ assert reg.get("WebSearch").mode == ToolMode.DEFERRED
+ assert reg.get("WebFetch").mode == ToolMode.DEFERRED
+ assert [schema["name"] for schema in reg.get_inline_schemas()] == []
+
+ @pytest.mark.asyncio
+ async def test_web_search_schema_uses_allowed_and_blocked_domains(self):
+ reg = ToolRegistry()
+ service = WebService(registry=reg)
+ seen: dict[str, object] = {}
+
+ class _FakeSearcher:
+ async def search(self, *, query, max_results, include_domains=None, exclude_domains=None):
+ seen["query"] = query
+ seen["max_results"] = max_results
+ seen["include_domains"] = include_domains
+ seen["exclude_domains"] = exclude_domains
+ return SimpleNamespace(error=None, format_output=lambda: "fake results")
+
+ service._searchers = [("fake", _FakeSearcher())]
+
+ schema = reg.get("WebSearch").schema
+ props = schema["parameters"]["properties"]
+ assert "allowed_domains" in props
+ assert "blocked_domains" in props
+ assert "include_domains" not in props
+ assert "exclude_domains" not in props
+
+ result = await service._web_search(
+ query="docs",
+ allowed_domains=["example.com"],
+ blocked_domains=["bad.com"],
+ )
+
+ assert result == "fake results"
+ assert seen["include_domains"] == ["example.com"]
+ assert seen["exclude_domains"] == ["bad.com"]
+
+ def test_web_search_schema_carries_query_and_max_result_constraints(self):
+ reg = ToolRegistry()
+ WebService(registry=reg)
+
+ schema = reg.get("WebSearch").get_schema()
+ props = schema["parameters"]["properties"]
+
+ assert props["query"]["minLength"] == 1
+ assert props["max_results"]["minimum"] == 1
+ assert props["max_results"]["maximum"] == 10
+
+ @pytest.mark.asyncio
+ async def test_web_search_rejects_out_of_range_max_results_at_validation_layer(self):
+ reg = ToolRegistry()
+ WebService(registry=reg)
+ runner = _make_runner(reg.list_all())
+ req = _make_tool_call_request("WebSearch", {"query": "docs", "max_results": 11})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "InputValidationError" in result.content
+ assert "max_results" in result.content
+ assert "at most 10" in result.content
+
+ def test_web_fetch_schema_carries_non_empty_url_and_prompt_constraints(self):
+ reg = ToolRegistry()
+ WebService(registry=reg)
+
+ schema = reg.get("WebFetch").get_schema()
+ props = schema["parameters"]["properties"]
+
+ assert props["url"]["minLength"] == 1
+ assert props["prompt"]["minLength"] == 1
+
+ def test_list_dir_schema_uses_path(self, tmp_path):
+ reg = ToolRegistry()
+ FileSystemService(
+ registry=reg,
+ workspace_root=tmp_path,
+ )
+
+ schema = reg.get("list_dir").schema
+ props = schema["parameters"]["properties"]
+ assert "path" in props
+ assert "directory_path" not in props
+ assert schema["parameters"]["required"] == ["path"]
+
+ def test_bash_schema_carries_command_and_timeout_constraints(self, tmp_path):
+ reg = ToolRegistry()
+ CommandService(
+ registry=reg,
+ workspace_root=tmp_path,
+ )
+
+ schema = reg.get("Bash").get_schema()
+ props = schema["parameters"]["properties"]
+
+ assert props["command"]["minLength"] == 1
+ assert props["timeout"]["minimum"] == 1
+ assert props["timeout"]["maximum"] == 600000
+
+ @pytest.mark.asyncio
+ async def test_bash_rejects_out_of_range_timeout_at_validation_layer(self, tmp_path):
+ reg = ToolRegistry()
+ CommandService(
+ registry=reg,
+ workspace_root=tmp_path,
+ )
+ runner = _make_runner(reg.list_all())
+ req = _make_tool_call_request("Bash", {"command": "echo hi", "timeout": 600001})
+ req.state = MagicMock()
+
+ result = await runner.awrap_tool_call(req, AsyncMock())
+
+ assert "InputValidationError" in result.content
+ assert "timeout" in result.content
+ assert "at most 600000" in result.content
+
+ def test_can_auto_approve_only_for_read_only_non_destructive_tools(self):
+ assert can_auto_approve(ToolPermissionContext(is_read_only=True, is_destructive=False)) is True
+ assert can_auto_approve(ToolPermissionContext(is_read_only=False, is_destructive=False)) is False
+ assert can_auto_approve(ToolPermissionContext(is_read_only=True, is_destructive=True)) is False
diff --git a/tests/test_filesystem_extra_paths.py b/tests/Unit/filesystem/test_filesystem_extra_paths.py
similarity index 100%
rename from tests/test_filesystem_extra_paths.py
rename to tests/Unit/filesystem/test_filesystem_extra_paths.py
diff --git a/tests/Unit/filesystem/test_filesystem_service.py b/tests/Unit/filesystem/test_filesystem_service.py
new file mode 100644
index 000000000..a24a1455c
--- /dev/null
+++ b/tests/Unit/filesystem/test_filesystem_service.py
@@ -0,0 +1,395 @@
+from __future__ import annotations
+
+import threading
+import time
+from pathlib import Path, PurePosixPath
+
+from core.runtime.registry import ToolRegistry
+from core.tools.filesystem.service import FileSystemService, _ReadFileStateCache
+from sandbox.interfaces.filesystem import DirListResult, FileReadResult, FileSystemBackend, FileWriteResult
+
+
+def _make_service(
+ workspace: Path,
+ *,
+ max_read_cache_entries: int = 100,
+ max_edit_file_size: int | None = None,
+) -> FileSystemService:
+ return FileSystemService(
+ registry=ToolRegistry(),
+ workspace_root=workspace,
+ max_read_cache_entries=max_read_cache_entries,
+ max_edit_file_size=max_edit_file_size,
+ )
+
+
+def test_edit_rejects_if_last_read_was_partial_view(tmp_path: Path):
+ service = _make_service(tmp_path)
+ target = tmp_path / "sample.txt"
+ target.write_text("alpha\nbeta\ngamma\n", encoding="utf-8")
+
+ read_result = service._read_file(str(target), offset=2, limit=1)
+ assert " FileReadResult:
+ before = self._content
+ self._content = "alpha\nEXTERNAL\n"
+ self._mtime = 2.0
+ return FileReadResult(content=before, size=len(before))
+
+ def write_file(self, path: str, content: str) -> FileWriteResult:
+ self.writes.append(content)
+ self._content = content
+ return FileWriteResult(success=True)
+
+ def file_exists(self, path: str) -> bool:
+ return True
+
+ def file_mtime(self, path: str) -> float | None:
+ return self._mtime
+
+ def file_size(self, path: str) -> int | None:
+ return len(self._content.encode("utf-8"))
+
+ def is_dir(self, path: str) -> bool:
+ return False
+
+ def list_dir(self, path: str) -> DirListResult:
+ return DirListResult(entries=[])
+
+ backend = RacingBackend()
+ service = FileSystemService(
+ registry=ToolRegistry(),
+ workspace_root=tmp_path,
+ backend=backend,
+ )
+ target = (tmp_path / "race.txt").resolve()
+ service._read_files.set(
+ target,
+ state=service._read_files.make_state(timestamp=1.0, is_partial=False),
+ )
+
+ edit_result = service._edit_file(
+ str(target),
+ old_string="beta",
+ new_string="BETA",
+ )
+
+ assert "modified since last read" in edit_result
+ assert backend.writes == []
+ assert backend._content == "alpha\nEXTERNAL\n"
+
+
+def test_concurrent_edits_do_not_both_commit_from_same_stale_read(tmp_path: Path):
+ class ConcurrentBackend(FileSystemBackend):
+ is_remote = False
+
+ def __init__(self):
+ self._mtime = 1.0
+ self._content = "alpha\nbeta\n"
+ self._write_lock = threading.Lock()
+ self.writes: list[str] = []
+
+ def read_file(self, path: str) -> FileReadResult:
+ return FileReadResult(content=self._content, size=len(self._content))
+
+ def write_file(self, path: str, content: str) -> FileWriteResult:
+ time.sleep(0.05)
+ with self._write_lock:
+ self.writes.append(content)
+ self._content = content
+ self._mtime += 1.0
+ return FileWriteResult(success=True)
+
+ def file_exists(self, path: str) -> bool:
+ return True
+
+ def file_mtime(self, path: str) -> float | None:
+ return self._mtime
+
+ def file_size(self, path: str) -> int | None:
+ return len(self._content.encode("utf-8"))
+
+ def is_dir(self, path: str) -> bool:
+ return False
+
+ def list_dir(self, path: str) -> DirListResult:
+ return DirListResult(entries=[])
+
+ backend = ConcurrentBackend()
+ service = FileSystemService(
+ registry=ToolRegistry(),
+ workspace_root=tmp_path,
+ backend=backend,
+ )
+ target = (tmp_path / "race.txt").resolve()
+ service._read_files.set(
+ target,
+ state=service._read_files.make_state(timestamp=1.0, is_partial=False),
+ )
+
+ results: list[str] = []
+
+ def run_edit(new_string: str) -> None:
+ results.append(
+ service._edit_file(
+ str(target),
+ old_string="beta",
+ new_string=new_string,
+ )
+ )
+
+ t1 = threading.Thread(target=run_edit, args=("BETA-ONE",))
+ t2 = threading.Thread(target=run_edit, args=("BETA-TWO",))
+ t1.start()
+ t2.start()
+ t1.join()
+ t2.join()
+
+ success_count = sum("File edited" in result for result in results)
+ failure_count = sum(("modified since last read" in result) or ("String not found in file" in result) for result in results)
+
+ assert success_count == 1
+ assert failure_count == 1
+ assert len(backend.writes) == 1
+
+
+def test_remote_edit_does_not_trust_false_negative_exists_probe(tmp_path: Path):
+ class FlakyRemoteBackend(FileSystemBackend):
+ is_remote = True
+
+ def __init__(self):
+ self._content = "result = 3\n"
+ self.writes: list[str] = []
+
+ def read_file(self, path: str) -> FileReadResult:
+ return FileReadResult(content=self._content, size=len(self._content))
+
+ def write_file(self, path: str, content: str) -> FileWriteResult:
+ self.writes.append(content)
+ self._content = content
+ return FileWriteResult(success=True)
+
+ def file_exists(self, path: str) -> bool:
+ return False
+
+ def file_mtime(self, path: str) -> float | None:
+ return None
+
+ def file_size(self, path: str) -> int | None:
+ return len(self._content.encode("utf-8"))
+
+ def is_dir(self, path: str) -> bool:
+ return False
+
+ def list_dir(self, path: str) -> DirListResult:
+ return DirListResult(entries=[])
+
+ backend = FlakyRemoteBackend()
+ service = FileSystemService(
+ registry=ToolRegistry(),
+ workspace_root=Path("/home/daytona"),
+ backend=backend,
+ )
+ target = PurePosixPath("/home/daytona/interleave.py")
+ service._read_files.set(
+ target,
+ state=service._read_files.make_state(timestamp=None, is_partial=False),
+ )
+
+ edit_result = service._edit_file(
+ str(target),
+ old_string="result = 3",
+ new_string="result = 5",
+ )
+
+ assert "File edited" in edit_result
+ assert backend.writes == ["result = 5\n"]
diff --git a/tests/test_read_file_limits.py b/tests/Unit/filesystem/test_read_file_limits.py
similarity index 100%
rename from tests/test_read_file_limits.py
rename to tests/Unit/filesystem/test_read_file_limits.py
diff --git a/tests/test_monitor_resource_overview_cache.py b/tests/Unit/monitor/test_monitor_resource_overview_cache.py
similarity index 55%
rename from tests/test_monitor_resource_overview_cache.py
rename to tests/Unit/monitor/test_monitor_resource_overview_cache.py
index d0426c967..2f0440fb6 100644
--- a/tests/test_monitor_resource_overview_cache.py
+++ b/tests/Unit/monitor/test_monitor_resource_overview_cache.py
@@ -53,3 +53,50 @@ def _raise():
assert degraded["providers"][0]["id"] == "docker"
assert degraded["summary"]["refresh_status"] == "error"
assert degraded["summary"]["refresh_error"] == "probe failed"
+
+
+def test_resource_overview_cache_refreshes_when_live_session_counts_drift(monkeypatch):
+ cache.clear_resource_overview_cache()
+
+ stale_payload = {
+ "summary": {
+ "snapshot_at": "2026-03-03T00:00:00Z",
+ "total_providers": 1,
+ "active_providers": 0,
+ "unavailable_providers": 0,
+ "running_sessions": 0,
+ },
+ "providers": [
+ {
+ "id": "local",
+ "sessions": [],
+ "telemetry": {"running": {"used": 0}},
+ }
+ ],
+ }
+ fresh_payload = {
+ "summary": {
+ "snapshot_at": "2026-03-03T00:01:00Z",
+ "total_providers": 1,
+ "active_providers": 1,
+ "unavailable_providers": 0,
+ "running_sessions": 1,
+ },
+ "providers": [
+ {
+ "id": "local",
+ "sessions": [{"id": "lease-1:m_thread"}],
+ "telemetry": {"running": {"used": 1}},
+ }
+ ],
+ }
+
+ calls = iter([stale_payload, fresh_payload])
+ monkeypatch.setattr(cache.resource_service, "list_resource_providers", lambda: next(calls))
+ monkeypatch.setattr(cache.resource_service, "visible_resource_session_stats", lambda: {"local": {"sessions": 1, "running": 1}})
+
+ cache.refresh_resource_overview_sync()
+ payload = cache.get_resource_overview_snapshot()
+
+ assert payload["providers"][0]["telemetry"]["running"]["used"] == 1
+ assert len(payload["providers"][0]["sessions"]) == 1
diff --git a/tests/test_monitor_resource_probe.py b/tests/Unit/monitor/test_monitor_resource_probe.py
similarity index 100%
rename from tests/test_monitor_resource_probe.py
rename to tests/Unit/monitor/test_monitor_resource_probe.py
diff --git a/tests/Unit/monitor/test_sqlite_sandbox_monitor_repo.py b/tests/Unit/monitor/test_sqlite_sandbox_monitor_repo.py
new file mode 100644
index 000000000..d8e7a217c
--- /dev/null
+++ b/tests/Unit/monitor/test_sqlite_sandbox_monitor_repo.py
@@ -0,0 +1,97 @@
+import sqlite3
+
+from storage.providers.sqlite.sandbox_monitor_repo import SQLiteSandboxMonitorRepo
+
+
+def _bootstrap_monitor_db(db_path):
+ conn = sqlite3.connect(db_path)
+ try:
+ conn.executescript(
+ """
+ CREATE TABLE sandbox_leases (
+ lease_id TEXT PRIMARY KEY,
+ provider_name TEXT,
+ desired_state TEXT,
+ observed_state TEXT,
+ current_instance_id TEXT,
+ created_at TEXT,
+ updated_at TEXT
+ );
+
+ CREATE TABLE abstract_terminals (
+ terminal_id TEXT PRIMARY KEY,
+ lease_id TEXT,
+ thread_id TEXT,
+ cwd TEXT,
+ created_at TEXT
+ );
+
+ CREATE TABLE chat_sessions (
+ chat_session_id TEXT PRIMARY KEY,
+ thread_id TEXT,
+ lease_id TEXT,
+ status TEXT,
+ started_at TEXT
+ );
+ """
+ )
+ conn.commit()
+ finally:
+ conn.close()
+
+
+def test_list_sessions_with_leases_keeps_raw_newest_terminal_truth(tmp_path):
+ db_path = tmp_path / "sandbox.db"
+ _bootstrap_monitor_db(db_path)
+
+ conn = sqlite3.connect(db_path)
+ try:
+ conn.execute(
+ """
+ INSERT INTO sandbox_leases (
+ lease_id, provider_name, desired_state, observed_state, current_instance_id, created_at, updated_at
+ ) VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ "lease-1",
+ "daytona_selfhost",
+ "paused",
+ "paused",
+ "instance-1",
+ "2026-04-05T13:00:00",
+ "2026-04-05T23:59:00",
+ ),
+ )
+ conn.executemany(
+ """
+ INSERT INTO abstract_terminals (terminal_id, lease_id, thread_id, cwd, created_at)
+ VALUES (?, ?, ?, ?, ?)
+ """,
+ [
+ ("term-parent", "lease-1", "thread-parent", "/home/daytona/files/app", "2026-04-05T13:35:08"),
+ ("term-subagent", "lease-1", "subagent-deadbeef", "/home/daytona/files/app", "2026-04-05T23:51:40"),
+ ],
+ )
+ conn.executemany(
+ """
+ INSERT INTO chat_sessions (chat_session_id, thread_id, lease_id, status, started_at)
+ VALUES (?, ?, ?, ?, ?)
+ """,
+ [
+ ("sess-parent", "thread-parent", "lease-1", "closed", "2026-04-05T23:24:06"),
+ ("sess-subagent", "subagent-deadbeef", "lease-1", "closed", "2026-04-05T23:51:42"),
+ ],
+ )
+ conn.commit()
+ finally:
+ conn.close()
+
+ repo = SQLiteSandboxMonitorRepo(db_path=db_path)
+ try:
+ rows = repo.list_sessions_with_leases()
+ finally:
+ repo.close()
+
+ assert len(rows) == 2
+ assert {row["thread_id"] for row in rows} == {"thread-parent", "subagent-deadbeef"}
+ assert all(row["lease_id"] == "lease-1" for row in rows)
diff --git a/tests/test_agentbay_capability_override.py b/tests/Unit/platform/test_agentbay_capability_override.py
similarity index 61%
rename from tests/test_agentbay_capability_override.py
rename to tests/Unit/platform/test_agentbay_capability_override.py
index f54d6ccd7..ed0d08b23 100644
--- a/tests/test_agentbay_capability_override.py
+++ b/tests/Unit/platform/test_agentbay_capability_override.py
@@ -6,13 +6,35 @@
def _install_fake_agentbay_module(monkeypatch) -> None:
fake_mod = types.ModuleType("agentbay")
+ fake_api_mod = types.ModuleType("agentbay.api")
+ fake_api_models_mod = types.ModuleType("agentbay.api.models")
class FakeAgentBay:
def __init__(self, api_key: str):
self.api_key = api_key
+ class FakeCreateSessionParams:
+ def __init__(self):
+ self.image_id = None
+ self.context_syncs = None
+
+ class FakeContextSync:
+ @staticmethod
+ def new(context_id: str, path: str):
+ return {"context_id": context_id, "path": path}
+
+ class FakeGetSessionRequest:
+ def __init__(self, authorization: str, session_id: str):
+ self.authorization = authorization
+ self.session_id = session_id
+
fake_mod.AgentBay = FakeAgentBay
+ fake_mod.CreateSessionParams = FakeCreateSessionParams
+ fake_mod.ContextSync = FakeContextSync
+ fake_api_models_mod.GetSessionRequest = FakeGetSessionRequest
monkeypatch.setitem(sys.modules, "agentbay", fake_mod)
+ monkeypatch.setitem(sys.modules, "agentbay.api", fake_api_mod)
+ monkeypatch.setitem(sys.modules, "agentbay.api.models", fake_api_models_mod)
def test_agentbay_capability_default_from_class(monkeypatch):
@@ -55,7 +77,12 @@ def screenshot(self):
return _ScreenshotResult()
class _FakeSession:
- computer = _FakeComputer()
+ def __init__(self) -> None:
+ self.session_id = "sess-1"
+ self.token = "tok"
+ self.link_url = "https://link"
+ self.mcpTools = [object()]
+ self.computer = _FakeComputer()
provider._sessions["sess-1"] = _FakeSession()
screenshot = provider.screenshot("sess-1")
diff --git a/tests/test_cron_api.py b/tests/Unit/platform/test_cron_api.py
similarity index 100%
rename from tests/test_cron_api.py
rename to tests/Unit/platform/test_cron_api.py
diff --git a/tests/test_cron_job_service.py b/tests/Unit/platform/test_cron_job_service.py
similarity index 100%
rename from tests/test_cron_job_service.py
rename to tests/Unit/platform/test_cron_job_service.py
diff --git a/tests/test_cron_service.py b/tests/Unit/platform/test_cron_service.py
similarity index 100%
rename from tests/test_cron_service.py
rename to tests/Unit/platform/test_cron_service.py
diff --git a/tests/Unit/platform/test_cron_tool_service.py b/tests/Unit/platform/test_cron_tool_service.py
new file mode 100644
index 000000000..69f546450
--- /dev/null
+++ b/tests/Unit/platform/test_cron_tool_service.py
@@ -0,0 +1,87 @@
+"""Tests for CronToolService — agent-callable cron CRUD surface."""
+
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import cast
+
+from core.runtime.registry import ToolRegistry
+from core.tools.cron.service import CronToolService
+
+
+def _redirect_cron_repo(monkeypatch, tmp_path: Path) -> None:
+ from storage.providers.sqlite.cron_job_repo import SQLiteCronJobRepo
+
+ db_path = tmp_path / "cron-tools.db"
+ monkeypatch.setattr(
+ "backend.web.services.cron_job_service.make_cron_job_repo",
+ lambda: SQLiteCronJobRepo(db_path=db_path),
+ )
+
+
+def test_cron_tool_registry_exposes_canonical_surface(monkeypatch, tmp_path: Path) -> None:
+ _redirect_cron_repo(monkeypatch, tmp_path)
+ registry = ToolRegistry()
+
+ CronToolService(registry)
+
+ for tool_name in ("CronCreate", "CronDelete", "CronList"):
+ assert registry.get(tool_name) is not None
+
+
+def test_cron_create_list_delete_roundtrip(monkeypatch, tmp_path: Path) -> None:
+ _redirect_cron_repo(monkeypatch, tmp_path)
+ registry = ToolRegistry()
+
+ CronToolService(registry)
+
+ create = registry.get("CronCreate")
+ list_jobs = registry.get("CronList")
+ delete = registry.get("CronDelete")
+
+ assert create is not None
+ assert list_jobs is not None
+ assert delete is not None
+
+ created_raw = create.handler(
+ name="nightly backup",
+ cron_expression="0 2 * * *",
+ description="backup prod",
+ task_template='{"title":"backup"}',
+ enabled=True,
+ )
+ created = json.loads(cast(str, created_raw))
+ job = created["item"]
+ assert job["name"] == "nightly backup"
+ assert job["cron_expression"] == "0 2 * * *"
+
+ listed = json.loads(cast(str, list_jobs.handler()))
+ assert listed["total"] == 1
+ assert listed["items"][0]["id"] == job["id"]
+
+ deleted = json.loads(cast(str, delete.handler(job_id=job["id"])))
+ assert deleted == {"ok": True, "id": job["id"]}
+
+ listed_after = json.loads(cast(str, list_jobs.handler()))
+ assert listed_after == {"items": [], "total": 0}
+
+
+def test_cron_create_requires_valid_json_template(monkeypatch, tmp_path: Path) -> None:
+ _redirect_cron_repo(monkeypatch, tmp_path)
+ registry = ToolRegistry()
+
+ CronToolService(registry)
+ create = registry.get("CronCreate")
+ assert create is not None
+
+ try:
+ create.handler(
+ name="broken",
+ cron_expression="0 2 * * *",
+ task_template="{not json}",
+ )
+ except ValueError as exc:
+ assert "task_template must be valid JSON" in str(exc)
+ else:
+ raise AssertionError("CronCreate should fail loudly on invalid JSON")
diff --git a/tests/Unit/platform/test_lsp_service.py b/tests/Unit/platform/test_lsp_service.py
new file mode 100644
index 000000000..8e851850e
--- /dev/null
+++ b/tests/Unit/platform/test_lsp_service.py
@@ -0,0 +1,136 @@
+from __future__ import annotations
+
+import json
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
+from core.runtime.registry import ToolRegistry
+from core.tools.lsp.service import LSPService
+
+
+class _FakeSession:
+ def __init__(self):
+ self.calls: list[tuple[str, str, int, int]] = []
+
+ async def request_definition(self, rel_path: str, line: int, character: int):
+ self.calls.append(("definition", rel_path, line, character))
+ return [
+ {
+ "absolutePath": "/tmp/example.py",
+ "range": {"start": {"line": line, "character": character}},
+ }
+ ]
+
+
+class _FakePyright:
+ def __init__(self):
+ self.calls: list[tuple[str, str, int, int]] = []
+
+ async def request_implementation(self, rel_path: str, line: int, character: int):
+ self.calls.append(("implementation", rel_path, line, character))
+ return [
+ {
+ "absolutePath": "/tmp/example.py",
+ "range": {"start": {"line": line, "character": character}},
+ }
+ ]
+
+
+def test_lsp_schema_uses_one_based_character_positions(tmp_path):
+ reg = ToolRegistry()
+ LSPService(registry=reg, workspace_root=tmp_path)
+
+ schema = reg.get("LSP").get_schema()
+ props = schema["parameters"]["properties"]
+
+ assert "character" in props
+ assert "column" not in props
+ assert "1-based" in props["line"]["description"]
+ assert "1-based" in props["character"]["description"]
+
+
+@pytest.mark.asyncio
+async def test_lsp_handle_converts_one_based_positions_to_zero_based_for_definition(tmp_path):
+ reg = ToolRegistry()
+ service = LSPService(registry=reg, workspace_root=tmp_path)
+ fake = _FakeSession()
+ service._get_session = AsyncMock(return_value=fake)
+
+ file_path = tmp_path / "example.py"
+ file_path.write_text("x = 1\n", encoding="utf-8")
+
+ result = await service._handle(
+ operation="goToDefinition",
+ file_path=str(file_path),
+ line=5,
+ character=3,
+ )
+
+ assert fake.calls == [("definition", "example.py", 4, 2)]
+ payload = json.loads(result)
+ assert payload[0]["line"] == 4
+ assert payload[0]["column"] == 2
+
+
+@pytest.mark.asyncio
+async def test_lsp_handle_offloads_gitignored_filtering_from_event_loop(tmp_path, monkeypatch):
+ reg = ToolRegistry()
+ service = LSPService(registry=reg, workspace_root=tmp_path)
+ fake = _FakeSession()
+ service._get_session = AsyncMock(return_value=fake)
+
+ file_path = tmp_path / "example.py"
+ file_path.write_text("x = 1\n", encoding="utf-8")
+
+ filter_results = [
+ {
+ "absolutePath": "/tmp/example.py",
+ "range": {"start": {"line": 0, "character": 0}},
+ }
+ ]
+ filter_mock = MagicMock(return_value=filter_results)
+ service._filter_gitignored_batched = filter_mock
+
+ calls: list[tuple[object, tuple[object, ...]]] = []
+
+ async def fake_to_thread(func, *args, **kwargs):
+ calls.append((func, args))
+ return func(*args, **kwargs)
+
+ monkeypatch.setattr("core.tools.lsp.service.asyncio.to_thread", fake_to_thread)
+
+ result = await service._handle(
+ operation="goToDefinition",
+ file_path=str(file_path),
+ line=1,
+ character=1,
+ )
+
+ assert calls == [(filter_mock, (filter_mock.call_args.args[0],))]
+ assert filter_mock.call_count == 1
+ payload = json.loads(result)
+ assert payload[0]["file"] == "/tmp/example.py"
+
+
+@pytest.mark.asyncio
+async def test_lsp_handle_converts_one_based_positions_to_zero_based_for_pyright_ops(tmp_path):
+ reg = ToolRegistry()
+ service = LSPService(registry=reg, workspace_root=tmp_path)
+ fake = _FakePyright()
+ service._get_pyright = AsyncMock(return_value=fake)
+
+ file_path = tmp_path / "example.py"
+ file_path.write_text("x = 1\n", encoding="utf-8")
+
+ result = await service._handle(
+ operation="goToImplementation",
+ file_path=str(file_path),
+ line=7,
+ character=4,
+ )
+
+ assert fake.calls == [("implementation", "example.py", 6, 3)]
+ payload = json.loads(result)
+ assert payload[0]["line"] == 6
+ assert payload[0]["column"] == 3
diff --git a/tests/test_marketplace_client.py b/tests/Unit/platform/test_marketplace_client.py
similarity index 100%
rename from tests/test_marketplace_client.py
rename to tests/Unit/platform/test_marketplace_client.py
diff --git a/tests/test_marketplace_models.py b/tests/Unit/platform/test_marketplace_models.py
similarity index 100%
rename from tests/test_marketplace_models.py
rename to tests/Unit/platform/test_marketplace_models.py
diff --git a/tests/Unit/platform/test_mcp_resource_tool_service.py b/tests/Unit/platform/test_mcp_resource_tool_service.py
new file mode 100644
index 000000000..1377c4cbd
--- /dev/null
+++ b/tests/Unit/platform/test_mcp_resource_tool_service.py
@@ -0,0 +1,191 @@
+from __future__ import annotations
+
+import json
+from collections.abc import Awaitable
+from contextlib import asynccontextmanager
+from types import SimpleNamespace
+from typing import Any, cast
+
+import pytest
+from pydantic import AnyUrl, TypeAdapter
+
+from core.runtime.registry import ToolRegistry
+from core.runtime.tool_result import ToolResultEnvelope
+from core.tools.mcp_resources.service import McpResourceToolService
+
+
+class _FakeSession:
+ def __init__(self, resources: list[SimpleNamespace], contents_by_uri: dict[str, list[SimpleNamespace]]) -> None:
+ self._resources = resources
+ self._contents_by_uri = contents_by_uri
+
+ async def list_resources(self):
+ return SimpleNamespace(resources=self._resources)
+
+ async def read_resource(self, uri: str):
+ return SimpleNamespace(contents=self._contents_by_uri[uri])
+
+
+class _FakeClient:
+ def __init__(self, sessions: dict[str, _FakeSession]) -> None:
+ self.connections = {name: object() for name in sessions}
+ self._sessions = sessions
+
+ @asynccontextmanager
+ async def session(self, server_name: str, *, auto_initialize: bool = True):
+ assert auto_initialize is True
+ yield self._sessions[server_name]
+
+
+def _unwrap_text(result: str | ToolResultEnvelope) -> str:
+ if isinstance(result, ToolResultEnvelope):
+ return cast(str, result.content)
+ return result
+
+
+async def _invoke_handler(handler: Any, /, **kwargs: Any) -> str | ToolResultEnvelope:
+ result = handler(**kwargs)
+ if isinstance(result, Awaitable):
+ return await result
+ return result
+
+
+@pytest.mark.asyncio
+async def test_mcp_resource_tool_service_registers_list_and_read_tools() -> None:
+ registry = ToolRegistry()
+ client = _FakeClient(
+ {
+ "demo": _FakeSession(
+ resources=[
+ SimpleNamespace(
+ uri="memo://alpha",
+ name="alpha",
+ mimeType="text/plain",
+ description="first resource",
+ )
+ ],
+ contents_by_uri={
+ "memo://alpha": [
+ SimpleNamespace(
+ uri="memo://alpha",
+ mimeType="text/plain",
+ text="hello from resource",
+ )
+ ]
+ },
+ )
+ }
+ )
+
+ McpResourceToolService(
+ registry=registry,
+ client_fn=lambda: client,
+ server_configs_fn=lambda: {"demo": object()},
+ )
+
+ list_entry = registry.get("ListMcpResources")
+ read_entry = registry.get("ReadMcpResource")
+ assert list_entry is not None
+ assert read_entry is not None
+
+ listed = json.loads(_unwrap_text(await _invoke_handler(list_entry.handler)))
+ assert listed == {
+ "items": [
+ {
+ "server": "demo",
+ "uri": "memo://alpha",
+ "name": "alpha",
+ "mime_type": "text/plain",
+ "description": "first resource",
+ }
+ ],
+ "total": 1,
+ }
+
+ content = json.loads(_unwrap_text(await _invoke_handler(read_entry.handler, server="demo", uri="memo://alpha")))
+ assert content == {
+ "server": "demo",
+ "uri": "memo://alpha",
+ "contents": [
+ {
+ "uri": "memo://alpha",
+ "mime_type": "text/plain",
+ "text": "hello from resource",
+ }
+ ],
+ }
+
+
+def test_mcp_resource_tool_service_skips_registration_without_servers() -> None:
+ registry = ToolRegistry()
+ McpResourceToolService(
+ registry=registry,
+ client_fn=lambda: None,
+ server_configs_fn=lambda: {},
+ )
+
+ assert registry.get("ListMcpResources") is None
+ assert registry.get("ReadMcpResource") is None
+
+
+@pytest.mark.asyncio
+async def test_mcp_resource_tool_service_fails_loudly_for_unknown_server() -> None:
+ registry = ToolRegistry()
+ client = _FakeClient({"demo": _FakeSession(resources=[], contents_by_uri={})})
+ McpResourceToolService(
+ registry=registry,
+ client_fn=lambda: client,
+ server_configs_fn=lambda: {"demo": object()},
+ )
+
+ read_entry = registry.get("ReadMcpResource")
+ assert read_entry is not None
+
+ with pytest.raises(ValueError, match='MCP server not found: "missing"'):
+ await _invoke_handler(read_entry.handler, server="missing", uri="memo://alpha")
+
+
+@pytest.mark.asyncio
+async def test_mcp_resource_tool_service_serializes_url_like_resource_uris() -> None:
+ registry = ToolRegistry()
+ uri = TypeAdapter(AnyUrl).validate_python("memo://alpha")
+ client = _FakeClient(
+ {
+ "demo": _FakeSession(
+ resources=[
+ SimpleNamespace(
+ uri=uri,
+ name="alpha",
+ mimeType="text/plain",
+ description="first resource",
+ )
+ ],
+ contents_by_uri={
+ "memo://alpha": [
+ SimpleNamespace(
+ uri=uri,
+ mimeType="text/plain",
+ text="hello from resource",
+ )
+ ]
+ },
+ )
+ }
+ )
+
+ McpResourceToolService(
+ registry=registry,
+ client_fn=lambda: client,
+ server_configs_fn=lambda: {"demo": object()},
+ )
+
+ list_entry = registry.get("ListMcpResources")
+ read_entry = registry.get("ReadMcpResource")
+ assert list_entry is not None
+ assert read_entry is not None
+
+ listed = json.loads(_unwrap_text(await _invoke_handler(list_entry.handler)))
+ assert listed["items"][0]["uri"] == "memo://alpha"
+
+ content = json.loads(_unwrap_text(await _invoke_handler(read_entry.handler, server="demo", uri="memo://alpha")))
+ assert content["contents"][0]["uri"] == "memo://alpha"
diff --git a/tests/Unit/platform/test_mcp_transport.py b/tests/Unit/platform/test_mcp_transport.py
new file mode 100644
index 000000000..f560f4d50
--- /dev/null
+++ b/tests/Unit/platform/test_mcp_transport.py
@@ -0,0 +1,52 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+from config.schema import MCPConfig, MCPServerConfig
+from core.runtime.agent import LeonAgent
+
+
+@pytest.mark.asyncio
+async def test_init_mcp_tools_respects_explicit_websocket_transport(monkeypatch):
+ captured: dict[str, object] = {}
+
+ class FakeClient:
+ def __init__(self, configs, tool_name_prefix=False):
+ captured["configs"] = configs
+
+ async def get_tools(self):
+ return []
+
+ async def close(self):
+ return None
+
+ agent = LeonAgent.__new__(LeonAgent)
+ agent.config = SimpleNamespace(
+ mcp=MCPConfig(
+ enabled=True,
+ servers={
+ "wsdemo": MCPServerConfig(
+ transport="websocket",
+ url="ws://example.test/mcp",
+ )
+ },
+ )
+ )
+ agent.verbose = False
+ agent._mcp_client = None
+
+ monkeypatch.setattr(
+ "langchain_mcp_adapters.client.MultiServerMCPClient",
+ FakeClient,
+ )
+
+ await LeonAgent._init_mcp_tools(agent)
+
+ assert captured["configs"] == {
+ "wsdemo": {
+ "transport": "websocket",
+ "url": "ws://example.test/mcp",
+ }
+ }
diff --git a/tests/test_model_config_enrichment.py b/tests/Unit/platform/test_model_config_enrichment.py
similarity index 90%
rename from tests/test_model_config_enrichment.py
rename to tests/Unit/platform/test_model_config_enrichment.py
index 6e1e3e53d..6fc470582 100644
--- a/tests/test_model_config_enrichment.py
+++ b/tests/Unit/platform/test_model_config_enrichment.py
@@ -1,9 +1,12 @@
"""Tests for model config enrichment (based_on + context_limit)."""
+import importlib
+
import pytest
from pydantic import ValidationError
from config.models_schema import ActiveModel, CustomModelConfig, ModelsConfig, ModelSpec, PoolConfig
+from core.runtime.middleware.monitor import cost as cost_module
from core.runtime.middleware.monitor.cost import fetch_openrouter_pricing, get_model_context_limit
from core.runtime.middleware.monitor.middleware import MonitorMiddleware
@@ -131,6 +134,25 @@ def test_update_model_based_on_affects_cost_calculator(self):
mw.update_model("Alice", overrides={"based_on": "claude-sonnet-4.5"})
assert mw._token_monitor.cost_calculator.costs != {}
+ def test_empty_cached_pricing_falls_back_to_bundled_models(self, monkeypatch: pytest.MonkeyPatch):
+ importlib.reload(cost_module)
+
+ monkeypatch.setattr(
+ cost_module,
+ "_load_cache",
+ lambda: (
+ {},
+ {"claude-sonnet-4.5": SONNET_LIMIT},
+ {"claude-sonnet-4.5": "anthropic"},
+ ),
+ )
+ monkeypatch.setattr(cost_module, "_fetch_from_openrouter", lambda: None)
+
+ prices = cost_module.fetch_openrouter_pricing()
+
+ assert prices.get("claude-sonnet-4.5") is not None
+ assert cost_module.CostCalculator("claude-sonnet-4.5").costs != {}
+
class TestThreeLevelPriority:
"""Level 1 用户配置 > Level 2 OpenRouter > Level 3 Bundled"""
diff --git a/tests/test_model_params.py b/tests/Unit/platform/test_model_params.py
similarity index 100%
rename from tests/test_model_params.py
rename to tests/Unit/platform/test_model_params.py
diff --git a/tests/test_search_tools.py b/tests/Unit/platform/test_search_tools.py
similarity index 100%
rename from tests/test_search_tools.py
rename to tests/Unit/platform/test_search_tools.py
diff --git a/tests/test_task_service.py b/tests/Unit/platform/test_task_service.py
similarity index 92%
rename from tests/test_task_service.py
rename to tests/Unit/platform/test_task_service.py
index e3105c5da..8fd33d775 100644
--- a/tests/test_task_service.py
+++ b/tests/Unit/platform/test_task_service.py
@@ -2,6 +2,7 @@
import sqlite3
import time
+from types import SimpleNamespace
import pytest
@@ -120,6 +121,19 @@ def test_list_returns_all(self):
tasks = task_service.list_tasks()
assert len(tasks) >= 2
+ def test_list_enriches_member_id_from_thread_repo(self, monkeypatch):
+ task_service.create_task(title="task with thread", thread_id="thread-1")
+
+ thread_repo = SimpleNamespace(
+ get_by_id=lambda thread_id: {"member_id": "member-1"} if thread_id == "thread-1" else None,
+ close=lambda: None,
+ )
+ monkeypatch.setattr(task_service, "build_thread_repo", lambda **_: thread_repo)
+
+ tasks = task_service.list_tasks()
+
+ assert tasks[0]["member_id"] == "member-1"
+
def test_delete_existing(self):
task = task_service.create_task(title="to delete")
assert task_service.delete_task(task["id"]) is True
diff --git a/tests/Unit/sandbox/test_agentbay_provider.py b/tests/Unit/sandbox/test_agentbay_provider.py
new file mode 100644
index 000000000..593757e22
--- /dev/null
+++ b/tests/Unit/sandbox/test_agentbay_provider.py
@@ -0,0 +1,290 @@
+import json
+import sys
+import types
+from dataclasses import replace
+from types import SimpleNamespace
+
+from sandbox.providers.agentbay import AgentBayProvider
+
+
+def _install_fake_agentbay_module(monkeypatch) -> None:
+ fake_mod = types.ModuleType("agentbay")
+ fake_api_mod = types.ModuleType("agentbay.api")
+ fake_api_models_mod = types.ModuleType("agentbay.api.models")
+
+ class FakeCreateSessionParams:
+ def __init__(self):
+ self.image_id = None
+ self.context_syncs = None
+
+ class FakeContextSync:
+ @staticmethod
+ def new(context_id: str, path: str):
+ return {"context_id": context_id, "path": path}
+
+ class FakeGetSessionRequest:
+ def __init__(self, authorization: str, session_id: str):
+ self.authorization = authorization
+ self.session_id = session_id
+
+ fake_mod.CreateSessionParams = FakeCreateSessionParams
+ fake_mod.ContextSync = FakeContextSync
+ fake_api_models_mod.GetSessionRequest = FakeGetSessionRequest
+ monkeypatch.setitem(sys.modules, "agentbay", fake_mod)
+ monkeypatch.setitem(sys.modules, "agentbay.api", fake_api_mod)
+ monkeypatch.setitem(sys.modules, "agentbay.api.models", fake_api_models_mod)
+
+
+def _provider_with_fake_client(fake_client) -> AgentBayProvider:
+ provider = AgentBayProvider.__new__(AgentBayProvider)
+ provider.name = "agentbay"
+ provider.client = fake_client
+ provider.default_context_path = "/home/wuying"
+ provider.image_id = None
+ provider._sessions = {}
+ provider._capability = AgentBayProvider.CAPABILITY
+ return provider
+
+
+def test_create_session_refreshes_agentbay_session_when_direct_call_fields_missing(monkeypatch):
+ _install_fake_agentbay_module(monkeypatch)
+ raw_session = SimpleNamespace(session_id="sess-123", token="", link_url="", mcpTools=[])
+ hydrated_session = SimpleNamespace(session_id="sess-123", token="tok", link_url="https://link", mcpTools=[object()])
+ fake_client = SimpleNamespace(
+ context=SimpleNamespace(get=lambda *args, **kwargs: None),
+ create=lambda params: SimpleNamespace(success=True, session=raw_session, error_message=""),
+ get=lambda session_id: SimpleNamespace(success=True, session=hydrated_session, error_message=""),
+ )
+ provider = _provider_with_fake_client(fake_client)
+
+ info = provider.create_session()
+
+ assert info.session_id == "sess-123"
+ assert provider._sessions["sess-123"] is hydrated_session
+
+
+def test_get_session_refreshes_stale_cached_agentbay_session():
+ stale_session = SimpleNamespace(session_id="sess-123", token="", link_url="", mcpTools=[])
+ hydrated_session = SimpleNamespace(session_id="sess-123", token="tok", link_url="https://link", mcpTools=[object()])
+ fake_client = SimpleNamespace(
+ get=lambda session_id: SimpleNamespace(success=True, session=hydrated_session, error_message=""),
+ )
+ provider = _provider_with_fake_client(fake_client)
+ provider._sessions["sess-123"] = stale_session
+
+ session = provider._get_session("sess-123")
+
+ assert session is hydrated_session
+ assert provider._sessions["sess-123"] is hydrated_session
+
+
+def test_destroy_session_skips_sync_when_pause_capability_is_disabled():
+ calls: list[bool] = []
+
+ class _DeleteResult:
+ success = True
+
+ class _Session:
+ def __init__(self) -> None:
+ self.session_id = "sess-123"
+ self.token = "tok"
+ self.link_url = "https://link"
+ self.mcpTools = [object()]
+
+ def delete(self, *, sync_context: bool):
+ calls.append(sync_context)
+ return _DeleteResult()
+
+ provider = _provider_with_fake_client(SimpleNamespace())
+ provider._capability = replace(AgentBayProvider.CAPABILITY, can_pause=False, can_resume=False)
+ provider._sessions["sess-123"] = _Session()
+
+ assert provider.destroy_session("sess-123") is True
+ assert calls == [False]
+ assert "sess-123" not in provider._sessions
+
+
+def test_execute_prefers_link_url_shell_path_when_session_has_direct_call_metadata():
+ calls: list[tuple[str, object]] = []
+
+ class _Tool:
+ name = "shell"
+ server = "wuying_shell"
+
+ def _command_execute(**kwargs):
+ calls.append(("command", kwargs))
+ return SimpleNamespace(success=False, output="", error_message="should not be used")
+
+ session = SimpleNamespace(
+ session_id="sess-123",
+ token="tok",
+ link_url="https://link",
+ mcpTools=[_Tool()],
+ _get_mcp_server_for_tool=lambda tool_name: "wuying_shell" if tool_name == "shell" else None,
+ command=SimpleNamespace(execute_command=_command_execute),
+ )
+ provider = _provider_with_fake_client(SimpleNamespace())
+ provider._sessions["sess-123"] = session
+ provider._call_link_url_tool = lambda session, tool_name, args, server_name: (
+ calls.append(("link", {"tool": tool_name, "args": args, "server": server_name}))
+ or AgentBayProvider._provider_exec_result_from_tool_result(
+ SimpleNamespace(
+ success=True,
+ data=json.dumps({"stdout": "/home/wuying\n", "stderr": "", "exit_code": 0}),
+ error_message="",
+ )
+ )
+ )
+
+ result = provider.execute("sess-123", "pwd", timeout_ms=5000, cwd="/home/wuying")
+
+ assert result.output == "/home/wuying\n"
+ assert result.exit_code == 0
+ assert result.error is None
+ assert calls == [
+ (
+ "link",
+ {
+ "tool": "shell",
+ "args": {"command": "pwd", "timeout_ms": 5000, "cwd": "/home/wuying"},
+ "server": "wuying_shell",
+ },
+ )
+ ]
+
+
+def test_get_session_hydrates_sdk_shape_session_from_raw_get_session_metadata(monkeypatch):
+ _install_fake_agentbay_module(monkeypatch)
+ sdk_shape_session = SimpleNamespace(
+ session_id="sess-123",
+ token="tok",
+ resource_url="https://resource",
+ mcp_tools=[],
+ )
+ fake_response = SimpleNamespace(
+ to_map=lambda: {
+ "body": {
+ "Success": True,
+ "Data": {
+ "LinkUrl": "https://link",
+ "Token": "tok",
+ "ToolList": [{"Name": "shell", "Server": "wuying_shell"}],
+ },
+ }
+ }
+ )
+ fake_client = SimpleNamespace(
+ api_key="api-key",
+ get=lambda session_id: SimpleNamespace(success=True, session=sdk_shape_session, error_message=""),
+ client=SimpleNamespace(get_session=lambda request: fake_response),
+ )
+ provider = _provider_with_fake_client(fake_client)
+
+ session = provider._get_session("sess-123")
+
+ assert session is sdk_shape_session
+ assert getattr(session, "link_url") == "https://link"
+ assert getattr(session, "token") == "tok"
+ assert len(getattr(session, "mcp_tools")) == 1
+ assert getattr(session, "mcpTools") == getattr(session, "mcp_tools")
+ assert provider._resolve_shell_server(session) == "wuying_shell"
+
+
+def test_execute_prefers_link_url_shell_path_for_sdk_shape_session():
+ calls: list[tuple[str, object]] = []
+
+ def _command_execute(**kwargs):
+ calls.append(("command", kwargs))
+ return SimpleNamespace(success=False, output="", error_message="should not be used")
+
+ session = SimpleNamespace(
+ session_id="sess-123",
+ token="tok",
+ link_url="https://link",
+ mcp_tools=[SimpleNamespace(name="shell", server="wuying_shell")],
+ _find_server_for_tool=lambda tool_name: "wuying_shell" if tool_name == "shell" else "",
+ command=SimpleNamespace(execute_command=_command_execute),
+ )
+ provider = _provider_with_fake_client(SimpleNamespace())
+ provider._sessions["sess-123"] = session
+ provider._call_link_url_tool = lambda session, tool_name, args, server_name: (
+ calls.append(("link", {"tool": tool_name, "args": args, "server": server_name}))
+ or AgentBayProvider._provider_exec_result_from_tool_result(
+ SimpleNamespace(
+ success=True,
+ data=json.dumps({"stdout": "/home/wuying\n", "stderr": "", "exit_code": 0}),
+ error_message="",
+ )
+ )
+ )
+
+ result = provider.execute("sess-123", "pwd", timeout_ms=5000, cwd="/home/wuying")
+
+ assert result.output == "/home/wuying\n"
+ assert result.exit_code == 0
+ assert result.error is None
+ assert calls == [
+ (
+ "link",
+ {
+ "tool": "shell",
+ "args": {"command": "pwd", "timeout_ms": 5000, "cwd": "/home/wuying"},
+ "server": "wuying_shell",
+ },
+ )
+ ]
+
+
+def test_resolve_shell_server_falls_back_to_mcp_tools_when_sdk_resolver_raises():
+ session = SimpleNamespace(
+ mcp_tools=[SimpleNamespace(name="shell", server="wuying_shell")],
+ _find_server_for_tool=lambda tool_name: (_ for _ in ()).throw(StopIteration()),
+ )
+
+ assert AgentBayProvider._resolve_shell_server(session) == "wuying_shell"
+
+
+def test_execute_uses_provider_owned_link_call_instead_of_sdk_private_method():
+ calls: list[tuple[str, object]] = []
+
+ def _sdk_link(*args, **kwargs):
+ raise StopIteration()
+
+ def _provider_link(session: object, tool_name: str, args: dict, server_name: str):
+ calls.append(("provider-link", {"tool": tool_name, "args": args, "server": server_name}))
+ return AgentBayProvider._provider_exec_result_from_tool_result(
+ SimpleNamespace(
+ success=True,
+ data=json.dumps({"stdout": "/home/wuying\n", "stderr": "", "exit_code": 0}),
+ error_message="",
+ )
+ )
+
+ session = SimpleNamespace(
+ session_id="sess-123",
+ token="tok",
+ link_url="https://link",
+ mcp_tools=[SimpleNamespace(name="shell", server="wuying_shell")],
+ _find_server_for_tool=lambda tool_name: "wuying_shell",
+ _call_mcp_tool_link_url=_sdk_link,
+ command=SimpleNamespace(execute_command=lambda **kwargs: None),
+ )
+ provider = _provider_with_fake_client(SimpleNamespace())
+ provider._sessions["sess-123"] = session
+ provider._call_link_url_tool = _provider_link
+
+ result = provider.execute("sess-123", "pwd", timeout_ms=5000, cwd="/home/wuying")
+
+ assert result.output == "/home/wuying\n"
+ assert result.exit_code == 0
+ assert result.error is None
+ assert calls == [
+ (
+ "provider-link",
+ {
+ "tool": "shell",
+ "args": {"command": "pwd", "timeout_ms": 5000, "cwd": "/home/wuying"},
+ "server": "wuying_shell",
+ },
+ )
+ ]
diff --git a/tests/test_chat_session.py b/tests/Unit/sandbox/test_chat_session.py
similarity index 100%
rename from tests/test_chat_session.py
rename to tests/Unit/sandbox/test_chat_session.py
diff --git a/tests/test_daytona_provider.py b/tests/Unit/sandbox/test_daytona_provider.py
similarity index 100%
rename from tests/test_daytona_provider.py
rename to tests/Unit/sandbox/test_daytona_provider.py
diff --git a/tests/Unit/sandbox/test_daytona_provider_proxy.py b/tests/Unit/sandbox/test_daytona_provider_proxy.py
new file mode 100644
index 000000000..32f7f9533
--- /dev/null
+++ b/tests/Unit/sandbox/test_daytona_provider_proxy.py
@@ -0,0 +1,21 @@
+"""Unit tests for Daytona local toolbox URL normalization."""
+
+from sandbox.providers.daytona import DaytonaProvider
+
+
+def test_daytona_provider_rewrites_local_toolbox_proxy_url_to_loopback():
+ provider = object.__new__(DaytonaProvider)
+ provider.api_url = "http://localhost:3986/api"
+
+ rewritten = provider._normalize_toolbox_proxy_url("http://172.18.0.1:4000/toolbox")
+
+ assert rewritten == "http://127.0.0.1:4000/toolbox"
+
+
+def test_daytona_provider_leaves_remote_toolbox_proxy_url_unchanged():
+ provider = object.__new__(DaytonaProvider)
+ provider.api_url = "https://daytona.example.com/api"
+
+ untouched = provider._normalize_toolbox_proxy_url("https://proxy.example.com/toolbox")
+
+ assert untouched == "https://proxy.example.com/toolbox"
diff --git a/tests/test_e2b_provider.py b/tests/Unit/sandbox/test_e2b_provider.py
similarity index 57%
rename from tests/test_e2b_provider.py
rename to tests/Unit/sandbox/test_e2b_provider.py
index 8c88b614d..d64f72663 100644
--- a/tests/test_e2b_provider.py
+++ b/tests/Unit/sandbox/test_e2b_provider.py
@@ -1,16 +1,60 @@
"""Smoke test for E2B provider and sandbox."""
+import builtins
import os
import sys
+from types import SimpleNamespace
+
+import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sandbox.providers.e2b import E2BProvider
+def test_e2b_provider_requires_sdk(monkeypatch):
+ real_import = builtins.__import__
+
+ def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
+ if name == "e2b":
+ raise ModuleNotFoundError("No module named 'e2b'")
+ return real_import(name, globals, locals, fromlist, level)
+
+ monkeypatch.setattr(builtins, "__import__", fake_import)
+
+ with pytest.raises(ModuleNotFoundError, match="No module named 'e2b'"):
+ E2BProvider(api_key="test-key", timeout=60)
+
+
+def test_e2b_create_session_bootstraps_workspace_files_dir(monkeypatch):
+ calls: list[tuple[str, str | None, float | None]] = []
+
+ class _FakeCommands:
+ def run(self, command, cwd=None, timeout=None):
+ calls.append((command, cwd, timeout))
+ return SimpleNamespace(stdout="", stderr="", exit_code=0)
+
+ class _FakeSandbox:
+ def __init__(self):
+ self.sandbox_id = "sbx-123"
+ self.commands = _FakeCommands()
+
+ @classmethod
+ def beta_create(cls, template, timeout, auto_pause, api_key):
+ return cls()
+
+ monkeypatch.setitem(sys.modules, "e2b", SimpleNamespace(Sandbox=_FakeSandbox))
+
+ provider = E2BProvider(api_key="test-key", timeout=60)
+ info = provider.create_session()
+
+ assert info.session_id == "sbx-123"
+ assert calls == [("mkdir -p /home/user/workspace/files", "/home/user", 10.0)]
+
+
def test_e2b_provider():
api_key = os.getenv("E2B_API_KEY")
- if not api_key:
+ if not api_key or not api_key.startswith("e2b_"):
print("E2B_API_KEY not set, skipping")
return
diff --git a/tests/test_lease.py b/tests/Unit/sandbox/test_lease.py
similarity index 100%
rename from tests/test_lease.py
rename to tests/Unit/sandbox/test_lease.py
diff --git a/tests/test_lifecycle.py b/tests/Unit/sandbox/test_lifecycle.py
similarity index 100%
rename from tests/test_lifecycle.py
rename to tests/Unit/sandbox/test_lifecycle.py
diff --git a/tests/Unit/sandbox/test_local_provider_metrics.py b/tests/Unit/sandbox/test_local_provider_metrics.py
new file mode 100644
index 000000000..1cb1daabc
--- /dev/null
+++ b/tests/Unit/sandbox/test_local_provider_metrics.py
@@ -0,0 +1,44 @@
+from __future__ import annotations
+
+import builtins
+import io
+from types import SimpleNamespace
+
+from sandbox.providers import local as local_module
+from sandbox.providers.local import LocalSessionProvider
+
+
+def test_local_provider_reads_linux_procfs_metrics_without_top_or_free(monkeypatch) -> None:
+ provider = LocalSessionProvider()
+
+ cpu_samples = iter(
+ [
+ "cpu 100 0 100 800 0 0 0 0 0 0\n",
+ "cpu 130 0 120 850 0 0 0 0 0 0\n",
+ ]
+ )
+
+ def fake_open(path: str, *args, **kwargs):
+ if path == "/proc/stat":
+ return io.StringIO(next(cpu_samples))
+ if path == "/proc/meminfo":
+ return io.StringIO("MemTotal: 1048576 kB\nMemAvailable: 524288 kB\n")
+ raise FileNotFoundError(path)
+
+ monkeypatch.setattr("sandbox.providers.local.platform.system", lambda: "Linux")
+ monkeypatch.setattr(builtins, "open", fake_open)
+ monkeypatch.setattr(
+ local_module.os,
+ "statvfs",
+ lambda _path: SimpleNamespace(f_frsize=4096, f_blocks=262144, f_bavail=131072),
+ raising=False,
+ )
+
+ metrics = provider.get_metrics("host")
+
+ assert metrics is not None
+ assert metrics.cpu_percent == 50.0
+ assert metrics.memory_total_mb == 1024.0
+ assert metrics.memory_used_mb == 512.0
+ assert metrics.disk_total_gb == 1.0
+ assert metrics.disk_used_gb == 0.5
diff --git a/tests/Unit/sandbox/test_remote_sandbox_init_commands.py b/tests/Unit/sandbox/test_remote_sandbox_init_commands.py
new file mode 100644
index 000000000..72ad58a1e
--- /dev/null
+++ b/tests/Unit/sandbox/test_remote_sandbox_init_commands.py
@@ -0,0 +1,32 @@
+from types import SimpleNamespace
+
+import pytest
+
+from sandbox.base import RemoteSandbox
+from sandbox.config import SandboxConfig
+
+
+class _RecordingCommand:
+ def __init__(self) -> None:
+ self.calls: list[str] = []
+
+ async def execute(self, command: str):
+ self.calls.append(command)
+ return SimpleNamespace(exit_code=0, stderr="", stdout="")
+
+
+@pytest.mark.asyncio
+async def test_run_init_commands_avoids_same_loop_threadsafe_wait(monkeypatch: pytest.MonkeyPatch):
+ command = _RecordingCommand()
+ capability = SimpleNamespace(command=command)
+ sandbox = RemoteSandbox.__new__(RemoteSandbox)
+ sandbox._config = SandboxConfig(init_commands=["echo init"])
+
+ def _unexpected_threadsafe(*args, **kwargs):
+ raise AssertionError("same-loop run_coroutine_threadsafe path should not be used")
+
+ monkeypatch.setattr("sandbox.base.asyncio.run_coroutine_threadsafe", _unexpected_threadsafe)
+
+ sandbox._run_init_commands(capability)
+
+ assert command.calls == ["echo init"]
diff --git a/tests/Unit/sandbox/test_sandbox_manager_volume_repo.py b/tests/Unit/sandbox/test_sandbox_manager_volume_repo.py
new file mode 100644
index 000000000..82b9c76eb
--- /dev/null
+++ b/tests/Unit/sandbox/test_sandbox_manager_volume_repo.py
@@ -0,0 +1,536 @@
+import json
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+
+import sandbox.manager as sandbox_manager_module
+from sandbox.manager import SandboxManager
+from sandbox.providers.local import LocalSessionProvider
+from sandbox.volume_source import DaytonaVolume, HostVolume
+
+
+class _FakeVolumeRepo:
+ def __init__(self, source: dict[str, str]) -> None:
+ self._source = source
+ self.closed = False
+ self.requested_ids: list[str] = []
+ self.created: list[tuple[str, str | None]] = []
+
+ def get(self, volume_id: str):
+ self.requested_ids.append(volume_id)
+ if self.created and volume_id == self.created[-1][0]:
+ return {"source": json.dumps(self._source)}
+ return {"source": json.dumps(self._source)}
+
+ def create(self, volume_id: str, source_json: str, name: str | None, created_at: str) -> None:
+ self.created.append((volume_id, name))
+ self._source = json.loads(source_json)
+
+ def close(self) -> None:
+ self.closed = True
+
+
+class _FakeVolume:
+ def __init__(self) -> None:
+ self.mount_calls: list[tuple[str, str]] = []
+ self.upload_calls: list[tuple[str, str]] = []
+ self.download_calls: list[tuple[str, str]] = []
+ self.cleared: list[str] = []
+
+ def resolve_mount_path(self) -> str:
+ return "/workspace"
+
+ def mount(self, thread_id: str, source, remote_path: str) -> None:
+ self.mount_calls.append((thread_id, remote_path))
+
+ def mount_managed_volume(self, thread_id: str, volume_name: str, remote_path: str) -> None:
+ self.mount_calls.append((thread_id, remote_path))
+
+ def sync_upload(self, thread_id: str, session_id: str, source, remote_path: str, files=None) -> None:
+ self.upload_calls.append((thread_id, session_id))
+
+ def sync_download(self, thread_id: str, session_id: str, source, remote_path: str) -> None:
+ self.download_calls.append((thread_id, session_id))
+
+ def clear_sync_state(self, thread_id: str) -> None:
+ self.cleared.append(thread_id)
+
+
+class _FakeThreadRepo:
+ def __init__(self, row):
+ self._row = row
+ self.closed = False
+
+ def get_by_id(self, _thread_id: str):
+ return self._row
+
+ def close(self) -> None:
+ self.closed = True
+
+
+class _FakeUpdateRepo:
+ def __init__(self) -> None:
+ self.updated: list[tuple[str, str]] = []
+ self.closed = False
+
+ def update_source(self, volume_id: str, source_json: str) -> None:
+ self.updated.append((volume_id, source_json))
+
+ def close(self) -> None:
+ self.closed = True
+
+
+class _FakeLeaseStore:
+ def __init__(self) -> None:
+ self.volume_updates: list[tuple[str, str]] = []
+
+ def set_volume_id(self, lease_id: str, volume_id: str) -> None:
+ self.volume_updates.append((lease_id, volume_id))
+
+
+class _FakeSessionManager:
+ def __init__(self, active_rows) -> None:
+ self._active_rows = active_rows
+ self.deleted: list[tuple[str, str]] = []
+
+ def list_active(self):
+ return list(self._active_rows)
+
+ def delete(self, session_id: str, reason: str) -> None:
+ self.deleted.append((session_id, reason))
+
+
+class _FakeDaytonaProvider:
+ def __init__(self) -> None:
+ self.calls: list[tuple[str, str]] = []
+ self.ready_waits: list[str] = []
+
+ def create_managed_volume(self, member_id: str, mount_path: str) -> str:
+ self.calls.append((member_id, mount_path))
+ return f"leon-volume-{member_id}"
+
+ def wait_managed_volume_ready(self, volume_name: str) -> None:
+ self.ready_waits.append(volume_name)
+
+
+def test_setup_mounts_reads_volume_from_active_storage_repo(tmp_path):
+ manager = object.__new__(SandboxManager)
+ manager.provider_capability = SimpleNamespace(runtime_kind="local")
+ manager.volume = _FakeVolume()
+ manager._get_active_terminal = lambda _thread_id: SimpleNamespace(lease_id="lease-1")
+ manager._get_lease = lambda _lease_id: SimpleNamespace(volume_id="volume-1")
+ repo = _FakeVolumeRepo(HostVolume(Path(tmp_path) / "vol").serialize())
+ manager._sandbox_volume_repo = lambda: repo
+
+ result = manager._setup_mounts("thread-1")
+
+ assert repo.requested_ids == ["volume-1"]
+ assert repo.closed is True
+ assert isinstance(result["source"], HostVolume)
+ assert manager.volume.mount_calls == [("thread-1", "/workspace")]
+
+
+def test_resolve_volume_source_reads_volume_from_active_storage_repo(tmp_path):
+ manager = object.__new__(SandboxManager)
+ manager.provider_capability = SimpleNamespace(runtime_kind="agentbay")
+ manager._get_active_terminal = lambda _thread_id: SimpleNamespace(lease_id="lease-1")
+ manager._get_lease = lambda _lease_id: SimpleNamespace(volume_id="volume-1")
+ repo = _FakeVolumeRepo(HostVolume(Path(tmp_path) / "vol").serialize())
+ manager._sandbox_volume_repo = lambda: repo
+
+ source = manager.resolve_volume_source("thread-1")
+
+ assert repo.requested_ids == ["volume-1"]
+ assert repo.closed is True
+ assert isinstance(source, HostVolume)
+
+
+def test_setup_mounts_provisions_missing_remote_volume_metadata(monkeypatch, tmp_path):
+ manager = object.__new__(SandboxManager)
+ manager.provider_capability = SimpleNamespace(runtime_kind="agentbay")
+ manager.volume = _FakeVolume()
+ manager._get_active_terminal = lambda _thread_id: SimpleNamespace(lease_id="lease-1")
+ lease = SimpleNamespace(lease_id="lease-1", volume_id=None)
+ manager._get_lease = lambda _lease_id: lease
+ manager.lease_store = _FakeLeaseStore()
+ repo = _FakeVolumeRepo(HostVolume(Path(tmp_path) / "vol").serialize())
+ manager._sandbox_volume_repo = lambda: repo
+ monkeypatch.setenv("LEON_SANDBOX_VOLUME_ROOT", str(tmp_path / "volumes"))
+
+ result = manager._setup_mounts("thread-1")
+
+ assert lease.volume_id is not None
+ assert repo.created == [(lease.volume_id, "vol-thread-1")]
+ assert manager.lease_store.volume_updates == [("lease-1", lease.volume_id)]
+ assert repo.requested_ids == [lease.volume_id]
+ assert isinstance(result["source"], HostVolume)
+
+
+def test_setup_mounts_recreates_missing_remote_volume_row_for_existing_volume_id(monkeypatch, tmp_path):
+ class _MissingRowRepo(_FakeVolumeRepo):
+ def __init__(self) -> None:
+ super().__init__(HostVolume(tmp_path / "vol").serialize())
+ self._rows: dict[str, dict[str, str]] = {}
+
+ def get(self, volume_id: str):
+ self.requested_ids.append(volume_id)
+ return self._rows.get(volume_id)
+
+ def create(self, volume_id: str, source_json: str, name: str | None, created_at: str) -> None:
+ super().create(volume_id, source_json, name, created_at)
+ self._rows[volume_id] = {"source": source_json}
+
+ def update_source(self, volume_id: str, source_json: str) -> None:
+ self._rows[volume_id] = {"source": source_json}
+ self._source = json.loads(source_json)
+
+ manager = object.__new__(SandboxManager)
+ manager.provider_capability = SimpleNamespace(runtime_kind="daytona_pty")
+ manager.provider = _FakeDaytonaProvider()
+ manager.volume = _FakeVolume()
+ manager._get_active_terminal = lambda _thread_id: SimpleNamespace(lease_id="lease-1")
+ lease = SimpleNamespace(lease_id="lease-1", volume_id="volume-missing")
+ manager._get_lease = lambda _lease_id: lease
+ manager.lease_store = _FakeLeaseStore()
+ repo = _MissingRowRepo()
+ manager._sandbox_volume_repo = lambda: repo
+ thread_repo = _FakeThreadRepo({"member_id": "member-daytona"})
+ monkeypatch.setattr(
+ sandbox_manager_module,
+ "build_thread_repo",
+ lambda **_kwargs: thread_repo,
+ raising=False,
+ )
+ monkeypatch.setenv("LEON_SANDBOX_VOLUME_ROOT", str(tmp_path / "volumes"))
+
+ result = manager._setup_mounts("thread-1")
+
+ assert repo.created == [("volume-missing", "vol-thread-1")]
+ assert manager.lease_store.volume_updates == []
+ assert repo.requested_ids == ["volume-missing", "volume-missing"]
+ assert isinstance(result["source"], DaytonaVolume)
+ assert manager.provider.calls == [("member-daytona", "/workspace")]
+ assert thread_repo.closed is True
+
+
+def test_enforce_idle_timeouts_destroys_when_provider_cannot_pause(monkeypatch):
+ manager = object.__new__(SandboxManager)
+ manager.provider = SimpleNamespace(
+ name="agentbay",
+ get_capability=lambda: SimpleNamespace(can_pause=False, can_destroy=True),
+ )
+ manager.terminal_store = SimpleNamespace(
+ db_path=Path("/tmp/fake-sandbox.db"),
+ get_by_id=lambda _terminal_id: {"terminal_id": "term-1", "lease_id": "lease-1"},
+ )
+ active_rows = [
+ {
+ "session_id": "sess-1",
+ "thread_id": "thread-1",
+ "terminal_id": "term-1",
+ "lease_id": "lease-1",
+ "started_at": "2026-04-04T00:00:00",
+ "last_active_at": "2026-04-04T00:00:00",
+ "idle_ttl_sec": 1,
+ "max_duration_sec": 3600,
+ "status": "active",
+ }
+ ]
+ manager.session_manager = _FakeSessionManager(active_rows)
+ fake_lease = SimpleNamespace(
+ lease_id="lease-1",
+ provider_name="agentbay",
+ refresh_instance_status=lambda _provider: "running",
+ pause_instance=lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("pause should not be used")),
+ destroy_instance=lambda *_args, **_kwargs: destroy_calls.append(True),
+ )
+ destroy_calls: list[bool] = []
+ manager._get_lease = lambda _lease_id: fake_lease
+ manager._terminal_is_busy = lambda _terminal_id: False
+ manager._lease_is_busy = lambda _lease_id: False
+ monkeypatch.setattr(
+ sandbox_manager_module,
+ "terminal_from_row",
+ lambda _row, _db_path: SimpleNamespace(terminal_id="term-1", lease_id="lease-1"),
+ )
+
+ manager.enforce_idle_timeouts()
+
+ assert destroy_calls == [True]
+ assert manager.session_manager.deleted == [("sess-1", "idle_timeout")]
+
+
+def test_destroy_thread_resources_skips_local_sync_when_lease_has_no_volume_id():
+ manager = object.__new__(SandboxManager)
+ manager.provider_capability = SimpleNamespace(runtime_kind="local")
+ manager.provider = SimpleNamespace(name="local")
+ manager.volume = _FakeVolume()
+ manager._get_thread_lease = lambda _thread_id: lease
+ manager._get_lease = lambda _lease_id: lease
+ manager._resolve_volume_entry = lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("volume lookup should not happen"))
+ manager.terminal_store = SimpleNamespace(
+ list_by_thread=lambda _thread_id: [{"terminal_id": "term-1", "lease_id": "lease-1", "thread_id": "thread-1"}],
+ delete=lambda _terminal_id: deleted_terminals.append(_terminal_id),
+ list_all=lambda: [],
+ db_path=Path("/tmp/fake-sandbox.db"),
+ )
+ manager.session_manager = SimpleNamespace(
+ get=lambda _thread_id, _terminal_id: SimpleNamespace(session_id="sess-1"),
+ delete=lambda session_id, reason: deleted_sessions.append((session_id, reason)),
+ )
+ deleted_terminals: list[str] = []
+ deleted_sessions: list[tuple[str, str]] = []
+ destroy_calls: list[str] = []
+
+ class _Lease:
+ lease_id = "lease-1"
+ observed_state = "running"
+ volume_id = None
+
+ def get_instance(self):
+ return SimpleNamespace(instance_id="instance-1")
+
+ def destroy_instance(self, _provider):
+ destroy_calls.append("lease-1")
+
+ lease = _Lease()
+ manager.lease_store = SimpleNamespace(delete=lambda lease_id: deleted_leases.append(lease_id))
+ deleted_leases: list[str] = []
+
+ assert manager.destroy_thread_resources("thread-1") is True
+ assert manager.volume.download_calls == []
+ assert manager.volume.cleared == ["thread-1"]
+ assert deleted_sessions == [("sess-1", "thread_deleted")]
+ assert deleted_terminals == ["term-1"]
+ assert destroy_calls == ["lease-1"]
+ assert deleted_leases == ["lease-1"]
+
+
+def test_sync_uploads_skips_local_volume_sync_when_lease_has_no_volume_id():
+ manager = object.__new__(SandboxManager)
+ manager.provider_capability = SimpleNamespace(runtime_kind="local")
+ manager.volume = _FakeVolume()
+ manager._get_active_terminal = lambda _thread_id: SimpleNamespace(terminal_id="term-1", lease_id="lease-1")
+ manager._get_lease = lambda _lease_id: SimpleNamespace(volume_id=None)
+ manager._get_thread_lease = lambda _thread_id: SimpleNamespace(volume_id=None)
+ manager._resolve_volume_entry = lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("volume lookup should not happen"))
+ manager.session_manager = SimpleNamespace(
+ get=lambda _thread_id, _terminal_id: SimpleNamespace(
+ lease=SimpleNamespace(get_instance=lambda: SimpleNamespace(instance_id="instance-1"))
+ )
+ )
+
+ assert manager.sync_uploads("thread-1") is True
+ assert manager.volume.upload_calls == []
+
+
+def test_get_sandbox_local_provider_does_not_require_volume_bootstrap(tmp_path):
+ manager = SandboxManager(
+ provider=LocalSessionProvider(default_cwd=str(tmp_path)),
+ db_path=tmp_path / "sandbox.db",
+ )
+
+ capability = manager.get_sandbox("thread-local")
+
+ assert capability.command.runtime_owns_cwd is True
+ session = manager.session_manager.get("thread-local")
+ assert session is not None
+ assert session.lease.provider_name == "local"
+
+
+def test_get_sandbox_auto_resumes_paused_lease_when_reconstructing_session():
+ manager = object.__new__(SandboxManager)
+ manager.provider = SimpleNamespace(name="local")
+ manager.provider_capability = SimpleNamespace(runtime_kind="local", eager_instance_binding=False)
+ manager.volume = _FakeVolume()
+ terminal = SimpleNamespace(
+ terminal_id="term-1",
+ lease_id="lease-1",
+ get_state=lambda: SimpleNamespace(cwd="/tmp", env_delta={}, state_version=0),
+ update_state=lambda _state: None,
+ )
+ lease = SimpleNamespace(
+ provider_name="local",
+ observed_state="paused",
+ bind_mounts=None,
+ recipe=None,
+ get_instance=lambda: SimpleNamespace(instance_id="instance-1"),
+ )
+ manager._get_active_terminal = lambda _thread_id: terminal
+ manager._get_lease = lambda _lease_id: lease
+ manager._assert_lease_provider = lambda _lease, _thread_id: None
+ manager._ensure_bound_instance = lambda _lease: None
+ resume_calls: list[tuple[str, str]] = []
+ manager.resume_session = lambda thread_id, source="user_resume": resume_calls.append((thread_id, source)) or True
+ manager.session_manager = SimpleNamespace(
+ get=lambda _thread_id, _terminal_id: None,
+ create=lambda **_kwargs: SimpleNamespace(session_id="sess-1", terminal=terminal, lease=lease),
+ )
+
+ manager.get_sandbox("thread-1")
+
+ assert resume_calls == [("thread-1", "auto_resume")]
+
+
+def test_get_sandbox_auto_resumes_live_session_when_lease_state_is_paused():
+ manager = object.__new__(SandboxManager)
+ terminal = SimpleNamespace(
+ terminal_id="term-1",
+ lease_id="lease-1",
+ get_state=lambda: SimpleNamespace(cwd="/tmp", env_delta={}, state_version=0),
+ )
+ paused_lease = SimpleNamespace(
+ lease_id="lease-1",
+ provider_name="local",
+ observed_state="paused",
+ bind_mounts=None,
+ )
+ resumed_lease = SimpleNamespace(
+ lease_id="lease-1",
+ provider_name="local",
+ observed_state="running",
+ bind_mounts=None,
+ )
+ live_session = SimpleNamespace(
+ terminal=terminal,
+ lease=paused_lease,
+ status="active",
+ )
+
+ manager.provider = SimpleNamespace(name="local")
+ manager.provider_capability = SimpleNamespace(runtime_kind="local", eager_instance_binding=False)
+ manager.volume = _FakeVolume()
+ manager._assert_lease_provider = lambda _lease, _thread_id: None
+ manager._ensure_bound_instance = lambda _lease: None
+ resume_calls: list[tuple[str, str]] = []
+
+ def _get_session(_thread_id, _terminal_id):
+ if resume_calls:
+ return SimpleNamespace(terminal=terminal, lease=resumed_lease, status="active")
+ return live_session
+
+ manager._get_active_terminal = lambda _thread_id: terminal
+ manager.resume_session = lambda thread_id, source="user_resume": resume_calls.append((thread_id, source)) or True
+ manager.session_manager = SimpleNamespace(get=_get_session)
+
+ capability = manager.get_sandbox("thread-1")
+
+ assert resume_calls == [("thread-1", "auto_resume")]
+ assert capability._session.lease is resumed_lease
+
+
+def test_resume_session_rebinds_live_session_lease_after_resume():
+ manager = object.__new__(SandboxManager)
+ terminal = SimpleNamespace(terminal_id="term-1", lease_id="lease-1")
+ resumed_lease = SimpleNamespace(
+ lease_id="lease-1",
+ observed_state="running",
+ get_instance=lambda: SimpleNamespace(instance_id="instance-1"),
+ resume_instance=lambda _provider, source="user_resume": True,
+ )
+ stale_lease = SimpleNamespace(lease_id="lease-1", observed_state="paused")
+ runtime = SimpleNamespace(lease=stale_lease)
+ live_session = SimpleNamespace(
+ session_id="sess-1",
+ terminal=terminal,
+ lease=stale_lease,
+ runtime=runtime,
+ status="paused",
+ )
+ manager.provider = SimpleNamespace(name="local")
+ manager._get_thread_terminals = lambda _thread_id: [terminal]
+ manager._get_thread_lease = lambda _thread_id: resumed_lease
+ manager._sync_to_sandbox = lambda *_args, **_kwargs: None
+ manager._ensure_chat_session = lambda _thread_id: None
+ manager.session_manager = SimpleNamespace(
+ get=lambda _thread_id, _terminal_id: live_session,
+ resume=lambda _session_id: setattr(live_session, "status", "active"),
+ )
+
+ ok = manager.resume_session("thread-1", source="auto_resume")
+
+ assert ok is True
+ assert live_session.lease is resumed_lease
+ assert runtime.lease is resumed_lease
+
+
+def test_upgrade_to_daytona_volume_uses_runtime_thread_repo_for_member_lookup(monkeypatch, tmp_path):
+ manager = object.__new__(SandboxManager)
+ manager.provider = _FakeDaytonaProvider()
+ update_repo = _FakeUpdateRepo()
+ manager._sandbox_volume_repo = lambda: update_repo
+
+ thread_repo = _FakeThreadRepo({"member_id": "member-supabase"})
+ monkeypatch.setattr(
+ sandbox_manager_module,
+ "build_thread_repo",
+ lambda **_kwargs: thread_repo,
+ raising=False,
+ )
+ monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
+
+ new_source = manager._upgrade_to_daytona_volume(
+ "thread-supabase",
+ HostVolume(tmp_path / "staging"),
+ "volume-1",
+ "/workspace",
+ )
+
+ assert manager.provider.calls == [("member-supabase", "/workspace")]
+ assert thread_repo.closed is True
+ assert isinstance(new_source, DaytonaVolume)
+ assert update_repo.closed is True
+ assert update_repo.updated
+
+
+def test_upgrade_to_daytona_volume_waits_when_reusing_existing_daytona_volume(monkeypatch, tmp_path):
+ manager = object.__new__(SandboxManager)
+ provider = _FakeDaytonaProvider()
+ update_repo = _FakeUpdateRepo()
+ manager.provider = provider
+ manager._sandbox_volume_repo = lambda: update_repo
+
+ thread_repo = _FakeThreadRepo({"member_id": "member-supabase"})
+ monkeypatch.setattr(
+ sandbox_manager_module,
+ "build_thread_repo",
+ lambda **_kwargs: thread_repo,
+ raising=False,
+ )
+
+ def _already_exists(member_id: str, mount_path: str) -> str:
+ provider.calls.append((member_id, mount_path))
+ raise RuntimeError("volume already exists")
+
+ provider.create_managed_volume = _already_exists
+
+ new_source = manager._upgrade_to_daytona_volume(
+ "thread-supabase",
+ HostVolume(tmp_path / "staging"),
+ "volume-1",
+ "/workspace",
+ )
+
+ assert isinstance(new_source, DaytonaVolume)
+ assert provider.ready_waits == ["leon-volume-member-supabase"]
+
+
+@pytest.mark.parametrize(
+ ("strategy", "expected_class_name"),
+ [
+ ("sqlite", "SQLiteSandboxMonitorRepo"),
+ ("supabase", "SQLiteSandboxMonitorRepo"),
+ ],
+)
+def test_make_sandbox_monitor_repo_uses_runtime_sandbox_db(monkeypatch, strategy, expected_class_name):
+ from backend.web.core import storage_factory
+
+ monkeypatch.setenv("LEON_STORAGE_STRATEGY", strategy)
+ storage_factory.make_sandbox_monitor_repo.cache_clear() if hasattr(storage_factory.make_sandbox_monitor_repo, "cache_clear") else None
+
+ repo = storage_factory.make_sandbox_monitor_repo()
+ try:
+ assert repo.__class__.__name__ == expected_class_name
+ finally:
+ repo.close()
diff --git a/tests/test_sandbox_state.py b/tests/Unit/sandbox/test_sandbox_state.py
similarity index 100%
rename from tests/test_sandbox_state.py
rename to tests/Unit/sandbox/test_sandbox_state.py
diff --git a/tests/test_terminal.py b/tests/Unit/sandbox/test_terminal.py
similarity index 100%
rename from tests/test_terminal.py
rename to tests/Unit/sandbox/test_terminal.py
diff --git a/tests/test_terminal_persistence.py b/tests/Unit/sandbox/test_terminal_persistence.py
similarity index 100%
rename from tests/test_terminal_persistence.py
rename to tests/Unit/sandbox/test_terminal_persistence.py
diff --git a/tests/test_checkpoint_repo.py b/tests/Unit/storage/test_checkpoint_repo.py
similarity index 100%
rename from tests/test_checkpoint_repo.py
rename to tests/Unit/storage/test_checkpoint_repo.py
diff --git a/tests/test_eval_repo.py b/tests/Unit/storage/test_eval_repo.py
similarity index 100%
rename from tests/test_eval_repo.py
rename to tests/Unit/storage/test_eval_repo.py
diff --git a/tests/test_file_operation_repo.py b/tests/Unit/storage/test_file_operation_repo.py
similarity index 100%
rename from tests/test_file_operation_repo.py
rename to tests/Unit/storage/test_file_operation_repo.py
diff --git a/tests/test_run_event_repo.py b/tests/Unit/storage/test_run_event_repo.py
similarity index 100%
rename from tests/test_run_event_repo.py
rename to tests/Unit/storage/test_run_event_repo.py
diff --git a/tests/test_sqlite_kernel.py b/tests/Unit/storage/test_sqlite_kernel.py
similarity index 100%
rename from tests/test_sqlite_kernel.py
rename to tests/Unit/storage/test_sqlite_kernel.py
diff --git a/tests/Unit/storage/test_storage_container_contract.py b/tests/Unit/storage/test_storage_container_contract.py
new file mode 100644
index 000000000..503f9dd3a
--- /dev/null
+++ b/tests/Unit/storage/test_storage_container_contract.py
@@ -0,0 +1,82 @@
+from pathlib import Path
+
+import pytest
+
+from storage import StorageContainer
+from storage.providers.sqlite.checkpoint_repo import SQLiteCheckpointRepo
+from storage.providers.sqlite.eval_repo import SQLiteEvalRepo
+from storage.providers.supabase.checkpoint_repo import SupabaseCheckpointRepo
+from storage.providers.supabase.eval_repo import SupabaseEvalRepo
+from storage.providers.supabase.file_operation_repo import SupabaseFileOperationRepo
+from storage.providers.supabase.run_event_repo import SupabaseRunEventRepo
+from storage.providers.supabase.summary_repo import SupabaseSummaryRepo
+
+
+class _FakeSupabaseClient:
+ def table(self, table_name: str):
+ raise AssertionError(f"table() should not be called in this container test: {table_name}")
+
+
+def test_storage_container_sqlite_strategy_uses_sqlite_checkpoint_repo(tmp_path: Path) -> None:
+ container = StorageContainer(main_db_path=tmp_path / "leon.db", strategy="sqlite")
+ assert isinstance(container.checkpoint_repo(), SQLiteCheckpointRepo)
+
+
+def test_storage_container_supabase_strategy_builds_concrete_repos() -> None:
+ container = StorageContainer(strategy="supabase", supabase_client=_FakeSupabaseClient())
+
+ assert isinstance(container.checkpoint_repo(), SupabaseCheckpointRepo)
+ assert isinstance(container.run_event_repo(), SupabaseRunEventRepo)
+ assert isinstance(container.file_operation_repo(), SupabaseFileOperationRepo)
+ assert isinstance(container.summary_repo(), SupabaseSummaryRepo)
+ assert isinstance(container.eval_repo(), SupabaseEvalRepo)
+
+
+@pytest.mark.parametrize(
+ ("strategy", "repo_providers", "repo_method", "expected_type"),
+ [
+ ("sqlite", {"checkpoint_repo": "supabase"}, "checkpoint_repo", SupabaseCheckpointRepo),
+ ("supabase", {"eval_repo": "sqlite"}, "eval_repo", SQLiteEvalRepo),
+ ],
+)
+def test_storage_container_repo_level_overrides(
+ strategy: str,
+ repo_providers: dict[str, str],
+ repo_method: str,
+ expected_type: type,
+) -> None:
+ container = StorageContainer(
+ strategy=strategy,
+ repo_providers=repo_providers,
+ supabase_client=_FakeSupabaseClient(),
+ )
+ assert isinstance(getattr(container, repo_method)(), expected_type)
+
+
+@pytest.mark.parametrize(
+ ("repo_method", "message"),
+ [
+ ("checkpoint_repo", "Supabase strategy checkpoint_repo requires supabase_client"),
+ ("run_event_repo", "Supabase strategy run_event_repo requires supabase_client"),
+ ("file_operation_repo", "Supabase strategy file_operation_repo requires supabase_client"),
+ ("summary_repo", "Supabase strategy summary_repo requires supabase_client"),
+ ("eval_repo", "Supabase strategy eval_repo requires supabase_client"),
+ ],
+)
+def test_storage_container_supabase_repos_require_client(repo_method: str, message: str) -> None:
+ container = StorageContainer(strategy="supabase")
+ with pytest.raises(RuntimeError, match=message):
+ getattr(container, repo_method)()
+
+
+@pytest.mark.parametrize(
+ ("kwargs", "message"),
+ [
+ ({"strategy": "redis"}, "Unsupported storage strategy: redis. Supported strategies: sqlite, supabase"),
+ ({"repo_providers": {"foo_repo": "sqlite"}}, "Unknown repo provider bindings: foo_repo"),
+ ({"repo_providers": {"checkpoint_repo": "mysql"}}, "Unsupported provider for checkpoint_repo"),
+ ],
+)
+def test_storage_container_rejects_invalid_configuration(kwargs: dict[str, object], message: str) -> None:
+ with pytest.raises(ValueError, match=message):
+ StorageContainer(**kwargs) # type: ignore[arg-type]
diff --git a/tests/test_summary_repo.py b/tests/Unit/storage/test_summary_repo.py
similarity index 100%
rename from tests/test_summary_repo.py
rename to tests/Unit/storage/test_summary_repo.py
diff --git a/tests/middleware/memory/test_summary_store.py b/tests/Unit/storage/test_summary_store.py
similarity index 100%
rename from tests/middleware/memory/test_summary_store.py
rename to tests/Unit/storage/test_summary_store.py
diff --git a/tests/Unit/storage/test_supabase_chat_repo.py b/tests/Unit/storage/test_supabase_chat_repo.py
new file mode 100644
index 000000000..95422182d
--- /dev/null
+++ b/tests/Unit/storage/test_supabase_chat_repo.py
@@ -0,0 +1,98 @@
+from storage.providers.supabase.chat_repo import SupabaseChatMessageRepo
+from tests.fakes.supabase import FakeSupabaseClient
+
+
+def test_supabase_chat_message_repo_has_unread_mention_tracks_mentions_after_last_read():
+ tables = {
+ "chat_entities": [
+ {
+ "chat_id": "chat-1",
+ "user_id": "entity-target",
+ "joined_at": 1.0,
+ "last_read_at": 5.0,
+ }
+ ],
+ "chat_messages": [
+ {
+ "id": "msg-old",
+ "chat_id": "chat-1",
+ "sender_id": "entity-other",
+ "content": "old mention",
+ "mentions": '["entity-target"]',
+ "created_at": 4.0,
+ },
+ {
+ "id": "msg-self",
+ "chat_id": "chat-1",
+ "sender_id": "entity-target",
+ "content": "self mention",
+ "mentions": '["entity-target"]',
+ "created_at": 6.0,
+ },
+ {
+ "id": "msg-unread",
+ "chat_id": "chat-1",
+ "sender_id": "entity-other",
+ "content": "new mention",
+ "mentions": '["entity-target"]',
+ "created_at": 7.0,
+ },
+ {
+ "id": "msg-unread-no-mention",
+ "chat_id": "chat-1",
+ "sender_id": "entity-other",
+ "content": "plain unread",
+ "mentions": "[]",
+ "created_at": 8.0,
+ },
+ ],
+ }
+ repo = SupabaseChatMessageRepo(FakeSupabaseClient(tables))
+
+ assert repo.has_unread_mention("chat-1", "entity-target") is True
+
+
+def test_supabase_chat_message_repo_has_unread_mention_false_without_matching_unread_mentions():
+ tables = {
+ "chat_entities": [
+ {
+ "chat_id": "chat-1",
+ "user_id": "entity-target",
+ "joined_at": 1.0,
+ "last_read_at": 5.0,
+ }
+ ],
+ "chat_messages": [
+ {
+ "id": "msg-unread",
+ "chat_id": "chat-1",
+ "sender_id": "entity-other",
+ "content": "plain unread",
+ "mentions": "[]",
+ "created_at": 7.0,
+ }
+ ],
+ }
+ repo = SupabaseChatMessageRepo(FakeSupabaseClient(tables))
+
+ assert repo.has_unread_mention("chat-1", "entity-target") is False
+
+
+def test_supabase_chat_message_repo_has_unread_mention_false_without_membership_row():
+ tables = {
+ "chat_entities": [],
+ "chat_messages": [
+ {
+ "id": "msg-unread",
+ "chat_id": "chat-1",
+ "sender_id": "entity-other",
+ "content": "new mention",
+ "mentions": '["entity-target"]',
+ "created_at": 7.0,
+ }
+ ],
+ }
+ repo = SupabaseChatMessageRepo(FakeSupabaseClient(tables))
+
+ assert repo.count_unread("chat-1", "entity-target") == 0
+ assert repo.has_unread_mention("chat-1", "entity-target") is False
diff --git a/tests/Unit/storage/test_supabase_entity_repo.py b/tests/Unit/storage/test_supabase_entity_repo.py
new file mode 100644
index 000000000..3a9180e0d
--- /dev/null
+++ b/tests/Unit/storage/test_supabase_entity_repo.py
@@ -0,0 +1,31 @@
+from storage.providers.supabase.entity_repo import SupabaseEntityRepo
+from tests.fakes.supabase import FakeSupabaseClient
+
+
+def test_supabase_entity_repo_get_by_thread_id_returns_matching_entity():
+ tables = {
+ "entities": [
+ {
+ "id": "entity-1",
+ "type": "agent",
+ "member_id": "member-1",
+ "name": "worker-1",
+ "avatar": None,
+ "thread_id": "thread-1",
+ "created_at": 1.0,
+ }
+ ]
+ }
+ repo = SupabaseEntityRepo(FakeSupabaseClient(tables))
+
+ row = repo.get_by_thread_id("thread-1")
+
+ assert row is not None
+ assert row.id == "entity-1"
+ assert row.thread_id == "thread-1"
+
+
+def test_supabase_entity_repo_get_by_thread_id_returns_none_when_missing():
+ repo = SupabaseEntityRepo(FakeSupabaseClient({"entities": []}))
+
+ assert repo.get_by_thread_id("thread-missing") is None
diff --git a/tests/Unit/storage/test_supabase_thread_repo.py b/tests/Unit/storage/test_supabase_thread_repo.py
new file mode 100644
index 000000000..7f684797b
--- /dev/null
+++ b/tests/Unit/storage/test_supabase_thread_repo.py
@@ -0,0 +1,74 @@
+from storage.providers.supabase.thread_repo import SupabaseThreadRepo
+
+
+class _FakeTable:
+ def __init__(self) -> None:
+ self.insert_payload = None
+ self.update_payload = None
+ self.eq_calls: list[tuple[str, object]] = []
+ self.rows = [
+ {
+ "id": "thread-1",
+ "member_id": "member-1",
+ "sandbox_type": "local",
+ "model": None,
+ "cwd": None,
+ "observation_provider": None,
+ "is_main": 1,
+ "branch_index": 0,
+ "created_at": 1.0,
+ }
+ ]
+
+ def insert(self, payload):
+ self.insert_payload = payload
+ return self
+
+ def update(self, payload):
+ self.update_payload = payload
+ return self
+
+ def select(self, _cols):
+ return self
+
+ def eq(self, key, value):
+ self.eq_calls.append((key, value))
+ return self
+
+ def execute(self):
+ return type("Resp", (), {"data": self.rows})()
+
+
+class _FakeClient:
+ def __init__(self) -> None:
+ self.table_obj = _FakeTable()
+
+ def table(self, _name):
+ return self.table_obj
+
+
+def test_supabase_thread_repo_create_writes_integer_main_flag():
+ client = _FakeClient()
+ repo = SupabaseThreadRepo(client)
+
+ repo.create(
+ thread_id="thread-1",
+ member_id="member-1",
+ sandbox_type="local",
+ created_at=1.0,
+ is_main=True,
+ branch_index=0,
+ )
+
+ assert client.table_obj.insert_payload["is_main"] == 1
+
+
+def test_supabase_thread_repo_update_writes_integer_main_flag():
+ client = _FakeClient()
+ client.table_obj.rows[0]["branch_index"] = 1
+ client.table_obj.rows[0]["is_main"] = 0
+ repo = SupabaseThreadRepo(client)
+
+ repo.update("thread-1", is_main=False)
+
+ assert client.table_obj.update_payload["is_main"] == 0
diff --git a/tests/test_sync_state_thread_safety.py b/tests/Unit/storage/test_sync_state_thread_safety.py
similarity index 100%
rename from tests/test_sync_state_thread_safety.py
rename to tests/Unit/storage/test_sync_state_thread_safety.py
diff --git a/tests/test_sync_strategy.py b/tests/Unit/storage/test_sync_strategy.py
similarity index 100%
rename from tests/test_sync_strategy.py
rename to tests/Unit/storage/test_sync_strategy.py
diff --git a/tests/test_thread_repo.py b/tests/Unit/storage/test_thread_repo.py
similarity index 100%
rename from tests/test_thread_repo.py
rename to tests/Unit/storage/test_thread_repo.py
diff --git a/tests/middleware/memory/test_summary_store_performance.py b/tests/middleware/memory/test_summary_store_performance.py
deleted file mode 100644
index ce3b0c3bb..000000000
--- a/tests/middleware/memory/test_summary_store_performance.py
+++ /dev/null
@@ -1,266 +0,0 @@
-"""Performance tests for SummaryStore.
-
-This module tests the performance characteristics of SummaryStore operations
-to ensure they meet production requirements.
-
-Test Cases:
-1. Query performance with many summaries (1000 summaries, query < 50ms)
-2. Concurrent write performance (10 threads, avg write < 100ms)
-3. Database size growth (100 summaries, DB < 1MB)
-"""
-
-import sys
-import threading
-import time
-from pathlib import Path
-
-import pytest
-
-_SKIP_WINDOWS = pytest.mark.skipif(
- sys.platform == "win32", reason="SQLite connection-per-call is slow on Windows; performance tests not meaningful there"
-)
-
-from core.runtime.middleware.memory.summary_store import SummaryStore
-
-
-@_SKIP_WINDOWS
-def test_query_performance_with_many_summaries(temp_db):
- """Test query performance with 1000 summaries.
-
- Requirements:
- - Create 1000 summaries across multiple threads
- - Query for latest summary should complete in < 50ms
- - Index should enable fast lookups even with large dataset
- """
- store = SummaryStore(temp_db)
-
- # Create 1000 summaries across 100 threads (10 summaries per thread)
- num_threads = 100
- summaries_per_thread = 10
-
- print(f"\n[Performance Test] Creating {num_threads * summaries_per_thread} summaries...")
- start_time = time.perf_counter()
-
- for thread_idx in range(num_threads):
- thread_id = f"thread-{thread_idx:04d}"
- for summary_idx in range(summaries_per_thread):
- store.save_summary(
- thread_id=thread_id,
- summary_text=f"Summary {summary_idx} for {thread_id}. " * 10, # ~500 chars
- compact_up_to_index=summary_idx * 10,
- compacted_at=summary_idx * 20,
- )
-
- creation_time = time.perf_counter() - start_time
- print(f"[Performance Test] Created 1000 summaries in {creation_time:.2f}s")
-
- # Now test query performance on a thread with many summaries
- # Query the middle thread to avoid edge cases
- target_thread = "thread-0050"
-
- # Warm up query (first query might be slower due to cold cache)
- store.get_latest_summary(target_thread)
-
- # Measure query performance over 10 iterations
- query_times = []
- for _ in range(10):
- start = time.perf_counter()
- summary = store.get_latest_summary(target_thread)
- elapsed = (time.perf_counter() - start) * 1000 # Convert to ms
- query_times.append(elapsed)
-
- assert summary is not None
- assert summary.thread_id == target_thread
-
- avg_query_time = sum(query_times) / len(query_times)
- max_query_time = max(query_times)
-
- print(f"[Performance Test] Query times: avg={avg_query_time:.2f}ms, max={max_query_time:.2f}ms")
-
- # Assert performance requirements
- assert avg_query_time < 50, f"Average query time {avg_query_time:.2f}ms exceeds 50ms threshold"
- assert max_query_time < 100, f"Max query time {max_query_time:.2f}ms exceeds 100ms threshold"
-
-
-@_SKIP_WINDOWS
-def test_concurrent_write_performance(temp_db):
- """Test concurrent write performance with 10 threads.
-
- Requirements:
- - 10 threads writing concurrently
- - Each thread writes 10 summaries
- - Average write time per summary < 100ms
- - No database locks or corruption
- """
- store = SummaryStore(temp_db)
-
- num_threads = 10
- summaries_per_thread = 10
-
- results = []
- errors = []
-
- def write_summaries(thread_idx: int):
- """Worker function to write summaries."""
- thread_id = f"concurrent-thread-{thread_idx:02d}"
- thread_times = []
-
- try:
- for summary_idx in range(summaries_per_thread):
- start = time.perf_counter()
-
- store.save_summary(
- thread_id=thread_id,
- summary_text=f"Concurrent summary {summary_idx} from thread {thread_idx}. " * 10,
- compact_up_to_index=summary_idx * 10,
- compacted_at=summary_idx * 20,
- )
-
- elapsed = (time.perf_counter() - start) * 1000 # Convert to ms
- thread_times.append(elapsed)
-
- results.append(
- {
- "thread_idx": thread_idx,
- "times": thread_times,
- "avg_time": sum(thread_times) / len(thread_times),
- }
- )
- except Exception as e:
- errors.append(
- {
- "thread_idx": thread_idx,
- "error": str(e),
- }
- )
-
- # Start all threads
- print(f"\n[Performance Test] Starting {num_threads} concurrent write threads...")
- start_time = time.perf_counter()
-
- threads = []
- for i in range(num_threads):
- t = threading.Thread(target=write_summaries, args=(i,))
- threads.append(t)
- t.start()
-
- # Wait for all threads to complete
- for t in threads:
- t.join()
-
- total_time = time.perf_counter() - start_time
-
- # Check for errors
- assert len(errors) == 0, f"Concurrent writes failed: {errors}"
- assert len(results) == num_threads, f"Expected {num_threads} results, got {len(results)}"
-
- # Calculate statistics
- all_times = []
- for result in results:
- all_times.extend(result["times"])
-
- avg_write_time = sum(all_times) / len(all_times)
- max_write_time = max(all_times)
- min_write_time = min(all_times)
-
- print(f"[Performance Test] Concurrent writes completed in {total_time:.2f}s")
- print(f"[Performance Test] Write times: avg={avg_write_time:.2f}ms, min={min_write_time:.2f}ms, max={max_write_time:.2f}ms")
-
- # Assert performance requirements
- assert avg_write_time < 100, f"Average write time {avg_write_time:.2f}ms exceeds 100ms threshold"
-
- # Verify data integrity - each thread should have its latest summary
- for i in range(num_threads):
- thread_id = f"concurrent-thread-{i:02d}"
- summary = store.get_latest_summary(thread_id)
- assert summary is not None, f"Missing summary for {thread_id}"
- assert summary.thread_id == thread_id
- assert summary.compact_up_to_index == (summaries_per_thread - 1) * 10
-
-
-@_SKIP_WINDOWS
-def test_database_size_growth(temp_db):
- """Test database size growth with 100 summaries.
-
- Requirements:
- - Create 100 summaries with realistic content
- - Database size (including WAL files) should be < 1MB
- - Verify efficient storage without excessive overhead
- """
- store = SummaryStore(temp_db)
-
- num_summaries = 100
-
- # Create realistic summary content (~2KB per summary)
- summary_template = (
- """
- The conversation covered the following topics:
- - User requested implementation of feature X
- - Discussion about architecture and design patterns
- - Code review and feedback on proposed changes
- - Testing strategy and coverage requirements
- - Documentation updates and API changes
- """
- * 10
- ) # ~2KB of text
-
- print(f"\n[Performance Test] Creating {num_summaries} summaries with realistic content...")
-
- for i in range(num_summaries):
- store.save_summary(
- thread_id=f"size-test-thread-{i:03d}",
- summary_text=f"Summary {i}: {summary_template}",
- compact_up_to_index=i * 10,
- compacted_at=i * 20,
- is_split_turn=(i % 5 == 0), # 20% split turns
- split_turn_prefix=f"Prefix for summary {i}" if i % 5 == 0 else None,
- )
-
- # Force WAL checkpoint to flush data to main database
- import sqlite3
-
- conn = sqlite3.connect(str(temp_db))
- try:
- conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")
- conn.commit()
- finally:
- conn.close()
-
- # Calculate total database size (main DB + WAL files)
- db_size = temp_db.stat().st_size
-
- wal_size = 0
- for suffix in ["-wal", "-shm"]:
- wal_file = Path(str(temp_db) + suffix)
- if wal_file.exists():
- wal_size += wal_file.stat().st_size
-
- total_size = db_size + wal_size
- total_size_kb = total_size / 1024
- total_size_mb = total_size / (1024 * 1024)
-
- print("[Performance Test] Database sizes:")
- print(f" - Main DB: {db_size / 1024:.2f} KB")
- print(f" - WAL files: {wal_size / 1024:.2f} KB")
- print(f" - Total: {total_size_kb:.2f} KB ({total_size_mb:.3f} MB)")
-
- # Assert size requirements
- assert total_size < 1024 * 1024, f"Database size {total_size_mb:.3f}MB exceeds 1MB threshold"
-
- # Verify data integrity - spot check a few summaries
- for i in [0, 49, 99]:
- thread_id = f"size-test-thread-{i:03d}"
- summary = store.get_latest_summary(thread_id)
- assert summary is not None, f"Missing summary for {thread_id}"
- assert summary.thread_id == thread_id
- assert summary.compact_up_to_index == i * 10
- assert summary_template in summary.summary_text
-
- # Verify total count
- all_threads = [f"size-test-thread-{i:03d}" for i in range(num_summaries)]
- found_count = sum(1 for tid in all_threads if store.get_latest_summary(tid) is not None)
- assert found_count == num_summaries, f"Expected {num_summaries} summaries, found {found_count}"
-
-
-if __name__ == "__main__":
- pytest.main([__file__, "-v", "-s"])
diff --git a/tests/test_agent_pool.py b/tests/test_agent_pool.py
deleted file mode 100644
index 3ddd2945f..000000000
--- a/tests/test_agent_pool.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import asyncio
-import time
-from types import SimpleNamespace
-
-import pytest
-
-from backend.web.services import agent_pool
-
-
-class _FakeThreadRepo:
- def get_by_id(self, thread_id: str):
- return {"id": thread_id, "cwd": "/tmp", "model": "leon:large"}
-
-
-@pytest.mark.asyncio
-async def test_get_or_create_agent_creates_once_per_thread(monkeypatch: pytest.MonkeyPatch):
- created: list[object] = []
-
- def _fake_create_agent_sync(
- sandbox_name: str,
- workspace_root=None,
- model_name: str | None = None,
- agent: str | None = None,
- queue_manager=None,
- chat_repos=None,
- extra_allowed_paths=None,
- ) -> object:
- time.sleep(0.05)
- obj = SimpleNamespace()
- created.append(obj)
- return obj
-
- monkeypatch.setattr(agent_pool, "create_agent_sync", _fake_create_agent_sync)
- monkeypatch.setattr(agent_pool, "get_or_create_agent_id", lambda **_: "agent-1")
-
- app = SimpleNamespace(
- state=SimpleNamespace(
- agent_pool={},
- thread_repo=_FakeThreadRepo(),
- thread_cwd={},
- thread_sandbox={},
- )
- )
-
- first, second = await asyncio.gather(
- agent_pool.get_or_create_agent(app, "local", thread_id="thread-1"),
- agent_pool.get_or_create_agent(app, "local", thread_id="thread-1"),
- )
-
- assert len(created) == 1
- assert first is second
- assert app.state.agent_pool["thread-1:local"] is first
diff --git a/tests/test_capability_async.py b/tests/test_capability_async.py
deleted file mode 100644
index 8d1ba06d7..000000000
--- a/tests/test_capability_async.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import asyncio
-import uuid
-
-from sandbox.capability import SandboxCapability
-from sandbox.interfaces.executor import AsyncCommand, ExecuteResult
-
-
-class _DummyState:
- cwd = "/tmp"
-
-
-class _DummyTerminal:
- terminal_id = "dummy-term"
-
- def get_state(self):
- return _DummyState()
-
-
-class _DummyRuntime:
- def __init__(self):
- self.commands: list[str] = []
- self._async_commands: dict[str, AsyncCommand] = {}
-
- async def execute(self, command: str, timeout=None):
- self.commands.append(command)
- await asyncio.sleep(0.01)
- return ExecuteResult(exit_code=0, stdout=f"ok:{command}", stderr="")
-
- async def start_command(self, command: str, cwd: str) -> AsyncCommand:
- command_id = f"cmd_{uuid.uuid4().hex[:12]}"
- result = await self.execute(command)
- async_cmd = AsyncCommand(
- command_id=command_id,
- command_line=command,
- cwd=cwd,
- exit_code=result.exit_code,
- done=True,
- stdout_buffer=[result.stdout],
- )
- self._async_commands[command_id] = async_cmd
- return async_cmd
-
- async def get_command(self, command_id: str) -> AsyncCommand | None:
- return self._async_commands.get(command_id)
-
- async def wait_for_command(self, command_id: str, timeout: float | None = None) -> ExecuteResult | None:
- cmd = self._async_commands.get(command_id)
- if cmd is None:
- return None
- return ExecuteResult(
- exit_code=cmd.exit_code or 0,
- stdout="".join(cmd.stdout_buffer),
- stderr="".join(cmd.stderr_buffer),
- )
-
-
-class _DummySession:
- def __init__(self):
- self.terminal = _DummyTerminal()
- self.runtime = _DummyRuntime()
- self.touches = 0
-
- def touch(self):
- self.touches += 1
-
-
-async def _run_async_command_flow():
- session = _DummySession()
- capability = SandboxCapability(session)
-
- async_cmd = await capability.command.execute_async("echo hi", cwd="/tmp/demo", env={"A": "1"})
- assert async_cmd.command_id.startswith("cmd_")
-
- status = await capability.command.get_status(async_cmd.command_id)
- assert status is not None
-
- result = await capability.command.wait_for(async_cmd.command_id, timeout=1.0)
- assert result is not None
- assert result.exit_code == 0
- assert "echo hi" in result.stdout
- assert session.touches > 0
-
-
-def test_command_wrapper_supports_execute_async():
- asyncio.run(_run_async_command_flow())
diff --git a/tests/test_filesystem_touch_updates_session.py b/tests/test_filesystem_touch_updates_session.py
deleted file mode 100644
index 9a6bede32..000000000
--- a/tests/test_filesystem_touch_updates_session.py
+++ /dev/null
@@ -1,103 +0,0 @@
-"""FS wrapper should count as activity (touch ChatSession) for idle reaper."""
-
-# TODO: fs.list_dir now goes through volume-mount path; FakeProvider needs a volume_id to pass
-import pytest
-
-pytest.skip("pre-existing: FakeProvider missing volume setup — needs test update", allow_module_level=True)
-
-import sqlite3
-import tempfile
-import uuid
-from datetime import datetime
-from pathlib import Path
-
-from sandbox.manager import SandboxManager
-from sandbox.provider import Metrics, ProviderCapability, ProviderExecResult, SandboxProvider, SessionInfo
-
-
-class _FakeProvider(SandboxProvider):
- name = "fake"
-
- def __init__(self) -> None:
- self._statuses: dict[str, str] = {}
-
- def get_capability(self) -> ProviderCapability:
- return ProviderCapability(
- can_pause=True,
- can_resume=True,
- can_destroy=True,
- supports_webhook=False,
- )
-
- def create_session(self, context_id: str | None = None) -> SessionInfo:
- sid = f"s-{uuid.uuid4().hex[:8]}"
- self._statuses[sid] = "running"
- return SessionInfo(session_id=sid, provider=self.name, status="running")
-
- def destroy_session(self, session_id: str, sync: bool = True) -> bool:
- self._statuses.pop(session_id, None)
- return True
-
- def pause_session(self, session_id: str) -> bool:
- self._statuses[session_id] = "paused"
- return True
-
- def resume_session(self, session_id: str) -> bool:
- self._statuses[session_id] = "running"
- return True
-
- def get_session_status(self, session_id: str) -> str:
- return self._statuses.get(session_id, "deleted")
-
- def execute(self, session_id: str, command: str, timeout_ms: int = 30000, cwd: str | None = None) -> ProviderExecResult:
- return ProviderExecResult(output="", exit_code=0)
-
- def read_file(self, session_id: str, path: str) -> str:
- return ""
-
- def write_file(self, session_id: str, path: str, content: str) -> str:
- return "ok"
-
- def list_dir(self, session_id: str, path: str) -> list[dict]:
- return [{"name": "a.txt", "type": "file", "size": 1}]
-
- def get_metrics(self, session_id: str) -> Metrics | None:
- return None
-
- def create_runtime(self, terminal, lease):
- from sandbox.runtime import RemoteWrappedRuntime
-
- return RemoteWrappedRuntime(terminal, lease, self)
-
-
-def _temp_db() -> Path:
- with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
- return Path(f.name)
-
-
-def test_fs_list_dir_touches_session_last_active_at() -> None:
- db = _temp_db()
- try:
- provider = _FakeProvider()
- mgr = SandboxManager(provider=provider, db_path=db)
-
- cap = mgr.get_sandbox("thread-1")
- session_id = cap._session.session_id # type: ignore[attr-defined]
-
- with sqlite3.connect(str(db)) as conn:
- before = conn.execute(
- "SELECT last_active_at FROM chat_sessions WHERE chat_session_id = ?",
- (session_id,),
- ).fetchone()[0]
-
- cap.fs.list_dir("/")
-
- with sqlite3.connect(str(db)) as conn:
- after = conn.execute(
- "SELECT last_active_at FROM chat_sessions WHERE chat_session_id = ?",
- (session_id,),
- ).fetchone()[0]
-
- assert datetime.fromisoformat(str(after)) >= datetime.fromisoformat(str(before))
- finally:
- db.unlink(missing_ok=True)
diff --git a/tests/test_idle_reaper_shared_lease.py b/tests/test_idle_reaper_shared_lease.py
deleted file mode 100644
index 172e07537..000000000
--- a/tests/test_idle_reaper_shared_lease.py
+++ /dev/null
@@ -1,146 +0,0 @@
-from __future__ import annotations
-
-# TODO: get_sandbox now calls _setup_mounts which requires lease.volume_id; FakeProvider needs update
-import pytest
-
-pytest.skip("pre-existing: FakeProvider missing volume setup — needs test update", allow_module_level=True)
-
-import sqlite3
-from dataclasses import dataclass
-from datetime import datetime, timedelta
-from pathlib import Path
-
-from sandbox.manager import SandboxManager
-from sandbox.provider import ProviderCapability, ProviderExecResult, SandboxProvider, SessionInfo
-
-
-@dataclass
-class _DummyInstance:
- instance_id: str
-
-
-class DummyProvider(SandboxProvider):
- """Minimal provider stub for lease + idle-reaper tests."""
-
- name = "daytona"
-
- def __init__(self) -> None:
- self._paused: set[str] = set()
- self._created: list[str] = []
- self._pause_calls: list[str] = []
-
- def get_capability(self) -> ProviderCapability:
- return ProviderCapability(
- can_pause=True,
- can_resume=True,
- can_destroy=True,
- supports_status_probe=True,
- eager_instance_binding=False,
- runtime_kind="remote",
- )
-
- def create_session(self, context_id: str | None = None) -> SessionInfo:
- sid = f"sb-{len(self._created) + 1}"
- self._created.append(sid)
- return SessionInfo(session_id=sid, provider=self.name, status="running")
-
- def destroy_session(self, session_id: str, sync: bool = True) -> bool:
- return True
-
- def pause_session(self, session_id: str) -> bool:
- self._pause_calls.append(session_id)
- self._paused.add(session_id)
- return True
-
- def resume_session(self, session_id: str) -> bool:
- self._paused.discard(session_id)
- return True
-
- def get_session_status(self, session_id: str) -> str:
- return "paused" if session_id in self._paused else "running"
-
- def execute(
- self,
- session_id: str,
- command: str,
- timeout_ms: int = 30000,
- cwd: str | None = None,
- ) -> ProviderExecResult:
- return ProviderExecResult(output="", exit_code=0)
-
- def read_file(self, session_id: str, path: str) -> str:
- return ""
-
- def write_file(self, session_id: str, path: str, content: str) -> str:
- return "ok"
-
- def list_dir(self, session_id: str, path: str) -> list[dict]:
- return []
-
- def get_metrics(self, session_id: str):
- return None
-
- def create_runtime(self, terminal, lease):
- from sandbox.runtime import RemoteWrappedRuntime
-
- return RemoteWrappedRuntime(terminal, lease, self)
-
-
-def _connect(db: Path) -> sqlite3.Connection:
- conn = sqlite3.connect(str(db), timeout=30)
- conn.execute("PRAGMA busy_timeout=30000")
- return conn
-
-
-def test_idle_reaper_does_not_pause_shared_lease_when_other_session_active(tmp_path: Path) -> None:
- db = tmp_path / "sandbox.db"
- provider = DummyProvider()
- manager = SandboxManager(provider=provider, db_path=db)
-
- thread_id = "thread-1"
-
- # Create the main terminal/session.
- cap = manager.get_sandbox(thread_id)
- lease_id = cap._session.lease.lease_id # type: ignore[attr-defined]
-
- # Force-bind a physical instance so idle reaper has something to pause.
- cap._session.lease.ensure_active_instance(provider) # type: ignore[attr-defined]
-
- # Create a background terminal/session on the same lease (non-block command behavior).
- bg_session = manager.create_background_command_session(thread_id=thread_id, initial_cwd="/home/daytona")
-
- main_session_id = cap._session.session_id # type: ignore[attr-defined]
- bg_session_id = bg_session.session_id
-
- # Make the background session expired, keep the main session active.
- now = datetime.now()
- expired_at = (now - timedelta(seconds=10_000)).isoformat()
-
- with _connect(db) as conn:
- conn.execute(
- "UPDATE chat_sessions SET idle_ttl_sec = 1, last_active_at = ?, started_at = ? WHERE chat_session_id = ?",
- (expired_at, expired_at, bg_session_id),
- )
- conn.execute(
- "UPDATE chat_sessions SET idle_ttl_sec = 300, last_active_at = ?, started_at = ? WHERE chat_session_id = ?",
- (now.isoformat(), now.isoformat(), main_session_id),
- )
- conn.commit()
-
- closed = manager.enforce_idle_timeouts()
- assert closed == 1
-
- # The shared lease must NOT be paused because the main session is still active.
- lease = manager.lease_store.get(lease_id)
- assert lease is not None
- assert lease.desired_state == "running"
- assert provider._pause_calls == []
-
- with _connect(db) as conn:
- row = conn.execute(
- "SELECT status, close_reason FROM chat_sessions WHERE chat_session_id = ?",
- (bg_session_id,),
- ).fetchone()
- assert row is not None
- assert row[0] == "closed"
- assert row[1] == "idle_timeout"
diff --git a/tests/test_integration_new_arch.py b/tests/test_integration_new_arch.py
deleted file mode 100644
index 459919424..000000000
--- a/tests/test_integration_new_arch.py
+++ /dev/null
@@ -1,619 +0,0 @@
-"""Integration tests for the full new architecture flow.
-
-Tests the complete flow: Thread → ChatSession → Runtime → Terminal → Lease → Instance
-"""
-
-# TODO: get_sandbox now calls _setup_mounts requiring lease.volume_id; FakeProvider/mock_provider
-# needs a volume configured. Most tests in this file fail for the same reason.
-import pytest
-
-pytest.skip("pre-existing: FakeProvider missing volume setup — needs test update", allow_module_level=True)
-
-import asyncio
-import sqlite3
-import tempfile
-from pathlib import Path
-from unittest.mock import MagicMock
-
-from sandbox.chat_session import ChatSessionManager
-from sandbox.manager import SandboxManager
-from sandbox.provider import ProviderCapability, SessionInfo
-from sandbox.terminal import terminal_from_row
-from storage.providers.sqlite.lease_repo import SQLiteLeaseRepo
-from storage.providers.sqlite.terminal_repo import SQLiteTerminalRepo
-
-
-@pytest.fixture
-def temp_db():
- """Create temporary database for testing."""
- with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
- db_path = Path(f.name)
- yield db_path
- db_path.unlink(missing_ok=True)
-
-
-@pytest.fixture
-def mock_provider():
- """Create mock SandboxProvider for local testing."""
- provider = MagicMock()
- provider.name = "local"
- provider.default_cwd = "/tmp"
- provider.get_capability.return_value = ProviderCapability(
- can_pause=True,
- can_resume=True,
- can_destroy=True,
- supports_webhook=False,
- supports_status_probe=False,
- eager_instance_binding=True,
- inspect_visible=True,
- runtime_kind="local",
- )
- provider.create_session.return_value = SessionInfo(
- session_id="local-inst-1",
- provider="local",
- status="running",
- )
- provider.get_session_status.return_value = "running"
- provider.pause_session.return_value = True
- provider.resume_session.return_value = True
- provider.destroy_session.return_value = True
-
- # Mock execute to return proper results
- def mock_execute(instance_id, command, timeout_ms=None, cwd=None):
- result = MagicMock()
- result.exit_code = 0
-
- if command == "pwd":
- result.stdout = cwd or "/root"
- result.stderr = ""
- elif command.startswith("cd "):
- result.stdout = ""
- result.stderr = ""
- else:
- result.stdout = "command output"
- result.stderr = ""
-
- return result
-
- provider.execute = mock_execute
- from sandbox.providers.local import LocalPersistentShellRuntime
-
- provider.create_runtime.side_effect = lambda terminal, lease: LocalPersistentShellRuntime(terminal, lease)
- return provider
-
-
-@pytest.fixture
-def mock_remote_provider():
- """Create mock remote provider that supports lease lifecycle + fs ops."""
- provider = MagicMock()
- provider.name = "e2b"
- provider.get_capability.return_value = ProviderCapability(
- can_pause=True,
- can_resume=True,
- can_destroy=True,
- supports_webhook=False,
- runtime_kind="remote",
- )
- provider.create_session.return_value = SessionInfo(
- session_id="inst-remote-1",
- provider="e2b",
- status="running",
- )
- provider.get_session_status.return_value = "running"
- provider.pause_session.return_value = True
- provider.resume_session.return_value = True
- provider.write_file.return_value = "ok"
- provider.read_file.return_value = "content"
- provider.list_dir.return_value = []
- from sandbox.runtime import RemoteWrappedRuntime
-
- provider.create_runtime.side_effect = lambda terminal, lease: RemoteWrappedRuntime(terminal, lease, provider)
- return provider
-
-
-@pytest.fixture
-def sandbox_manager(temp_db, mock_provider):
- """Create SandboxManager with temp database."""
- return SandboxManager(provider=mock_provider, db_path=temp_db)
-
-
-@pytest.fixture
-def remote_sandbox_manager(temp_db, mock_remote_provider):
- """Create SandboxManager with remote provider."""
- return SandboxManager(provider=mock_remote_provider, db_path=temp_db)
-
-
-class TestFullArchitectureFlow:
- """Test complete flow through all layers."""
-
- @pytest.mark.skip(reason="pre-existing: get_sandbox now requires lease.volume_id — FakeProvider needs update")
- def test_get_sandbox_creates_all_layers(self, sandbox_manager, temp_db):
- """Test that get_sandbox creates Terminal → Lease → Runtime → ChatSession."""
- thread_id = "test-thread-1"
-
- # Get sandbox (should create everything)
- capability = sandbox_manager.get_sandbox(thread_id)
-
- assert capability is not None
- assert capability._session is not None
- assert capability._session.thread_id == thread_id
- assert capability._session.terminal is not None
- assert capability._session.lease is not None
- assert capability._session.runtime is not None
-
- # Verify persistence
- terminal_store = SQLiteTerminalRepo(db_path=temp_db)
- terminal_row = terminal_store.get_active(thread_id)
- assert terminal_row is not None
-
- lease_repo = SQLiteLeaseRepo(db_path=temp_db)
- lease_row = lease_repo.get(terminal_row["lease_id"])
- lease_repo.close()
- assert lease_row is not None
-
- def test_get_sandbox_reuses_existing_session(self, sandbox_manager):
- """Test that get_sandbox reuses existing session."""
- thread_id = "test-thread-2"
-
- # First call creates
- capability1 = sandbox_manager.get_sandbox(thread_id)
- session_id1 = capability1._session.session_id
-
- # Second call reuses
- capability2 = sandbox_manager.get_sandbox(thread_id)
- session_id2 = capability2._session.session_id
-
- assert session_id1 == session_id2
-
- @pytest.mark.asyncio
- async def test_command_execution_through_capability(self, sandbox_manager):
- """Test command execution through capability wrapper."""
- thread_id = "test-thread-3"
-
- capability = sandbox_manager.get_sandbox(thread_id)
-
- # Execute command
- result = await capability.command.execute("echo hello")
-
- assert result.exit_code == 0
- assert result.stdout is not None
-
- @pytest.mark.asyncio
- async def test_async_command_status_survives_session_recreate(self, sandbox_manager):
- """Completed async commands should remain queryable after ChatSession recreation."""
- thread_id = "test-thread-3b"
- capability1 = sandbox_manager.get_sandbox(thread_id)
- session_id_1 = capability1._session.session_id
-
- async_cmd = await capability1.command.execute_async("echo async-ok")
- done_1 = await capability1.command.wait_for(async_cmd.command_id, timeout=5.0)
- assert done_1 is not None
- assert done_1.exit_code == 0
- assert "async-ok" in done_1.stdout
-
- sandbox_manager.session_manager.delete(session_id_1, reason="test_rotate_session")
- capability2 = sandbox_manager.get_sandbox(thread_id)
- assert capability2._session.session_id != session_id_1
-
- status = await capability2.command.get_status(async_cmd.command_id)
- assert status is not None
- assert status.done
-
- done_2 = await capability2.command.wait_for(async_cmd.command_id, timeout=1.0)
- assert done_2 is not None
- assert done_2.exit_code == 0
- assert "async-ok" in done_2.stdout
-
- @pytest.mark.asyncio
- async def test_non_blocking_command_uses_new_abstract_terminal(self, sandbox_manager, temp_db):
- thread_id = "test-thread-async-terminal"
- capability = sandbox_manager.get_sandbox(thread_id)
- default_terminal_id = capability._session.terminal.terminal_id
- shared_lease_id = capability._session.lease.lease_id
-
- from sandbox.terminal import TerminalState
-
- capability._session.terminal.update_state(TerminalState(cwd="/tmp", env_delta={"FOO": "bar"}))
-
- async_cmd = await capability.command.execute_async("echo bg-terminal")
- result = await capability.command.wait_for(async_cmd.command_id, timeout=5.0)
- assert result is not None
- assert result.exit_code == 0
- assert "bg-terminal" in result.stdout
-
- terminal_rows = sandbox_manager.terminal_store.list_by_thread(thread_id)
- assert len(terminal_rows) == 2
- terminals = [terminal_from_row(r, sandbox_manager.terminal_store.db_path) for r in terminal_rows]
- default_row = sandbox_manager.terminal_store.get_default(thread_id)
- assert default_row is not None
- default_terminal = terminal_from_row(default_row, sandbox_manager.terminal_store.db_path)
- assert default_terminal.terminal_id == default_terminal_id
-
- background_terminal = next(t for t in terminals if t.terminal_id != default_terminal_id)
- assert background_terminal.lease_id == shared_lease_id
- bg_state = background_terminal.get_state()
- assert bg_state.cwd in {"/tmp", "/private/tmp"}
- assert bg_state.env_delta.get("FOO") == "bar"
-
- with sqlite3.connect(str(temp_db), timeout=30) as conn:
- row = conn.execute(
- "SELECT terminal_id FROM terminal_commands WHERE command_id = ?",
- (async_cmd.command_id,),
- ).fetchone()
- assert row is not None
- assert row[0] == background_terminal.terminal_id
-
- @pytest.mark.asyncio
- async def test_running_async_command_visible_from_new_manager(self, temp_db, mock_provider):
- thread_id = "test-thread-running-visible"
- manager1 = SandboxManager(provider=mock_provider, db_path=temp_db)
- capability1 = manager1.get_sandbox(thread_id)
-
- async_cmd = await capability1.command.execute_async("for i in 1 2 3; do echo tick-$i; sleep 1; done")
- await asyncio.sleep(1.2)
-
- # Simulate command_status query from a fresh API manager/session process.
- manager2 = SandboxManager(provider=mock_provider, db_path=temp_db)
- capability2 = manager2.get_sandbox(thread_id)
-
- running = await capability2.command.get_status(async_cmd.command_id)
- assert running is not None
- assert not running.done
- assert "Runtime restarted before command completion" not in "".join(running.stderr_buffer)
- assert "tick-1" in "".join(running.stdout_buffer)
-
- finished = await capability2.command.wait_for(async_cmd.command_id, timeout=5.0)
- assert finished is not None
- assert finished.exit_code == 0
- assert "tick-3" in finished.stdout
-
- def test_terminal_state_persists_across_sessions(self, sandbox_manager, temp_db):
- """Test that terminal state persists when session expires."""
- thread_id = "test-thread-4"
-
- # Create session and update terminal state
- capability1 = sandbox_manager.get_sandbox(thread_id)
- terminal_id = capability1._session.terminal.terminal_id
-
- # Update terminal state
- from sandbox.terminal import TerminalState
-
- new_state = TerminalState(cwd="/tmp", env_delta={"FOO": "bar"})
- capability1._session.terminal.update_state(new_state)
-
- # Delete session (simulating expiry)
- sandbox_manager.session_manager.delete(capability1._session.session_id)
-
- # Get sandbox again (creates new session)
- capability2 = sandbox_manager.get_sandbox(thread_id)
-
- # Terminal should be reused with persisted state
- assert capability2._session.terminal.terminal_id == terminal_id
- state = capability2._session.terminal.get_state()
- assert state.cwd == "/tmp"
- assert state.env_delta == {"FOO": "bar"}
-
- def test_get_sandbox_fails_on_provider_mismatch(self, temp_db, mock_provider, mock_remote_provider):
- local_mgr = SandboxManager(provider=mock_provider, db_path=temp_db)
- remote_mgr = SandboxManager(provider=mock_remote_provider, db_path=temp_db)
-
- thread_id = "test-thread-provider-mismatch"
- _ = local_mgr.get_sandbox(thread_id)
-
- with pytest.raises(RuntimeError, match="bound to provider"):
- remote_mgr.get_sandbox(thread_id)
-
- def test_pause_all_sessions_skips_provider_mismatch(self, temp_db, mock_provider, mock_remote_provider):
- local_mgr = SandboxManager(provider=mock_provider, db_path=temp_db)
- remote_mgr = SandboxManager(provider=mock_remote_provider, db_path=temp_db)
-
- _ = local_mgr.get_sandbox("test-thread-provider-mismatch-pause")
-
- assert remote_mgr.pause_all_sessions() == 0
-
- def test_lease_shared_across_terminals(self, sandbox_manager, temp_db):
- """Test that multiple terminals can share the same lease."""
- thread_id1 = "test-thread-5"
- thread_id2 = "test-thread-6"
-
- # Create first terminal
- capability1 = sandbox_manager.get_sandbox(thread_id1)
- lease_id1 = capability1._session.lease.lease_id
-
- # Manually create second terminal with same lease
- terminal_store = SQLiteTerminalRepo(db_path=temp_db)
- _terminal2 = terminal_store.create(
- terminal_id="term-shared",
- thread_id=thread_id2,
- lease_id=lease_id1,
- )
-
- # Get sandbox for second thread
- capability2 = sandbox_manager.get_sandbox(thread_id2)
- lease_id2 = capability2._session.lease.lease_id
-
- # Should share the same lease
- assert lease_id1 == lease_id2
-
- def test_session_touch_updates_activity(self, sandbox_manager):
- """Test that capability.touch() updates session activity."""
- thread_id = "test-thread-7"
-
- capability = sandbox_manager.get_sandbox(thread_id)
- old_activity = capability._session.last_active_at
-
- import time
-
- time.sleep(0.01)
-
- capability.touch()
-
- # Activity should be updated
- assert capability._session.last_active_at > old_activity
-
- def test_session_info_api(self, sandbox_manager):
- """Test that manager can expose current provider session info."""
- thread_id = "test-thread-8"
-
- session_info = sandbox_manager.get_or_create_session(thread_id)
- assert session_info is not None
- assert session_info.provider == "local"
-
- sessions = sandbox_manager.list_sessions()
- assert len(sessions) > 0
-
- def test_remote_fs_operation_fails_on_paused_lease(self, remote_sandbox_manager, mock_remote_provider):
- """Paused lease must fail fast until explicit resume."""
- thread_id = "test-thread-remote-fs-1"
- capability = remote_sandbox_manager.get_sandbox(thread_id)
-
- lease = capability._session.lease
- lease.ensure_active_instance(mock_remote_provider)
- lease.pause_instance(mock_remote_provider)
- assert lease.get_instance() is not None
- assert lease.get_instance().status == "paused"
- mock_remote_provider.get_session_status.return_value = "paused"
-
- with pytest.raises(RuntimeError, match="is paused"):
- capability.fs.write_file("/home/user/test.txt", "ok")
- assert lease.get_instance().status == "paused"
-
-
-class TestSessionLifecycle:
- """Test session lifecycle management."""
-
- def test_session_expiry_cleanup(self, sandbox_manager, temp_db):
- """Test that expired sessions are cleaned up."""
-
- thread_id = "test-thread-9"
-
- # Create session with very short timeout
- capability = sandbox_manager.get_sandbox(thread_id)
- _session_id = capability._session.session_id
-
- # Manually update policy to expire immediately
- session_manager = ChatSessionManager(
- provider=sandbox_manager.provider,
- db_path=temp_db,
- )
-
- import time
-
- time.sleep(0.1)
-
- # Cleanup expired
- count = session_manager.cleanup_expired()
-
- # Session should still exist (default policy is 10 minutes)
- assert count == 0
-
- def test_pause_and_resume_session(self, sandbox_manager):
- """Test pausing and resuming sessions."""
- thread_id = "test-thread-10"
-
- # Create session
- capability = sandbox_manager.get_sandbox(thread_id)
- session_id = capability._session.session_id
- terminal_id = capability._session.terminal.terminal_id
-
- assert sandbox_manager.pause_session(thread_id)
- paused = sandbox_manager.session_manager.get(thread_id, terminal_id)
- assert paused is not None
- assert paused.session_id == session_id
- assert paused.status == "paused"
-
- assert sandbox_manager.resume_session(thread_id)
- resumed = sandbox_manager.session_manager.get(thread_id, terminal_id)
- assert resumed is not None
- assert resumed.session_id == session_id
- assert resumed.status == "active"
-
- def test_pause_and_resume_cover_all_thread_terminals(self, sandbox_manager):
- thread_id = "test-thread-10b"
- capability = sandbox_manager.get_sandbox(thread_id)
- asyncio.run(capability.command.execute_async("echo bg"))
-
- terminal_rows = sandbox_manager.terminal_store.list_by_thread(thread_id)
- assert len(terminal_rows) == 2
-
- assert sandbox_manager.pause_session(thread_id)
- for row in terminal_rows:
- session = sandbox_manager.session_manager.get(thread_id, row["terminal_id"])
- assert session is not None
- assert session.status == "paused"
-
- assert sandbox_manager.resume_session(thread_id)
- for row in terminal_rows:
- session = sandbox_manager.session_manager.get(thread_id, row["terminal_id"])
- assert session is not None
- assert session.status == "active"
-
- def test_destroy_session(self, sandbox_manager):
- """Test destroying a session."""
- thread_id = "test-thread-11"
-
- # Create session
- capability = sandbox_manager.get_sandbox(thread_id)
- _session_id = capability._session.session_id
- terminal_id = capability._session.terminal.terminal_id
-
- # Destroy
- sandbox_manager.destroy_session(thread_id)
-
- # Session should be gone
- session = sandbox_manager.session_manager.get(thread_id, terminal_id)
- assert session is None
-
- def test_destroy_session_removes_all_thread_resources(self, sandbox_manager):
- thread_id = "test-thread-11b"
- capability = sandbox_manager.get_sandbox(thread_id)
- asyncio.run(capability.command.execute_async("echo bg"))
-
- terminal_rows_before = sandbox_manager.terminal_store.list_by_thread(thread_id)
- assert len(terminal_rows_before) == 2
-
- assert sandbox_manager.destroy_session(thread_id)
- assert sandbox_manager.terminal_store.list_by_thread(thread_id) == []
- assert all(sandbox_manager.session_manager.get(thread_id, row["terminal_id"]) is None for row in terminal_rows_before)
-
-
-class TestMultiThreadScenarios:
- """Test scenarios with multiple threads."""
-
- def test_multiple_threads_independent_sessions(self, sandbox_manager):
- """Test that multiple threads get independent sessions."""
- thread_ids = [f"test-thread-{i}" for i in range(3)]
-
- capabilities = [sandbox_manager.get_sandbox(tid) for tid in thread_ids]
-
- # All should have different sessions
- session_ids = [cap._session.session_id for cap in capabilities]
- assert len(set(session_ids)) == 3
-
- # All should have different terminals
- terminal_ids = [cap._session.terminal.terminal_id for cap in capabilities]
- assert len(set(terminal_ids)) == 3
-
- def test_thread_switch_preserves_state(self, sandbox_manager):
- """Test that switching between threads preserves state."""
- thread_id1 = "test-thread-12"
- thread_id2 = "test-thread-13"
-
- # Work on thread 1
- cap1 = sandbox_manager.get_sandbox(thread_id1)
- from sandbox.terminal import TerminalState
-
- cap1._session.terminal.update_state(TerminalState(cwd="/tmp"))
-
- # Switch to thread 2
- cap2 = sandbox_manager.get_sandbox(thread_id2)
- cap2._session.terminal.update_state(TerminalState(cwd="/home"))
-
- # Switch back to thread 1
- cap1_again = sandbox_manager.get_sandbox(thread_id1)
- state1 = cap1_again._session.terminal.get_state()
- assert state1.cwd == "/tmp"
-
- # Check thread 2 state
- cap2_again = sandbox_manager.get_sandbox(thread_id2)
- state2 = cap2_again._session.terminal.get_state()
- assert state2.cwd == "/home"
-
-
-class TestErrorHandling:
- """Test error handling scenarios."""
-
- def test_missing_terminal_recreates_with_same_id(self, sandbox_manager, temp_db):
- """Test that terminal is recreated when missing from DB.
-
- Note: The terminal_id is stored in the session, so when we delete
- the terminal but not the session, the session still references the
- old terminal_id. This is expected behavior - the terminal_id is
- stable across recreations.
- """
- thread_id = "test-thread-14"
-
- # Create session
- capability = sandbox_manager.get_sandbox(thread_id)
- terminal_id = capability._session.terminal.terminal_id
-
- # Delete terminal from DB (but not session)
- terminal_store = SQLiteTerminalRepo(db_path=temp_db)
- terminal_store.delete(terminal_id)
-
- # Delete session to force full recreation
- sandbox_manager.session_manager.delete(capability._session.session_id)
-
- # Get sandbox again - creates new terminal
- _capability2 = sandbox_manager.get_sandbox(thread_id)
-
- # Terminal should exist in DB now
- _terminal2 = terminal_store.get_active(thread_id)
- assert _terminal2 is not None
-
- def test_missing_lease_recreates_with_same_id(self, sandbox_manager, temp_db):
- """Test that lease is recreated when missing from DB.
-
- Note: The lease_id is stored in the terminal, so when we delete
- the lease but not the terminal, the terminal still references the
- old lease_id. This is expected behavior - the lease_id is stable.
- """
- thread_id = "test-thread-15"
-
- # Create session
- capability = sandbox_manager.get_sandbox(thread_id)
- lease_id = capability._session.lease.lease_id
-
- # Delete lease from DB
- lease_repo = SQLiteLeaseRepo(db_path=temp_db)
- lease_repo.delete(lease_id)
- lease_repo.close()
-
- # Delete session AND terminal to force full recreation
- sandbox_manager.session_manager.delete(capability._session.session_id)
- terminal_store = SQLiteTerminalRepo(db_path=temp_db)
- terminal_store.delete(capability._session.terminal.terminal_id)
-
- # Get sandbox again - creates new terminal + lease
- capability2 = sandbox_manager.get_sandbox(thread_id)
-
- # Lease should exist in DB now
- lease_repo2 = SQLiteLeaseRepo(db_path=temp_db)
- lease2 = lease_repo2.get(capability2._session.lease.lease_id)
- lease_repo2.close()
- assert lease2 is not None
-
-
-# ── create_sandbox() factory tests ──────────────────────────────────────────
-
-from sandbox import LocalSandbox, create_sandbox # noqa: E402
-from sandbox.config import SandboxConfig # noqa: E402
-
-
-def test_create_sandbox_local():
- sbx = create_sandbox(SandboxConfig(provider="local"), workspace_root="/tmp")
- assert isinstance(sbx, LocalSandbox)
- assert sbx.working_dir == "/tmp"
-
-
-def test_create_sandbox_agentbay_requires_api_key(monkeypatch):
- monkeypatch.delenv("AGENTBAY_API_KEY", raising=False)
- with pytest.raises(ValueError, match="AGENTBAY_API_KEY"):
- create_sandbox(SandboxConfig(provider="agentbay"))
-
-
-def test_create_sandbox_e2b_requires_api_key(monkeypatch):
- monkeypatch.delenv("E2B_API_KEY", raising=False)
- with pytest.raises(ValueError, match="E2B_API_KEY"):
- create_sandbox(SandboxConfig(provider="e2b"))
-
-
-def test_create_sandbox_daytona_requires_api_key(monkeypatch):
- monkeypatch.delenv("DAYTONA_API_KEY", raising=False)
- with pytest.raises(ValueError, match="DAYTONA_API_KEY"):
- create_sandbox(SandboxConfig(provider="daytona"))
-
-
-def test_create_sandbox_unknown_provider():
- with pytest.raises(ValueError, match="Unknown sandbox provider"):
- create_sandbox(SandboxConfig(provider="bogus"))
diff --git a/tests/test_local_chat_session.py b/tests/test_local_chat_session.py
deleted file mode 100644
index 49b45fb9a..000000000
--- a/tests/test_local_chat_session.py
+++ /dev/null
@@ -1,72 +0,0 @@
-"""Tests for local sandbox using ChatSession architecture."""
-
-from __future__ import annotations
-
-# TODO: pre-existing: get_sandbox requires lease.volume_id
-import pytest
-
-pytest.skip("pre-existing: FakeProvider missing volume setup — needs test update", allow_module_level=True)
-
-from pathlib import Path
-
-import pytest
-
-from sandbox.base import LocalSandbox
-from sandbox.manager import lookup_sandbox_for_thread
-from sandbox.providers.local import LocalSessionProvider
-from sandbox.thread_context import set_current_thread_id
-
-
-@pytest.mark.asyncio
-async def test_local_chat_session_persistence_and_resume(tmp_path: Path):
- workspace = tmp_path / "workspace"
- workspace.mkdir(parents=True, exist_ok=True)
- db_path = tmp_path / "sandbox.db"
-
- thread_id = "local-thread-1"
- sandbox = LocalSandbox(workspace_root=str(workspace), db_path=db_path)
- set_current_thread_id(thread_id)
- sandbox.ensure_session(thread_id)
-
- shell = sandbox.shell()
-
- first = await shell.execute("cd /tmp && export LEON_LOCAL_VAR=chat-session-ok && pwd")
- assert first.exit_code == 0
- assert "/tmp" in first.stdout
-
- second = await shell.execute("pwd")
- assert second.exit_code == 0
- assert "/tmp" in second.stdout
-
- third = await shell.execute("echo $LEON_LOCAL_VAR")
- assert third.exit_code == 0
- assert "chat-session-ok" in third.stdout
-
- assert sandbox.pause_thread(thread_id)
- assert lookup_sandbox_for_thread(thread_id, db_path=db_path) == "local"
- assert sandbox.resume_thread(thread_id)
-
- set_current_thread_id(thread_id)
- resumed_pwd = await shell.execute("pwd")
- assert resumed_pwd.exit_code == 0
- assert "/tmp" in resumed_pwd.stdout
-
- resumed_env = await shell.execute("echo $LEON_LOCAL_VAR")
- assert resumed_env.exit_code == 0
- assert "chat-session-ok" in resumed_env.stdout
-
- sandbox.close()
-
-
-def test_local_provider_pause_resume_state_recovery():
- provider = LocalSessionProvider()
- session = provider.create_session(context_id="leon-lease-test-session")
- sid = session.session_id
- provider._session_states.clear()
- assert provider.pause_session(sid)
- assert provider.get_session_status(sid) == "paused"
-
- provider._session_states.clear()
- assert provider.resume_session(sid)
- assert provider.get_session_status(sid) == "running"
- assert not provider.pause_session("unknown-session-id")
diff --git a/tests/test_main_thread_flow.py b/tests/test_main_thread_flow.py
deleted file mode 100644
index e9c2afbd3..000000000
--- a/tests/test_main_thread_flow.py
+++ /dev/null
@@ -1,243 +0,0 @@
-import pytest
-
-pytest.skip("pre-existing: thread_config and agent-member wiring broken — needs migration", allow_module_level=True)
-
-import asyncio
-import os
-from types import SimpleNamespace
-
-from backend.web.models.requests import CreateThreadRequest, ResolveMainThreadRequest
-from backend.web.routers import threads as threads_router
-from backend.web.services.auth_service import AuthService
-from storage.contracts import EntityRow
-from storage.providers.sqlite.entity_repo import SQLiteEntityRepo
-from storage.providers.sqlite.member_repo import SQLiteAccountRepo, SQLiteMemberRepo
-from storage.providers.sqlite.thread_repo import SQLiteThreadRepo
-
-
-def test_register_creates_agent_members_without_threads(tmp_path, monkeypatch):
- db_path = tmp_path / "leon.db"
- members_dir = tmp_path / "members"
-
- import backend.web.services.member_service as member_service
-
- monkeypatch.setattr(member_service, "MEMBERS_DIR", members_dir)
- monkeypatch.setattr(member_service, "LEON_HOME", tmp_path)
-
- member_repo = SQLiteMemberRepo(db_path)
- account_repo = SQLiteAccountRepo(db_path)
- entity_repo = SQLiteEntityRepo(db_path)
- thread_repo = SQLiteThreadRepo(db_path)
- service = AuthService(
- members=member_repo,
- accounts=account_repo,
- entities=entity_repo,
- )
-
- payload = service.register("fresh_user", "pass1234")
- claims = service.verify_token(payload["token"])
- account = account_repo.get_by_username("fresh_user")
-
- owned_agents = member_repo.list_by_owner_user_id(payload["user"]["id"])
- assert "member_id" not in claims
- assert claims["user_id"] == payload["user"]["id"]
- assert payload["user"]["name"] == "fresh_user"
- assert account is not None
- assert account.user_id == payload["user"]["id"]
- assert len(owned_agents) == 2
- assert [agent.name for agent in owned_agents] == ["Toad", "Morel"]
- for agent in owned_agents:
- assert thread_repo.list_by_member(agent.id) == []
- assert entity_repo.get_by_member_id(agent.id) == []
-
-
-def test_first_explicit_thread_becomes_main_then_followups_are_children(tmp_path):
- db_path = tmp_path / "leon.db"
-
- member_repo = SQLiteMemberRepo(db_path)
- entity_repo = SQLiteEntityRepo(db_path)
- thread_repo = SQLiteThreadRepo(db_path)
-
- from storage.contracts import MemberRow, MemberType
-
- member_repo.create(
- MemberRow(
- id="owner-1",
- name="owner",
- type=MemberType.HUMAN,
- created_at=1.0,
- )
- )
- member_repo.create(
- MemberRow(
- id="member-1",
- name="Template Agent",
- type=MemberType.MYCEL_AGENT,
- owner_user_id="owner-1",
- created_at=2.0,
- )
- )
-
- app = SimpleNamespace(
- state=SimpleNamespace(
- member_repo=member_repo,
- entity_repo=entity_repo,
- thread_repo=thread_repo,
- thread_sandbox={},
- thread_cwd={},
- )
- )
-
- first = threads_router._create_owned_thread(
- app,
- "owner-1",
- CreateThreadRequest(member_id="member-1", sandbox="local"),
- is_main=False,
- )
- second = threads_router._create_owned_thread(
- app,
- "owner-1",
- CreateThreadRequest(member_id="member-1", sandbox="local"),
- is_main=False,
- )
-
- assert first["is_main"] is True
- assert first["branch_index"] == 0
- assert first["entity_name"] == "Template Agent"
- assert second["is_main"] is False
- assert second["branch_index"] == 1
- assert second["entity_name"] == "Template Agent · 分身1"
- assert thread_repo.get_main_thread("member-1")["id"] == first["thread_id"]
-
-
-def test_member_rename_recomputes_agent_entity_names(tmp_path, monkeypatch):
- db_path = tmp_path / "leon.db"
- members_dir = tmp_path / "members"
- members_dir.mkdir(parents=True)
- os.environ["LEON_DB_PATH"] = str(db_path)
-
- import backend.web.services.member_service as member_service
-
- monkeypatch.setattr(member_service, "MEMBERS_DIR", members_dir)
- monkeypatch.setattr(member_service, "LEON_HOME", tmp_path)
-
- member_repo = SQLiteMemberRepo(db_path)
- entity_repo = SQLiteEntityRepo(db_path)
- thread_repo = SQLiteThreadRepo(db_path)
-
- from storage.contracts import MemberRow, MemberType
-
- member_repo.create(
- MemberRow(
- id="owner-1",
- name="owner",
- type=MemberType.HUMAN,
- created_at=1.0,
- )
- )
- member_repo.create(
- MemberRow(
- id="member-1",
- name="Toad",
- type=MemberType.MYCEL_AGENT,
- owner_user_id="owner-1",
- created_at=2.0,
- )
- )
-
- member_dir = members_dir / "member-1"
- member_dir.mkdir()
- (member_dir / "agent.md").write_text("---\nname: Toad\n---\n\n", encoding="utf-8")
- (member_dir / "meta.json").write_text("{}", encoding="utf-8")
-
- thread_repo.create(
- thread_id="member-1-1",
- member_id="member-1",
- sandbox_type="local",
- created_at=3.0,
- is_main=True,
- branch_index=0,
- )
- thread_repo.create(
- thread_id="member-1-2",
- member_id="member-1",
- sandbox_type="local",
- created_at=4.0,
- is_main=False,
- branch_index=1,
- )
- entity_repo.create(
- EntityRow(
- id="member-1-1",
- type="agent",
- member_id="member-1",
- name="Toad",
- thread_id="member-1-1",
- created_at=3.0,
- )
- )
- entity_repo.create(
- EntityRow(
- id="member-1-2",
- type="agent",
- member_id="member-1",
- name="Toad · 分身1",
- thread_id="member-1-2",
- created_at=4.0,
- )
- )
-
- updated = member_service.update_member("member-1", name="Scout")
-
- refreshed_entities = sorted(entity_repo.get_by_member_id("member-1"), key=lambda entity: entity.thread_id or "")
- assert updated is not None
- assert updated["name"] == "Scout"
- assert [entity.name for entity in refreshed_entities] == ["Scout", "Scout · 分身1"]
-
-
-def test_resolve_main_thread_returns_null_when_member_has_no_main(tmp_path):
- db_path = tmp_path / "leon.db"
-
- member_repo = SQLiteMemberRepo(db_path)
- entity_repo = SQLiteEntityRepo(db_path)
- thread_repo = SQLiteThreadRepo(db_path)
-
- from storage.contracts import MemberRow, MemberType
-
- member_repo.create(
- MemberRow(
- id="owner-1",
- name="owner",
- type=MemberType.HUMAN,
- created_at=1.0,
- )
- )
- member_repo.create(
- MemberRow(
- id="member-1",
- name="Template Agent",
- type=MemberType.MYCEL_AGENT,
- owner_user_id="owner-1",
- created_at=2.0,
- )
- )
-
- app = SimpleNamespace(
- state=SimpleNamespace(
- member_repo=member_repo,
- entity_repo=entity_repo,
- thread_repo=thread_repo,
- thread_sandbox={},
- thread_cwd={},
- )
- )
-
- result = asyncio.run(
- threads_router.resolve_main_thread(
- ResolveMainThreadRequest(member_id="member-1"),
- "owner-1",
- app,
- )
- )
-
- assert result == {"thread": None}
diff --git a/tests/test_manager_ground_truth.py b/tests/test_manager_ground_truth.py
deleted file mode 100644
index 59027d277..000000000
--- a/tests/test_manager_ground_truth.py
+++ /dev/null
@@ -1,303 +0,0 @@
-"""Tests for SandboxManager inspect ground-truth behavior."""
-
-import asyncio
-import sqlite3
-import tempfile
-import uuid
-from datetime import datetime, timedelta
-from pathlib import Path
-
-import pytest
-
-from sandbox.manager import SandboxManager
-from sandbox.provider import Metrics, ProviderCapability, ProviderExecResult, SandboxProvider, SessionInfo
-from storage import StorageContainer
-from storage.providers.sqlite.checkpoint_repo import SQLiteCheckpointRepo
-from storage.providers.sqlite.eval_repo import SQLiteEvalRepo
-from storage.providers.supabase.checkpoint_repo import SupabaseCheckpointRepo
-from storage.providers.supabase.eval_repo import SupabaseEvalRepo
-from storage.providers.supabase.file_operation_repo import SupabaseFileOperationRepo
-from storage.providers.supabase.run_event_repo import SupabaseRunEventRepo
-from storage.providers.supabase.summary_repo import SupabaseSummaryRepo
-
-
-class FakeProvider(SandboxProvider):
- name = "fake"
-
- def __init__(self):
- self._statuses: dict[str, str] = {}
- self.fail_pause = False
-
- def get_capability(self) -> ProviderCapability:
- return ProviderCapability(
- can_pause=True,
- can_resume=True,
- can_destroy=True,
- supports_webhook=False,
- )
-
- def create_session(self, context_id: str | None = None, thread_id: str | None = None) -> SessionInfo:
- sid = f"s-{uuid.uuid4().hex[:8]}"
- self._statuses[sid] = "running"
- return SessionInfo(session_id=sid, provider=self.name, status="running")
-
- def destroy_session(self, session_id: str, sync: bool = True) -> bool:
- self._statuses.pop(session_id, None)
- return True
-
- def pause_session(self, session_id: str) -> bool:
- if self.fail_pause:
- return False
- if session_id in self._statuses:
- self._statuses[session_id] = "paused"
- return True
- return False
-
- def resume_session(self, session_id: str) -> bool:
- if session_id in self._statuses:
- self._statuses[session_id] = "running"
- return True
- return False
-
- def get_session_status(self, session_id: str) -> str:
- return self._statuses.get(session_id, "deleted")
-
- def execute(
- self,
- session_id: str,
- command: str,
- timeout_ms: int = 30000,
- cwd: str | None = None,
- ) -> ProviderExecResult:
- return ProviderExecResult(output="", exit_code=0, error=None)
-
- def read_file(self, session_id: str, path: str) -> str:
- return ""
-
- def write_file(self, session_id: str, path: str, content: str) -> str:
- return "ok"
-
- def list_dir(self, session_id: str, path: str) -> list[dict]:
- return []
-
- def get_metrics(self, session_id: str) -> Metrics | None:
- return None
-
- def list_provider_sessions(self) -> list[SessionInfo]:
- return [SessionInfo(session_id=sid, provider=self.name, status=status) for sid, status in self._statuses.items()]
-
- def create_runtime(self, terminal, lease):
- from sandbox.runtime import RemoteWrappedRuntime
-
- return RemoteWrappedRuntime(terminal, lease, self)
-
-
-class _FakeSupabaseClient:
- def table(self, table_name: str):
- raise AssertionError(f"table() should not be called in this container wiring test: {table_name}")
-
-
-def _temp_db() -> Path:
- with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
- return Path(f.name)
-
-
-@pytest.mark.skip(reason="pre-existing: get_sandbox requires lease.volume_id — FakeProvider needs update")
-def test_list_sessions_shows_running_lease_without_chat_session() -> None:
- db = _temp_db()
- try:
- provider = FakeProvider()
- mgr = SandboxManager(provider=provider, db_path=db)
- lease = mgr.lease_store.create("lease-1", provider.name)
- instance = lease.ensure_active_instance(provider)
- mgr.terminal_store.create("term-1", "thread-1", "lease-1", "/home/user")
-
- rows = mgr.list_sessions()
- assert rows
- row = rows[0]
- assert row["thread_id"] == "thread-1"
- assert row["instance_id"] == instance.instance_id
- assert row["status"] == "running"
- assert row["source"] == "lease"
- finally:
- db.unlink(missing_ok=True)
-
-
-def test_list_sessions_includes_provider_orphan(temp_db) -> None:
- provider = FakeProvider()
- mgr = SandboxManager(provider=provider, db_path=temp_db)
- orphan = provider.create_session()
- rows = mgr.list_sessions()
- assert any(r["instance_id"] == orphan.session_id and r["source"] == "provider_orphan" for r in rows)
-
-
-@pytest.mark.skip(reason="pre-existing: get_sandbox requires lease.volume_id — FakeProvider needs update")
-def test_enforce_idle_timeouts_pauses_lease_and_closes_session() -> None:
- db = _temp_db()
- try:
- provider = FakeProvider()
- mgr = SandboxManager(provider=provider, db_path=db)
-
- capability = mgr.get_sandbox("thread-1")
- asyncio.run(capability.command.execute("echo hi"))
- session_id = capability._session.session_id
- instance_id = capability._session.lease.get_instance().instance_id
-
- with sqlite3.connect(str(db)) as conn:
- conn.execute(
- """
- UPDATE chat_sessions
- SET idle_ttl_sec = 1, last_active_at = ?
- WHERE chat_session_id = ?
- """,
- ((datetime.now() - timedelta(seconds=5)).isoformat(), session_id),
- )
- conn.commit()
-
- count = mgr.enforce_idle_timeouts()
- assert count == 1
- assert provider.get_session_status(instance_id) == "paused"
- assert mgr.session_manager.get("thread-1") is None
- finally:
- db.unlink(missing_ok=True)
-
-
-@pytest.mark.skip(reason="pre-existing: get_sandbox requires lease.volume_id — FakeProvider needs update")
-def test_enforce_idle_timeouts_continues_on_pause_failure() -> None:
- db = _temp_db()
- try:
- provider = FakeProvider()
- mgr = SandboxManager(provider=provider, db_path=db)
-
- capability = mgr.get_sandbox("thread-1")
- asyncio.run(capability.command.execute("echo hi"))
- session_id = capability._session.session_id
-
- with sqlite3.connect(str(db)) as conn:
- conn.execute(
- """
- UPDATE chat_sessions
- SET idle_ttl_sec = 1, last_active_at = ?
- WHERE chat_session_id = ?
- """,
- ((datetime.now() - timedelta(seconds=5)).isoformat(), session_id),
- )
- conn.commit()
-
- provider.fail_pause = True
- count = mgr.enforce_idle_timeouts()
- assert count == 0
- assert mgr.session_manager.get("thread-1") is not None
- finally:
- db.unlink(missing_ok=True)
-
-
-def test_storage_container_sqlite_strategy_is_non_regression(temp_db) -> None:
- container = StorageContainer(main_db_path=temp_db, strategy="sqlite")
- repo = container.checkpoint_repo()
- assert isinstance(repo, SQLiteCheckpointRepo)
-
-
-def test_storage_container_supabase_repos_are_concrete() -> None:
- fake_client = _FakeSupabaseClient()
- container = StorageContainer(strategy="supabase", supabase_client=fake_client)
- checkpoint_repo = container.checkpoint_repo()
- assert isinstance(checkpoint_repo, SupabaseCheckpointRepo)
- run_event_repo = container.run_event_repo()
- assert isinstance(run_event_repo, SupabaseRunEventRepo)
- file_operation_repo = container.file_operation_repo()
- assert isinstance(file_operation_repo, SupabaseFileOperationRepo)
- summary_repo = container.summary_repo()
- assert isinstance(summary_repo, SupabaseSummaryRepo)
- eval_repo = container.eval_repo()
- assert isinstance(eval_repo, SupabaseEvalRepo)
-
-
-def test_storage_container_repo_level_provider_override_from_sqlite_default() -> None:
- fake_client = _FakeSupabaseClient()
- container = StorageContainer(
- strategy="sqlite",
- repo_providers={"checkpoint_repo": "supabase"},
- supabase_client=fake_client,
- )
- assert isinstance(container.checkpoint_repo(), SupabaseCheckpointRepo)
-
-
-def test_storage_container_repo_level_provider_override_from_supabase_default() -> None:
- fake_client = _FakeSupabaseClient()
- container = StorageContainer(
- strategy="supabase",
- repo_providers={"eval_repo": "sqlite"},
- supabase_client=fake_client,
- )
- assert isinstance(container.eval_repo(), SQLiteEvalRepo)
- assert isinstance(container.checkpoint_repo(), SupabaseCheckpointRepo)
-
-
-def test_storage_container_supabase_checkpoint_requires_client() -> None:
- container = StorageContainer(strategy="supabase")
- with pytest.raises(
- RuntimeError,
- match="Supabase strategy checkpoint_repo requires supabase_client",
- ):
- container.checkpoint_repo()
-
-
-def test_storage_container_supabase_run_event_requires_client() -> None:
- container = StorageContainer(strategy="supabase")
- with pytest.raises(
- RuntimeError,
- match="Supabase strategy run_event_repo requires supabase_client",
- ):
- container.run_event_repo()
-
-
-def test_storage_container_supabase_file_operation_requires_client() -> None:
- container = StorageContainer(strategy="supabase")
- with pytest.raises(
- RuntimeError,
- match="Supabase strategy file_operation_repo requires supabase_client",
- ):
- container.file_operation_repo()
-
-
-def test_storage_container_supabase_summary_requires_client() -> None:
- container = StorageContainer(strategy="supabase")
- with pytest.raises(
- RuntimeError,
- match="Supabase strategy summary_repo requires supabase_client",
- ):
- container.summary_repo()
-
-
-def test_storage_container_supabase_eval_requires_client() -> None:
- container = StorageContainer(strategy="supabase")
- with pytest.raises(
- RuntimeError,
- match="Supabase strategy eval_repo requires supabase_client",
- ):
- container.eval_repo()
-
-
-def test_storage_container_rejects_unknown_strategy() -> None:
- with pytest.raises(
- ValueError,
- match="Unsupported storage strategy: redis. Supported strategies: sqlite, supabase",
- ):
- StorageContainer(strategy="redis") # type: ignore[arg-type]
-
-
-def test_storage_container_rejects_unknown_repo_provider_binding() -> None:
- with pytest.raises(
- ValueError,
- match="Unknown repo provider bindings: foo_repo",
- ):
- StorageContainer(repo_providers={"foo_repo": "sqlite"})
-
-
-def test_storage_container_rejects_invalid_repo_provider_value() -> None:
- with pytest.raises(
- ValueError,
- match="Unsupported provider for checkpoint_repo",
- ):
- StorageContainer(repo_providers={"checkpoint_repo": "mysql"})
diff --git a/tests/test_monitor_core_overview.py b/tests/test_monitor_core_overview.py
deleted file mode 100644
index d80ace417..000000000
--- a/tests/test_monitor_core_overview.py
+++ /dev/null
@@ -1,415 +0,0 @@
-import pytest
-
-pytest.skip("pre-existing: monitor/resource_service API mismatch — needs test update", allow_module_level=True)
-
-import json
-from pathlib import Path
-from unittest.mock import MagicMock
-
-from backend.web.services import resource_service
-from sandbox.provider import ProviderCapability, build_resource_capabilities
-
-
-def _write_provider_config(tmp_path: Path, instance_name: str, payload: dict) -> None:
- (tmp_path / f"{instance_name}.json").write_text(json.dumps(payload))
-
-
-def _make_fake_thread_config_repo(agent_by_thread: dict[str, str]):
- """Fake ThreadConfigRepo backed by a simple dict — works for both SQLite and Supabase code paths."""
- repo = MagicMock()
- repo.lookup_config.side_effect = lambda tid: (
- {
- "sandbox_type": "local",
- "cwd": None,
- "model": None,
- "queue_mode": None,
- "observation_provider": None,
- "agent": agent_by_thread[tid],
- }
- if tid in agent_by_thread
- else None
- )
- repo.close.return_value = None
- return repo
-
-
-def _make_fake_repo(sessions: list[dict]):
- """Create a mock repo that returns pre-canned sessions."""
- repo = MagicMock()
- repo.list_sessions_with_leases.return_value = sessions
- repo.close.return_value = None
- return repo
-
-
-def _patch_resources_context(
- monkeypatch,
- *,
- tmp_path: Path,
- providers: list[dict],
- sessions: list[dict],
- snapshots: dict | None = None,
-) -> None:
- monkeypatch.setattr(resource_service, "SANDBOXES_DIR", tmp_path)
- monkeypatch.setattr(resource_service, "available_sandbox_types", lambda: providers)
- monkeypatch.setattr(
- resource_service,
- "SQLiteSandboxMonitorRepo",
- lambda: _make_fake_repo(sessions),
- )
- capability_by_provider = {
- "local": build_resource_capabilities(
- filesystem=True,
- terminal=True,
- metrics=False,
- screenshot=False,
- web=False,
- process=False,
- hooks=False,
- snapshot=False,
- ),
- "docker": build_resource_capabilities(
- filesystem=True,
- terminal=True,
- metrics=True,
- screenshot=False,
- web=False,
- process=False,
- hooks=False,
- snapshot=False,
- ),
- "e2b": build_resource_capabilities(
- filesystem=True,
- terminal=True,
- metrics=False,
- screenshot=False,
- web=False,
- process=False,
- hooks=False,
- snapshot=True,
- ),
- "daytona": build_resource_capabilities(
- filesystem=True,
- terminal=True,
- metrics=False,
- screenshot=False,
- web=False,
- process=False,
- hooks=True,
- snapshot=False,
- ),
- "agentbay": build_resource_capabilities(
- filesystem=True,
- terminal=True,
- metrics=True,
- screenshot=True,
- web=True,
- process=True,
- hooks=False,
- snapshot=False,
- ),
- }
-
- def _fake_provider_builder(config_name: str, *, sandboxes_dir: Path | None = None):
- provider_name = resource_service.resolve_provider_name(
- config_name,
- sandboxes_dir=sandboxes_dir or tmp_path,
- )
- resource_capabilities = capability_by_provider.get(provider_name)
- if resource_capabilities is None:
- return None
-
- class _FakeProvider:
- def get_capability(self) -> ProviderCapability:
- return ProviderCapability(
- can_pause=True,
- can_resume=True,
- can_destroy=True,
- resource_capabilities=resource_capabilities,
- )
-
- return _FakeProvider()
-
- monkeypatch.setattr(resource_service, "build_provider_from_config_name", _fake_provider_builder)
- if snapshots is not None:
- monkeypatch.setattr(resource_service, "list_snapshots_by_lease_ids", lambda _: snapshots)
-
-
-def test_list_resource_providers_maps_status_and_metric_metadata(tmp_path, monkeypatch):
- _write_provider_config(tmp_path, "docker_dev", {"provider": "docker"})
-
- monkeypatch.setattr(
- resource_service,
- "_make_thread_config_repo",
- lambda: _make_fake_thread_config_repo({"thread-local-1": "member-1"}),
- )
- monkeypatch.setattr(resource_service, "_member_name_map", lambda: {"member-1": "Alice"})
- _patch_resources_context(
- monkeypatch,
- tmp_path=tmp_path,
- providers=[
- {"name": "local", "available": True},
- {"name": "docker_dev", "available": False, "reason": "docker daemon down"},
- ],
- sessions=[
- {
- "provider": "local",
- "session_id": "sess-local-1",
- "thread_id": "thread-local-1",
- "observed_state": "detached",
- "desired_state": "running",
- "created_at": "2026-03-03T00:00:00",
- },
- {
- "provider": "docker_dev",
- "session_id": "sess-docker-1",
- "thread_id": "thread-docker-1",
- "observed_state": "paused",
- "desired_state": "paused",
- "created_at": "2026-03-03T00:00:00",
- },
- ],
- )
-
- payload = resource_service.list_resource_providers()
- assert "summary" in payload
- assert "providers" in payload
- assert payload["summary"]["total_providers"] == 2
- assert payload["summary"]["active_providers"] == 1
- assert payload["summary"]["unavailable_providers"] == 1
- assert payload["summary"]["running_sessions"] == 1
-
- local = next(item for item in payload["providers"] if item["id"] == "local")
- assert local["status"] == "active"
- assert local["telemetry"]["running"]["used"] == 1
- assert local["telemetry"]["running"]["source"] == "sandbox_db"
- assert local["telemetry"]["running"]["freshness"] == "cached"
- assert local["sessions"][0]["threadId"] == "thread-local-1"
- assert local["sessions"][0]["agentId"] == "member-1"
- assert local["sessions"][0]["agentName"] == "Alice"
-
- docker = next(item for item in payload["providers"] if item["id"] == "docker_dev")
- assert docker["status"] == "unavailable"
- assert docker["error"]["code"] == "PROVIDER_UNAVAILABLE"
- assert docker["sessions"][0]["status"] == "paused"
- assert docker["sessions"][0]["agentName"] == "未绑定Agent"
-
-
-def test_list_resource_providers_marks_ready_when_no_running_sessions(tmp_path, monkeypatch):
- _write_provider_config(tmp_path, "e2b_test", {"provider": "e2b"})
- _patch_resources_context(
- monkeypatch,
- tmp_path=tmp_path,
- providers=[{"name": "e2b_test", "available": True}],
- sessions=[],
- )
-
- payload = resource_service.list_resource_providers()
- assert len(payload["providers"]) == 1
- assert payload["summary"]["active_providers"] == 0
- assert payload["summary"]["running_sessions"] == 0
-
- e2b = payload["providers"][0]
- assert e2b["id"] == "e2b_test"
- assert e2b["status"] == "ready"
- assert e2b["telemetry"]["running"]["used"] == 0
- assert e2b["telemetry"]["cpu"]["freshness"] == "stale"
- assert e2b["cardCpu"]["used"] is None
- assert e2b["cardCpu"]["limit"] is None
- assert e2b["cardCpu"]["error"] is not None
-
-
-def test_list_resource_providers_prefers_config_console_url_override(tmp_path, monkeypatch):
- _write_provider_config(
- tmp_path,
- "daytona_selfhost",
- {
- "provider": "daytona",
- "console_url": "https://ops.example.com/daytona",
- "daytona": {"target": "local", "api_url": "https://daytona.example.com/api"},
- },
- )
- _patch_resources_context(
- monkeypatch,
- tmp_path=tmp_path,
- providers=[{"name": "daytona_selfhost", "available": True}],
- sessions=[],
- )
-
- payload = resource_service.list_resource_providers()
- provider = payload["providers"][0]
- assert provider["id"] == "daytona_selfhost"
- assert provider["consoleUrl"] == "https://ops.example.com/daytona"
- assert provider["type"] == "container"
-
-
-def test_list_resource_providers_uses_snapshot_metrics(tmp_path, monkeypatch):
- _write_provider_config(tmp_path, "agentbay_prod", {"provider": "agentbay"})
- _patch_resources_context(
- monkeypatch,
- tmp_path=tmp_path,
- providers=[{"name": "agentbay_prod", "available": True}],
- sessions=[
- {
- "provider": "agentbay_prod",
- "session_id": "sess-1",
- "thread_id": "thread-1",
- "lease_id": "lease-1",
- "status": "running",
- "created_at": "2026-03-03T00:00:00",
- }
- ],
- snapshots={
- "lease-1": {
- "lease_id": "lease-1",
- "cpu_used": 21.0,
- "cpu_limit": 100.0,
- "memory_used_mb": 1024.0,
- "memory_total_mb": 4096.0,
- "disk_used_gb": 4.0,
- "disk_total_gb": 20.0,
- "collected_at": "2099-01-01T00:00:00Z",
- }
- },
- )
-
- payload = resource_service.list_resource_providers()
- provider = payload["providers"][0]
- assert provider["telemetry"]["cpu"]["used"] == 21.0
- assert provider["telemetry"]["cpu"]["limit"] == 100.0
- assert provider["telemetry"]["memory"]["used"] == 1.0
- assert provider["telemetry"]["memory"]["limit"] == 4.0
- assert provider["telemetry"]["disk"]["used"] == 4.0
- assert provider["telemetry"]["disk"]["limit"] == 20.0
- assert provider["telemetry"]["cpu"]["source"] == "api"
-
-
-def test_list_resource_providers_surfaces_snapshot_probe_error(tmp_path, monkeypatch):
- _write_provider_config(tmp_path, "daytona_cloud", {"provider": "daytona"})
- _patch_resources_context(
- monkeypatch,
- tmp_path=tmp_path,
- providers=[{"name": "daytona_cloud", "available": True}],
- sessions=[
- {
- "provider": "daytona_cloud",
- "session_id": "sess-1",
- "thread_id": "thread-1",
- "lease_id": "lease-1",
- "status": "paused",
- "created_at": "2026-03-03T00:00:00",
- }
- ],
- snapshots={
- "lease-1": {
- "lease_id": "lease-1",
- "cpu_used": None,
- "cpu_limit": None,
- "memory_used_mb": None,
- "memory_total_mb": None,
- "disk_used_gb": None,
- "disk_total_gb": None,
- "probe_error": "metrics unavailable",
- "collected_at": "2099-01-01T00:00:00Z",
- }
- },
- )
-
- payload = resource_service.list_resource_providers()
- provider = payload["providers"][0]
- assert provider["telemetry"]["cpu"]["used"] is None
- assert provider["telemetry"]["cpu"]["source"] == "sandbox_db"
- assert provider["telemetry"]["cpu"]["error"] == "metrics unavailable"
- assert provider["telemetry"]["memory"]["error"] == "metrics unavailable"
- assert provider["telemetry"]["disk"]["error"] == "metrics unavailable"
-
-
-def test_thread_owner_uses_agent_ref_as_name_when_member_lookup_missing(monkeypatch):
- monkeypatch.setattr(
- resource_service,
- "_make_thread_config_repo",
- lambda: _make_fake_thread_config_repo({"thread-1": "Lex"}),
- )
- monkeypatch.setattr(resource_service, "_member_name_map", lambda: {})
-
- owners = resource_service._thread_owners(["thread-1", "thread-2"])
- assert owners["thread-1"]["agent_id"] == "Lex"
- assert owners["thread-1"]["agent_name"] == "Lex"
- assert owners["thread-2"]["agent_id"] is None
- assert owners["thread-2"]["agent_name"] == "未绑定Agent"
-
-
-def test_thread_owner_works_with_supabase_backed_thread_config(monkeypatch):
- """Thread config lookup routes through ThreadConfigRepo abstraction,
- so it works identically whether the backing store is SQLite or Supabase."""
-
- class _FakeSupabaseThreadConfigRepo:
- """Mimics SupabaseThreadConfigRepo interface without a real Supabase connection."""
-
- def __init__(self):
- self._data = {"thread-supabase-1": "agent-uuid-abc"}
-
- def lookup_config(self, thread_id: str):
- agent = self._data.get(thread_id)
- return (
- {
- "sandbox_type": "local",
- "cwd": None,
- "model": None,
- "queue_mode": None,
- "observation_provider": None,
- "agent": agent,
- }
- if agent
- else None
- )
-
- def close(self):
- pass
-
- monkeypatch.setattr(resource_service, "_make_thread_config_repo", _FakeSupabaseThreadConfigRepo)
- monkeypatch.setattr(resource_service, "_member_name_map", lambda: {"agent-uuid-abc": "Bob"})
-
- owners = resource_service._thread_owners(["thread-supabase-1", "thread-missing"])
- assert owners["thread-supabase-1"]["agent_id"] == "agent-uuid-abc"
- assert owners["thread-supabase-1"]["agent_name"] == "Bob"
- assert owners["thread-missing"]["agent_id"] is None
- assert owners["thread-missing"]["agent_name"] == "未绑定Agent"
-
-
-def test_list_resource_providers_uses_instance_capability_single_source(tmp_path, monkeypatch):
- _write_provider_config(tmp_path, "agentbay_prod", {"provider": "agentbay"})
- _patch_resources_context(
- monkeypatch,
- tmp_path=tmp_path,
- providers=[{"name": "agentbay_prod", "available": True}],
- sessions=[],
- )
-
- class _InstanceOverrideProvider:
- def get_capability(self) -> ProviderCapability:
- return ProviderCapability(
- can_pause=False,
- can_resume=False,
- can_destroy=True,
- resource_capabilities=build_resource_capabilities(
- filesystem=True,
- terminal=True,
- metrics=False,
- screenshot=False,
- web=False,
- process=False,
- hooks=False,
- snapshot=False,
- ),
- )
-
- monkeypatch.setattr(
- resource_service,
- "build_provider_from_config_name",
- lambda _name, **_kwargs: _InstanceOverrideProvider(),
- )
-
- payload = resource_service.list_resource_providers()
- provider = payload["providers"][0]
- assert provider["capabilities"]["metrics"] is False
- assert provider["capabilities"]["web"] is False
diff --git a/tests/test_mount_pluggable.py b/tests/test_mount_pluggable.py
deleted file mode 100644
index b9bcdd049..000000000
--- a/tests/test_mount_pluggable.py
+++ /dev/null
@@ -1,212 +0,0 @@
-"""Mount contract tests for pluggable multi-folder mounts."""
-
-from __future__ import annotations
-
-# TODO: pre-existing failures — provider capability API changed
-import pytest
-
-pytest.skip("pre-existing: provider capability API mismatch — needs test update", allow_module_level=True)
-
-import subprocess
-import sys
-import types
-from pathlib import Path
-
-import pytest
-
-
-def test_mount_spec_defaults_to_mount_mode() -> None:
- from sandbox.config import MountSpec
-
- mount = MountSpec.model_validate({"source": "/host/x", "target": "/sandbox/x"})
- assert mount.mode == "mount"
-
-
-def test_create_thread_request_parses_bind_mounts_with_legacy_keys() -> None:
- from backend.web.models.requests import CreateThreadRequest
-
- payload = CreateThreadRequest.model_validate(
- {
- "sandbox": "local",
- "bind_mounts": [
- {"source": "/host/tasks", "target": "/sandbox/tasks", "mode": "mount", "read_only": False},
- {"host_path": "/host/docs", "mount_path": "/sandbox/docs", "mode": "copy", "read_only": True},
- ],
- }
- )
-
- assert len(payload.bind_mounts) == 2
- assert payload.bind_mounts[0].source == "/host/tasks"
- assert payload.bind_mounts[0].target == "/sandbox/tasks"
- assert payload.bind_mounts[1].source == "/host/docs"
- assert payload.bind_mounts[1].target == "/sandbox/docs"
- assert payload.bind_mounts[1].mode == "copy"
- assert payload.bind_mounts[1].read_only is True
-
-
-def test_mount_capability_gate_detects_mismatch() -> None:
- from backend.web.routers.threads import _find_mount_capability_mismatch
- from sandbox.config import MountSpec
- from sandbox.provider import MountCapability
-
- requested = [MountSpec.model_validate({"source": "/host/a", "target": "/sandbox/a", "mode": "copy"})]
- mismatch = _find_mount_capability_mismatch(
- requested_mounts=requested,
- mount_capability=MountCapability(supports_mount=True, supports_copy=False, supports_read_only=False),
- )
-
- assert mismatch is not None
- assert mismatch["requested"] == {"mode": "copy", "read_only": False}
- assert mismatch["capability"]["supports_copy"] is False
-
-
-def test_mount_capability_gate_accepts_supported_combo() -> None:
- from backend.web.routers.threads import _find_mount_capability_mismatch
- from sandbox.config import MountSpec
- from sandbox.provider import MountCapability
-
- requested = [
- MountSpec.model_validate({"source": "/host/a", "target": "/sandbox/a", "mode": "mount", "read_only": True}),
- MountSpec.model_validate({"source": "/host/b", "target": "/sandbox/b", "mode": "copy", "read_only": False}),
- ]
- mismatch = _find_mount_capability_mismatch(
- requested_mounts=requested,
- mount_capability=MountCapability(supports_mount=True, supports_copy=True, supports_read_only=True),
- )
- assert mismatch is None
-
-
-def test_mount_capability_gate_respects_mode_handlers() -> None:
- from backend.web.routers.threads import _find_mount_capability_mismatch
- from sandbox.config import MountSpec
- from sandbox.provider import MountCapability
-
- requested = [MountSpec.model_validate({"source": "/host/a", "target": "/sandbox/a", "mode": "copy"})]
- mismatch = _find_mount_capability_mismatch(
- requested_mounts=requested,
- mount_capability=MountCapability(
- supports_mount=True,
- supports_copy=True,
- supports_read_only=True,
- mode_handlers={"mount": True, "copy": False},
- ),
- )
-
- assert mismatch is not None
- assert mismatch["requested"] == {"mode": "copy", "read_only": False}
- assert mismatch["capability"]["mode_handlers"]["copy"] is False
-
-
-def test_docker_provider_supports_multiple_bind_mount_modes(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
- from sandbox.providers.docker import DockerProvider
-
- copy_source = tmp_path / "bootstrap"
- copy_source.mkdir(parents=True, exist_ok=True)
- (copy_source / "seed.txt").write_text("hello")
-
- provider = DockerProvider(
- image="python:3.12-slim",
- mount_path="/workspace",
- default_cwd="/home/leon",
- bind_mounts=[
- {"source": "/host/tasks", "target": "/home/leon/shared/tasks", "mode": "mount", "read_only": False},
- {"source": "/host/docs", "target": "/home/leon/shared/docs", "mode": "mount", "read_only": True},
- {"source": str(copy_source), "target": "/home/leon/bootstrap", "mode": "copy", "read_only": False},
- {
- "host_path": "/host/issues",
- "mount_path": "/home/leon/shared/issues",
- "mode": "mount",
- "read_only": False,
- },
- ],
- )
-
- calls: list[list[str]] = []
-
- def fake_run(cmd: list[str], **_: object) -> subprocess.CompletedProcess[str]:
- calls.append(cmd)
- return subprocess.CompletedProcess(cmd, 0, stdout="container-123\n", stderr="")
-
- monkeypatch.setattr(provider, "_run", fake_run)
-
- session = provider.create_session(context_id="ctx-volume")
- assert session.status == "running"
-
- run_cmd = calls[0]
- volume_specs = [run_cmd[i + 1] for i, token in enumerate(run_cmd) if token == "-v"]
- assert "/host/tasks:/home/leon/shared/tasks" in volume_specs
- assert "/host/docs:/home/leon/shared/docs:ro" in volume_specs
- assert "/host/issues:/home/leon/shared/issues" in volume_specs
- assert "ctx-volume:/workspace" in volume_specs
- assert all(str(copy_source) not in spec for spec in volume_specs)
-
- serialized_calls = [" ".join(cmd) for cmd in calls]
- assert any("docker cp" in cmd and "bootstrap/." in cmd and "container-123:/home/leon/bootstrap" in cmd for cmd in serialized_calls)
-
-
-def test_daytona_provider_maps_multiple_mounts_to_http_payload(monkeypatch: pytest.MonkeyPatch) -> None:
- captured: dict[str, object] = {}
-
- class FakeDaytona:
- def __init__(self) -> None:
- pass
-
- fake_sdk = types.SimpleNamespace(Daytona=FakeDaytona)
- monkeypatch.setitem(sys.modules, "daytona_sdk", fake_sdk)
-
- import sandbox.providers.daytona as daytona_module
- from sandbox.providers.daytona import DaytonaProvider
-
- class FakeResponse:
- def __init__(self, status_code: int, payload: dict[str, object]) -> None:
- self.status_code = status_code
- self._payload = payload
- self.text = str(payload)
-
- def json(self) -> dict[str, object]:
- return self._payload
-
- class FakeClient:
- def __init__(self, timeout: float) -> None:
- self.timeout = timeout
-
- def __enter__(self) -> FakeClient:
- return self
-
- def __exit__(self, exc_type, exc, tb) -> None:
- return None
-
- def post(self, url: str, headers: dict[str, str], json: dict[str, object]) -> FakeResponse:
- captured["url"] = url
- captured["headers"] = headers
- captured["json"] = json
- return FakeResponse(200, {"id": "sb-123"})
-
- monkeypatch.setattr(daytona_module.httpx, "Client", FakeClient)
-
- provider = DaytonaProvider(
- api_key="token-1",
- api_url="http://127.0.0.1:3000/api",
- bind_mounts=[
- {"source": "/host/tasks", "target": "/home/daytona/shared/tasks", "mode": "mount", "read_only": False},
- {"source": "/host/docs", "target": "/home/daytona/shared/docs", "mode": "mount", "read_only": True},
- {"source": "/host/bootstrap", "target": "/home/daytona/bootstrap", "mode": "copy", "read_only": False},
- {
- "host_path": "/host/issues",
- "mount_path": "/home/daytona/shared/issues",
- "mode": "mount",
- "read_only": False,
- },
- ],
- )
-
- sandbox_id = provider._create_via_http(provider.bind_mounts)
- assert sandbox_id == "sb-123"
-
- payload = captured["json"]
- assert isinstance(payload, dict)
- assert payload.get("bindMounts") == [
- {"hostPath": "/host/tasks", "mountPath": "/home/daytona/shared/tasks", "readOnly": False},
- {"hostPath": "/host/docs", "mountPath": "/home/daytona/shared/docs", "readOnly": True},
- {"hostPath": "/host/issues", "mountPath": "/home/daytona/shared/issues", "readOnly": False},
- ]
diff --git a/tests/test_remote_sandbox.py b/tests/test_remote_sandbox.py
deleted file mode 100644
index c0a48e22a..000000000
--- a/tests/test_remote_sandbox.py
+++ /dev/null
@@ -1,142 +0,0 @@
-"""Unit tests for RemoteSandbox._run_init_commands and RemoteSandbox.close()."""
-
-# TODO: pre-existing: get_sandbox now requires lease.volume_id
-import pytest
-
-pytest.skip("pre-existing: RemoteSandbox tests need volume setup — needs test update", allow_module_level=True)
-
-import asyncio
-import tempfile
-from pathlib import Path
-from unittest.mock import AsyncMock, MagicMock
-
-import pytest
-
-from sandbox.base import RemoteSandbox
-from sandbox.config import SandboxConfig
-from sandbox.interfaces.executor import ExecuteResult
-from sandbox.provider import ProviderCapability, SessionInfo
-from sandbox.thread_context import set_current_thread_id
-
-
-@pytest.fixture
-def temp_db():
- with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
- db_path = Path(f.name)
- yield db_path
- db_path.unlink(missing_ok=True)
-
-
-def _make_provider(on_init_exit_code: int = 0) -> MagicMock:
- provider = MagicMock()
- provider.name = "mock"
- provider.default_cwd = "/tmp"
- provider.get_capability.return_value = ProviderCapability(
- can_pause=True,
- can_resume=True,
- can_destroy=True,
- supports_status_probe=False,
- eager_instance_binding=True,
- )
- provider.create_session.return_value = SessionInfo(session_id="inst-1", provider="mock", status="running")
- provider.get_session_status.return_value = "running"
- provider.pause_session.return_value = True
- provider.resume_session.return_value = True
- provider.destroy_session.return_value = True
-
- runtime = MagicMock()
- runtime.runtime_id = "runtime-test-000001"
- runtime.chat_session_id = None
- runtime.execute = AsyncMock(
- return_value=ExecuteResult(
- exit_code=on_init_exit_code,
- stdout="ok" if on_init_exit_code == 0 else "",
- stderr="" if on_init_exit_code == 0 else "fail",
- )
- )
- runtime.close = AsyncMock()
- provider.create_runtime.return_value = runtime
- return provider
-
-
-def _make_sandbox(provider, db_path: Path, init_commands: list[str] | None = None, on_exit: str = "pause") -> RemoteSandbox:
- config = SandboxConfig(provider="mock", on_exit=on_exit, init_commands=init_commands or [])
- return RemoteSandbox(
- provider=provider,
- config=config,
- default_cwd="/tmp",
- db_path=db_path,
- name="mock",
- working_dir="/tmp",
- env_label="Mock",
- )
-
-
-# ── _run_init_commands ───────────────────────────────────────────────────────
-
-
-def test_run_init_commands_happy_path(temp_db):
- sandbox = _make_sandbox(_make_provider(), temp_db, init_commands=["echo hello"])
- set_current_thread_id("thread-init-1")
- assert sandbox._get_capability() is not None
- assert "thread-init-1" in sandbox._init_commands_run
-
-
-def test_run_init_commands_failure_raises(temp_db):
- sandbox = _make_sandbox(_make_provider(on_init_exit_code=1), temp_db, init_commands=["bad-cmd"])
- set_current_thread_id("thread-init-fail")
- with pytest.raises(RuntimeError, match="Init command #1 failed"):
- sandbox._get_capability()
-
-
-def test_run_init_commands_idempotent(temp_db):
- sandbox = _make_sandbox(_make_provider(), temp_db, init_commands=["echo once"])
- set_current_thread_id("thread-init-2")
- sandbox._get_capability()
- sandbox._get_capability()
- assert len(sandbox._init_commands_run) == 1
-
-
-@pytest.mark.asyncio
-async def test_run_init_commands_inside_running_loop(temp_db):
- """Covers the run_coroutine_threadsafe branch: _get_capability called from a running event loop."""
- sandbox = _make_sandbox(_make_provider(), temp_db, init_commands=["echo hello"])
- set_current_thread_id("thread-init-async")
- await asyncio.to_thread(sandbox._get_capability)
- assert "thread-init-async" in sandbox._init_commands_run
-
-
-# ── RemoteSandbox.close() ────────────────────────────────────────────────────
-
-
-def test_close_pause_calls_pause_all_sessions(temp_db):
- sandbox = _make_sandbox(_make_provider(), temp_db, on_exit="pause")
- sandbox._manager.pause_all_sessions = MagicMock(return_value=2)
- sandbox.close()
- sandbox._manager.pause_all_sessions.assert_called_once()
-
-
-def test_close_destroy_calls_destroy_for_each_session(temp_db):
- sandbox = _make_sandbox(_make_provider(), temp_db, on_exit="destroy")
- sandbox._manager.list_sessions = MagicMock(return_value=[{"thread_id": "t1"}, {"thread_id": "t2"}, {"thread_id": "t3"}])
- sandbox._manager.destroy_session = MagicMock(return_value=True)
- sandbox.close()
- assert sandbox._manager.destroy_session.call_count == 3
-
-
-def test_close_destroy_continues_after_one_failure(temp_db):
- sandbox = _make_sandbox(_make_provider(), temp_db, on_exit="destroy")
- sandbox._manager.list_sessions = MagicMock(return_value=[{"thread_id": "t1"}, {"thread_id": "t2"}, {"thread_id": "t3"}])
-
- call_count = 0
-
- def side_effect(thread_id):
- nonlocal call_count
- call_count += 1
- if thread_id == "t2":
- raise RuntimeError("network error")
- return True
-
- sandbox._manager.destroy_session = MagicMock(side_effect=side_effect)
- sandbox.close()
- assert call_count == 3
diff --git a/tests/test_resource_snapshot.py b/tests/test_resource_snapshot.py
deleted file mode 100644
index 314e2a194..000000000
--- a/tests/test_resource_snapshot.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import pytest
-
-pytest.skip("pre-existing: resource_snapshot API mismatch — needs test update", allow_module_level=True)
-
-from pathlib import Path
-from unittest.mock import MagicMock
-
-from sandbox.provider import Metrics, ProviderCapability, ProviderExecResult, SandboxProvider, SessionInfo
-from sandbox.resource_snapshot import (
- ensure_resource_snapshot_table,
- list_snapshots_by_lease_ids,
- probe_and_upsert_for_instance,
- upsert_lease_resource_snapshot,
-)
-
-
-class _FakeProvider(SandboxProvider):
- name = "fake"
-
- def get_capability(self) -> ProviderCapability:
- return ProviderCapability(
- can_pause=True,
- can_resume=True,
- can_destroy=True,
- resource_capabilities={
- "filesystem": True,
- "terminal": True,
- "metrics": True,
- "screenshot": False,
- "web": False,
- "process": False,
- "hooks": False,
- "mount": False,
- },
- )
-
- def create_session(self, context_id: str | None = None) -> SessionInfo:
- raise RuntimeError("unused")
-
- def destroy_session(self, session_id: str, sync: bool = True) -> bool:
- raise RuntimeError("unused")
-
- def pause_session(self, session_id: str) -> bool:
- raise RuntimeError("unused")
-
- def resume_session(self, session_id: str) -> bool:
- raise RuntimeError("unused")
-
- def get_session_status(self, session_id: str) -> str:
- raise RuntimeError("unused")
-
- def execute(self, session_id: str, command: str, timeout_ms: int = 30000, cwd: str | None = None) -> ProviderExecResult:
- raise RuntimeError("unused")
-
- def read_file(self, session_id: str, path: str) -> str:
- raise RuntimeError("unused")
-
- def write_file(self, session_id: str, path: str, content: str) -> str:
- raise RuntimeError("unused")
-
- def list_dir(self, session_id: str, path: str) -> list[dict]:
- raise RuntimeError("unused")
-
- def get_metrics(self, session_id: str) -> Metrics | None:
- return Metrics(
- cpu_percent=23.5,
- memory_used_mb=1536.0,
- memory_total_mb=4096.0,
- disk_used_gb=8.0,
- disk_total_gb=20.0,
- network_rx_kbps=30.0,
- network_tx_kbps=40.0,
- )
-
-
-def test_upsert_and_query_snapshot(tmp_path):
- db_path = Path(tmp_path) / "sandbox.db"
- ensure_resource_snapshot_table(db_path)
- upsert_lease_resource_snapshot(
- lease_id="lease-1",
- provider_name="agentbay_prod",
- observed_state="running",
- probe_mode="running_runtime",
- cpu_used=12.0,
- cpu_limit=100.0,
- memory_used_mb=512.0,
- memory_total_mb=1024.0,
- disk_used_gb=2.0,
- disk_total_gb=10.0,
- network_rx_kbps=1.0,
- network_tx_kbps=2.0,
- probe_error=None,
- db_path=db_path,
- )
- snapshots = list_snapshots_by_lease_ids(["lease-1"], db_path=db_path)
- assert snapshots["lease-1"]["provider_name"] == "agentbay_prod"
- assert snapshots["lease-1"]["cpu_used"] == 12.0
-
-
-def test_probe_and_upsert_from_provider_metrics(tmp_path):
- db_path = Path(tmp_path) / "sandbox.db"
- provider = _FakeProvider()
- result = probe_and_upsert_for_instance(
- lease_id="lease-2",
- provider_name="fake_provider",
- observed_state="running",
- probe_mode="create_running",
- provider=provider,
- instance_id="instance-1",
- db_path=db_path,
- )
- assert result["ok"] is True
- snapshots = list_snapshots_by_lease_ids(["lease-2"], db_path=db_path)
- assert snapshots["lease-2"]["cpu_used"] == 23.5
- assert snapshots["lease-2"]["memory_total_mb"] == 4096.0
-
-
-def test_probe_and_upsert_ignores_non_numeric_metrics(tmp_path):
- db_path = Path(tmp_path) / "sandbox.db"
- provider = _FakeProvider()
- provider.get_metrics = lambda _session_id: MagicMock()
- result = probe_and_upsert_for_instance(
- lease_id="lease-3",
- provider_name="fake_provider",
- observed_state="running",
- probe_mode="create_running",
- provider=provider,
- instance_id="instance-1",
- db_path=db_path,
- )
- assert result["ok"] is False
- assert result["error"] == "metrics unavailable"
- snapshots = list_snapshots_by_lease_ids(["lease-3"], db_path=db_path)
- assert snapshots["lease-3"]["cpu_used"] is None
- assert snapshots["lease-3"]["probe_error"] == "metrics unavailable"
diff --git a/tests/test_sandbox_e2e.py b/tests/test_sandbox_e2e.py
deleted file mode 100644
index f1dd64383..000000000
--- a/tests/test_sandbox_e2e.py
+++ /dev/null
@@ -1,234 +0,0 @@
-"""End-to-end headless test for sandbox mode.
-
-Tests that LeonAgent can:
-1. Initialize with sandbox=docker or sandbox=e2b
-2. Execute commands in the sandbox
-3. Read/write files in the sandbox
-4. All paths resolve correctly (no macOS firmlink leaks)
-
-Usage:
- # Docker sandbox (requires Docker running)
- pytest tests/test_sandbox_e2e.py -k docker -s
-
- # E2B sandbox (requires E2B_API_KEY)
- pytest tests/test_sandbox_e2e.py -k e2b -s
-
- # Both
- pytest tests/test_sandbox_e2e.py -s
-"""
-
-import pytest
-
-pytest.skip("pre-existing: Docker/E2B e2e tests require running providers", allow_module_level=True)
-
-import os
-import sys
-import uuid
-
-import pytest
-
-sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-
-# Load config.env so API keys are available
-from config.env_manager import ConfigManager
-
-ConfigManager().load_to_env()
-
-
-def _can_docker() -> bool:
- """Check if Docker is available."""
- import subprocess
-
- try:
- subprocess.run(["docker", "info"], capture_output=True, timeout=5)
- return True
- except Exception:
- return False
-
-
-def _can_e2b() -> bool:
- if os.getenv("E2B_API_KEY"):
- return True
- # Check sandbox config file
- from pathlib import Path
-
- config_file = Path.home() / ".leon" / "sandboxes" / "e2b.json"
- if config_file.exists():
- import json
-
- data = json.loads(config_file.read_text())
- key = data.get("e2b", {}).get("api_key")
- if key:
- os.environ["E2B_API_KEY"] = key
- return True
- return False
-
-
-def _invoke_and_extract(agent, message: str, thread_id: str) -> dict:
- """Invoke agent via async runner and extract tool calls + response."""
- import asyncio
-
- from core.runner import NonInteractiveRunner
- from sandbox.thread_context import set_current_thread_id
-
- set_current_thread_id(thread_id)
- runner = NonInteractiveRunner(agent, thread_id, debug=True)
- result = asyncio.run(runner.run_turn(message))
-
- return {
- "tool_calls": [tc["name"] for tc in result.get("tool_calls", [])],
- "response": result.get("response", ""),
- "error": result.get("error"),
- }
-
-
-def _get_model_name() -> str:
- return os.getenv("MODEL_NAME") or "claude-sonnet-4-5-20250929"
-
-
-# ---------------------------------------------------------------------------
-# Docker E2E
-# ---------------------------------------------------------------------------
-
-
-@pytest.mark.skipif(not _can_docker(), reason="Docker not available")
-class TestDockerSandboxE2E:
- def test_agent_init_and_command(self):
- """Agent initializes with docker sandbox and can run commands."""
- from agent import create_leon_agent
-
- thread_id = f"test-docker-{uuid.uuid4().hex[:8]}"
- agent = None
- try:
- agent = create_leon_agent(
- model_name=_get_model_name(),
- sandbox="docker",
- verbose=True,
- )
-
- # Verify workspace_root is the sandbox path, not a local resolved path
- assert str(agent.workspace_root) == "/workspace", f"workspace_root should be /workspace, got {agent.workspace_root}"
-
- # Ensure session exists before invoking
- agent._sandbox.ensure_session(thread_id)
-
- extracted = _invoke_and_extract(
- agent,
- "Use the run_command tool to execute: echo 'SANDBOX_OK' && pwd",
- thread_id,
- )
-
- print("\n--- Result ---")
- print(f"Response: {extracted['response'][:500]}")
- print(f"Tool calls: {extracted['tool_calls']}")
-
- assert "run_command" in extracted["tool_calls"], f"Expected run_command in {extracted['tool_calls']}"
-
- finally:
- if agent:
- agent.close()
-
- def test_file_operations(self):
- """Agent can read and write files in docker sandbox."""
- from agent import create_leon_agent
-
- thread_id = f"test-docker-{uuid.uuid4().hex[:8]}"
- agent = None
- try:
- agent = create_leon_agent(
- model_name=_get_model_name(),
- sandbox="docker",
- verbose=True,
- )
- agent._sandbox.ensure_session(thread_id)
-
- extracted = _invoke_and_extract(
- agent,
- "Write the text 'hello from test' to /workspace/test_e2e.txt, then read it back and tell me the content.",
- thread_id,
- )
-
- print("\n--- Result ---")
- print(f"Response: {extracted['response'][:500]}")
- print(f"Tool calls: {extracted['tool_calls']}")
-
- assert "write_file" in extracted["tool_calls"], f"Expected write_file in {extracted['tool_calls']}"
-
- finally:
- if agent:
- agent.close()
-
-
-# ---------------------------------------------------------------------------
-# E2B E2E
-# ---------------------------------------------------------------------------
-
-
-@pytest.mark.skipif(not _can_e2b(), reason="E2B_API_KEY not set")
-class TestE2BSandboxE2E:
- def test_agent_init_and_command(self):
- """Agent initializes with e2b sandbox and can run commands."""
- from agent import create_leon_agent
-
- thread_id = f"test-e2b-{uuid.uuid4().hex[:8]}"
- agent = None
- try:
- agent = create_leon_agent(
- model_name=_get_model_name(),
- sandbox="e2b",
- verbose=True,
- )
-
- assert str(agent.workspace_root) == "/home/user", f"workspace_root should be /home/user, got {agent.workspace_root}"
-
- agent._sandbox.ensure_session(thread_id)
-
- extracted = _invoke_and_extract(
- agent,
- "Use the run_command tool to execute: echo 'E2B_OK' && uname -a",
- thread_id,
- )
-
- print("\n--- Result ---")
- print(f"Response: {extracted['response'][:500]}")
- print(f"Tool calls: {extracted['tool_calls']}")
-
- assert "run_command" in extracted["tool_calls"], f"Expected run_command in {extracted['tool_calls']}"
-
- finally:
- if agent:
- agent.close()
-
- def test_file_operations(self):
- """Agent can read and write files in e2b sandbox."""
- from agent import create_leon_agent
-
- thread_id = f"test-e2b-{uuid.uuid4().hex[:8]}"
- agent = None
- try:
- agent = create_leon_agent(
- model_name=_get_model_name(),
- sandbox="e2b",
- verbose=True,
- )
- agent._sandbox.ensure_session(thread_id)
-
- extracted = _invoke_and_extract(
- agent,
- "Write the text 'e2b test content' to /home/user/test_e2e.txt, then read it back and tell me the content.",
- thread_id,
- )
-
- print("\n--- Result ---")
- print(f"Response: {extracted['response'][:500]}")
- print(f"Tool calls: {extracted['tool_calls']}")
-
- assert "write_file" in extracted["tool_calls"], f"Expected write_file in {extracted['tool_calls']}"
-
- finally:
- if agent:
- agent.close()
-
-
-if __name__ == "__main__":
- pytest.main([__file__, "-s", "-v"])
diff --git a/tests/test_storage_runtime_wiring.py b/tests/test_storage_runtime_wiring.py
deleted file mode 100644
index fcb60e8ae..000000000
--- a/tests/test_storage_runtime_wiring.py
+++ /dev/null
@@ -1,389 +0,0 @@
-"""Runtime storage wiring tests for backend agent creation path."""
-
-from __future__ import annotations
-
-import asyncio
-from pathlib import Path
-from types import SimpleNamespace
-from typing import Any
-
-import pytest
-
-from backend.web.services import agent_pool
-from backend.web.services.event_buffer import ThreadEventBuffer
-from backend.web.services.streaming_service import _run_agent_to_buffer
-from storage.providers.sqlite.checkpoint_repo import SQLiteCheckpointRepo
-from storage.providers.sqlite.eval_repo import SQLiteEvalRepo
-from storage.providers.supabase.checkpoint_repo import SupabaseCheckpointRepo
-
-
-class _FakeSupabaseClient:
- def table(self, table_name: str):
- raise AssertionError(f"table() should not be called in this wiring test: {table_name}")
-
-
-def _build_fake_supabase_client() -> _FakeSupabaseClient:
- return _FakeSupabaseClient()
-
-
-def _build_invalid_supabase_client() -> object:
- return object()
-
-
-def _capture_create_leon_agent(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]:
- captured: dict[str, Any] = {}
-
- def _fake_create_leon_agent(**kwargs):
- captured.update(kwargs)
- return object()
-
- monkeypatch.setattr(agent_pool, "create_leon_agent", _fake_create_leon_agent)
- return captured
-
-
-def test_create_agent_sync_wires_supabase_storage_container(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
- monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
- monkeypatch.setenv(
- "LEON_SUPABASE_CLIENT_FACTORY",
- "tests.test_storage_runtime_wiring:_build_fake_supabase_client",
- )
- monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
- monkeypatch.setenv("LEON_EVAL_DB_PATH", str(tmp_path / "eval.db"))
-
- captured = _capture_create_leon_agent(monkeypatch)
- agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
-
- container = captured["storage_container"]
- assert isinstance(container.checkpoint_repo(), SupabaseCheckpointRepo)
-
-
-def test_create_agent_sync_supabase_missing_runtime_config_fails_loud(
- monkeypatch: pytest.MonkeyPatch,
- tmp_path: Path,
-) -> None:
- monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
- monkeypatch.delenv("LEON_SUPABASE_CLIENT_FACTORY", raising=False)
-
- with pytest.raises(
- RuntimeError,
- match="LEON_SUPABASE_CLIENT_FACTORY",
- ):
- agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
-
-
-def test_create_agent_sync_supabase_invalid_runtime_config_fails_loud(
- monkeypatch: pytest.MonkeyPatch,
- tmp_path: Path,
-) -> None:
- monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
- monkeypatch.setenv(
- "LEON_SUPABASE_CLIENT_FACTORY",
- "tests.test_storage_runtime_wiring:_build_invalid_supabase_client",
- )
-
- with pytest.raises(RuntimeError, match="callable table\\(name\\) API"):
- agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
-
-
-def test_create_agent_sync_defaults_to_sqlite_storage_container(
- monkeypatch: pytest.MonkeyPatch,
- tmp_path: Path,
-) -> None:
- monkeypatch.delenv("LEON_STORAGE_STRATEGY", raising=False)
- monkeypatch.delenv("LEON_SUPABASE_CLIENT_FACTORY", raising=False)
- monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
-
- captured = _capture_create_leon_agent(monkeypatch)
- agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
-
- container = captured["storage_container"]
- assert isinstance(container.checkpoint_repo(), SQLiteCheckpointRepo)
-
-
-def test_create_agent_sync_repo_override_supabase_with_sqlite_default(
- monkeypatch: pytest.MonkeyPatch,
- tmp_path: Path,
-) -> None:
- monkeypatch.setenv("LEON_STORAGE_STRATEGY", "sqlite")
- monkeypatch.setenv("LEON_STORAGE_REPO_PROVIDERS", '{"checkpoint_repo":"supabase"}')
- monkeypatch.setenv(
- "LEON_SUPABASE_CLIENT_FACTORY",
- "tests.test_storage_runtime_wiring:_build_fake_supabase_client",
- )
- monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
-
- captured = _capture_create_leon_agent(monkeypatch)
- agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
- container = captured["storage_container"]
- assert isinstance(container.checkpoint_repo(), SupabaseCheckpointRepo)
-
-
-def test_create_agent_sync_repo_override_sqlite_with_supabase_default(
- monkeypatch: pytest.MonkeyPatch,
- tmp_path: Path,
-) -> None:
- monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
- monkeypatch.setenv("LEON_STORAGE_REPO_PROVIDERS", '{"eval_repo":"sqlite"}')
- monkeypatch.setenv(
- "LEON_SUPABASE_CLIENT_FACTORY",
- "tests.test_storage_runtime_wiring:_build_fake_supabase_client",
- )
- monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
- monkeypatch.setenv("LEON_EVAL_DB_PATH", str(tmp_path / "eval.db"))
-
- captured = _capture_create_leon_agent(monkeypatch)
- agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
- container = captured["storage_container"]
- assert isinstance(container.eval_repo(), SQLiteEvalRepo)
-
-
-@pytest.mark.skip(reason="pre-existing: storage wiring/factory API mismatch")
-def test_create_agent_sync_all_sqlite_override_with_supabase_default_does_not_require_factory(
- monkeypatch: pytest.MonkeyPatch,
- tmp_path: Path,
-) -> None:
- monkeypatch.setenv("LEON_STORAGE_STRATEGY", "supabase")
- monkeypatch.setenv(
- "LEON_STORAGE_REPO_PROVIDERS",
- (
- '{"checkpoint_repo":"sqlite","thread_config_repo":"sqlite","run_event_repo":"sqlite",'
- '"file_operation_repo":"sqlite","summary_repo":"sqlite","eval_repo":"sqlite",'
- '"queue_repo":"sqlite","workspace_repo":"sqlite"}'
- ),
- )
- monkeypatch.delenv("LEON_SUPABASE_CLIENT_FACTORY", raising=False)
- monkeypatch.setenv("LEON_DB_PATH", str(tmp_path / "leon.db"))
- monkeypatch.setenv("LEON_EVAL_DB_PATH", str(tmp_path / "eval.db"))
-
- captured = _capture_create_leon_agent(monkeypatch)
- agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
- container = captured["storage_container"]
- assert isinstance(container.checkpoint_repo(), SQLiteCheckpointRepo)
-
-
-def test_create_agent_sync_repo_override_supabase_without_runtime_config_fails_loud(
- monkeypatch: pytest.MonkeyPatch,
- tmp_path: Path,
-) -> None:
- monkeypatch.setenv("LEON_STORAGE_STRATEGY", "sqlite")
- monkeypatch.setenv("LEON_STORAGE_REPO_PROVIDERS", '{"checkpoint_repo":"supabase"}')
- monkeypatch.delenv("LEON_SUPABASE_CLIENT_FACTORY", raising=False)
-
- with pytest.raises(RuntimeError, match="LEON_SUPABASE_CLIENT_FACTORY"):
- agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
-
-
-def test_create_agent_sync_invalid_repo_override_json_fails_loud(
- monkeypatch: pytest.MonkeyPatch,
- tmp_path: Path,
-) -> None:
- monkeypatch.setenv("LEON_STORAGE_REPO_PROVIDERS", "not-json")
-
- with pytest.raises(RuntimeError, match="Invalid LEON_STORAGE_REPO_PROVIDERS"):
- agent_pool.create_agent_sync("local", workspace_root=tmp_path, model_name="leon:test")
-
-
-class _FakeRunEventRepo:
- def __init__(self) -> None:
- self.append_calls: list[dict[str, Any]] = []
- self.closed = False
-
- def append_event(
- self,
- thread_id: str,
- run_id: str,
- event_type: str,
- data: dict[str, Any],
- message_id: str | None = None,
- ) -> int:
- self.append_calls.append(
- {
- "thread_id": thread_id,
- "run_id": run_id,
- "event_type": event_type,
- "data": data,
- "message_id": message_id,
- }
- )
- return len(self.append_calls)
-
- def list_run_ids(self, thread_id: str) -> list[str]:
- return []
-
- def delete_runs(self, thread_id: str, run_ids: list[str]) -> int:
- return 0
-
- def close(self) -> None:
- self.closed = True
-
-
-class _FakeStorageContainer:
- def __init__(self, repo: _FakeRunEventRepo) -> None:
- self._repo = repo
-
- def run_event_repo(self) -> _FakeRunEventRepo:
- return self._repo
-
-
-class _FakeGraphAgent:
- checkpointer = None
-
- async def astream(self, *_args: Any, **_kwargs: Any):
- if False: # pragma: no cover
- yield None
-
-
-class _FakeRuntime:
- current_state = "IDLE"
-
- def get_pending_subagent_events(self) -> list[tuple[str, list[dict[str, Any]]]]:
- return []
-
- def get_status_dict(self) -> dict[str, Any]:
- return {}
-
- def set_event_callback(self, cb: Any) -> None:
- pass
-
- def set_activity_sink(self, sink: Any) -> None:
- pass
-
- def emit_activity_event(self, event: dict[str, Any]) -> None:
- pass
-
- def transition(self, new_state: Any) -> bool:
- return True
-
-
-class _FakeRuntimeAgent:
- def __init__(self, storage_container: Any = None) -> None:
- self.agent = _FakeGraphAgent()
- self.storage_container = storage_container
- self.runtime = _FakeRuntime()
-
-
-@pytest.mark.skip(reason="pre-existing: storage wiring/factory API mismatch")
-def test_run_runtime_consumes_storage_container_run_event_repo(monkeypatch: pytest.MonkeyPatch) -> None:
- async def _run() -> None:
- repo = _FakeRunEventRepo()
- agent = _FakeRuntimeAgent(storage_container=_FakeStorageContainer(repo))
- from unittest.mock import MagicMock
-
- qm = MagicMock()
- qm.dequeue.return_value = None
- app = SimpleNamespace(state=SimpleNamespace(thread_tasks={}, thread_event_buffers={}, subagent_buffers={}, queue_manager=qm))
- thread_buf = ThreadEventBuffer()
- run_id = "run-1"
-
- await _run_agent_to_buffer(agent, "thread-1", "hello", app, False, thread_buf, run_id)
-
- assert repo.append_calls, "run path should persist events through storage_container.run_event_repo()"
- assert any(c["event_type"] == "run_done" for c in repo.append_calls)
- assert repo.closed is True
-
- asyncio.run(_run())
-
-
-@pytest.mark.skip(reason="pre-existing: storage wiring/factory API mismatch")
-def test_run_runtime_without_storage_container_keeps_sqlite_event_store_path(monkeypatch: pytest.MonkeyPatch) -> None:
- async def _run() -> None:
- import backend.web.services.event_store as event_store
-
- calls: list[dict[str, Any]] = []
-
- async def _fake_append_event(
- thread_id: str,
- run_id: str,
- event: dict[str, Any],
- message_id: str | None = None,
- run_event_repo: Any | None = None,
- ) -> int:
- calls.append(
- {
- "thread_id": thread_id,
- "run_id": run_id,
- "event": event,
- "message_id": message_id,
- "run_event_repo": run_event_repo,
- }
- )
- return len(calls)
-
- async def _fake_cleanup_old_runs(
- thread_id: str,
- keep_latest: int = 1,
- run_event_repo: Any | None = None,
- ) -> int:
- return 0
-
- monkeypatch.setattr(event_store, "append_event", _fake_append_event)
- monkeypatch.setattr(event_store, "cleanup_old_runs", _fake_cleanup_old_runs)
-
- from unittest.mock import MagicMock
-
- qm = MagicMock()
- qm.dequeue.return_value = None
- agent = _FakeRuntimeAgent(storage_container=None)
- app = SimpleNamespace(state=SimpleNamespace(thread_tasks={}, thread_event_buffers={}, subagent_buffers={}, queue_manager=qm))
- thread_buf = ThreadEventBuffer()
- run_id = "run-1"
-
- await _run_agent_to_buffer(agent, "thread-1", "hello", app, False, thread_buf, run_id)
-
- assert calls, "sqlite event store path should still be used when no storage container is injected"
- assert all(call["run_event_repo"] is None for call in calls)
-
- asyncio.run(_run())
-
-
-@pytest.mark.skip(reason="pre-existing: thread_config_repo removed from StorageContainer")
-def test_purge_thread_deletes_all_repo_data(tmp_path: Path) -> None:
- from storage.container import StorageContainer
-
- db_path = tmp_path / "leon.db"
- eval_db = tmp_path / "eval.db"
- container = StorageContainer(main_db_path=db_path, eval_db_path=eval_db, strategy="sqlite")
-
- # Populate repos for thread t-1 and t-2
- tc = container.thread_config_repo()
- tc.save_metadata("t-1", "docker", "/ws")
- tc.save_metadata("t-2", "local", None)
- tc.close()
-
- re_repo = container.run_event_repo()
- re_repo.append_event("t-1", "r-1", "status", {"ok": True})
- re_repo.append_event("t-2", "r-2", "status", {"ok": True})
- re_repo.close()
-
- fo = container.file_operation_repo()
- fo.record("t-1", "cp-1", "write", "/a.txt", None, "x")
- fo.record("t-2", "cp-2", "write", "/b.txt", None, "y")
- fo.close()
-
- sr = container.summary_repo()
- sr.ensure_tables()
- sr.save_summary("s-1", "t-1", "summary", 10, 20, False, None, "2025-01-01")
- sr.close()
-
- # Purge t-1
- container.purge_thread("t-1")
-
- # Verify t-1 is gone, t-2 remains
- tc2 = container.thread_config_repo()
- assert tc2.lookup_metadata("t-1") is None
- assert tc2.lookup_metadata("t-2") == ("local", None)
- tc2.close()
-
- re2 = container.run_event_repo()
- assert re2.latest_seq("t-1") == 0
- assert re2.latest_seq("t-2") > 0
- re2.close()
-
- fo2 = container.file_operation_repo()
- assert fo2.get_operations_for_thread("t-1") == []
- assert len(fo2.get_operations_for_thread("t-2")) == 1
- fo2.close()
-
- sr2 = container.summary_repo()
- assert sr2.get_latest_summary_row("t-1") is None
- sr2.close()
diff --git a/tests/test_thread_config_repo.py b/tests/test_thread_config_repo.py
deleted file mode 100644
index 007d30c40..000000000
--- a/tests/test_thread_config_repo.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# TODO: thread_config_repo was removed in refactoring; update tests to use thread_repo / thread_launch_pref_repo
-import pytest
-
-pytest.skip("thread_config_repo module removed — needs migration to thread_repo", allow_module_level=True)
-
-import sqlite3 # noqa: E402
-from pathlib import Path # noqa: E402
-
-from storage.providers.sqlite.thread_config_repo import SQLiteThreadConfigRepo # noqa: F401
-from storage.providers.supabase.thread_config_repo import SupabaseThreadConfigRepo
-
-from backend.web.utils import helpers
-
-
-def test_migrate_thread_metadata_table(tmp_path):
- db_path = tmp_path / "leon.db"
- with sqlite3.connect(str(db_path)) as conn:
- conn.execute("CREATE TABLE thread_metadata (thread_id TEXT PRIMARY KEY, sandbox_type TEXT NOT NULL, cwd TEXT, model TEXT)")
- conn.execute(
- "INSERT INTO thread_metadata (thread_id, sandbox_type, cwd, model) VALUES (?, ?, ?, ?)",
- ("t-1", "local", "/tmp/ws", "m-1"),
- )
- conn.commit()
-
- repo = SQLiteThreadConfigRepo(db_path)
- try:
- assert repo.lookup_metadata("t-1") == ("local", "/tmp/ws")
- assert repo.lookup_model("t-1") == "m-1"
- finally:
- repo.close()
-
- with sqlite3.connect(str(db_path)) as conn:
- tables = {r[0] for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'")}
- assert "thread_config" in tables
- assert "thread_metadata" not in tables
-
-
-def test_save_and_lookup_thread_config(tmp_path):
- db_path = tmp_path / "leon.db"
- repo = SQLiteThreadConfigRepo(db_path)
- try:
- repo.save_metadata("t-2", "docker", "/workspace")
- repo.save_model("t-2", "anthropic/claude-sonnet-4.6")
- assert repo.lookup_metadata("t-2") == ("docker", "/workspace")
- assert repo.lookup_model("t-2") == "anthropic/claude-sonnet-4.6"
- repo.update_fields("t-2", queue_mode="followup", observation_provider="langfuse")
- cfg = repo.lookup_config("t-2")
- assert cfg is not None
- assert cfg["queue_mode"] == "followup"
- assert cfg["observation_provider"] == "langfuse"
- finally:
- repo.close()
-
-
-def test_helpers_compatibility_api(tmp_path, monkeypatch):
- db_path = tmp_path / "leon.db"
- monkeypatch.setattr(helpers, "DB_PATH", Path(db_path))
-
- helpers.init_thread_config("t-3", "local", "/tmp/p")
- helpers.save_thread_model("t-3", "m-3")
-
- config = helpers.load_thread_config("t-3")
- assert config is not None
- assert (config.sandbox_type, config.cwd) == ("local", "/tmp/p")
- assert helpers.lookup_thread_model("t-3") == "m-3"
- helpers.save_thread_config("t-3", observation_provider="langsmith")
- config2 = helpers.load_thread_config("t-3")
- assert config2 is not None
- assert config2.observation_provider == "langsmith"
-
-
-from tests.fakes.supabase import FakeSupabaseClient
-
-
-def test_supabase_thread_config_repo_save_and_lookup():
- tables: dict[str, list[dict]] = {"thread_config": []}
- repo = SupabaseThreadConfigRepo(client=FakeSupabaseClient(tables=tables))
-
- repo.save_metadata("t-1", "docker", "/workspace")
- repo.save_model("t-1", "anthropic/claude-sonnet-4.6")
-
- assert repo.lookup_metadata("t-1") == ("docker", "/workspace")
- assert repo.lookup_model("t-1") == "anthropic/claude-sonnet-4.6"
-
- repo.save_model("t-2", "openai/gpt-5")
- assert repo.lookup_metadata("t-2") == ("local", None)
- assert repo.lookup_model("t-2") == "openai/gpt-5"
- repo.update_fields("t-1", queue_mode="followup", observation_provider="langfuse")
- cfg = repo.lookup_config("t-1")
- assert cfg is not None
- assert cfg["queue_mode"] == "followup"
- assert cfg["observation_provider"] == "langfuse"
-
-
-def test_supabase_thread_config_repo_delete():
- tables: dict[str, list[dict]] = {"thread_config": []}
- repo = SupabaseThreadConfigRepo(client=FakeSupabaseClient(tables=tables))
- repo.save_metadata("t-1", "docker", "/workspace")
- repo.save_metadata("t-2", "local", None)
-
- repo.delete_thread_config("t-1")
- assert repo.lookup_metadata("t-1") is None
- assert repo.lookup_metadata("t-2") == ("local", None)
-
-
-def test_sqlite_thread_config_repo_delete(tmp_path):
- db_path = tmp_path / "leon.db"
- repo = SQLiteThreadConfigRepo(db_path)
- try:
- repo.save_metadata("t-1", "docker", "/workspace")
- repo.save_metadata("t-2", "local", None)
- repo.delete_thread_config("t-1")
- assert repo.lookup_metadata("t-1") is None
- assert repo.lookup_metadata("t-2") == ("local", None)
- finally:
- repo.close()
-
-
-def test_supabase_thread_config_repo_requires_compatible_client():
- with pytest.raises(RuntimeError, match="table\\(name\\)"):
- SupabaseThreadConfigRepo(client=object())
diff --git a/tests/test_tool_registry_runner.py b/tests/test_tool_registry_runner.py
deleted file mode 100644
index 934ae93ca..000000000
--- a/tests/test_tool_registry_runner.py
+++ /dev/null
@@ -1,339 +0,0 @@
-"""Tests for ToolRegistry, ToolRunner, and ToolValidator (P0/P1 verification).
-
-Covers:
-- P0: Three-tier error normalization (Layer 1: validation, Layer 2: execution, Layer 3: soft)
-- P1: ToolRegistry inline/deferred mode
-- P1: ToolRunner dispatches registered tools and normalizes errors
-"""
-
-from __future__ import annotations
-
-from unittest.mock import MagicMock
-
-import pytest
-
-from core.runtime.errors import InputValidationError
-from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry
-from core.runtime.runner import ToolRunner
-from core.runtime.validator import ToolValidator
-
-# ---------------------------------------------------------------------------
-# ToolRegistry
-# ---------------------------------------------------------------------------
-
-
-class TestToolRegistry:
- def _make_entry(self, name: str, mode: ToolMode = ToolMode.INLINE) -> ToolEntry:
- return ToolEntry(
- name=name,
- mode=mode,
- schema={"name": name, "description": f"{name} tool"},
- handler=lambda: f"result:{name}",
- source="test",
- )
-
- def test_register_and_get(self):
- reg = ToolRegistry()
- entry = self._make_entry("Read")
- reg.register(entry)
- assert reg.get("Read") is entry
-
- def test_get_unknown_returns_none(self):
- reg = ToolRegistry()
- assert reg.get("NonExistent") is None
-
- def test_inline_tools_appear_in_get_inline_schemas(self):
- reg = ToolRegistry()
- reg.register(self._make_entry("Read", ToolMode.INLINE))
- reg.register(self._make_entry("TaskCreate", ToolMode.DEFERRED))
- schemas = reg.get_inline_schemas()
- names = [s["name"] for s in schemas]
- assert "Read" in names
- assert "TaskCreate" not in names # P1: deferred not in inline
-
- def test_deferred_tools_not_in_inline_schemas(self):
- reg = ToolRegistry()
- reg.register(self._make_entry("TaskCreate", ToolMode.DEFERRED))
- reg.register(self._make_entry("TaskUpdate", ToolMode.DEFERRED))
- assert reg.get_inline_schemas() == []
-
- def test_search_finds_by_name(self):
- reg = ToolRegistry()
- reg.register(self._make_entry("TaskCreate", ToolMode.DEFERRED))
- reg.register(self._make_entry("Read", ToolMode.INLINE))
- results = reg.search("task")
- names = [e.name for e in results]
- assert "TaskCreate" in names
-
- def test_search_includes_deferred_tools(self):
- """tool_search must discover deferred tools too."""
- reg = ToolRegistry()
- reg.register(self._make_entry("TaskCreate", ToolMode.DEFERRED))
- results = reg.search("TaskCreate")
- assert any(e.name == "TaskCreate" for e in results)
-
- def test_allowed_tools_filter(self):
- reg = ToolRegistry(allowed_tools={"Read", "Grep"})
- reg.register(self._make_entry("Read"))
- reg.register(self._make_entry("Grep"))
- reg.register(self._make_entry("Bash"))
- assert reg.get("Read") is not None
- assert reg.get("Grep") is not None
- assert reg.get("Bash") is None # filtered out
-
- def test_dynamic_schema_callable(self):
- call_count = 0
-
- def schema_fn() -> dict:
- nonlocal call_count
- call_count += 1
- return {"name": "DynTool", "description": "dynamic"}
-
- reg = ToolRegistry()
- entry = ToolEntry(
- name="DynTool",
- mode=ToolMode.INLINE,
- schema=schema_fn,
- handler=lambda: "ok",
- source="test",
- )
- reg.register(entry)
- schemas = reg.get_inline_schemas()
- assert call_count >= 1
- assert any(s["name"] == "DynTool" for s in schemas)
-
-
-# ---------------------------------------------------------------------------
-# ToolValidator
-# ---------------------------------------------------------------------------
-
-
-class TestToolValidator:
- def _schema(self, required: list[str], props: dict) -> dict:
- return {
- "name": "TestTool",
- "parameters": {
- "type": "object",
- "required": required,
- "properties": {k: {"type": v} for k, v in props.items()},
- },
- }
-
- def test_valid_args_pass(self):
- v = ToolValidator()
- schema = self._schema(["file_path"], {"file_path": "string"})
- result = v.validate(schema, {"file_path": "/tmp/x"})
- assert result.ok
-
- def test_missing_required_raises_layer1(self):
- v = ToolValidator()
- schema = self._schema(["file_path"], {"file_path": "string"})
- with pytest.raises(InputValidationError) as exc_info:
- v.validate(schema, {})
- assert "file_path" in str(exc_info.value)
- assert "missing" in str(exc_info.value)
-
- def test_wrong_type_raises_layer1(self):
- v = ToolValidator()
- schema = self._schema(["count"], {"count": "integer"})
- with pytest.raises(InputValidationError):
- v.validate(schema, {"count": "not-an-int"})
-
- def test_extra_params_allowed(self):
- v = ToolValidator()
- schema = self._schema(["a"], {"a": "string"})
- result = v.validate(schema, {"a": "hello", "extra": "ok"})
- assert result.ok
-
-
-# ---------------------------------------------------------------------------
-# ToolRunner — P0 error normalization
-# ---------------------------------------------------------------------------
-
-
-def _make_runner(entries: list[ToolEntry]) -> ToolRunner:
- reg = ToolRegistry()
- for e in entries:
- reg.register(e)
- return ToolRunner(registry=reg)
-
-
-def _make_tool_call_request(name: str, args: dict, call_id: str = "tc-1"):
- req = MagicMock()
- req.tool_call = {"name": name, "args": args, "id": call_id}
- return req
-
-
-class TestToolRunnerErrorNormalization:
- """P0: three-tier error normalization."""
-
- def test_layer1_missing_param_returns_input_validation_error(self):
- entry = ToolEntry(
- name="Read",
- mode=ToolMode.INLINE,
- schema={
- "name": "Read",
- "parameters": {
- "type": "object",
- "required": ["file_path"],
- "properties": {"file_path": {"type": "string"}},
- },
- },
- handler=lambda file_path: "content",
- source="test",
- )
- runner = _make_runner([entry])
- req = _make_tool_call_request("Read", {}) # missing file_path
-
- called_upstream = []
-
- def upstream(r):
- called_upstream.append(r)
- return MagicMock()
-
- result = runner.wrap_tool_call(req, upstream)
- # Layer 1 error format: InputValidationError: {name} failed due to...
- assert "InputValidationError" in result.content
- assert "Read" in result.content
- assert not called_upstream # must not fall through to upstream
-
- def test_layer2_handler_exception_returns_tool_use_error(self):
- def bad_handler(**kwargs):
- raise ValueError("disk full")
-
- entry = ToolEntry(
- name="Write",
- mode=ToolMode.INLINE,
- schema={
- "name": "Write",
- "parameters": {
- "type": "object",
- "required": [],
- "properties": {},
- },
- },
- handler=bad_handler,
- source="test",
- )
- runner = _make_runner([entry])
- req = _make_tool_call_request("Write", {})
- result = runner.wrap_tool_call(req, lambda r: MagicMock())
- # Layer 2 error format: ...
- assert "" in result.content
- assert "disk full" in result.content
-
- def test_layer3_handler_returns_soft_failure_text(self):
- def soft_fail(**kwargs):
- return "No files found"
-
- entry = ToolEntry(
- name="Glob",
- mode=ToolMode.INLINE,
- schema={
- "name": "Glob",
- "parameters": {
- "type": "object",
- "required": ["pattern"],
- "properties": {"pattern": {"type": "string"}},
- },
- },
- handler=soft_fail,
- source="test",
- )
- runner = _make_runner([entry])
- req = _make_tool_call_request("Glob", {"pattern": "**/*.xyz"})
- result = runner.wrap_tool_call(req, lambda r: MagicMock())
- # Layer 3: plain text, no tags
- assert result.content == "No files found"
- assert "" not in result.content
- assert "InputValidationError" not in result.content
-
- def test_unknown_tool_falls_through_to_upstream(self):
- runner = _make_runner([]) # empty registry
- req = _make_tool_call_request("UnknownMCPTool", {})
- upstream_called = []
-
- def upstream(r):
- upstream_called.append(r)
- msg = MagicMock()
- msg.content = "mcp result"
- return msg
-
- result = runner.wrap_tool_call(req, upstream)
- assert upstream_called
- assert result.content == "mcp result"
-
-
-class TestToolRunnerInlineInjection:
- """P1: ToolRunner injects inline schemas into model call."""
-
- def test_inline_schemas_injected(self):
- entry = ToolEntry(
- name="Read",
- mode=ToolMode.INLINE,
- schema={"name": "Read", "description": "read file"},
- handler=lambda: "ok",
- source="test",
- )
- runner = _make_runner([entry])
-
- # Build a mock ModelRequest
- request = MagicMock()
- request.tools = []
-
- captured = []
-
- def handler(req):
- captured.append(req)
- return MagicMock()
-
- request.override.return_value = request
- runner.wrap_model_call(request, handler)
-
- # Should have called override with tools containing Read
- assert request.override.called
- call_kwargs = request.override.call_args
- _tools_arg = call_kwargs[1].get("tools") or (call_kwargs[0][0] if call_kwargs[0] else None)
- # override was called — inline tools were injected
-
- def test_deferred_schemas_not_injected(self):
- deferred = ToolEntry(
- name="TaskCreate",
- mode=ToolMode.DEFERRED,
- schema={"name": "TaskCreate", "description": "create task"},
- handler=lambda: "ok",
- source="test",
- )
- runner = _make_runner([deferred])
- schemas = runner._registry.get_inline_schemas()
- assert all(s["name"] != "TaskCreate" for s in schemas)
-
-
-# ---------------------------------------------------------------------------
-# P1: tool_modes from config honored
-# ---------------------------------------------------------------------------
-
-
-class TestToolModeFromConfig:
- """Verify tool_modes config is applied during service init."""
-
- def test_task_service_registers_deferred(self, tmp_path):
- reg = ToolRegistry()
- from core.tools.task.service import TaskService
-
- _svc = TaskService(registry=reg, db_path=tmp_path / "test.db")
- # TaskCreate/TaskUpdate/TaskList/TaskGet should be DEFERRED
- for tool_name in ["TaskCreate", "TaskGet", "TaskList", "TaskUpdate"]:
- entry = reg.get(tool_name)
- assert entry is not None, f"{tool_name} not registered"
- assert entry.mode == ToolMode.DEFERRED, f"{tool_name} should be DEFERRED, got {entry.mode}"
-
- def test_search_service_registers_inline(self, tmp_path):
- reg = ToolRegistry()
- from core.tools.search.service import SearchService
-
- _svc = SearchService(registry=reg, workspace_root=tmp_path)
- for tool_name in ["Grep", "Glob"]:
- entry = reg.get(tool_name)
- assert entry is not None, f"{tool_name} not registered"
- assert entry.mode == ToolMode.INLINE, f"{tool_name} should be INLINE, got {entry.mode}"
diff --git a/uv.lock b/uv.lock
index 721e5c891..78f682840 100644
--- a/uv.lock
+++ b/uv.lock
@@ -375,6 +375,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/90/45/f458fa2c388e79dd9d8b9b0c99f1d31b568f27388f2fdba7bb66bbc0c6ed/cachetools-6.2.6-py3-none-any.whl", hash = "sha256:8c9717235b3c651603fff0076db52d6acbfd1b338b8ed50256092f7ce9c85bda", size = 11668, upload-time = "2026-01-27T20:32:58.527Z" },
]
+[[package]]
+name = "cattrs"
+version = "26.1.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "attrs" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a0/ec/ba18945e7d6e55a58364d9fb2e46049c1c2998b3d805f19b703f14e81057/cattrs-26.1.0.tar.gz", hash = "sha256:fa239e0f0ec0715ba34852ce813986dfed1e12117e209b816ab87401271cdd40", size = 495672, upload-time = "2026-02-18T22:15:19.406Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/80/56/60547f7801b97c67e97491dc3d9ade9fbccbd0325058fd3dfcb2f5d98d90/cattrs-26.1.0-py3-none-any.whl", hash = "sha256:d1e0804c42639494d469d08d4f26d6b9de9b8ab26b446db7b5f8c2e97f7c3096", size = 73054, upload-time = "2026-02-18T22:15:17.958Z" },
+]
+
[[package]]
name = "certifi"
version = "2026.1.4"
@@ -719,6 +732,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" },
]
+[[package]]
+name = "docstring-to-markdown"
+version = "0.17"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "importlib-metadata" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/52/d8/8abe80d62c5dce1075578031bcfde07e735bcf0afe2886dd48b470162ab4/docstring_to_markdown-0.17.tar.gz", hash = "sha256:df72a112294c7492487c9da2451cae0faeee06e86008245c188c5761c9590ca3", size = 32260, upload-time = "2025-05-02T15:09:07.932Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/56/7b/af3d0da15bed3a8665419bb3a630585756920f4ad67abfdfef26240ebcc0/docstring_to_markdown-0.17-py3-none-any.whl", hash = "sha256:fd7d5094aa83943bf5f9e1a13701866b7c452eac19765380dead666e36d3711c", size = 23479, upload-time = "2025-05-02T15:09:06.676Z" },
+]
+
[[package]]
name = "duckduckgo-search"
version = "8.1.1"
@@ -1089,6 +1115,34 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
]
+[[package]]
+name = "jedi"
+version = "0.19.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "parso" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" },
+]
+
+[[package]]
+name = "jedi-language-server"
+version = "0.41.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "cattrs" },
+ { name = "docstring-to-markdown" },
+ { name = "jedi" },
+ { name = "lsprotocol" },
+ { name = "pygls" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/f3/34/4a35094c680040c8dd598b1ee9153a701289351c1dcbad1a0f2d196c524b/jedi_language_server-0.41.3.tar.gz", hash = "sha256:113ec22b95fadaceefbb704b5f365384bed296b82ede59026be375ecc97a9f8a", size = 29113, upload-time = "2024-02-26T04:28:05.521Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b6/67/2cf4419a8c418b0e5cba0b43dc1ea33a0bb42907694d6a786a3644889f32/jedi_language_server-0.41.3-py3-none-any.whl", hash = "sha256:7411f7479cdc9e9ea495f91e20b182a5d00170c0a8a4a87d3a147462282c06af", size = 27615, upload-time = "2024-02-26T04:28:02.084Z" },
+]
+
[[package]]
name = "jiter"
version = "0.12.0"
@@ -1422,10 +1476,12 @@ dependencies = [
{ name = "langgraph" },
{ name = "langgraph-checkpoint-postgres" },
{ name = "langgraph-checkpoint-sqlite" },
+ { name = "multilspy" },
{ name = "pillow" },
{ name = "psycopg", extra = ["binary"] },
{ name = "pydantic" },
{ name = "pyjwt" },
+ { name = "pyright" },
{ name = "pyyaml" },
{ name = "rich" },
{ name = "sse-starlette" },
@@ -1513,6 +1569,7 @@ requires-dist = [
{ name = "langgraph-checkpoint-sqlite", specifier = ">=2.0.0" },
{ name = "langsmith", marker = "extra == 'all'", specifier = ">=0.1.0" },
{ name = "langsmith", marker = "extra == 'langsmith'", specifier = ">=0.1.0" },
+ { name = "multilspy", specifier = ">=0.0.15" },
{ name = "opentelemetry-api", marker = "extra == 'otel'", specifier = ">=1.20.0" },
{ name = "opentelemetry-exporter-otlp", marker = "extra == 'otel'", specifier = ">=1.20.0" },
{ name = "opentelemetry-sdk", marker = "extra == 'otel'", specifier = ">=1.20.0" },
@@ -1523,6 +1580,7 @@ requires-dist = [
{ name = "pymupdf", marker = "extra == 'all'", specifier = ">=1.24.0" },
{ name = "pymupdf", marker = "extra == 'docs'", specifier = ">=1.24.0" },
{ name = "pymupdf", marker = "extra == 'pdf'", specifier = ">=1.24.0" },
+ { name = "pyright", specifier = ">=1.1.0" },
{ name = "python-pptx", marker = "extra == 'all'", specifier = ">=1.0.0" },
{ name = "python-pptx", marker = "extra == 'docs'", specifier = ">=1.0.0" },
{ name = "python-pptx", marker = "extra == 'pptx'", specifier = ">=1.0.0" },
@@ -1560,6 +1618,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0c/29/0348de65b8cc732daa3e33e67806420b2ae89bdce2b04af740289c5c6c8c/loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c", size = 61595, upload-time = "2024-12-06T11:20:54.538Z" },
]
+[[package]]
+name = "lsprotocol"
+version = "2023.0.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "attrs" },
+ { name = "cattrs" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/9d/f6/6e80484ec078d0b50699ceb1833597b792a6c695f90c645fbaf54b947e6f/lsprotocol-2023.0.1.tar.gz", hash = "sha256:cc5c15130d2403c18b734304339e51242d3018a05c4f7d0f198ad6e0cd21861d", size = 69434, upload-time = "2024-01-09T17:21:12.625Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8d/37/2351e48cb3309673492d3a8c59d407b75fb6630e560eb27ecd4da03adc9a/lsprotocol-2023.0.1-py3-none-any.whl", hash = "sha256:c75223c9e4af2f24272b14c6375787438279369236cd568f596d4951052a60f2", size = 70826, upload-time = "2024-01-09T17:21:14.491Z" },
+]
+
[[package]]
name = "lxml"
version = "6.0.2"
@@ -1876,6 +1947,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" },
]
+[[package]]
+name = "multilspy"
+version = "0.0.15"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "jedi-language-server" },
+ { name = "psutil" },
+ { name = "requests" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/d8/a8/4d6ab48e624f911eb5229aa01b3524b916470c9d036a9e8cc96d6fb81673/multilspy-0.0.15.tar.gz", hash = "sha256:b27a0b7c5c5306216b31fe1df9b4a42d2797735d0a78928e0df9ef8dfbcc97c5", size = 120639, upload-time = "2025-04-03T07:01:27.216Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/97/4d/b9d3492d6a7a2536498fc7fd49c1cc7bc86a41acf93b0ad967d75dbe5cd6/multilspy-0.0.15-py3-none-any.whl", hash = "sha256:3fa88939b953ed5d39aba4688a34105ec1e5cf2b2f778167fee2b78b3c0e1427", size = 137361, upload-time = "2025-04-03T07:01:25.492Z" },
+]
+
[[package]]
name = "multipart"
version = "1.3.0"
@@ -2176,6 +2262,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
]
+[[package]]
+name = "parso"
+version = "0.8.6"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/81/76/a1e769043c0c0c9fe391b702539d594731a4362334cdf4dc25d0c09761e7/parso-0.8.6.tar.gz", hash = "sha256:2b9a0332696df97d454fa67b81618fd69c35a7b90327cbe6ba5c92d2c68a7bfd", size = 401621, upload-time = "2026-02-09T15:45:24.425Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/b6/61/fae042894f4296ec49e3f193aff5d7c18440da9e48102c3315e1bc4519a7/parso-0.8.6-py2.py3-none-any.whl", hash = "sha256:2c549f800b70a5c4952197248825584cb00f033b29c692671d3bf08bf380baff", size = 106894, upload-time = "2026-02-09T15:45:21.391Z" },
+]
+
[[package]]
name = "pillow"
version = "12.1.0"
@@ -2403,6 +2498,34 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/57/bf/2086963c69bdac3d7cff1cc7ff79b8ce5ea0bec6797a017e1be338a46248/protobuf-6.33.5-py3-none-any.whl", hash = "sha256:69915a973dd0f60f31a08b8318b73eab2bd6a392c79184b3612226b0a3f8ec02", size = 170687, upload-time = "2026-01-29T21:51:32.557Z" },
]
+[[package]]
+name = "psutil"
+version = "7.2.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/51/08/510cbdb69c25a96f4ae523f733cdc963ae654904e8db864c07585ef99875/psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b", size = 130595, upload-time = "2026-01-28T18:14:57.293Z" },
+ { url = "https://files.pythonhosted.org/packages/d6/f5/97baea3fe7a5a9af7436301f85490905379b1c6f2dd51fe3ecf24b4c5fbf/psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea", size = 131082, upload-time = "2026-01-28T18:14:59.732Z" },
+ { url = "https://files.pythonhosted.org/packages/37/d6/246513fbf9fa174af531f28412297dd05241d97a75911ac8febefa1a53c6/psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63", size = 181476, upload-time = "2026-01-28T18:15:01.884Z" },
+ { url = "https://files.pythonhosted.org/packages/b8/b5/9182c9af3836cca61696dabe4fd1304e17bc56cb62f17439e1154f225dd3/psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312", size = 184062, upload-time = "2026-01-28T18:15:04.436Z" },
+ { url = "https://files.pythonhosted.org/packages/16/ba/0756dca669f5a9300d0cbcbfae9a4c30e446dfc7440ffe43ded5724bfd93/psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b", size = 139893, upload-time = "2026-01-28T18:15:06.378Z" },
+ { url = "https://files.pythonhosted.org/packages/1c/61/8fa0e26f33623b49949346de05ec1ddaad02ed8ba64af45f40a147dbfa97/psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9", size = 135589, upload-time = "2026-01-28T18:15:08.03Z" },
+ { url = "https://files.pythonhosted.org/packages/81/69/ef179ab5ca24f32acc1dac0c247fd6a13b501fd5534dbae0e05a1c48b66d/psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00", size = 130664, upload-time = "2026-01-28T18:15:09.469Z" },
+ { url = "https://files.pythonhosted.org/packages/7b/64/665248b557a236d3fa9efc378d60d95ef56dd0a490c2cd37dafc7660d4a9/psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9", size = 131087, upload-time = "2026-01-28T18:15:11.724Z" },
+ { url = "https://files.pythonhosted.org/packages/d5/2e/e6782744700d6759ebce3043dcfa661fb61e2fb752b91cdeae9af12c2178/psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a", size = 182383, upload-time = "2026-01-28T18:15:13.445Z" },
+ { url = "https://files.pythonhosted.org/packages/57/49/0a41cefd10cb7505cdc04dab3eacf24c0c2cb158a998b8c7b1d27ee2c1f5/psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf", size = 185210, upload-time = "2026-01-28T18:15:16.002Z" },
+ { url = "https://files.pythonhosted.org/packages/dd/2c/ff9bfb544f283ba5f83ba725a3c5fec6d6b10b8f27ac1dc641c473dc390d/psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1", size = 141228, upload-time = "2026-01-28T18:15:18.385Z" },
+ { url = "https://files.pythonhosted.org/packages/f2/fc/f8d9c31db14fcec13748d373e668bc3bed94d9077dbc17fb0eebc073233c/psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841", size = 136284, upload-time = "2026-01-28T18:15:19.912Z" },
+ { url = "https://files.pythonhosted.org/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486", size = 129090, upload-time = "2026-01-28T18:15:22.168Z" },
+ { url = "https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979", size = 129859, upload-time = "2026-01-28T18:15:23.795Z" },
+ { url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" },
+ { url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" },
+ { url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" },
+ { url = "https://files.pythonhosted.org/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988", size = 137737, upload-time = "2026-01-28T18:15:33.849Z" },
+ { url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" },
+]
+
[[package]]
name = "psycopg"
version = "3.3.3"
@@ -2594,6 +2717,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9b/4d/b9add7c84060d4c1906abe9a7e5359f2a60f7a9a4f67268b2766673427d8/pyee-13.0.0-py3-none-any.whl", hash = "sha256:48195a3cddb3b1515ce0695ed76036b5ccc2ef3a9f963ff9f77aec0139845498", size = 15730, upload-time = "2025-03-17T18:53:14.532Z" },
]
+[[package]]
+name = "pygls"
+version = "1.3.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "cattrs" },
+ { name = "lsprotocol" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/86/b9/41d173dad9eaa9db9c785a85671fc3d68961f08d67706dc2e79011e10b5c/pygls-1.3.1.tar.gz", hash = "sha256:140edceefa0da0e9b3c533547c892a42a7d2fd9217ae848c330c53d266a55018", size = 45527, upload-time = "2024-03-26T18:44:25.679Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/11/19/b74a10dd24548e96e8c80226cbacb28b021bc3a168a7d2709fb0d0185348/pygls-1.3.1-py3-none-any.whl", hash = "sha256:6e00f11efc56321bdeb6eac04f6d86131f654c7d49124344a9ebb968da3dd91e", size = 56031, upload-time = "2024-03-26T18:44:24.249Z" },
+]
+
[[package]]
name = "pygments"
version = "2.19.2"
@@ -3030,7 +3166,7 @@ wheels = [
[[package]]
name = "requests"
-version = "2.32.5"
+version = "2.32.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "certifi" },
@@ -3038,9 +3174,9 @@ dependencies = [
{ name = "idna" },
{ name = "urllib3" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218, upload-time = "2024-05-29T15:37:49.536Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
+ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928, upload-time = "2024-05-29T15:37:47.027Z" },
]
[[package]]