diff --git a/.github/workflows/deploy-staging.yml b/.github/workflows/deploy-staging.yml index ee18d0d38..54eee564c 100644 --- a/.github/workflows/deploy-staging.yml +++ b/.github/workflows/deploy-staging.yml @@ -7,6 +7,9 @@ name: Deploy Staging # Both update the staging apps to the target branch, then deploy. on: + push: + branches: + - pr188-agent-optimize pull_request: types: [labeled] workflow_dispatch: @@ -23,9 +26,12 @@ jobs: deploy-staging: # For label trigger: only run when the label is exactly "deploy-staging" if: > + github.event_name == 'push' || github.event_name == 'workflow_dispatch' || (github.event_name == 'pull_request' && github.event.label.name == 'deploy-staging') runs-on: ubuntu-latest + env: + STAGING_STACK_UUID: fasbsube26s75ag6qus5bpi2 steps: - name: Resolve target ref @@ -33,33 +39,99 @@ jobs: run: | if [ "${{ github.event_name }}" = "pull_request" ]; then echo "ref=${{ github.head_ref }}" >> "$GITHUB_OUTPUT" + elif [ "${{ github.event_name }}" = "push" ]; then + echo "ref=${{ github.ref_name }}" >> "$GITHUB_OUTPUT" else echo "ref=${{ inputs.ref }}" >> "$GITHUB_OUTPUT" fi - - name: Update staging backend branch + - name: Check out target ref + uses: actions/checkout@v4 + with: + ref: ${{ steps.ref.outputs.ref }} + + - name: Resolve target commit + id: target run: | - curl -s -X PATCH "${{ secrets.COOLIFY_URL }}/api/v1/applications/${{ secrets.COOLIFY_BACKEND_STAGING_UUID }}" \ - -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}" \ - -H "Content-Type: application/json" \ - -d '{"git_branch": "${{ steps.ref.outputs.ref }}"}' + set -euo pipefail + echo "sha=$(git rev-parse HEAD)" >> "$GITHUB_OUTPUT" + + - name: Assert repo staging compose contract + run: | + set -euo pipefail + grep -F "leon-home:/root/.leon" docker-compose.yml >/dev/null + grep -F "volumes:" docker-compose.yml >/dev/null - - name: Update staging frontend branch + - name: Update staging stack branch run: | - curl -s -X PATCH "${{ secrets.COOLIFY_URL }}/api/v1/applications/${{ secrets.COOLIFY_FRONTEND_STAGING_UUID }}" \ + set -euo pipefail + body="$(curl -sS --fail-with-body -X PATCH "${{ secrets.COOLIFY_URL }}/api/v1/applications/${STAGING_STACK_UUID}" \ -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}" \ -H "Content-Type: application/json" \ - -d '{"git_branch": "${{ steps.ref.outputs.ref }}"}' + -d "{\"git_branch\": \"${{ steps.ref.outputs.ref }}\"}")" + echo "$body" + printf '%s' "$body" | jq -e --arg uuid "$STAGING_STACK_UUID" '.uuid == $uuid' >/dev/null + + - name: Deploy staging stack + id: deploy + run: | + set -euo pipefail + body="$(curl -sS --fail-with-body "${{ secrets.COOLIFY_URL }}/api/v1/deploy?uuid=${STAGING_STACK_UUID}&force=false" \ + -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}")" + echo "$body" + printf '%s' "$body" | jq -e --arg uuid "$STAGING_STACK_UUID" '.deployments[0].resource_uuid == $uuid' >/dev/null + echo "deployment_uuid=$(printf '%s' "$body" | jq -r '.deployments[0].deployment_uuid')" >> "$GITHUB_OUTPUT" + + - name: Wait for staging deployment + run: | + set -euo pipefail + deployment_uuid="${{ steps.deploy.outputs.deployment_uuid }}" + for _ in $(seq 1 60); do + body="$(curl -sS --fail-with-body "${{ secrets.COOLIFY_URL }}/api/v1/deployments/${deployment_uuid}" \ + -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}")" + status="$(printf '%s' "$body" | jq -r '.status')" + echo "deployment status: $status" + if [ "$status" = "finished" ]; then + exit 0 + fi + if [ "$status" != "queued" ] && [ "$status" != "in_progress" ]; then + echo "$body" + exit 1 + fi + sleep 10 + done + echo "Timed out waiting for staging deployment ${deployment_uuid}" + exit 1 - - name: Deploy backend to staging + - name: Verify Coolify staging contract run: | - curl -sX GET "${{ secrets.COOLIFY_URL }}/api/v1/deploy?uuid=${{ secrets.COOLIFY_BACKEND_STAGING_UUID }}&force=false" \ - -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}" + set -euo pipefail + body="$(curl -sS --fail-with-body "${{ secrets.COOLIFY_URL }}/api/v1/applications/${STAGING_STACK_UUID}" \ + -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}")" + echo "$body" | jq '{uuid,git_branch,docker_compose_location}' + printf '%s' "$body" | jq -e --arg ref "${{ steps.ref.outputs.ref }}" '.git_branch == $ref' >/dev/null + printf '%s' "$body" | jq -e '.docker_compose_raw | contains("leon-home:/root/.leon")' >/dev/null + printf '%s' "$body" | jq -e --arg volume "${STAGING_STACK_UUID}_leon-home:/root/.leon" '.docker_compose | contains($volume)' >/dev/null + printf '%s' "$body" | jq -e --arg sha "${{ steps.target.outputs.sha }}" '.docker_compose | contains($sha)' >/dev/null - - name: Deploy frontend to staging + - name: Verify staging health contract run: | - curl -sX GET "${{ secrets.COOLIFY_URL }}/api/v1/deploy?uuid=${{ secrets.COOLIFY_FRONTEND_STAGING_UUID }}&force=false" \ - -H "Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}" + set -euo pipefail + for attempt in $(seq 1 18); do + status="$(curl -sS -o /tmp/staging-health.json -w '%{http_code}' "https://app.staging.mycel.nextmind.space/api/monitor/health")" + echo "health attempt ${attempt}: status=${status}" + if [ "$status" = "200" ]; then + body="$(cat /tmp/staging-health.json)" + echo "$body" + printf '%s' "$body" | jq -e '.db.path == "/root/.leon/sandbox.db"' >/dev/null + printf '%s' "$body" | jq -e '.db.exists == true' >/dev/null + exit 0 + fi + cat /tmp/staging-health.json || true + sleep 10 + done + echo "Staging health contract did not become ready in time" + exit 1 - name: Comment on PR with staging URL if: github.event_name == 'pull_request' @@ -70,5 +142,5 @@ jobs: issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, - body: `🚀 **预发部署已触发**\n\n- 前端: https://app.staging.mycel.nextmind.space\n- 后端: https://api.staging.mycel.nextmind.space\n\n分支: \`${{ steps.ref.outputs.ref }}\`` + body: `🚀 **预发部署已触发**\n\n- 共享 Staging: https://app.staging.mycel.nextmind.space\n- API(同域反代): https://app.staging.mycel.nextmind.space/api\n\n分支: \`${{ steps.ref.outputs.ref }}\`` }) diff --git a/Dockerfile b/Dockerfile index e875ed19f..36bb7bf5a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,11 +7,13 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv # Install dependencies (cached layer before source copy) COPY pyproject.toml uv.lock ./ -RUN uv sync --frozen --no-dev --no-install-project +# @@@sandbox-sdk-image-parity - shared staging/provider inventory should reflect runtime truth, +# not "SDK missing from image" accidents while config files are present. +RUN uv sync --frozen --no-dev --extra sandbox --extra e2b --extra daytona --no-install-project # Copy source and install project COPY . . -RUN uv sync --frozen --no-dev +RUN uv sync --frozen --no-dev --extra sandbox --extra e2b --extra daytona ENV PATH="/app/.venv/bin:$PATH" diff --git a/README.md b/README.md index a7fdc9af7..f75571e6f 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ Full-featured web platform for managing and interacting with agents: ### Multi-Agent Communication -Agents are first-class social entities. They can discover each other, send messages, and collaborate autonomously: +Agents are first-class social entities. They can list chats, read messages, send messages, and collaborate autonomously: ``` Member (template) @@ -103,8 +103,10 @@ Member (template) └→ Thread (agent brain / conversation) ``` -- **`chat_send`**: Agent A messages Agent B; B responds autonomously -- **`directory`**: Agents browse and discover other entities +- **`list_chats`**: List active conversations with unread counts and participants +- **`read_messages`**: Read message history before responding +- **`send_message`**: Agent A messages Agent B; B responds autonomously +- **`search_messages`**: Search message history across chats - **Real-time delivery**: SSE-based chat with typing indicators and read receipts Humans also have entities — agents can initiate conversations with humans, not just the other way around. diff --git a/README.zh.md b/README.zh.md index 12bb8981a..1b3d31c87 100644 --- a/README.zh.md +++ b/README.zh.md @@ -95,7 +95,7 @@ cd frontend/app && npm run dev ### 多 Agent 通讯 -Agent 是一等公民的社交实体,可以互相发现、发送消息、自主协作: +Agent 是一等公民的社交实体,可以列出对话、读取消息、发送消息、自主协作: ``` Member(模板) @@ -103,8 +103,10 @@ Member(模板) └→ Thread(Agent 大脑 / 对话) ``` -- **`chat_send`**:Agent A 给 Agent B 发消息,B 自主回复 -- **`directory`**:Agent 浏览和发现其他实体 +- **`list_chats`**:列出活跃对话、未读数和参与者 +- **`read_messages`**:先读取消息历史,再决定如何回复 +- **`send_message`**:Agent A 给 Agent B 发消息,B 自主回复 +- **`search_messages`**:跨对话搜索消息历史 - **实时投递**:基于 SSE 的聊天,支持输入提示和已读回执 人类也有 Entity——Agent 可以主动找人类对话,而不只是被动响应。 diff --git a/backend/web/core/dependencies.py b/backend/web/core/dependencies.py index 52bc277a0..ef099c3c8 100644 --- a/backend/web/core/dependencies.py +++ b/backend/web/core/dependencies.py @@ -1,7 +1,6 @@ """FastAPI dependency injection functions.""" import asyncio -import os from typing import Annotated, Any from fastapi import Depends, FastAPI, HTTPException, Request @@ -9,18 +8,6 @@ from backend.web.services.agent_pool import get_or_create_agent, resolve_thread_sandbox from sandbox.thread_context import set_current_thread_id -# Dev bypass: set LEON_DEV_SKIP_AUTH=1 to skip JWT verification and inject a mock identity. -# WARNING: this bypasses ALL auth — never set in production. -_DEV_SKIP_AUTH = os.environ.get("LEON_DEV_SKIP_AUTH", "").lower() in ("1", "true", "yes") -_DEV_PAYLOAD = {"user_id": "dev-user"} - -if _DEV_SKIP_AUTH: - import logging as _logging - - _logging.getLogger(__name__).warning( - "LEON_DEV_SKIP_AUTH is active — JWT auth is BYPASSED for all requests. This must never be enabled in production." - ) - async def get_app(request: Request) -> FastAPI: """Get FastAPI app instance from request.""" @@ -36,9 +23,7 @@ def _get_auth_service(app: FastAPI): def _extract_jwt_payload(request: Request) -> dict: - """Extract and verify JWT payload from Bearer token. Returns {user_id}.""" - if _DEV_SKIP_AUTH: - return _DEV_PAYLOAD + """Extract and verify JWT payload from Bearer token. Returns {user_id, entity_id}.""" auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): raise HTTPException(401, "Missing or invalid Authorization header") @@ -52,14 +37,29 @@ def _extract_jwt_payload(request: Request) -> dict: async def get_current_user_id(request: Request) -> str: """Extract user_id from JWT and verify user exists. Returns 401 if user was deleted (e.g. DB reset).""" user_id = _extract_jwt_payload(request)["user_id"] - if _DEV_SKIP_AUTH: - return user_id member_repo = getattr(request.app.state, "member_repo", None) if member_repo and member_repo.get_by_id(user_id) is None: raise HTTPException(401, "User no longer exists — please re-login") return user_id +async def get_current_entity_id(request: Request) -> str: + """Derive entity_id for the authenticated human user. + + Supabase JWTs may omit custom entity claims, so keep the older + direct-claim path when present and otherwise derive the stable + human entity convention: f"{user_id}-1". + """ + payload = _extract_jwt_payload(request) + entity_id = payload.get("entity_id") + if entity_id: + return entity_id + user_id = payload.get("user_id") + if not user_id: + raise HTTPException(401, "Token missing user_id") + return f"{user_id}-1" + + async def verify_thread_owner( thread_id: str, user_id: Annotated[str, Depends(get_current_user_id)], diff --git a/backend/web/core/lifespan.py b/backend/web/core/lifespan.py index 13a76a4b2..e2860f177 100644 --- a/backend/web/core/lifespan.py +++ b/backend/web/core/lifespan.py @@ -14,84 +14,6 @@ from core.runtime.middleware.queue import MessageQueueManager -def _seed_dev_user(app: FastAPI) -> None: - """Create dev-user human member + initial agents if not yet seeded. - - Mirrors AuthService.register() but uses the fixed 'dev-user' ID that - matches _DEV_PAYLOAD, so list_members('dev-user') returns results. - """ - import logging - import time - from pathlib import Path - - from backend.web.services.member_service import MEMBERS_DIR, _write_agent_md, _write_json - from storage.contracts import MemberRow, MemberType - from storage.providers.sqlite.member_repo import generate_member_id - - log = logging.getLogger(__name__) - member_repo = app.state.member_repo - - dev_user_id = "dev-user" - - if member_repo.get_by_id(dev_user_id) is not None: - return # already seeded - - log.info("DEV: seeding dev-user member + initial agents") - now = time.time() - - # Human member row - member_repo.create( - MemberRow( - id=dev_user_id, - name="Dev", - type=MemberType.HUMAN, - created_at=now, - ) - ) - - # Initial agents (same as register()) - initial_agents = [ - {"name": "Toad", "description": "Curious and energetic assistant", "avatar": "toad.jpeg"}, - {"name": "Morel", "description": "Thoughtful senior analyst", "avatar": "morel.jpeg"}, - ] - assets_dir = Path(__file__).resolve().parents[3] / "assets" - - for agent_def in initial_agents: - agent_id = generate_member_id() - agent_dir = MEMBERS_DIR / agent_id - agent_dir.mkdir(parents=True, exist_ok=True) - _write_agent_md(agent_dir / "agent.md", name=agent_def["name"], description=agent_def["description"]) - _write_json( - agent_dir / "meta.json", - { - "status": "active", - "version": "1.0.0", - "created_at": int(now * 1000), - "updated_at": int(now * 1000), - }, - ) - member_repo.create( - MemberRow( - id=agent_id, - name=agent_def["name"], - type=MemberType.MYCEL_AGENT, - description=agent_def["description"], - config_dir=str(agent_dir), - owner_user_id=dev_user_id, - created_at=now, - ) - ) - src_avatar = assets_dir / agent_def["avatar"] - if src_avatar.exists(): - try: - from backend.web.routers.entities import process_and_save_avatar - - avatar_path = process_and_save_avatar(src_avatar, agent_id) - member_repo.update(agent_id, avatar=avatar_path, updated_at=now) - except Exception as e: - log.warning("DEV: avatar copy failed for %s: %s", agent_def["name"], e) - - @asynccontextmanager async def lifespan(app: FastAPI): """FastAPI lifespan context manager for startup and shutdown.""" @@ -114,7 +36,7 @@ async def lifespan(app: FastAPI): _storage_strategy = os.getenv("LEON_STORAGE_STRATEGY", "sqlite") if _storage_strategy == "supabase": - from backend.web.core.supabase_factory import create_supabase_client + from backend.web.core.supabase_factory import create_supabase_auth_client, create_supabase_client from storage.container import StorageContainer from storage.providers.supabase import ( SupabaseAccountRepo, @@ -144,6 +66,7 @@ async def lifespan(app: FastAPI): app.state.invite_code_repo = SupabaseInviteCodeRepo(_supabase_client) app.state.user_settings_repo = SupabaseUserSettingsRepo(_supabase_client) app.state._supabase_client = _supabase_client + app.state._supabase_auth_client_factory = create_supabase_auth_client app.state._storage_container = StorageContainer(strategy="supabase", supabase_client=_supabase_client) else: from storage.providers.sqlite.chat_repo import SQLiteChatEntityRepo, SQLiteChatMessageRepo, SQLiteChatRepo @@ -175,6 +98,7 @@ async def lifespan(app: FastAPI): accounts=app.state.account_repo, entities=app.state.entity_repo, supabase_client=_supabase_client, + supabase_auth_client_factory=create_supabase_auth_client, invite_codes=app.state.invite_code_repo, ) else: @@ -185,12 +109,6 @@ async def lifespan(app: FastAPI): supabase_client=None, ) - # Dev bypass: seed dev-user + initial agents on first startup - from backend.web.core.dependencies import _DEV_SKIP_AUTH - - if _DEV_SKIP_AUTH: - _seed_dev_user(app) - from backend.web.services.chat_events import ChatEventBus from backend.web.services.typing_tracker import TypingTracker @@ -259,30 +177,6 @@ async def lifespan(app: FastAPI): await cron_svc.start() app.state.cron_service = cron_svc - # @@@wechat-registry — create registry with delivery callback, auto-start all - from backend.web.services.wechat_service import WeChatConnectionRegistry, migrate_entity_id_dirs - from core.runtime.middleware.queue.formatters import format_wechat_message - - migrate_entity_id_dirs() - - async def _wechat_deliver(conn, msg): - """Delivery callback — routes WeChat messages to configured thread/chat.""" - routing = conn.routing - if not routing.type or not routing.id: - return - sender_name = msg.from_user_id.split("@")[0] or msg.from_user_id - if routing.type == "thread": - from backend.web.services.message_routing import route_message_to_brain - - content = format_wechat_message(sender_name, msg.from_user_id, msg.text) - await route_message_to_brain(app, routing.id, content, source="owner", sender_name=sender_name) - elif routing.type == "chat": - content = format_wechat_message(sender_name, msg.from_user_id, msg.text) - app.state.chat_service.send_message(routing.id, conn.user_id, content) - - app.state.wechat_registry = WeChatConnectionRegistry(delivery_fn=_wechat_deliver) - app.state.wechat_registry.auto_start_all() - yield finally: # @@@background-task-shutdown-order - cancel monitor/reaper before provider cleanup. @@ -295,10 +189,6 @@ async def _wechat_deliver(conn, msg): except asyncio.CancelledError: pass - # Cleanup: stop WeChat connections - if hasattr(app.state, "wechat_registry"): - await app.state.wechat_registry.shutdown() - # Cleanup: stop cron scheduler if app.state.cron_service: await app.state.cron_service.stop() @@ -312,3 +202,8 @@ async def _wechat_deliver(conn, msg): agent.close() except Exception as e: print(f"[web] Agent cleanup error: {e}") + + # Cleanup: stop LSP language servers + from core.tools.lsp.service import lsp_pool + + await lsp_pool.close_all() diff --git a/backend/web/core/storage_factory.py b/backend/web/core/storage_factory.py index 8e189dd9d..caba25f04 100644 --- a/backend/web/core/storage_factory.py +++ b/backend/web/core/storage_factory.py @@ -45,10 +45,8 @@ def make_cron_job_repo() -> Any: def make_sandbox_monitor_repo() -> Any: - if _strategy() == "supabase": - from storage.providers.supabase.sandbox_monitor_repo import SupabaseSandboxMonitorRepo - - return SupabaseSandboxMonitorRepo(_supabase_client()) + # @@@sandbox-runtime-truth-stays-local - sandbox lifecycle facts still live in local sandbox.db. + # Auth/member/thread metadata can be Supabase-backed without moving lease/session/terminal monitoring there. from storage.providers.sqlite.sandbox_monitor_repo import SQLiteSandboxMonitorRepo return SQLiteSandboxMonitorRepo() diff --git a/backend/web/core/supabase_factory.py b/backend/web/core/supabase_factory.py index c8dc9abd1..6afd00655 100644 --- a/backend/web/core/supabase_factory.py +++ b/backend/web/core/supabase_factory.py @@ -1,4 +1,4 @@ -"""Runtime Supabase client factory for storage wiring.""" +"""Runtime Supabase client factories for storage and auth wiring.""" from __future__ import annotations @@ -6,6 +6,19 @@ import httpx from supabase import ClientOptions, create_client +from supabase_auth import SyncGoTrueClient + + +def _resolve_supabase_url() -> str: + url = os.getenv("SUPABASE_INTERNAL_URL") or os.getenv("SUPABASE_PUBLIC_URL") + if not url: + raise RuntimeError("SUPABASE_INTERNAL_URL or SUPABASE_PUBLIC_URL is required.") + return url + + +def _resolve_supabase_auth_url() -> str: + url = os.getenv("SUPABASE_AUTH_URL") or _resolve_supabase_url() + return url def create_supabase_client(): @@ -16,13 +29,30 @@ def create_supabase_client(): httpx client never routes through any system/VPN proxy. """ # Prefer internal URL (same-host direct connection) over public tunnel URL. - url = os.getenv("SUPABASE_INTERNAL_URL") or os.getenv("SUPABASE_PUBLIC_URL") + url = _resolve_supabase_url() key = os.getenv("LEON_SUPABASE_SERVICE_ROLE_KEY") - if not url: - raise RuntimeError("SUPABASE_INTERNAL_URL or SUPABASE_PUBLIC_URL is required.") if not key: raise RuntimeError("LEON_SUPABASE_SERVICE_ROLE_KEY is required for Supabase storage runtime.") schema = os.getenv("LEON_DB_SCHEMA", "public") timeout = httpx.Timeout(30.0, connect=10.0) http_client = httpx.Client(timeout=timeout, trust_env=False) return create_client(url, key, options=ClientOptions(httpx_client=http_client, schema=schema)) + + +def create_supabase_auth_client(): + """Build a supabase-py auth client for end-user auth flows. + + Uses the anon key rather than service-role credentials so auth endpoints + behave like real caller traffic instead of admin/server traffic. + """ + url = _resolve_supabase_auth_url() + key = os.getenv("SUPABASE_ANON_KEY") + if not key: + raise RuntimeError("SUPABASE_ANON_KEY is required for Supabase auth runtime.") + timeout = httpx.Timeout(30.0, connect=10.0) + http_client = httpx.Client(timeout=timeout, trust_env=False) + auth_url = os.getenv("SUPABASE_AUTH_URL") + if auth_url: + # @@@direct-gotrue - local auth may bypass Kong and hit GoTrue directly at /token. + return SyncGoTrueClient(url=auth_url, headers={"apikey": key}, http_client=http_client) + return create_client(url, key, options=ClientOptions(httpx_client=http_client)) diff --git a/backend/web/main.py b/backend/web/main.py index 64f60e0a5..a457e017b 100644 --- a/backend/web/main.py +++ b/backend/web/main.py @@ -83,8 +83,6 @@ def _sqlite_root_supports_wal(root: Path) -> bool: from backend.web.routers import ( # noqa: E402 auth, chats, - connections, - debug, entities, invite_codes, marketplace, @@ -118,11 +116,9 @@ def _sqlite_root_supports_wal(root: Path) -> bool: app.include_router(entities.members_router) app.include_router(sandbox.router) app.include_router(webhooks.router) -app.include_router(connections.router) app.include_router(thread_files.router) app.include_router(thread_files._public) app.include_router(settings.router) -app.include_router(debug.router) app.include_router(panel.router) app.include_router(monitor.router) app.include_router(marketplace.router) diff --git a/backend/web/models/requests.py b/backend/web/models/requests.py index 05a108bf0..582ec7f4c 100644 --- a/backend/web/models/requests.py +++ b/backend/web/models/requests.py @@ -1,8 +1,8 @@ """Pydantic request models for Leon web API.""" -from typing import Literal +from typing import Any, Literal -from pydantic import BaseModel, Field +from pydantic import AliasChoices, BaseModel, Field from sandbox.config import MountSpec @@ -20,7 +20,7 @@ class RecipeSnapshotRequest(BaseModel): class CreateThreadRequest(BaseModel): member_id: str # which agent template to create thread from - sandbox: str = "local" + sandbox: str = Field(default="local", validation_alias=AliasChoices("sandbox", "sandbox_type")) recipe: RecipeSnapshotRequest | None = None lease_id: str | None = None cwd: str | None = None @@ -53,3 +53,22 @@ class RunRequest(BaseModel): class SendMessageRequest(BaseModel): message: str attachments: list[str] = Field(default_factory=list) + + +class AskUserAnswerRequest(BaseModel): + header: str | None = None + question: str | None = None + selected_options: list[str] = Field(default_factory=list) + free_text: str | None = None + + +class ResolvePermissionRequest(BaseModel): + decision: Literal["allow", "deny"] + message: str | None = None + answers: list[AskUserAnswerRequest] | None = None + annotations: dict[str, Any] | None = None + + +class ThreadPermissionRuleRequest(BaseModel): + behavior: Literal["allow", "deny", "ask"] + tool_name: str diff --git a/backend/web/routers/chats.py b/backend/web/routers/chats.py index 5e7e3ff9e..fc9e45482 100644 --- a/backend/web/routers/chats.py +++ b/backend/web/routers/chats.py @@ -180,15 +180,12 @@ async def stream_chat_events( app: Annotated[Any, Depends(get_app)] = None, ): """SSE stream for chat events. Uses ?token= for auth.""" - from backend.web.core.dependencies import _DEV_SKIP_AUTH - - if not _DEV_SKIP_AUTH: - if not token: - raise HTTPException(401, "Missing token") - try: - app.state.auth_service.verify_token(token) - except ValueError as e: - raise HTTPException(401, str(e)) + if not token: + raise HTTPException(401, "Missing token") + try: + app.state.auth_service.verify_token(token) + except ValueError as e: + raise HTTPException(401, str(e)) event_bus = app.state.chat_event_bus queue = event_bus.subscribe(chat_id) diff --git a/backend/web/routers/connections.py b/backend/web/routers/connections.py deleted file mode 100644 index c5fa0adc2..000000000 --- a/backend/web/routers/connections.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Connection endpoints — manage external platform connections (WeChat, etc.). - -@@@per-user — all endpoints scoped by user_id (the user's social identity). -""" - -from typing import Annotated, Any - -from fastapi import APIRouter, Depends, HTTPException - -from backend.web.core.dependencies import get_app, get_current_user_id -from backend.web.services.wechat_service import ( - QrPollRequest, - RoutingConfig, - RoutingSetRequest, - WeChatConnectionRegistry, -) - -router = APIRouter(prefix="/api/connections", tags=["connections"]) - - -def _get_registry(app: Any) -> WeChatConnectionRegistry: - return app.state.wechat_registry - - -# --- WeChat --- - - -@router.get("/wechat/state") -async def wechat_state( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - return _get_registry(app).get(user_id).get_state() - - -@router.post("/wechat/qrcode") -async def wechat_qrcode( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - conn = _get_registry(app).get(user_id) - if conn.connected: - raise HTTPException(400, "Already connected. Disconnect first.") - return await conn.get_qr_code() - - -@router.post("/wechat/qrcode/poll") -async def wechat_qrcode_poll( - body: QrPollRequest, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - registry = _get_registry(app) - conn = registry.get(user_id) - result = await conn.poll_qr_status(body.qrcode) - # Evict duplicates after successful connection - if result.get("status") == "confirmed" and conn._credentials: - registry.evict_duplicates(conn._credentials.account_id, user_id) - return result - - -@router.post("/wechat/disconnect") -async def wechat_disconnect( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - _get_registry(app).get(user_id).disconnect() - return {"ok": True} - - -@router.post("/wechat/polling/start") -async def wechat_start_polling( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - conn = _get_registry(app).get(user_id) - if not conn.connected: - raise HTTPException(400, "Not connected") - conn.start_polling() - return {"ok": True, "polling": True} - - -@router.post("/wechat/polling/stop") -async def wechat_stop_polling( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - _get_registry(app).get(user_id).stop_polling() - return {"ok": True, "polling": False} - - -# --- Routing config --- - - -@router.get("/wechat/routing") -async def wechat_get_routing( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - return _get_registry(app).get(user_id).routing.model_dump() - - -@router.post("/wechat/routing") -async def wechat_set_routing( - body: RoutingSetRequest, - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - _get_registry(app).get(user_id).set_routing(RoutingConfig(type=body.type, id=body.id, label=body.label)) - return {"ok": True} - - -@router.delete("/wechat/routing") -async def wechat_clear_routing( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - _get_registry(app).get(user_id).set_routing(RoutingConfig()) - return {"ok": True} - - -# --- List targets for routing picker --- - - -@router.get("/wechat/routing/targets") -async def wechat_routing_targets( - user_id: Annotated[str, Depends(get_current_user_id)], - app: Annotated[Any, Depends(get_app)], -) -> dict: - """List available threads and chats for the routing picker.""" - from backend.web.utils.serializers import avatar_url - - raw_threads = app.state.thread_repo.list_by_owner_user_id(user_id) - threads = [ - { - "id": t["id"], - "label": t.get("entity_name") or t.get("member_name") or t["id"][:12], - "avatar_url": avatar_url(t.get("member_id"), bool(t.get("member_avatar"))), - } - for t in raw_threads - ] - - raw_chats = app.state.chat_service.list_chats_for_user(user_id) - chats = [] - for c in raw_chats: - others = [e for e in c.get("entities", []) if e["id"] != user_id] - name = ", ".join(e["name"] for e in others) or "Unknown" - chats.append({"id": c["id"], "label": name}) - - return {"threads": threads, "chats": chats} diff --git a/backend/web/routers/debug.py b/backend/web/routers/debug.py deleted file mode 100644 index 57299f219..000000000 --- a/backend/web/routers/debug.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Debug logging endpoints.""" - -from fastapi import APIRouter -from pydantic import BaseModel - -router = APIRouter(prefix="/api/debug", tags=["debug"]) - - -class LogMessage(BaseModel): - message: str - timestamp: str - - -@router.post("/log") -async def log_frontend_message(payload: LogMessage) -> dict: - """Receive frontend console logs and write to file.""" - with open("/tmp/leon-frontend-console.log", "a") as f: - f.write(f"[{payload.timestamp}] {payload.message}\n") - return {"status": "ok"} diff --git a/backend/web/routers/entities.py b/backend/web/routers/entities.py index 96f636955..bf64c2e9d 100644 --- a/backend/web/routers/entities.py +++ b/backend/web/routers/entities.py @@ -194,6 +194,9 @@ async def list_entities( member = member_map.get(entity.member_id) owner = member_map.get(member.owner_user_id) if member and member.owner_user_id else None thread = app.state.thread_repo.get_by_id(entity.thread_id) if entity.thread_id else None + # @@@chat-discovery-surface - branch/subagent entities are runtime artifacts, not top-level chat picker entries. + if entity.type == "agent" and thread and not thread["is_main"]: + continue items.append( { "id": entity.id, # entity.id = member_id = social identity for agents diff --git a/backend/web/routers/panel.py b/backend/web/routers/panel.py index 3fe2f481b..700fe1d2f 100644 --- a/backend/web/routers/panel.py +++ b/backend/web/routers/panel.py @@ -315,8 +315,12 @@ async def update_resource_content(resource_type: str, resource_id: str, req: Upd @router.get("/profile") -async def get_profile() -> dict[str, Any]: - return await asyncio.to_thread(profile_service.get_profile) +async def get_profile( + user_id: Annotated[str, Depends(get_current_user_id)], + request: Request, +) -> dict[str, Any]: + member = request.app.state.member_repo.get_by_id(user_id) + return await asyncio.to_thread(profile_service.get_profile, member) @router.put("/profile") diff --git a/backend/web/routers/threads.py b/backend/web/routers/threads.py index 33a75b8aa..653674c44 100644 --- a/backend/web/routers/threads.py +++ b/backend/web/routers/threads.py @@ -21,8 +21,10 @@ from backend.web.models.requests import ( CreateThreadRequest, ResolveMainThreadRequest, + ResolvePermissionRequest, SaveThreadLaunchConfigRequest, SendMessageRequest, + ThreadPermissionRuleRequest, ) from backend.web.services import sandbox_service from backend.web.services.agent_pool import get_or_create_agent, resolve_thread_sandbox @@ -50,6 +52,7 @@ from backend.web.utils.serializers import avatar_url, serialize_message from core.runtime.middleware.monitor import AgentState from sandbox.config import MountSpec +from sandbox.manager import bind_thread_to_existing_lease from sandbox.recipes import normalize_recipe_snapshot, provider_type_from_name from sandbox.thread_context import set_current_thread_id from storage.contracts import EntityRow @@ -59,6 +62,18 @@ router = APIRouter(prefix="/api/threads", tags=["threads"]) +class _NoopAsyncLock: + async def __aenter__(self) -> None: + return None + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + +def _is_internal_child_thread(thread_id: str) -> bool: + return thread_id.startswith("subagent-") + + def _invalidate_resource_overview_cache() -> None: # @@@resource-overview-invalidation - thread/lease mutations change the monitor topology immediately. # Clear the overview snapshot so the next /api/monitor/resources read reflects the fresh binding/state. @@ -179,6 +194,71 @@ async def _validate_mount_capability_gate( ) +def _provider_unavailable_response(sandbox_type: str) -> JSONResponse: + return JSONResponse( + status_code=400, + content={ + "error": "sandbox_provider_unavailable", + "provider": sandbox_type, + }, + ) + + +def _format_ask_user_question_followup( + pending_request: dict[str, Any], + *, + answers: list[dict[str, Any]], + annotations: dict[str, Any] | None, +) -> str: + payload: dict[str, Any] = { + "questions": (pending_request.get("args") or {}).get("questions", []), + "answers": answers, + } + if annotations is not None: + payload["annotations"] = annotations + # @@@ask-user-followup-payload - keep this as one narrow, structured owner reply + # so the resumed run can continue from the user's choices without inventing + # a bespoke second continuation channel. + return ( + "The user answered your AskUserQuestion prompt. Continue the task using these answers.\n" + "\n" + f"{json.dumps(payload, ensure_ascii=False, indent=2)}\n" + "" + ) + + +def _serialize_permission_answers(payload: Any) -> list[dict[str, Any]] | None: + raw_answers = getattr(payload, "answers", None) + if raw_answers is None: + return None + serialized: list[dict[str, Any]] = [] + for item in raw_answers: + if hasattr(item, "model_dump"): + serialized.append(item.model_dump(exclude_none=True)) + elif isinstance(item, dict): + serialized.append({key: value for key, value in item.items() if value is not None}) + else: + serialized.append({key: value for key, value in vars(item).items() if value is not None}) + return serialized + + +def _validate_sandbox_provider_gate(app: Any, owner_user_id: str, payload: CreateThreadRequest) -> JSONResponse | None: + sandbox_type = payload.sandbox or "local" + if payload.lease_id: + owned_lease = next( + (lease for lease in sandbox_service.list_user_leases(owner_user_id) if lease["lease_id"] == payload.lease_id), + None, + ) + if owned_lease is not None: + sandbox_type = str(owned_lease["provider_name"] or sandbox_type) + if sandbox_type == "local": + return None + provider = sandbox_service.build_provider_from_config_name(sandbox_type) + if provider is not None: + return None + return _provider_unavailable_response(sandbox_type) + + def _get_agent_for_thread(app: Any, thread_id: str) -> Any | None: """Get agent instance for a thread from the agent pool.""" pool = getattr(app.state, "agent_pool", None) @@ -210,7 +290,165 @@ def _thread_payload(app: Any, thread_id: str, sandbox_type: str) -> dict[str, An } -def _create_thread_sandbox_resources(thread_id: str, sandbox_type: str, recipe: dict[str, Any] | None) -> None: +_IDLE_REPLAYABLE_RUN_EVENTS = frozenset({"error", "cancelled", "retry"}) + + +def _checkpoint_tail_is_pending_owner_turn(messages: list[dict[str, Any]]) -> bool: + if not messages: + return False + tail = messages[-1] + if tail.get("type") != "HumanMessage": + return False + meta = tail.get("metadata") or {} + return meta.get("source") not in {"system", "external"} + + +async def _get_thread_display_entries(app: Any, thread_id: str) -> list[dict[str, Any]]: + display_builder = app.state.display_builder + entries = display_builder.get_entries(thread_id) + if entries is not None: + _normalize_blocking_subagent_terminal_status(entries) + sandbox_type = resolve_thread_sandbox(app, thread_id) + agent = await get_or_create_agent(app, sandbox_type, thread_id=thread_id) + if entries is not None and getattr(agent.runtime, "current_state", None) != AgentState.IDLE: + return entries + + set_current_thread_id(thread_id) + config = {"configurable": {"thread_id": thread_id}} + state = await agent.agent.aget_state(config) + values = getattr(state, "values", {}) if state else {} + messages = values.get("messages", []) if isinstance(values, dict) else [] + serialized = [serialize_message(msg) for msg in messages] + + from core.runtime.visibility import annotate_owner_visibility + + annotated, _ = annotate_owner_visibility(serialized) + if entries is not None and not _display_entries_need_idle_rebuild(entries, annotated): + return entries + entries = display_builder.build_from_checkpoint(thread_id, annotated) + if _checkpoint_tail_is_pending_owner_turn(annotated): + await _replay_latest_run_failure_events( + thread_id=thread_id, + display_builder=display_builder, + ) + entries = display_builder.get_entries(thread_id) or entries + _normalize_blocking_subagent_terminal_status(entries) + return entries + + +def _display_entries_need_idle_rebuild(entries: list[dict[str, Any]], messages: list[dict[str, Any]]) -> bool: + if not messages: + return bool(entries) + if not entries: + return True + # @@@idle-cache-honesty - idle detail must not trust cached assistant shells after + # clear/restart. Rebuild only when cache is visibly impossible for the persisted checkpoint. + return any(entry.get("role") == "assistant" and not entry.get("segments") for entry in entries) + + +def _normalize_blocking_subagent_terminal_status(entries: list[dict[str, Any]]) -> None: + for entry in entries: + if entry.get("role") != "assistant": + continue + for seg in entry.get("segments", []): + if seg.get("type") != "tool": + continue + step = seg.get("step") or {} + if step.get("name") != "Agent" or step.get("status") != "done": + continue + stream = step.get("subagent_stream") + if not isinstance(stream, dict): + continue + result_text = step.get("result") + existing_status = str(stream.get("status") or "").lower() + terminal_status = ( + existing_status + if existing_status in {"completed", "error", "cancelled"} + else ("error" if isinstance(result_text, str) and result_text.startswith("") else "completed") + ) + if stream.get("status") != terminal_status: + # @@@blocking-subagent-terminal-honesty - a finished blocking Agent tool + # must not keep exposing a stale running child status on refresh/detail/tasks. + stream["status"] = terminal_status + if terminal_status == "error" and not stream.get("error") and isinstance(result_text, str): + stream["error"] = result_text + + +def _collect_display_subagent_tasks(entries: list[dict[str, Any]]) -> dict[str, dict[str, Any]]: + tasks: dict[str, dict[str, Any]] = {} + for entry in entries: + if entry.get("role") != "assistant": + continue + for seg in entry.get("segments", []): + if seg.get("type") != "tool": + continue + step = seg.get("step") or {} + if step.get("name") != "Agent": + continue + stream = step.get("subagent_stream") + if not isinstance(stream, dict) or not stream.get("task_id"): + continue + task_id = str(stream["task_id"]) + raw_args = step.get("args") + args: dict[str, Any] = raw_args if isinstance(raw_args, dict) else {} + description = stream.get("description") or args.get("description") or args.get("prompt") + status = str(stream.get("status") or ("completed" if step.get("status") == "done" else "running")) + result_text = step.get("result") or stream.get("text") + # @@@dual-source-task-surface - blocking Agent subagents never enter parent _background_runs, + # so /tasks must also project persisted subagent_stream state from display history. + tasks[task_id] = { + "task_id": task_id, + "task_type": "agent", + "status": status, + "command_line": None, + "description": description, + "exit_code": None, + "error": stream.get("error"), + "result": result_text, + "text": result_text, + "thread_id": stream.get("thread_id"), + } + return tasks + + +async def _replay_latest_run_failure_events( + *, + thread_id: str, + display_builder: Any, +) -> None: + from backend.web.services.event_store import get_latest_run_id, read_events_after + + run_id = await get_latest_run_id(thread_id) + if not run_id or run_id.startswith("activity_"): + return + + events = await read_events_after(thread_id, run_id, 0) + if not any(event.get("event") in _IDLE_REPLAYABLE_RUN_EVENTS for event in events): + return + + # @@@idle-run-error-replay - checkpoint can stop at the owner's input when + # the run dies before first persisted AI/Tool message. Rebuild must replay + # the latest run-level failure events so refresh/detail stays honest. + for event in events: + event_type = event.get("event", "") + if event_type not in {"run_start", "run_done", *_IDLE_REPLAYABLE_RUN_EVENTS}: + continue + raw_data = event.get("data", "{}") + try: + payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data + except (json.JSONDecodeError, TypeError): + payload = {} + if not isinstance(payload, dict): + payload = {} + display_builder.apply_event(thread_id, event_type, payload) + + +def _create_thread_sandbox_resources( + thread_id: str, + sandbox_type: str, + recipe: dict[str, Any] | None, + cwd: str | None = None, +) -> None: """Create volume, lease, and terminal eagerly so volume exists before file uploads.""" from datetime import datetime @@ -250,11 +488,11 @@ def _create_thread_sandbox_resources(thread_id: str, sandbox_type: str, recipe: terminal_repo = SQLiteTerminalRepo(db_path=sandbox_db) try: terminal_id = f"term-{uuid.uuid4().hex[:12]}" - # @@@initial-cwd - use project root for local, provider default for remote + # @@@initial-cwd - local threads own their requested cwd; remote threads start from provider defaults. from backend.web.core.config import LOCAL_WORKSPACE_ROOT if sandbox_type == "local": - initial_cwd = str(LOCAL_WORKSPACE_ROOT) + initial_cwd = cwd or str(LOCAL_WORKSPACE_ROOT) else: from backend.web.services.sandbox_service import build_provider_from_config_name from sandbox.manager import resolve_provider_cwd @@ -271,43 +509,6 @@ def _create_thread_sandbox_resources(thread_id: str, sandbox_type: str, recipe: terminal_repo.close() -def _resolve_existing_lease_cwd(lease_id: str, fallback_cwd: str | None) -> str: - if fallback_cwd: - return fallback_cwd - - from backend.web.core.config import LOCAL_WORKSPACE_ROOT - from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path - from storage.providers.sqlite.terminal_repo import SQLiteTerminalRepo - - terminal_repo = SQLiteTerminalRepo(db_path=resolve_role_db_path(SQLiteDBRole.SANDBOX)) - try: - row = terminal_repo.get_latest_by_lease(lease_id) - finally: - terminal_repo.close() - if row and row.get("cwd"): - return str(row["cwd"]) - - return str(LOCAL_WORKSPACE_ROOT) - - -def _bind_thread_to_existing_lease(thread_id: str, lease_id: str, *, cwd: str | None) -> str: - from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path - from storage.providers.sqlite.terminal_repo import SQLiteTerminalRepo - - initial_cwd = _resolve_existing_lease_cwd(lease_id, cwd) - terminal_repo = SQLiteTerminalRepo(db_path=resolve_role_db_path(SQLiteDBRole.SANDBOX)) - try: - terminal_repo.create( - terminal_id=f"term-{uuid.uuid4().hex[:12]}", - thread_id=thread_id, - lease_id=lease_id, - initial_cwd=initial_cwd, - ) - finally: - terminal_repo.close() - return initial_cwd - - def _create_owned_thread( app: Any, owner_user_id: str, @@ -390,7 +591,7 @@ def _create_owned_thread( if selected_lease_id: # @@@reuse-lease-binding - Reuse an existing lease by attaching a fresh terminal for the new thread. - bound_cwd = _bind_thread_to_existing_lease( + bound_cwd = bind_thread_to_existing_lease( new_thread_id, selected_lease_id, cwd=payload.cwd, @@ -403,6 +604,7 @@ def _create_owned_thread( new_thread_id, sandbox_type, payload.recipe.model_dump() if payload.recipe else None, + payload.cwd, ) if selected_lease_id and owned_lease is not None: @@ -448,6 +650,9 @@ async def create_thread( app: Annotated[Any, Depends(get_app)] = None, ) -> dict[str, Any] | JSONResponse: """Create a new child thread for an agent member.""" + provider_error = _validate_sandbox_provider_gate(app, user_id, payload) + if provider_error is not None: + return provider_error # Validate bind_mounts capability before creating thread sandbox_type = payload.sandbox or "local" requested_mounts = payload.bind_mounts if payload.bind_mounts else [] @@ -477,7 +682,15 @@ async def resolve_main_thread( existing = app.state.thread_repo.get_main_thread(payload.member_id) if existing is None: return {"thread": None} - return {"thread": _thread_payload(app, existing["id"], existing.get("sandbox_type", "local"))} + try: + return {"thread": _thread_payload(app, existing["id"], existing.get("sandbox_type", "local"))} + except HTTPException as exc: + # @@@orphan-main-thread - stale bootstrap data can leave the member pointing at a thread whose + # member/entity rows are gone. Treat that as "no resolvable main thread" instead of surfacing a 500. + if exc.status_code == 500 and "missing member/entity" in str(exc.detail): + logger.warning("resolve_main_thread ignored orphaned main thread %s for member %s", existing["id"], payload.member_id) + return {"thread": None} + raise @router.get("/default-config") @@ -518,6 +731,8 @@ async def list_threads( threads = [] for t in raw: tid = t["id"] + if _is_internal_child_thread(tid): + continue sandbox_type = t.get("sandbox_type", "local") # Check if agent is currently running — pool key is "{thread_id}:{sandbox_type}" running = False @@ -562,26 +777,10 @@ async def get_thread_messages( @@@display-builder — returns pre-computed ChatEntry[] from DisplayBuilder. Hot path: return in-memory state. Cold path: rebuild from checkpoint. """ - display_builder = app.state.display_builder sandbox_type = resolve_thread_sandbox(app, thread_id) agent = await get_or_create_agent(app, sandbox_type, thread_id=thread_id) - - # Hot path: return cached display entries - entries = display_builder.get_entries(thread_id) - if entries is None: - # Cold path: rebuild from checkpoint - set_current_thread_id(thread_id) - config = {"configurable": {"thread_id": thread_id}} - state = await agent.agent.aget_state(config) - values = getattr(state, "values", {}) if state else {} - messages = values.get("messages", []) if isinstance(values, dict) else [] - serialized = [serialize_message(msg) for msg in messages] - - from core.runtime.visibility import annotate_owner_visibility - - annotated, _ = annotate_owner_visibility(serialized) - entries = display_builder.build_from_checkpoint(thread_id, annotated) - + display_builder = app.state.display_builder + entries = await _get_thread_display_entries(app, thread_id) sandbox_info = get_sandbox_info(agent, thread_id, sandbox_type) return { "thread_id": thread_id, @@ -647,6 +846,28 @@ async def delete_thread( return {"ok": True, "thread_id": thread_id} +@router.post("/{thread_id}/clear") +async def clear_thread_history( + thread_id: str, + user_id: Annotated[str, Depends(verify_thread_owner)], + app: Annotated[Any, Depends(get_app)] = None, +) -> dict[str, Any]: + """Clear replayable thread history while preserving the thread itself.""" + sandbox_type = resolve_thread_sandbox(app, thread_id) + + lock = await get_thread_lock(app, thread_id) + async with lock: + agent = await get_or_create_agent(app, sandbox_type, thread_id=thread_id) + if hasattr(agent, "runtime") and agent.runtime.current_state == AgentState.ACTIVE: + raise HTTPException(status_code=409, detail="Cannot clear thread while run is in progress") + await agent.aclear_thread(thread_id) + + app.state.display_builder.clear(thread_id) + app.state.thread_event_buffers.pop(thread_id, None) + app.state.queue_manager.clear_all(thread_id) + return {"ok": True, "thread_id": thread_id} + + @router.post("/{thread_id}/messages") async def send_message( thread_id: str, @@ -705,7 +926,7 @@ async def get_thread_history( thread_id: str, limit: int = 20, truncate: int = 300, - user_id: Annotated[str, Depends(verify_thread_owner)] = None, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, app: Annotated[Any, Depends(get_app)] = None, ) -> dict[str, Any]: """Compact conversation history for debugging — no raw LangChain noise. @@ -759,7 +980,7 @@ def _expand(msg: Any) -> list[dict[str, Any]]: text = extract_text_content(msg.content) if text: entries.append({"role": "assistant", "text": _trunc(text)}) - return entries or [{"role": "assistant", "text": ""}] + return entries if cls == "ToolMessage": return [ { @@ -782,11 +1003,148 @@ def _expand(msg: Any) -> list[dict[str, Any]]: } +@router.get("/{thread_id}/permissions") +async def get_thread_permissions( + thread_id: str, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, + thread_lock: Annotated[asyncio.Lock | None, Depends(get_thread_lock)] = None, + agent: Annotated[Any, Depends(get_thread_agent)] = None, +) -> dict[str, Any]: + # @@@permission-state-lock - owner polling and resolve can race on idle + # threads. Serialize the lightweight /permissions read with resolve/persist + # so stale checkpoint hydration cannot resurrect an already-resolved request. + async with thread_lock or _NoopAsyncLock(): + await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + rule_state = agent.get_thread_permission_rules(thread_id) + return { + "thread_id": thread_id, + "requests": agent.get_pending_permission_requests(thread_id), + "session_rules": rule_state["rules"], + "managed_only": rule_state["managed_only"], + } + + +@router.post("/{thread_id}/permissions/{request_id}/resolve") +async def resolve_thread_permission_request( + thread_id: str, + request_id: str, + payload: ResolvePermissionRequest, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, + agent: Annotated[Any, Depends(get_thread_agent)] = None, + app: Annotated[Any, Depends(get_app)] = None, + thread_lock: Annotated[asyncio.Lock | None, Depends(get_thread_lock)] = None, +) -> dict[str, Any]: + async with thread_lock or _NoopAsyncLock(): + await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + pending_requests = { + item.get("request_id"): item + for item in agent.get_pending_permission_requests(thread_id) + if isinstance(item, dict) and item.get("request_id") + } + pending_request = pending_requests.get(request_id) + is_ask_user_question = bool(pending_request and pending_request.get("tool_name") == "AskUserQuestion") + answers = _serialize_permission_answers(payload) + if is_ask_user_question and payload.decision == "allow" and not answers: + raise HTTPException(status_code=400, detail="AskUserQuestion answers are required when approving the request") + ok = agent.resolve_permission_request( + request_id, + decision=payload.decision, + message=payload.message, + answers=answers, + annotations=getattr(payload, "annotations", None), + ) + if not ok: + raise HTTPException(status_code=404, detail="Permission request not found") + await agent.agent.apersist_state(thread_id) + if is_ask_user_question and payload.decision == "allow" and answers is not None: + # @@@ask-user-lifecycle - the owner's answer is about to become a + # real follow-up user message. Clear the old request before that + # run starts so checkpoint replay cannot resurrect the popup. + agent.drop_permission_request(request_id) + await agent.agent.apersist_state(thread_id) + + followup: dict[str, Any] | None = None + if is_ask_user_question and payload.decision == "allow" and pending_request is not None and answers is not None: + from backend.web.services.message_routing import route_message_to_brain + + followup = await route_message_to_brain( + app, + thread_id, + _format_ask_user_question_followup( + pending_request, + answers=answers, + annotations=getattr(payload, "annotations", None), + ), + source="owner", + ) + + response = {"ok": True, "thread_id": thread_id, "request_id": request_id} + if followup is not None: + response["followup"] = followup + return response + + +@router.post("/{thread_id}/permissions/rules") +async def add_thread_permission_rule( + thread_id: str, + payload: ThreadPermissionRuleRequest, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, + agent: Annotated[Any, Depends(get_thread_agent)] = None, +) -> dict[str, Any]: + await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + rule_state = agent.get_thread_permission_rules(thread_id) + if rule_state["managed_only"]: + raise HTTPException(status_code=409, detail="Managed permission rules only; session overrides are disabled") + ok = agent.add_thread_permission_rule( + thread_id, + behavior=payload.behavior, + tool_name=payload.tool_name, + ) + if not ok: + raise HTTPException(status_code=400, detail="Could not add thread permission rule") + await agent.agent.apersist_state(thread_id) + updated = agent.get_thread_permission_rules(thread_id) + return { + "ok": True, + "thread_id": thread_id, + "scope": "session", + "rules": updated["rules"], + "managed_only": updated["managed_only"], + } + + +@router.delete("/{thread_id}/permissions/rules/{behavior}/{tool_name}") +async def delete_thread_permission_rule( + thread_id: str, + behavior: str, + tool_name: str, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, + agent: Annotated[Any, Depends(get_thread_agent)] = None, +) -> dict[str, Any]: + await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + ok = agent.remove_thread_permission_rule( + thread_id, + behavior=behavior, + tool_name=tool_name, + ) + if not ok: + raise HTTPException(status_code=404, detail="Thread permission rule not found") + await agent.agent.apersist_state(thread_id) + updated = agent.get_thread_permission_rules(thread_id) + return { + "ok": True, + "thread_id": thread_id, + "scope": "session", + "rules": updated["rules"], + "managed_only": updated["managed_only"], + } + + @router.get("/{thread_id}/runtime") async def get_thread_runtime( thread_id: str, stream: bool = False, - user_id: Annotated[str, Depends(verify_thread_owner)] = None, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, app: Annotated[Any, Depends(get_app)] = None, ) -> dict[str, Any]: """Get runtime status for a thread.""" @@ -931,17 +1289,12 @@ async def stream_thread_events( app: Annotated[Any, Depends(get_app)] = None, ) -> EventSourceResponse: """Persistent SSE event stream — uses ?token= for auth (EventSource can't set headers).""" - from backend.web.core.dependencies import _DEV_PAYLOAD, _DEV_SKIP_AUTH - - if _DEV_SKIP_AUTH: - sse_user_id = _DEV_PAYLOAD["user_id"] - else: - if not token: - raise HTTPException(401, "Missing token") - try: - sse_user_id = app.state.auth_service.verify_token(token)["user_id"] - except ValueError as e: - raise HTTPException(401, str(e)) + if not token: + raise HTTPException(401, "Missing token") + try: + sse_user_id = app.state.auth_service.verify_token(token)["user_id"] + except ValueError as e: + raise HTTPException(401, str(e)) thread = app.state.thread_repo.get_by_id(thread_id) if not thread: raise HTTPException(404, "Thread not found") @@ -995,7 +1348,7 @@ async def stream_thread_events( @router.post("/{thread_id}/runs/cancel") async def cancel_run( thread_id: str, - user_id: Annotated[str, Depends(verify_thread_owner)] = None, + user_id: Annotated[str | None, Depends(verify_thread_owner)] = None, app: Annotated[Any, Depends(get_app)] = None, ): """Cancel an active run for the given thread.""" @@ -1016,6 +1369,33 @@ def _get_background_runs(app: Any, thread_id: str) -> dict: return getattr(agent, "_background_runs", {}) if agent else {} +def _background_run_type(run: Any) -> str: + return "bash" if run.__class__.__name__ == "_BashBackgroundRun" else "agent" + + +def _serialize_background_run(task_id: str, run: Any, *, include_result: bool) -> dict[str, Any]: + run_type = _background_run_type(run) + result_text = run.get_result() if include_result and run.is_done else None + payload = { + "task_id": task_id, + "task_type": run_type, + "status": "completed" if run.is_done else "running", + "command_line": getattr(run, "command", None) if run_type == "bash" else None, + } + if include_result: + payload["result"] = result_text + payload["text"] = result_text + return payload + payload["description"] = getattr(run, "description", None) + payload["exit_code"] = getattr(getattr(run, "_cmd", None), "exit_code", None) if run_type == "bash" else None + payload["error"] = None + return payload + + +async def _get_display_task_map(app: Any, thread_id: str) -> dict[str, dict[str, Any]]: + return _collect_display_subagent_tasks(await _get_thread_display_entries(app, thread_id)) + + @router.get("/{thread_id}/tasks") async def list_tasks( thread_id: str, @@ -1023,18 +1403,20 @@ async def list_tasks( ) -> list[dict]: """列出线程的所有后台 run(bash + agent)""" runs = _get_background_runs(request.app, thread_id) - result = [] - for task_id, run in runs.items(): - run_type = "bash" if run.__class__.__name__ == "_BashBackgroundRun" else "agent" + result = [_serialize_background_run(task_id, run, include_result=False) for task_id, run in runs.items()] + seen_task_ids = set(runs) + for task_id, task in (await _get_display_task_map(request.app, thread_id)).items(): + if task_id in seen_task_ids: + continue result.append( { - "task_id": task_id, - "task_type": run_type, - "status": "completed" if run.is_done else "running", - "command_line": getattr(run, "command", None) if run_type == "bash" else None, - "description": getattr(run, "description", None), - "exit_code": getattr(getattr(run, "_cmd", None), "exit_code", None) if run_type == "bash" else None, - "error": None, + "task_id": task["task_id"], + "task_type": task["task_type"], + "status": task["status"], + "command_line": task["command_line"], + "description": task["description"], + "exit_code": task["exit_code"], + "error": task["error"], } ) return result @@ -1050,18 +1432,19 @@ async def get_task( runs = _get_background_runs(request.app, thread_id) run = runs.get(task_id) if not run: - raise HTTPException(status_code=404, detail="Task not found") + task = (await _get_display_task_map(request.app, thread_id)).get(task_id) + if task is None: + raise HTTPException(status_code=404, detail="Task not found") + return { + "task_id": task["task_id"], + "task_type": task["task_type"], + "status": task["status"], + "command_line": task["command_line"], + "result": task["result"], + "text": task["text"], + } - run_type = "bash" if run.__class__.__name__ == "_BashBackgroundRun" else "agent" - result_text = run.get_result() if run.is_done else None - return { - "task_id": task_id, - "task_type": run_type, - "status": "completed" if run.is_done else "running", - "command_line": getattr(run, "command", None) if run_type == "bash" else None, - "result": result_text, - "text": result_text, - } + return _serialize_background_run(task_id, run, include_result=True) @router.post("/{thread_id}/tasks/{task_id}/cancel") @@ -1074,7 +1457,16 @@ async def cancel_task( runs = _get_background_runs(request.app, thread_id) run = runs.get(task_id) if not run: - raise HTTPException(status_code=404, detail="Task not found") + task = (await _get_display_task_map(request.app, thread_id)).get(task_id) + if task is None: + raise HTTPException(status_code=404, detail="Task not found") + if task["status"] != "running": + raise HTTPException(status_code=400, detail="Task is not running") + thread_task = request.app.state.thread_tasks.get(thread_id) + if thread_task is None or thread_task.done(): + raise HTTPException(status_code=400, detail="Task is not independently cancellable") + thread_task.cancel() + return {"ok": True, "message": "Run cancellation requested", "task_id": task_id} if run.is_done: raise HTTPException(status_code=400, detail="Task is not running") @@ -1112,7 +1504,7 @@ async def _notify_task_cancelled(app: Any, thread_id: str, task_id: str, run: An agent_id=task_id, agent_name=f"cancel-{task_id[:8]}", ) - await emit_fn( + emission = emit_fn( { "event": "task_done", "data": json.dumps( @@ -1125,6 +1517,8 @@ async def _notify_task_cancelled(app: Any, thread_id: str, task_id: str, run: An ), } ) + if asyncio.iscoroutine(emission): + await emission except Exception: logger.warning("Failed to emit task_done for cancelled task %s", task_id, exc_info=True) diff --git a/backend/web/services/agent_pool.py b/backend/web/services/agent_pool.py index 50ecb5dbf..e49b70135 100644 --- a/backend/web/services/agent_pool.py +++ b/backend/web/services/agent_pool.py @@ -1,18 +1,22 @@ """Agent pool management service.""" import asyncio +import logging import os from pathlib import Path from typing import Any from fastapi import FastAPI +from config.user_paths import preferred_existing_user_home_path from core.identity.agent_registry import get_or_create_agent_id from core.runtime.agent import create_leon_agent from sandbox.manager import lookup_sandbox_for_thread from sandbox.thread_context import set_current_thread_id from storage.runtime import build_storage_container +logger = logging.getLogger(__name__) + # Thread lock for config updates _config_update_locks: dict[str, asyncio.Lock] = {} _agent_create_locks: dict[str, asyncio.Lock] = {} @@ -23,9 +27,14 @@ def create_agent_sync( workspace_root: Path | None = None, model_name: str | None = None, agent: str | None = None, + bundle_dir: Path | None = None, + thread_repo: Any = None, + entity_repo: Any = None, + member_repo: Any = None, queue_manager: Any = None, chat_repos: dict | None = None, extra_allowed_paths: list[str] | None = None, + web_app: Any = None, ) -> Any: """Create a LeonAgent with the given sandbox. Runs in a thread.""" storage_container = build_storage_container( @@ -41,10 +50,16 @@ def create_agent_sync( workspace_root=workspace_root or Path.cwd(), sandbox=sandbox_name if sandbox_name != "local" else None, storage_container=storage_container, + permission_resolver_scope="thread", + thread_repo=thread_repo, + entity_repo=entity_repo, + member_repo=member_repo, queue_manager=queue_manager, chat_repos=chat_repos, + web_app=web_app, verbose=True, agent=agent, + bundle_dir=bundle_dir, extra_allowed_paths=extra_allowed_paths, ) @@ -76,11 +91,27 @@ async def get_or_create_agent(app_obj: FastAPI, sandbox_type: str, thread_id: st thread_data = app_obj.state.thread_repo.get_by_id(thread_id) if hasattr(app_obj.state, "thread_repo") else None if sandbox_type == "local": cwd = app_obj.state.thread_cwd.get(thread_id) + cwd_from_live_map = cwd is not None if not cwd and thread_data and thread_data.get("cwd"): cwd = thread_data["cwd"] - app_obj.state.thread_cwd[thread_id] = cwd if cwd: - workspace_root = Path(cwd).resolve() + path = Path(cwd).expanduser() + # @@@fresh-local-cwd-owns-workspace - a cwd chosen in this live backend session is + # the caller contract for local threads; create it instead of silently falling + # back to the repo root. Persisted paths from another host stay advisory. + if cwd_from_live_map: + path.mkdir(parents=True, exist_ok=True) + workspace_root = path.resolve() + app_obj.state.thread_cwd[thread_id] = str(workspace_root) + # @@@host-local-cwd-is-advisory - persisted local thread cwd can come from another + # host (for example a macOS path stored in shared Supabase but replayed inside a + # Linux staging container). Only pin workspace_root when that path exists here. + elif path.exists() and path.is_dir(): + workspace_root = path.resolve() + app_obj.state.thread_cwd[thread_id] = str(workspace_root) + else: + app_obj.state.thread_cwd.pop(thread_id, None) + logger.warning("Ignoring unavailable local cwd for thread %s: %s", thread_id, cwd) # Look up model for this thread (threads table → preferences default) model_name = thread_data.get("model") if thread_data else None @@ -93,6 +124,11 @@ async def get_or_create_agent(app_obj: FastAPI, sandbox_type: str, thread_id: st # @@@agent-vs-member - thread_config.agent stores a member ID (e.g. "__leon__") for display, # NOT an agent type name ("bash", "general", etc.). Never pass it to create_leon_agent. agent_name = agent # explicit caller-provided type only; None → default Leon agent + bundle_dir = None + if thread_data and thread_data.get("member_id"): + member_dir = preferred_existing_user_home_path("members", str(thread_data["member_id"])) + if member_dir.is_dir(): + bundle_dir = member_dir.resolve() # @@@chat-repos - construct chat_repos for ChatToolService if entity system is available chat_repos = None @@ -136,12 +172,24 @@ async def get_or_create_agent(app_obj: FastAPI, sandbox_type: str, thread_id: st except FileNotFoundError: pass - extra_allowed_paths = extra_allowed_paths or None + extra_allowed_paths_or_none: list[str] | None = extra_allowed_paths or None # @@@ agent-init-thread - LeonAgent.__init__ uses run_until_complete, must run in thread qm = getattr(app_obj.state, "queue_manager", None) agent_obj = await asyncio.to_thread( - create_agent_sync, sandbox_type, workspace_root, model_name, agent_name, qm, chat_repos, extra_allowed_paths + create_agent_sync, + sandbox_name=sandbox_type, + workspace_root=workspace_root, + model_name=model_name, + agent=agent_name, + bundle_dir=bundle_dir, + thread_repo=getattr(app_obj.state, "thread_repo", None), + entity_repo=getattr(app_obj.state, "entity_repo", None), + member_repo=getattr(app_obj.state, "member_repo", None), + queue_manager=qm, + chat_repos=chat_repos, + extra_allowed_paths=extra_allowed_paths_or_none, + web_app=app_obj, ) member = agent_name or "leon" agent_id = get_or_create_agent_id( diff --git a/backend/web/services/auth_service.py b/backend/web/services/auth_service.py index 85c9c21c6..9467b7e4a 100644 --- a/backend/web/services/auth_service.py +++ b/backend/web/services/auth_service.py @@ -5,6 +5,7 @@ import logging import os import time +from collections.abc import Callable import jwt @@ -22,12 +23,16 @@ def __init__( accounts: AccountRepo, entities: EntityRepo, supabase_client=None, + supabase_auth_client=None, + supabase_auth_client_factory: Callable[[], object] | None = None, invite_codes: InviteCodeRepo | None = None, ) -> None: self._members = members self._accounts = accounts self._entities = entities - self._sb = supabase_client # None in sqlite-only mode + self._sb = supabase_client # storage/service-role client + self._sb_auth = supabase_auth_client # end-user auth client + self._sb_auth_factory = supabase_auth_client_factory self._invite_codes = invite_codes # ------------------------------------------------------------------ @@ -39,6 +44,7 @@ def __init__( def send_otp(self, email: str, password: str, invite_code: str) -> None: """Validate invite code, create user via signUp (sends confirmation OTP to email).""" + auth_client = self._auth_api(self._require_auth_client()) if self._sb is None: raise RuntimeError("Supabase client required.") if self._invite_codes is None or not self._invite_codes.is_valid(invite_code): @@ -46,7 +52,7 @@ def send_otp(self, email: str, password: str, invite_code: str) -> None: from supabase_auth.errors import AuthApiError try: - self._sb.auth.sign_up({"email": email, "password": password}) + auth_client.sign_up({"email": email, "password": password}) except AuthApiError as e: msg = e.message or "" if "already registered" in msg or "already exists" in msg: @@ -55,12 +61,13 @@ def send_otp(self, email: str, password: str, invite_code: str) -> None: def verify_register_otp(self, email: str, token: str) -> dict: """Verify signup OTP. Returns temp_token to be used in complete_register.""" + auth_client = self._auth_api(self._require_auth_client()) if self._sb is None: raise RuntimeError("Supabase client required.") from supabase_auth.errors import AuthApiError try: - resp = self._sb.auth.verify_otp({"email": email, "token": token, "type": "signup"}) + resp = auth_client.verify_otp({"email": email, "token": token, "type": "signup"}) except AuthApiError as e: raise ValueError(f"验证码错误: {e.message}") from e if resp.user is None or resp.session is None: @@ -129,8 +136,7 @@ def complete_register(self, temp_token: str, invite_code: str) -> dict: def login(self, identifier: str, password: str) -> dict: """Login with email or mycel_id + password.""" - if self._sb is None: - raise RuntimeError("Supabase client required for login. Set LEON_STORAGE_STRATEGY=supabase.") + auth_client = self._auth_api(self._require_auth_client()) # Resolve email email = self._resolve_email(identifier) @@ -139,7 +145,7 @@ def login(self, identifier: str, password: str) -> dict: # Sign in via Supabase try: - resp = self._sb.auth.sign_in_with_password({"email": email, "password": password}) + resp = auth_client.sign_in_with_password({"email": email, "password": password}) except AuthApiError: raise ValueError("邮箱或密码错误") if resp.user is None or resp.session is None: @@ -174,7 +180,17 @@ def login(self, identifier: str, password: str) -> dict: } def verify_token(self, token: str) -> dict: - """Verify Supabase JWT. Returns {user_id}.""" + """Verify Supabase JWT. Returns {user_id, entity_id}.""" + auth_client = self._sb_auth_factory() if self._sb_auth_factory is not None else self._sb_auth + if auth_client is not None: + auth_api = self._auth_api(auth_client) + try: + user_resp = auth_api.get_user(token) + except Exception as e: + raise ValueError(f"Token 无效: {e}") from e + if user_resp is None or getattr(user_resp, "user", None) is None: + raise ValueError("Token 无效: user not found") + return {"user_id": str(user_resp.user.id), "entity_id": None} jwt_secret = os.getenv("SUPABASE_JWT_SECRET") if not jwt_secret: raise RuntimeError("SUPABASE_JWT_SECRET env var required for token verification.") @@ -204,6 +220,16 @@ def _resolve_email(self, identifier: str) -> str: return member.email return identifier.strip() + def _require_auth_client(self): + if self._sb_auth_factory is not None: + return self._sb_auth_factory() + if self._sb_auth is None: + raise RuntimeError("Supabase auth client required. Configure SUPABASE_ANON_KEY for auth runtime.") + return self._sb_auth + + def _auth_api(self, auth_client): + return getattr(auth_client, "auth", auth_client) + def _create_initial_agents(self, owner_user_id: str, now: float) -> dict | None: """Create Toad and Morel agents for a new user. Returns first agent info.""" from pathlib import Path diff --git a/backend/web/services/display_builder.py b/backend/web/services/display_builder.py index 25f034ed5..c6b24bc5f 100644 --- a/backend/web/services/display_builder.py +++ b/backend/web/services/display_builder.py @@ -38,18 +38,46 @@ # Helpers — ported from message-mapper.ts # --------------------------------------------------------------------------- -_CHAT_MESSAGE_RE = re.compile(r"]*>([\s\S]*?)") - - -def _extract_chat_message(text: str) -> str | None: - m = _CHAT_MESSAGE_RE.search(text) - return m.group(1).strip() if m else None +_TASK_NOTIFICATION_RUN_ID_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) +_TASK_NOTIFICATION_STATUS_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) def _make_id(prefix: str = "db") -> str: return f"{prefix}-{uuid.uuid4().hex[:12]}" +def _extract_terminal_task_status(notification_type: str | None, text: str) -> tuple[str | None, str | None]: + if notification_type != "agent" or "" not in text: + return None, None + task_match = _TASK_NOTIFICATION_RUN_ID_RE.search(text) + status_match = _TASK_NOTIFICATION_STATUS_RE.search(text) + task_id = task_match.group(1).strip() if task_match else None + status = status_match.group(1).strip().lower() if status_match else None + return task_id, status + + +def _reconcile_subagent_stream_status( + entries: list[dict], + current_turn: dict | None, + task_id: str, + status: str, +) -> None: + # @@@checkpoint-status-reconcile - idle detail rebuild only sees persisted + # checkpoint messages, not live task_done events. If a later persisted + # terminal notification names the child task, reconcile the earlier Agent + # subagent_stream status so cold rebuild does not regress it back to running. + turns: list[dict] = [] + if current_turn is not None: + turns.append(current_turn) + turns.extend(entry for entry in reversed(entries) if entry.get("role") == "assistant" and entry is not current_turn) + for turn in turns: + for seg in turn.get("segments", []): + stream = seg.get("step", {}).get("subagent_stream") + if seg.get("type") == "tool" and stream and stream.get("task_id") == task_id: + stream["status"] = status + return + + # --------------------------------------------------------------------------- # Entry builders # --------------------------------------------------------------------------- @@ -89,6 +117,23 @@ def _append_to_turn(turn: dict, msg_id: str, segments: list[dict]) -> None: turn.setdefault("messageIds", []).append(msg_id) +def _build_subagent_stream( + *, + task_id: str, + thread_id: str, + description: str | None, + status: str, +) -> dict[str, Any]: + return { + "task_id": task_id, + "thread_id": thread_id, + "description": description, + "text": "", + "tool_calls": [], + "status": status, + } + + # --------------------------------------------------------------------------- # ThreadDisplay — per-thread in-memory state # --------------------------------------------------------------------------- @@ -242,6 +287,9 @@ def _handle_human( if source == "system" or (source == "external" and ntype == "chat"): content = _extract_text_content(msg.get("content")) msg_run_id = meta.get("run_id") or None + task_id, task_status = _extract_terminal_task_status(ntype, content) + if task_id and task_status: + _reconcile_subagent_stream_status(entries, current_turn, task_id, task_status) # Fold into current turn if same run if current_turn and (not msg_run_id or msg_run_id == current_run_id): @@ -332,19 +380,12 @@ def _handle_tool(self, msg: dict, _i: int, current_turn: dict | None, _now: int) seg["step"]["result"] = content_str seg["step"]["status"] = "done" - # Restore subagent_stream from metadata meta = msg.get("metadata") or {} - task_id = meta.get("task_id") - sub_thread = meta.get("subagent_thread_id") or (f"subagent-{task_id}" if task_id else None) - - if not task_id and seg["step"].get("name") == "Agent": - try: - parsed = json.loads(content_str) - if isinstance(parsed, dict) and parsed.get("task_id"): - task_id = parsed["task_id"] - sub_thread = parsed.get("thread_id") or f"subagent-{task_id}" - except (json.JSONDecodeError, TypeError): - pass + task_id, sub_thread, task_status = _extract_subagent_stream_identity( + seg["step"].get("name"), + meta, + content_str, + ) if sub_thread and not seg["step"].get("subagent_stream"): seg["step"]["subagent_stream"] = { @@ -353,7 +394,7 @@ def _handle_tool(self, msg: dict, _i: int, current_turn: dict | None, _now: int) "description": meta.get("description"), "text": "", "tool_calls": [], - "status": "completed", + "status": task_status, } break @@ -502,18 +543,18 @@ def _handle_tool_result(td: ThreadDisplay, data: dict) -> dict | None: seg["step"]["result"] = result seg["step"]["status"] = "done" - # Subagent stream tracking - task_id = metadata.get("task_id") - sub_thread = metadata.get("subagent_thread_id") or (f"subagent-{task_id}" if task_id else None) + task_id, sub_thread, task_status = _extract_subagent_stream_identity( + seg["step"].get("name"), + metadata, + result, + ) if sub_thread and not seg["step"].get("subagent_stream"): - seg["step"]["subagent_stream"] = { - "task_id": task_id or "", - "thread_id": sub_thread, - "description": metadata.get("description"), - "text": "", - "tool_calls": [], - "status": "running", - } + seg["step"]["subagent_stream"] = _build_subagent_stream( + task_id=task_id or "", + thread_id=sub_thread, + description=metadata.get("description"), + status=task_status, + ) return { "type": "update_segment", @@ -526,8 +567,15 @@ def _handle_tool_result(td: ThreadDisplay, data: dict) -> dict | None: def _handle_notice(td: ThreadDisplay, data: dict) -> dict | None: content = data.get("content", "") ntype = data.get("notification_type") + task_id, task_status = _extract_terminal_task_status(ntype, content) turn = _get_current_turn(td) + if task_id and task_status: + # @@@live-notice-status-reconcile - live parent detail stays on the + # in-memory display while the followthrough run is still active, so the + # terminal notice must reconcile the earlier Agent step immediately + # instead of waiting for a later cold rebuild from checkpoint. + _reconcile_subagent_stream_status(td.entries, turn, task_id, task_status) if turn: # Fold into current turn seg = {"type": "notice", "content": content, "notification_type": ntype} @@ -629,22 +677,18 @@ def _handle_task_start(td: ThreadDisplay, data: dict) -> dict | None: task_id = data["task_id"] sub_thread = data.get("thread_id") or f"subagent-{task_id}" - # Find most recent Agent tool call without subagent_stream + # @@@late-task-start-race - background Agent tools can return their + # immediate "started" ToolMessage before the async task_start activity + # reaches the parent thread. Still patch the newest Agent step that + # has no child stream, even if its tool_result already marked it done. for seg in reversed(turn["segments"]): - if ( - seg.get("type") == "tool" - and seg.get("step", {}).get("name") == "Agent" - and seg.get("step", {}).get("status") == "calling" - and not seg.get("step", {}).get("subagent_stream") - ): - seg["step"]["subagent_stream"] = { - "task_id": task_id, - "thread_id": sub_thread, - "description": data.get("description"), - "text": "", - "tool_calls": [], - "status": "running", - } + if seg.get("type") == "tool" and seg.get("step", {}).get("name") == "Agent" and not seg.get("step", {}).get("subagent_stream"): + seg["step"]["subagent_stream"] = _build_subagent_stream( + task_id=task_id, + thread_id=sub_thread, + description=data.get("description"), + status="running", + ) idx = _find_seg_index(turn, seg["step"]["id"]) return { "type": "update_segment", @@ -679,6 +723,28 @@ def _find_seg_index(turn: dict, tc_id: str) -> int: return -1 +def _extract_subagent_stream_identity(step_name: str | None, metadata: dict, content: str) -> tuple[str | None, str | None, str]: + task_id = metadata.get("task_id") + sub_thread = metadata.get("subagent_thread_id") or (f"subagent-{task_id}" if task_id else None) + task_status = "completed" if task_id else "running" + + if task_id or step_name != "Agent": + return task_id, sub_thread, task_status + + try: + parsed = json.loads(content) + except (json.JSONDecodeError, TypeError): + return task_id, sub_thread, task_status + + if not isinstance(parsed, dict) or not parsed.get("task_id"): + return task_id, sub_thread, task_status + + task_id = parsed["task_id"] + sub_thread = parsed.get("thread_id") or f"subagent-{task_id}" + task_status = parsed.get("status") or "running" + return task_id, sub_thread, task_status + + # Event type → handler _EVENT_HANDLERS: dict[str, Any] = { "user_message": _handle_user_message, diff --git a/backend/web/services/idle_reaper.py b/backend/web/services/idle_reaper.py index 90651365a..a739aa9fb 100644 --- a/backend/web/services/idle_reaper.py +++ b/backend/web/services/idle_reaper.py @@ -40,7 +40,7 @@ async def idle_reaper_loop(app_obj: FastAPI) -> None: try: count = await asyncio.to_thread(run_idle_reaper_once, app_obj) if count > 0: - print(f"[idle-reaper] paused+closed {count} expired chat session(s)") + print(f"[idle-reaper] reclaimed+closed {count} expired chat session(s)") except Exception as e: print(f"[idle-reaper] error: {e}") await asyncio.sleep(IDLE_REAPER_INTERVAL_SEC) diff --git a/backend/web/services/message_routing.py b/backend/web/services/message_routing.py index 7984e9552..328b10750 100644 --- a/backend/web/services/message_routing.py +++ b/backend/web/services/message_routing.py @@ -26,6 +26,7 @@ async def route_message_to_brain( ACTIVE → enqueue as steer """ from backend.web.services.agent_pool import get_or_create_agent, resolve_thread_sandbox + from backend.web.services.resource_cache import clear_resource_overview_cache from backend.web.services.streaming_service import start_agent_run sandbox_type = resolve_thread_sandbox(app, thread_id) @@ -74,4 +75,7 @@ async def route_message_to_brain( if attachments: meta["attachments"] = attachments run_id = start_agent_run(agent, thread_id, run_content, app, message_metadata=meta) + # @@@resource-cache-run-start - a fresh run can create or resume a lease immediately. + # Drop the cached resource snapshot so the next Resources read reflects the live topology. + clear_resource_overview_cache() return {"status": "started", "routing": "direct", "run_id": run_id, "thread_id": thread_id} diff --git a/backend/web/services/profile_service.py b/backend/web/services/profile_service.py index c6b755bde..60359431a 100644 --- a/backend/web/services/profile_service.py +++ b/backend/web/services/profile_service.py @@ -1,10 +1,11 @@ -"""Profile CRUD — config.json based.""" +"""Profile CRUD — config.json based, with auth-member override for signed-in shell.""" import json from pathlib import Path from typing import Any from config.user_paths import preferred_existing_user_home_path, user_home_path +from storage.contracts import MemberRow LEON_HOME = user_home_path() CONFIG_PATH = LEON_HOME / "config.json" @@ -24,7 +25,23 @@ def _write_json(path: Path, data: Any) -> None: path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") -def get_profile() -> dict[str, Any]: +def _initials_from_name(name: str) -> str: + stripped = name.strip() + if not stripped: + return "U" + compact = "".join(part[:1] for part in stripped.split() if part) + if len(compact) >= 2: + return compact[:2].upper() + return stripped[:2].upper() + + +def get_profile(member: MemberRow | None = None) -> dict[str, Any]: + if member is not None: + return { + "name": member.name or "用户", + "initials": _initials_from_name(member.name or ""), + "email": member.email or "", + } cfg = _read_json(preferred_existing_user_home_path("config.json"), {}) profile = cfg.get("profile", {}) return { diff --git a/backend/web/services/resource_cache.py b/backend/web/services/resource_cache.py index 4b1d5f5fe..67875b4e8 100644 --- a/backend/web/services/resource_cache.py +++ b/backend/web/services/resource_cache.py @@ -55,6 +55,23 @@ def _with_refresh_metadata( return payload +def _snapshot_drifted_from_live_sessions(snapshot: dict[str, Any]) -> bool: + live_stats = resource_service.visible_resource_session_stats() + for provider in snapshot.get("providers") or []: + provider_id = str(provider.get("id") or "") + current = live_stats.get(provider_id, {"sessions": 0, "running": 0}) + cached_running = int(((provider.get("telemetry") or {}).get("running") or {}).get("used") or 0) + cached_sessions = len(provider.get("sessions") or []) + if cached_running != current["running"] or cached_sessions != current["sessions"]: + return True + for provider_id, current in live_stats.items(): + if current["running"] or current["sessions"]: + cached = next((item for item in snapshot.get("providers") or [] if str(item.get("id") or "") == provider_id), None) + if cached is None: + return True + return False + + def refresh_resource_overview_sync() -> dict[str, Any]: """Refresh cached overview snapshot and return latest payload.""" global _snapshot_cache @@ -84,6 +101,11 @@ def get_resource_overview_snapshot() -> dict[str, Any]: with _snapshot_lock: cached = copy.deepcopy(_snapshot_cache) if cached is not None: + # @@@resource-cache-live-drift - durable session truth lands in sandbox.db after a run + # starts; if the cached Resources snapshot no longer matches visible lease/session + # counts, refresh synchronously instead of serving a stale zero-sandbox card. + if _snapshot_drifted_from_live_sessions(cached): + return refresh_resource_overview_sync() return cached # @@@cold-start-cache-fill - route fallback fills cache once to keep first call deterministic. return refresh_resource_overview_sync() diff --git a/backend/web/services/resource_service.py b/backend/web/services/resource_service.py index 236db63ab..6c0738215 100644 --- a/backend/web/services/resource_service.py +++ b/backend/web/services/resource_service.py @@ -23,6 +23,8 @@ probe_and_upsert_for_instance, ) from storage.models import map_lease_to_session_status +from storage.providers.sqlite.kernel import SQLiteDBRole, resolve_role_db_path +from storage.runtime import build_member_repo, build_thread_repo _CONFIG_LOADER = SandboxConfigLoader(SANDBOXES_DIR) @@ -72,7 +74,8 @@ def _resolve_console_url(provider_name: str, config_name: str, *, sandboxes_dir: if provider_name == "e2b": return "https://e2b.dev" if provider_name == "daytona": - daytona = payload.get("daytona") if isinstance(payload.get("daytona"), dict) else {} + raw_daytona = payload.get("daytona") + daytona = raw_daytona if isinstance(raw_daytona, dict) else {} target = str(daytona.get("target") or "").strip().lower() if target == "cloud": return "https://app.daytona.io" @@ -216,17 +219,13 @@ def _to_session_metrics(snapshot: dict[str, Any] | None) -> dict[str, Any] | Non def _member_meta_map(member_repo: Any = None) -> dict[str, dict[str, str | None]]: """Build member_id → display metadata map from DB.""" + repo = member_repo + own_repo = False + if repo is None: + repo = build_member_repo(main_db_path=resolve_role_db_path(SQLiteDBRole.MAIN)) + own_repo = True try: - if member_repo is not None: - members = member_repo.list_all() - else: - from storage.providers.sqlite.member_repo import SQLiteMemberRepo - - repo = SQLiteMemberRepo() - try: - members = repo.list_all() - finally: - repo.close() + members = repo.list_all() return { m.id: { "member_name": m.name, @@ -237,6 +236,9 @@ def _member_meta_map(member_repo: Any = None) -> dict[str, dict[str, str | None] } except Exception: return {} + finally: + if own_repo: + repo.close() def _thread_agent_refs(thread_ids: list[str], thread_repo: Any = None) -> dict[str, str]: @@ -244,14 +246,11 @@ def _thread_agent_refs(thread_ids: list[str], thread_repo: Any = None) -> dict[s unique = sorted({tid for tid in thread_ids if tid}) if not unique: return {} - if thread_repo is None: - from storage.providers.sqlite.thread_repo import SQLiteThreadRepo - - repo = SQLiteThreadRepo() + repo = thread_repo + own_repo = False + if repo is None: + repo = build_thread_repo(main_db_path=resolve_role_db_path(SQLiteDBRole.MAIN)) own_repo = True - else: - repo = thread_repo - own_repo = False try: refs: dict[str, str] = {} for tid in unique: @@ -350,6 +349,69 @@ def _resolve_card_cpu_metric(provider_type: str, telemetry: dict[str, Any]) -> d return cpu +def _is_resource_visible_thread(thread_id: str | None) -> bool: + raw = str(thread_id or "").strip() + if raw.startswith("subagent-"): + return False + return True + + +def _resource_session_identity(session: dict[str, Any]) -> str: + lease_id = str(session.get("lease_id") or "") + thread_id = str(session.get("thread_id") or "") + if lease_id and thread_id: + # @@@resource-session-contract - resource cards are lease/thread scoped, not chat-session scoped. + # Terminal fallback rows can carry distinct session ids for the same visible lease+thread binding. + return f"{lease_id}:{thread_id}" + session_id = str(session.get("session_id") or "") + if session_id: + return session_id + return f"{lease_id}:{thread_id or 'unbound'}" + + +def _project_user_visible_resource_sessions(repo: Any, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Project raw monitor rows into the user-visible resource surface. + + @@@user-visible-resource-projection - raw monitor rows may be bound to a newer + subagent terminal even though the lease still belongs to a user-visible parent + thread. Keep raw monitor truth in the repo; only the Resources UI gets this + parent-thread preference. + """ + grouped: dict[str, list[dict[str, Any]]] = {} + for row in rows: + lease_id = str(row.get("lease_id") or "") + grouped.setdefault(lease_id, []).append(dict(row)) + + projected: list[dict[str, Any]] = [] + for lease_id, group in grouped.items(): + visible_rows = [row for row in group if _is_resource_visible_thread(row.get("thread_id"))] + if visible_rows: + projected.extend(visible_rows) + continue + + if not lease_id: + continue + + try: + thread_rows = repo.query_lease_threads(lease_id) + except Exception: + thread_rows = [] + + preferred_thread_id = next( + (str(item.get("thread_id") or "").strip() for item in thread_rows if _is_resource_visible_thread(item.get("thread_id"))), + "", + ) + if not preferred_thread_id: + continue + + base = dict(group[0]) + base["thread_id"] = preferred_thread_id + base["session_id"] = None + projected.append(base) + + return projected + + # --------------------------------------------------------------------------- # Public API: resource overview # --------------------------------------------------------------------------- @@ -359,7 +421,8 @@ def list_resource_providers() -> dict[str, Any]: # @@@overview-fast-path - avoid provider-network calls; overview uses DB session snapshot. repo = make_sandbox_monitor_repo() try: - sessions = repo.list_sessions_with_leases() + raw_sessions = repo.list_sessions_with_leases() + sessions = _project_user_visible_resource_sessions(repo, raw_sessions) finally: repo.close() @@ -386,6 +449,7 @@ def list_resource_providers() -> dict[str, Any]: provider_sessions = grouped.get(config_name, []) normalized_sessions: list[dict[str, Any]] = [] + seen_session_ids: set[str] = set() running_count = 0 # @@@running-dedup - lease-driven query may yield multiple rows per lease (one per crew member). # Count each running lease only once. @@ -402,11 +466,18 @@ def list_resource_providers() -> dict[str, Any]: seen_running_leases.add(lease_id) session_metrics = _to_session_metrics(snapshot_by_lease.get(lease_id)) owner = owners.get(thread_id, {"member_id": None, "member_name": "未绑定Agent"}) + session_identity = _resource_session_identity(session) + # @@@resource-session-dedup - terminal fallback can surface multiple + # monitor rows for the same lease/thread binding. The overview + # contract is one session row per stable session identity. + if session_identity in seen_session_ids: + continue + seen_session_ids.add(session_identity) normalized_sessions.append( { # @@@resource-session-identity - monitor rows can legitimately have empty chat session ids. # Use stable lease+thread identity so React keys do not collapse when one lease has multiple threads. - "id": str(session.get("session_id") or f"{lease_id}:{thread_id or 'unbound'}"), + "id": session_identity, "leaseId": lease_id, "threadId": thread_id, "memberId": str(owner.get("member_id") or ""), @@ -469,6 +540,36 @@ def list_resource_providers() -> dict[str, Any]: return {"summary": summary, "providers": providers} +def visible_resource_session_stats() -> dict[str, dict[str, int]]: + """Return the current user-visible session/running counts per provider.""" + repo = make_sandbox_monitor_repo() + try: + raw_sessions = repo.list_sessions_with_leases() + sessions = _project_user_visible_resource_sessions(repo, raw_sessions) + finally: + repo.close() + + stats: dict[str, dict[str, int]] = {} + seen_session_ids: set[str] = set() + seen_running_leases: set[tuple[str, str]] = set() + for session in sessions: + provider_instance = str(session.get("provider") or "local") + provider_stats = stats.setdefault(provider_instance, {"sessions": 0, "running": 0}) + session_identity = _resource_session_identity(session) + if session_identity not in seen_session_ids: + seen_session_ids.add(session_identity) + provider_stats["sessions"] += 1 + + lease_id = str(session.get("lease_id") or "") + normalized = map_lease_to_session_status(session.get("observed_state"), session.get("desired_state")) + running_identity = (provider_instance, lease_id) + if normalized == "running" and lease_id and running_identity not in seen_running_leases: + seen_running_leases.add(running_identity) + provider_stats["running"] += 1 + + return stats + + # --------------------------------------------------------------------------- # Public API: sandbox filesystem browse # --------------------------------------------------------------------------- diff --git a/backend/web/services/sandbox_service.py b/backend/web/services/sandbox_service.py index 2e5e06cf0..d43227225 100644 --- a/backend/web/services/sandbox_service.py +++ b/backend/web/services/sandbox_service.py @@ -77,10 +77,11 @@ def list_user_leases( "cwd": row.get("cwd"), "thread_ids": [], "agents": [], + "_seen_member_ids": set(), }, ) thread_id = str(row.get("thread_id") or "").strip() - if not thread_id or thread_id in group["thread_ids"]: + if not _is_user_visible_lease_thread(thread_id) or thread_id in group["thread_ids"]: continue thread = _thread_repo.get_by_id(thread_id) if thread is None: @@ -89,18 +90,21 @@ def list_user_leases( if member is None or member.owner_user_id != user_id: continue group["thread_ids"].append(thread_id) - group["agents"].append( - { - "member_id": member.id, - "member_name": member.name, - "avatar_url": avatar_url(member.id, bool(member.avatar)), - } - ) + if member.id not in group["_seen_member_ids"]: + group["_seen_member_ids"].add(member.id) + group["agents"].append( + { + "member_id": member.id, + "member_name": member.name, + "avatar_url": avatar_url(member.id, bool(member.avatar)), + } + ) if not group["cwd"] and row.get("cwd"): group["cwd"] = row.get("cwd") leases: list[dict[str, Any]] = [] for lease in grouped.values(): + lease.pop("_seen_member_ids", None) if not lease["thread_ids"]: continue provider_name = lease["provider_name"] @@ -123,6 +127,17 @@ def list_user_leases( monitor_repo.close() +def _is_user_visible_lease_thread(thread_id: str | None) -> bool: + raw = str(thread_id or "").strip() + if not raw: + return False + if raw.startswith("subagent-"): + return False + if is_virtual_thread_id(raw): + return False + return True + + def available_sandbox_types() -> list[dict[str, Any]]: """Scan ~/.leon/sandboxes/ for configured providers.""" providers, _ = init_providers_and_managers() @@ -142,6 +157,16 @@ def available_sandbox_types() -> list[dict[str, Any]]: try: config = SandboxConfig.load(name) provider_obj = providers.get(name) + if provider_obj is None: + types.append( + { + "name": name, + "provider": config.provider, + "available": False, + "reason": f"Provider {name} is configured but unavailable in the current process", + } + ) + continue item: dict[str, Any] = { "name": name, "provider": config.provider, @@ -194,6 +219,8 @@ def _build_providers_and_managers() -> tuple[dict[str, Any], dict[str, Any]]: default_context_path=config.agentbay.context_path, image_id=config.agentbay.image_id, provider_name=name, + supports_pause=config.agentbay.supports_pause, + supports_resume=config.agentbay.supports_resume, ) elif config.provider == "docker": from sandbox.providers.docker import DockerProvider diff --git a/backend/web/services/streaming_service.py b/backend/web/services/streaming_service.py index 9e6e71a77..5992e4ca7 100644 --- a/backend/web/services/streaming_service.py +++ b/backend/web/services/streaming_service.py @@ -13,11 +13,22 @@ from backend.web.services.event_store import cleanup_old_runs from backend.web.utils.serializers import extract_text_content from core.runtime.middleware.monitor import AgentState +from core.runtime.notifications import is_terminal_background_notification from sandbox.thread_context import set_current_run_id, set_current_thread_id from storage.contracts import RunEventRepo logger = logging.getLogger(__name__) +_TERMINAL_FOLLOWTHROUGH_SYSTEM_NOTE = ( + "Terminal background completion notifications require an explicit assistant followthrough. " + "Treat these notifications as fresh inputs that need a visible assistant reply. " + "You must produce at least one visible assistant message for them; " + "do not stay silent and do not end the run after only surfacing a notice. " + "Do not call TaskOutput or TaskStop for a terminal notification. " + "If no further tool is truly needed, answer directly in natural language " + "and briefly acknowledge the completion, failure, or cancellation honestly." +) + def _resolve_run_event_repo(agent: Any) -> RunEventRepo | None: storage_container = getattr(agent, "storage_container", None) @@ -28,6 +39,18 @@ def _resolve_run_event_repo(agent: Any) -> RunEventRepo | None: return storage_container.run_event_repo() +def _augment_system_prompt_for_terminal_followthrough(system_prompt: Any) -> Any: + content = getattr(system_prompt, "content", None) + if not isinstance(content, str): + return system_prompt + if _TERMINAL_FOLLOWTHROUGH_SYSTEM_NOTE in content: + return system_prompt + # @@@terminal-followthrough-system-note - live models can otherwise treat + # terminal background notifications as internal reminders and emit no + # assistant text, leaving caller surfaces notice-only. + return system_prompt.__class__(content=f"{content}\n\n{_TERMINAL_FOLLOWTHROUGH_SYSTEM_NOTE}") + + async def prime_sandbox(agent: Any, thread_id: str) -> None: """Prime sandbox session before tool calls to avoid race conditions.""" @@ -256,8 +279,7 @@ def _ensure_thread_handlers(agent: Any, thread_id: str, app: Any) -> None: runtime = getattr(agent, "runtime", None) if not runtime: return - # Already bound? Skip. - if getattr(runtime, "_activity_sink", None) is not None: + if getattr(runtime, "_bound_thread_id", None) == thread_id and getattr(runtime, "_bound_thread_app", None) is app: return # Runtime must support bind_thread (AgentRuntime does, test fakes may not) if not hasattr(runtime, "bind_thread"): @@ -288,6 +310,7 @@ async def activity_sink(event: dict) -> None: if event_type and isinstance(data, dict): delta = display_builder_ref.apply_event(thread_id, event_type, data) if delta: + delta["_seq"] = seq await thread_buf.put( { "event": "display_delta", @@ -373,6 +396,8 @@ async def _start_run(): agent.runtime.transition(AgentState.IDLE) runtime.bind_thread(activity_sink=activity_sink) + runtime._bound_thread_id = thread_id + runtime._bound_thread_app = app qm.register_wake(thread_id, wake_handler) # Subscribe to EventBus so sub-agent events (spawned via AgentService) @@ -380,11 +405,221 @@ async def _start_run(): try: from backend.web.event_bus import get_event_bus - get_event_bus().subscribe(thread_id, activity_sink) + unsubscribe = getattr(runtime, "_thread_event_unsubscribe", None) + if callable(unsubscribe): + unsubscribe() + runtime._thread_event_unsubscribe = get_event_bus().subscribe(thread_id, activity_sink) except ImportError: pass +def _is_terminal_background_notification_message( + message: str, + *, + source: str | None, + notification_type: str | None, +) -> bool: + return is_terminal_background_notification( + message, + source=source, + notification_type=notification_type, + ) + + +def _partition_terminal_followups(items: list[Any]) -> tuple[list[Any], list[Any]]: + terminal = [] + passthrough = [] + for item in items: + if _is_terminal_background_notification_message( + item.content, + source=item.source or "system", + notification_type=item.notification_type, + ): + terminal.append(item) + else: + passthrough.append(item) + return terminal, passthrough + + +def _message_metadata_dict(message_metadata: dict[str, Any] | None) -> dict[str, Any]: + return dict(message_metadata or {}) + + +def _message_already_persisted(message: Any, *, content: str, metadata: dict[str, Any]) -> bool: + if message.__class__.__name__ != "HumanMessage": + return False + if getattr(message, "content", None) != content: + return False + return (getattr(message, "metadata", None) or {}) == metadata + + +async def _persist_cancelled_run_input_if_missing( + *, + agent: Any, + config: dict[str, Any], + message: str, + message_metadata: dict[str, Any] | None, +) -> None: + graph = getattr(agent, "agent", None) + if graph is None or not hasattr(graph, "aget_state") or not hasattr(graph, "aupdate_state"): + return + + from langchain_core.messages import HumanMessage + + metadata = _message_metadata_dict(message_metadata) + state = await graph.aget_state(config) + persisted = list((getattr(state, "values", None) or {}).get("messages", [])) + if persisted and _message_already_persisted(persisted[-1], content=message, metadata=metadata): + return + + # @@@cancelled-run-input-persist - a started run has already accepted this + # input at the caller boundary. If cancellation lands before the next loop + # checkpoint save, persist the input here so later turns do not pretend it + # never happened. + candidate = HumanMessage(content=message, metadata=metadata) if metadata else HumanMessage(content=message) + await graph.aupdate_state(config, {"messages": [candidate]}) + + +def _is_owner_steer_followup_message( + *, + source: str | None, + notification_type: str | None, +) -> bool: + return source == "owner" and notification_type == "steer" + + +async def _persist_cancelled_owner_steers( + *, + agent: Any, + config: dict[str, Any], + items: list[dict[str, str | None]], +) -> None: + graph = getattr(agent, "agent", None) + if graph is None or not hasattr(graph, "aupdate_state") or not items: + return + + from langchain_core.messages import HumanMessage + + # @@@cancelled-steer-persist - accepted steer is a real user turn. If the + # active run is cancelled before the next model call, we must checkpoint it + # now instead of letting it silently relaunch as a ghost instruction. + await graph.aupdate_state( + config, + { + "messages": [ + HumanMessage( + content=str(item["content"] or ""), + metadata={ + "source": "owner", + "notification_type": "steer", + "is_steer": True, + }, + ) + for item in items + ] + }, + ) + + +async def _flush_cancelled_owner_steers( + *, + agent: Any, + config: dict[str, Any], + thread_id: str, + app: Any, +) -> None: + qm = app.state.queue_manager + queued_items = qm.drain_all(thread_id) + if not queued_items: + return + + owner_steers: list[dict[str, str | None]] = [] + passthrough: list[Any] = [] + for item in queued_items: + if _is_owner_steer_followup_message( + source=item.source, + notification_type=item.notification_type, + ): + owner_steers.append( + { + "content": item.content, + "source": item.source or "owner", + "notification_type": item.notification_type, + } + ) + else: + passthrough.append(item) + + await _persist_cancelled_owner_steers(agent=agent, config=config, items=owner_steers) + + for item in passthrough: + qm.enqueue( + item.content, + thread_id, + notification_type=item.notification_type, + source=item.source, + sender_id=item.sender_id, + sender_name=item.sender_name, + sender_avatar_url=item.sender_avatar_url, + is_steer=item.is_steer, + ) + + +async def _emit_queued_terminal_followups( + *, + app: Any, + thread_id: str, + emit: Any, +) -> list[dict[str, str | None]]: + emitted_terminal: list[dict[str, str | None]] = [] + + async def _drain_once() -> bool: + queued_items = app.state.queue_manager.drain_all(thread_id) + extra_terminal, passthrough = _partition_terminal_followups(queued_items) + for item in passthrough: + app.state.queue_manager.enqueue( + item.content, + thread_id, + notification_type=item.notification_type, + source=item.source, + sender_id=item.sender_id, + sender_name=item.sender_name, + sender_avatar_url=item.sender_avatar_url, + is_steer=item.is_steer, + ) + for item in extra_terminal: + await emit( + { + "event": "notice", + "data": json.dumps( + { + "content": item.content, + "source": item.source or "system", + "notification_type": item.notification_type, + }, + ensure_ascii=False, + ), + } + ) + emitted_terminal.append( + { + "content": item.content, + "source": item.source or "system", + "notification_type": item.notification_type, + } + ) + return bool(extra_terminal) + + # @@@terminal-followup-race-window - multiple background tasks can finish + # while the first notice-only followthrough run is being emitted. Drain once + # for already-persisted notices, yield one loop tick, then drain again so + # same-turn terminal completions are folded into the same stable followthrough. + await _drain_once() + await asyncio.sleep(0) + await _drain_once() + return emitted_terminal + + # --------------------------------------------------------------------------- # Producer: runs agent, writes events to ThreadEventBuffer # --------------------------------------------------------------------------- @@ -399,7 +634,8 @@ async def _run_agent_to_buffer( thread_buf: ThreadEventBuffer, run_id: str, message_metadata: dict[str, Any] | None = None, -) -> None: + input_messages: list[Any] | None = None, +) -> str: """Run agent execution and write all SSE events into *thread_buf*.""" from backend.web.services.event_store import append_event @@ -428,12 +664,16 @@ async def emit(event: dict, message_id: str | None = None) -> None: event = {**event, "data": json.dumps(data, ensure_ascii=False)} await thread_buf.put(event) - # Compute display delta and emit it (no _seq — avoids dedup conflict - # with the raw event that shares the same seq) + # Compute display delta and emit it alongside the raw event. event_type = event.get("event", "") if event_type and isinstance(data, dict): delta = display_builder.apply_event(thread_id, event_type, data) if delta: + # @@@display-delta-source-seq - replay after-filter only knows raw + # event seqs. Carry the source seq onto the derived delta so a + # reconnect after GET /thread can skip stale display_delta + # replays instead of rebuilding the same thread a second time. + delta["_seq"] = seq await thread_buf.put( { "event": "display_delta", @@ -444,6 +684,7 @@ async def emit(event: dict, message_id: str | None = None) -> None: task = None stream_gen = None pending_tool_calls: dict[str, dict] = {} + output_parts: list[str] = [] try: config = {"configurable": {"thread_id": thread_id, "run_id": run_id}} if hasattr(agent, "_current_model_config"): @@ -625,9 +866,10 @@ def on_activity_event(event: dict) -> None: ) # @@@run-notice — emit notice right after run_start so frontend folds it - # into the (re)opened turn. Only for external notifications (not owner steer). + # into the (re)opened turn. Mirror the cold-path DisplayBuilder rule: + # any source=system message is a notice; external notices stay chat-only. ntype = meta.get("notification_type") - if src and src != "owner" and ntype == "chat": + if src == "system" or (src == "external" and ntype == "chat"): await emit( { "event": "notice", @@ -642,7 +884,46 @@ def on_activity_event(event: dict) -> None: } ) - if message_metadata: + terminal_followthrough_items: list[dict[str, str | None]] | None = None + original_system_prompt = None + # @@@terminal-followthrough-reentry - terminal background completions + # still surface as durable notices first, but they must then re-enter the + # model as a real followthrough turn instead of terminating at notice-only. + if _is_terminal_background_notification_message( + message, + source=src, + notification_type=ntype, + ): + terminal_followthrough_items = [ + { + "content": message, + "source": src or "system", + "notification_type": ntype, + } + ] + terminal_followthrough_items.extend(await _emit_queued_terminal_followups(app=app, thread_id=thread_id, emit=emit)) + if hasattr(agent, "agent") and hasattr(agent.agent, "system_prompt"): + original_system_prompt = agent.agent.system_prompt + agent.agent.system_prompt = _augment_system_prompt_for_terminal_followthrough(original_system_prompt) + + if terminal_followthrough_items: + from langchain_core.messages import HumanMessage + + _initial_input = { + "messages": [ + HumanMessage( + content=str(item["content"] or ""), + metadata={ + "source": item["source"] or "system", + "notification_type": item["notification_type"], + }, + ) + for item in terminal_followthrough_items + ] + } + elif input_messages is not None: + _initial_input = {"messages": input_messages} + elif message_metadata: from langchain_core.messages import HumanMessage _initial_input: dict | None = {"messages": [HumanMessage(content=message, metadata=message_metadata)]} @@ -725,7 +1006,7 @@ def _is_retryable_stream_error(err: Exception) -> bool: mode, data = chunk if mode == "messages": - msg_chunk, metadata = data + msg_chunk, _metadata = data msg_class = msg_chunk.__class__.__name__ if msg_class == "AIMessageChunk": # @@@compact-leak-guard — skip chunks from compact's summary LLM call. @@ -735,6 +1016,7 @@ def _is_retryable_stream_error(err: Exception) -> bool: content = extract_text_content(getattr(msg_chunk, "content", "")) chunk_msg_id = getattr(msg_chunk, "id", None) if content: + output_parts.append(content) await emit( { "event": "text", @@ -792,14 +1074,13 @@ def _is_retryable_stream_error(err: Exception) -> bool: msg_class = msg.__class__.__name__ if msg_class == "HumanMessage": - # @@@mid-turn-chat-notice — emit notice for chat - # notifications injected by before_model. display_builder - # folds it into the current turn as a segment (same as - # cold-path checkpoint rebuild behavior). + # @@@mid-turn-notice-parity — hot streaming must use the + # same notice contract as cold checkpoint rebuild: + # source=system always folds as notice; external stays + # limited to chat notifications. meta = getattr(msg, "metadata", None) or {} - if meta.get("notification_type") == "chat" and meta.get("source") in ( - "external", - "system", + if meta.get("source") == "system" or ( + meta.get("source") == "external" and meta.get("notification_type") == "chat" ): await emit( { @@ -808,7 +1089,7 @@ def _is_retryable_stream_error(err: Exception) -> bool: { "content": msg.content if isinstance(msg.content, str) else str(msg.content), "source": meta.get("source", "external"), - "notification_type": "chat", + "notification_type": meta.get("notification_type"), }, ensure_ascii=False, ), @@ -861,8 +1142,11 @@ def _is_retryable_stream_error(err: Exception) -> bool: continue if tc_id: pending_tool_calls.pop(tc_id, None) - if hasattr(msg, "metadata") and isinstance(msg.metadata, dict): - msg.metadata["run_id"] = run_id + merged_meta = dict(getattr(msg, "metadata", None) or {}) + tool_result_meta = getattr(msg, "additional_kwargs", {}).get("tool_result_meta") + if isinstance(tool_result_meta, dict): + merged_meta = {**tool_result_meta, **merged_meta} + merged_meta["run_id"] = run_id tool_name = getattr(msg, "name", "") or "" await emit( { @@ -872,7 +1156,7 @@ def _is_retryable_stream_error(err: Exception) -> bool: "tool_call_id": tc_id, "name": tool_name, "content": str(getattr(msg, "content", "")), - "metadata": getattr(msg, "metadata", None) or {}, + "metadata": merged_meta, "showing": True, }, ensure_ascii=False, @@ -954,8 +1238,21 @@ def _is_retryable_stream_error(err: Exception) -> bool: # A5: emit run_done instead of done (persistent buffer — no mark_done) await emit({"event": "run_done", "data": json.dumps({"thread_id": thread_id, "run_id": run_id})}) + return "".join(output_parts).strip() except asyncio.CancelledError: cancelled_tool_call_ids = await write_cancellation_markers(agent, config, pending_tool_calls) + await _persist_cancelled_run_input_if_missing( + agent=agent, + config=config, + message=message, + message_metadata=message_metadata, + ) + await _flush_cancelled_owner_steers( + agent=agent, + config=config, + thread_id=thread_id, + app=app, + ) await emit( { "event": "cancelled", @@ -969,11 +1266,15 @@ def _is_retryable_stream_error(err: Exception) -> bool: ) # Also emit run_done so frontend knows the run ended await emit({"event": "run_done", "data": json.dumps({"thread_id": thread_id, "run_id": run_id})}) + return "" except Exception as e: traceback.print_exc() await emit({"event": "error", "data": json.dumps({"error": str(e)}, ensure_ascii=False)}) await emit({"event": "run_done", "data": json.dumps({"thread_id": thread_id, "run_id": run_id})}) + return "" finally: + if original_system_prompt is not None and hasattr(agent, "agent") and hasattr(agent.agent, "system_prompt"): + agent.agent.system_prompt = original_system_prompt # @@@typing-lifecycle-stop — guaranteed cleanup even on crash/cancel typing_tracker = getattr(app.state, "typing_tracker", None) if typing_tracker is not None: @@ -1036,22 +1337,29 @@ async def _consume_followup_queue(agent: Any, thread_id: str, app: Any) -> None: item = None try: qm = app.state.queue_manager + if not qm.peek(thread_id) or not app: + return + if not (hasattr(agent, "runtime") and agent.runtime.transition(AgentState.ACTIVE)): + return item = qm.dequeue(thread_id) - if item and app: - if hasattr(agent, "runtime") and agent.runtime.transition(AgentState.ACTIVE): - start_agent_run( - agent, - thread_id, - item.content, - app, - message_metadata={ - "source": item.source or "system", - "notification_type": item.notification_type, - "sender_name": item.sender_name, - "sender_avatar_url": item.sender_avatar_url, - "is_steer": getattr(item, "is_steer", False), - }, - ) + if item is None: + logger.warning("followup dequeue lost race for thread %s; reverting to IDLE", thread_id) + if hasattr(agent, "runtime"): + agent.runtime.transition(AgentState.IDLE) + return + start_agent_run( + agent, + thread_id, + item.content, + app, + message_metadata={ + "source": item.source or "system", + "notification_type": item.notification_type, + "sender_name": item.sender_name, + "sender_avatar_url": item.sender_avatar_url, + "is_steer": getattr(item, "is_steer", False), + }, + ) except Exception: logger.exception("Failed to consume followup queue for thread %s", thread_id) # Re-enqueue the message if it was already dequeued to prevent data loss @@ -1074,18 +1382,90 @@ def start_agent_run( app: Any, enable_trajectory: bool = False, message_metadata: dict[str, Any] | None = None, + input_messages: list[Any] | None = None, ) -> str: """Launch agent producer on the persistent ThreadEventBuffer. Returns run_id.""" thread_buf = get_or_create_thread_buffer(app, thread_id) run_id = str(_uuid.uuid4()) bg_task = asyncio.create_task( - _run_agent_to_buffer(agent, thread_id, message, app, enable_trajectory, thread_buf, run_id, message_metadata) + _run_agent_to_buffer( + agent, + thread_id, + message, + app, + enable_trajectory, + thread_buf, + run_id, + message_metadata, + input_messages, + ) ) # Store the background task so cancel_run can still cancel it app.state.thread_tasks[thread_id] = bg_task return run_id +async def run_child_thread_live( + agent: Any, + thread_id: str, + message: str, + app: Any, + *, + input_messages: list[Any], +) -> str: + """Run a spawned child agent through the normal web thread bridge.""" + from backend.web.services.agent_pool import resolve_thread_sandbox + from backend.web.utils.serializers import extract_text_content + + sandbox_type = resolve_thread_sandbox(app, thread_id) + app.state.agent_pool[f"{thread_id}:{sandbox_type}"] = agent + thread_buf = get_or_create_thread_buffer(app, thread_id) + error_cursor = thread_buf.total_count + _ensure_thread_handlers(agent, thread_id, app) + if not (hasattr(agent, "runtime") and agent.runtime.transition(AgentState.ACTIVE)): + raise RuntimeError(f"Child thread {thread_id} could not transition to active") + + start_agent_run( + agent, + thread_id, + message, + app, + input_messages=input_messages, + ) + task = app.state.thread_tasks[thread_id] + result = await task + recent_events, _ = await thread_buf.read_with_timeout(error_cursor, timeout=0.01) + if recent_events: + # @@@child-live-error-surfacing - child live runs can emit an error event + # and still return an empty string from _run_agent_to_buffer(); treat that + # as a real child failure instead of laundering it into fake completion. + for event in recent_events: + if event.get("event") != "error": + continue + try: + payload = json.loads(event.get("data", "{}")) + except (json.JSONDecodeError, TypeError): + payload = {} + error_text = payload.get("error") if isinstance(payload, dict) else None + raise RuntimeError(error_text or f"Child thread {thread_id} failed") + if isinstance(result, str) and result.strip(): + return result.strip() + + state = await agent.agent.aget_state({"configurable": {"thread_id": thread_id}}) + values = getattr(state, "values", {}) if state else {} + messages = values.get("messages", []) if isinstance(values, dict) else [] + visible_ai = [ + extract_text_content(getattr(msg, "content", "")).strip() + for msg in messages + if msg.__class__.__name__ == "AIMessage" and extract_text_content(getattr(msg, "content", "")).strip() + ] + runtime_status = agent.runtime.get_status_dict() if hasattr(agent, "runtime") and hasattr(agent.runtime, "get_status_dict") else {} + runtime_calls = runtime_status.get("calls") if isinstance(runtime_status, dict) else None + if not visible_ai and runtime_calls == 0: + raise RuntimeError(f"Child thread {thread_id} failed before first model call") + return "\n".join(visible_ai) if visible_ai else "(Agent completed with no text output)" + + # --------------------------------------------------------------------------- # Consumer: persistent thread event stream # --------------------------------------------------------------------------- @@ -1101,40 +1481,12 @@ async def observe_thread_events( disconnect (or server shutdown) closes the connection. run_done is a flow event, not a terminal signal. """ - yield {"retry": 5000} - # Always start from the beginning of the ring buffer. # For after=0 (new connection): replay all buffered events so we never miss # events emitted between postRun and SSE connect (race condition fix). # For after>0 (reconnect): start from ring start, filter by _seq below. - cursor = 0 - - while True: - events, cursor = await thread_buf.read_with_timeout(cursor, timeout=30) - if events is None: - yield {"comment": "keepalive"} - continue - if not events: - continue - for event in events: - parsed_data = None - try: - parsed_data = json.loads(event.get("data", "{}")) - except (json.JSONDecodeError, TypeError): - pass - - # @@@after-filter — skip events already seen on reconnect. - # Events without _seq (e.g. display_delta) are never filtered — - # they are ephemeral derivatives of persisted events. - if after > 0 and isinstance(parsed_data, dict) and "_seq" in parsed_data: - if parsed_data["_seq"] <= after: - continue - - seq_id = str(parsed_data["_seq"]) if isinstance(parsed_data, dict) and "_seq" in parsed_data else None - if seq_id: - yield {**event, "id": seq_id} - else: - yield event + async for event in _observe_sse_buffer(thread_buf, after=after, stop_on_finish=False): + yield event async def observe_run_events( @@ -1142,6 +1494,17 @@ async def observe_run_events( after: int = 0, ) -> AsyncGenerator[dict[str, str], None]: """Consume events from a RunEventBuffer (subagent streams only). Yields SSE event dicts.""" + async for event in _observe_sse_buffer(buf, after=after, stop_on_finish=True): + yield event + + +async def _observe_sse_buffer( + buf: ThreadEventBuffer | RunEventBuffer, + *, + after: int, + stop_on_finish: bool, +) -> AsyncGenerator[dict[str, str], None]: + """Shared SSE observer loop for thread and run buffers.""" yield {"retry": 5000} cursor = 0 @@ -1150,7 +1513,7 @@ async def observe_run_events( if events is None and not buf.finished.is_set(): yield {"comment": "keepalive"} continue - if not events and buf.finished.is_set(): + if stop_on_finish and not events and buf.finished.is_set(): break if not events: continue @@ -1162,8 +1525,8 @@ async def observe_run_events( pass # @@@after-filter — skip events already seen on reconnect. - # Events without _seq (e.g. display_delta) are never filtered — - # they are ephemeral derivatives of persisted events. + # display_delta now carries the source raw-event seq too, so stale + # derived deltas are filtered together with their persisted source. if after > 0 and isinstance(parsed_data, dict) and "_seq" in parsed_data: if parsed_data["_seq"] <= after: continue diff --git a/backend/web/services/task_service.py b/backend/web/services/task_service.py index 86197b584..af041dc03 100644 --- a/backend/web/services/task_service.py +++ b/backend/web/services/task_service.py @@ -3,6 +3,7 @@ from typing import Any from backend.web.core.storage_factory import make_panel_task_repo +from storage.runtime import build_thread_repo def _repo() -> Any: @@ -12,11 +13,35 @@ def _repo() -> Any: def list_tasks() -> list[dict[str, Any]]: repo = _repo() try: - return repo.list_all() + return _enrich_task_thread_members(repo.list_all()) finally: repo.close() +def _enrich_task_thread_members(tasks: list[dict[str, Any]]) -> list[dict[str, Any]]: + thread_ids = [str(task.get("thread_id") or "").strip() for task in tasks] + thread_ids = [thread_id for thread_id in dict.fromkeys(thread_ids) if thread_id] + if not thread_ids: + return tasks + + # @@@task-thread-member-enrichment - panel tasks persist thread_id only, so enrich member_id + # from canonical thread metadata before frontend deep-links are rendered. + thread_repo = build_thread_repo() + try: + member_ids = {thread_id: (thread_repo.get_by_id(thread_id) or {}).get("member_id") for thread_id in thread_ids} + finally: + thread_repo.close() + + enriched: list[dict[str, Any]] = [] + for task in tasks: + thread_id = str(task.get("thread_id") or "").strip() + if thread_id and member_ids.get(thread_id): + enriched.append({**task, "member_id": member_ids[thread_id]}) + else: + enriched.append(task) + return enriched + + def get_task(task_id: str) -> dict[str, Any] | None: repo = _repo() try: diff --git a/backend/web/services/wechat_service.py b/backend/web/services/wechat_service.py deleted file mode 100644 index b19261d79..000000000 --- a/backend/web/services/wechat_service.py +++ /dev/null @@ -1,517 +0,0 @@ -"""WeChat connection service — ilink API client + connection lifecycle + background poll. - -Uses the official WeChat ClawBot ilink API at ilinkai.weixin.qq.com. -Protocol: HTTP/JSON long-polling, modeled after Telegram Bot API. -Auth: Bearer token obtained via QR code scan. - -@@@per-user — each human user_id gets its own WeChatConnection. -user_id is the social identity in Leon's network (Supabase auth UUID for humans). -Polling auto-starts at backend boot via lifespan.py for all users with saved credentials. - -@@@no-globals — WeChatConnectionRegistry lives on app.state, not module-level. -""" - -import asyncio -import json -import logging -import os -import random -import struct -import time -from base64 import b64encode -from collections.abc import Awaitable, Callable -from pathlib import Path -from typing import Literal - -import httpx -from pydantic import BaseModel - -from config.user_paths import user_home_path, user_home_read_candidates - -logger = logging.getLogger(__name__) - -DEFAULT_BASE_URL = "https://ilinkai.weixin.qq.com" -BOT_TYPE = "3" -CHANNEL_VERSION = "0.1.0" -LONG_POLL_TIMEOUT_S = 35 -SEND_TIMEOUT_S = 15 - -MSG_TYPE_USER = 1 -MSG_TYPE_BOT = 2 -MSG_ITEM_TEXT = 1 -MSG_ITEM_VOICE = 3 -MSG_STATE_FINISH = 2 - -CONNECTIONS_BASE = user_home_path("connections", "wechat") - -RoutingType = Literal["thread", "chat"] - -# @@@delivery-callback — injected at construction, avoids circular import of app -DeliveryFn = Callable[["WeChatConnection", "WeChatMessage"], Awaitable[None]] - - -# --- Pydantic models for API --- - - -class WeChatCredentials(BaseModel): - token: str - base_url: str = DEFAULT_BASE_URL - account_id: str - user_id: str = "" - saved_at: str = "" - - -class RoutingConfig(BaseModel): - type: RoutingType | None = None - id: str | None = None - label: str = "" - - -class QrPollRequest(BaseModel): - qrcode: str - - -class RoutingSetRequest(BaseModel): - type: RoutingType - id: str - label: str = "" - - -class WeChatMessage(BaseModel): - from_user_id: str - text: str - context_token: str - - class Config: - frozen = True - - -class WeChatAPIError(Exception): - pass - - -class SessionExpiredError(WeChatAPIError): - pass - - -# --- ilink protocol helpers --- - - -def _random_wechat_uin() -> str: - val = struct.unpack(">I", os.urandom(4))[0] - return b64encode(str(val).encode()).decode() - - -def _build_headers(token: str | None = None, body: str | None = None) -> dict[str, str]: - headers: dict[str, str] = { - "Content-Type": "application/json", - "AuthorizationType": "ilink_bot_token", - "X-WECHAT-UIN": _random_wechat_uin(), - } - if body: - headers["Content-Length"] = str(len(body.encode())) - if token: - headers["Authorization"] = f"Bearer {token.strip()}" - return headers - - -def _extract_text(msg: dict) -> str: - items = msg.get("item_list") or [] - for item in items: - if item.get("type") == MSG_ITEM_TEXT: - text = (item.get("text_item") or {}).get("text", "") - ref = item.get("ref_msg") - if ref and ref.get("title"): - return f"[引用: {ref['title']}]\n{text}" - return text - if item.get("type") == MSG_ITEM_VOICE: - return (item.get("voice_item") or {}).get("text", "") - return "" - - -# --- Per-user persistence (keyed by user_id) --- - - -def _user_dir(user_id: str) -> Path: - return CONNECTIONS_BASE / user_id - - -def _user_dir_candidates(user_id: str) -> tuple[Path, ...]: - return tuple(path / user_id for path in user_home_read_candidates("connections", "wechat")) - - -def _save_json(user_id: str, filename: str, data: dict) -> None: - d = _user_dir(user_id) - d.mkdir(parents=True, exist_ok=True) - path = d / filename - path.write_text(json.dumps(data, indent=2)) - if filename == "credentials.json": - path.chmod(0o600) - - -def _load_json(user_id: str, filename: str) -> dict | None: - for path in reversed(_user_dir_candidates(user_id)): - candidate = path / filename - if not candidate.exists(): - continue - try: - return json.loads(candidate.read_text()) - except (json.JSONDecodeError, KeyError) as e: - logger.error("Failed to load %s for %s: %s", filename, user_id[:12], e) - return None - - -def _delete_file(user_id: str, filename: str) -> None: - seen: set[Path] = set() - for user_dir in _user_dir_candidates(user_id): - path = user_dir / filename - if path in seen: - continue - seen.add(path) - if path.exists(): - path.unlink() - - -def migrate_entity_id_dirs() -> None: - """Startup migration: rename {user_id}-1/ → {user_id}/ for existing connections.""" - if not CONNECTIONS_BASE.exists(): - return - for user_dir in list(CONNECTIONS_BASE.iterdir()): - if not user_dir.is_dir(): - continue - name = user_dir.name - # Old entity_id format was "{user_id}-1" — strip the suffix - if name.endswith("-1"): - new_name = name[:-2] - new_dir = CONNECTIONS_BASE / new_name - if not new_dir.exists(): - try: - user_dir.rename(new_dir) - logger.info("Migrated WeChat dir: %s → %s", name, new_name) - except Exception as e: - logger.error("Failed to migrate WeChat dir %s: %s", name, e) - - -# --- WeChatConnection (one per human user) --- - - -class WeChatConnection: - """A single user's WeChat connection. Keyed by user_id.""" - - def __init__(self, user_id: str, delivery_fn: DeliveryFn | None = None) -> None: - self.user_id = user_id - self._delivery_fn = delivery_fn - self._credentials: WeChatCredentials | None = None - self._context_tokens: dict[str, str] = {} - self._sync_buf: str = "" - self._poll_task: asyncio.Task | None = None - self._routing = RoutingConfig() - # @@@no-proxy — trust_env=False prevents httpx from inheriting - # http_proxy/all_proxy which causes bimodal latency on long-poll. - self._http = httpx.AsyncClient( - timeout=httpx.Timeout(LONG_POLL_TIMEOUT_S + 5), - trust_env=False, - ) - - # Load persisted state - routing_data = _load_json(user_id, "routing.json") - if routing_data: - try: - self._routing = RoutingConfig(**routing_data) - except Exception: - pass - - ctx = _load_json(user_id, "context_tokens.json") - if ctx: - self._context_tokens = ctx - - creds_data = _load_json(user_id, "credentials.json") - if creds_data: - try: - self._credentials = WeChatCredentials(**creds_data) - logger.info("Loaded WeChat credentials for user=%s", user_id[:12]) - except Exception as e: - logger.error("Invalid WeChat credentials for %s: %s", user_id[:12], e) - - @property - def connected(self) -> bool: - return self._credentials is not None - - @property - def polling(self) -> bool: - return self._poll_task is not None and not self._poll_task.done() - - @property - def routing(self) -> RoutingConfig: - return self._routing - - def set_routing(self, config: RoutingConfig) -> None: - self._routing = config - _save_json(self.user_id, "routing.json", config.model_dump()) - - def get_state(self) -> dict: - if not self._credentials: - return {"connected": False, "routing": self._routing.model_dump()} - return { - "connected": True, - "polling": self.polling, - "account_id": self._credentials.account_id, - "user_id": self._credentials.user_id, - "contact_count": len(self._context_tokens), - "contacts": self.list_contacts(), - "routing": self._routing.model_dump(), - } - - def list_contacts(self) -> list[dict[str, str]]: - return [{"user_id": uid, "display_name": uid.split("@")[0] or uid} for uid in self._context_tokens] - - # --- QR Login --- - - async def get_qr_code(self) -> dict: - url = f"{DEFAULT_BASE_URL}/ilink/bot/get_bot_qrcode?bot_type={BOT_TYPE}" - resp = await self._http.get(url, timeout=10) - resp.raise_for_status() - data = resp.json() - return {"qrcode": data["qrcode"], "qrcode_img_url": data["qrcode_img_content"]} - - async def poll_qr_status(self, qrcode: str) -> dict: - url = f"{DEFAULT_BASE_URL}/ilink/bot/get_qrcode_status?qrcode={qrcode}" - try: - resp = await self._http.get( - url, - headers={"iLink-App-ClientVersion": "1"}, - timeout=LONG_POLL_TIMEOUT_S + 5, - ) - resp.raise_for_status() - data = resp.json() - except httpx.TimeoutException: - return {"status": "wait"} - - status = data.get("status", "wait") - if status == "confirmed": - bot_token = data.get("bot_token") - bot_id = data.get("ilink_bot_id") - if not bot_token or not bot_id: - return {"status": "error", "message": "Missing bot credentials in response"} - creds = WeChatCredentials( - token=bot_token, - base_url=data.get("baseurl") or DEFAULT_BASE_URL, - account_id=bot_id, - user_id=data.get("ilink_user_id", ""), - saved_at=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - ) - self._credentials = creds - _save_json(self.user_id, "credentials.json", creds.model_dump()) - logger.info("WeChat connected for user=%s account=%s", self.user_id[:12], creds.account_id) - self.start_polling() - return {"status": "confirmed", "account_id": creds.account_id} - return {"status": status} - - # --- Disconnect --- - - def disconnect(self) -> None: - self.stop_polling() - self._credentials = None - self._context_tokens.clear() - self._sync_buf = "" - _delete_file(self.user_id, "credentials.json") - _delete_file(self.user_id, "context_tokens.json") - logger.info("WeChat disconnected for user=%s", self.user_id[:12]) - - async def close(self) -> None: - """Shutdown: stop polling + close HTTP client.""" - self.stop_polling() - await self._http.aclose() - - # --- Polling --- - - def start_polling(self) -> None: - if self.polling: - return - if not self._credentials: - raise RuntimeError("Cannot start polling: not connected") - self._poll_task = asyncio.create_task(self._poll_loop()) - logger.info("WeChat polling started for user=%s", self.user_id[:12]) - - def stop_polling(self) -> None: - if self._poll_task and not self._poll_task.done(): - self._poll_task.cancel() - self._poll_task = None - - async def _deliver_message(self, msg: WeChatMessage) -> None: - """Deliver via injected callback. No circular imports.""" - if not self._delivery_fn: - logger.warning("No delivery function configured for user=%s", self.user_id[:12]) - return - if not self._routing.type or not self._routing.id: - logger.debug("WeChat message not delivered — no routing configured") - return - try: - await self._delivery_fn(self, msg) - except Exception: - logger.exception("Failed to deliver WeChat message") - - async def _poll_loop(self) -> None: - consecutive_failures = 0 - while True: - try: - messages = await self._get_updates() - consecutive_failures = 0 - for msg in messages: - logger.info("WeChat[%s] from=%s: %s", self.user_id[:8], msg.from_user_id[:20], msg.text[:60]) - asyncio.create_task(self._deliver_message(msg)) - except asyncio.CancelledError: - return - except SessionExpiredError: - logger.error("WeChat session expired for user=%s", self.user_id[:12]) - self._credentials = None - _delete_file(self.user_id, "credentials.json") - return - except Exception: - consecutive_failures += 1 - logger.exception("WeChat poll error #%d user=%s", consecutive_failures, self.user_id[:12]) - if consecutive_failures >= 3: - consecutive_failures = 0 - await asyncio.sleep(30) - else: - await asyncio.sleep(2) - - async def _get_updates(self) -> list[WeChatMessage]: - if not self._credentials: - raise RuntimeError("Not connected") - body = json.dumps( - { - "get_updates_buf": self._sync_buf, - "base_info": {"channel_version": CHANNEL_VERSION}, - } - ) - headers = _build_headers(self._credentials.token, body) - try: - resp = await self._http.post( - f"{self._credentials.base_url}/ilink/bot/getupdates", - content=body, - headers=headers, - timeout=LONG_POLL_TIMEOUT_S + 5, - ) - resp.raise_for_status() - data = resp.json() - except httpx.TimeoutException: - return [] - - if data.get("ret", 0) != 0 or data.get("errcode", 0) != 0: - errcode = data.get("errcode", 0) - errmsg = data.get("errmsg", "") - if errcode == -14: - raise SessionExpiredError("Session expired") - raise WeChatAPIError(f"getUpdates: errcode={errcode} {errmsg}") - - if data.get("get_updates_buf"): - self._sync_buf = data["get_updates_buf"] - - messages = [] - tokens_changed = False - for msg in data.get("msgs") or []: - if msg.get("message_type") != MSG_TYPE_USER: - continue - text = _extract_text(msg) - if not text: - continue - sender = msg.get("from_user_id", "unknown") - ctx_token = msg.get("context_token", "") - if ctx_token: - self._context_tokens[sender] = ctx_token - tokens_changed = True - messages.append( - WeChatMessage( - from_user_id=sender, - text=text, - context_token=ctx_token, - ) - ) - if tokens_changed: - await asyncio.to_thread(_save_json, self.user_id, "context_tokens.json", self._context_tokens) - return messages - - # --- Send --- - - async def send_message(self, to_user_id: str, text: str) -> str: - if not self._credentials: - raise RuntimeError("WeChat not connected") - context_token = self._context_tokens.get(to_user_id) - if not context_token: - raise RuntimeError(f"No context_token for {to_user_id}. The user needs to message the bot first.") - client_id = f"leon:{int(time.time())}-{random.randint(0, 0xFFFF):04x}" - body = json.dumps( - { - "msg": { - "from_user_id": "", - "to_user_id": to_user_id, - "client_id": client_id, - "message_type": MSG_TYPE_BOT, - "message_state": MSG_STATE_FINISH, - "item_list": [{"type": MSG_ITEM_TEXT, "text_item": {"text": text}}], - "context_token": context_token, - }, - "base_info": {"channel_version": CHANNEL_VERSION}, - } - ) - headers = _build_headers(self._credentials.token, body) - resp = await self._http.post( - f"{self._credentials.base_url}/ilink/bot/sendmessage", - content=body, - headers=headers, - timeout=SEND_TIMEOUT_S, - ) - resp.raise_for_status() - return client_id - - -# --- WeChatConnectionRegistry (lives on app.state) --- - - -class WeChatConnectionRegistry: - """Manages per-user WeChatConnections. Lives on app.state, not module-level.""" - - def __init__(self, delivery_fn: DeliveryFn | None = None) -> None: - self._connections: dict[str, WeChatConnection] = {} - self._delivery_fn = delivery_fn - - def get(self, user_id: str) -> WeChatConnection: - if user_id not in self._connections: - self._connections[user_id] = WeChatConnection(user_id, self._delivery_fn) - return self._connections[user_id] - - def auto_start_all(self) -> None: - """Resume polling for all users with saved credentials on disk.""" - if not CONNECTIONS_BASE.exists(): - return - for user_dir in CONNECTIONS_BASE.iterdir(): - if user_dir.is_dir() and (user_dir / "credentials.json").exists(): - conn = self.get(user_dir.name) - if conn.connected and not conn.polling: - conn.start_polling() - - def evict_duplicates(self, account_id: str, keep_user_id: str) -> None: - """@@@unique-wechat — one WeChat account → one Leon user. Last one wins.""" - for uid, conn in list(self._connections.items()): - if uid == keep_user_id: - continue - if conn._credentials and conn._credentials.account_id == account_id: - logger.info("Evicting WeChat: user=%s (same account=%s)", uid[:12], account_id[:12]) - conn.disconnect() - - if CONNECTIONS_BASE.exists(): - for user_dir in CONNECTIONS_BASE.iterdir(): - if not user_dir.is_dir() or user_dir.name == keep_user_id: - continue - data = _load_json(user_dir.name, "credentials.json") - if data and data.get("account_id") == account_id: - logger.info("Evicting persisted WeChat: user=%s", user_dir.name[:12]) - _delete_file(user_dir.name, "credentials.json") - _delete_file(user_dir.name, "context_tokens.json") - - async def shutdown(self) -> None: - """Close all connections gracefully.""" - for conn in self._connections.values(): - await conn.close() - self._connections.clear() diff --git a/backend/web/utils/serializers.py b/backend/web/utils/serializers.py index 4c070f285..abeb8a856 100644 --- a/backend/web/utils/serializers.py +++ b/backend/web/utils/serializers.py @@ -38,7 +38,15 @@ def extract_text_content(raw_content: Any) -> str: def serialize_message(msg: Any) -> dict[str, Any]: """Serialize a LangChain message to a JSON-compatible dict.""" content = getattr(msg, "content", "") - metadata = getattr(msg, "metadata", None) or {} + metadata = dict(getattr(msg, "metadata", None) or {}) + additional_kwargs = getattr(msg, "additional_kwargs", None) or {} + tool_result_meta = additional_kwargs.get("tool_result_meta") + # @@@tool-result-meta-bridge - LangChain ToolMessage keeps durable tool + # metadata in additional_kwargs, but Leon display rebuild consumes + # serialized metadata. Merge the exact structured tool_result_meta here so + # checkpoint rebuild can recover blocking subagent identity honestly. + if isinstance(tool_result_meta, dict): + metadata = {**tool_result_meta, **metadata} # Strip system tags from owner HumanMessages (context-shift hints). # External HumanMessages keep their so frontend can diff --git a/config/defaults/tool_catalog.py b/config/defaults/tool_catalog.py index 294293874..1c2e67d2e 100644 --- a/config/defaults/tool_catalog.py +++ b/config/defaults/tool_catalog.py @@ -21,7 +21,9 @@ class ToolGroup(StrEnum): COMMAND = "command" WEB = "web" AGENT = "agent" + CHAT = "chat" TODO = "todo" + CRON = "cron" SKILLS = "skills" SYSTEM = "system" TASKBOARD = "taskboard" @@ -62,16 +64,26 @@ class ToolDef(BaseModel): ToolDef(name="TaskOutput", desc="获取后台任务输出", group=ToolGroup.AGENT), ToolDef(name="TaskStop", desc="停止后台任务", group=ToolGroup.AGENT), ToolDef(name="Agent", desc="启动子 Agent 执行任务", group=ToolGroup.AGENT), - ToolDef(name="SendMessage", desc="向其他 Agent 发送消息", group=ToolGroup.AGENT), + ToolDef(name="SendMessage", desc="向运行中的 Agent 发送排队消息", group=ToolGroup.AGENT), + # chat + ToolDef(name="list_chats", desc="列出当前实体可访问的聊天会话", group=ToolGroup.CHAT), + ToolDef(name="read_messages", desc="读取聊天消息并标记为已读", group=ToolGroup.CHAT), + ToolDef(name="send_message", desc="向聊天对象发送消息", group=ToolGroup.CHAT), + ToolDef(name="search_messages", desc="搜索历史聊天消息", group=ToolGroup.CHAT), # todo ToolDef(name="TaskCreate", desc="创建待办任务", group=ToolGroup.TODO, mode=ToolMode.DEFERRED), ToolDef(name="TaskGet", desc="获取任务详情", group=ToolGroup.TODO, mode=ToolMode.DEFERRED), ToolDef(name="TaskList", desc="列出所有任务", group=ToolGroup.TODO, mode=ToolMode.DEFERRED), ToolDef(name="TaskUpdate", desc="更新任务状态", group=ToolGroup.TODO, mode=ToolMode.DEFERRED), + # cron — backed by existing cron_jobs substrate; off by default until explicitly enabled + ToolDef(name="CronCreate", desc="创建定时任务", group=ToolGroup.CRON, mode=ToolMode.DEFERRED, default=False), + ToolDef(name="CronDelete", desc="删除定时任务", group=ToolGroup.CRON, mode=ToolMode.DEFERRED, default=False), + ToolDef(name="CronList", desc="列出定时任务", group=ToolGroup.CRON, mode=ToolMode.DEFERRED, default=False), # skills ToolDef(name="load_skill", desc="加载 Skill", group=ToolGroup.SKILLS), # system ToolDef(name="tool_search", desc="搜索可用工具", group=ToolGroup.SYSTEM), + ToolDef(name="LSP", desc="Language Server Protocol 操作", group=ToolGroup.SYSTEM, mode=ToolMode.DEFERRED, default=False), # taskboard — all off by default; enable on dedicated scheduler members ToolDef(name="ListBoardTasks", desc="列出任务板上的任务", group=ToolGroup.TASKBOARD, default=False), ToolDef(name="ClaimTask", desc="认领一个任务板任务", group=ToolGroup.TASKBOARD, default=False), diff --git a/config/loader.py b/config/loader.py index 7b2f3190c..7dccb1c00 100644 --- a/config/loader.py +++ b/config/loader.py @@ -153,7 +153,7 @@ def _load_agents_from_members(self, members_dir: Path) -> None: continue config = self.parse_agent_file(agent_md) if config: - # source_dir is already set to member_dir by parse_agent_file + config.source_dir = member_dir.resolve() self._agents[config.name] = config @staticmethod @@ -184,7 +184,7 @@ def parse_agent_file(path: Path) -> AgentConfig | None: tools=fm.get("tools", ["*"]), system_prompt=parts[2].strip(), model=fm.get("model"), - source_dir=path.resolve().parent, + source_dir=None, ) def get_agent(self, name: str) -> AgentConfig | None: diff --git a/config/schema.py b/config/schema.py index 53a0cc8ea..62ba9f7df 100644 --- a/config/schema.py +++ b/config/schema.py @@ -215,6 +215,10 @@ class ToolsConfig(BaseModel): class MCPServerConfig(BaseModel): """Configuration for a single MCP server.""" + transport: str | None = Field( + None, + description="MCP transport type: stdio | streamable_http | sse | websocket", + ) command: str | None = Field(None, description="Command to run the MCP server") args: list[str] = Field(default_factory=list, description="Command arguments") env: dict[str, str] = Field(default_factory=dict, description="Environment variables") diff --git a/config/types.py b/config/types.py index 9731d5aff..0c49458fd 100644 --- a/config/types.py +++ b/config/types.py @@ -20,10 +20,12 @@ class AgentConfig(BaseModel): class McpServerConfig(BaseModel): """Single MCP server entry from .mcp.json.""" + transport: str | None = None command: str | None = None args: list[str] = Field(default_factory=list) env: dict[str, str] = Field(default_factory=dict) url: str | None = None + instructions: str | None = None allowed_tools: list[str] | None = None disabled: bool = False diff --git a/core/agents/communication/chat_tool_service.py b/core/agents/communication/chat_tool_service.py index f5464abb4..66078d7f6 100644 --- a/core/agents/communication/chat_tool_service.py +++ b/core/agents/communication/chat_tool_service.py @@ -1,4 +1,4 @@ -"""Chat tool service — 7 tools for entity-to-entity communication. +"""Chat tool service — Mycel-native tools for entity-to-entity communication. Tools use user_ids as parameters (human = Supabase auth UUID, agent = member_id). Two users share at most one chat; the system auto-resolves user_id → chat. @@ -12,11 +12,11 @@ from datetime import UTC, datetime from typing import Any -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema logger = logging.getLogger(__name__) -# @@@range-parser — parse range strings for chat_read history queries. +# @@@range-parser — parse range strings for read_messages history queries. # Supports: negative index (-10:-1), relative time (-2h:, -1d:-6h), ISO dates (2026-03-20:2026-03-22). _RELATIVE_RE = re.compile(r"^-(\d+)([hdm])$") @@ -89,7 +89,7 @@ def _parse_time_endpoint(s: str, now: float) -> float | None: class ChatToolService: - """Registers 5 chat tools into ToolRegistry. + """Registers the chat tool surface into ToolRegistry. Each tool closure captures user_id (the calling agent's social identity = member_id). """ @@ -120,19 +120,44 @@ def __init__( self._register(registry) def _register(self, registry: ToolRegistry) -> None: - self._register_chats(registry) - self._register_chat_read(registry) - self._register_chat_send(registry) - self._register_chat_search(registry) - self._register_directory(registry) + self._register_list_chats(registry) + self._register_read_messages(registry) + self._register_send_message(registry) + self._register_search_messages(registry) + + def _latest_notified_chat_id(self, request: Any) -> str | None: + state = getattr(request, "state", None) + messages = getattr(state, "messages", None) + if not isinstance(messages, list): + return None + for message in reversed(messages): + metadata = getattr(message, "metadata", None) or {} + if metadata.get("source") != "external" or metadata.get("notification_type") != "chat": + continue + content = getattr(message, "content", "") + text = content if isinstance(content, str) else str(content) + match = re.search(r'read_messages\(chat_id="([^"]+)"\)', text) + if match: + return match.group(1) + return None + + def _fill_missing_chat_target(self, args: dict[str, Any], request: Any) -> dict[str, Any]: + if args.get("user_id"): + return args + if isinstance(args.get("chat_id"), str) and args["chat_id"].strip(): + return args + notified_chat_id = self._latest_notified_chat_id(request) + if notified_chat_id: + return {**args, "chat_id": notified_chat_id} + return args def _resolve_name(self, user_id: str) -> str: """Resolve display name: entity_repo (agents) → member_repo (humans).""" - e = self._entities.get_by_id(user_id) - if e: - return e.name - m = self._members.get_by_id(user_id) if self._members else None - return m.name if m else "unknown" + entity = self._entities.get_by_id(user_id) if self._entities else None + if entity: + return entity.name + member = self._members.get_by_id(user_id) if self._members else None + return member.name if member else "unknown" def _format_msgs(self, msgs: list, eid: str) -> str: lines = [] @@ -159,309 +184,262 @@ def _fetch_by_range(self, chat_id: str, parsed: dict) -> list: before=parsed["before"], ) - def _register_chats(self, registry: ToolRegistry) -> None: + def _handle_list_chats(self, unread_only: bool = False, limit: int = 20) -> str: eid = self._user_id + chats = self._chat_service.list_chats_for_user(eid) + if unread_only: + chats = [c for c in chats if c.get("unread_count", 0) > 0] + chats = chats[:limit] + if not chats: + return "No chats found." + lines = [] + for c in chats: + others = [e for e in c.get("entities", []) if e["id"] != eid] + name = ", ".join(e["name"] for e in others) or "Unknown" + unread = c.get("unread_count", 0) + last = c.get("last_message") + last_preview = f' — last: "{last["content"][:50]}"' if last else "" + unread_str = f" ({unread} unread)" if unread > 0 else "" + is_group = len(others) >= 2 + if is_group: + id_str = f" [chat_id: {c['id']}]" + else: + other_id = others[0]["id"] if others else "" + id_str = f" [user_id: {other_id}]" if other_id else "" + lines.append(f"- {name}{id_str}{unread_str}{last_preview}") + return "\n".join(lines) + + def _handle_read_messages(self, user_id: str | None = None, chat_id: str | None = None, range: str | None = None) -> str: + eid = self._user_id + if chat_id: + pass # use chat_id directly + elif user_id: + chat_id = self._chat_entities.find_chat_between(eid, user_id) + if not chat_id: + name = self._resolve_name(user_id) + return f"No chat history with {name}." + else: + return "Provide user_id or chat_id." + + # @@@range-dispatch — if range is provided, use it regardless of unread state. + if range: + try: + parsed = _parse_range(range) + except ValueError as e: + return str(e) + msgs = self._fetch_by_range(chat_id, parsed) + if not msgs: + return "No messages in that range." + # @@@range-marks-read — WORKAROUND: unblock send_message by pushing + # last_read_at to now. This marks ALL messages as read, not just + # the requested range. Proper fix needs per-message read tracking + # instead of the current single-timestamp waterline model. + self._chat_entities.update_last_read(chat_id, eid, time.time()) + return self._format_msgs(msgs, eid) + + # @@@read-unread-only — default to unread messages only. + msgs = self._messages.list_unread(chat_id, eid) + if msgs: + self._chat_entities.update_last_read(chat_id, eid, time.time()) + return self._format_msgs(msgs, eid) + + # Nothing unread — prompt agent to use range parameter + return ( + "No unread messages. To read history, call again with range:\n" + " range='-10:-1' (last 10 messages)\n" + " range='-5:' (last 5 messages)\n" + " range='-1h:' (last hour)\n" + " range='-2d:-1d' (yesterday)\n" + " range='2026-03-20:2026-03-22' (date range)" + ) + + def _handle_send_message( + self, + content: str, + user_id: str | None = None, + chat_id: str | None = None, + signal: str = "open", + mentions: list[str] | None = None, + ) -> str: + eid = self._user_id + # @@@read-before-write — resolve chat_id, then check unread + resolved_chat_id = chat_id + target_name = "chat" + + if chat_id: + if not self._chat_entities.is_participant_in_chat(chat_id, eid): + raise RuntimeError(f"You are not a member of chat {chat_id}") + elif user_id: + if user_id == eid: + raise RuntimeError("Cannot send a message to yourself.") + target_name = self._resolve_name(user_id) + resolved_chat_id = self._chat_entities.find_chat_between(eid, user_id) + if not resolved_chat_id: + # New chat — no unread possible, create and send + chat = self._chat_service.find_or_create_chat([eid, user_id]) + resolved_chat_id = chat.id + else: + raise RuntimeError("Provide user_id (for 1:1) or chat_id (for group)") + + # @@@read-before-write-gate — reject if unread messages exist + unread = self._messages.count_unread(resolved_chat_id, eid) + if unread > 0: + raise RuntimeError(f"You have {unread} unread message(s). Call read_messages(chat_id='{resolved_chat_id}') first.") - def handle(unread_only: bool = False, limit: int = 20) -> str: - chats = self._chat_service.list_chats_for_user(eid) - if unread_only: - chats = [c for c in chats if c.get("unread_count", 0) > 0] - chats = chats[:limit] - if not chats: - return "No chats found." - lines = [] - for c in chats: - others = [e for e in c.get("entities", []) if e["id"] != eid] - name = ", ".join(e["name"] for e in others) or "Unknown" - unread = c.get("unread_count", 0) - last = c.get("last_message") - last_preview = f' — last: "{last["content"][:50]}"' if last else "" - unread_str = f" ({unread} unread)" if unread > 0 else "" - is_group = len(others) >= 2 - if is_group: - id_str = f" [chat_id: {c['id']}]" - else: - other_id = others[0]["id"] if others else "" - id_str = f" [user_id: {other_id}]" if other_id else "" - lines.append(f"- {name}{id_str}{unread_str}{last_preview}") - return "\n".join(lines) + # Append signal to content (for read_messages) + pass through chain (for notification) + effective_signal = signal if signal in ("yield", "close") else None + if effective_signal: + content = f"{content}\n[signal: {effective_signal}]" + self._chat_service.send_message(resolved_chat_id, eid, content, mentions, signal=effective_signal) + return f"Message sent to {target_name}." + + def _handle_search_messages(self, query: str, user_id: str | None = None) -> str: + eid = self._user_id + chat_id = None + if user_id: + chat_id = self._chat_entities.find_chat_between(eid, user_id) + results = self._messages.search(query, chat_id=chat_id, limit=20) + if not results: + return f"No messages matching '{query}'." + lines = [] + for m in results: + name = self._resolve_name(m.sender_id) + lines.append(f"[{name}] {m.content[:100]}") + return "\n".join(lines) + + def _register_list_chats(self, registry: ToolRegistry) -> None: registry.register( ToolEntry( - name="chats", + name="list_chats", mode=ToolMode.INLINE, - schema={ - "name": "chats", - "description": "List your chats. Returns chat summaries with user_ids of participants.", - "parameters": { - "type": "object", - "properties": { - "unread_only": { - "type": "boolean", - "description": "Only show chats with unread messages", - "default": False, - }, - "limit": {"type": "integer", "description": "Max number of chats to return", "default": 20}, + schema=make_tool_schema( + name="list_chats", + description="List your chats. Returns chat summaries with user_ids of participants.", + properties={ + "unread_only": { + "type": "boolean", + "description": "Only show chats with unread messages", + "default": False, }, + "limit": {"type": "integer", "description": "Max number of chats to return", "default": 20}, }, - }, - handler=handle, + ), + handler=self._handle_list_chats, source="chat", + is_read_only=True, + is_concurrency_safe=True, ) ) - def _register_chat_read(self, registry: ToolRegistry) -> None: - eid = self._user_id - - def handle(user_id: str | None = None, chat_id: str | None = None, range: str | None = None) -> str: - if chat_id: - pass # use chat_id directly - elif user_id: - chat_id = self._chat_entities.find_chat_between(eid, user_id) - if not chat_id: - name = self._resolve_name(user_id) - return f"No chat history with {name}." - else: - return "Provide user_id or chat_id." - - # @@@range-dispatch — if range is provided, use it regardless of unread state. - if range: - try: - parsed = _parse_range(range) - except ValueError as e: - return str(e) - msgs = self._fetch_by_range(chat_id, parsed) - if not msgs: - return "No messages in that range." - # @@@range-marks-read — WORKAROUND: unblock chat_send by pushing - # last_read_at to now. This marks ALL messages as read, not just - # the requested range. Proper fix needs per-message read tracking - # instead of the current single-timestamp waterline model. - self._chat_entities.update_last_read(chat_id, eid, time.time()) - return self._format_msgs(msgs, eid) - - # @@@read-unread-only — default to unread messages only. - msgs = self._messages.list_unread(chat_id, eid) - if msgs: - self._chat_entities.update_last_read(chat_id, eid, time.time()) - return self._format_msgs(msgs, eid) - - # Nothing unread — prompt agent to use range parameter - return ( - "No unread messages. To read history, call again with range:\n" - " range='-10:-1' (last 10 messages)\n" - " range='-5:' (last 5 messages)\n" - " range='-1h:' (last hour)\n" - " range='-2d:-1d' (yesterday)\n" - " range='2026-03-20:2026-03-22' (date range)" - ) - + def _register_read_messages(self, registry: ToolRegistry) -> None: registry.register( ToolEntry( - name="chat_read", + name="read_messages", mode=ToolMode.INLINE, - schema={ - "name": "chat_read", - "description": ( + schema=make_tool_schema( + name="read_messages", + description=( "Read chat messages. Returns unread messages by default.\n" "If nothing unread, use range to read history:\n" " Negative index: '-10:-1' (last 10), '-5:' (last 5)\n" " Time interval: '-1h:', '-2d:-1d', '2026-03-20:2026-03-22'\n" "Positive indices are NOT allowed." ), - "parameters": { - "type": "object", - "properties": { - "user_id": {"type": "string", "description": "user_id for 1:1 chat history"}, - "chat_id": {"type": "string", "description": "Chat_id for group chat history"}, - "range": { - "type": "string", - "description": ( - "History range. Negative index '-X:-Y' or time '-1h:', '2026-03-20:'. Positive indices NOT allowed." - ), - }, + properties={ + "user_id": {"type": "string", "description": "user_id for 1:1 chat history"}, + "chat_id": {"type": "string", "description": "Chat_id for group chat history"}, + "range": { + "type": "string", + "description": ( + "History range. Negative index '-X:-Y' or time '-1h:', '2026-03-20:'. Positive indices NOT allowed." + ), }, }, - }, - handler=handle, + parameter_overrides={ + "x-leon-required-any-of": [ + ["user_id"], + ["chat_id"], + ], + }, + ), + handler=self._handle_read_messages, source="chat", + search_hint="read chat messages history conversation", + is_read_only=True, + is_concurrency_safe=True, + validate_input=self._fill_missing_chat_target, ) ) - def _register_chat_send(self, registry: ToolRegistry) -> None: - eid = self._user_id - - def handle( - content: str, - user_id: str | None = None, - chat_id: str | None = None, - signal: str = "open", - mentions: list[str] | None = None, - ) -> str: - # @@@read-before-write — resolve chat_id, then check unread - resolved_chat_id = chat_id - target_name = "chat" - - if chat_id: - if not self._chat_entities.is_participant_in_chat(chat_id, eid): - raise RuntimeError(f"You are not a member of chat {chat_id}") - elif user_id: - if user_id == eid: - raise RuntimeError("Cannot send a message to yourself.") - target_name = self._resolve_name(user_id) - resolved_chat_id = self._chat_entities.find_chat_between(eid, user_id) - if not resolved_chat_id: - # New chat — no unread possible, create and send - chat = self._chat_service.find_or_create_chat([eid, user_id]) - resolved_chat_id = chat.id - else: - raise RuntimeError("Provide user_id (for 1:1) or chat_id (for group)") - - # @@@read-before-write-gate — reject if unread messages exist - unread = self._messages.count_unread(resolved_chat_id, eid) - if unread > 0: - raise RuntimeError(f"You have {unread} unread message(s). Call chat_read(chat_id='{resolved_chat_id}') first.") - - # Append signal to content (for chat_read) + pass through chain (for notification) - effective_signal = signal if signal in ("yield", "close") else None - if effective_signal: - content = f"{content}\n[signal: {effective_signal}]" - - self._chat_service.send_message(resolved_chat_id, eid, content, mentions, signal=effective_signal) - return f"Message sent to {target_name}." - + def _register_send_message(self, registry: ToolRegistry) -> None: registry.register( ToolEntry( - name="chat_send", + name="send_message", mode=ToolMode.INLINE, - schema={ - "name": "chat_send", - "description": ( + schema=make_tool_schema( + name="send_message", + description=( "Send a message. Use user_id for 1:1 chats, chat_id for group chats.\n\n" - "You MUST call chat_read() first if you have unread messages — sending will fail otherwise.\n\n" + "You MUST call read_messages() first if you have unread messages — sending will fail otherwise.\n\n" "Signal protocol — append to content:\n" " (no tag) = I expect a reply from you\n" " ::yield = I'm done with my turn; reply only if you want to\n" " ::close = conversation over, do NOT reply\n\n" "For games/turns: do NOT append ::yield — just send the move and expect a reply." ), - "parameters": { - "type": "object", - "properties": { - "content": {"type": "string", "description": "Message content"}, - "user_id": {"type": "string", "description": "Target user_id (for 1:1 chat)"}, - "chat_id": {"type": "string", "description": "Target chat_id (for group chat)"}, - "signal": { - "type": "string", - "enum": ["open", "yield", "close"], - "description": "Signal intent to recipient", - "default": "open", - }, - "mentions": { - "type": "array", - "items": {"type": "string"}, - "description": "Entity IDs to @mention (overrides mute for these recipients)", - }, + properties={ + "content": {"type": "string", "description": "Message content"}, + "user_id": {"type": "string", "description": "Target user_id (for 1:1 chat)"}, + "chat_id": {"type": "string", "description": "Target chat_id (for group chat)"}, + "signal": { + "type": "string", + "enum": ["open", "yield", "close"], + "description": "Signal intent to recipient", + "default": "open", }, - "required": ["content"], - }, - }, - handler=handle, - source="chat", - ) - ) - - def _register_chat_search(self, registry: ToolRegistry) -> None: - eid = self._user_id - - def handle(query: str, user_id: str | None = None) -> str: - chat_id = None - if user_id: - chat_id = self._chat_entities.find_chat_between(eid, user_id) - results = self._messages.search(query, chat_id=chat_id, limit=20) - if not results: - return f"No messages matching '{query}'." - lines = [] - for m in results: - name = self._resolve_name(m.sender_id) - lines.append(f"[{name}] {m.content[:100]}") - return "\n".join(lines) - - registry.register( - ToolEntry( - name="chat_search", - mode=ToolMode.INLINE, - schema={ - "name": "chat_search", - "description": "Search messages. Optionally filter by user_id.", - "parameters": { - "type": "object", - "properties": { - "query": {"type": "string", "description": "Search query"}, - "user_id": { - "type": "string", - "description": "Optional: only search in chat with this user", - }, + "mentions": { + "type": "array", + "items": {"type": "string"}, + "description": "Entity IDs to @mention (overrides mute for these recipients)", }, - "required": ["query"], }, - }, - handler=handle, + required=["content"], + parameter_overrides={ + "x-leon-required-any-of": [ + ["content", "user_id"], + ["content", "chat_id"], + ], + }, + ), + handler=self._handle_send_message, source="chat", + search_hint="send message reply chat entity", + validate_input=self._fill_missing_chat_target, ) ) - def _register_directory(self, registry: ToolRegistry) -> None: - eid = self._user_id - - def handle(search: str | None = None, type: str | None = None) -> str: - lines = [] - all_members = self._members.list_all() if self._members else [] - member_map = {m.id: m for m in all_members} - - if type is None or type == "human": - for m in all_members: - if m.id == eid or m.type != "human": - continue - if search and search.lower() not in m.name.lower(): - continue - lines.append(f"- {m.name} [human] user_id={m.id}") - - if type is None or type == "agent": - all_entities = self._entities.list_all() - for e in all_entities: - if e.id == eid or e.type != "agent": - continue - if search and search.lower() not in e.name.lower(): - continue - member = member_map.get(e.member_id) - owner_info = "" - if member and member.owner_user_id: - owner = member_map.get(member.owner_user_id) - if owner: - owner_info = f" (owner: {owner.name})" - lines.append(f"- {e.name} [{e.type}] user_id={e.id}{owner_info}") - - if not lines: - return "No users found." - return "\n".join(lines) - + def _register_search_messages(self, registry: ToolRegistry) -> None: registry.register( ToolEntry( - name="directory", + name="search_messages", mode=ToolMode.INLINE, - schema={ - "name": "directory", - "description": "Browse the user directory. Returns user_ids for use with chat_send, chat_read.", - "parameters": { - "type": "object", - "properties": { - "search": {"type": "string", "description": "Search by name"}, - "type": {"type": "string", "description": "Filter by type: 'human' or 'agent'"}, + schema=make_tool_schema( + name="search_messages", + description="Search messages. Optionally filter by user_id.", + properties={ + "query": {"type": "string", "description": "Search query"}, + "user_id": { + "type": "string", + "description": "Optional: only search in chat with this user", }, }, - }, - handler=handle, + required=["query"], + ), + handler=self._handle_search_messages, source="chat", + search_hint="search messages query chat history", + is_read_only=True, + is_concurrency_safe=True, ) ) diff --git a/core/agents/communication/delivery.py b/core/agents/communication/delivery.py index c14ee6025..7e0a502bf 100644 --- a/core/agents/communication/delivery.py +++ b/core/agents/communication/delivery.py @@ -1,12 +1,13 @@ """Chat delivery — enqueues lightweight notifications for agent threads. -v3: no full message text injected. Agent must chat_read to see content. +v3: no full message text injected. Agent must read_messages to see content. ChatService._deliver_to_agents calls the delivery function for each non-sender agent entity. """ from __future__ import annotations +import functools import logging from typing import Any @@ -41,18 +42,20 @@ def _deliver( loop, ) - def _on_done(f): - exc = f.exception() - if exc: - logger.error("[delivery] async delivery failed for %s: %s", entity.id, exc, exc_info=exc) - else: - logger.info("[delivery] async delivery completed for %s", entity.id) - - future.add_done_callback(_on_done) + future.add_done_callback(functools.partial(_log_delivery_result, entity.id)) return _deliver +def _log_delivery_result(entity_id: str, f: Any) -> None: + """Done-callback for async delivery futures.""" + exc = f.exception() + if exc: + logger.error("[delivery] async delivery failed for %s: %s", entity_id, exc, exc_info=exc) + else: + logger.info("[delivery] async delivery completed for %s", entity_id) + + async def _async_deliver( app: Any, entity: EntityRow, @@ -64,7 +67,7 @@ async def _async_deliver( ) -> None: """Enqueue chat notification to an agent's brain thread. - @@@v3-notification-only — no message content. Agent calls chat_read to see it. + @@@v3-notification-only — no message content. Agent calls read_messages to see it. """ # @@@context-isolation — clear inherited LangChain ContextVar so the recipient # agent's astream doesn't inherit the sender's StreamMessagesHandler callbacks. diff --git a/core/agents/registry.py b/core/agents/registry.py index f74f4f4ec..cb208641d 100644 --- a/core/agents/registry.py +++ b/core/agents/registry.py @@ -59,6 +59,33 @@ async def get_by_id(self, agent_id: str) -> AgentEntry | None: subagent_type=row[5], ) + async def list_running_by_name(self, name: str) -> list[AgentEntry]: + rows = self._repo.list_running_by_name(name) + return [ + AgentEntry( + agent_id=row[0], + name=row[1], + thread_id=row[2], + status=row[3], + parent_agent_id=row[4], + subagent_type=row[5], + ) + for row in rows + ] + + async def get_latest_by_name_and_parent(self, name: str, parent_agent_id: str | None) -> AgentEntry | None: + row = self._repo.get_latest_by_name_and_parent(name, parent_agent_id) + if row is None: + return None + return AgentEntry( + agent_id=row[0], + name=row[1], + thread_id=row[2], + status=row[3], + parent_agent_id=row[4], + subagent_type=row[5], + ) + async def update_status(self, agent_id: str, status: str) -> None: async with self._lock: self._repo.update_status(agent_id, status) diff --git a/core/agents/service.py b/core/agents/service.py index e7baff89b..823d37a4e 100644 --- a/core/agents/service.py +++ b/core/agents/service.py @@ -11,89 +11,306 @@ import asyncio import json import logging +import os +import time import uuid +from collections.abc import Awaitable, Callable from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any, cast +from config.loader import AgentLoader from core.agents.registry import AgentEntry, AgentRegistry -from core.runtime.middleware.queue.formatters import format_background_notification -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.middleware.queue.formatters import ( + format_agent_message, + format_background_notification, + format_progress_notification, +) +from core.runtime.permissions import ToolPermissionContext +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema +from core.runtime.state import BootstrapConfig, ToolUseContext +from core.runtime.tool_result import tool_error, tool_permission_request, tool_success +from storage.contracts import EntityRow logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from core.runtime.agent import LeonAgent -AGENT_SCHEMA = { - "name": "Agent", - "description": ( - "Launch a new agent to handle complex tasks autonomously. " - "Use subagent_type to select a specialized agent, or omit for default. " - "Agents run independently with their own tool stack." + +EventEmitter = Callable[[dict[str, Any]], Awaitable[None] | None] +ChildAgentFactory = Callable[..., "LeonAgent"] + + +def _resolve_default_child_agent_factory() -> ChildAgentFactory: + from core.runtime.agent import create_leon_agent + + return cast(ChildAgentFactory, create_leon_agent) + + +# ── Sub-agent tool filtering (CC alignment) ────────────────────────────────── +# Tools that sub-agents must never access (prevents controlling parent). +AGENT_DISALLOWED: set[str] = {"TaskOutput", "TaskStop", "Agent"} + +# Per-type allowed tool sets. Tools not in the set are blocked. +EXPLORE_ALLOWED: set[str] = {"Read", "Grep", "Glob", "list_dir", "WebSearch", "WebFetch", "tool_search"} +PLAN_ALLOWED: set[str] = EXPLORE_ALLOWED # plan agents are also read-only +BASH_ALLOWED: set[str] = {"Bash", "Read", "Grep", "Glob", "list_dir", "tool_search"} + + +def _get_tool_filters(subagent_type: str) -> tuple[set[str], set[str] | None]: + """Return (extra_blocked_tools, allowed_tools) for a sub-agent type. + + For explore/plan/bash: use allowed_tools whitelist (ToolRegistry skips unmatched). + For general: only block AGENT_DISALLOWED, no whitelist. + """ + agent_type = subagent_type.lower() + allowed_map: dict[str, set[str]] = { + "explore": EXPLORE_ALLOWED, + "plan": PLAN_ALLOWED, + "bash": BASH_ALLOWED, + } + + if agent_type in allowed_map: + return AGENT_DISALLOWED, allowed_map[agent_type] + + # general: only block parent-controlling tools, no whitelist + return AGENT_DISALLOWED, None + + +def _get_subagent_agent_name(subagent_type: str) -> str: + return subagent_type.lower() + + +def _resolve_subagent_model( + workspace_root: Path, + subagent_type: str, + requested_model: str | None, + inherited_model: str, + fallback_model: str | None = None, +) -> str: + def _is_inherit_marker(value: str | None) -> bool: + return value is None or value.lower() in {"default", "inherit"} + + env_model = os.getenv("CLAUDE_CODE_SUBAGENT_MODEL") + if env_model: + return env_model + if requested_model and not _is_inherit_marker(requested_model): + return requested_model + + agent_def = AgentLoader(workspace_root=workspace_root).load_all_agents().get(_get_subagent_agent_name(subagent_type)) + if agent_def and agent_def.model: + return agent_def.model + + if inherited_model and not _is_inherit_marker(inherited_model): + return inherited_model + if fallback_model and not _is_inherit_marker(fallback_model): + return fallback_model + return inherited_model + + +def _normalize_child_workspace_prompt(prompt: str, workspace_root: Path) -> str: + workspace_text = str(workspace_root) + for suffix in ("current working directory", "working directory"): + prompt = prompt.replace(f"{workspace_text}/{suffix}", workspace_text) + return prompt + + +def _filter_fork_messages(messages: list) -> list: + """Filter parent messages for forkContext sub-agent spawning. + + Equivalent to CC's yF0: removes assistant messages whose tool_use blocks + have no matching tool_result in a subsequent user message (orphan tool_use). + Orphan tool_use blocks cause Anthropic API validation errors. + """ + # Collect all tool_use_ids that have a corresponding tool_result + answered: set[str] = set() + for msg in messages: + # ToolMessage or user message with tool_result content + tool_call_id = getattr(msg, "tool_call_id", None) + if tool_call_id: + answered.add(tool_call_id) + content = getattr(msg, "content", None) + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_result": + tid = block.get("tool_use_id") or block.get("tool_call_id") + if tid: + answered.add(tid) + + result = [] + for msg in messages: + content = getattr(msg, "content", None) + if isinstance(content, list): + tool_uses = [b for b in content if isinstance(b, dict) and b.get("type") == "tool_use"] + if tool_uses and any(b.get("id") not in answered for b in tool_uses): + continue # skip assistant message with unanswered tool_use + result.append(msg) + return result + + +AGENT_SCHEMA = make_tool_schema( + name="Agent", + description=( + "Launch a sub-agent for independent task execution. " + "Types: explore (read-only codebase search), plan (architecture design, read-only), " + "bash (shell commands only), general (broad tool access except Agent, TaskOutput, and TaskStop). " + "Use for: multi-step tasks, parallel work, tasks needing isolation. " + "Do NOT use for simple file reads or single grep searches — use the tools directly." ), - "parameters": { - "type": "object", - "properties": { - "subagent_type": { - "type": "string", - "description": "Type of agent to spawn (e.g. 'Explore', 'Coder'). Omit for general-purpose.", - }, - "prompt": { - "type": "string", - "description": "Task for the agent", - }, - "name": { - "type": "string", - "description": "Name for the agent (used for SendMessage routing)", - }, - "description": { - "type": "string", - "description": ( - "Short description of what agent will do. Required when run_in_background is true; " - "shown in the background task indicator." - ), - }, - "run_in_background": { - "type": "boolean", - "default": False, - "description": "Fire-and-forget: return immediately with task_id instead of waiting for completion", - }, - "max_turns": { - "type": "integer", - "description": "Maximum turns the agent can take", - }, + properties={ + "subagent_type": { + "type": "string", + "enum": ["explore", "plan", "general", "bash"], + "description": "Type of agent to spawn. Omit for general-purpose.", + }, + "prompt": { + "type": "string", + "description": "Task for the agent", + }, + "name": { + "type": "string", + "description": "Optional display name for the spawned agent", + }, + "description": { + "type": "string", + "description": ( + "Short description of what agent will do. Required when run_in_background is true; shown in the background task indicator." + ), + }, + "run_in_background": { + "type": "boolean", + "default": False, + "description": "Fire-and-forget: return immediately with task_id instead of waiting for completion", + }, + "model": { + "type": "string", + "description": "Optional sub-agent model override. Priority: env > this field > agent frontmatter > inherit.", + }, + "max_turns": { + "type": "integer", + "description": "Maximum turns the agent can take", + }, + "fork_context": { + "type": "boolean", + "default": False, + "description": ( + "Inherit parent conversation history as read-only context. " + "Use when the sub-agent needs background from the parent's work. " + "Adds a ### ENTERING SUB-AGENT ROUTINE ### marker so the sub-agent " + "knows which messages are context vs its actual task." + ), }, - "required": ["prompt"], }, -} - -TASK_OUTPUT_SCHEMA = { - "name": "TaskOutput", - "description": "Get the output of a background agent task by its task_id.", - "parameters": { - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID returned when starting a background agent", - }, + required=["prompt", "description"], +) + +TASK_OUTPUT_SCHEMA = make_tool_schema( + name="TaskOutput", + description=( + "Get output of a background task (agent or bash). Blocks until task completes by default. Returns full text output or error." + ), + properties={ + "task_id": { + "type": "string", + "description": "The task ID returned when starting a background agent", + }, + "block": { + "type": "boolean", + "default": True, + "description": "Whether to wait for completion. Use false for a non-blocking status check.", + }, + "timeout": { + "type": "integer", + "default": 30000, + "minimum": 0, + "maximum": 600000, + "description": "Maximum wait time in milliseconds when block=true (default: 30000, max: 600000).", }, - "required": ["task_id"], }, -} - -TASK_STOP_SCHEMA = { - "name": "TaskStop", - "description": "Stop a running background agent task.", - "parameters": { - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID to stop", + required=["task_id"], +) + +TASK_STOP_SCHEMA = make_tool_schema( + name="TaskStop", + description="Cancel a running background task. Sends cancellation signal; task may take a moment to stop.", + properties={ + "task_id": { + "type": "string", + "description": "The task ID to stop", + }, + }, + required=["task_id"], +) + +SEND_MESSAGE_SCHEMA = make_tool_schema( + name="SendMessage", + description="Send a queued message to another running agent by name. Delivered before that agent's next model turn.", + properties={ + "target_name": { + "type": "string", + "description": "Display name of the running target agent", + }, + "message": { + "type": "string", + "description": "Message body to deliver", + }, + "sender_name": { + "type": "string", + "description": "Optional sender label for the delivered message", + }, + }, + required=["target_name", "message"], +) + +ASK_USER_QUESTION_SCHEMA = make_tool_schema( + name="AskUserQuestion", + description=( + "Ask the user one or more structured questions when progress requires their choice or clarification. " + "Use for genuine ambiguity, preference selection, or approval that needs an explicit answer before continuing." + ), + properties={ + "questions": { + "type": "array", + "description": "Questions to present to the user.", + "minItems": 1, + "items": { + "type": "object", + "properties": { + "header": {"type": "string", "description": "Short UI label for the question."}, + "question": {"type": "string", "description": "Full question text shown to the user."}, + "multiSelect": { + "type": "boolean", + "default": False, + "description": "Whether the user may pick multiple options.", + }, + "options": { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "properties": { + "label": {"type": "string"}, + "description": {"type": "string"}, + "preview": {"type": "string"}, + }, + "required": ["label", "description"], + }, + }, + }, + "required": ["header", "question", "options"], }, }, - "required": ["task_id"], + "annotations": { + "type": "object", + "description": "Optional structured annotations kept with the question request.", + }, + "metadata": { + "type": "object", + "description": "Optional metadata describing the source of the question request.", + }, }, -} + required=["questions"], +) class _RunningTask: @@ -150,6 +367,33 @@ def get_result(self) -> str | None: BackgroundRun = _RunningTask | _BashBackgroundRun +def _background_run_running_message(running: BackgroundRun) -> str: + return "Command is still running." if isinstance(running, _BashBackgroundRun) else "Agent is still running." + + +def _background_run_result_status(result: str | None) -> str: + return "error" if (result and result.startswith("")) else "completed" + + +async def _wait_for_background_run(running: BackgroundRun, timeout_ms: int) -> bool: + timeout_s = max(timeout_ms, 0) / 1000.0 + if isinstance(running, _RunningTask): + try: + await asyncio.wait_for(asyncio.shield(running.task), timeout=timeout_s) + return True + except TimeoutError: + return running.is_done + + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout_s + while True: + if running.is_done: + return True + if loop.time() >= deadline: + return False + await asyncio.sleep(0.1) + + class AgentService: """Registers Agent, TaskOutput, TaskStop tools into ToolRegistry. @@ -170,11 +414,25 @@ def __init__( model_name: str, queue_manager: Any | None = None, shared_runs: dict[str, BackgroundRun] | None = None, + background_progress_interval_s: float = 30.0, + thread_repo: Any = None, + entity_repo: Any = None, + member_repo: Any = None, + web_app: Any = None, + child_agent_factory: ChildAgentFactory | None = None, ): self._agent_registry = agent_registry self._workspace_root = workspace_root self._model_name = model_name self._queue_manager = queue_manager + self._background_progress_interval_s = background_progress_interval_s + self._thread_repo = thread_repo + self._entity_repo = entity_repo + self._member_repo = member_repo + self._web_app = web_app + self._child_agent_factory = child_agent_factory or _resolve_default_child_agent_factory() + self._parent_bootstrap: BootstrapConfig | None = None + self._parent_tool_context: Any | None = None # Shared with CommandService so TaskOutput covers both bash and agent runs. self._tasks: dict[str, BackgroundRun] = shared_runs if shared_runs is not None else {} @@ -185,6 +443,7 @@ def __init__( schema=AGENT_SCHEMA, handler=self._handle_agent, source="AgentService", + search_hint="launch sub-agent spawn parallel task independent", ) ) tool_registry.register( @@ -194,6 +453,9 @@ def __init__( schema=TASK_OUTPUT_SCHEMA, handler=self._handle_task_output, source="AgentService", + search_hint="get background task output result poll", + is_read_only=True, + is_concurrency_safe=True, ) ) tool_registry.register( @@ -203,9 +465,97 @@ def __init__( schema=TASK_STOP_SCHEMA, handler=self._handle_task_stop, source="AgentService", + search_hint="stop cancel background task agent", + ) + ) + tool_registry.register( + ToolEntry( + name="SendMessage", + mode=ToolMode.INLINE, + schema=SEND_MESSAGE_SCHEMA, + handler=self._handle_send_message, + source="AgentService", + search_hint="send message running agent delivery queue", + ) + ) + tool_registry.register( + ToolEntry( + name="AskUserQuestion", + mode=ToolMode.INLINE, + schema=ASK_USER_QUESTION_SCHEMA, + handler=self._handle_ask_user_question, + source="AgentService", + search_hint="ask user question clarification choice preference", + is_read_only=True, + is_concurrency_safe=True, ) ) + @staticmethod + def _normalize_child_sandbox(sandbox_type: str | None) -> str | None: + return None if not sandbox_type or sandbox_type == "local" else sandbox_type + + def _ensure_subagent_thread_metadata( + self, + *, + thread_id: str, + parent_thread_id: str | None, + agent_name: str, + model_name: str, + ) -> None: + if self._thread_repo is None or self._entity_repo is None or self._member_repo is None or not parent_thread_id: + return + existing_thread = self._thread_repo.get_by_id(thread_id) + if existing_thread is not None: + if self._entity_repo.get_by_thread_id(thread_id) is None: + self._entity_repo.create( + EntityRow( + id=thread_id, + type="agent", + member_id=existing_thread["member_id"], + name=agent_name, + thread_id=thread_id, + created_at=time.time(), + ) + ) + return + + parent_thread = self._thread_repo.get_by_id(parent_thread_id) + if parent_thread is None: + return + + member_id = parent_thread["member_id"] + member = self._member_repo.get_by_id(member_id) + if member is None: + return + + created_at = time.time() + branch_index = self._thread_repo.get_next_branch_index(member_id) + sandbox_type = parent_thread.get("sandbox_type") or "local" + cwd = parent_thread.get("cwd") + self._thread_repo.create( + thread_id=thread_id, + member_id=member_id, + sandbox_type=sandbox_type, + cwd=cwd, + created_at=created_at, + model=model_name or parent_thread.get("model"), + is_main=False, + branch_index=branch_index, + ) + + if self._entity_repo.get_by_thread_id(thread_id) is None: + self._entity_repo.create( + EntityRow( + id=thread_id, + type="agent", + member_id=member_id, + name=agent_name, + thread_id=thread_id, + created_at=created_at, + ) + ) + async def _handle_agent( self, prompt: str, @@ -213,15 +563,22 @@ async def _handle_agent( name: str | None = None, description: str | None = None, run_in_background: bool = False, + model: str | None = None, max_turns: int | None = None, - ) -> str: + fork_context: bool = False, + tool_context: ToolUseContext | None = None, + ) -> Any: """Spawn an independent LeonAgent and run it with the given prompt.""" from sandbox.thread_context import get_current_thread_id task_id = uuid.uuid4().hex[:8] agent_name = name or f"agent-{task_id}" - thread_id = f"subagent-{task_id}" parent_thread_id = get_current_thread_id() + existing_child = None + lookup_existing_child = getattr(self._agent_registry, "get_latest_by_name_and_parent", None) + if name and parent_thread_id and lookup_existing_child is not None: + existing_child = await lookup_existing_child(name, parent_thread_id) + thread_id = existing_child.thread_id if existing_child is not None and existing_child.status != "running" else f"subagent-{task_id}" # Register in AgentRegistry immediately entry = AgentEntry( @@ -233,6 +590,12 @@ async def _handle_agent( subagent_type=subagent_type, ) await self._agent_registry.register(entry) + self._ensure_subagent_thread_metadata( + thread_id=thread_id, + parent_thread_id=parent_thread_id, + agent_name=agent_name, + model_name=model or self._model_name, + ) # Create async task (independent LeonAgent runs inside) task = asyncio.create_task( @@ -243,33 +606,57 @@ async def _handle_agent( prompt, subagent_type, max_turns, + model=model, description=description or "", run_in_background=run_in_background, + fork_context=fork_context, + parent_tool_context=tool_context, ) ) if run_in_background: # True fire-and-forget: track in self._tasks for TaskOutput/TaskStop running = _RunningTask(task=task, agent_id=task_id, thread_id=thread_id, description=description or "") self._tasks[task_id] = running - return json.dumps( - { + return tool_success( + json.dumps( + { + "task_id": task_id, + "agent_name": agent_name, + "thread_id": thread_id, + "status": "running", + "message": "Agent started in background. Use TaskOutput to get result.", + }, + ensure_ascii=False, + ), + metadata={ "task_id": task_id, - "agent_name": agent_name, - "thread_id": thread_id, - "status": "running", - "message": "Agent started in background. Use TaskOutput to get result.", + "subagent_thread_id": thread_id, + "description": description or agent_name, }, - ensure_ascii=False, ) # Default: parent blocks until sub-agent completes (does not block frontend event loop) try: result = await task await self._agent_registry.update_status(task_id, "completed") - return result + return tool_success( + result, + metadata={ + "task_id": task_id, + "subagent_thread_id": thread_id, + "description": description or agent_name, + }, + ) except Exception as e: await self._agent_registry.update_status(task_id, "error") - return f"Agent failed: {e}" + return tool_error( + f"Agent failed: {e}", + metadata={ + "task_id": task_id, + "subagent_thread_id": thread_id, + "description": description or agent_name, + }, + ) async def _run_agent( self, @@ -279,8 +666,11 @@ async def _run_agent( prompt: str, subagent_type: str, max_turns: int | None, + model: str | None = None, description: str = "", run_in_background: bool = False, + fork_context: bool = False, + parent_tool_context: ToolUseContext | None = None, ) -> str: """Create and run an independent LeonAgent, collect its text output.""" # Isolate this sub-agent from the parent's LangChain callback chain. @@ -294,48 +684,164 @@ async def _run_agent( var_child_runnable_config.set(None) - # Lazy import avoids circular dependency (agent.py imports AgentService) - from core.runtime.agent import create_leon_agent from sandbox.thread_context import get_current_thread_id, set_current_thread_id parent_thread_id = get_current_thread_id() + self._ensure_subagent_thread_metadata( + thread_id=thread_id, + parent_thread_id=parent_thread_id, + agent_name=agent_name, + model_name=model or self._model_name, + ) # emit_fn is set if EventBus is available; used for task lifecycle SSE events - emit_fn = None + emit_fn: EventEmitter | None = None try: from backend.web.event_bus import get_event_bus - event_bus = get_event_bus() - emit_fn = event_bus.make_emitter( - thread_id=parent_thread_id, - agent_id=task_id, - agent_name=agent_name, - ) + if parent_thread_id: + event_bus = get_event_bus() + emit_fn = event_bus.make_emitter( + thread_id=parent_thread_id, + agent_id=task_id, + agent_name=agent_name, + ) except ImportError: pass # backend not available in standalone core usage - agent = None + agent: LeonAgent | None = None + progress_task: asyncio.Task | None = None + progress_stop: asyncio.Event | None = None + child_bootstrap_start_cost = 0.0 + child_bootstrap_start_tool_duration_ms = 0 try: - agent = create_leon_agent( - model_name=self._model_name, - workspace_root=self._workspace_root, - verbose=False, - ) + # Sub-agent context trimming: each spawn creates a fresh LeonAgent + # with its own _build_system_prompt(). No CLAUDE.md content or + # gitStatus is injected into the prompt pipeline (core/runtime/prompts + # has no such injection). Therefore explore/plan/bash sub-agents + # already run lightweight — no extra trimming is needed. + # + # Try to use context fork from parent agent's BootstrapConfig. + # Falls back to create_leon_agent when bootstrap is not available. + # Compute tool filtering for this sub-agent type + extra_blocked, allowed = _get_tool_filters(subagent_type) + agent_name_for_role = _get_subagent_agent_name(subagent_type) + + try: + from core.runtime.fork import create_subagent_context + from core.runtime.fork import fork_context as fork_bootstrap + + # Parent bootstrap is stored on the ToolUseContext or agent instance. + # AgentService stores workspace_root and model_name directly; use those + # to check if a richer bootstrap is available via a shared reference. + # _parent_bootstrap is injected by LeonAgent when building AgentService. + parent_bootstrap = getattr(self, "_parent_bootstrap", None) + child_tool_context = None + if parent_tool_context is not None: + child_tool_context = create_subagent_context(parent_tool_context) + child_bootstrap = child_tool_context.bootstrap + elif parent_bootstrap is not None: + child_bootstrap = fork_bootstrap(parent_bootstrap) + selected_model = _resolve_subagent_model( + self._workspace_root, + subagent_type, + model, + child_bootstrap.model_name, + self._model_name, + ) + agent = self._child_agent_factory( + model_name=selected_model, + workspace_root=child_bootstrap.workspace_root, + sandbox=self._normalize_child_sandbox(getattr(child_bootstrap, "sandbox_type", None)), + agent=agent_name_for_role, + web_app=self._web_app, + extra_blocked_tools=extra_blocked, + allowed_tools=allowed, + verbose=False, + ) + else: + raise AttributeError("no parent bootstrap") + child_bootstrap_start_cost = float(getattr(child_bootstrap, "total_cost_usd", 0.0)) + child_bootstrap_start_tool_duration_ms = int(getattr(child_bootstrap, "total_tool_duration_ms", 0)) + if parent_tool_context is not None: + # @@@sa-05-subagent-policy-resolution + # Role-specific tool envelopes and model priority order must + # be resolved explicitly here instead of leaking through + # prompt text or whichever defaults happen to win later. + selected_model = _resolve_subagent_model( + self._workspace_root, + subagent_type, + model, + child_bootstrap.model_name, + self._model_name, + ) + agent = self._child_agent_factory( + model_name=selected_model, + workspace_root=child_bootstrap.workspace_root, + sandbox=self._normalize_child_sandbox(getattr(child_bootstrap, "sandbox_type", None)), + agent=agent_name_for_role, + web_app=self._web_app, + extra_blocked_tools=extra_blocked, + allowed_tools=allowed, + verbose=False, + ) + # @@@sa-04-child-bootstrap-wiring + # Keep the forked bootstrap/context handoff behind an explicit + # LeonAgent API so AgentService stops reaching into QueryLoop + # internals directly. + assert agent is not None + agent.apply_forked_child_context( + child_bootstrap, + tool_context=child_tool_context, + ) + except (AttributeError, ImportError): + inherited_model = getattr(parent_tool_context.bootstrap, "model_name", None) if parent_tool_context else None + selected_model = _resolve_subagent_model( + self._workspace_root, + subagent_type, + model, + inherited_model or self._model_name, + self._model_name, + ) + agent = self._child_agent_factory( + model_name=selected_model, + workspace_root=self._workspace_root, + sandbox=self._normalize_child_sandbox( + getattr(parent_tool_context.bootstrap, "sandbox_type", None) if parent_tool_context else None + ), + agent=agent_name_for_role, + web_app=self._web_app, + extra_blocked_tools=extra_blocked, + allowed_tools=allowed, + verbose=False, + ) # In async context LeonAgent defers checkpointer init; call ainit() to # ensure state is persisted (and loadable via GET /api/threads/{thread_id}). + assert agent is not None await agent.ainit() + # @@@subagent-prompt-path-sanitize - Parent models sometimes satisfy + # "use absolute paths" by appending natural-language cwd labels onto the + # real workspace path. Normalize the obvious fake suffix before dispatch. + child_workspace_root = Path(getattr(agent, "workspace_root", self._workspace_root)) + prompt = _normalize_child_workspace_prompt(prompt, child_workspace_root) + + if parent_thread_id and parent_thread_id != thread_id: + from sandbox.manager import bind_thread_to_existing_thread_lease + + bind_thread_to_existing_thread_lease(thread_id, parent_thread_id) # Wire child agent events to the parent's EventBus subscription # so the parent SSE stream shows sub-agent activity. if emit_fn is not None: - if hasattr(agent, "runtime") and hasattr(agent.runtime, "bind_thread"): - agent.runtime.bind_thread(activity_sink=emit_fn) + runtime = getattr(agent, "runtime", None) + if runtime is not None and hasattr(runtime, "bind_thread"): + runtime.bind_thread(activity_sink=emit_fn) set_current_thread_id(thread_id) # Notify frontend: task started if emit_fn is not None: - await emit_fn( + emission = emit_fn( { "event": "task_start", "data": json.dumps( @@ -350,38 +856,95 @@ async def _run_agent( ), } ) + if asyncio.iscoroutine(emission): + await emission config = {"configurable": {"thread_id": thread_id}} output_parts: list[str] = [] + latest_progress = description or agent_name + + if run_in_background and self._queue_manager and parent_thread_id and self._background_progress_interval_s > 0: + progress_stop = asyncio.Event() + progress_task = asyncio.create_task( + self._emit_background_progress( + task_id=task_id, + agent_name=agent_name, + parent_thread_id=parent_thread_id, + latest_progress=lambda: latest_progress, + stop_event=progress_stop, + ) + ) + + # Build initial input — with or without forked parent context + if fork_context: + from sandbox.thread_context import get_current_messages + + # @@@pt-04-fork-context-source + # The Agent tool already has an explicit parent ToolUseContext on + # the live ToolRunner path. Forked sub-agents must prefer that + # concrete message snapshot over ambient ContextVar state, or the + # direct runner path silently drops parent context. + parent_msgs = list(parent_tool_context.messages) if parent_tool_context is not None else get_current_messages() + fork_marker = ( + "\n\n### ENTERING SUB-AGENT ROUTINE ###\n" + "Messages above are from the parent thread (read-only context).\n" + "Only complete the specific task assigned below.\n\n" + ) + initial_messages: list = [ + *_filter_fork_messages(parent_msgs), + {"role": "user", "content": fork_marker + prompt}, + ] + else: + initial_messages = [{"role": "user", "content": prompt}] + + if self._web_app is not None: + from backend.web.services.streaming_service import run_child_thread_live - async for chunk in agent.agent.astream( - {"messages": [{"role": "user", "content": prompt}]}, - config=config, - stream_mode="updates", - ): - for _, node_update in chunk.items(): - if not isinstance(node_update, dict): - continue - msgs = node_update.get("messages", []) - if not isinstance(msgs, list): - msgs = [msgs] - for msg in msgs: - if msg.__class__.__name__ == "AIMessage": - content = getattr(msg, "content", "") - if isinstance(content, str) and content: - output_parts.append(content) - elif isinstance(content, list): - for block in content: - if isinstance(block, dict) and block.get("type") == "text": - text = block.get("text", "") - if text: - output_parts.append(text) + result = await run_child_thread_live( + agent, + thread_id, + prompt, + self._web_app, + input_messages=initial_messages, + ) + if result: + output_parts.append(result) + latest_progress = self._summarize_progress(result, description or agent_name) + else: + async for chunk in agent.agent.astream( + {"messages": initial_messages}, + config=config, + stream_mode="updates", + ): + for _, node_update in chunk.items(): + if not isinstance(node_update, dict): + continue + msgs = node_update.get("messages", []) + if not isinstance(msgs, list): + msgs = [msgs] + for msg in msgs: + if msg.__class__.__name__ == "AIMessage": + content = getattr(msg, "content", "") + if isinstance(content, str) and content: + output_parts.append(content) + latest_progress = self._summarize_progress(content, description or agent_name) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + if text: + output_parts.append(text) + latest_progress = self._summarize_progress(text, description or agent_name) await self._agent_registry.update_status(task_id, "completed") result = "\n".join(output_parts) or "(Agent completed with no text output)" + if progress_stop is not None: + progress_stop.set() + if progress_task is not None: + await progress_task # Notify frontend: task done if emit_fn is not None: - await emit_fn( + emission = emit_fn( { "event": "task_done", "data": json.dumps( @@ -393,6 +956,8 @@ async def _run_agent( ), } ) + if asyncio.iscoroutine(emission): + await emission # Queue notification only for background runs — blocking callers already # received the result as the tool's return value; sending a notification # would trigger a spurious new parent turn. @@ -402,18 +967,23 @@ async def _run_agent( task_id=task_id, status="completed", summary=label, + result=result, description=label, ) self._queue_manager.enqueue(notification, parent_thread_id, notification_type="agent") return result except Exception: + if progress_stop is not None: + progress_stop.set() + if progress_task is not None: + await progress_task logger.exception("[AgentService] Agent %s failed", agent_name) await self._agent_registry.update_status(task_id, "error") # Notify frontend: task error if emit_fn is not None: try: - await emit_fn( + emission = emit_fn( { "event": "task_error", "data": json.dumps( @@ -425,6 +995,8 @@ async def _run_agent( ), } ) + if asyncio.iscoroutine(emission): + await emission except Exception: pass if run_in_background and self._queue_manager and parent_thread_id: @@ -433,6 +1005,7 @@ async def _run_agent( task_id=task_id, status="error", summary=label, + result="Agent failed", description=label, ) self._queue_manager.enqueue(notification, parent_thread_id, notification_type="agent") @@ -440,37 +1013,252 @@ async def _run_agent( finally: if agent is not None: try: - agent.close() + self._merge_child_bootstrap_accumulators( + getattr(self, "_parent_bootstrap", None), + getattr(agent, "_bootstrap", None), + child_bootstrap_start_cost=child_bootstrap_start_cost, + child_bootstrap_start_tool_duration_ms=child_bootstrap_start_tool_duration_ms, + ) + if hasattr(agent, "_agent_service") and hasattr(agent._agent_service, "cleanup_background_runs"): + await agent._agent_service.cleanup_background_runs() + # @@@web-child-persistence - web child threads are user-visible + # thread surfaces. Closing the LeonAgent here marks runtime + # terminated and drops its live/checkpoint bridge right after + # completion, so the child tab collapses to an empty shell. + if self._web_app is None: + # @@@subagent-sandbox-close-skip - Child agents can share the + # parent's lease; closing the child sandbox here can pause the + # shared lease mid-owner-turn. + agent.close(cleanup_sandbox=False) except Exception: pass - async def _handle_task_output(self, task_id: str) -> str: + @staticmethod + def _merge_child_bootstrap_accumulators( + parent_bootstrap: Any, + child_bootstrap: Any, + *, + child_bootstrap_start_cost: float, + child_bootstrap_start_tool_duration_ms: int, + ) -> None: + if parent_bootstrap is None or child_bootstrap is None or parent_bootstrap is child_bootstrap: + return + # @@@sa-03-bootstrap-rollup + # Sub-agent loops start from a forked bootstrap snapshot. At join time we + # need to preserve both the parent's concurrent growth and the child's + # post-fork delta instead of letting one side overwrite the other. + child_cost_delta = max( + 0.0, + float(getattr(child_bootstrap, "total_cost_usd", 0.0)) - child_bootstrap_start_cost, + ) + child_tool_duration_delta = max( + 0, + int(getattr(child_bootstrap, "total_tool_duration_ms", 0)) - child_bootstrap_start_tool_duration_ms, + ) + parent_bootstrap.total_cost_usd = float(getattr(parent_bootstrap, "total_cost_usd", 0.0)) + child_cost_delta + parent_bootstrap.total_tool_duration_ms = int(getattr(parent_bootstrap, "total_tool_duration_ms", 0)) + child_tool_duration_delta + + @staticmethod + def _summarize_progress(text: str, fallback: str) -> str: + collapsed = " ".join(text.split()).strip() + if not collapsed: + return fallback + return collapsed[:120] + + async def _emit_background_progress( + self, + *, + task_id: str, + agent_name: str, + parent_thread_id: str, + latest_progress: Any, + stop_event: asyncio.Event, + ) -> None: + # @@@sa-06-progress-loop - keep prompt-facing coordinator updates on the + # real thread delivery queue instead of inventing a detached parallel channel. + while True: + try: + await asyncio.wait_for(stop_event.wait(), timeout=self._background_progress_interval_s) + return + except TimeoutError: + pass + + if self._queue_manager is None: + return + + notification = format_progress_notification( + task_id, + latest_progress(), + step="running", + ) + self._queue_manager.enqueue( + notification, + parent_thread_id, + notification_type="agent", + source="system", + sender_name=agent_name, + ) + + async def _handle_task_output(self, task_id: str, block: bool = True, timeout: int = 30_000) -> str: """Get output of a background agent task.""" running = self._tasks.get(task_id) if not running: return f"Error: task '{task_id}' not found" + if not block: + if not running.is_done: + return json.dumps( + { + "task_id": task_id, + "status": "running", + "message": _background_run_running_message(running), + }, + ensure_ascii=False, + ) + + result = running.get_result() + return json.dumps( + { + "task_id": task_id, + "status": _background_run_result_status(result), + "result": result, + }, + ensure_ascii=False, + ) + + if not running.is_done: + completed = await _wait_for_background_run(running, min(timeout, 600_000)) + if not completed and not running.is_done: + return json.dumps( + { + "task_id": task_id, + "status": "timeout", + "message": _background_run_running_message(running), + }, + ensure_ascii=False, + ) + if not running.is_done: return json.dumps( { "task_id": task_id, "status": "running", - "message": "Agent is still running.", + "message": _background_run_running_message(running), }, ensure_ascii=False, ) result = running.get_result() - status = "error" if (result and result.startswith("")) else "completed" return json.dumps( { "task_id": task_id, - "status": status, + "status": _background_run_result_status(result), "result": result, }, ensure_ascii=False, ) + async def _handle_send_message( + self, + target_name: str, + message: str, + sender_name: str | None = None, + ) -> str: + if self._queue_manager is None: + return "SendMessage requires queue_manager" + + matches = await self._agent_registry.list_running_by_name(target_name) + if not matches: + return f"Running agent '{target_name}' not found" + if len(matches) > 1: + return ( + f"Running agent name '{target_name}' is ambiguous. " + "Use a unique name before calling SendMessage." + ) + target = matches[0] + + delivered = format_agent_message(sender_name or "agent", message) + self._queue_manager.enqueue( + delivered, + target.thread_id, + notification_type="agent", + source="system", + sender_name=sender_name or "agent", + ) + return f"Message sent to {target.name}." + + async def _handle_ask_user_question( + self, + questions: list[dict[str, Any]], + annotations: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + tool_context: ToolUseContext | None = None, + ) -> Any: + if tool_context is None or tool_context.request_permission is None: + return tool_error("AskUserQuestion requires an interactive owner resolver") + + payload: dict[str, Any] = {"questions": questions} + if annotations is not None: + payload["annotations"] = annotations + if metadata is not None: + payload["metadata"] = metadata + + request_result = tool_context.request_permission( + "AskUserQuestion", + payload, + ToolPermissionContext(is_read_only=True, is_destructive=False), + None, + "Please answer the following questions so Leon can continue.", + ) + request_id = request_result.get("request_id") if isinstance(request_result, dict) else request_result + if not isinstance(request_id, str) or not request_id: + return tool_error("AskUserQuestion could not create a user-facing request") + + return tool_permission_request( + "User input required to continue.", + metadata={ + "decision": "ask", + "request_id": request_id, + "request_kind": "ask_user_question", + }, + ) + + async def _stop_background_run(self, task_id: str, running: BackgroundRun) -> None: + if isinstance(running, _RunningTask): + was_running = not running.task.done() + if was_running: + running.task.cancel() + try: + await running.task + except asyncio.CancelledError: + pass + await self._agent_registry.update_status(running.agent_id, "error") + self._tasks.pop(task_id, None) + return + + if not running.is_done: + process = getattr(running._cmd, "process", None) + wait = getattr(process, "wait", None) if process is not None else None + terminate = getattr(process, "terminate", None) if process is not None else None + kill = getattr(process, "kill", None) if process is not None else None + + if callable(terminate): + terminate() + if callable(wait): + wait_fn = cast(Callable[[], Awaitable[Any]], wait) + try: + await asyncio.wait_for(wait_fn(), timeout=1.0) + except TimeoutError: + if callable(kill): + kill() + await wait_fn() + + self._tasks.pop(task_id, None) + + async def cleanup_background_runs(self) -> None: + for task_id, running in list(self._tasks.items()): + await self._stop_background_run(task_id, running) + async def _handle_task_stop(self, task_id: str) -> str: """Stop a running background agent task.""" running = self._tasks.get(task_id) @@ -480,6 +1268,5 @@ async def _handle_task_stop(self, task_id: str) -> str: if running.is_done: return f"Task {task_id} already completed" - running.task.cancel() - await self._agent_registry.update_status(running.agent_id, "error") + await self._stop_background_run(task_id, running) return f"Task {task_id} cancelled" diff --git a/core/runner.py b/core/runner.py index 6c3902e3c..fddd6b135 100644 --- a/core/runner.py +++ b/core/runner.py @@ -153,7 +153,7 @@ def _print_memory_stats(self, status: dict) -> None: def _process_chunk(self, chunk: dict, result: dict) -> None: """Process streaming chunk, extract tool calls and response""" - for node_name, node_update in chunk.items(): + for _node_name, node_update in chunk.items(): if not isinstance(node_update, dict): continue diff --git a/core/runtime/abort.py b/core/runtime/abort.py new file mode 100644 index 000000000..f95ca4e2f --- /dev/null +++ b/core/runtime/abort.py @@ -0,0 +1,48 @@ +"""Minimal abort controller tree for runtime lifecycle wiring.""" + +from __future__ import annotations + +from collections.abc import Callable + + +class AbortController: + def __init__(self) -> None: + self._aborted = False + self._listeners: dict[int, Callable[[], None]] = {} + self._next_listener_id = 0 + + def abort(self) -> None: + if self._aborted: + return + self._aborted = True + listeners = list(self._listeners.values()) + self._listeners.clear() + for listener in listeners: + listener() + + def is_aborted(self) -> bool: + return self._aborted + + def on_abort(self, listener: Callable[[], None]) -> Callable[[], None]: + if self._aborted: + listener() + return lambda: None + + listener_id = self._next_listener_id + self._next_listener_id += 1 + self._listeners[listener_id] = listener + + def unsubscribe() -> None: + self._listeners.pop(listener_id, None) + + return unsubscribe + + +def create_child_abort_controller(parent: AbortController | None) -> AbortController: + child = AbortController() + if parent is None: + return child + + unsubscribe = parent.on_abort(child.abort) + child.on_abort(unsubscribe) + return child diff --git a/core/runtime/agent.py b/core/runtime/agent.py index e4d7299c6..89c8eb172 100644 --- a/core/runtime/agent.py +++ b/core/runtime/agent.py @@ -18,18 +18,18 @@ All paths must be absolute. Full security mechanisms and audit logging. """ +import asyncio +import concurrent.futures +import inspect +import logging import os -import threading from pathlib import Path from typing import Any -from langchain.agents import create_agent from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver -from config.schema import DEFAULT_MODEL - # Load .env file _env_file = Path(__file__).parent / ".env" if _env_file.exists(): @@ -53,6 +53,11 @@ # Import file operation recorder for time travel from core.operations import get_recorder # noqa: E402 + +# New architecture: ToolRegistry + ToolRunner + Services +from core.runtime.cleanup import CleanupRegistry # noqa: E402 +from core.runtime.loop import QueryLoop # noqa: E402 +from core.runtime.middleware.mcp_instructions import McpInstructionsDeltaMiddleware # noqa: E402 from core.runtime.middleware.memory import MemoryMiddleware # noqa: E402 from core.runtime.middleware.monitor import MonitorMiddleware, apply_usage_patches # noqa: E402 from core.runtime.middleware.prompt_caching import PromptCachingMiddleware # noqa: E402 @@ -60,10 +65,9 @@ # Middleware imports (migrated paths) from core.runtime.middleware.spill_buffer import SpillBufferMiddleware # noqa: E402 - -# New architecture: ToolRegistry + ToolRunner + Services -from core.runtime.registry import ToolRegistry # noqa: E402 +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema # noqa: E402 from core.runtime.runner import ToolRunner # noqa: E402 +from core.runtime.state import AppState, BootstrapConfig # noqa: E402 from core.runtime.validator import ToolValidator # noqa: E402 # Hooks (used by Services) @@ -71,7 +75,9 @@ from core.tools.command.hooks.file_access_logger import FileAccessLoggerHook # noqa: E402 from core.tools.command.hooks.file_permission import FilePermissionHook # noqa: E402 from core.tools.command.service import CommandService # noqa: E402 +from core.tools.cron.service import CronToolService # noqa: E402 from core.tools.filesystem.service import FileSystemService # noqa: E402 +from core.tools.mcp_resources.service import McpResourceToolService # noqa: E402 from core.tools.search.service import SearchService # noqa: E402 from core.tools.skills.service import SkillsService # noqa: E402 from core.tools.task.service import TaskService # noqa: E402 @@ -82,10 +88,41 @@ from core.tools.web.service import WebService # noqa: E402 from storage.container import StorageContainer # noqa: E402 +logger = logging.getLogger(__name__) + # @@@langchain-anthropic-streaming-usage-regression apply_usage_patches() +def _make_mcp_tool_entry(tool) -> ToolEntry: + schema_model = getattr(tool, "tool_call_schema", None) + if schema_model is not None and hasattr(schema_model, "model_json_schema"): + parameters = schema_model.model_json_schema() + else: + parameters = { + "type": "object", + "properties": getattr(tool, "args", {}) or {}, + } + + async def mcp_handler(**kwargs): + if hasattr(tool, "ainvoke"): + return await tool.ainvoke(kwargs) + return await asyncio.to_thread(tool.invoke, kwargs) + + return ToolEntry( + name=tool.name, + mode=ToolMode.INLINE, + schema=make_tool_schema( + name=tool.name, + description=getattr(tool, "description", "") or tool.name, + properties={}, + parameter_overrides=parameters, + ), + handler=mcp_handler, + source="mcp", + ) + + class LeonAgent: """ Leon Agent - AI Coding Assistant @@ -108,6 +145,7 @@ def __init__( workspace_root: str | Path | None = None, *, agent: str | None = None, + bundle_dir: str | Path | None = None, allowed_file_extensions: list[str] | None = None, block_dangerous_commands: bool | None = None, block_network_commands: bool | None = None, @@ -119,9 +157,16 @@ def __init__( jina_api_key: str | None = None, sandbox: Any = None, storage_container: StorageContainer | None = None, + thread_repo: Any = None, + entity_repo: Any = None, + member_repo: Any = None, queue_manager: MessageQueueManager | None = None, chat_repos: dict | None = None, + web_app: Any = None, extra_allowed_paths: list[str] | None = None, + extra_blocked_tools: set[str] | None = None, + allowed_tools: set[str] | None = None, + permission_resolver_scope: str = "none", verbose: bool = False, ): """ @@ -138,7 +183,11 @@ def __init__( enable_audit_log: Whether to enable audit logging enable_web_tools: Whether to enable web search and content fetching tools sandbox: Sandbox instance, name string, or None for local + thread_repo: Optional thread metadata repo for backend-integrated subagent registration + entity_repo: Optional entity repo for backend-integrated subagent registration + member_repo: Optional member repo for backend-integrated subagent registration queue_manager: Shared MessageQueueManager instance (created if not provided) + permission_resolver_scope: Permission request surface for this agent ("none" or "thread") verbose: Whether to output detailed logs (default False) """ self.agent_id: str | None = None @@ -146,11 +195,23 @@ def __init__( self.extra_allowed_paths = extra_allowed_paths self.queue_manager = queue_manager or MessageQueueManager() self._chat_repos: dict | None = chat_repos + self._thread_repo = thread_repo + self._entity_repo = entity_repo + self._member_repo = member_repo + self._web_app = web_app + self._session_started = False + self._session_ended = False + self._closing = False + self._closed = False + requested_sandbox_name = sandbox if isinstance(sandbox, str) else getattr(sandbox, "name", None) + self._explicit_model_name = model_name is not None # New config system mode self.config, self.models_config = self._load_config( agent_name=agent, + bundle_dir=bundle_dir, workspace_root=workspace_root, + sandbox_name=requested_sandbox_name, model_name=model_name, api_key=api_key, allowed_file_extensions=allowed_file_extensions, @@ -167,8 +228,9 @@ def __init__( from config.schema import DEFAULT_MODEL # noqa: E402 active_model = DEFAULT_MODEL - # Member model override: agent.md's model field takes precedence over global config - if hasattr(self, "_agent_override") and self._agent_override and self._agent_override.model: + # Agent frontmatter model applies only when the caller did not explicitly + # request a model at construction time. + if not self._explicit_model_name and hasattr(self, "_agent_override") and self._agent_override and self._agent_override.model: active_model = self._agent_override.model resolved_model, model_overrides = self.models_config.resolve_model(active_model) self.model_name = resolved_model @@ -177,6 +239,7 @@ def __init__( # Resolve API key (prefer resolved provider from mapping) provider_name = self._resolve_provider_name(resolved_model, model_overrides) p = self.models_config.get_provider(provider_name) if provider_name else None + self._explicit_api_key = api_key is not None self.api_key = api_key or (p.api_key if p else None) or self.models_config.get_api_key() if not self.api_key: @@ -213,6 +276,7 @@ def __init__( } # Initialize checkpointer and MCP tools + self.checkpointer = None self._aiosqlite_conn, mcp_tools = self._init_async_components() # If in async context (running loop detected), _init_async_components @@ -225,51 +289,61 @@ def __init__( self.checkpointer = None # Initialize ToolRegistry and Services (new architecture) - self._tool_registry = ToolRegistry(blocked_tools=self._get_member_blocked_tools()) + blocked = self._get_member_blocked_tools() + if extra_blocked_tools: + blocked = blocked | extra_blocked_tools + self._tool_registry = ToolRegistry( + blocked_tools=blocked, + allowed_tools=allowed_tools, + ) self._init_services() + self._register_mcp_tools(mcp_tools) # Build middleware stack middleware = self._build_middleware_stack() - # Ensure ToolNode is created (middleware tools need at least one BaseTool) + # Ensure the bound model still sees at least one BaseTool-compatible entry. if not mcp_tools and not self._has_middleware_tools(middleware): mcp_tools = [self._create_placeholder_tool()] - # Build system prompt - self.system_prompt = self._build_system_prompt() - custom_prompt = self.config.system_prompt - if custom_prompt: - self.system_prompt += f"\n\n**Custom Instructions:**\n{custom_prompt}" - - # @@@entity-identity — inject chat identity so agent knows who it is in the social layer - if self._chat_repos: - repos = self._chat_repos - uid = repos.get("user_id") - owner_uid = repos.get("owner_user_id", "") - if uid: - entity_repo = repos.get("entity_repo") - entity = entity_repo.get_by_id(uid) if entity_repo else None - member_repo = repos.get("member_repo") - owner_row = member_repo.get_by_id(owner_uid) if member_repo and owner_uid else None - name = entity.name if entity else uid - owner_name = owner_row.name if owner_row else "unknown" - self.system_prompt += ( - f"\n\n**Chat Identity:**\n" - f"- Your name: {name}\n" - f"- Your user_id: {uid}\n" - f"- Your owner: {owner_name} (user_id: {owner_uid})\n" - f"- When you receive a chat notification, READ the message with chat_read(), " - f"then REPLY with chat_send(). Your text output goes to your owner's thread, " - f"not to the chat — only chat_send() delivers to the other party.\n" - ) + self._system_prompt_section_cache: dict[str, str] = {} + self.system_prompt = self._compose_system_prompt() - # Create agent - self.agent = create_agent( + # Build BootstrapConfig for sub-agent forking + self._bootstrap = BootstrapConfig( + workspace_root=self.workspace_root, + original_cwd=Path.cwd(), + project_root=self.workspace_root, + cwd=self.workspace_root, + model_name=self.model_name, + api_key=self.api_key, + sandbox_type=self._sandbox.name, + permission_resolver_scope=permission_resolver_scope, + block_dangerous_commands=self.block_dangerous_commands, + block_network_commands=self.block_network_commands, + enable_audit_log=self.enable_audit_log, + enable_web_tools=self.enable_web_tools, + allowed_file_extensions=self.allowed_file_extensions, + extra_allowed_paths=self.extra_allowed_paths, + model_provider=self._current_model_config.get("model_provider"), + base_url=self._current_model_config.get("base_url"), + ) + self._app_state = AppState() + self.app_state = self._app_state + # Inject bootstrap into AgentService so sub-agents can fork from it + if hasattr(self, "_agent_service"): + self._agent_service._parent_bootstrap = self._bootstrap + + # Create agent via QueryLoop (replaces LangGraph create_agent) + self.agent = QueryLoop( model=self.model, - tools=mcp_tools, system_prompt=SystemMessage(content=[{"type": "text", "text": self.system_prompt}]), middleware=middleware, - checkpointer=self.checkpointer if not self._needs_async_init else None, + checkpointer=self.checkpointer, + registry=self._tool_registry, + app_state=self._app_state, + runtime=self._monitor_middleware.runtime, + bootstrap=self._bootstrap, ) # Get runtime from MonitorMiddleware @@ -286,13 +360,39 @@ def __init__( print("[LeonAgent] Initialized successfully") print(f"[LeonAgent] Workspace: {self.workspace_root}") print(f"[LeonAgent] Audit log: {self.enable_audit_log}") - if self._needs_async_init: + if self.checkpointer is None: print("[LeonAgent] Note: Async components need initialization via ainit()") - # Mark agent as ready (if not needing async init) - if not self._needs_async_init: + # Wire CleanupRegistry for priority-ordered resource teardown + self._cleanup_registry = CleanupRegistry() + self._cleanup_registry.register(self._cleanup_sandbox, priority=2) + self._cleanup_registry.register(self._mark_terminated, priority=3) + self._cleanup_registry.register(self._cleanup_mcp_client, priority=4) + self._cleanup_registry.register(self._cleanup_sqlite_connection, priority=5) + + # Mark agent as ready (checkpointer is None when async init still pending) + if self.checkpointer is not None: self._monitor_middleware.mark_ready() + def apply_forked_child_context( + self, + bootstrap: BootstrapConfig, + *, + tool_context: Any | None = None, + ) -> None: + # @@@subagent-fork-wiring + # AgentService should not reach through LeonAgent and mutate QueryLoop + # internals directly. Keep the child bootstrap + abort-controller wiring + # behind one explicit LeonAgent seam. + self._bootstrap = bootstrap + self.agent._bootstrap = bootstrap + if hasattr(self, "_agent_service"): + self._agent_service._parent_bootstrap = bootstrap + if tool_context is not None: + self._agent_service._parent_tool_context = tool_context + if tool_context is not None: + self.agent._tool_abort_controller = tool_context.abort_controller + async def ainit(self): """Complete async initialization (call this if initialized in async context). @@ -300,22 +400,23 @@ async def ainit(self): agent = LeonAgent(sandbox=sandbox) await agent.ainit() """ - if not self._needs_async_init: - return # Already initialized + if self.checkpointer is None: + # Initialize async components + self._aiosqlite_conn = await self._init_checkpointer() + _mcp_tools = await self._init_mcp_tools() + self._register_mcp_tools(_mcp_tools) - # Initialize async components - self._aiosqlite_conn = await self._init_checkpointer() - _mcp_tools = await self._init_mcp_tools() + # Update agent with checkpointer + self.agent.checkpointer = self.checkpointer - # Update agent with checkpointer - self.agent.checkpointer = self.checkpointer + self._monitor_middleware.mark_ready() - # Mark as initialized - self._needs_async_init = False - self._monitor_middleware.mark_ready() + if self.verbose: + print("[LeonAgent] Async initialization completed") - if self.verbose: - print("[LeonAgent] Async initialization completed") + if not self._session_started: + await self._run_session_hooks("SessionStart") + self._session_started = True def _init_async_components(self) -> tuple[Any, list]: """Initialize async components (checkpointer and MCP tools). @@ -350,13 +451,22 @@ def _has_middleware_tools(self, middleware: list) -> bool: """Check if any middleware has BaseTool instances.""" return any(getattr(m, "tools", None) for m in middleware) + def _register_mcp_tools(self, mcp_tools: list) -> None: + if not mcp_tools: + return + for tool in mcp_tools: + try: + self._tool_registry.register(_make_mcp_tool_entry(tool)) + except Exception as exc: + logger.warning("[LeonAgent] Failed to register MCP tool %s: %s", getattr(tool, "name", ""), exc) + def _create_placeholder_tool(self): - """Create placeholder tool to ensure ToolNode is created.""" + """Create placeholder tool so the bound model still has a BaseTool.""" from langchain_core.tools import tool @tool def _placeholder() -> str: - """Internal placeholder - ensures ToolNode is created for middleware tools.""" + """Internal placeholder for the empty-tool edge.""" return "" return _placeholder @@ -391,10 +501,26 @@ def _get_member_blocked_tools(self) -> set[str]: return blocked + def _get_mcp_server_configs(self) -> dict[str, Any]: + if hasattr(self, "_agent_bundle") and self._agent_bundle and self._agent_bundle.mcp: + return {name: srv for name, srv in self._agent_bundle.mcp.items() if not srv.disabled} + return self.config.mcp.servers + + def _get_mcp_instruction_blocks(self) -> dict[str, str]: + blocks: dict[str, str] = {} + for name, cfg in self._get_mcp_server_configs().items(): + instructions = getattr(cfg, "instructions", None) + if not isinstance(instructions, str) or not instructions.strip(): + continue + blocks[name] = instructions.strip() + return blocks + def _load_config( self, agent_name: str | None, + bundle_dir: str | Path | None, workspace_root: str | Path | None, + sandbox_name: str | None, model_name: str | None, api_key: str | None, allowed_file_extensions: list[str] | None, @@ -410,8 +536,14 @@ def _load_config( """ # Build CLI overrides for runtime config cli_overrides: dict = {} - - if workspace_root is not None: + use_workspace_override = sandbox_name in (None, "", "local") + + if workspace_root is not None and use_workspace_override: + # @@@remote-sandbox-config-root + # Remote child agents may inherit a sandbox cwd like /home/daytona, + # which is valid inside the sandbox but not on the host. Feeding that + # path into LeonSettings makes config validation fail before sandbox + # init ever runs, so only local sandboxes pin workspace_root here. cli_overrides["workspace_root"] = str(workspace_root) # Runtime overrides go into "runtime" section @@ -441,8 +573,14 @@ def _load_config( models_loader = ModelsLoader(workspace_root=workspace_root) models_config = models_loader.load(cli_overrides=models_cli if models_cli else None) + # @@@bundle-dir-wins - member-backed top-level agents need their own bundle even when + # no explicit agent type name is passed through the thread runtime wiring. + if bundle_dir is not None: + bundle_path = Path(bundle_dir).expanduser().resolve() + self._agent_bundle = loader.load_bundle(bundle_path) + self._agent_override = self._agent_bundle.agent.model_copy(update={"source_dir": bundle_path}) # If agent specified, load agent definition to override system_prompt and tools - if agent_name: + elif agent_name: all_agents = loader.load_all_agents() agent_def = all_agents.get(agent_name) if not agent_def: @@ -609,7 +747,16 @@ def _build_model_kwargs(self) -> dict: # Get credentials from the resolved provider p = self.models_config.get_provider(provider) if provider else None - base_url = (p.base_url if p else None) or self.models_config.get_base_url() + env_base_url = os.getenv("ANTHROPIC_BASE_URL") or os.getenv("OPENAI_BASE_URL") + + # @@@explicit-api-key-base-url + # Real-model verification must not be silently redirected to a provider + # config endpoint when the caller explicitly injected credentials for a + # different OpenAI-compatible endpoint. + if self._explicit_api_key and env_base_url: + base_url = env_base_url + else: + base_url = (p.base_url if p else None) or self.models_config.get_base_url() if base_url: kwargs["base_url"] = self._normalize_base_url(base_url, provider) @@ -714,12 +861,71 @@ def update_observation(self, **overrides) -> None: if self.verbose: print(f"[LeonAgent] Observation updated: active={self._observation_config.active}") - def close(self): - """Clean up resources.""" - self._cleanup_sandbox() - self._mark_terminated() - self._cleanup_mcp_client() - self._cleanup_sqlite_connection() + def close(self, *, cleanup_sandbox: bool = True): + """Clean up resources via CleanupRegistry (priority-ordered). + + Falls back to direct cleanup if CleanupRegistry is not initialized. + """ + # @@@close-idempotent - child agents may explicitly skip sandbox cleanup + # and later still hit __del__ on GC; never let a second close silently + # re-enable default sandbox teardown on a shared lease. + if getattr(self, "_closed", False) or getattr(self, "_closing", False): + return + + self._closing = True + session_end_error: Exception | None = None + try: + if getattr(self, "_session_started", False) and not getattr(self, "_session_ended", False): + try: + self._run_async_cleanup(lambda: self._run_session_hooks("SessionEnd"), "SessionEnd hooks") + except Exception as exc: + session_end_error = exc + finally: + self._session_ended = True + + if hasattr(self, "_cleanup_registry") and cleanup_sandbox: + self._run_async_cleanup(self._cleanup_registry.run_cleanup, "CleanupRegistry") + else: + # Fallback for edge cases where __init__ did not complete fully + cleanup_steps = [ + ("monitor", self._mark_terminated), + ("MCP client", self._cleanup_mcp_client), + ("SQLite connection", self._cleanup_sqlite_connection), + ] + if cleanup_sandbox: + cleanup_steps.insert(0, ("sandbox", self._cleanup_sandbox)) + + for step_name, step_fn in cleanup_steps: + try: + step_fn() + except Exception as e: + print(f"[LeonAgent] {step_name} cleanup error: {e}") + + if session_end_error is not None: + raise session_end_error + finally: + self._closed = True + self._closing = False + + def _build_session_hook_payload(self, event: str) -> dict[str, Any]: + return { + "event": event, + "session_id": self._bootstrap.session_id, + "workspace_root": str(self.workspace_root), + "cwd": str(self._bootstrap.cwd or self.workspace_root), + "sandbox": self._sandbox.name, + } + + async def _run_session_hooks(self, event: str) -> None: + hooks = self._app_state.get_session_hooks(event) + if not hooks: + return + + payload = self._build_session_hook_payload(event) + for hook in hooks: + result = hook(payload) + if inspect.isawaitable(result): + await result def _cleanup_sandbox(self) -> None: """Clean up sandbox resources.""" @@ -734,32 +940,29 @@ def _mark_terminated(self) -> None: if hasattr(self, "_monitor_middleware"): self._monitor_middleware.mark_terminated() + _CLEANUP_TIMEOUT: float = 10.0 # seconds; prevents hanging on stuck I/O + @staticmethod def _run_async_cleanup(coro_factory, label: str) -> None: import asyncio try: - running_loop = asyncio.get_running_loop() + asyncio.get_running_loop() except RuntimeError: - running_loop = None - - if running_loop is None: asyncio.run(coro_factory()) return - error: list[Exception] = [] - - def _runner() -> None: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(asyncio.run, coro_factory()) try: - asyncio.run(coro_factory()) + future.result(timeout=LeonAgent._CLEANUP_TIMEOUT) + except concurrent.futures.TimeoutError: + raise RuntimeError( + f"{label} cleanup timed out after {LeonAgent._CLEANUP_TIMEOUT}s — " + f"possible stuck I/O; resource abandoned to prevent hang" + ) except Exception as exc: - error.append(exc) - - thread = threading.Thread(target=_runner, daemon=True) - thread.start() - thread.join() - if error: - raise RuntimeError(f"{label} cleanup failed: {error[0]}") from error[0] + raise RuntimeError(f"{label} cleanup failed: {exc}") from exc def _cleanup_mcp_client(self) -> None: """Clean up MCP client.""" @@ -767,35 +970,23 @@ def _cleanup_mcp_client(self) -> None: return try: - self._run_async_cleanup(lambda: self._mcp_client.close(), "MCP client") + close_fn = getattr(self._mcp_client, "close", None) + if callable(close_fn): + self._run_async_cleanup(close_fn, "MCP client") except Exception as e: print(f"[LeonAgent] MCP cleanup error: {e}") self._mcp_client = None def _cleanup_sqlite_connection(self) -> None: - """Clean up SQLite connection. - - Properly closes aiosqlite connection using asyncio.run() to avoid - hanging on process exit. - """ + """Clean up SQLite connection.""" if not hasattr(self, "_aiosqlite_conn") or not self._aiosqlite_conn: return - + conn = self._aiosqlite_conn + self._aiosqlite_conn = None try: - import asyncio - - # Close the connection asynchronously - async def _close(): - if self._aiosqlite_conn: - await self._aiosqlite_conn.close() - - # Use asyncio.run() to properly close the connection - asyncio.run(_close()) + self._run_async_cleanup(conn.close, "SQLite connection") except Exception: - # Ignore errors during cleanup pass - finally: - self._aiosqlite_conn = None def __del__(self): self.close() @@ -830,11 +1021,19 @@ def _build_middleware_stack(self) -> list: if memory_enabled: self._add_memory_middleware(middleware) - # 4. Steering — injects queued messages before model call + # 4. MCP instructions delta — thread-scoped reminder when MCP guidance changes + middleware.append( + McpInstructionsDeltaMiddleware( + get_instruction_blocks=self._get_mcp_instruction_blocks, + get_app_state=lambda: self.app_state, + ) + ) + + # 5. Steering — injects queued messages before model call self._steering_middleware = SteeringMiddleware(queue_manager=self.queue_manager) middleware.append(self._steering_middleware) - # 5. ToolRunner (innermost — routes all ToolRegistry-registered tool calls) + # 6. ToolRunner (innermost — routes all ToolRegistry-registered tool calls) self._tool_runner = ToolRunner( registry=self._tool_registry, validator=ToolValidator(), @@ -843,7 +1042,7 @@ def _build_middleware_stack(self) -> list: # 0. SpillBuffer (outermost — catches oversized tool outputs) # Must be inserted at index 0 AFTER building the list: - # LangChain wraps middlewares as "first = outermost". + # QueryLoop composes middleware so the first entry remains outermost. if self.config.tools.spill_buffer.enabled: spill_cfg = self.config.tools.spill_buffer middleware.insert( @@ -993,6 +1192,17 @@ def _init_services(self) -> None: workspace_root=self.workspace_root, ) + # Cron tools (DEFERRED - backed by existing panel cron_jobs substrate) + self._cron_tool_service = CronToolService( + registry=self._tool_registry, + ) + + self._mcp_resource_tool_service = McpResourceToolService( + registry=self._tool_registry, + client_fn=lambda: getattr(self, "_mcp_client", None), + server_configs_fn=self._get_mcp_server_configs, + ) + # ToolSearch (INLINE - always available for discovering DEFERRED tools) self._tool_search_service = ToolSearchService( registry=self._tool_registry, @@ -1005,8 +1215,13 @@ def _init_services(self) -> None: agent_registry=self._agent_registry, workspace_root=self.workspace_root, model_name=self.model_name, + thread_repo=self._thread_repo, + entity_repo=self._entity_repo, + member_repo=self._member_repo, queue_manager=self.queue_manager, shared_runs=self._background_runs, + web_app=self._web_app, + child_agent_factory=create_leon_agent, ) # Team coordination (TeamCreate/TeamDelete — deferred mode) @@ -1046,28 +1261,17 @@ def _init_services(self) -> None: runtime_fn=lambda: getattr(self, "runtime", None), ) - # @@@wechat-tools — register WeChat tools via lazy connection lookup - owner_uid = self._chat_repos.get("owner_user_id", "") if self._chat_repos else "" - if owner_uid: - try: - from core.tools.wechat.service import WeChatToolService - - def _get_wechat_conn(uid=owner_uid): - """Lazy lookup — returns None if registry not on app.state yet.""" - try: - from backend.web.main import app - - registry = getattr(app.state, "wechat_registry", None) - return registry.get(uid) if registry else None - except Exception: - return None + # LSP tools — DEFERRED, always registered, multilspy checked at call time + self._lsp_service = None + try: + from core.tools.lsp.service import LSPService - self._wechat_tool_service = WeChatToolService( - registry=self._tool_registry, - connection_fn=_get_wechat_conn, - ) - except ImportError: - self._wechat_tool_service = None + self._lsp_service = LSPService( + registry=self._tool_registry, + workspace_root=self.workspace_root, + ) + except Exception as e: + logger.debug("[LeonAgent] LSPService init skipped: %s", e) if self.verbose: all_tools = self._tool_registry.list_all() @@ -1078,11 +1282,7 @@ def _get_wechat_conn(uid=owner_uid): async def _init_mcp_tools(self) -> list: mcp_enabled = self.config.mcp.enabled - # Use member bundle MCP config if available, else fall back to global config - if hasattr(self, "_agent_bundle") and self._agent_bundle and self._agent_bundle.mcp: - mcp_servers = {name: srv for name, srv in self._agent_bundle.mcp.items() if not srv.disabled} - else: - mcp_servers = self.config.mcp.servers + mcp_servers = self._get_mcp_server_configs() if not mcp_enabled or not mcp_servers: return [] @@ -1091,10 +1291,21 @@ async def _init_mcp_tools(self) -> list: configs = {} for name, cfg in mcp_servers.items(): + transport = getattr(cfg, "transport", None) if cfg.url: - config = {"transport": "streamable_http", "url": cfg.url} + # @@@mcp-transport-honesty - api-04 requires explicit transport + # config to survive loader -> runtime. URL-based MCP is not + # always streamable_http; websocket/sse must stay explicit. + config = { + "transport": transport or "streamable_http", + "url": cfg.url, + } else: - config = {"transport": "stdio", "command": cfg.command, "args": cfg.args} + config = { + "transport": transport or "stdio", + "command": cfg.command, + "args": cfg.args, + } if cfg.env: config["env"] = cfg.env configs[name] = config @@ -1190,155 +1401,105 @@ def _build_system_prompt(self) -> str: return prompt - def _build_context_section(self) -> str: - """Build the context section based on sandbox mode.""" - if self._sandbox.name != "local": - env_label = self._sandbox.env_label - working_dir = self._sandbox.working_dir - if self._sandbox.name == "docker": - mode_label = "Sandbox (isolated local container)" - else: - mode_label = "Sandbox (isolated cloud environment)" - return f"""- Environment: {env_label} -- Working Directory: {working_dir} -- Mode: {mode_label}""" - else: - import platform - - os_name = platform.system() - if os_name == "Windows": - shell_name = "powershell" - else: - shell_name = os.environ.get("SHELL", "/bin/bash").split("/")[-1] - return f"""- Workspace: `{self.workspace_root}` -- OS: {os_name} -- Shell: {shell_name} -- Mode: Local""" + def _compose_system_prompt(self) -> str: + prompt = self._build_system_prompt() - def _build_rules_section(self) -> str: - """Build shared rules section for all modes.""" - is_sandbox = self._sandbox.name != "local" - working_dir = self._sandbox.working_dir if is_sandbox else self.workspace_root + custom_prompt = self.config.system_prompt + if custom_prompt: + prompt += f"\n\n**Custom Instructions:**\n{custom_prompt}" - rules = [] + # @@@entity-identity — inject chat identity so agent knows who it is in the social layer + if self._chat_repos: + repos = self._chat_repos + uid = repos.get("user_id") + owner_uid = repos.get("owner_user_id", "") + if uid: + entity_repo = repos.get("entity_repo") + member_repo = repos.get("member_repo") + entity = entity_repo.get_by_id(uid) if entity_repo else None + self_member = member_repo.get_by_id(uid) if member_repo and not entity else None + owner_row = member_repo.get_by_id(owner_uid) if member_repo and owner_uid else None + name = entity.name if entity else (self_member.name if self_member else uid) + owner_name = owner_row.name if owner_row else "unknown" + prompt += ( + f"\n\n**Chat Identity:**\n" + f"- Your name: {name}\n" + f"- Your user_id: {uid}\n" + f"- Your owner: {owner_name} (user_id: {owner_uid})\n" + f"- When you receive a chat notification, you MUST read it with read_messages() before deciding what to do.\n" + f"- If that notification already gives you a chat_id, prefer using that exact chat_id directly.\n" + f"- If you reply to the other party, you MUST call send_message(). Never claim you replied unless send_message() succeeded.\n" + f"- Your normal text output goes to your owner's thread, not to the chat — only send_message() delivers to the other party.\n" + ) + return prompt - # Rule 1: Environment-specific - if is_sandbox: - if self._sandbox.name == "docker": - location_rule = "All file and command operations run in a local Docker container, NOT on the user's host filesystem." - else: - location_rule = "All file and command operations run in a remote sandbox, NOT on the user's local machine." - rules.append(f"1. **Sandbox Environment**: {location_rule} The sandbox is an isolated Linux environment.") - else: - rules.append("1. **Workspace**: File operations are restricted to: " + str(self.workspace_root)) + def _invalidate_system_prompt_cache(self) -> None: + self._system_prompt_section_cache.clear() - # Rule 2: Absolute paths - rules.append(f"""2. **Absolute Paths**: All file paths must be absolute paths. - - ✅ Correct: `{working_dir}/project/test.py` - - ❌ Wrong: `test.py` or `./test.py`""") + def _get_cached_prompt_section(self, key: str, builder) -> str: + cached = self._system_prompt_section_cache.get(key) + if cached is not None: + return cached + value = builder() + self._system_prompt_section_cache[key] = value + return value - # Rule 3: Security - if is_sandbox: - rules.append("3. **Security**: The sandbox is isolated. You can install packages, run any commands, and modify files freely.") - else: - rules.append("3. **Security**: Dangerous commands are blocked. All operations are logged.") + def _build_context_section(self) -> str: + from core.runtime.prompts import build_context_section + + def _build() -> str: + is_sandbox = self._sandbox.name != "local" + if is_sandbox: + return build_context_section( + sandbox_name=self._sandbox.name, + sandbox_env_label=self._sandbox.env_label, + sandbox_working_dir=self._sandbox.working_dir, + ) + import platform - # Rule 4: Tool priority - rules.append( - """4. **Tool Priority**: When a built-in tool and an MCP tool (`mcp__*`) have the same functionality, use the built-in tool.""" - ) + os_name = platform.system() + shell_name = "powershell" if os_name == "Windows" else os.environ.get("SHELL", "/bin/bash").split("/")[-1] + return build_context_section( + sandbox_name="local", + workspace_root=str(self.workspace_root), + os_name=os_name, + shell_name=shell_name, + ) - # Rule 5: Dedicated tools over shell - rules.append("""5. **Use Dedicated Tools Instead of Shell Commands**: Do NOT use `Bash` for tasks that have dedicated tools: - - File search → use `Grep` (NOT `rg`, `grep`, or `find` via Bash) - - File listing → use `Glob` (NOT `find` or `ls` via Bash) - - File reading → use `Read` (NOT `cat`, `head`, `tail` via Bash) - - File editing → use `Edit` (NOT `sed` or `awk` via Bash) - - Reserve `Bash` for: git, package managers, build tools, tests, and other system operations.""") + return self._get_cached_prompt_section("context", _build) - # Rule 6: Background task description - rules.append("""6. **Background Task Description**: When using `Bash` or `Agent` with `run_in_background: true`, always include a clear `description` parameter. # noqa: E501 - - The description is shown to the user in the background task indicator. - - Keep it concise (5–10 words), action-oriented, e.g. "Run test suite", "Analyze API codebase". - - Without a description, the raw command or agent name is shown, which is hard to read.""") + def _build_rules_section(self) -> str: + from core.runtime.prompts import build_rules_section + + def _build() -> str: + is_sandbox = self._sandbox.name != "local" + working_dir = self._sandbox.working_dir if is_sandbox else str(self.workspace_root) + return build_rules_section( + is_sandbox=is_sandbox, + sandbox_name=self._sandbox.name, + working_dir=working_dir, + workspace_root=str(self.workspace_root), + spill_buffer_enabled=self.config.tools.spill_buffer.enabled, + spill_keep_recent=self.config.memory.pruning.protect_recent, + ) - return "\n\n".join(rules) + return self._get_cached_prompt_section("rules", _build) def _build_base_prompt(self) -> str: - """Build the base system prompt (context + rules), shared by all modes.""" - context = self._build_context_section() - rules = self._build_rules_section() + from core.runtime.prompts import build_base_prompt - return f"""You are a highly capable AI assistant with access to file and system tools. - -**Context:** -{context} - -**Important Rules:** - -{rules} -""" + return self._get_cached_prompt_section( + "base_prompt", + lambda: build_base_prompt(self._build_context_section(), self._build_rules_section()), + ) def _build_common_prompt_sections(self) -> str: - """Build common prompt sections for both sandbox and local modes.""" - prompt = """ -**Agent Tool (Sub-agent Orchestration):** - -Use the Agent tool to launch specialized sub-agents for complex tasks: -- `explore`: Read-only codebase exploration. Use for: finding files, searching code, understanding implementations. -- `plan`: Design implementation plans. Use for: architecture decisions, multi-step planning. -- `bash`: Execute shell commands. Use for: git operations, running tests, system commands. -- `general`: Full tool access. Use for: independent multi-step tasks requiring file modifications. - -When to use Agent: -- Open-ended searches that may require multiple rounds of exploration -- Tasks that can run independently while you continue other work -- Complex operations that benefit from specialized focus - -When NOT to use Agent: -- Simple file reads (use Read directly) -- Specific searches with known patterns (use Grep directly) -- Quick operations that don't need isolation - -**Todo Tools (Task Management):** - -Use Todo tools to track progress on complex, multi-step tasks: -- `TaskCreate`: Create a new task with subject, description, and activeForm (present continuous for spinner) -- `TaskList`: View all tasks and their status -- `TaskGet`: Get full details of a specific task -- `TaskUpdate`: Update task status (pending → in_progress → completed) or details - -When to use Todo: -- Complex tasks with 3+ distinct steps -- When the user provides multiple tasks to complete -- To show progress on non-trivial work - -When NOT to use Todo: -- Single, straightforward tasks -- Trivial operations that don't need tracking -""" + from core.runtime.prompts import build_common_sections - # Add Skills section if skills are enabled - skills_enabled = self.config.skills.enabled and self.config.skills.paths - - if skills_enabled: - prompt += """ -**Skills (Specialized Knowledge):** - -Use the `load_skill` tool to access specialized domain knowledge and workflows: -- Skills provide focused instructions for specific tasks (e.g., TDD, debugging, git workflows) -- Call `load_skill(skill_name)` to load a skill's content into context -- Available skills are listed in the load_skill tool description - -When to use load_skill: -- When you need specialized guidance for a specific workflow -- To access domain-specific best practices -- When the user mentions a skill by name (e.g., "use TDD skill") - -Progressive disclosure: Skills are loaded on-demand to save tokens. -""" - - return prompt + return self._get_cached_prompt_section( + "common_sections", + lambda: build_common_sections(bool(self.config.skills.enabled and self.config.skills.paths)), + ) def invoke(self, message: str, thread_id: str = "default") -> dict: """Invoke agent with a message (sync version). @@ -1388,6 +1549,174 @@ async def ainvoke(self, message: str, thread_id: str = "default") -> dict: self._monitor_middleware.mark_error(e) raise + async def astream( + self, + message: str, + thread_id: str = "default", + stream_mode: str | list[str] = "updates", + max_budget_usd: float | None = None, + ): + """Stream agent output through a caller-owned LeonAgent surface.""" + try: + async for chunk in self.agent.astream( + {"messages": [{"role": "user", "content": message}]}, + config={"configurable": {"thread_id": thread_id}}, + stream_mode=stream_mode, + ): + yield chunk + if max_budget_usd is not None and self.runtime.cost > max_budget_usd: + raise RuntimeError(f"max_budget_usd exceeded: cost={self.runtime.cost:.6f} budget={max_budget_usd:.6f}") + except Exception as e: + self._monitor_middleware.mark_error(e) + raise + + async def aclear_thread(self, thread_id: str = "default") -> None: + """Clear turn-scoped state for a thread while preserving session accumulators.""" + try: + await self.agent.aclear(thread_id) + self._invalidate_system_prompt_cache() + self.system_prompt = self._compose_system_prompt() + self.agent.system_prompt = SystemMessage(content=[{"type": "text", "text": self.system_prompt}]) + except Exception as e: + self._monitor_middleware.mark_error(e) + raise + + def clear_thread(self, thread_id: str = "default") -> None: + """Sync wrapper for aclear_thread().""" + import asyncio + + async def _aclear(): + await self.aclear_thread(thread_id) + + try: + if hasattr(self, "_event_loop") and self._event_loop: + self._event_loop.run_until_complete(_aclear()) + else: + asyncio.run(_aclear()) + except Exception as e: + self._monitor_middleware.mark_error(e) + raise + + def get_pending_permission_requests(self, thread_id: str | None = None) -> list[dict]: + requests = list(self._app_state.pending_permission_requests.values()) + if thread_id is not None: + requests = [item for item in requests if item.get("thread_id") == thread_id] + return requests + + def get_thread_permission_rules(self, thread_id: str | None = None) -> dict[str, Any]: + state = self._app_state.tool_permission_context + return { + "thread_id": thread_id, + "scope": "session", + "managed_only": state.allowManagedPermissionRulesOnly, + "rules": { + "allow": list(state.alwaysAllowRules.get("session", [])), + "deny": list(state.alwaysDenyRules.get("session", [])), + "ask": list(state.alwaysAskRules.get("session", [])), + }, + } + + def add_thread_permission_rule(self, thread_id: str, *, behavior: str, tool_name: str) -> bool: + if self._app_state.tool_permission_context.allowManagedPermissionRulesOnly: + return False + + def _update(state: AppState) -> AppState: + permission_state = state.tool_permission_context.model_copy(deep=True) + for bucket in ( + permission_state.alwaysAllowRules.setdefault("session", []), + permission_state.alwaysDenyRules.setdefault("session", []), + permission_state.alwaysAskRules.setdefault("session", []), + ): + while tool_name in bucket: + bucket.remove(tool_name) + target_bucket = { + "allow": permission_state.alwaysAllowRules.setdefault("session", []), + "deny": permission_state.alwaysDenyRules.setdefault("session", []), + "ask": permission_state.alwaysAskRules.setdefault("session", []), + }[behavior] + if tool_name not in target_bucket: + target_bucket.append(tool_name) + return state.model_copy(update={"tool_permission_context": permission_state}) + + self._app_state.set_state(_update) + return True + + def remove_thread_permission_rule(self, thread_id: str, *, behavior: str, tool_name: str) -> bool: + removed = False + + def _update(state: AppState) -> AppState: + nonlocal removed + permission_state = state.tool_permission_context.model_copy(deep=True) + bucket = { + "allow": permission_state.alwaysAllowRules.setdefault("session", []), + "deny": permission_state.alwaysDenyRules.setdefault("session", []), + "ask": permission_state.alwaysAskRules.setdefault("session", []), + }[behavior] + if tool_name in bucket: + bucket.remove(tool_name) + removed = True + return state.model_copy(update={"tool_permission_context": permission_state}) + + self._app_state.set_state(_update) + return removed + + def resolve_permission_request( + self, + request_id: str, + *, + decision: str, + message: str | None = None, + answers: list[dict[str, Any]] | None = None, + annotations: dict[str, Any] | None = None, + ) -> bool: + pending = self._app_state.pending_permission_requests.get(request_id) + if pending is None: + return False + + resolved = dict(self._app_state.resolved_permission_requests) + payload = { + **pending, + "decision": decision, + "message": message or pending.get("message"), + } + if answers is not None: + payload["answers"] = answers + if annotations is not None: + payload["annotations"] = annotations + resolved[request_id] = payload + still_pending = dict(self._app_state.pending_permission_requests) + still_pending.pop(request_id, None) + self._app_state.set_state( + lambda prev: prev.model_copy( + update={ + "pending_permission_requests": still_pending, + "resolved_permission_requests": resolved, + } + ) + ) + return True + + def drop_permission_request(self, request_id: str) -> bool: + had_pending = request_id in self._app_state.pending_permission_requests + had_resolved = request_id in self._app_state.resolved_permission_requests + if not had_pending and not had_resolved: + return False + + def _drop(state: AppState) -> AppState: + pending = dict(state.pending_permission_requests) + resolved = dict(state.resolved_permission_requests) + pending.pop(request_id, None) + resolved.pop(request_id, None) + return state.model_copy( + update={ + "pending_permission_requests": pending, + "resolved_permission_requests": resolved, + } + ) + + self._app_state.set_state(_drop) + return True + def get_response(self, message: str, thread_id: str = "default", **kwargs) -> str: """Get agent's text response. @@ -1411,7 +1740,7 @@ def cleanup(self): def create_leon_agent( - model_name: str = DEFAULT_MODEL, + model_name: str | None = None, api_key: str | None = None, workspace_root: str | Path | None = None, sandbox: Any = None, @@ -1421,7 +1750,7 @@ def create_leon_agent( """Create Leon Agent. Args: - model_name: Model name + model_name: Model name. None means "let LeonAgent resolve defaults". api_key: API key workspace_root: Workspace directory sandbox: Sandbox instance, name string, or None for local diff --git a/core/runtime/cleanup.py b/core/runtime/cleanup.py new file mode 100644 index 000000000..d55600684 --- /dev/null +++ b/core/runtime/cleanup.py @@ -0,0 +1,116 @@ +"""CleanupRegistry — priority-ordered async cleanup for LeonAgent lifecycle. + +Aligned with CC Pattern 5: Lifecycle & Cleanup. +Priority numbers: lower = runs first. +""" + +from __future__ import annotations + +import asyncio +import logging +import signal +from collections.abc import Awaitable, Callable +from itertools import groupby + +logger = logging.getLogger(__name__) + + +class CleanupRegistry: + """Registry of async cleanup functions executed in priority order on shutdown. + + Usage: + registry = CleanupRegistry() + registry.register(close_db, priority=1) + registry.register(close_sandbox, priority=2) + await registry.run_cleanup() + """ + + def __init__(self): + # List of (priority, fn) — not a dict because same priority can have multiple fns + self._entries: list[tuple[int, Callable[[], Awaitable[None] | None]]] = [] + self._timeout_s = 2.0 + self._cleanup_task: asyncio.Task[None] | None = None + self._shutdown_in_progress = False + self._signal_loop: asyncio.AbstractEventLoop | None = None + self._setup_signal_handlers() + + def register(self, fn: Callable[[], Awaitable[None] | None], priority: int = 5) -> Callable[[], None]: + """Register a cleanup function. + + Args: + fn: Sync or async callable that releases resources. + priority: Execution order — lower number runs first (1 before 2). + """ + entry = (priority, fn) + self._entries.append(entry) + + def unregister() -> None: + try: + self._entries.remove(entry) + except ValueError: + return + + return unregister + + async def run_cleanup(self) -> None: + """Execute all registered cleanup functions in priority order. + + Different priority tiers run in order. Entries inside the same priority + tier run concurrently so one slow cleanup does not serialize its peers. + """ + if self._cleanup_task is not None: + await asyncio.shield(self._cleanup_task) + return + + async def _run_all() -> None: + sorted_entries = sorted(self._entries, key=lambda x: x[0]) + for priority, grouped_entries in groupby(sorted_entries, key=lambda x: x[0]): + await asyncio.gather( + *(self._run_entry(priority, fn) for _, fn in grouped_entries), + return_exceptions=True, + ) + + self._shutdown_in_progress = True + self._cleanup_task = asyncio.create_task(_run_all()) + await asyncio.shield(self._cleanup_task) + + def is_shutting_down(self) -> bool: + return self._shutdown_in_progress + + async def _run_entry(self, priority: int, fn: Callable[[], Awaitable[None] | None]) -> None: + try: + result = fn() + if asyncio.iscoroutine(result): + await asyncio.wait_for(result, timeout=self._timeout_s) + except TimeoutError: + logger.warning("CleanupRegistry: cleanup fn %s timed out after %.2fs", fn, self._timeout_s) + except Exception: + logger.exception("CleanupRegistry: error in cleanup fn %s (priority=%d)", fn, priority) + + def _setup_signal_handlers(self) -> None: + """Register SIGINT/SIGTERM handlers to trigger async cleanup.""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + return # No running loop yet — signal handlers set up later + self._signal_loop = loop + + signals = [signal.SIGINT, signal.SIGTERM] + if hasattr(signal, "SIGHUP"): + signals.append(signal.SIGHUP) + + for sig in signals: + try: + loop.add_signal_handler(sig, self._handle_signal) + except (NotImplementedError, RuntimeError): + # Windows or non-main thread — skip signal handler setup + pass + + def _handle_signal(self) -> None: + loop = self._signal_loop + if loop is None: + return + if loop.is_running(): + loop.create_task(self.run_cleanup()) + return + loop.run_until_complete(self.run_cleanup()) diff --git a/core/runtime/errors.py b/core/runtime/errors.py index 74ffbfc1e..591ff3090 100644 --- a/core/runtime/errors.py +++ b/core/runtime/errors.py @@ -1,4 +1,13 @@ class InputValidationError(Exception): """Tool parameter validation failed.""" - pass + def __init__( + self, + message: str, + *, + error_code: str | None = None, + details: list[dict[str, object]] | None = None, + ) -> None: + super().__init__(message) + self.error_code = error_code + self.details = [] if details is None else details diff --git a/core/runtime/fork.py b/core/runtime/fork.py new file mode 100644 index 000000000..c3992cf74 --- /dev/null +++ b/core/runtime/fork.py @@ -0,0 +1,91 @@ +"""Context fork for sub-agent spawning. + +When a sub-agent is spawned, it inherits workspace/model/permission configuration +from the parent but gets its own isolated messages and session identity. + +Aligned with CC createSubagentContext() field-by-field fork table. +""" + +from __future__ import annotations + +import copy +import uuid + +from .abort import create_child_abort_controller +from .state import BootstrapConfig, ToolUseContext + + +def fork_context(parent: BootstrapConfig) -> BootstrapConfig: + """Create a child BootstrapConfig for a sub-agent. + + Inherits all workspace identity, model settings, and security flags + from parent. Generates a fresh session_id and sets parent_session_id. + Messages, cost, and turn_count live in AppState — not here. + """ + return BootstrapConfig( + workspace_root=parent.workspace_root, + original_cwd=parent.original_cwd, + project_root=parent.project_root, + cwd=parent.cwd, + model_name=parent.model_name, + api_key=parent.api_key, + sandbox_type=parent.sandbox_type, + block_dangerous_commands=parent.block_dangerous_commands, + block_network_commands=parent.block_network_commands, + enable_audit_log=parent.enable_audit_log, + enable_web_tools=parent.enable_web_tools, + allowed_file_extensions=parent.allowed_file_extensions, + extra_allowed_paths=parent.extra_allowed_paths, + max_turns=parent.max_turns, + # Fresh session identity + session_id=uuid.uuid4().hex, + parent_session_id=parent.session_id, + total_cost_usd=parent.total_cost_usd, + total_tool_duration_ms=parent.total_tool_duration_ms, + # Model settings + model_provider=parent.model_provider, + base_url=parent.base_url, + context_limit=parent.context_limit, + ) + + +def create_subagent_context( + parent: ToolUseContext, + *, + share_set_app_state: bool = False, +) -> ToolUseContext: + """Create a minimally isolated ToolUseContext for sub-agents. + + Default contract: + - bootstrap: fresh fork + - set_app_state: NO-OP + - set_app_state_for_tasks: always reaches the root/session store + - turn-local refs: fresh + - file cache/messages: cloned snapshots + """ + read_file_state = parent.read_file_state + if hasattr(read_file_state, "clone") and callable(read_file_state.clone): + cloned_read_file_state = read_file_state.clone() + else: + # @@@sa-04-read-file-state-clone + # Subagent fork boundaries must isolate nested file cache state too; + # a shallow dict copy leaks child edits back into the parent cache. + cloned_read_file_state = copy.deepcopy(read_file_state) + return ToolUseContext( + bootstrap=fork_context(parent.bootstrap), + get_app_state=parent.get_app_state, + set_app_state=parent.set_app_state if share_set_app_state else (lambda updater: None), + set_app_state_for_tasks=parent.set_app_state_for_tasks or parent.set_app_state, + refresh_tools=parent.refresh_tools, + can_use_tool=parent.can_use_tool, + request_permission=parent.request_permission, + consume_permission_resolution=parent.consume_permission_resolution, + read_file_state=cloned_read_file_state, + loaded_nested_memory_paths=set(), + discovered_skill_names=set(), + discovered_tool_names=set(), + nested_memory_attachment_triggers=set(), + abort_controller=create_child_abort_controller(getattr(parent, "abort_controller", None)), + messages=list(parent.messages), + thread_id=parent.thread_id, + ) diff --git a/core/runtime/loop.py b/core/runtime/loop.py new file mode 100644 index 000000000..f27527e29 --- /dev/null +++ b/core/runtime/loop.py @@ -0,0 +1,2158 @@ +"""QueryLoop — self-managing agentic tool loop replacing LangGraph create_agent. + +Implements CC Pattern 1: Agentic Tool Loop (queryLoop). + +Design: +- AsyncGenerator that alternates LLM sampling and tool execution. +- Exposes the same .astream(input, config, stream_mode) interface as CompiledStateGraph. +- Middleware chain (SpillBuffer/Monitor/PromptCaching/Memory/Steering/ToolRunner) is + preserved exactly — awrap_model_call and awrap_tool_call pass through in order. +- is_concurrency_safe tools execute in parallel; others execute serially. +- Checkpointer (AsyncSqliteSaver) stores/restores message history across calls. +""" + +from __future__ import annotations + +import asyncio +import copy +import inspect +import json +import logging +import re +import uuid +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from enum import StrEnum +from types import SimpleNamespace +from typing import Any + +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, RemoveMessage, SystemMessage, ToolMessage + +from core.runtime.middleware import ( + AgentMiddleware, + ModelRequest, + ModelResponse, + ToolCallRequest, +) + +from .abort import AbortController +from .permissions import ToolPermissionContext, evaluate_permission_rules +from .registry import ToolMode, ToolRegistry +from .state import AppState, BootstrapConfig, ToolPermissionState, ToolUseContext +from .validator import _required_sets_match + +logger = logging.getLogger(__name__) + +_NOOP_HANDLER: Any = None # placeholder for innermost "handler" in middleware chain +_ESCALATED_MAX_OUTPUT_TOKENS = 64000 +_FLOOR_OUTPUT_TOKENS = 3000 +_CONTEXT_OVERFLOW_SAFETY_BUFFER = 1000 +_TRANSIENT_API_MAX_RETRIES = 3 +_TRANSIENT_API_BASE_DELAY_SECONDS = 0.5 +_PROMPT_TOO_LONG_NOTICE_TEXT = "Prompt is too long. Automatic recovery exhausted. Clear the thread or start a new one." + + +class TerminalReason(StrEnum): + completed = "completed" + aborted_streaming = "aborted_streaming" + aborted_tools = "aborted_tools" + model_error = "model_error" + max_turns = "max_turns" + prompt_too_long = "prompt_too_long" + blocking_limit = "blocking_limit" + image_error = "image_error" + hook_stopped = "hook_stopped" + stop_hook_prevented = "stop_hook_prevented" + + +class ContinueReason(StrEnum): + next_turn = "next_turn" + api_retry = "api_retry" + collapse_drain_retry = "collapse_drain_retry" + reactive_compact_retry = "reactive_compact_retry" + max_output_tokens_escalate = "max_output_tokens_escalate" + max_output_tokens_recovery = "max_output_tokens_recovery" + stop_hook_blocking = "stop_hook_blocking" + token_budget_continuation = "token_budget_continuation" + + +@dataclass(frozen=True) +class TerminalState: + reason: TerminalReason + turn_count: int + error: str | None = None + + +@dataclass(frozen=True) +class ContinueState: + reason: ContinueReason + + +@dataclass(frozen=True) +class _ModelErrorRecoveryResult: + messages: list + transition: ContinueState | None + max_output_tokens_recovery_count: int + has_attempted_reactive_compact: bool + max_output_tokens_override: int | None + transient_api_retry_count: int + terminal: TerminalState | None + + +@dataclass +class _TrackedTool: + order: int + tool_call: dict[str, Any] + is_concurrency_safe: bool + status: str = "queued" + task: asyncio.Task[ToolMessage] | None = None + result: ToolMessage | None = None + + +class QueryLoop: + """Self-managing query loop replacing create_agent. + + The .astream() method is an AsyncGenerator that yields dicts compatible + with LangGraph's stream_mode="updates": + {"agent": {"messages": [AIMessage(...)]}} + {"tools": {"messages": [ToolMessage(...), ...]}} + + The checkpointer attribute is set post-construction (mirrors create_agent pattern). + """ + + def __init__( + self, + model: Any, + system_prompt: SystemMessage, + middleware: list[AgentMiddleware], + checkpointer: Any, + registry: ToolRegistry, + app_state: AppState | None = None, + runtime: Any = None, + bootstrap: BootstrapConfig | None = None, + refresh_tools: Any = None, + max_turns: int = 100, + ): + self.model = model + self.system_prompt = system_prompt + self.middleware = middleware + self.checkpointer = checkpointer + self._registry = registry + self._app_state = app_state + self._runtime = runtime + self._bootstrap = bootstrap + self._refresh_tools = refresh_tools + self._memory_middleware = next( + (mw for mw in middleware if hasattr(mw, "compact_boundary_index")), + None, + ) + # @@@sa-02-session-tool-refs + # These refs must survive across turns within the same loop/session, + # while turn-local attachment triggers stay ephemeral per ToolUseContext. + self._tool_read_file_state: dict[str, Any] = {} + self._tool_loaded_nested_memory_paths: set[str] = set() + self._tool_discovered_skill_names: set[str] = set() + self._tool_discovered_tool_names_by_thread: dict[str, set[str]] = {} + self._tool_abort_controller = AbortController() + self.max_turns = max_turns + self.last_terminal: TerminalState | None = None + self.last_continue: ContinueState | None = None + + # ------------------------------------------------------------------------- + # Public streaming interface (LangGraph-compatible) + # ------------------------------------------------------------------------- + + async def query( + self, + input: dict, + config: dict | None = None, + ) -> AsyncGenerator[dict[str, Any], None]: + """Raw loop generator with an explicit final terminal event.""" + config = config or {} + thread_id = config.get("configurable", {}).get("thread_id", "default") + + # Set thread context so MemoryMiddleware can find thread_id via ContextVar + from sandbox.thread_context import set_current_thread_id + + set_current_thread_id(thread_id) + + # Load message history and thread-scoped runtime state from checkpointer + persisted = await self._hydrate_thread_state_from_checkpoint(thread_id) + messages = list(persisted["messages"]) + self._restore_discovered_tool_names_from_messages(thread_id, messages) + + # Parse and append new input messages + new_msgs = self._parse_input(input) + messages.extend(new_msgs) + self._sync_app_state(messages=messages, turn_count=0) + + terminal: TerminalState | None = None + transition: ContinueState | None = None + pending_system_notices: list[HumanMessage] = [] + max_output_tokens_recovery_count = 0 + has_attempted_reactive_compact = False + max_output_tokens_override: int | None = None + transient_api_retry_count = 0 + + turn = 0 + try: + while turn < self.max_turns: + turn += 1 + tool_context = self._build_tool_use_context(messages, thread_id=thread_id) + + messages_for_query, injected_messages = await self._build_query_messages(messages, config) + if injected_messages: + # @@@steer-persist - queue/steer messages accepted before the + # next model call must become durable conversation state, not + # request-only hints, or later replay/history lies about what + # the user actually said mid-run. + messages.extend(injected_messages) + self._sync_app_state(messages=messages, turn_count=turn) + self._sync_tool_context_messages(tool_context, messages_for_query) + + # --- Call model through middleware chain --- + streamed_tool_results: list[ToolMessage] = [] + pending_tool_results: list[ToolMessage] = [] + used_streaming_overlap = False + response: ModelResponse | None = None + ai_msg: AIMessage | None = None + tool_calls: list[dict[str, Any]] = [] + try: + if self._can_stream_tools(): + used_streaming_overlap = True + async for stream_event in self._stream_model_with_tool_overlap( + messages_for_query, + config, + thread_id=thread_id, + tool_context=tool_context, + max_output_tokens_override=max_output_tokens_override, + ): + if stream_event["type"] == "message_chunk": + yield {"message_chunk": stream_event["chunk"]} + continue + if stream_event["type"] == "tools": + chunk_messages = stream_event["messages"] + streamed_tool_results.extend(chunk_messages) + yield {"tools": {"messages": chunk_messages}} + continue + response = stream_event["response"] + ai_msg = stream_event["ai_message"] + tool_calls = stream_event["tool_calls"] + pending_tool_results = stream_event["remaining_tool_results"] + else: + response = await self._invoke_model( + messages_for_query, + config, + thread_id=thread_id, + max_output_tokens_override=max_output_tokens_override, + ) + except Exception as exc: + self._collect_memory_system_notices(pending_system_notices) + handled = await self._handle_model_error_recovery( + exc=exc, + thread_id=thread_id, + messages=messages, + turn=turn, + transition=transition, + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=max_output_tokens_override, + transient_api_retry_count=transient_api_retry_count, + ) + if handled is not None: + messages = handled.messages + transition = handled.transition + max_output_tokens_recovery_count = handled.max_output_tokens_recovery_count + has_attempted_reactive_compact = handled.has_attempted_reactive_compact + max_output_tokens_override = handled.max_output_tokens_override + transient_api_retry_count = handled.transient_api_retry_count + if handled.terminal is not None: + terminal = handled.terminal + break + self._sync_app_state(messages=messages, turn_count=turn) + continue + terminal = TerminalState( + reason=TerminalReason.model_error, + turn_count=turn, + error=str(exc), + ) + break + + if response is None or ai_msg is None: + ai_messages = [m for m in (response.result if response else []) if isinstance(m, AIMessage)] + if not ai_messages: + # No AI message — unexpected; treat as terminal + terminal = TerminalState( + reason=TerminalReason.model_error, + turn_count=turn, + error="model returned no AIMessage", + ) + break + ai_msg = ai_messages[0] + self._collect_memory_system_notices(pending_system_notices) + self._sync_tool_context_messages( + tool_context, + response.request_messages or messages_for_query, + ) + + truncated = self._handle_truncated_response_recovery( + ai_msg=ai_msg, + messages=messages, + turn=turn, + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + max_output_tokens_override=max_output_tokens_override, + ) + if truncated is not None: + messages = truncated["messages"] + transition = truncated["transition"] + max_output_tokens_recovery_count = truncated["max_output_tokens_recovery_count"] + max_output_tokens_override = truncated["max_output_tokens_override"] + self._sync_app_state(messages=messages, turn_count=turn) + if truncated["yield_ai"]: + yield {"agent": {"messages": [ai_msg]}} + if truncated["terminal"] is not None: + terminal = truncated["terminal"] + break + continue + + self._sync_app_state(messages=messages, turn_count=turn) + + if not tool_calls: + tool_calls = getattr(ai_msg, "tool_calls", None) or [] + if not tool_calls: + # Also check additional_kwargs for older message formats + tool_calls = ai_msg.additional_kwargs.get("tool_calls", []) + + if not tool_calls and not self._ai_message_has_visible_content(ai_msg): + terminal_followthrough_notice = self._get_terminal_followthrough_notice(messages) + if terminal_followthrough_notice is not None: + ai_msg = self._build_terminal_followthrough_fallback(terminal_followthrough_notice) + else: + chat_followthrough_notice = self._get_chat_followthrough_notice(messages) + if chat_followthrough_notice is not None: + ai_msg = self._build_chat_followthrough_fallback(chat_followthrough_notice) + + # Yield agent update (stream_mode="updates" format) + yield {"agent": {"messages": [ai_msg]}} + + if not tool_calls: + # No tool calls → agent is done + if self._ai_message_has_visible_content(ai_msg): + messages.append(ai_msg) + terminal = TerminalState( + reason=TerminalReason.completed, + turn_count=turn, + ) + break + + # Expose current messages for forkContext sub-agent spawning + from sandbox.thread_context import set_current_messages + + set_current_messages(messages + [ai_msg]) + + if used_streaming_overlap: + if pending_tool_results: + yield {"tools": {"messages": pending_tool_results}} + tool_results = streamed_tool_results + pending_tool_results + else: + # --- Execute tools through middleware chain --- + try: + tool_results = await self._execute_tools(tool_calls, response, tool_context) + except Exception as exc: + terminal = TerminalState( + reason=TerminalReason.aborted_tools, + turn_count=turn, + error=str(exc), + ) + break + + # Yield tools update + yield {"tools": {"messages": tool_results}} + + # Advance message history for next turn + messages.append(ai_msg) + messages.extend(tool_results) + await self._refresh_tools_between_turns(tool_context) + transition = ContinueState(reason=ContinueReason.next_turn) + max_output_tokens_recovery_count = 0 + has_attempted_reactive_compact = False + max_output_tokens_override = None + transient_api_retry_count = 0 + self._sync_app_state(messages=messages, turn_count=turn) + except asyncio.CancelledError: + # @@@cancel-persists-live-state - accepted user input from the + # current run must not evaporate just because the run is cancelled + # before the next terminal save. + messages = self._append_system_notices(messages, pending_system_notices) + await self._save_messages(thread_id, messages) + self._sync_app_state(messages=messages, turn_count=turn) + raise + + if terminal is None: + terminal = TerminalState( + reason=TerminalReason.max_turns, + turn_count=turn, + ) + + # Persist message history + self._collect_memory_system_notices(pending_system_notices) + visible_terminal_error = self._build_visible_terminal_error_message(terminal, messages) + if visible_terminal_error is not None: + messages.append(visible_terminal_error) + terminal_notice = self._build_terminal_notice(terminal) + if terminal_notice is not None: + pending_system_notices.append(terminal_notice) + messages = self._append_system_notices(messages, pending_system_notices) + await self._save_messages(thread_id, messages) + self._sync_app_state(messages=messages, turn_count=turn) + self.last_terminal = terminal + self.last_continue = transition + yield {"terminal": terminal, "transition": transition} + + async def astream( + self, + input: dict, + config: dict | None = None, + stream_mode: str | list[str] = "updates", + ) -> AsyncGenerator[Any, None]: + """Stream agent execution chunks compatible with LangGraph stream modes.""" + requested_modes = [stream_mode] if isinstance(stream_mode, str) else list(stream_mode) + emitted_live_agent_chunks = False + async for event in self.query(input, config=config): + if "terminal" in event: + terminal = event["terminal"] + if terminal is not None and terminal.reason is not TerminalReason.completed: + # @@@astream-terminal-loud-fail + # query() always emits a terminal event, but caller-facing + # astream() must not turn runtime failures into a silent empty + # iterator. Propagate non-completed terminals back to the caller. + raise RuntimeError(self._terminal_error_text(terminal)) + continue + if isinstance(stream_mode, str): + if "message_chunk" in event: + continue + yield event + continue + + if "message_chunk" in event: + if "messages" in requested_modes: + yield ( + "messages", + ( + event["message_chunk"], + {"langgraph_node": "agent"}, + ), + ) + emitted_live_agent_chunks = True + continue + + if "messages" in requested_modes and "agent" in event: + if not emitted_live_agent_chunks: + for msg in event["agent"].get("messages", []): + if not isinstance(msg, AIMessage): + continue + yield ( + "messages", + ( + AIMessageChunk(**msg.model_dump(exclude={"type"})), + {"langgraph_node": "agent"}, + ), + ) + emitted_live_agent_chunks = False + + if "updates" in requested_modes: + yield ("updates", event) + + async def ainvoke( + self, + input: dict, + config: dict | None = None, + stream_mode: str = "updates", + ) -> dict[str, Any]: + """Drain query and return messages plus explicit terminal state.""" + drained_messages: list[Any] = [] + terminal: TerminalState | None = None + transition: ContinueState | None = None + + # @@@ainvoke-drains-astream + # QueryLoop is generator-first. ainvoke exists only as a compatibility + # adapter for callers like LeonAgent.invoke/ainvoke and must not invent + # a separate execution path. + async for event in self.query(input, config=config): + if "terminal" in event: + terminal = event["terminal"] + transition = event.get("transition") + continue + for section in ("agent", "tools"): + drained_messages.extend(event.get(section, {}).get("messages", [])) + + return { + "messages": drained_messages, + "reason": terminal.reason.value if terminal else TerminalReason.completed.value, + "terminal": terminal, + "transition": transition, + } + + async def aget_state(self, config: dict | None = None) -> Any: + """Minimal graph-state bridge for backend/web callers.""" + config = config or {} + thread_id = config.get("configurable", {}).get("thread_id", "default") + if self._is_runtime_active(): + # @@@active-state-no-clobber - caller surfaces like /permissions and + # /history can poll during an active run. Rehydrating from stale + # checkpoint here would erase live thread-scoped permission state. + values = self._snapshot_live_thread_state(thread_id) + return SimpleNamespace(values=values) + values = await self._hydrate_thread_state_from_checkpoint(thread_id) + return SimpleNamespace(values=values) + + async def aupdate_state( + self, + config: dict | None, + input_data: dict[str, Any] | None, + as_node: str | None = None, + ) -> Any: + """Minimal graph-state update bridge for resumed-thread callers.""" + config = config or {} + input_data = input_data or {} + thread_id = config.get("configurable", {}).get("thread_id", "default") + messages = await self._load_messages(thread_id) + raw_updates = input_data.get("messages", []) + + # @@@ql-06-state-bridge - backend/web still speaks the old graph-state + # contract. Only the live caller shapes are supported here: append + # resumed start messages, or apply RemoveMessage-based repairs before + # appending replacement messages. + if as_node == "__start__": + messages.extend(self._parse_input({"messages": raw_updates})) + else: + updates = raw_updates if isinstance(raw_updates, list) else [raw_updates] + remove_ids = {update.id for update in updates if isinstance(update, RemoveMessage) and getattr(update, "id", None)} + if remove_ids: + messages = [message for message in messages if getattr(message, "id", None) not in remove_ids] + messages.extend(update for update in updates if not isinstance(update, RemoveMessage)) + + await self._save_messages(thread_id, messages) + current_turn_count = self._app_state.turn_count if self._app_state is not None else 0 + self._sync_app_state(messages=messages, turn_count=current_turn_count) + self._restore_discovered_tool_names_from_messages(thread_id, messages) + return await self.aget_state(config) + + async def apersist_state(self, thread_id: str) -> None: + """Persist the current thread-scoped loop/app state to the checkpointer.""" + messages = list(self._app_state.messages) if self._app_state is not None else await self._load_messages(thread_id) + await self._save_messages(thread_id, messages) + + # ------------------------------------------------------------------------- + # Model invocation through middleware chain + # ------------------------------------------------------------------------- + + async def _invoke_model( + self, + messages: list, + config: dict, + *, + thread_id: str = "default", + max_output_tokens_override: int | None = None, + ) -> ModelResponse: + """Call model through the full middleware chain (awrap_model_call).""" + + async def innermost_handler(request: ModelRequest) -> ModelResponse: + """Actual model call — innermost of the chain.""" + tools = request.tools or [] + model = request.model + + # Bind tools to model if any + if tools: + try: + bound = model.bind_tools(tools) + except Exception: + bound = model + else: + bound = model + + if max_output_tokens_override is not None and hasattr(bound, "bind"): + try: + bound = bound.bind(max_tokens=max_output_tokens_override) + except Exception: + pass + + # Build message list: system + conversation + call_messages = [] + if request.system_message: + call_messages.append(request.system_message) + call_messages.extend(request.messages) + + result = await bound.ainvoke(call_messages) + if not isinstance(result, list): + result = [result] + return ModelResponse(result=result, request_messages=list(request.messages)) + + # Build ModelRequest + inline_schemas = self._registry.get_inline_schemas(self._get_discovered_tool_names(thread_id)) + request = ModelRequest( + model=self.model, + messages=messages, + system_message=self.system_prompt, + tools=inline_schemas, + ) + + # Walk middleware chain outside-in: each wraps the next. + # Only include middleware that actually overrides awrap_model_call OR wrap_model_call + # (not just inherits the base-class NotImplementedError stub). + handler = innermost_handler + for mw in reversed(self.middleware): + if _mw_overrides_model_call(mw): + handler = _make_model_wrapper(mw, handler) + + return await handler(request) + + def _bind_model( + self, + model: Any, + tools: list | None, + *, + max_output_tokens_override: int | None = None, + ) -> Any: + if tools: + try: + bound = model.bind_tools(tools) + except Exception: + bound = model + else: + bound = model + + if max_output_tokens_override is not None and hasattr(bound, "bind"): + try: + bound = bound.bind(max_tokens=max_output_tokens_override) + except Exception: + pass + return bound + + def _can_stream_tools(self) -> bool: + stream_fn = getattr(self.model, "astream", None) + if not callable(stream_fn): + return False + return type(self.model).__module__ != "unittest.mock" + + async def _prepare_streaming_request( + self, + messages: list, + *, + thread_id: str, + ) -> ModelRequest: + inline_schemas = self._registry.get_inline_schemas(self._get_discovered_tool_names(thread_id)) + request = ModelRequest( + model=self.model, + messages=messages, + system_message=self.system_prompt, + tools=inline_schemas, + ) + + async def prepare_handler(request: ModelRequest) -> ModelResponse: + return ModelResponse( + result=[], + request_messages=list(request.messages), + prepared_request=request, + ) + + handler = prepare_handler + for mw in reversed(self.middleware): + if _mw_overrides_model_call(mw): + handler = _make_model_wrapper(mw, handler) + + response = await handler(request) + return response.prepared_request or request + + async def _stream_model_with_tool_overlap( + self, + messages: list, + config: dict, + *, + thread_id: str, + tool_context: ToolUseContext | None, + max_output_tokens_override: int | None, + ) -> AsyncGenerator[dict[str, Any], None]: + prepared_request = await self._prepare_streaming_request(messages, thread_id=thread_id) + bound = self._bind_model( + prepared_request.model, + prepared_request.tools, + max_output_tokens_override=max_output_tokens_override, + ) + + call_messages = [] + if prepared_request.system_message: + call_messages.append(prepared_request.system_message) + call_messages.extend(prepared_request.messages) + + executor = _StreamingToolExecutor(loop=self, tool_context=tool_context) + aggregate: AIMessageChunk | None = None + seen_tool_ids: set[str] = set() + streamed_tool_calls: list[dict[str, Any]] = [] + + try: + async for chunk in bound.astream(call_messages): + if isinstance(chunk, AIMessage): + chunk = AIMessageChunk(**chunk.model_dump(exclude={"type"})) + elif not isinstance(chunk, AIMessageChunk): + continue + + # @@@stream-chunk-snapshot + # Some providers reuse and mutate the same chunk object across + # yields. Snapshot before yielding/aggregating so the final + # AIMessage cannot collapse to the last empty chunk. + chunk = AIMessageChunk(**chunk.model_dump(exclude={"type"})) + if ( + aggregate is not None + and getattr(chunk, "chunk_position", None) == "last" + and not chunk.content + and not getattr(chunk, "tool_calls", None) + and not getattr(chunk, "invalid_tool_calls", None) + and not getattr(chunk, "tool_call_chunks", None) + and getattr(chunk, "usage_metadata", None) == getattr(aggregate, "usage_metadata", None) + ): + chunk = chunk.model_copy(update={"usage_metadata": None}) + aggregate = chunk if aggregate is None else aggregate + chunk + + yield {"type": "message_chunk", "chunk": chunk} + + tool_call_chunks = getattr(aggregate, "tool_call_chunks", None) or [] + for tool_call in getattr(aggregate, "tool_calls", None) or []: + ready_tool_call = self._normalize_stream_tool_call(tool_call, tool_call_chunks) + if ready_tool_call is None: + continue + call_id = ready_tool_call.get("id") + if not call_id or call_id in seen_tool_ids: + continue + seen_tool_ids.add(call_id) + streamed_tool_calls.append(ready_tool_call) + await executor.add_tool(ready_tool_call) + + completed = await executor.get_completed_results() + if completed: + yield {"type": "tools", "messages": completed} + except Exception: + discarded = await executor.discard(reason="streaming_error") + if discarded: + yield {"type": "tools", "messages": discarded} + raise + + if aggregate is None: + raise RuntimeError("streaming model returned no AIMessageChunk") + + ai_message = AIMessage(**aggregate.model_dump(exclude={"type"})) + self._notify_stream_response(prepared_request, ai_message) + remaining = await executor.drain_remaining() + yield { + "type": "done", + "response": ModelResponse(result=[ai_message], request_messages=list(prepared_request.messages)), + "ai_message": ai_message, + "tool_calls": list(streamed_tool_calls), + "remaining_tool_results": remaining, + } + + def _notify_stream_response(self, request: ModelRequest, ai_message: AIMessage) -> None: + req_dict = {"messages": request.messages} + resp_dict = {"messages": [ai_message]} + for mw in self.middleware: + dispatch = getattr(mw, "_dispatch_monitors", None) + if callable(dispatch): + dispatch("on_response", req_dict, resp_dict) + + async def _build_query_messages(self, messages: list, config: dict) -> tuple[list, list]: + return await self._apply_before_model(list(messages), config) + + async def _apply_before_model(self, messages: list, config: dict) -> tuple[list, list]: + """Run middleware before_model/abefore_model hooks on the live path.""" + current_messages = list(messages) + injected_messages: list[Any] = [] + state = {"messages": current_messages} + + for mw in self.middleware: + update: dict[str, Any] | None = None + abefore = getattr(mw, "abefore_model", None) + before = getattr(mw, "before_model", None) + + if callable(abefore): + update = await abefore(state=state, runtime=None, config=config) + elif callable(before): + update = before(state=state, runtime=None, config=config) + + if not update: + continue + + new_messages = update.get("messages") + if new_messages: + if not isinstance(new_messages, list): + new_messages = [new_messages] + current_messages.extend(new_messages) + injected_messages.extend(new_messages) + state["messages"] = current_messages + + return current_messages, injected_messages + + def _sync_app_state(self, messages: list, turn_count: int) -> None: + """Keep runtime AppState aligned with the loop's live state.""" + if self._app_state is None: + return + + snapshot = list(messages) + current_cost = self._read_runtime_cost() + bootstrap_cost = self._bootstrap.total_cost_usd if self._bootstrap is not None else 0.0 + cumulative_cost = max(current_cost, self._app_state.total_cost, bootstrap_cost) + compact_boundary_index = self._read_compact_boundary_index() + + # @@@sa-03-cost-accumulator-monotonic + # /clear must preserve session accumulators, so loop sync cannot let a + # lower per-run observation overwrite the accumulated session total. + if self._bootstrap is not None: + self._bootstrap.total_cost_usd = cumulative_cost + + # @@@app-state-sync + # ql-02 needs the loop's local lifecycle to write back into AppState, + # but we still do not have compaction yet. Clamp the boundary so the + # store stays coherent without pretending compaction exists. + def _update(state: AppState) -> AppState: + return state.model_copy( + update={ + "messages": snapshot, + "turn_count": turn_count, + "total_cost": cumulative_cost, + "compact_boundary_index": compact_boundary_index, + } + ) + + self._app_state.set_state(_update) + + def _read_runtime_cost(self) -> float: + if self._runtime is None: + return self._app_state.total_cost if self._app_state is not None else 0.0 + try: + return float(self._runtime.cost) + except Exception: + return self._app_state.total_cost if self._app_state is not None else 0.0 + + def _read_compact_boundary_index(self) -> int: + if self._memory_middleware is None: + return 0 + try: + boundary = int(self._memory_middleware.compact_boundary_index) + except Exception: + return 0 + return max(boundary, 0) + + def _get_discovered_tool_names(self, thread_id: str) -> set[str]: + # @@@dt-03-thread-scoped-deferred-tools - deferred discovery must stay + # isolated per thread_id, or one thread's tool_search silently changes + # another thread's inline schema surface on the next turn. + return self._tool_discovered_tool_names_by_thread.setdefault(thread_id, set()) + + def _restore_discovered_tool_names_from_messages( + self, + thread_id: str, + messages: list, + ) -> None: + discovered: set[str] = set() + for message in messages: + if not isinstance(message, ToolMessage) or getattr(message, "name", None) != "tool_search": + continue + content = getattr(message, "content", None) + if not isinstance(content, str): + continue + try: + payload = json.loads(content) + except Exception: + continue + if not isinstance(payload, list): + continue + for item in payload: + if not isinstance(item, dict): + continue + name = item.get("name") + if not isinstance(name, str): + continue + entry = self._registry.get(name) + if entry is not None and entry.mode == ToolMode.DEFERRED: + discovered.add(name) + self._tool_discovered_tool_names_by_thread[thread_id] = discovered + + def _build_tool_use_context(self, messages: list, *, thread_id: str = "default") -> ToolUseContext | None: + if self._bootstrap is None or self._app_state is None: + return None + has_permission_resolver = self._bootstrap.permission_resolver_scope != "none" + return ToolUseContext( + bootstrap=self._bootstrap, + get_app_state=self._app_state.get_state, + set_app_state=self._app_state.set_state, + refresh_tools=self._refresh_tools, + can_use_tool=lambda name, args, permission_context, request: self._default_can_use_tool( + name=name, + permission_context=permission_context, + ), + request_permission=( + lambda name, args, context, request, message: self._request_permission( + thread_id=thread_id, + name=name, + args=args, + message=message, + ) + ) + if has_permission_resolver + else None, + consume_permission_resolution=lambda name, args, context, request: self._consume_permission_resolution( + thread_id=thread_id, + name=name, + args=args, + ), + read_file_state=self._tool_read_file_state, + loaded_nested_memory_paths=self._tool_loaded_nested_memory_paths, + discovered_skill_names=self._tool_discovered_skill_names, + discovered_tool_names=self._get_discovered_tool_names(thread_id), + nested_memory_attachment_triggers=set(), + abort_controller=self._tool_abort_controller, + messages=list(messages), + thread_id=thread_id, + ) + + def _default_can_use_tool( + self, + *, + name: str, + permission_context: ToolPermissionContext, + ) -> dict[str, Any] | None: + if self._app_state is None: + return None + permission_state = self._app_state.tool_permission_context + merged_context = ToolPermissionContext( + is_read_only=permission_context.is_read_only, + is_destructive=permission_context.is_destructive, + alwaysAllowRules=permission_state.alwaysAllowRules, + alwaysDenyRules=permission_state.alwaysDenyRules, + alwaysAskRules=permission_state.alwaysAskRules, + allowManagedPermissionRulesOnly=permission_state.allowManagedPermissionRulesOnly, + ) + decision = evaluate_permission_rules(name, merged_context) + if ( + decision is not None + and decision.get("decision") == "ask" + and self._bootstrap is not None + and self._bootstrap.permission_resolver_scope == "none" + ): + # @@@permission-headless-fail-loud - ask is only a real product mode + # when this run has an owner-facing resolver. Otherwise fail loudly + # instead of creating a dead-end pending request in hidden state. + return { + "decision": "deny", + "message": f"{decision.get('message')}. No interactive permission resolver is available for this run.", + } + return decision + + def _request_permission( + self, + *, + thread_id: str, + name: str, + args: dict[str, Any], + message: str | None, + ) -> str | None: + if self._app_state is None: + return None + + request_id = uuid.uuid4().hex[:8] + payload = { + "request_id": request_id, + "thread_id": thread_id, + "tool_name": name, + "args": copy.deepcopy(args), + "message": message, + } + + def _store(state: AppState) -> AppState: + pending = dict(state.pending_permission_requests) + pending[request_id] = payload + return state.model_copy(update={"pending_permission_requests": pending}) + + self._app_state.set_state(_store) + return request_id + + def _consume_permission_resolution( + self, + *, + thread_id: str, + name: str, + args: dict[str, Any], + ) -> dict[str, Any] | None: + if self._app_state is None: + return None + + resolved_items = list(self._app_state.resolved_permission_requests.items()) + matched_id: str | None = None + matched_payload: dict[str, Any] | None = None + for request_id, payload in resolved_items: + if payload.get("thread_id") != thread_id: + continue + if payload.get("tool_name") != name: + continue + if payload.get("args") != args: + continue + matched_id = request_id + matched_payload = payload + break + + if matched_id is None or matched_payload is None: + return None + + def _consume(state: AppState) -> AppState: + resolved = dict(state.resolved_permission_requests) + resolved.pop(matched_id, None) + return state.model_copy(update={"resolved_permission_requests": resolved}) + + self._app_state.set_state(_consume) + return { + "decision": matched_payload.get("decision"), + "message": matched_payload.get("message"), + } + + def _sync_tool_context_messages( + self, + tool_context: ToolUseContext | None, + messages: list, + ) -> None: + if tool_context is None: + return + tool_context.messages = list(messages) + + async def _refresh_tools_between_turns(self, tool_context: ToolUseContext | None) -> None: + refresh = self._refresh_tools + if refresh is None and tool_context is not None: + refresh = tool_context.refresh_tools + if refresh is None: + return + result = refresh() + if inspect.isawaitable(result): + await result + + async def _handle_model_error_recovery( + self, + *, + exc: Exception, + thread_id: str, + messages: list, + turn: int, + transition: ContinueState | None, + max_output_tokens_recovery_count: int, + has_attempted_reactive_compact: bool, + max_output_tokens_override: int | None, + transient_api_retry_count: int, + ) -> _ModelErrorRecoveryResult | None: + error_message = str(exc) + error_text = error_message.lower() + + parsed_overflow = self._parse_context_overflow_override(error_message) + if parsed_overflow is not None: + return _ModelErrorRecoveryResult( + messages=messages, + transition=ContinueState(reason=ContinueReason.max_output_tokens_escalate), + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=parsed_overflow, + transient_api_retry_count=transient_api_retry_count, + terminal=None, + ) + + if self._is_transient_api_error(exc, error_text): + if transient_api_retry_count >= _TRANSIENT_API_MAX_RETRIES: + return None + delay_seconds = self._retry_delay_seconds(exc, transient_api_retry_count) + if delay_seconds > 0: + await asyncio.sleep(delay_seconds) + return _ModelErrorRecoveryResult( + messages=messages, + transition=ContinueState(reason=ContinueReason.api_retry), + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=max_output_tokens_override, + transient_api_retry_count=transient_api_retry_count + 1, + terminal=None, + ) + + if "max_output_tokens" in error_text: + if max_output_tokens_override is None: + return _ModelErrorRecoveryResult( + messages=messages, + transition=ContinueState(reason=ContinueReason.max_output_tokens_escalate), + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=_ESCALATED_MAX_OUTPUT_TOKENS, + transient_api_retry_count=transient_api_retry_count, + terminal=None, + ) + if max_output_tokens_recovery_count < 3: + recovered_messages = list(messages) + recovered_messages.append( + HumanMessage( + content="Output token limit hit. Resume directly with no apology or recap.", + ) + ) + return _ModelErrorRecoveryResult( + messages=recovered_messages, + transition=ContinueState(reason=ContinueReason.max_output_tokens_recovery), + max_output_tokens_recovery_count=max_output_tokens_recovery_count + 1, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=max_output_tokens_override, + transient_api_retry_count=transient_api_retry_count, + terminal=None, + ) + return _ModelErrorRecoveryResult( + messages=messages, + transition=ContinueState(reason=ContinueReason.max_output_tokens_recovery), + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=max_output_tokens_override, + transient_api_retry_count=transient_api_retry_count, + terminal=TerminalState( + reason=TerminalReason.model_error, + turn_count=turn, + error=str(exc), + ), + ) + + if self._is_prompt_too_long_error(error_text): + if transition is None or transition.reason is not ContinueReason.collapse_drain_retry: + drained = await self._recover_from_overflow(messages) + if drained is not None and drained["committed"] > 0: + return _ModelErrorRecoveryResult( + messages=drained["messages"], + transition=ContinueState(reason=ContinueReason.collapse_drain_retry), + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=max_output_tokens_override, + transient_api_retry_count=transient_api_retry_count, + terminal=None, + ) + if not has_attempted_reactive_compact: + compacted = await self._force_reactive_compact(messages, thread_id=thread_id) + if compacted is not None: + return _ModelErrorRecoveryResult( + messages=compacted, + transition=ContinueState(reason=ContinueReason.reactive_compact_retry), + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=True, + max_output_tokens_override=max_output_tokens_override, + transient_api_retry_count=transient_api_retry_count, + terminal=None, + ) + return _ModelErrorRecoveryResult( + messages=messages, + transition=transition, + max_output_tokens_recovery_count=max_output_tokens_recovery_count, + has_attempted_reactive_compact=has_attempted_reactive_compact, + max_output_tokens_override=max_output_tokens_override, + transient_api_retry_count=transient_api_retry_count, + terminal=TerminalState( + reason=TerminalReason.prompt_too_long, + turn_count=turn, + error=str(exc), + ), + ) + + return None + + @staticmethod + def _parse_context_overflow_override(error_message: str) -> int | None: + match = re.search( + r"input length and `max_tokens` exceed context limit: (\d+) \+ (\d+) > (\d+)", + error_message, + ) + if match is None: + return None + input_tokens = int(match.group(1)) + context_limit = int(match.group(3)) + available_context = max(0, context_limit - input_tokens - _CONTEXT_OVERFLOW_SAFETY_BUFFER) + if available_context < _FLOOR_OUTPUT_TOKENS: + return None + return max(_FLOOR_OUTPUT_TOKENS, available_context) + + @staticmethod + def _is_transient_api_error(exc: Exception, error_text: str) -> bool: + status = getattr(exc, "status", None) + return status in {429, 529} or '"type":"overloaded_error"' in error_text + + @staticmethod + def _retry_delay_seconds(exc: Exception, transient_api_retry_count: int) -> float: + headers = getattr(exc, "headers", None) or {} + # @@@retry-after-shape + # Test doubles use plain dict headers while SDK errors expose a Headers-like + # object. Keep this probe shape-tolerant so the loop can honor retry-after + # without forcing a specific exception class. + if hasattr(headers, "get"): + retry_after = headers.get("retry-after") + else: + retry_after = None + try: + if retry_after is not None: + return max(0.0, float(retry_after)) + except (TypeError, ValueError): + pass + return _TRANSIENT_API_BASE_DELAY_SECONDS * (2**transient_api_retry_count) + + def _handle_truncated_response_recovery( + self, + *, + ai_msg: AIMessage, + messages: list, + turn: int, + max_output_tokens_recovery_count: int, + max_output_tokens_override: int | None, + ) -> dict[str, Any] | None: + if not self._is_max_output_truncated(ai_msg): + return None + + if max_output_tokens_override is None: + return { + "messages": messages, + "transition": ContinueState(reason=ContinueReason.max_output_tokens_escalate), + "max_output_tokens_recovery_count": max_output_tokens_recovery_count, + "max_output_tokens_override": _ESCALATED_MAX_OUTPUT_TOKENS, + "yield_ai": False, + "terminal": None, + } + + if max_output_tokens_recovery_count < 3: + recovered_messages = list(messages) + recovered_messages.append(ai_msg) + recovered_messages.append( + HumanMessage( + content="Output token limit hit. Resume directly with no apology or recap.", + ) + ) + return { + "messages": recovered_messages, + "transition": ContinueState(reason=ContinueReason.max_output_tokens_recovery), + "max_output_tokens_recovery_count": max_output_tokens_recovery_count + 1, + "max_output_tokens_override": max_output_tokens_override, + "yield_ai": False, + "terminal": None, + } + + surfaced_messages = list(messages) + surfaced_messages.append(ai_msg) + return { + "messages": surfaced_messages, + "transition": ContinueState(reason=ContinueReason.max_output_tokens_recovery), + "max_output_tokens_recovery_count": max_output_tokens_recovery_count, + "max_output_tokens_override": max_output_tokens_override, + "yield_ai": True, + "terminal": TerminalState( + reason=TerminalReason.model_error, + turn_count=turn, + error="max_output_tokens", + ), + } + + async def _force_reactive_compact(self, messages: list, *, thread_id: str) -> list | None: + if self._memory_middleware is None: + return None + compact = getattr(self._memory_middleware, "compact_messages_for_recovery", None) + if not callable(compact): + return None + signature = inspect.signature(compact) + if "thread_id" in signature.parameters: + return await compact(messages, thread_id=thread_id) + return await compact(messages) + + async def _recover_from_overflow(self, messages: list) -> dict[str, Any] | None: + # @@@collapse-drain-single-shot + # ql-04 needs collapse-drain and reactive-compact to stay as separate + # phases. The drain hook is optional, but if present it only gets one + # chance before prompt-too-long falls through to reactive compaction. + for middleware in self.middleware: + recover = getattr(middleware, "recover_from_overflow", None) + if not callable(recover): + continue + drained = recover(messages) + if inspect.isawaitable(drained): + drained = await drained + if drained is None: + return None + committed = int(getattr(drained, "get", lambda *_: 0)("committed", 0)) + updated_messages = getattr(drained, "get", lambda *_: None)("messages") + if committed <= 0 or not isinstance(updated_messages, list): + return None + return {"committed": committed, "messages": list(updated_messages)} + return None + + @staticmethod + def _is_prompt_too_long_error(error_text: str) -> bool: + return ( + "prompt is too long" in error_text + or "prompt too long" in error_text + or "context length" in error_text + or "maximum context length" in error_text + ) + + @staticmethod + def _is_max_output_truncated(message: AIMessage) -> bool: + response_metadata = getattr(message, "response_metadata", None) or {} + additional_kwargs = getattr(message, "additional_kwargs", None) or {} + finish_reason = ( + response_metadata.get("finish_reason") + or response_metadata.get("stop_reason") + or additional_kwargs.get("finish_reason") + or additional_kwargs.get("stop_reason") + ) + return finish_reason in {"length", "max_tokens", "max_output_tokens"} + + # ------------------------------------------------------------------------- + # Tool execution through middleware chain + # ------------------------------------------------------------------------- + + async def _execute_tools( + self, + tool_calls: list, + model_response: ModelResponse, + tool_context: ToolUseContext | None, + ) -> list[ToolMessage]: + """Execute tool calls respecting concurrency safety, via middleware chain.""" + results: dict[int, ToolMessage] = {} + + async def execute_batch(batch: list[tuple[int, dict]]) -> None: + if not batch: + return + batch_results = await asyncio.gather( + *[self._execute_single_tool(tool_call, tool_context) for _, tool_call in batch], + return_exceptions=True, + ) + for (idx, tool_call), result in zip(batch, batch_results): + if isinstance(result, Exception): + results[idx] = ToolMessage( + content=f"{result}", + tool_call_id=tool_call.get("id", ""), + name=tool_call.get("name", ""), + ) + continue + results[idx] = result + + safe_batch: list[tuple[int, dict]] = [] + for idx, tool_call in enumerate(tool_calls): + # @@@tool-order-boundary + # te-01 needs the non-streaming path to keep the same queue barrier + # semantics as the streaming executor: contiguous safe tools may fan + # out together, but any unsafe tool flushes the batch and blocks the + # next safe tool until it finishes. + if self._tool_is_concurrency_safe(tool_call): + safe_batch.append((idx, tool_call)) + continue + + await execute_batch(safe_batch) + safe_batch = [] + try: + results[idx] = await self._execute_single_tool(tool_call, tool_context) + except Exception as exc: + results[idx] = ToolMessage( + content=f"{exc}", + tool_call_id=tool_call.get("id", ""), + name=tool_call.get("name", ""), + ) + + await execute_batch(safe_batch) + return [results[i] for i in range(len(tool_calls))] + + async def _execute_single_tool( + self, + tool_call: dict, + tool_context: ToolUseContext | None, + ) -> ToolMessage: + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + call_id = tool_call.get("id", "") + args = tool_call.get("args", {}) or tool_call.get("function", {}).get("arguments", {}) + + if isinstance(args, str): + import json + + try: + args = json.loads(args) + except Exception: + args = {} + + normalized_call = {"name": name, "args": args, "id": call_id} + tc_request = ToolCallRequest( + tool_call=normalized_call, + tool=None, + state=tool_context, + runtime=self._runtime, # type: ignore[arg-type] + ) + + async def innermost_tool_handler(req: ToolCallRequest) -> ToolMessage: + tc = req.tool_call + t_name = tc.get("name", "") + t_id = tc.get("id", "") + t_args = tc.get("args", {}) + entry = self._registry.get(t_name) + if entry is None: + return ToolMessage( + content=f"Tool '{t_name}' not found", + tool_call_id=t_id, + name=t_name, + ) + try: + import asyncio as _asyncio + + if _asyncio.iscoroutinefunction(entry.handler): + result = await entry.handler(**t_args) + else: + result = await _asyncio.to_thread(entry.handler, **t_args) + return ToolMessage(content=str(result), tool_call_id=t_id, name=t_name) + except Exception as e: + return ToolMessage( + content=f"{e}", + tool_call_id=t_id, + name=t_name, + ) + + tool_handler = innermost_tool_handler + for mw in reversed(self.middleware): + if _mw_overrides_tool_call(mw): + tool_handler = _make_tool_wrapper(mw, tool_handler) + + return await tool_handler(tc_request) + + def _tool_is_concurrency_safe(self, tool_call: dict) -> bool: + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + entry = self._registry.get(name) + if entry is None: + return False + safety = entry.is_concurrency_safe + if callable(safety): + args = tool_call.get("args", {}) + if isinstance(args, str): + try: + import json as _json + + args = _json.loads(args) + except Exception: + args = {} + try: + return bool(safety(args if isinstance(args, dict) else {})) + except Exception: + return False + return bool(safety) + + def _tool_call_is_ready(self, tool_call: dict) -> bool: + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + entry = self._registry.get(name) + if entry is None: + return True + + args = tool_call.get("args", {}) + if isinstance(args, str): + try: + import json as _json + + args = _json.loads(args) + except Exception: + return False + if not isinstance(args, dict): + return False + + schema = entry.get_schema() or {} + parameters = schema.get("parameters", {}) if isinstance(schema, dict) else {} + return _required_sets_match(parameters, args) if isinstance(parameters, dict) else True + + def _normalize_stream_tool_call( + self, + tool_call: dict, + tool_call_chunks: list[dict[str, Any]], + ) -> dict[str, Any] | None: + call_id = tool_call.get("id") + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + args: Any = tool_call.get("args", {}) + if isinstance(args, str): + try: + import json as _json + + args = _json.loads(args) + except Exception: + args = {} + + raw_arg_chunks: list[str] = [] + for chunk in tool_call_chunks: + if chunk.get("id") != call_id: + continue + if chunk.get("name"): + name = chunk["name"] + raw_args = chunk.get("args") + if raw_args in (None, ""): + continue + if isinstance(raw_args, str): + raw_arg_chunks.append(raw_args) + else: + args = raw_args + + if raw_arg_chunks: + try: + import json as _json + + args = _json.loads("".join(raw_arg_chunks)) + except Exception: + return None + + normalized = {"name": name, "args": args, "id": call_id} + if not self._tool_call_is_ready(normalized): + return None + return normalized + + # ------------------------------------------------------------------------- + # Checkpointer persistence + # ------------------------------------------------------------------------- + + async def _load_messages(self, thread_id: str) -> list: + """Load message history from checkpointer (if available).""" + channel_values = await self._load_checkpoint_channel_values(thread_id) + return list(channel_values.get("messages", [])) + + async def _load_checkpoint_channel_values(self, thread_id: str) -> dict[str, Any]: + """Load raw channel values for one thread checkpoint.""" + if self.checkpointer is None: + return {} + try: + cfg = self._checkpoint_config(thread_id) + checkpoint = await self.checkpointer.aget(cfg) + if checkpoint is None: + return {} + return dict(checkpoint.get("channel_values", {}) or {}) + except Exception: + logger.debug("QueryLoop: could not load checkpoint for thread %s", thread_id) + return {} + + def _thread_permission_state_snapshot( + self, + thread_id: str, + ) -> tuple[dict[str, Any], dict[str, dict[str, Any]], dict[str, dict[str, Any]]]: + if self._app_state is None: + return {}, {}, {} + + permission_context = copy.deepcopy(self._app_state.tool_permission_context.model_dump()) + pending = { + key: copy.deepcopy(value) + for key, value in self._app_state.pending_permission_requests.items() + if value.get("thread_id") == thread_id + } + resolved = { + key: copy.deepcopy(value) + for key, value in self._app_state.resolved_permission_requests.items() + if value.get("thread_id") == thread_id + } + return permission_context, pending, resolved + + def _thread_memory_state_snapshot(self, thread_id: str) -> dict[str, Any]: + if self._memory_middleware is None: + return {} + snapshot = getattr(self._memory_middleware, "snapshot_thread_state", None) + if not callable(snapshot): + return {} + raw_snapshot = snapshot(thread_id) or {} + if not isinstance(raw_snapshot, dict): + return {} + return {str(key): value for key, value in raw_snapshot.items()} + + def _thread_mcp_instruction_state_snapshot(self, thread_id: str) -> dict[str, Any]: + if self._app_state is None: + return {} + announced_blocks = dict(self._app_state.announced_mcp_instruction_blocks.get(thread_id, {})) + return {"announced_blocks": announced_blocks} + + def _is_runtime_active(self) -> bool: + current_state = getattr(self._runtime, "current_state", None) + return getattr(current_state, "value", current_state) == "active" + + def _snapshot_live_thread_state(self, thread_id: str) -> dict[str, Any]: + messages = list(self._app_state.messages) if self._app_state is not None else [] + permission_context, pending, resolved = self._thread_permission_state_snapshot(thread_id) + memory_state = self._thread_memory_state_snapshot(thread_id) + return { + "messages": messages, + "tool_permission_context": permission_context, + "pending_permission_requests": pending, + "resolved_permission_requests": resolved, + "memory_compaction_state": memory_state, + "mcp_instruction_state": self._thread_mcp_instruction_state_snapshot(thread_id), + } + + def _restore_thread_permission_state( + self, + thread_id: str, + *, + permission_context: dict[str, Any], + pending: dict[str, dict[str, Any]], + resolved: dict[str, dict[str, Any]], + ) -> None: + if self._app_state is None: + return + + # @@@permission-checkpoint-bridge - pending/resolved permission requests + # are thread-scoped runtime state, not display-only metadata. They must + # survive checkpoint replay so backend/UI surfaces stay honest after an + # idle reload or agent recreation. + def _update(state: AppState) -> AppState: + kept_pending = {key: value for key, value in state.pending_permission_requests.items() if value.get("thread_id") != thread_id} + kept_pending.update(copy.deepcopy(pending)) + kept_resolved = {key: value for key, value in state.resolved_permission_requests.items() if value.get("thread_id") != thread_id} + kept_resolved.update(copy.deepcopy(resolved)) + return state.model_copy( + update={ + "tool_permission_context": ToolPermissionState.model_validate(copy.deepcopy(permission_context)), + "pending_permission_requests": kept_pending, + "resolved_permission_requests": kept_resolved, + } + ) + + self._app_state.set_state(_update) + + def _restore_thread_memory_state( + self, + thread_id: str, + *, + memory_state: dict[str, Any], + ) -> None: + if self._memory_middleware is None: + return + restore = getattr(self._memory_middleware, "restore_thread_state", None) + if callable(restore): + restore(thread_id, memory_state) + + def _restore_thread_mcp_instruction_state( + self, + thread_id: str, + *, + mcp_instruction_state: dict[str, Any], + ) -> None: + if self._app_state is None: + return + announced_blocks = mcp_instruction_state.get("announced_blocks", {}) + if not isinstance(announced_blocks, dict): + announced_blocks = {} + kept = {key: value for key, value in self._app_state.announced_mcp_instruction_blocks.items() if key != thread_id} + kept[thread_id] = {name: block for name, block in announced_blocks.items() if isinstance(name, str) and isinstance(block, str)} + self._app_state.announced_mcp_instruction_blocks = kept + + async def _hydrate_thread_state_from_checkpoint(self, thread_id: str) -> dict[str, Any]: + channel_values = await self._load_checkpoint_channel_values(thread_id) + messages = list(channel_values.get("messages", [])) + permission_context = dict(channel_values.get("tool_permission_context", {}) or {}) + pending = dict(channel_values.get("pending_permission_requests", {}) or {}) + resolved = dict(channel_values.get("resolved_permission_requests", {}) or {}) + memory_state = dict(channel_values.get("memory_compaction_state", {}) or {}) + mcp_instruction_state = dict(channel_values.get("mcp_instruction_state", {}) or {}) + turn_count = self._app_state.turn_count if self._app_state is not None else 0 + self._sync_app_state(messages=messages, turn_count=turn_count) + self._restore_thread_permission_state( + thread_id, + permission_context=permission_context, + pending=pending, + resolved=resolved, + ) + self._restore_thread_memory_state( + thread_id, + memory_state=memory_state, + ) + self._restore_thread_mcp_instruction_state( + thread_id, + mcp_instruction_state=mcp_instruction_state, + ) + return { + "messages": messages, + "tool_permission_context": permission_context, + "pending_permission_requests": pending, + "resolved_permission_requests": resolved, + "memory_compaction_state": memory_state, + "mcp_instruction_state": mcp_instruction_state, + } + + async def _save_messages(self, thread_id: str, messages: list) -> None: + """Persist message history to checkpointer.""" + if self.checkpointer is None: + return + try: + from langgraph.checkpoint.base import CheckpointMetadata, empty_checkpoint + + cfg = self._checkpoint_config(thread_id) + checkpoint = empty_checkpoint() + permission_context, pending_requests, resolved_requests = self._thread_permission_state_snapshot(thread_id) + memory_state = self._thread_memory_state_snapshot(thread_id) + mcp_instruction_state = self._thread_mcp_instruction_state_snapshot(thread_id) + checkpoint["channel_values"] = { + "messages": messages, + "tool_permission_context": permission_context, + "pending_permission_requests": pending_requests, + "resolved_permission_requests": resolved_requests, + "memory_compaction_state": memory_state, + "mcp_instruction_state": mcp_instruction_state, + } + metadata: CheckpointMetadata = { + "source": "loop", + "step": len(messages), + } + await self.checkpointer.aput(cfg, checkpoint, metadata, {}) + except Exception: + logger.debug("QueryLoop: could not save checkpoint for thread %s", thread_id, exc_info=True) + + def _collect_memory_system_notices(self, pending_notices: list[HumanMessage]) -> None: + if self._memory_middleware is None: + return + consume_many = getattr(self._memory_middleware, "consume_pending_notices", None) + notices: list[dict[str, Any]] = [] + if callable(consume_many): + notices = list(consume_many() or []) + else: + consume_one = getattr(self._memory_middleware, "consume_latest_compaction_notice", None) + if callable(consume_one): + notice = consume_one() + if notice: + notices = [notice] + for notice in notices: + pending_notices.append( + HumanMessage( + content=str(notice.get("content") or ""), + metadata={ + "source": "system", + "notification_type": str(notice.get("notification_type") or "compact"), + "compact_boundary_index": int(notice.get("compact_boundary_index") or 0), + }, + ) + ) + + def _append_system_notices(self, messages: list, notices: list[HumanMessage]) -> list: + if not notices: + return messages + # @@@compact-notice-persist - compaction changes the model-visible + # boundary, but the notice is for the owner surface only. Persist it + # after the run settles so replay stays honest without perturbing the + # same run's next model call. + return list(messages) + list(notices) + + def _build_terminal_notice(self, terminal: TerminalState | None) -> HumanMessage | None: + # @@@terminal-recovery-notice - recovery exhaustion must survive cold + # rebuilds. Persist one owner-visible system notice instead of leaving + # prompt-too-long as a hot-stream-only error. + if terminal is None or terminal.reason is not TerminalReason.prompt_too_long: + return None + return HumanMessage( + content=_PROMPT_TOO_LONG_NOTICE_TEXT, + metadata={"source": "system"}, + ) + + def _terminal_error_text(self, terminal: TerminalState) -> str: + if terminal.reason is TerminalReason.prompt_too_long: + return _PROMPT_TOO_LONG_NOTICE_TEXT + return terminal.error or terminal.reason.value + + def _build_visible_terminal_error_message( + self, + terminal: TerminalState, + messages: list[Any], + ) -> AIMessage | None: + if terminal.reason is TerminalReason.completed: + return None + error_text = self._terminal_error_text(terminal).strip() + if not error_text: + return None + last_message = messages[-1] if messages else None + if isinstance(last_message, AIMessage) and self._ai_message_has_visible_content(last_message): + return None + return AIMessage(content=f"Error: {error_text}") + + @staticmethod + def _checkpoint_config(thread_id: str) -> dict[str, Any]: + # @@@sa-03-real-checkpointer-config + # AsyncSqliteSaver requires checkpoint_ns even when we only use a + # single logical namespace; without it, aput() raises and replay dies. + return {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + + async def aclear(self, thread_id: str) -> None: + """Clear turn-scoped state for a thread while preserving session accumulators.""" + await self._save_messages(thread_id, []) + + self._tool_read_file_state.clear() + self._tool_loaded_nested_memory_paths.clear() + self._tool_discovered_skill_names.clear() + self._tool_discovered_tool_names_by_thread.pop(thread_id, None) + + if self._memory_middleware is not None: + summary_store = getattr(self._memory_middleware, "summary_store", None) + if summary_store is not None: + # @@@clear-thread-clears-summary-store - api-05 requires /clear + # to wipe replayable compaction state, not just in-memory cache. + summary_store.delete_thread_summaries(thread_id) + if hasattr(self._memory_middleware, "_cached_summary"): + self._memory_middleware._cached_summary = None + if hasattr(self._memory_middleware, "_summary_restored"): + self._memory_middleware._summary_restored = False + if hasattr(self._memory_middleware, "_summary_thread_id"): + self._memory_middleware._summary_thread_id = None + if hasattr(self._memory_middleware, "_compact_up_to_index"): + self._memory_middleware._compact_up_to_index = 0 + clear_thread_state = getattr(self._memory_middleware, "clear_thread_state", None) + if callable(clear_thread_state): + clear_thread_state(thread_id) + + if self._app_state is not None: + preserved_total_cost = self._app_state.total_cost + preserved_tool_overrides = dict(self._app_state.tool_overrides) + pending_requests = { + key: value for key, value in self._app_state.pending_permission_requests.items() if value.get("thread_id") != thread_id + } + resolved_requests = { + key: value for key, value in self._app_state.resolved_permission_requests.items() if value.get("thread_id") != thread_id + } + + def _reset(state: AppState) -> AppState: + return state.model_copy( + update={ + "messages": [], + "turn_count": 0, + "total_cost": preserved_total_cost, + "compact_boundary_index": 0, + "tool_overrides": preserved_tool_overrides, + "pending_permission_requests": pending_requests, + "resolved_permission_requests": resolved_requests, + } + ) + + self._app_state.set_state(_reset) + + await self._save_messages(thread_id, []) + + if self._bootstrap is not None: + old_session_id = self._bootstrap.session_id + self._bootstrap.parent_session_id = old_session_id + self._bootstrap.session_id = uuid.uuid4().hex + + # ------------------------------------------------------------------------- + # Input parsing + # ------------------------------------------------------------------------- + + @staticmethod + def _parse_input(input: dict | None) -> list: + """Convert input dict to list of LangChain message objects.""" + if input is None: + return [] + raw_messages = input.get("messages", []) + result = [] + for msg in raw_messages: + if hasattr(msg, "content"): + result.append(msg) + elif isinstance(msg, dict): + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "user": + result.append(HumanMessage(content=content)) + elif role == "assistant": + result.append(AIMessage(content=content)) + else: + result.append(HumanMessage(content=content)) + return result + + @staticmethod + def _ai_message_has_visible_content(message: AIMessage) -> bool: + content = getattr(message, "content", None) + if isinstance(content, str): + return content.strip() != "" + if isinstance(content, list): + for item in content: + if isinstance(item, str) and item.strip(): + return True + if isinstance(item, dict) and str(item.get("text", "")).strip(): + return True + return False + return bool(content) + + @staticmethod + def _get_terminal_followthrough_notice(messages: list[Any]) -> HumanMessage | None: + if not messages: + return None + last_message = messages[-1] + if last_message.__class__.__name__ != "HumanMessage": + return None + metadata = getattr(last_message, "metadata", None) or {} + if metadata.get("source") != "system": + return None + if metadata.get("notification_type") not in {"agent", "command"}: + return None + content = getattr(last_message, "content", "") + text = content if isinstance(content, str) else str(content) + if "CommandNotification" not in text and "task-notification" not in text: + return None + return last_message + + @staticmethod + def _get_chat_followthrough_notice(messages: list[Any]) -> HumanMessage | None: + if not messages: + return None + last_message = messages[-1] + if last_message.__class__.__name__ != "HumanMessage": + return None + metadata = getattr(last_message, "metadata", None) or {} + if metadata.get("source") != "external": + return None + if metadata.get("notification_type") != "chat": + return None + content = getattr(last_message, "content", "") + text = content if isinstance(content, str) else str(content) + if "New message from" not in text or "read_messages(chat_id=" not in text: + return None + return last_message + + @classmethod + def _build_terminal_followthrough_fallback(cls, notice: HumanMessage) -> AIMessage: + metadata = getattr(notice, "metadata", None) or {} + notification_type = str(metadata.get("notification_type") or "task") + content = getattr(notice, "content", "") + text = content if isinstance(content, str) else str(content) + status_match = re.search(r"(.*?)", text, flags=re.IGNORECASE | re.DOTALL) + status = status_match.group(1).strip().lower() if status_match else "" + subject = "command" if notification_type == "command" else "agent" + # @@@terminal-followthrough-fallback - terminal background notifications + # must never collapse into notice-only durable history when the model + # reentry stays silent; surface the silence explicitly instead. + if status == "completed": + reply = f"Background {subject} completed, but the followthrough assistant reply was empty." + elif status == "cancelled": + reply = f"Background {subject} was cancelled, but the followthrough assistant reply was empty." + elif status == "error": + reply = f"Background {subject} failed, but the followthrough assistant reply was empty." + else: + reply = f"Background {subject} update arrived, but the followthrough assistant reply was empty." + return AIMessage(content=reply) + + @classmethod + def _build_chat_followthrough_fallback(cls, notice: HumanMessage) -> AIMessage: + content = getattr(notice, "content", "") + text = content if isinstance(content, str) else str(content) + chat_id_match = re.search(r'read_messages\(chat_id="([^"]+)"\)', text) + if chat_id_match: + chat_id = chat_id_match.group(1) + reply = ( + f"I received a chat notification, but the followthrough assistant reply was empty. " + f'Read it with read_messages(chat_id="{chat_id}") before deciding whether to reply.' + ) + else: + reply = "I received a chat notification, but the followthrough assistant reply was empty." + return AIMessage(content=reply) + + +class _StreamingToolExecutor: + def __init__(self, loop: QueryLoop, tool_context: ToolUseContext | None): + self._loop = loop + self._tool_context = tool_context + self._tracked: list[_TrackedTool] = [] + self._discarded = False + + async def add_tool(self, tool_call: dict[str, Any]) -> None: + if self._discarded: + return + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + if self._loop._registry.get(name) is None: + self._tracked.append( + _TrackedTool( + order=len(self._tracked), + tool_call=tool_call, + is_concurrency_safe=False, + status="completed", + result=self._tool_error(tool_call, f"Tool '{name}' not found"), + ) + ) + return + tracked = _TrackedTool( + order=len(self._tracked), + tool_call=tool_call, + is_concurrency_safe=self._loop._tool_is_concurrency_safe(tool_call), + ) + self._tracked.append(tracked) + self._process_queue() + + async def get_completed_results(self) -> list[ToolMessage]: + await asyncio.sleep(0) + self._process_queue() + ready: list[ToolMessage] = [] + for tracked in self._tracked: + if tracked.status == "yielded": + continue + if tracked.status == "completed" and tracked.result is not None: + tracked.status = "yielded" + ready.append(tracked.result) + continue + break + return ready + + async def drain_remaining(self) -> list[ToolMessage]: + while True: + self._process_queue() + running = [tracked.task for tracked in self._tracked if tracked.status == "executing" and tracked.task is not None] + if not running: + break + await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED) + self._process_queue() + remaining: list[ToolMessage] = [] + for tracked in self._tracked: + if tracked.status == "yielded": + continue + if tracked.status == "completed" and tracked.result is not None: + tracked.status = "yielded" + remaining.append(tracked.result) + return remaining + + async def discard(self, reason: str) -> list[ToolMessage]: + # @@@streaming-tool-discard + # ql-05 must not leave orphaned tool tasks behind when streaming exits + # early. Synthetic error emission is still a later hardening pass, but + # task cleanup itself must happen now. + self._discarded = True + running: list[asyncio.Task[ToolMessage]] = [] + for tracked in self._tracked: + if tracked.status == "queued": + tracked.status = "completed" + tracked.result = self._synthetic_error(tracked.tool_call, reason) + continue + if tracked.status == "executing" and tracked.task is not None: + tracked.task.cancel() + running.append(tracked.task) + if running: + await asyncio.gather(*running, return_exceptions=True) + for tracked in self._tracked: + if tracked.status == "executing": + tracked.status = "completed" + tracked.result = self._synthetic_error(tracked.tool_call, reason) + return await self.drain_remaining() + + def _process_queue(self) -> None: + if self._discarded: + return + for tracked in self._tracked: + if tracked.status != "queued": + continue + if not self._can_execute(tracked): + break + tracked.status = "executing" + tracked.task = asyncio.create_task(self._run_tool(tracked)) + + def _can_execute(self, tracked: _TrackedTool) -> bool: + executing = [item for item in self._tracked if item.status == "executing"] + if not executing: + return True + if not tracked.is_concurrency_safe: + return False + return all(item.is_concurrency_safe for item in executing) + + async def _run_tool(self, tracked: _TrackedTool) -> None: + # @@@streaming-tool-task-exit + # ql-05 cannot let middleware-level exceptions disappear into a dead + # task. Every tool_use must resolve to a ToolMessage, and queue + # progression must re-run immediately when a task exits. + try: + tracked.result = await self._loop._execute_single_tool(tracked.tool_call, self._tool_context) + tracked.status = "completed" + except asyncio.CancelledError: + raise + except Exception as exc: + tracked.result = self._tool_error(tracked.tool_call, str(exc)) + tracked.status = "completed" + finally: + if self._should_abort_siblings(tracked): + await self._abort_siblings( + excluding=tracked, + reason="sibling aborted after bash error", + ) + if not self._discarded: + self._process_queue() + + def _should_abort_siblings(self, tracked: _TrackedTool) -> bool: + if tracked.result is None: + return False + name = tracked.tool_call.get("name") or tracked.tool_call.get("function", {}).get("name", "") + return name.lower() == "bash" and "" in tracked.result.content + + async def _abort_siblings(self, *, excluding: _TrackedTool, reason: str) -> None: + # @@@bash-sibling-abort + # Claude Code only fan-outs this abort for bash failures. Keep it + # local to the current executor iteration so the parent loop survives + # and later turns can continue with explicit tool errors. + self._discarded = True + running: list[asyncio.Task[ToolMessage]] = [] + for tracked in self._tracked: + if tracked is excluding or tracked.status in {"completed", "yielded"}: + continue + if tracked.status == "queued": + tracked.status = "completed" + tracked.result = self._tool_error(tracked.tool_call, reason) + continue + if tracked.status == "executing" and tracked.task is not None: + tracked.task.cancel() + running.append(tracked.task) + if running: + await asyncio.gather(*running, return_exceptions=True) + for tracked in self._tracked: + if tracked is excluding or tracked.status != "executing": + continue + tracked.status = "completed" + tracked.result = self._tool_error(tracked.tool_call, reason) + + def _synthetic_error(self, tool_call: dict[str, Any], reason: str) -> ToolMessage: + return self._tool_error( + tool_call, + f"streaming discarded: {reason}", + ) + + def _tool_error(self, tool_call: dict[str, Any], error_text: str) -> ToolMessage: + name = tool_call.get("name") or tool_call.get("function", {}).get("name", "") + call_id = tool_call.get("id", "") + return ToolMessage( + content=f"{error_text}", + tool_call_id=call_id, + name=name, + ) + + +# ------------------------------------------------------------------------- +# Closure helpers (avoid late-binding bugs in loop-built lambdas) +# ------------------------------------------------------------------------- + + +def _make_model_wrapper(mw: AgentMiddleware, next_handler): + """Build an awrap_model_call wrapper that correctly closes over mw and next_handler.""" + + async def wrapper(request: ModelRequest) -> ModelResponse: + return await mw.awrap_model_call(request, next_handler) + + return wrapper + + +def _make_tool_wrapper(mw: AgentMiddleware, next_handler): + """Build an awrap_tool_call wrapper that correctly closes over mw and next_handler.""" + + async def wrapper(request: ToolCallRequest) -> ToolMessage: + return await mw.awrap_tool_call(request, next_handler) + + return wrapper + + +# ------------------------------------------------------------------------- +# Middleware override detection helpers +def _mw_overrides_model_call(mw: AgentMiddleware) -> bool: + """True if mw actually overrides awrap_model_call (not just inherits the base stub).""" + mw_type = type(mw) + own_fn = mw_type.__dict__.get("awrap_model_call") + if own_fn is not None: + return True + own_sync = mw_type.__dict__.get("wrap_model_call") + return own_sync is not None + + +def _mw_overrides_tool_call(mw: AgentMiddleware) -> bool: + """True if mw actually overrides awrap_tool_call (not just inherits the base stub).""" + mw_type = type(mw) + own_fn = mw_type.__dict__.get("awrap_tool_call") + if own_fn is not None: + return True + own_sync = mw_type.__dict__.get("wrap_tool_call") + return own_sync is not None diff --git a/core/runtime/middleware/__init__.py b/core/runtime/middleware/__init__.py index e69de29bb..f777a7fde 100644 --- a/core/runtime/middleware/__init__.py +++ b/core/runtime/middleware/__init__.py @@ -0,0 +1,79 @@ +"""Local runtime middleware protocol and request/response types. + +This replaces the phantom `langchain.agents.middleware.types` dependency for +the current runtime stack. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, replace +from typing import Any, ClassVar + +from langchain_core.messages import ToolMessage + + +@dataclass(frozen=True) +class ModelRequest: + model: Any + messages: list + system_message: Any = None + tools: list | None = None + + def override(self, **changes: Any) -> ModelRequest: + return replace(self, **changes) + + +@dataclass(frozen=True) +class ModelResponse: + result: list + request_messages: list | None = None + prepared_request: ModelRequest | None = None + + +ModelCallResult = ModelResponse + + +@dataclass(frozen=True) +class ToolCallRequest: + tool_call: dict + tool: Any = None + state: Any = None + runtime: Any = None + + def override(self, **changes: Any) -> ToolCallRequest: + return replace(self, **changes) + + +class AgentMiddleware: + """Minimal chain-of-responsibility middleware base for the runtime stack.""" + + tools: ClassVar[tuple[Any, ...]] = () + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + return handler(request) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + return await handler(request) + + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage], + ) -> ToolMessage: + return handler(request) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage]], + ) -> ToolMessage: + return await handler(request) diff --git a/core/runtime/middleware/mcp_instructions.py b/core/runtime/middleware/mcp_instructions.py new file mode 100644 index 000000000..7cff4c7cb --- /dev/null +++ b/core/runtime/middleware/mcp_instructions.py @@ -0,0 +1,80 @@ +"""Thread-scoped MCP instruction delta injection. + +Mycel does not have CC's attachment plane. Keep this contract smaller: +- MCP server configs may carry `instructions` +- the loop stores which server names have already been announced per thread +- on the next turn after a change, inject one delta SystemMessage +""" + +from __future__ import annotations + +import json +from collections.abc import Callable +from typing import Any + +from langchain_core.messages import SystemMessage + +from core.runtime.middleware import AgentMiddleware +from core.runtime.state import AppState + +_DELTA_TAG = "mcp_instructions_delta" + + +def _format_instruction_block(server_name: str, instructions: str) -> str: + return f"## {server_name}\n{instructions.strip()}" + + +def _render_delta_message(*, added: dict[str, str], removed: list[str]) -> SystemMessage: + payload = { + "added_names": sorted(added), + "removed_names": sorted(removed), + } + blocks = [ + "", + f"<{_DELTA_TAG}>{json.dumps(payload, ensure_ascii=False)}", + "MCP server instructions changed for this thread.", + ] + if added: + blocks.append("Use the newly available MCP instructions below for subsequent turns:") + blocks.extend(_format_instruction_block(name, added[name]) for name in sorted(added)) + if removed: + blocks.append("The following MCP servers are no longer active for this thread:") + blocks.extend(f"- {name}" for name in sorted(removed)) + blocks.append("") + return SystemMessage(content="\n".join(blocks)) + + +class McpInstructionsDeltaMiddleware(AgentMiddleware): + """Injects MCP instruction deltas once per thread when the connected set changes.""" + + def __init__( + self, + *, + get_instruction_blocks: Callable[[], dict[str, str]], + get_app_state: Callable[[], AppState | None], + ) -> None: + self._get_instruction_blocks = get_instruction_blocks + self._get_app_state = get_app_state + + def before_model(self, state: dict[str, Any], runtime: Any = None, config: dict[str, Any] | None = None) -> dict[str, Any] | None: + app_state = self._get_app_state() + if app_state is None: + return None + + config = config or {} + thread_id = config.get("configurable", {}).get("thread_id", "default") + current_blocks = {name: block for name, block in self._get_instruction_blocks().items() if block.strip()} + announced_blocks = { + name: block + for name, block in app_state.announced_mcp_instruction_blocks.get(thread_id, {}).items() + if isinstance(name, str) and isinstance(block, str) and block.strip() + } + + added_names = sorted(name for name, block in current_blocks.items() if announced_blocks.get(name) != block) + removed_names = sorted(name for name in announced_blocks if name not in current_blocks) + if not added_names and not removed_names: + return None + + app_state.announced_mcp_instruction_blocks[thread_id] = dict(current_blocks) + added = {name: current_blocks[name] for name in added_names} + return {"messages": [_render_delta_message(added=added, removed=removed_names)]} diff --git a/core/runtime/middleware/memory/compactor.py b/core/runtime/middleware/memory/compactor.py index 67599b534..defbb7221 100644 --- a/core/runtime/middleware/memory/compactor.py +++ b/core/runtime/middleware/memory/compactor.py @@ -10,13 +10,22 @@ from langchain_core.messages import HumanMessage, SystemMessage +# CC L4b Legacy Compact: system prompt is simple (~200 tokens) — NOT inherited from parent. +# Using a distinct simple system prompt prevents reusing the parent conversation's cache +# (different system prompt → different prefix hash), and reduces input token cost. +COMPACT_SYSTEM_PROMPT = "You are a helpful AI assistant tasked with summarizing conversations." + SUMMARY_PROMPT = """\ -Provide a detailed summary for continuing our conversation. Include: -1. Key decisions made and their rationale -2. Files created, modified, or read and their current state -3. Errors encountered and how they were resolved -4. Outstanding tasks and current progress -5. Important context that would be needed to continue the work +Summarize this conversation in the following 9 sections: +1. Request/Intent — what the user asked for +2. Technical Concepts — key technologies and approaches discussed +3. Files/Code — files created or modified and their current state +4. Errors — errors encountered and how they were resolved +5. Problem Solving — decisions made and rationale +6. User Messages — key user inputs and feedback +7. Pending Tasks — unfinished work +8. Current Work — what was actively being done at the end +9. Next Step — the immediate next action needed Be concise but retain all information needed to continue seamlessly.""" SPLIT_TURN_PREFIX_PROMPT = """\ @@ -80,19 +89,41 @@ def split_messages(self, messages: list[Any]) -> tuple[list[Any], list[Any]]: return messages[:split_idx], messages[split_idx:] - async def compact(self, messages_to_summarize: list[Any], model: Any) -> str: + async def compact( + self, + messages_to_summarize: list[Any], + model: Any, + compact_boundary: int = 0, + ) -> str: """Generate a summary of the given messages using the LLM. + Aligned with CC L4b Legacy Compact: + - Uses COMPACT_SYSTEM_PROMPT (simple, ~200 tokens — NOT parent system prompt) + - No tools passed (extended thinking disabled, tools=[]) + - Slices from compact_boundary forward + - max_tokens capped at 20000 (CC max summary output) + Returns plain text summary string. """ - # Build the summarization request + # Slice from compact_boundary forward (CC: from last compact_boundary marker) + if compact_boundary > 0 and compact_boundary < len(messages_to_summarize): + messages_to_summarize = messages_to_summarize[compact_boundary:] + formatted = self._format_messages_for_summary(messages_to_summarize) + # CC L4b: system prompt is simple — does NOT inherit parent's system prompt. + # No tools, no extended thinking. summary_messages = [ - SystemMessage(content=SUMMARY_PROMPT), - HumanMessage(content=f"Here is the conversation to summarize:\n\n{formatted}"), + SystemMessage(content=COMPACT_SYSTEM_PROMPT), + HumanMessage(content=f"Summarize this conversation:\n\n{formatted}\n\n{SUMMARY_PROMPT}"), ] - response = await model.ainvoke(summary_messages) + # Bind max_tokens=20000 (CC max summary output), no tools + try: + bound_model = model.bind(max_tokens=20000) + except Exception: + bound_model = model + + response = await bound_model.ainvoke(summary_messages) return response.content if hasattr(response, "content") else str(response) def _estimate_msg_tokens(self, msg: Any) -> int: diff --git a/core/runtime/middleware/memory/middleware.py b/core/runtime/middleware/memory/middleware.py index 8775e1c21..6dfbc6e96 100644 --- a/core/runtime/middleware/memory/middleware.py +++ b/core/runtime/middleware/memory/middleware.py @@ -7,19 +7,20 @@ from __future__ import annotations +import json import logging from collections.abc import Awaitable, Callable from pathlib import Path from typing import Any -from langchain.agents.middleware.types import ( +from langchain_core.messages import SystemMessage + +from core.runtime.middleware import ( AgentMiddleware, ModelCallResult, ModelRequest, ModelResponse, ) -from langchain_core.messages import SystemMessage - from storage.contracts import SummaryRepo from .compactor import ContextCompactor @@ -27,6 +28,7 @@ from .summary_store import SummaryStore logger = logging.getLogger(__name__) +_COMPACTION_BREAKER_THRESHOLD = 3 class MemoryMiddleware(AgentMiddleware): @@ -86,6 +88,10 @@ def __init__( self._cached_summary: str | None = None self._compact_up_to_index: int = 0 self._summary_restored: bool = False + self._summary_thread_id: str | None = None + self._pending_owner_notices: list[dict[str, Any]] = [] + self._compaction_failure_counts_by_thread: dict[str, int] = {} + self._compaction_breaker_open_by_thread: dict[str, bool] = {} if verbose: print("[MemoryMiddleware] Initialized") @@ -125,6 +131,10 @@ def set_runtime(self, runtime: Any) -> None: """Inject AgentRuntime reference (called by agent.py).""" self._runtime = runtime + @property + def compact_boundary_index(self) -> int: + return self._compact_up_to_index + # ========== AgentMiddleware interface ========== async def awrap_model_call( @@ -134,13 +144,18 @@ async def awrap_model_call( ) -> ModelCallResult: messages = list(request.messages) original_count = len(messages) + thread_id = self._extract_thread_id(request) # Restore summary from store if not already done if not self._summary_restored and self.summary_store: - thread_id = self._extract_thread_id(request) if thread_id: await self._restore_summary_from_store(thread_id) self._summary_restored = True + self._summary_thread_id = thread_id + elif self.summary_store and thread_id and self._summary_thread_id != thread_id: + await self._restore_summary_from_store(thread_id) + self._summary_restored = True + self._summary_thread_id = thread_id sys_tokens = self._estimate_system_tokens(request) @@ -173,8 +188,9 @@ async def awrap_model_call( ) if self.compactor.should_compact(estimated, self._context_limit, self._compaction_threshold) and self._model: - thread_id = self._extract_thread_id(request) - messages = await self._do_compact(messages, thread_id) + compacted = await self._attempt_compaction(messages, thread_id=thread_id) + if compacted is not None: + messages = compacted elif self._cached_summary and self._compact_up_to_index > 0: if self._compact_up_to_index <= len(messages): summary_msg = SystemMessage(content=f"[Conversation Summary]\n{self._cached_summary}") @@ -190,7 +206,14 @@ async def awrap_model_call( final_tokens = self._estimate_tokens(messages) + sys_tokens print(f"[Memory] Final: {len(messages)} msgs (~{final_tokens} tokens) sent to LLM (original: {original_count} msgs)") - return await handler(request.override(messages=messages)) + response = await handler(request.override(messages=messages)) + if response.request_messages is None: + return ModelResponse( + result=response.result, + request_messages=list(messages), + prepared_request=response.prepared_request, + ) + return response async def _do_compact(self, messages: list[Any], thread_id: str | None = None) -> list[Any]: """Execute compaction: summarize old messages, return compacted list.""" @@ -219,6 +242,9 @@ async def _do_compact(self, messages: list[Any], thread_id: str | None = None) - self._cached_summary = summary_text self._compact_up_to_index = len(messages) - len(to_keep) + self._summary_restored = True + self._summary_thread_id = thread_id + self._record_compaction_notice() if self.summary_store and thread_id: try: @@ -257,6 +283,7 @@ async def force_compact(self, messages: list[Any]) -> dict[str, Any] | None: summary_text = await self.compactor.compact(to_summarize, self._resolved_model) self._cached_summary = summary_text self._compact_up_to_index = len(messages) - len(to_keep) + self._record_compaction_notice() return { "stats": { "summarized": len(to_summarize), @@ -267,6 +294,24 @@ async def force_compact(self, messages: list[Any]) -> dict[str, Any] | None: if self._runtime: self._runtime.set_flag("is_compacting", False) + async def compact_messages_for_recovery(self, messages: list[Any], thread_id: str | None = None) -> list[Any] | None: + """Force a compaction pass and return the compacted message list.""" + if not self._model: + return None + + pruned = self.pruner.prune(messages) + to_summarize, to_keep = self.compactor.split_messages(pruned) + if len(to_summarize) < 2: + return None + + return await self._attempt_compaction( + pruned, + thread_id=thread_id or self._current_thread_id(), + respect_breaker=False, + record_failures=False, + clear_breaker_on_success=True, + ) + def _estimate_tokens(self, messages: list[Any]) -> int: """Estimate total tokens for messages (chars // 2).""" total = 0 @@ -306,6 +351,110 @@ def _extract_thread_id(self, request: ModelRequest) -> str | None: return configurable.get("thread_id") return getattr(configurable, "thread_id", None) if configurable else None + def consume_pending_notices(self) -> list[dict[str, Any]]: + notices = list(self._pending_owner_notices) + self._pending_owner_notices.clear() + return notices + + def snapshot_thread_state(self, thread_id: str) -> dict[str, Any]: + return { + "failure_count": int(self._compaction_failure_counts_by_thread.get(thread_id, 0)), + "breaker_open": bool(self._compaction_breaker_open_by_thread.get(thread_id, False)), + } + + def restore_thread_state(self, thread_id: str, state: dict[str, Any] | None) -> None: + payload = dict(state or {}) + failure_count = int(payload.get("failure_count") or 0) + breaker_open = bool(payload.get("breaker_open", False)) + if failure_count > 0: + self._compaction_failure_counts_by_thread[thread_id] = failure_count + else: + self._compaction_failure_counts_by_thread.pop(thread_id, None) + if breaker_open: + self._compaction_breaker_open_by_thread[thread_id] = True + else: + self._compaction_breaker_open_by_thread.pop(thread_id, None) + + def clear_thread_state(self, thread_id: str) -> None: + self._compaction_failure_counts_by_thread.pop(thread_id, None) + self._compaction_breaker_open_by_thread.pop(thread_id, None) + + def _record_compaction_notice(self) -> None: + content = f"Conversation compacted. Earlier {self._compact_up_to_index} message(s) are now represented by a summary." + self._queue_owner_notice( + { + "content": content, + "notification_type": "compact", + "compact_boundary_index": self._compact_up_to_index, + } + ) + + def _current_thread_id(self) -> str | None: + from sandbox.thread_context import get_current_thread_id + + return get_current_thread_id() + + async def _attempt_compaction( + self, + messages: list[Any], + *, + thread_id: str | None, + respect_breaker: bool = True, + record_failures: bool = True, + clear_breaker_on_success: bool = False, + ) -> list[Any] | None: + # @@@compaction-breaker-scope - match cc-src's narrower boundary: + # the breaker blocks later automatic compaction attempts, but reactive + # recovery may still try once and clear the breaker on success. + if respect_breaker and thread_id and self._compaction_breaker_open_by_thread.get(thread_id, False): + return None + try: + compacted = await self._do_compact(messages, thread_id) + except Exception as exc: + logger.error("[Memory] Compaction failed for thread %s: %s", thread_id or "", exc) + if record_failures: + self._record_compaction_failure(thread_id, exc) + return None + self._record_compaction_success(thread_id, clear_breaker=clear_breaker_on_success) + return compacted + + def _record_compaction_success(self, thread_id: str | None, *, clear_breaker: bool = False) -> None: + if not thread_id: + return + self._compaction_failure_counts_by_thread.pop(thread_id, None) + if clear_breaker: + self._compaction_breaker_open_by_thread.pop(thread_id, None) + + def _record_compaction_failure(self, thread_id: str | None, exc: Exception) -> None: + if not thread_id: + return + failures = int(self._compaction_failure_counts_by_thread.get(thread_id, 0)) + 1 + self._compaction_failure_counts_by_thread[thread_id] = failures + if failures < _COMPACTION_BREAKER_THRESHOLD or self._compaction_breaker_open_by_thread.get(thread_id, False): + return + self._compaction_breaker_open_by_thread[thread_id] = True + self._queue_owner_notice( + { + "content": "Automatic compaction disabled for this thread after repeated failures. Clear the thread or start a new one.", + "notification_type": "compact_breaker", + "failure_count": failures, + "error": str(exc), + } + ) + + def _queue_owner_notice(self, notice: dict[str, Any]) -> None: + self._pending_owner_notices.append(dict(notice)) + if self._runtime and hasattr(self._runtime, "emit_activity_event"): + # @@@memory-owner-notices - compaction boundary and breaker state are + # owner-facing runtime facts, so stream and cold rebuild must share + # the same notice payload instead of inventing separate surfaces. + self._runtime.emit_activity_event( + { + "event": "notice", + "data": json.dumps(notice, ensure_ascii=False), + } + ) + async def _restore_summary_from_store(self, thread_id: str) -> None: """Restore summary from SummaryStore.""" if not thread_id: @@ -314,6 +463,8 @@ async def _restore_summary_from_store(self, thread_id: str) -> None: ) try: + self._cached_summary = None + self._compact_up_to_index = 0 summary_data = self.summary_store.get_latest_summary(thread_id) if not summary_data: @@ -332,6 +483,7 @@ async def _restore_summary_from_store(self, thread_id: str) -> None: self._cached_summary = summary_data.summary_text self._compact_up_to_index = summary_data.compact_up_to_index + self._summary_thread_id = thread_id if self.verbose: print( @@ -342,6 +494,8 @@ async def _restore_summary_from_store(self, thread_id: str) -> None: ) except Exception as e: + self._cached_summary = None + self._compact_up_to_index = 0 logger.error(f"[Memory] Failed to restore summary: {e}") async def _rebuild_summary_from_checkpointer(self, thread_id: str) -> None: diff --git a/core/runtime/middleware/monitor/cost.py b/core/runtime/middleware/monitor/cost.py index 4b09c2a51..08615af02 100644 --- a/core/runtime/middleware/monitor/cost.py +++ b/core/runtime/middleware/monitor/cost.py @@ -112,7 +112,7 @@ def _load_cache() -> tuple[dict[str, dict[str, str]], dict[str, int], dict[str, if not cache_path.exists(): return None try: - data = json.loads(cache_path.read_text()) + data = json.loads(cache_path.read_text(encoding="utf-8")) if time.time() - data.get("timestamp", 0) > _CACHE_TTL: return None models = data.get("models", {}) @@ -128,7 +128,7 @@ def _save_cache(models: dict[str, dict[str, str]], context_limits: dict[str, int try: _CACHE_PATH.parent.mkdir(parents=True, exist_ok=True) data = {"timestamp": time.time(), "models": models, "context_limits": context_limits, "providers": providers} - _CACHE_PATH.write_text(json.dumps(data)) + _CACHE_PATH.write_text(json.dumps(data), encoding="utf-8") except Exception: pass @@ -163,11 +163,17 @@ def fetch_openrouter_pricing() -> dict[str, dict[str, Decimal]]: cached = _load_cache() if cached: models_raw, ctx, provs = cached - _pricing_data = _deserialize_costs(models_raw) - _context_limits = ctx - _model_providers = provs - _initialized = True - return _pricing_data + cached_costs = _deserialize_costs(models_raw) + # @@@pricing-cache-integrity - older CI caches can carry context/provider + # metadata with an empty model-pricing payload, which makes cost + # calculation silently degrade while context-limit tests still pass. + # Treat that cache as invalid and fall through to bundled/API reload. + if cached_costs: + _pricing_data = cached_costs + _context_limits = ctx + _model_providers = provs + _initialized = True + return _pricing_data _pricing_data = _fetch_from_openrouter() or _load_bundled() _initialized = True @@ -219,7 +225,10 @@ def _load_bundled() -> dict[str, dict[str, Decimal]]: if not _BUNDLED_PATH.exists(): return {} try: - data = json.loads(_BUNDLED_PATH.read_text()) + # @@@bundled-models-utf8 - Windows runners do not default to UTF-8. + # The bundled OpenRouter snapshot contains non-ASCII descriptions, so + # implicit decoding can fail and silently collapse pricing/context data. + data = json.loads(_BUNDLED_PATH.read_text(encoding="utf-8")) result: dict[str, dict[str, Decimal]] = {} ctx_result: dict[str, int] = {} prov_result: dict[str, str] = {} diff --git a/core/runtime/middleware/monitor/middleware.py b/core/runtime/middleware/monitor/middleware.py index 218ebcd06..899617379 100644 --- a/core/runtime/middleware/monitor/middleware.py +++ b/core/runtime/middleware/monitor/middleware.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Callable from typing import Any -from langchain.agents.middleware.types import ( +from core.runtime.middleware import ( AgentMiddleware, ModelCallResult, ModelRequest, @@ -113,6 +113,9 @@ async def awrap_model_call( self._state_monitor.mark_error(e) raise + if response.prepared_request is not None: + return response + messages = response.result if hasattr(response, "result") else [response] resp_dict = {"messages": messages} diff --git a/core/runtime/middleware/prompt_caching/__init__.py b/core/runtime/middleware/prompt_caching/__init__.py index 87f4e92b4..361b124a8 100644 --- a/core/runtime/middleware/prompt_caching/__init__.py +++ b/core/runtime/middleware/prompt_caching/__init__.py @@ -1,8 +1,8 @@ """Anthropic prompt caching middleware. Requires: - - `langchain`: For agent middleware framework - - `langchain-anthropic`: For `ChatAnthropic` model (already a dependency) + - local `core.runtime.middleware` protocol types + - `langchain-anthropic`: For `ChatAnthropic` model """ from collections.abc import Awaitable, Callable @@ -10,9 +10,10 @@ from warnings import warn from langchain_anthropic.chat_models import ChatAnthropic +from langchain_core.messages import SystemMessage try: - from langchain.agents.middleware.types import ( + from core.runtime.middleware import ( AgentMiddleware, ModelCallResult, ModelRequest, @@ -20,9 +21,9 @@ ) except ImportError as e: msg = ( - "AnthropicPromptCachingMiddleware requires 'langchain' to be installed. " - "This middleware is designed for use with LangChain agents. " - "Install it with: pip install langchain" + "AnthropicPromptCachingMiddleware requires the local " + "'core.runtime.middleware' protocol definitions and " + "'langchain-anthropic' to be importable." ) raise ImportError(msg) from e @@ -32,7 +33,7 @@ class PromptCachingMiddleware(AgentMiddleware): Optimizes API usage by caching conversation prefixes for Anthropic models. - Requires both `langchain` and `langchain-anthropic` packages to be installed. + Requires the local runtime middleware protocol plus `langchain-anthropic`. Learn more about Anthropic prompt caching [here](https://platform.claude.com/docs/en/build-with-claude/prompt-caching). @@ -68,6 +69,26 @@ def __init__( self.min_messages_to_cache = min_messages_to_cache self.unsupported_model_behavior = unsupported_model_behavior + def _apply_system_cache(self, request: ModelRequest) -> ModelRequest: + """Add cache_control to the first (static) block of system_message. + + Anthropic prompt caching requires cache_control on the system content + blocks, not on messages. Marking the first block caches the entire + static system prefix (identity + tool rules) across sessions. + """ + sm = request.system_message + if sm is None: + return request + content = sm.content + if isinstance(content, str): + new_content: list = [{"type": "text", "text": content, "cache_control": {"type": self.type}}] + elif isinstance(content, list) and content: + first = {**content[0], "cache_control": {"type": self.type}} + new_content = [first, *content[1:]] + else: + return request + return request.override(system_message=SystemMessage(content=new_content)) + def _should_apply_caching(self, request: ModelRequest) -> bool: """Check if caching should be applied to the request. @@ -112,12 +133,7 @@ def wrap_model_call( """ if not self._should_apply_caching(request): return handler(request) - - new_model_settings = { - **request.model_settings, - "cache_control": {"type": self.type, "ttl": self.ttl}, - } - return handler(request.override(model_settings=new_model_settings)) + return handler(self._apply_system_cache(request)) async def awrap_model_call( self, @@ -135,12 +151,7 @@ async def awrap_model_call( """ if not self._should_apply_caching(request): return await handler(request) - - new_model_settings = { - **request.model_settings, - "cache_control": {"type": self.type, "ttl": self.ttl}, - } - return await handler(request.override(model_settings=new_model_settings)) + return await handler(self._apply_system_cache(request)) __all__ = ["PromptCachingMiddleware"] diff --git a/core/runtime/middleware/queue/__init__.py b/core/runtime/middleware/queue/__init__.py index f3d08f337..cf97229dc 100644 --- a/core/runtime/middleware/queue/__init__.py +++ b/core/runtime/middleware/queue/__init__.py @@ -2,7 +2,12 @@ from storage.contracts import QueueItem -from .formatters import format_background_notification, format_chat_notification, format_wechat_message +from .formatters import ( + format_agent_message, + format_background_notification, + format_chat_notification, + format_progress_notification, +) from .manager import MessageQueueManager from .middleware import SteeringMiddleware @@ -10,7 +15,8 @@ "MessageQueueManager", "QueueItem", "SteeringMiddleware", + "format_agent_message", "format_background_notification", "format_chat_notification", - "format_wechat_message", + "format_progress_notification", ] diff --git a/core/runtime/middleware/queue/formatters.py b/core/runtime/middleware/queue/formatters.py index 1e7821187..85034f7b4 100644 --- a/core/runtime/middleware/queue/formatters.py +++ b/core/runtime/middleware/queue/formatters.py @@ -11,13 +11,51 @@ def format_chat_notification(sender_name: str, chat_id: str, unread_count: int, signal: str | None = None) -> str: - """Lightweight notification — agent must chat_read to see content. + """Lightweight notification — agent must read_messages to see content. @@@v3-notification-only — no message content injected. Agent calls - chat_read(chat_id=...) to read, then chat_send() to reply. + read_messages(chat_id=...) to read, then send_message() to reply. """ signal_hint = f" [signal: {signal}]" if signal and signal != "open" else "" - return f"\nNew message from {sender_name} in chat {chat_id} ({unread_count} unread).{signal_hint}\n" + return ( + "\n" + f"New message from {sender_name} in chat {chat_id} ({unread_count} unread).{signal_hint}\n" + f'Read it with read_messages(chat_id="{chat_id}").\n' + f'Reply with send_message(chat_id="{chat_id}", content="...").\n' + "Prefer using this exact chat_id directly.\n" + "Do not treat your normal assistant text as a chat reply.\n" + "" + ) + + +def format_agent_message(sender_name: str, message: str) -> str: + """Format inter-agent delivery for steering injection on the next turn.""" + return ( + "\n" + "\n" + f" {escape(sender_name)}\n" + f" {escape(message)}\n" + "\n" + "" + ) + + +def format_progress_notification( + agent_id: str, + description: str, + *, + step: str = "running", +) -> str: + """Format background worker progress for coordinator-style prompt injection.""" + return ( + "\n" + "\n" + f" {escape(agent_id)}\n" + f" {escape(step)}\n" + f" {escape(description)}\n" + "\n" + "" + ) def format_background_notification( @@ -31,7 +69,7 @@ def format_background_notification( """Format background task completion as system-reminder XML.""" parts = [ "", - "", + "", f" {task_id}", f" {status}", ] @@ -44,29 +82,11 @@ def format_background_notification( parts.append(f" {escape(truncated)}") if usage: parts.append(f" {json.dumps(usage)}") - parts.append("") + parts.append("") parts.append("") return "\n".join(parts) -def format_wechat_message(sender_name: str, user_id: str, text: str) -> str: - """Format incoming WeChat message for thread delivery. - - Agent sees: full message with user_id metadata (needed for wechat_send reply). - Frontend sees: just the message text (system-reminder stripped). - """ - return ( - f"{text}\n" - "\n" - "\n" - f" {escape(sender_name)}\n" - f" {escape(user_id)}\n" - "\n" - 'To reply, use wechat_send(user_id="' + escape(user_id) + '", text="...").\n' - "" - ) - - def format_command_notification( command_id: str, status: Literal["completed", "failed"], diff --git a/core/runtime/middleware/queue/middleware.py b/core/runtime/middleware/queue/middleware.py index ccb9c30be..c713c33bd 100644 --- a/core/runtime/middleware/queue/middleware.py +++ b/core/runtime/middleware/queue/middleware.py @@ -10,11 +10,13 @@ from collections.abc import Awaitable, Callable from typing import Any -from langchain_core.messages import HumanMessage, ToolMessage +from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableConfig +from core.runtime.notifications import is_terminal_background_notification + try: - from langchain.agents.middleware.types import ( + from core.runtime.middleware import ( AgentMiddleware, ModelCallResult, ModelRequest, @@ -35,6 +37,50 @@ class AgentMiddleware: logger = logging.getLogger(__name__) +_STEER_NON_PREEMPTIVE_SYSTEM_NOTE = ( + "Steer requests accepted during an active run are non-preemptive. " + "If any tool call from the interrupted run already started, it was allowed to finish and its side effects may " + "already have happened. Do not claim that prior work was interrupted, prevented, cancelled, or rolled back. " + "Treat the steer as instructions for what to do next after that completed work, and answer honestly about any " + "side effects that may already exist." +) + + +def _is_terminal_background_notification(item: Any) -> bool: + return is_terminal_background_notification( + getattr(item, "content", None), + source="system", + notification_type=getattr(item, "notification_type", None), + ) + + +def _is_owner_steer_message(message: Any) -> bool: + if message.__class__.__name__ != "HumanMessage": + return False + metadata = getattr(message, "metadata", {}) or {} + return bool(metadata.get("is_steer") or (metadata.get("source") == "owner" and metadata.get("notification_type") == "steer")) + + +def _apply_steer_contract(request: ModelRequest) -> ModelRequest: + if not any(_is_owner_steer_message(message) for message in request.messages): + return request + + system_message = request.system_message + if system_message is None: + return request.override(system_message=SystemMessage(content=_STEER_NON_PREEMPTIVE_SYSTEM_NOTE)) + + content = getattr(system_message, "content", None) + if isinstance(content, str): + if _STEER_NON_PREEMPTIVE_SYSTEM_NOTE in content: + return request + # @@@steer-honesty-contract - mid-run steer stays a real user message in + # durable history, but the live model call also needs an explicit + # non-preemptive contract so it cannot overclaim that already-started + # tool work was stopped or never produced side effects. + return request.override(system_message=SystemMessage(content=f"{content}\n\n{_STEER_NON_PREEMPTIVE_SYSTEM_NOTE}")) + + return request.override(messages=[SystemMessage(content=_STEER_NON_PREEMPTIVE_SYSTEM_NOTE), *request.messages]) + class SteeringMiddleware(AgentMiddleware): """Non-preemptive steering: let all tool calls finish, inject before next LLM call. @@ -66,6 +112,20 @@ async def awrap_tool_call( """Async pure passthrough — never skip tool calls.""" return await handler(request) + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + return handler(_apply_steer_contract(request)) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + return await handler(_apply_steer_contract(request)) + def before_model( self, state: Any, @@ -79,7 +139,27 @@ def before_model( return None items = self._queue_manager.drain_all(thread_id) - rt = self._agent_runtime + inject_now = [] + deferred = [] + for item in items: + if _is_terminal_background_notification(item): + deferred.append(item) + else: + inject_now.append(item) + # @@@followup-defer - terminal background notifications must never be + # injected inline into an active run. Their stable contract is a + # dedicated followthrough notice-only turn, regardless of the current + # run source. + for item in deferred: + self._queue_manager.enqueue( + item.content, + thread_id, + notification_type=item.notification_type, + source=item.source, + sender_id=item.sender_id, + sender_name=item.sender_name, + ) + items = inject_now if not items: return None @@ -109,14 +189,15 @@ def before_model( # breaks the turn at the steer injection point. # user_message is NOT emitted here — wake_handler already did it # at enqueue time (@@@steer-instant-feedback). - if has_steer and rt and hasattr(rt, "emit_activity_event"): - rt.emit_activity_event( + agent_runtime = self._agent_runtime + if has_steer and agent_runtime and hasattr(agent_runtime, "emit_activity_event"): + agent_runtime.emit_activity_event( { "event": "run_done", "data": json.dumps({"thread_id": thread_id}), } ) - rt.emit_activity_event( + agent_runtime.emit_activity_event( { "event": "run_start", "data": json.dumps({"thread_id": thread_id, "showing": True}), diff --git a/core/runtime/middleware/spill_buffer/middleware.py b/core/runtime/middleware/spill_buffer/middleware.py index ca519cb27..66390718d 100644 --- a/core/runtime/middleware/spill_buffer/middleware.py +++ b/core/runtime/middleware/spill_buffer/middleware.py @@ -2,28 +2,16 @@ from __future__ import annotations +import json +import mimetypes +import posixpath from collections.abc import Awaitable, Callable from pathlib import Path from typing import Any from langchain_core.messages import ToolMessage -try: - from langchain.agents.middleware.types import ( - AgentMiddleware, - ModelRequest, - ModelResponse, - ToolCallRequest, - ) -except ImportError: - - class AgentMiddleware: # type: ignore[no-redef] - pass - - ModelRequest = Any - ModelResponse = Any - ToolCallRequest = Any - +from core.runtime.middleware import AgentMiddleware, ModelRequest, ModelResponse, ToolCallRequest from core.tools.filesystem.backend import FileSystemBackend from .spill import spill_if_needed @@ -57,6 +45,53 @@ def __init__( self.thresholds: dict[str, int] = thresholds or {} self.default_threshold = default_threshold + def _rewrite_mcp_blocks(self, content: Any, *, tool_call_id: str) -> Any: + if not isinstance(content, list): + return content + + lines: list[str] = [] + saw_mcp_blocks = False + + for index, block in enumerate(content): + if not isinstance(block, dict): + return content + + kind = block.get("type") + if kind == "text": + lines.append(str(block.get("text", ""))) + continue + + saw_mcp_blocks = True + mime_type = str(block.get("mime_type") or "application/octet-stream") + guessed_ext = mimetypes.guess_extension(mime_type.split(";", 1)[0].strip()) or ".bin" + + if isinstance(block.get("base64"), str): + payload_path = posixpath.join( + self.workspace_root, + ".leon", + "tool-results", + f"{tool_call_id}-{index}{guessed_ext}.base64", + ) + # @@@mcp-binary-handoff - api-04 keeps Leon's sandbox/file + # abstraction by persisting encoded payloads through fs_backend + # instead of writing host-local bytes behind the sandbox's back. + write_result = self.fs_backend.write_file(payload_path, block["base64"]) + if hasattr(write_result, "success") and not write_result.success: + raise RuntimeError(write_result.error or f"failed to persist MCP payload to {payload_path}") + lines.append(f"MCP binary content ({mime_type}) saved to {payload_path} as base64 payload.") + continue + + if isinstance(block.get("url"), str): + lines.append(f"MCP {kind} content available at {block['url']} ({mime_type})") + continue + + lines.append(json.dumps(block, ensure_ascii=False, default=str)) + + if not saw_mcp_blocks: + text_only = "\n".join(line for line in lines if line) + return text_only if text_only else content + return "\n".join(line for line in lines if line) + # -- model call: pass-through ------------------------------------------ def wrap_model_call( @@ -81,6 +116,19 @@ def _maybe_spill(self, request: ToolCallRequest, result: ToolMessage) -> ToolMes if tool_name in SKIP_TOOLS: return result + source = result.additional_kwargs.get("tool_result_meta", {}).get("source") + normalized_content = result.content + if source == "mcp": + normalized_content = self._rewrite_mcp_blocks( + normalized_content, + tool_call_id=request.tool_call.get("id", "unknown"), + ) + if normalized_content is not result.content: + result = result.model_copy(update={"content": normalized_content}) + + if isinstance(result.content, str) and not result.content.strip(): + return result.model_copy(update={"content": f"({tool_name} completed with no output)"}) + threshold = self.thresholds.get(tool_name, self.default_threshold) tool_call_id = request.tool_call.get("id", "unknown") @@ -93,10 +141,10 @@ def _maybe_spill(self, request: ToolCallRequest, result: ToolMessage) -> ToolMes ) if spilled is not result.content: - return ToolMessage( - content=spilled, - tool_call_id=result.tool_call_id, - ) + # @@@spill-message-preservation - replacing content must not discard + # metadata/name/id; te-03 is about persisted handoff, not rebuilding + # a thinner ToolMessage shell. + return result.model_copy(update={"content": spilled}) return result def wrap_tool_call( diff --git a/core/runtime/middleware/spill_buffer/spill.py b/core/runtime/middleware/spill_buffer/spill.py index 8246a4f33..58cfa470e 100644 --- a/core/runtime/middleware/spill_buffer/spill.py +++ b/core/runtime/middleware/spill_buffer/spill.py @@ -2,7 +2,7 @@ from __future__ import annotations -import os +import posixpath from typing import Any from core.tools.filesystem.backend import FileSystemBackend @@ -10,6 +10,14 @@ PREVIEW_BYTES = 2048 +def _format_preview(content: str) -> str: + preview = content[:PREVIEW_BYTES] + cutoff = preview.rfind("\n") + if cutoff >= PREVIEW_BYTES // 2: + return preview[:cutoff] + return preview + + def spill_if_needed( content: Any, threshold_bytes: int, @@ -36,8 +44,8 @@ def spill_if_needed( if size <= threshold_bytes: return content - spill_dir = os.path.join(workspace_root, ".leon", "tool-results") - spill_path = os.path.join(spill_dir, f"{tool_call_id}.txt") + spill_dir = posixpath.join(workspace_root, ".leon", "tool-results") + spill_path = posixpath.join(spill_dir, f"{tool_call_id}.txt") write_note = "" try: @@ -50,10 +58,15 @@ def spill_if_needed( write_note = f"\n\n(Warning: failed to save full output to disk: {exc})" spill_path = "" - preview = content[:PREVIEW_BYTES] + # @@@persisted-output-wrapper - te-03 is about durable handoff semantics, + # not "shorter string". The model must see an explicit persisted artifact + # boundary plus the re-read path, otherwise we silently amputate context. + preview = _format_preview(content) return ( - f"Output too large ({size} bytes). Full output saved to: {spill_path}" - f"\n\nUse read_file to view specific sections with offset and limit parameters." - f"\n\nPreview (first {PREVIEW_BYTES} bytes):\n{preview}\n..." - f"{write_note}" + f'' + f"\nSize: {size} bytes" + f"\nUse read_file to inspect the full persisted output." + f"\nPreview (first {PREVIEW_BYTES} bytes):\n{preview}\n..." + f"{write_note}\n" + f"" ) diff --git a/core/runtime/notifications.py b/core/runtime/notifications.py new file mode 100644 index 000000000..f70ffc1fa --- /dev/null +++ b/core/runtime/notifications.py @@ -0,0 +1,13 @@ +from __future__ import annotations + + +def is_terminal_background_notification( + content: str | None, + *, + source: str | None, + notification_type: str | None, +) -> bool: + if source != "system" or notification_type not in {"agent", "command"}: + return False + text = content or "" + return "" in text or "" in text diff --git a/core/runtime/permissions.py b/core/runtime/permissions.py new file mode 100644 index 000000000..37c182ed7 --- /dev/null +++ b/core/runtime/permissions.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +PERMISSION_RULE_SOURCES = ( + "userSettings", + "projectSettings", + "localSettings", + "flagSettings", + "policySettings", + "cliArg", + "session", +) + + +@dataclass(frozen=True) +class ToolPermissionContext: + is_read_only: bool + is_destructive: bool = False + # @@@camelcase-permission-surface - external state/routes already speak this camelCase shape. + alwaysAllowRules: dict[str, list[str]] | None = None # noqa: N815 + alwaysDenyRules: dict[str, list[str]] | None = None # noqa: N815 + alwaysAskRules: dict[str, list[str]] | None = None # noqa: N815 + allowManagedPermissionRulesOnly: bool = False # noqa: N815 + + +def can_auto_approve(context: ToolPermissionContext) -> bool: + return context.is_read_only and not context.is_destructive + + +def _active_sources(context: ToolPermissionContext) -> tuple[str, ...]: + if context.allowManagedPermissionRulesOnly: + return ("policySettings",) + return PERMISSION_RULE_SOURCES + + +def _extract_tool_name(rule: str) -> str: + rule = rule.strip() + open_paren = rule.find("(") + return rule if open_paren == -1 else rule[:open_paren] + + +def _find_matching_rule( + rule_buckets: dict[str, list[str]] | None, + tool_name: str, + *, + sources: tuple[str, ...], +) -> str | None: + if not rule_buckets: + return None + for source in sources: + for rule in rule_buckets.get(source, []): + if _extract_tool_name(rule) == tool_name: + return rule + return None + + +def evaluate_permission_rules( + tool_name: str, + context: ToolPermissionContext, +) -> dict[str, Any] | None: + sources = _active_sources(context) + + deny_rule = _find_matching_rule(context.alwaysDenyRules, tool_name, sources=sources) + if deny_rule is not None: + return {"decision": "deny", "message": f"Permission denied by rule: {deny_rule}"} + + ask_rule = _find_matching_rule(context.alwaysAskRules, tool_name, sources=sources) + if ask_rule is not None: + return {"decision": "ask", "message": f"Permission required by rule: {ask_rule}"} + + allow_rule = _find_matching_rule(context.alwaysAllowRules, tool_name, sources=sources) + if allow_rule is not None: + return {"decision": "allow", "message": f"Permission allowed by rule: {allow_rule}"} + + return None diff --git a/core/runtime/prompts.py b/core/runtime/prompts.py new file mode 100644 index 000000000..6077cf371 --- /dev/null +++ b/core/runtime/prompts.py @@ -0,0 +1,217 @@ +"""System prompt builders — pure functions, no agent state. + +Extracted from LeonAgent so agent.py stays lean. + +Middleware Stack +- MemoryMiddleware: trims/compacts conversation context before model calls. +- MonitorMiddleware: aggregates runtime metrics and observes model execution. +- PromptCachingMiddleware: enables Anthropic prompt caching for eligible requests. +- SteeringMiddleware: drains queued messages and injects them before the next model call. +- SpillBufferMiddleware: spills oversized tool outputs to disk and replaces them with previews. +""" + +from __future__ import annotations + +from typing import NamedTuple + + +class RuleSpec(NamedTuple): + title: str + body: str + details: tuple[str, ...] = () + + +def _render_rule(index: int, rule: RuleSpec) -> str: + rendered = f"{index}. **{rule.title}**: {rule.body}" + if not rule.details: + return rendered + return rendered + "\n" + "\n".join(f" - {detail}" for detail in rule.details) + + +def _build_core_rules(*, is_sandbox: bool, sandbox_name: str, workspace_root: str, working_dir: str) -> list[RuleSpec]: + rules: list[RuleSpec] = [] + if is_sandbox: + if sandbox_name == "docker": + location_rule = "All file and command operations run in a local Docker container, NOT on the user's host filesystem." + else: + location_rule = "All file and command operations run in a remote sandbox, NOT on the user's local machine." + rules.append(RuleSpec("Sandbox Environment", f"{location_rule} The sandbox is an isolated Linux environment.")) + else: + rules.append(RuleSpec("Workspace", "File operations are restricted to: " + workspace_root)) + + rules.append( + RuleSpec( + "Absolute Paths", + "All file paths must be absolute paths.", + ( + f"Correct: `{working_dir}/project/test.py`", + "Wrong: `test.py` or `./test.py`", + ), + ) + ) + + if is_sandbox: + security = "The sandbox is isolated. You can install packages, run any commands, and modify files freely." + else: + security = "Dangerous commands are blocked. All operations are logged." + rules.append(RuleSpec("Security", security)) + return rules + + +def _build_risk_rules() -> list[RuleSpec]: + return [ + RuleSpec( + "Risky Actions", + "Ask before destructive, hard-to-reverse, or shared-state actions.", + ( + "Examples: deleting files, force-pushing, dropping tables, killing unfamiliar processes, modifying shared infrastructure.", + "If you see unexpected state, investigate before deleting or overwriting it.", + ), + ), + RuleSpec( + "No URL Guessing", + "Do not guess URLs unless the user provided them or you are confident they are directly relevant to programming help.", + ), + RuleSpec( + "Minimal Change", + "Do not add features, refactor code, or make speculative abstractions beyond what the task requires.", + ( + "Don't create helpers, utilities, or abstractions for one-time operations.", + "Don't add error handling, fallbacks, or validation for scenarios that can't happen.", + ), + ), + ] + + +def _build_tool_preference_rules() -> list[RuleSpec]: + return [ + RuleSpec( + "Tool Priority", + "When a built-in tool and an MCP tool (`mcp__*`) have the same functionality, use the built-in tool.", + ), + RuleSpec( + "Tool Preference", + "Prefer dedicated tools over `Bash` when a built-in tool already matches the job.", + ( + "Use `Read` instead of `cat`, `head`, or `tail`.", + "Use `Edit` instead of shell text-munging for file edits.", + "Use `Write` instead of heredoc or echo redirection for file creation.", + "Use `Glob`/`Grep` for file discovery and content search before falling back to `Bash`.", + ), + ), + ] + + +def _build_interaction_rules() -> list[RuleSpec]: + return [] + + +def _build_function_result_clearing_rules(*, spill_buffer_enabled: bool, spill_keep_recent: int) -> list[RuleSpec]: + if not spill_buffer_enabled: + return [] + return [ + RuleSpec( + "Function Result Clearing", + f"Old tool results may be cleared from context to free up space. The {spill_keep_recent} most recent results are always kept.", + ( + "When working with tool results, write down any important information " + "you might need later in your response, as the original tool result " + "may be cleared later.", + ), + ) + ] + + +def _build_rule_specs( + *, + is_sandbox: bool, + sandbox_name: str, + workspace_root: str, + working_dir: str, + spill_buffer_enabled: bool, + spill_keep_recent: int, +) -> list[RuleSpec]: + rules: list[RuleSpec] = [] + rules.extend( + _build_core_rules( + is_sandbox=is_sandbox, + sandbox_name=sandbox_name, + workspace_root=workspace_root, + working_dir=working_dir, + ) + ) + rules.extend(_build_risk_rules()) + rules.extend(_build_tool_preference_rules()) + rules.extend( + _build_function_result_clearing_rules( + spill_buffer_enabled=spill_buffer_enabled, + spill_keep_recent=spill_keep_recent, + ) + ) + rules.extend(_build_interaction_rules()) + return rules + + +def build_context_section( + *, + sandbox_name: str, + sandbox_env_label: str = "", + sandbox_working_dir: str = "", + workspace_root: str = "", + os_name: str = "", + shell_name: str = "", +) -> str: + if sandbox_name != "local": + mode_label = "Sandbox (isolated local container)" if sandbox_name == "docker" else "Sandbox (isolated cloud environment)" + return f"""- Environment: {sandbox_env_label} +- Working Directory: {sandbox_working_dir} +- Mode: {mode_label}""" + return f"""- Workspace: `{workspace_root}` +- OS: {os_name} +- Shell: {shell_name} +- Mode: Local""" + + +def build_rules_section( + *, + is_sandbox: bool, + sandbox_name: str = "", + working_dir: str, + workspace_root: str, + spill_buffer_enabled: bool = False, + spill_keep_recent: int = 0, +) -> str: + rule_specs = _build_rule_specs( + is_sandbox=is_sandbox, + sandbox_name=sandbox_name, + workspace_root=workspace_root, + working_dir=working_dir, + spill_buffer_enabled=spill_buffer_enabled, + spill_keep_recent=spill_keep_recent, + ) + return "\n\n".join(_render_rule(index, rule) for index, rule in enumerate(rule_specs, start=1)) + + +def build_base_prompt(context: str, rules: str) -> str: + return f"""You are a highly capable AI assistant with access to file and system tools. + +**Context:** +{context} + +**Important Rules:** + +{rules} +""" + + +_AGENT_TOOL_SECTION = """ +**Sub-agent Types:** +- `explore`: Read-only codebase exploration (Grep, Glob, Read only) +- `plan`: Architecture design and planning (read-only tools) +- `bash`: Shell command execution (Bash + read tools) +- `general`: Full tool access for independent multi-step tasks +""" + + +def build_common_sections(skills_enabled: bool) -> str: + return _AGENT_TOOL_SECTION diff --git a/core/runtime/registry.py b/core/runtime/registry.py index f6a87f008..79cb48590 100644 --- a/core/runtime/registry.py +++ b/core/runtime/registry.py @@ -1,11 +1,46 @@ from __future__ import annotations from collections.abc import Awaitable, Callable +from copy import deepcopy from dataclasses import dataclass from enum import Enum +from typing import Any, NotRequired, Required, TypedDict, Unpack -Handler = Callable[..., str] | Callable[..., Awaitable[str]] -SchemaProvider = dict | Callable[[], dict] +from core.runtime.tool_result import ToolResultEnvelope + +type ToolSchema = dict[str, Any] +type ToolHandlerResult = str | ToolResultEnvelope +type ToolArgs = dict[str, Any] +type ToolPropertySchema = dict[str, Any] +type ToolProperties = dict[str, ToolPropertySchema] + +type Handler = Callable[..., ToolHandlerResult] | Callable[..., Awaitable[ToolHandlerResult]] +type SchemaProvider = ToolSchema | Callable[[], ToolSchema] +type ConcurrencySafety = bool | Callable[[ToolArgs], bool] +type ToolInputValidator = Callable[[ToolArgs, Any], ToolArgs | None] | Callable[[ToolArgs, Any], Awaitable[ToolArgs | None]] + + +class _ToolEntryDefaults(TypedDict): + search_hint: str + is_concurrency_safe: ConcurrencySafety + is_read_only: bool + is_destructive: bool + context_schema: ToolSchema | None + validate_input: ToolInputValidator | None + + +class _ToolEntryBuildArgs(TypedDict, total=False): + name: Required[str] + mode: Required[ToolMode] + schema: Required[SchemaProvider] + handler: Required[Handler] + source: Required[str] + search_hint: NotRequired[str] + is_concurrency_safe: NotRequired[ConcurrencySafety] + is_read_only: NotRequired[bool] + is_destructive: NotRequired[bool] + context_schema: NotRequired[ToolSchema | None] + validate_input: NotRequired[ToolInputValidator | None] class ToolMode(Enum): @@ -20,11 +55,56 @@ class ToolEntry: schema: SchemaProvider handler: Handler source: str - - def get_schema(self) -> dict: + search_hint: str = "" # 3-10 word capability description for ToolSearch matching + is_concurrency_safe: ConcurrencySafety = False # fail-closed: assume not safe + is_read_only: bool = False # fail-closed: assume write operation + is_destructive: bool = False # advisory metadata for permission/UI layers + context_schema: ToolSchema | None = None # fields this tool needs from ToolUseContext + validate_input: ToolInputValidator | None = None + + def get_schema(self) -> ToolSchema: return self.schema() if callable(self.schema) else self.schema +TOOL_DEFAULTS: _ToolEntryDefaults = { + "search_hint": "", + "is_concurrency_safe": False, + "is_read_only": False, + "is_destructive": False, + "context_schema": None, + "validate_input": None, +} + + +def build_tool(**kwargs: Unpack[_ToolEntryBuildArgs]) -> ToolEntry: + """Factory that fills in safety defaults. Fail-closed: assumes write + non-concurrent.""" + merged: _ToolEntryBuildArgs = {**TOOL_DEFAULTS, **kwargs} + return ToolEntry(**merged) + + +def make_tool_schema( + *, + name: str, + description: str, + properties: ToolProperties, + required: list[str] | None = None, + parameter_overrides: ToolSchema | None = None, +) -> ToolSchema: + parameters: ToolSchema = { + "type": "object", + "properties": properties, + } + if required: + parameters["required"] = required + if parameter_overrides: + parameters.update(parameter_overrides) + return { + "name": name, + "description": description, + "parameters": parameters, + } + + class ToolRegistry: """Central registry for all tools. @@ -55,23 +135,70 @@ def register(self, entry: ToolEntry) -> None: def get(self, name: str) -> ToolEntry | None: return self._tools.get(name) - def get_inline_schemas(self) -> list[dict]: - return [e.get_schema() for e in self._tools.values() if e.mode == ToolMode.INLINE] - - def search(self, query: str) -> list[ToolEntry]: - """Return all matching tools (including inline) for tool_search.""" - q = query.lower() - results = [] - for entry in self._tools.values(): + def get_inline_schemas(self, discovered_tool_names: set[str] | None = None) -> list[dict]: + discovered_tool_names = discovered_tool_names or set() + return [ + self._sanitize_schema_for_model(e.get_schema()) + for e in self._tools.values() + if e.mode == ToolMode.INLINE or e.name in discovered_tool_names + ] + + def _sanitize_schema_for_model(self, schema: dict) -> dict: + # @@@tool-schema-sanitize - runtime-only schema metadata is useful for + # validator/readiness, but provider tool schemas must stay within the + # subset the live model API accepts. + def _walk(value: Any) -> Any: + if isinstance(value, dict): + return {key: _walk(child) for key, child in value.items() if not (isinstance(key, str) and key.startswith("x-leon-"))} + if isinstance(value, list): + return [_walk(item) for item in value] + return value + + return _walk(deepcopy(schema)) + + def search(self, query: str, *, modes: set[ToolMode] | None = None) -> list[ToolEntry]: + """Return matching tools with ranked relevance. + + Supports ``select:Name1,Name2`` for exact selection. + Otherwise ranks by: search_hint > name > description. + """ + q = query.strip() + entries = [entry for entry in self._tools.values() if modes is None or entry.mode in modes] + + # --- select: exact lookup --- + if q.lower().startswith("select:"): + names = [n.strip() for n in q[len("select:") :].split(",") if n.strip()] + results = [self._tools[n] for n in names if n in self._tools and (modes is None or self._tools[n].mode in modes)] + return results + + # --- keyword search with ranking --- + keywords = q.lower().split() + if not keywords: + return list(entries) + + scored: list[tuple[int, ToolEntry]] = [] + for entry in entries: schema = entry.get_schema() - name = schema.get("name", "") - desc = schema.get("description", "") - if q in name.lower() or q in desc.lower(): - results.append(entry) - # If no match, return all - if not results: - results = list(self._tools.values()) - return results + name_lower = entry.name.lower() + hint_lower = entry.search_hint.lower() + desc_lower = schema.get("description", "").lower() + + score = 0 + for kw in keywords: + if kw in hint_lower: + score += 3 + if kw in name_lower: + score += 2 + if kw in desc_lower: + score += 1 + if score > 0: + scored.append((score, entry)) + + if not scored: + return [] + + scored.sort(key=lambda x: x[0], reverse=True) + return [entry for _, entry in scored] def list_all(self) -> list[ToolEntry]: return list(self._tools.values()) diff --git a/core/runtime/runner.py b/core/runtime/runner.py index ade917216..b40c7347a 100644 --- a/core/runtime/runner.py +++ b/core/runtime/runner.py @@ -1,23 +1,44 @@ from __future__ import annotations import asyncio +import copy +import inspect import json import logging +import threading from collections.abc import Awaitable, Callable +from typing import Any -from langchain.agents.middleware.types import ( +from langchain_core.messages import ToolMessage + +from core.runtime.middleware import ( AgentMiddleware, ModelRequest, ModelResponse, ToolCallRequest, ) -from langchain_core.messages import ToolMessage from .errors import InputValidationError +from .permissions import ToolPermissionContext from .registry import ToolRegistry +from .tool_result import ( + ToolResultEnvelope, + materialize_tool_message, + tool_error, + tool_permission_denied, + tool_permission_request, + tool_success, +) from .validator import ToolValidator logger = logging.getLogger(__name__) +DEFAULT_ASYNC_HOOK_TIMEOUT_S = 15.0 + + +class _ToolSpecificValidationError(Exception): + def __init__(self, message: str, error_code: str | None = None): + super().__init__(message) + self.error_code = error_code class ToolRunner(AgentMiddleware): @@ -48,9 +69,9 @@ def _inject_tools(self, request: ModelRequest) -> ModelRequest: def _extract_call_info(self, request: ToolCallRequest) -> tuple[str, dict, str]: tool_call = request.tool_call - name = tool_call.get("name") + name = tool_call.get("name") or "" args = tool_call.get("args", {}) - call_id = tool_call.get("id", "") + call_id = tool_call.get("id", "") or "" if isinstance(args, str): try: @@ -60,49 +81,896 @@ def _extract_call_info(self, request: ToolCallRequest) -> tuple[str, dict, str]: return name, args, call_id - def _validate_and_run(self, name: str, args: dict, call_id: str) -> ToolMessage: + @staticmethod + def _get_request_hook(request: ToolCallRequest, hook_name: str): + state = getattr(request, "state", None) + if state is None: + return None + if isinstance(state, dict): + hook = state.get(hook_name) + else: + hook = vars(state).get(hook_name) + if hook is None: + return None + if isinstance(hook, list): + return hook + return hook if callable(hook) else None + + @staticmethod + def _apply_result_hooks_sync( + hook_or_hooks, + payload: ToolMessage | ToolResultEnvelope, + request: ToolCallRequest, + ) -> ToolMessage | ToolResultEnvelope: + if hook_or_hooks is None: + return payload + hooks = hook_or_hooks if isinstance(hook_or_hooks, list) else [hook_or_hooks] + current = payload + for hook in hooks: + updated = hook(current, request) + if asyncio.iscoroutine(updated): + updated = ToolRunner._await_async_hook_with_timeout_sync( + request, + updated, + hook_name=getattr(hook, "__name__", type(hook).__name__), + ) + if updated is not None: + current = updated + return current + + @staticmethod + async def _apply_result_hooks( + hook_or_hooks, + payload: ToolMessage | ToolResultEnvelope, + request: ToolCallRequest, + ) -> ToolMessage | ToolResultEnvelope: + if hook_or_hooks is None: + return payload + hooks = hook_or_hooks if isinstance(hook_or_hooks, list) else [hook_or_hooks] + current = payload + + async def _invoke(hook): + updated = hook(copy.deepcopy(payload), request) + if asyncio.iscoroutine(updated): + updated = await ToolRunner._await_async_hook_with_timeout( + request, + updated, + hook_name=getattr(hook, "__name__", type(hook).__name__), + ) + return updated + + for updated in await asyncio.gather(*(_invoke(hook) for hook in hooks)): + if updated is not None: + current = updated + return current + + def _normalize_result(self, result: Any) -> ToolResultEnvelope: + if isinstance(result, ToolResultEnvelope): + return result + return tool_success(result) + + @staticmethod + def _resolve_context_path(state: Any, path: str) -> Any: + current = state + for segment in path.split("."): + if segment == "app_state": + current = current.get_app_state() + continue + if isinstance(current, dict): + current = current[segment] + else: + current = getattr(current, segment) + return current + + @staticmethod + def _inject_handler_context(entry, args: dict, request: ToolCallRequest) -> dict: + state = getattr(request, "state", None) + if state is None: + return args + try: + signature = inspect.signature(entry.handler) + except (TypeError, ValueError): + return args + accepts_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()) + injected = dict(args) + + context_schema = getattr(entry, "context_schema", None) or {} + if isinstance(context_schema, dict): + # @@@pt-02-context-schema-mapping + # Pattern 2 only becomes real once declared ToolUseContext field + # mappings are injected into handler kwargs on the live path. + for param_name, context_path in context_schema.items(): + if param_name in injected: + continue + if not accepts_kwargs and param_name not in signature.parameters: + continue + injected[param_name] = ToolRunner._resolve_context_path(state, context_path) + + if "tool_context" in injected: + return injected + if accepts_kwargs or "tool_context" in signature.parameters: + # @@@sa-04-tool-context-injection + # The sub-agent boundary only becomes real once the live ToolUseContext + # can cross the tool runner into handlers that explicitly opt in. + injected["tool_context"] = state + return injected + + @staticmethod + def _coerce_permission_response(result) -> tuple[str | None, str | None]: + if result is None: + return None, None + if isinstance(result, str): + return result, None + if isinstance(result, dict): + decision = result.get("decision") or result.get("permission") + message = result.get("message") + return decision, message + decision = getattr(result, "decision", None) or getattr(result, "permission", None) + message = getattr(result, "message", None) + return decision, message + + @staticmethod + def _permission_denied_result(decision: str, message: str | None) -> ToolResultEnvelope: + if decision == "ask": + text = message or "Permission required" + else: + text = message or "Permission denied" + return tool_permission_denied( + text, + metadata={"decision": decision, "error_type": "permission_resolution"}, + ) + + @staticmethod + def _permission_request_result(request_id: str, message: str | None) -> ToolResultEnvelope: + return tool_permission_request( + message or "Permission required", + metadata={ + "decision": "ask", + "request_id": request_id, + "error_type": "permission_resolution", + }, + ) + + @staticmethod + def _materialize_permission_ask( + request_id: str | None, + message: str | None, + ) -> ToolResultEnvelope: + # @@@permission-ask-materialization + # Ask is only honest when a concrete request surface exists. Otherwise + # fail loudly as a deny so caller metadata matches the actual runtime. + if request_id is not None: + return ToolRunner._permission_request_result(request_id, message) + return ToolRunner._permission_denied_result("deny", message) + + @staticmethod + def _run_awaitable_sync(awaitable): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(awaitable) + + result_box: list[Any] = [] + error_box: list[BaseException] = [] + + # @@@sync-awaitable-bridge - sync tool entrypoints still need to consume + # async permission checkers even when called from a live event loop. + def _runner() -> None: + try: + result_box.append(asyncio.run(awaitable)) + except BaseException as exc: # pragma: no cover - re-raised below + error_box.append(exc) + + thread = threading.Thread(target=_runner, daemon=True) + thread.start() + thread.join() + + if error_box: + raise error_box[0] + return result_box[0] if result_box else None + + @staticmethod + def _get_async_hook_timeout_s(request: ToolCallRequest) -> float: + state = getattr(request, "state", None) + if state is None: + return DEFAULT_ASYNC_HOOK_TIMEOUT_S + hook_timeout_ms = state.get("hook_timeout_ms") if isinstance(state, dict) else getattr(state, "hook_timeout_ms", None) + if isinstance(hook_timeout_ms, (int, float)) and hook_timeout_ms > 0: + return float(hook_timeout_ms) / 1000.0 + hook_timeout_s = state.get("hook_timeout_s") if isinstance(state, dict) else getattr(state, "hook_timeout_s", None) + if isinstance(hook_timeout_s, (int, float)) and hook_timeout_s > 0: + return float(hook_timeout_s) + return DEFAULT_ASYNC_HOOK_TIMEOUT_S + + @staticmethod + async def _await_async_hook_with_timeout( + request: ToolCallRequest, + awaitable, + *, + hook_name: str, + ): + timeout_s = ToolRunner._get_async_hook_timeout_s(request) + task = asyncio.create_task(awaitable) + try: + return await asyncio.wait_for(task, timeout=timeout_s) + except TimeoutError: + logger.warning("Async hook %s timed out after %.3fs; ignoring hook result", hook_name, timeout_s) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + return None + + @staticmethod + def _await_async_hook_with_timeout_sync( + request: ToolCallRequest, + awaitable, + *, + hook_name: str, + ): + return ToolRunner._run_awaitable_sync( + ToolRunner._await_async_hook_with_timeout( + request, + awaitable, + hook_name=hook_name, + ) + ) + + @staticmethod + def _get_state_callable(request: ToolCallRequest, name: str): + state = getattr(request, "state", None) + if state is None: + return None + return state.get(name) if isinstance(state, dict) else getattr(state, name, None) + + def _consume_permission_resolution_sync( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + ) -> tuple[str | None, str | None]: + consumer = self._get_state_callable(request, "consume_permission_resolution") + if not callable(consumer): + return None, None + permission_context = ToolPermissionContext( + is_read_only=bool(getattr(entry, "is_read_only", False)), + is_destructive=bool(getattr(entry, "is_destructive", False)), + ) + result = consumer(name, args, permission_context, request) + if asyncio.iscoroutine(result): + result = self._run_awaitable_sync(result) + return self._coerce_permission_response(result) + + async def _consume_permission_resolution_async( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + ) -> tuple[str | None, str | None]: + consumer = self._get_state_callable(request, "consume_permission_resolution") + if not callable(consumer): + return None, None + permission_context = ToolPermissionContext( + is_read_only=bool(getattr(entry, "is_read_only", False)), + is_destructive=bool(getattr(entry, "is_destructive", False)), + ) + result = consumer(name, args, permission_context, request) + if asyncio.iscoroutine(result): + result = await result + return self._coerce_permission_response(result) + + def _request_permission_sync( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + message: str | None, + ) -> str | None: + requester = self._get_state_callable(request, "request_permission") + if not callable(requester): + return None + permission_context = ToolPermissionContext( + is_read_only=bool(getattr(entry, "is_read_only", False)), + is_destructive=bool(getattr(entry, "is_destructive", False)), + ) + result = requester(name, args, permission_context, request, message) + if asyncio.iscoroutine(result): + result = self._run_awaitable_sync(result) + if isinstance(result, dict): + request_id = result.get("request_id") + return request_id if isinstance(request_id, str) else None + return result if isinstance(result, str) else None + + async def _request_permission_async( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + message: str | None, + ) -> str | None: + requester = self._get_state_callable(request, "request_permission") + if not callable(requester): + return None + permission_context = ToolPermissionContext( + is_read_only=bool(getattr(entry, "is_read_only", False)), + is_destructive=bool(getattr(entry, "is_destructive", False)), + ) + result = requester(name, args, permission_context, request, message) + if asyncio.iscoroutine(result): + result = await result + if isinstance(result, dict): + request_id = result.get("request_id") + return request_id if isinstance(request_id, str) else None + return result if isinstance(result, str) else None + + def _run_tool_specific_validation_sync(self, entry, args: dict, request: ToolCallRequest) -> dict: + validator = getattr(entry, "validate_input", None) + if validator is None: + return args + result = validator(dict(args), request) + if result is None: + return args + if isinstance(result, dict): + if result.get("result") is False or result.get("ok") is False: + raise _ToolSpecificValidationError( + result.get("message") or "Tool-specific validation failed", + result.get("errorCode") or result.get("error_code"), + ) + return result + raise InputValidationError(str(result)) + + async def _run_tool_specific_validation_async(self, entry, args: dict, request: ToolCallRequest) -> dict: + validator = getattr(entry, "validate_input", None) + if validator is None: + return args + result = validator(dict(args), request) + if asyncio.iscoroutine(result): + result = await result + if result is None: + return args + if isinstance(result, dict): + if result.get("result") is False or result.get("ok") is False: + raise _ToolSpecificValidationError( + result.get("message") or "Tool-specific validation failed", + result.get("errorCode") or result.get("error_code"), + ) + return result + raise InputValidationError(str(result)) + + def _run_pre_tool_use_sync(self, request: ToolCallRequest, *, name: str, args: dict, entry) -> tuple[dict, str | None, str | None]: + hooks = self._get_request_hook(request, "pre_tool_use") + if hooks is None: + return args, None, None + payload = {"name": name, "args": dict(args), "entry": entry} + permission: str | None = None + message: str | None = None + hook_list = hooks if isinstance(hooks, list) else [hooks] + for hook in hook_list: + updated = hook(payload, request) + if asyncio.iscoroutine(updated): + updated = self._await_async_hook_with_timeout_sync( + request, + updated, + hook_name=getattr(hook, "__name__", type(hook).__name__), + ) + if updated is None: + continue + if isinstance(updated, dict): + if "args" in updated: + payload["args"] = updated["args"] + if "name" in updated: + payload["name"] = updated["name"] + if "entry" in updated: + payload["entry"] = updated["entry"] + new_permission, new_message = self._coerce_permission_response(updated) + if new_permission is not None: + permission = new_permission + message = new_message + return payload["args"], permission, message + + async def _run_pre_tool_use_async( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + ) -> tuple[dict, str | None, str | None]: + hooks = self._get_request_hook(request, "pre_tool_use") + if hooks is None: + return args, None, None + payload = {"name": name, "args": dict(args), "entry": entry} + permission: str | None = None + message: str | None = None + hook_list = hooks if isinstance(hooks, list) else [hooks] + + async def _invoke(hook): + updated = hook({"name": name, "args": dict(args), "entry": entry}, request) + if asyncio.iscoroutine(updated): + updated = await self._await_async_hook_with_timeout( + request, + updated, + hook_name=getattr(hook, "__name__", type(hook).__name__), + ) + return updated + + # @@@pt-06-hook-fanout + # Pattern 6 requires hooks to fan out instead of impersonating a + # middleware chain. We still fold results back in hook-list order so + # the aggregation stays deterministic. + for updated in await asyncio.gather(*(_invoke(hook) for hook in hook_list)): + if updated is None: + continue + if isinstance(updated, dict): + if "args" in updated: + next_args = updated["args"] + if isinstance(next_args, dict): + payload["args"] = {**payload["args"], **next_args} + else: + payload["args"] = next_args + if "name" in updated: + payload["name"] = updated["name"] + if "entry" in updated: + payload["entry"] = updated["entry"] + new_permission, new_message = self._coerce_permission_response(updated) + if new_permission == "deny" and permission != "deny": + permission = new_permission + message = new_message + elif new_permission == "ask" and permission not in {"deny", "ask"}: + permission = new_permission + message = new_message + elif new_permission == "allow" and permission is None: + permission = new_permission + message = new_message + return payload["args"], permission, message + + def _run_permission_request_hooks_sync( + self, + request: ToolCallRequest, + *, + name: str, + entry, + message: str | None, + ) -> tuple[str | None, str | None]: + hooks = self._get_request_hook(request, "permission_request_hooks") + if hooks is None: + return None, message + payload = {"name": name, "entry": entry, "message": message} + permission: str | None = None + hook_message = message + hook_list = hooks if isinstance(hooks, list) else [hooks] + for hook in hook_list: + updated = hook(payload, request) + if asyncio.iscoroutine(updated): + updated = self._await_async_hook_with_timeout_sync( + request, + updated, + hook_name=getattr(hook, "__name__", type(hook).__name__), + ) + if updated is None: + continue + if isinstance(updated, dict): + new_permission, new_message = self._coerce_permission_response(updated) + if new_permission is not None: + permission = new_permission + if new_message is not None: + hook_message = new_message + return permission, hook_message + + async def _run_permission_request_hooks_async( + self, + request: ToolCallRequest, + *, + name: str, + entry, + message: str | None, + ) -> tuple[str | None, str | None]: + hooks = self._get_request_hook(request, "permission_request_hooks") + if hooks is None: + return None, message + payload = {"name": name, "entry": entry, "message": message} + permission: str | None = None + hook_message = message + hook_list = hooks if isinstance(hooks, list) else [hooks] + + async def _invoke(hook): + updated = hook(payload, request) + if asyncio.iscoroutine(updated): + updated = await self._await_async_hook_with_timeout( + request, + updated, + hook_name=getattr(hook, "__name__", type(hook).__name__), + ) + return updated + + for updated in await asyncio.gather(*(_invoke(hook) for hook in hook_list)): + if updated is None: + continue + if isinstance(updated, dict): + new_permission, new_message = self._coerce_permission_response(updated) + if new_permission == "deny" and permission != "deny": + permission = new_permission + elif new_permission == "ask" and permission not in {"deny", "ask"}: + permission = new_permission + elif new_permission == "allow" and permission is None: + permission = new_permission + if new_message is not None: + hook_message = new_message + return permission, hook_message + + def _resolve_permission( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + hook_permission: str | None, + hook_message: str | None, + ) -> ToolResultEnvelope | None: + if hook_permission == "deny": + return self._permission_denied_result("deny", hook_message) + + checker = self._get_state_callable(request, "can_use_tool") + rule_permission: str | None = None + rule_message: str | None = None + permission_context = ToolPermissionContext( + is_read_only=bool(getattr(entry, "is_read_only", False)), + is_destructive=bool(getattr(entry, "is_destructive", False)), + ) + if callable(checker): + result = checker(name, args, permission_context, request) + if asyncio.iscoroutine(result): + result = self._run_awaitable_sync(result) + rule_permission, rule_message = self._coerce_permission_response(result) + + # @@@permission-resolution-precedence - only consume one-shot approvals when current state still asks. + if rule_permission == "ask": + resolved_permission, resolved_message = self._consume_permission_resolution_sync( + request, + name=name, + args=args, + entry=entry, + ) + if resolved_permission == "allow": + return None + if resolved_permission in {"deny", "ask"}: + return self._permission_denied_result(resolved_permission, resolved_message) + request_hook_permission, request_hook_message = self._run_permission_request_hooks_sync( + request, + name=name, + entry=entry, + message=rule_message, + ) + if request_hook_permission == "allow": + return None + if request_hook_permission in {"deny", "ask"}: + return self._permission_denied_result(request_hook_permission, request_hook_message) + rule_message = request_hook_message + + if hook_permission == "allow": + if rule_permission in {"deny", "ask"}: + if rule_permission == "ask": + request_id = self._request_permission_sync( + request, + name=name, + args=args, + entry=entry, + message=rule_message, + ) + return self._materialize_permission_ask(request_id, rule_message) + return self._permission_denied_result(rule_permission, rule_message) + return None + + if rule_permission in {"deny", "ask"}: + if rule_permission == "ask": + request_id = self._request_permission_sync( + request, + name=name, + args=args, + entry=entry, + message=rule_message, + ) + return self._materialize_permission_ask(request_id, rule_message) + return self._permission_denied_result(rule_permission, rule_message) + return None + + async def _resolve_permission_async( + self, + request: ToolCallRequest, + *, + name: str, + args: dict, + entry, + hook_permission: str | None, + hook_message: str | None, + ) -> ToolResultEnvelope | None: + if hook_permission == "deny": + return self._permission_denied_result("deny", hook_message) + + checker = self._get_state_callable(request, "can_use_tool") + rule_permission: str | None = None + rule_message: str | None = None + permission_context = ToolPermissionContext( + is_read_only=bool(getattr(entry, "is_read_only", False)), + is_destructive=bool(getattr(entry, "is_destructive", False)), + ) + if callable(checker): + result = checker(name, args, permission_context, request) + if asyncio.iscoroutine(result): + result = await result + rule_permission, rule_message = self._coerce_permission_response(result) + + # @@@permission-resolution-precedence - only consume one-shot approvals when current state still asks. + if rule_permission == "ask": + resolved_permission, resolved_message = await self._consume_permission_resolution_async( + request, + name=name, + args=args, + entry=entry, + ) + if resolved_permission == "allow": + return None + if resolved_permission in {"deny", "ask"}: + return self._permission_denied_result(resolved_permission, resolved_message) + request_hook_permission, request_hook_message = await self._run_permission_request_hooks_async( + request, + name=name, + entry=entry, + message=rule_message, + ) + if request_hook_permission == "allow": + return None + if request_hook_permission in {"deny", "ask"}: + return self._permission_denied_result(request_hook_permission, request_hook_message) + rule_message = request_hook_message + + if hook_permission == "allow": + if rule_permission in {"deny", "ask"}: + if rule_permission == "ask": + request_id = await self._request_permission_async( + request, + name=name, + args=args, + entry=entry, + message=rule_message, + ) + return self._materialize_permission_ask(request_id, rule_message) + return self._permission_denied_result(rule_permission, rule_message) + return None + + if rule_permission in {"deny", "ask"}: + if rule_permission == "ask": + request_id = await self._request_permission_async( + request, + name=name, + args=args, + entry=entry, + message=rule_message, + ) + return self._materialize_permission_ask(request_id, rule_message) + return self._permission_denied_result(rule_permission, rule_message) + return None + + def _materialize_result( + self, + envelope: ToolResultEnvelope, + *, + name: str, + call_id: str, + source: str, + ) -> ToolMessage: + return materialize_tool_message( + envelope, + tool_call_id=call_id, + name=name, + source=source, + ) + + @staticmethod + def _entry_source(entry) -> str: + return "mcp" if getattr(entry, "source", None) == "mcp" else "local" + + def _finalize_registered_result( + self, + envelope: ToolResultEnvelope, + *, + name: str, + call_id: str, + source: str, + ) -> ToolMessage | ToolResultEnvelope: + if source == "mcp": + return envelope + return self._materialize_result( + envelope, + name=name, + call_id=call_id, + source=source, + ) + + @staticmethod + def _select_hook_name(kind: str) -> str: + if kind == "error": + return "post_tool_use_failure" + if kind == "permission_denied": + return "permission_denied_hooks" + return "post_tool_use" + + @staticmethod + def _input_validation_metadata(error: InputValidationError) -> dict[str, object]: + metadata: dict[str, object] = {"error_type": "input_validation"} + if error.error_code: + metadata["error_code"] = error.error_code + if error.details: + metadata["error_details"] = error.details + return metadata + + def _validate_and_run(self, request: ToolCallRequest, name: str, args: dict, call_id: str) -> ToolMessage | ToolResultEnvelope | None: entry = self._registry.get(name) if entry is None: return None # not our tool + source = self._entry_source(entry) schema = entry.get_schema() try: self._validator.validate(schema, args) except InputValidationError as e: - return ToolMessage( - content=f"InputValidationError: {name} failed due to the following issue:\n{e}", - tool_call_id=call_id, + return self._finalize_registered_result( + tool_error( + f"InputValidationError: {name} failed due to the following issue:\n{e}", + metadata=self._input_validation_metadata(e), + ), + name=name, + call_id=call_id, + source=source, + ) + try: + args = self._run_tool_specific_validation_sync(entry, args, request) + except _ToolSpecificValidationError as e: + return self._finalize_registered_result( + tool_error( + f"ToolValidationError: {name} failed due to the following issue:\n{e}", + metadata={"error_type": "tool_input_validation", "error_code": e.error_code}, + ), + name=name, + call_id=call_id, + source=source, + ) + except InputValidationError as e: + return self._finalize_registered_result( + tool_error( + f"ToolValidationError: {name} failed due to the following issue:\n{e}", + metadata={"error_type": "tool_input_validation"}, + ), name=name, + call_id=call_id, + source=source, + ) + args, hook_permission, hook_message = self._run_pre_tool_use_sync( + request, + name=name, + args=args, + entry=entry, + ) + permission_result = self._resolve_permission( + request, + name=name, + args=args, + entry=entry, + hook_permission=hook_permission, + hook_message=hook_message, + ) + if permission_result is not None: + return self._finalize_registered_result( + permission_result, + name=name, + call_id=call_id, + source=source, ) + args = self._inject_handler_context(entry, args, request) try: result = entry.handler(**args) if asyncio.iscoroutine(result): result = asyncio.get_event_loop().run_until_complete(result) - return ToolMessage(content=str(result), tool_call_id=call_id, name=name) + return self._finalize_registered_result( + self._normalize_result(result), + name=name, + call_id=call_id, + source=source, + ) except Exception as e: logger.exception("Tool %s execution failed", name) - return ToolMessage( - content=f"{e}", - tool_call_id=call_id, + return self._finalize_registered_result( + tool_error( + f"{e}", + metadata={"error_type": "tool_execution"}, + ), name=name, + call_id=call_id, + source=source, ) - async def _validate_and_run_async(self, name: str, args: dict, call_id: str) -> ToolMessage | None: + async def _validate_and_run_async( + self, + request: ToolCallRequest, + name: str, + args: dict, + call_id: str, + ) -> ToolMessage | ToolResultEnvelope | None: entry = self._registry.get(name) if entry is None: return None + source = self._entry_source(entry) schema = entry.get_schema() try: self._validator.validate(schema, args) except InputValidationError as e: - return ToolMessage( - content=f"InputValidationError: {name} failed due to the following issue:\n{e}", - tool_call_id=call_id, + return self._finalize_registered_result( + tool_error( + f"InputValidationError: {name} failed due to the following issue:\n{e}", + metadata=self._input_validation_metadata(e), + ), name=name, + call_id=call_id, + source=source, + ) + try: + args = await self._run_tool_specific_validation_async(entry, args, request) + except _ToolSpecificValidationError as e: + return self._finalize_registered_result( + tool_error( + f"ToolValidationError: {name} failed due to the following issue:\n{e}", + metadata={"error_type": "tool_input_validation", "error_code": e.error_code}, + ), + name=name, + call_id=call_id, + source=source, + ) + except InputValidationError as e: + return self._finalize_registered_result( + tool_error( + f"ToolValidationError: {name} failed due to the following issue:\n{e}", + metadata={"error_type": "tool_input_validation"}, + ), + name=name, + call_id=call_id, + source=source, ) + args, hook_permission, hook_message = await self._run_pre_tool_use_async( + request, + name=name, + args=args, + entry=entry, + ) + permission_result = await self._resolve_permission_async( + request, + name=name, + args=args, + entry=entry, + hook_permission=hook_permission, + hook_message=hook_message, + ) + if permission_result is not None: + return self._finalize_registered_result( + permission_result, + name=name, + call_id=call_id, + source=source, + ) + + args = self._inject_handler_context(entry, args, request) try: if asyncio.iscoroutinefunction(entry.handler): result = await entry.handler(**args) @@ -113,13 +981,22 @@ async def _validate_and_run_async(self, name: str, args: dict, call_id: str) -> result = await asyncio.to_thread(entry.handler, **args) if asyncio.iscoroutine(result): result = await result - return ToolMessage(content=str(result), tool_call_id=call_id, name=name) + return self._finalize_registered_result( + self._normalize_result(result), + name=name, + call_id=call_id, + source=source, + ) except Exception as e: logger.exception("Tool %s execution failed", name) - return ToolMessage( - content=f"{e}", - tool_call_id=call_id, + return self._finalize_registered_result( + tool_error( + f"{e}", + metadata={"error_type": "tool_execution"}, + ), name=name, + call_id=call_id, + source=source, ) # -- Model call wrappers -- @@ -146,10 +1023,26 @@ def wrap_tool_call( handler: Callable[[ToolCallRequest], ToolMessage], ) -> ToolMessage: name, args, call_id = self._extract_call_info(request) - result = self._validate_and_run(name, args, call_id) + entry = self._registry.get(name) + result = self._validate_and_run(request, name, args, call_id) if result is not None: - return result - return handler(request) + source = self._entry_source(entry) if entry is not None else "local" + if isinstance(result, ToolResultEnvelope): + hook_name = self._select_hook_name(result.kind) + hooks = self._get_request_hook(request, hook_name) + hooked = self._apply_result_hooks_sync(hooks, result, request) if hooks else result + if isinstance(hooked, ToolMessage): + return hooked + return self._materialize_result(hooked, name=name, call_id=call_id, source=source) + kind = result.additional_kwargs.get("tool_result_meta", {}).get("kind") + hook_name = self._select_hook_name(kind) + hooks = self._get_request_hook(request, hook_name) + maybe_updated = self._apply_result_hooks_sync(hooks, result, request) if hooks else result + if isinstance(maybe_updated, ToolMessage): + return maybe_updated + return self._materialize_result(maybe_updated, name=name, call_id=call_id, source=source) + upstream = handler(request) + return upstream async def awrap_tool_call( self, @@ -157,7 +1050,39 @@ async def awrap_tool_call( handler: Callable[[ToolCallRequest], Awaitable[ToolMessage]], ) -> ToolMessage: name, args, call_id = self._extract_call_info(request) - result = await self._validate_and_run_async(name, args, call_id) + entry = self._registry.get(name) + source = self._entry_source(entry) if entry is not None else "local" + result = await self._validate_and_run_async(request, name, args, call_id) if result is not None: - return result - return await handler(request) + # @@@tool-result-ordering + # te-02 keeps local tools materialize-first, but registered MCP + # tools must stay envelope-first so post hooks can see and modify + # structured output before final ToolMessage creation. + if isinstance(result, ToolResultEnvelope): + hook_name = self._select_hook_name(result.kind) + hooks = self._get_request_hook(request, hook_name) + hooked = await self._apply_result_hooks(hooks, result, request) + if isinstance(hooked, ToolMessage): + return hooked + return self._materialize_result(hooked, name=name, call_id=call_id, source=source) + meta = result.additional_kwargs.get("tool_result_meta", {}) + hook_name = self._select_hook_name(meta.get("kind")) + hooks = self._get_request_hook(request, hook_name) + hooked = await self._apply_result_hooks(hooks, result, request) + if isinstance(hooked, ToolMessage): + return hooked + return self._materialize_result(hooked, name=name, call_id=call_id, source=source) + + upstream = await handler(request) + post_tool_use = self._get_request_hook(request, "post_tool_use") + if isinstance(upstream, ToolResultEnvelope): + # MCP/upstream path: post hooks get first shot at the structured + # result, and only then do we materialize the ToolMessage. + hooked = await self._apply_result_hooks(post_tool_use, upstream, request) + if isinstance(hooked, ToolMessage): + return hooked + return self._materialize_result(hooked, name=name, call_id=call_id, source="mcp") + if isinstance(upstream, ToolMessage): + hooked = await self._apply_result_hooks(post_tool_use, upstream, request) + return hooked if isinstance(hooked, ToolMessage) else self._materialize_result(hooked, name=name, call_id=call_id, source="mcp") + return upstream diff --git a/core/runtime/state.py b/core/runtime/state.py new file mode 100644 index 000000000..80b53a4c2 --- /dev/null +++ b/core/runtime/state.py @@ -0,0 +1,172 @@ +"""Three-layer state models aligned with CC architecture. + +Layer 1: BootstrapConfig — survives /clear, process-level constants +Layer 2: AppState — per-session mutable state (Zustand-style store) +Layer 3: ToolUseContext — per-turn, holds live closures to AppState +""" + +from __future__ import annotations + +import uuid +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from .abort import AbortController +from .permissions import ToolPermissionContext + + +class ToolPermissionState(BaseModel): + # @@@camelcase-permission-surface - persisted/thread API surface already uses camelCase keys. + alwaysAllowRules: dict[str, list[str]] = Field(default_factory=dict) # noqa: N815 + alwaysDenyRules: dict[str, list[str]] = Field(default_factory=dict) # noqa: N815 + alwaysAskRules: dict[str, list[str]] = Field(default_factory=dict) # noqa: N815 + allowManagedPermissionRulesOnly: bool = False # noqa: N815 + + +class BootstrapConfig(BaseModel): + """Process-level configuration that survives /clear. + + Analogous to CC Bootstrap State (~85 fields). Contains workspace + identity, model config, security flags, and API credentials. + """ + + workspace_root: Path + original_cwd: Path | None = None + project_root: Path | None = None + cwd: Path | None = None + model_name: str + api_key: str | None = None + sandbox_type: str = "local" + permission_resolver_scope: str = "none" + + # Security flags (fail-closed defaults) + block_dangerous_commands: bool = True + block_network_commands: bool = False + enable_audit_log: bool = True + enable_web_tools: bool = False + + # File access + allowed_file_extensions: list[str] | None = None + extra_allowed_paths: list[str] | None = None + + # Turn limits + max_turns: int | None = None + + # Session identity + session_id: str = Field(default_factory=lambda: uuid.uuid4().hex) + parent_session_id: str | None = None + + # Session accumulators that survive turn-level resets + total_cost_usd: float = 0.0 + total_tool_duration_ms: int = 0 + + # Model settings + model_provider: str | None = None + base_url: str | None = None + context_limit: int | None = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def model_post_init(self, __context: Any) -> None: + self.workspace_root = Path(self.workspace_root) + self.original_cwd = Path(self.original_cwd) if self.original_cwd is not None else self.workspace_root + self.project_root = Path(self.project_root) if self.project_root is not None else self.workspace_root + self.cwd = Path(self.cwd) if self.cwd is not None else self.project_root + + +class AppState(BaseModel): + """Per-session mutable state. Analogous to CC AppState store. + + Implements a minimal Zustand-style store with getState/setState. + Not reactive — no subscriptions needed for Python backend. + """ + + messages: list = Field(default_factory=list) + turn_count: int = 0 + total_cost: float = 0.0 + compact_boundary_index: int = 0 + # Map of tool_name -> is_enabled (runtime overrides) + tool_overrides: dict[str, bool] = Field(default_factory=dict) + tool_permission_context: ToolPermissionState = Field(default_factory=ToolPermissionState) + pending_permission_requests: dict[str, dict[str, Any]] = Field(default_factory=dict) + resolved_permission_requests: dict[str, dict[str, Any]] = Field(default_factory=dict) + announced_mcp_instruction_blocks: dict[str, dict[str, str]] = Field(default_factory=dict) + # @@@session-hooks-not-watchers - keep this surface local and lifecycle-scoped. + # File watching remains a later outer-layer concern so Leon keeps the + # filesystem + terminal core decoupled. + session_hooks: dict[str, list[Any]] = Field(default_factory=dict) + + def get_state(self) -> AppState: + return self + + def set_state(self, updater: Callable[[AppState], AppState]) -> AppState: + updated = updater(self) + # Mutate in place (Python idiom — no immutable constraint needed here) + for field_name in AppState.model_fields: + setattr(self, field_name, getattr(updated, field_name)) + return self + + def add_session_hook(self, event: str, hook: Any) -> None: + hooks = list(self.session_hooks.get(event, [])) + hooks.append(hook) + self.session_hooks[event] = hooks + + def remove_session_hook(self, event: str, hook: Any) -> None: + hooks = [candidate for candidate in self.session_hooks.get(event, []) if candidate != hook] + if hooks: + self.session_hooks[event] = hooks + else: + self.session_hooks.pop(event, None) + + def get_session_hooks(self, event: str) -> list[Any]: + return list(self.session_hooks.get(event, [])) + + +AppStateUpdater = Callable[[AppState], AppState] +AppStateGetter = Callable[[], AppState] +AppStateSetter = Callable[[AppStateUpdater], AppState | None] +RefreshToolsHook = Callable[[], Awaitable[None] | None] +PermissionDecision = dict[str, Any] | None +PermissionChecker = Callable[ + [str, dict[str, Any], ToolPermissionContext, object], + PermissionDecision | Awaitable[PermissionDecision], +] +PermissionRequester = Callable[ + [str, dict[str, Any], ToolPermissionContext, object, str | None], + str | dict[str, Any] | None | Awaitable[str | dict[str, Any] | None], +] +PermissionResolutionConsumer = Callable[ + [str, dict[str, Any], ToolPermissionContext, object], + PermissionDecision | Awaitable[PermissionDecision], +] + + +class ToolUseContext(BaseModel): + """Per-turn context bag. Analogous to CC ToolUseContext. + + Carries live closures to AppState so tools can read/mutate session state. + Sub-agents receive a NO-OP set_app_state to prevent write-through. + """ + + bootstrap: BootstrapConfig + get_app_state: AppStateGetter = Field(exclude=True) + set_app_state: AppStateSetter = Field(exclude=True) + set_app_state_for_tasks: AppStateSetter | None = Field(default=None, exclude=True) + refresh_tools: RefreshToolsHook | None = Field(default=None, exclude=True) + can_use_tool: PermissionChecker | None = Field(default=None, exclude=True) + request_permission: PermissionRequester | None = Field(default=None, exclude=True) + consume_permission_resolution: PermissionResolutionConsumer | None = Field(default=None, exclude=True) + read_file_state: Any = Field(default_factory=dict, exclude=True) + loaded_nested_memory_paths: Any = Field(default_factory=set, exclude=True) + discovered_skill_names: Any = Field(default_factory=set, exclude=True) + discovered_tool_names: Any = Field(default_factory=set, exclude=True) + nested_memory_attachment_triggers: Any = Field(default_factory=set, exclude=True) + abort_controller: AbortController = Field(default_factory=AbortController, exclude=True) + messages: list = Field(default_factory=list) + thread_id: str = "default" + turn_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8]) + + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/core/runtime/tool_result.py b/core/runtime/tool_result.py new file mode 100644 index 000000000..1ccd24288 --- /dev/null +++ b/core/runtime/tool_result.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from langchain_core.messages import ToolMessage + + +@dataclass +class ToolResultEnvelope: + kind: str + content: Any + is_error: bool = False + top_level_blocks: list[Any] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +def tool_success(content: Any, *, metadata: dict[str, Any] | None = None) -> ToolResultEnvelope: + return ToolResultEnvelope( + kind="success", + content=content, + metadata=dict(metadata or {}), + ) + + +def tool_error(content: str, *, metadata: dict[str, Any] | None = None) -> ToolResultEnvelope: + return ToolResultEnvelope( + kind="error", + content=content, + is_error=True, + metadata=dict(metadata or {}), + ) + + +def tool_permission_denied( + content: str, + *, + top_level_blocks: list[Any] | None = None, + metadata: dict[str, Any] | None = None, +) -> ToolResultEnvelope: + return ToolResultEnvelope( + kind="permission_denied", + content=content, + is_error=True, + top_level_blocks=list(top_level_blocks or []), + metadata=dict(metadata or {}), + ) + + +def tool_permission_request( + content: str, + *, + top_level_blocks: list[Any] | None = None, + metadata: dict[str, Any] | None = None, +) -> ToolResultEnvelope: + return ToolResultEnvelope( + kind="permission_request", + content=content, + top_level_blocks=list(top_level_blocks or []), + metadata=dict(metadata or {}), + ) + + +def materialize_tool_message( + envelope: ToolResultEnvelope, + *, + tool_call_id: str, + name: str, + source: str, +) -> ToolMessage: + additional_kwargs = { + "tool_result_meta": { + "kind": envelope.kind, + "source": source, + "top_level_blocks": list(envelope.top_level_blocks), + **dict(envelope.metadata), + } + } + return ToolMessage( + content=envelope.content, + tool_call_id=tool_call_id, + name=name, + additional_kwargs=additional_kwargs, + ) diff --git a/core/runtime/validator.py b/core/runtime/validator.py index 84e678d07..46fa6d963 100644 --- a/core/runtime/validator.py +++ b/core/runtime/validator.py @@ -1,8 +1,45 @@ import json +import re from .errors import InputValidationError +def _required_sets(parameters: dict, key: str) -> list[list[str]]: + value = parameters.get(key, []) + if not isinstance(value, list): + return [] + sets: list[list[str]] = [] + for item in value: + if isinstance(item, dict): + required = item.get("required", []) + else: + required = item + if isinstance(required, list): + sets.append([field for field in required if isinstance(field, str)]) + return sets + + +def _required_sets_match(parameters: dict, args: dict) -> bool: + required = parameters.get("required", []) + if any(field not in args for field in required): + return False + + # @@@required-set-contract - some tools need one of several identifier sets + # before they're valid. Keep that contract in runtime metadata so + # validator/readiness stay aligned without sending unsupported top-level + # anyOf/oneOf schema to live providers. + any_of = _required_sets(parameters, "x-leon-required-any-of") or _required_sets(parameters, "anyOf") + if any_of: + return any(all(field in args for field in required) for required in any_of) + + one_of = _required_sets(parameters, "x-leon-required-one-of") or _required_sets(parameters, "oneOf") + if one_of: + matches = [required for required in one_of if all(field in args for field in required)] + return len(matches) == 1 + + return True + + class ValidationResult: def __init__(self, ok: bool, params: dict): self.ok = ok @@ -13,14 +50,43 @@ class ToolValidator: """Three-phase tool argument validation.""" def validate(self, schema: dict, args: dict) -> ValidationResult: - properties = schema.get("parameters", {}).get("properties", {}) - required = schema.get("parameters", {}).get("required", []) + parameters = schema.get("parameters", {}) + properties = parameters.get("properties", {}) # Phase 1: required fields - missing = [f for f in required if f not in args] - if missing: - msgs = [f"The required parameter `{f}` is missing" for f in missing] - raise InputValidationError("\n".join(msgs)) + if not _required_sets_match(parameters, args): + required = parameters.get("required", []) + missing = [f for f in required if f not in args] + if missing: + details = [ + { + "field": field, + "error_code": "REQUIRED_FIELD_MISSING", + "message": f"The required parameter `{field}` is missing", + } + for field in missing + ] + raise InputValidationError( + "\n".join(detail["message"] for detail in details), + error_code="REQUIRED_FIELD_MISSING" if len(details) == 1 else "INPUT_CONSTRAINT_VIOLATION", + details=details, + ) + any_of = _required_sets(parameters, "x-leon-required-any-of") or _required_sets(parameters, "anyOf") + one_of = _required_sets(parameters, "x-leon-required-one-of") or _required_sets(parameters, "oneOf") + if any_of: + message = f"Arguments must satisfy one of these required sets: {any_of}" + raise InputValidationError( + message, + error_code="REQUIRED_SET_UNSATISFIED", + details=[{"error_code": "REQUIRED_SET_UNSATISFIED", "message": message}], + ) + if one_of: + message = f"Arguments must satisfy exactly one of these required sets: {one_of}" + raise InputValidationError( + message, + error_code="REQUIRED_SET_UNSATISFIED", + details=[{"error_code": "REQUIRED_SET_UNSATISFIED", "message": message}], + ) # Phase 2: type check for name, val in args.items(): @@ -28,12 +94,38 @@ def validate(self, schema: dict, args: dict) -> ValidationResult: expected = prop.get("type") if expected and not self._type_matches(val, expected): actual = type(val).__name__ - raise InputValidationError(f"The parameter `{name}` type is expected as `{expected}` but provided as `{actual}`") + message = f"The parameter `{name}` type is expected as `{expected}` but provided as `{actual}`" + raise InputValidationError( + message, + error_code="INVALID_TYPE", + details=[ + { + "field": name, + "error_code": "INVALID_TYPE", + "expected": expected, + "actual": actual, + "message": message, + } + ], + ) - # Phase 3: enum validation + # Phase 3: scalar constraints + issues = self._validate_scalar_constraints(properties, args) + if issues: + raise InputValidationError( + "\n".join(str(issue["message"]) for issue in issues), + error_code=str(issues[0]["error_code"]) if len(issues) == 1 else "INPUT_CONSTRAINT_VIOLATION", + details=issues, + ) + + # Phase 4: enum validation issues = self._validate_enum(properties, args) if issues: - raise InputValidationError(json.dumps(issues)) + raise InputValidationError( + json.dumps(issues), + error_code="INVALID_ENUM" if len(issues) == 1 else "INPUT_CONSTRAINT_VIOLATION", + details=issues, + ) return ValidationResult(ok=True, params=args) @@ -51,11 +143,77 @@ def _type_matches(self, val, expected: str) -> bool: return True return isinstance(val, expected_type) - def _validate_enum(self, properties: dict, args: dict) -> list: - issues = [] + def _validate_enum(self, properties: dict, args: dict) -> list[dict[str, object]]: + issues: list[dict[str, object]] = [] for name, val in args.items(): prop = properties.get(name, {}) enum_vals = prop.get("enum") if enum_vals and val not in enum_vals: - issues.append({"field": name, "expected": enum_vals, "got": val}) + issues.append( + { + "field": name, + "error_code": "INVALID_ENUM", + "expected": enum_vals, + "got": val, + "message": f"The parameter `{name}` must be one of {enum_vals}, got {val!r}", + } + ) + return issues + + def _validate_scalar_constraints(self, properties: dict, args: dict) -> list[dict[str, object]]: + issues: list[dict[str, object]] = [] + for name, val in args.items(): + prop = properties.get(name, {}) + if isinstance(val, str): + min_length = prop.get("minLength") + if isinstance(min_length, int) and len(val) < min_length: + issues.append( + { + "field": name, + "error_code": "STRING_TOO_SHORT", + "message": f"The parameter `{name}` must be at least {min_length} characters long", + "minimum": min_length, + } + ) + max_length = prop.get("maxLength") + if isinstance(max_length, int) and len(val) > max_length: + issues.append( + { + "field": name, + "error_code": "STRING_TOO_LONG", + "message": f"The parameter `{name}` must be at most {max_length} characters long", + "maximum": max_length, + } + ) + pattern = prop.get("pattern") + if isinstance(pattern, str) and re.search(pattern, val) is None: + issues.append( + { + "field": name, + "error_code": "PATTERN_MISMATCH", + "message": f"The parameter `{name}` must match pattern `{pattern}`", + "pattern": pattern, + } + ) + if isinstance(val, (int, float)) and not isinstance(val, bool): + minimum = prop.get("minimum") + if isinstance(minimum, (int, float)) and val < minimum: + issues.append( + { + "field": name, + "error_code": "NUMBER_TOO_SMALL", + "message": f"The parameter `{name}` must be at least {minimum}", + "minimum": minimum, + } + ) + maximum = prop.get("maximum") + if isinstance(maximum, (int, float)) and val > maximum: + issues.append( + { + "field": name, + "error_code": "NUMBER_TOO_LARGE", + "message": f"The parameter `{name}` must be at most {maximum}", + "maximum": maximum, + } + ) return issues diff --git a/core/tools/command/base.py b/core/tools/command/base.py index e716420b2..a13ee7654 100644 --- a/core/tools/command/base.py +++ b/core/tools/command/base.py @@ -8,3 +8,10 @@ from sandbox.interfaces.executor import AsyncCommand, BaseExecutor, ExecuteResult __all__ = ["BaseExecutor", "ExecuteResult", "AsyncCommand"] + + +def describe_execution_exception(exc: Exception) -> str: + detail = str(exc).strip() + if detail: + return detail + return exc.__class__.__name__ diff --git a/core/tools/command/hooks/dangerous_commands.py b/core/tools/command/hooks/dangerous_commands.py index 496251292..3abde2337 100644 --- a/core/tools/command/hooks/dangerous_commands.py +++ b/core/tools/command/hooks/dangerous_commands.py @@ -1,6 +1,7 @@ """Dangerous commands hook - blocks commands that may harm the system.""" import re +import shlex from pathlib import Path from typing import Any @@ -40,6 +41,32 @@ class DangerousCommandsHook(BashHook): r"\bssh\b", ] + DEFAULT_BLOCKED_BASE_COMMANDS = { + "rmdir", + "chmod", + "chown", + "sudo", + "su", + "kill", + "pkill", + "reboot", + "shutdown", + "mkfs", + "dd", + } + NETWORK_BASE_COMMANDS = { + "curl", + "wget", + "scp", + "sftp", + "rsync", + "ssh", + } + OPERATOR_TOKENS = {";", ";;", "&", "&&", "|", "||", "(", ")"} + ENV_ASSIGN_RE = re.compile(r"^[A-Za-z_]\w*=") + ANSI_C_QUOTE_RE = re.compile(r"\$'[^']*'") + LOCALE_QUOTE_RE = re.compile(r'\$"[^"]*"') + def __init__( self, workspace_root: Path | str | None = None, @@ -58,13 +85,140 @@ def __init__( patterns.extend(custom_blocked) self.compiled_patterns = [re.compile(p, re.IGNORECASE) for p in patterns] + self.blocked_base_commands = set(self.DEFAULT_BLOCKED_BASE_COMMANDS) + if block_network: + self.blocked_base_commands.update(self.NETWORK_BASE_COMMANDS) if verbose: print(f"[DangerousCommands] Loaded {len(self.compiled_patterns)} blocked command patterns") + @staticmethod + def _unquoted_command(command: str) -> str: + # @@@bash-hook-unquoted-scan - dangerous regexes should only inspect executable shell surface, + # not literal text inside quotes. + pieces: list[str] = [] + in_single = False + in_double = False + escaped = False + + for char in command: + if escaped: + if not in_single and not in_double: + pieces.append(char) + escaped = False + continue + + if char == "\\" and not in_single: + if not in_double: + pieces.append(char) + escaped = True + continue + + if char == "'" and not in_double: + in_single = not in_single + continue + + if char == '"' and not in_single: + in_double = not in_double + continue + + if not in_single and not in_double and char == "#": + prev = pieces[-1] if pieces else "" + if not prev or prev.isspace(): + break + + if not in_single and not in_double: + pieces.append(char) + + return "".join(pieces) + + @classmethod + def _has_dangerous_rm_flags(cls, tokens: list[str], start: int) -> bool: + recursive = False + force = False + + for token in tokens[start:]: + if token in cls.OPERATOR_TOKENS: + break + if token == "--": + break + lowered = token.lower() + if lowered == "--recursive": + recursive = True + elif lowered == "--force": + force = True + elif lowered.startswith("-"): + short_flags = lowered[1:] + recursive = recursive or "r" in short_flags + force = force or "f" in short_flags + if recursive and force: + return True + + return False + + def _find_dangerous_command_word(self, command: str) -> str | None: + try: + lexer = shlex.shlex(command, posix=True, punctuation_chars=";&|()<>") + except ValueError: + return None + lexer.whitespace_split = True + lexer.commenters = "#" + tokens = list(lexer) + command_position = True + + for index, token in enumerate(tokens): + if token in self.OPERATOR_TOKENS: + command_position = True + continue + + if token in {"<", ">", ">>", "<<", "<<<", "<>", ">|", "&>", "2>", "1>"}: + command_position = False + continue + + if not command_position: + continue + + if self.ENV_ASSIGN_RE.match(token): + continue + + if token in self.blocked_base_commands: + return token + + if token == "rm" and self._has_dangerous_rm_flags(tokens, index + 1): + return "rm -rf" + + command_position = False + + return None + def check_command(self, command: str, context: dict[str, Any]) -> HookResult: + stripped = command.strip() + if self.ANSI_C_QUOTE_RE.search(stripped) or self.LOCALE_QUOTE_RE.search(stripped): + return HookResult.block_command( + error_message=( + f"❌ SECURITY ERROR: Dangerous command detected\n" + f" Command: {command[:100]}\n" + f" Reason: Obfuscated shell quoting is blocked for security reasons\n" + f" Pattern: raw_obfuscation:$quote\n" + f" 💡 If you need to perform this operation, ask the user for permission." + ) + ) + + dangerous_word = self._find_dangerous_command_word(stripped) + if dangerous_word is not None: + return HookResult.block_command( + error_message=( + f"❌ SECURITY ERROR: Dangerous command detected\n" + f" Command: {command[:100]}\n" + f" Reason: This command is blocked for security reasons\n" + f" Pattern: command_word:{dangerous_word}\n" + f" 💡 If you need to perform this operation, ask the user for permission." + ) + ) + + scanned = self._unquoted_command(stripped) for pattern in self.compiled_patterns: - if pattern.search(command.strip()): + if pattern.search(scanned): return HookResult.block_command( error_message=( f"❌ SECURITY ERROR: Dangerous command detected\n" diff --git a/core/tools/command/middleware.py b/core/tools/command/middleware.py index dcd6453a4..5b4450c34 100644 --- a/core/tools/command/middleware.py +++ b/core/tools/command/middleware.py @@ -18,7 +18,7 @@ from sandbox.shell_output import normalize_pty_result -from .base import AsyncCommand, BaseExecutor +from .base import AsyncCommand, BaseExecutor, describe_execution_exception from .dispatcher import get_executor, get_shell_info logger = logging.getLogger(__name__) @@ -203,7 +203,7 @@ async def _execute_blocking(self, command_line: str, work_dir: str | None, timeo env=self.env, ) except Exception as e: - return f"Error executing command: {e}" + return f"Error executing command: {describe_execution_exception(e)}" return result.to_tool_result() def set_agent(self, agent: Any) -> None: @@ -219,7 +219,7 @@ async def _execute_async(self, command_line: str, work_dir: str | None, timeout: env=self.env, ) except Exception as e: - return f"Error starting async command: {e}" + return f"Error starting async command: {describe_execution_exception(e)}" # Emit task_start event runtime = getattr(self._agent, "runtime", None) if self._agent else None diff --git a/core/tools/command/service.py b/core/tools/command/service.py index 475289b9c..e1927b82b 100644 --- a/core/tools/command/service.py +++ b/core/tools/command/service.py @@ -15,11 +15,13 @@ import asyncio import json import logging +from collections.abc import Awaitable, Callable from pathlib import Path from typing import Any -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry -from core.tools.command.base import BaseExecutor +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema +from core.runtime.tool_result import ToolResultEnvelope, tool_permission_denied +from core.tools.command.base import BaseExecutor, describe_execution_exception from core.tools.command.dispatcher import get_executor logger = logging.getLogger(__name__) @@ -61,35 +63,39 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Bash", mode=ToolMode.INLINE, - schema={ - "name": "Bash", - "description": ("Execute shell command. OS auto-detects shell (mac->zsh, linux->bash, win->powershell)."), - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "Command to execute", - }, - "description": { - "type": "string", - "description": ( - "Human-readable description of what this command does. " - "Required when run_in_background is true; shown in the background task indicator." - ), - }, - "run_in_background": { - "type": "boolean", - "description": "Run in background (default: false). Returns task ID for status queries.", - }, - "timeout": { - "type": "integer", - "description": "Timeout in milliseconds (default: 120000)", - }, + schema=make_tool_schema( + name="Bash", + description=( + "Execute shell command (zsh on macOS, bash on Linux, PowerShell on Windows). " + "Default timeout 120s (max 600s). Dangerous commands are blocked. " + "Prefer dedicated tools over Bash: Read over cat, Grep over grep/rg, Glob over find/ls, Edit over sed/awk." + ), + properties={ + "command": { + "type": "string", + "description": "Command to execute", + "minLength": 1, + }, + "description": { + "type": "string", + "description": ( + "Human-readable description of what this command does. " + "Required when run_in_background is true; shown in the background task indicator." + ), + }, + "run_in_background": { + "type": "boolean", + "description": "Run in background (default: false). Returns task ID for status queries.", + }, + "timeout": { + "type": "integer", + "description": "Timeout in milliseconds (default: 120000)", + "minimum": 1, + "maximum": 600000, }, - "required": ["command"], }, - }, + required=["command"], + ), handler=self._bash, source="CommandService", ) @@ -113,10 +119,13 @@ async def _bash( description: str = "", run_in_background: bool = False, timeout: int = DEFAULT_TIMEOUT_MS, - ) -> str: + ) -> str | ToolResultEnvelope: allowed, error_msg = self._check_hooks(command) if not allowed: - return error_msg + return tool_permission_denied( + error_msg, + metadata={"policy": "command_hook"}, + ) work_dir = None if self._executor.runtime_owns_cwd else str(self.workspace_root) timeout_secs = timeout / 1000.0 @@ -135,7 +144,7 @@ async def _execute_blocking(self, command: str, work_dir: str | None, timeout_se env=self.env, ) except Exception as e: - return f"Error executing command: {e}" + return f"Error executing command: {describe_execution_exception(e)}" return result.to_tool_result() async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: float, description: str = "") -> str: @@ -146,7 +155,7 @@ async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: env=self.env, ) except Exception as e: - return f"Error starting async command: {e}" + return f"Error starting async command: {describe_execution_exception(e)}" task_id = async_cmd.command_id @@ -156,7 +165,7 @@ async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: self._background_runs[task_id] = _BashBackgroundRun(async_cmd, command, description=description) # Build emit_fn for SSE task lifecycle events - emit_fn = None + emit_fn: Callable[[dict[str, Any]], Awaitable[None] | None] | None = None parent_thread_id = None try: from backend.web.event_bus import get_event_bus @@ -178,7 +187,7 @@ async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: # Emit task_start so the frontend dot lights up immediately if emit_fn is not None: - await emit_fn( + emission = emit_fn( { "event": "task_start", "data": json.dumps( @@ -193,6 +202,8 @@ async def _execute_async(self, command: str, work_dir: str | None, timeout_secs: ), } ) + if asyncio.iscoroutine(emission): + await emission if parent_thread_id: asyncio.create_task( @@ -207,7 +218,7 @@ async def _notify_bash_completion( async_cmd: Any, command: str, parent_thread_id: str, - emit_fn: Any = None, + emit_fn: Callable[[dict[str, Any]], Awaitable[None] | None] | None = None, description: str = "", ) -> None: """Poll until async command finishes, then enqueue CommandNotification.""" @@ -220,7 +231,7 @@ async def _notify_bash_completion( # Emit task_done so the frontend dot updates in real time if emit_fn is not None: try: - await emit_fn( + emission = emit_fn( { "event": "task_done", "data": json.dumps( @@ -232,6 +243,8 @@ async def _notify_bash_completion( ), } ) + if asyncio.iscoroutine(emission): + await emission except Exception: pass diff --git a/core/tools/cron/service.py b/core/tools/cron/service.py new file mode 100644 index 000000000..026c7d9be --- /dev/null +++ b/core/tools/cron/service.py @@ -0,0 +1,102 @@ +"""CronToolService — agent-callable cron job CRUD on top of existing backend service.""" + +from __future__ import annotations + +import json +from typing import Any + +from croniter import croniter + +from backend.web.services import cron_job_service +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema + +CRON_CREATE_SCHEMA = make_tool_schema( + name="CronCreate", + description="Create a cron job using the existing Mycel cron_jobs substrate.", + properties={ + "name": {"type": "string", "description": "Human-readable cron job name", "minLength": 1}, + "cron_expression": { + "type": "string", + "description": "Standard 5-field cron expression", + "minLength": 1, + }, + "description": {"type": "string", "description": "Optional cron job description"}, + "task_template": { + "type": "string", + "description": "JSON string template used when the cron job creates a task", + }, + "enabled": {"type": "boolean", "description": "Whether the cron job starts enabled"}, + }, + required=["name", "cron_expression"], +) + +CRON_DELETE_SCHEMA = make_tool_schema( + name="CronDelete", + description="Delete a cron job by ID.", + properties={ + "job_id": {"type": "string", "description": "Cron job ID returned by CronCreate", "minLength": 1}, + }, + required=["job_id"], +) + +CRON_LIST_SCHEMA = make_tool_schema( + name="CronList", + description="List all cron jobs in the current Mycel cron_jobs substrate.", + properties={}, +) + + +class CronToolService: + def __init__(self, registry: ToolRegistry): + self._register(registry) + + def _register(self, registry: ToolRegistry) -> None: + for name, schema, handler, read_only in [ + ("CronCreate", CRON_CREATE_SCHEMA, self._create, False), + ("CronDelete", CRON_DELETE_SCHEMA, self._delete, False), + ("CronList", CRON_LIST_SCHEMA, self._list, True), + ]: + registry.register( + ToolEntry( + name=name, + mode=ToolMode.DEFERRED, + schema=schema, + handler=handler, + source="CronToolService", + is_concurrency_safe=read_only, + is_read_only=read_only, + ) + ) + + def _create(self, **args: Any) -> str: + name = str(args.get("name", "")).strip() + cron_expression = str(args.get("cron_expression", "")).strip() + if not croniter.is_valid(cron_expression): + raise ValueError(f"Invalid cron expression: {cron_expression!r}") + + task_template = args.get("task_template", "{}") + if isinstance(task_template, str): + try: + json.loads(task_template) + except json.JSONDecodeError as exc: + raise ValueError("task_template must be valid JSON") from exc + + item = cron_job_service.create_cron_job( + name=name, + cron_expression=cron_expression, + description=str(args.get("description", "")), + task_template=task_template, + enabled=int(bool(args.get("enabled", True))), + ) + return json.dumps({"item": item}, ensure_ascii=False, indent=2) + + def _delete(self, **args: Any) -> str: + job_id = str(args.get("job_id", "")).strip() + ok = cron_job_service.delete_cron_job(job_id) + if not ok: + raise ValueError(f"Cron job not found: {job_id}") + return json.dumps({"ok": True, "id": job_id}, ensure_ascii=False, indent=2) + + def _list(self, **_args: Any) -> str: + items = cron_job_service.list_cron_jobs() + return json.dumps({"items": items, "total": len(items)}, ensure_ascii=False, indent=2) diff --git a/core/tools/filesystem/local_backend.py b/core/tools/filesystem/local_backend.py index 2bad2d45b..50bbe58a0 100644 --- a/core/tools/filesystem/local_backend.py +++ b/core/tools/filesystem/local_backend.py @@ -18,14 +18,16 @@ class LocalBackend(FileSystemBackend): def read_file(self, path: str) -> FileReadResult: p = Path(path) - content = p.read_text(encoding="utf-8") + with p.open("r", encoding="utf-8", newline="") as f: + content = f.read() return FileReadResult(content=content, size=p.stat().st_size) def write_file(self, path: str, content: str) -> FileWriteResult: try: p = Path(path) p.parent.mkdir(parents=True, exist_ok=True) - p.write_text(content, encoding="utf-8") + with p.open("w", encoding="utf-8", newline="") as f: + f.write(content) return FileWriteResult(success=True) except Exception as e: return FileWriteResult(success=False, error=str(e)) diff --git a/core/tools/filesystem/middleware.py b/core/tools/filesystem/middleware.py index 0844d892a..ff31d0c1c 100644 --- a/core/tools/filesystem/middleware.py +++ b/core/tools/filesystem/middleware.py @@ -14,7 +14,7 @@ from __future__ import annotations from collections.abc import Awaitable, Callable -from pathlib import Path +from pathlib import Path, PurePosixPath from typing import TYPE_CHECKING, Any from langchain.agents.middleware.types import ( @@ -33,6 +33,13 @@ from core.operations import FileOperationRecorder +def _remote_path(path: str | Path) -> PurePosixPath: + # @@@remote-posix-path-contract - Middleware callers still hand us sandbox + # POSIX paths even when tests run on Windows, so keep validation and + # workspace comparisons in POSIX space instead of host-native path rules. + return PurePosixPath(str(path).replace("\\", "/")) + + class FileSystemMiddleware(AgentMiddleware): """FileSystem Middleware - pure middleware implementation of file operations. @@ -80,7 +87,7 @@ def __init__( backend = LocalBackend() self.backend = backend - self.workspace_root = Path(workspace_root) if backend.is_remote else Path(workspace_root).resolve() + self.workspace_root = _remote_path(workspace_root) if backend.is_remote else Path(workspace_root).resolve() self.max_file_size = max_file_size self.allowed_extensions = allowed_extensions self.hooks = hooks or [] @@ -91,10 +98,10 @@ def __init__( "multi_edit": True, "list_dir": True, } - self._read_files: dict[Path, float | None] = {} + self._read_files: dict[Path | PurePosixPath, float | None] = {} self.operation_recorder = operation_recorder self.verbose = verbose - self.extra_allowed_paths: list[Path] = [Path(p) if backend.is_remote else Path(p).resolve() for p in (extra_allowed_paths or [])] + self.extra_allowed_paths = [_remote_path(p) if backend.is_remote else Path(p).resolve() for p in (extra_allowed_paths or [])] if not backend.is_remote: self.workspace_root.mkdir(parents=True, exist_ok=True) @@ -105,17 +112,20 @@ def __init__( if self.hooks: print(f"[FileSystemMiddleware] Loaded {len(self.hooks)} hooks") - def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | None]: + def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | PurePosixPath | None]: """Validate path for file operations. Returns: (is_valid, error_message, resolved_path) """ - if not Path(path).is_absolute(): + if self.backend.is_remote: + if not _remote_path(path).is_absolute(): + return False, f"Path must be absolute: {path}", None + elif not Path(path).is_absolute(): return False, f"Path must be absolute: {path}", None try: - resolved = Path(path) if self.backend.is_remote else Path(path).resolve() + resolved = _remote_path(path) if self.backend.is_remote else Path(path).resolve() except Exception as e: return False, f"Invalid path: {path} ({e})", None @@ -146,7 +156,7 @@ def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | N return True, "", resolved - def _check_file_staleness(self, resolved: Path) -> str | None: + def _check_file_staleness(self, resolved: Path | PurePosixPath) -> str | None: """Check if file has been modified since last read. Returns: @@ -165,7 +175,7 @@ def _check_file_staleness(self, resolved: Path) -> str | None: return None - def _update_file_tracking(self, resolved: Path) -> None: + def _update_file_tracking(self, resolved: Path | PurePosixPath) -> None: """Update mtime tracking after successful file operation.""" self._read_files[resolved] = self.backend.file_mtime(str(resolved)) @@ -203,7 +213,7 @@ def _record_operation( except Exception as e: raise RuntimeError(f"[FileSystemMiddleware] Failed to record operation: {e}") from e - def _count_lines(self, resolved: Path) -> int: + def _count_lines(self, resolved: Path | PurePosixPath) -> int: """Count total lines in a file (for error messages).""" try: raw = self.backend.read_file(str(resolved)) @@ -571,12 +581,12 @@ def _get_tool_schemas(self) -> list[dict]: "parameters": { "type": "object", "properties": { - "directory_path": { + "path": { "type": "string", "description": "Absolute directory path (e.g., /path/to/dir). Do NOT use '.' or '..'", }, }, - "required": ["directory_path"], + "required": ["path"], }, }, }, @@ -633,7 +643,7 @@ def _handle_tool_call(self, tool_call: dict) -> ToolMessage | None: return ToolMessage(content=result, tool_call_id=tool_call_id) if tool_name == self.TOOL_LIST_DIR: - result = self._list_dir_impl(directory_path=args.get("directory_path", "")) + result = self._list_dir_impl(directory_path=args.get("path", "")) return ToolMessage(content=result, tool_call_id=tool_call_id) return None diff --git a/core/tools/filesystem/read/dispatcher.py b/core/tools/filesystem/read/dispatcher.py index f880e60e1..0119f424e 100644 --- a/core/tools/filesystem/read/dispatcher.py +++ b/core/tools/filesystem/read/dispatcher.py @@ -22,6 +22,7 @@ def read_file( limits: ReadLimits | None = None, offset: int | None = None, limit: int | None = None, + pages: str | None = None, ) -> ReadResult: """ Read file with type-specific handling. @@ -38,6 +39,7 @@ def read_file( limits: ReadLimits configuration (uses defaults if None) offset: Start line for text files (1-indexed) limit: Number of lines for text files + pages: Optional page range for document files, e.g. "1" or "3-5" Returns: ReadResult with content and metadata @@ -68,7 +70,8 @@ def read_file( return read_binary(path) if file_type == FileType.DOCUMENT: - return _read_document(path, limits, offset, limit) + start_page, limit_pages = _parse_pages_arg(pages, offset, limit) + return _read_document(path, limits, start_page, limit_pages) if file_type == FileType.NOTEBOOK: return read_notebook(path, limits, start_cell=offset, limit_cells=limit) @@ -79,6 +82,32 @@ def read_file( return read_text(path, limits, offset, limit) +def _parse_pages_arg( + pages: str | None, + offset: int | None, + limit: int | None, +) -> tuple[int | None, int | None]: + if pages is None: + return offset, limit + + raw = pages.strip() + if not raw: + raise ValueError("pages must not be empty") + + if "-" in raw: + start_raw, end_raw = raw.split("-", 1) + start_page = int(start_raw) + end_page = int(end_raw) + if start_page <= 0 or end_page < start_page: + raise ValueError(f"Invalid pages range: {pages}") + return start_page, end_page - start_page + 1 + + start_page = int(raw) + if start_page <= 0: + raise ValueError(f"Invalid page number: {pages}") + return start_page, 1 + + def _read_document( path: Path, limits: ReadLimits, diff --git a/core/tools/filesystem/service.py b/core/tools/filesystem/service.py index a8cf1c9c6..ecfa0b7c5 100644 --- a/core/tools/filesystem/service.py +++ b/core/tools/filesystem/service.py @@ -10,18 +10,90 @@ from __future__ import annotations import logging -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +import tempfile +import threading +from collections import OrderedDict +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path, PurePosixPath +from typing import TYPE_CHECKING, Any, Literal + +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema +from core.runtime.tool_result import ToolResultEnvelope, tool_success from core.tools.filesystem.backend import FileSystemBackend from core.tools.filesystem.read import ReadLimits from core.tools.filesystem.read import read_file as read_file_dispatch +from core.tools.filesystem.read.readers.binary import IMAGE_EXTENSIONS, MAX_IMAGE_SIZE +from core.tools.filesystem.read.types import FileType, detect_file_type if TYPE_CHECKING: from core.operations import FileOperationRecorder logger = logging.getLogger(__name__) +DEFAULT_READ_STATE_CACHE_SIZE = 100 +ABSOLUTE_PATH_PATTERN = r"^(?:/|[A-Za-z]:[\\/])" +type ResolvedPath = Path | PurePosixPath +type ValidationResult = tuple[Literal[True], str, ResolvedPath] | tuple[Literal[False], str, None] + + +def _remote_path(path: str | Path) -> PurePosixPath: + # @@@remote-posix-path-contract - Remote filesystem tools operate on sandbox + # POSIX paths, not host-native paths. Preserve forward-slash semantics even + # when the host process is running on Windows. + return PurePosixPath(str(path).replace("\\", "/")) + + +@dataclass +class _ReadFileState: + timestamp: float | None + is_partial: bool + + +class _ReadFileStateCache: + def __init__(self, max_entries: int = DEFAULT_READ_STATE_CACHE_SIZE): + self._max_entries = max_entries + self._entries: OrderedDict[ResolvedPath, _ReadFileState] = OrderedDict() + + @staticmethod + def make_state(*, timestamp: float | None, is_partial: bool) -> _ReadFileState: + return _ReadFileState(timestamp=timestamp, is_partial=is_partial) + + def get(self, path: ResolvedPath) -> _ReadFileState | None: + state = self._entries.get(path) + if state is None: + return None + self._entries.move_to_end(path) + return state + + def set(self, path: ResolvedPath, state: _ReadFileState) -> None: + self._entries[path] = state + self._entries.move_to_end(path) + while len(self._entries) > self._max_entries: + self._entries.popitem(last=False) + + def clone(self) -> _ReadFileStateCache: + clone = _ReadFileStateCache(max_entries=self._max_entries) + clone._entries = OrderedDict( + (path, _ReadFileState(timestamp=state.timestamp, is_partial=state.is_partial)) for path, state in self._entries.items() + ) + return clone + + def merge(self, other: _ReadFileStateCache) -> None: + for path, incoming in other._entries.items(): + existing = self._entries.get(path) + if existing is None or self._is_newer(incoming, existing): + self.set( + path, + _ReadFileState(timestamp=incoming.timestamp, is_partial=incoming.is_partial), + ) + + @staticmethod + def _is_newer(incoming: _ReadFileState, existing: _ReadFileState) -> bool: + if incoming.timestamp is None: + return False + if existing.timestamp is None: + return True + return incoming.timestamp >= existing.timestamp class FileSystemService: @@ -37,7 +109,9 @@ def __init__( hooks: list[Any] | None = None, operation_recorder: FileOperationRecorder | None = None, backend: FileSystemBackend | None = None, - extra_allowed_paths: list[str | Path] | None = None, + extra_allowed_paths: Sequence[str | Path] | None = None, + max_read_cache_entries: int = DEFAULT_READ_STATE_CACHE_SIZE, + max_edit_file_size: int | None = None, ): if backend is None: from core.tools.filesystem.local_backend import LocalBackend @@ -45,15 +119,17 @@ def __init__( backend = LocalBackend() self.backend = backend - self.workspace_root = Path(workspace_root) if backend.is_remote else Path(workspace_root).resolve() + self.workspace_root: ResolvedPath = _remote_path(workspace_root) if backend.is_remote else Path(workspace_root).resolve() self.max_file_size = max_file_size self.allowed_extensions = allowed_extensions self.hooks = hooks or [] - self._read_files: dict[Path, float | None] = {} + self._read_files = _ReadFileStateCache(max_entries=max_read_cache_entries) + self.max_edit_file_size = max_file_size if max_edit_file_size is None else max_edit_file_size self.operation_recorder = operation_recorder - self.extra_allowed_paths: list[Path] = [Path(p) if backend.is_remote else Path(p).resolve() for p in (extra_allowed_paths or [])] + self.extra_allowed_paths = [_remote_path(p) if backend.is_remote else Path(p).resolve() for p in (extra_allowed_paths or [])] + self._edit_critical_section = threading.Lock() - if not backend.is_remote: + if not backend.is_remote and isinstance(self.workspace_root, Path): self.workspace_root.mkdir(parents=True, exist_ok=True) self._register(registry) @@ -67,30 +143,42 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Read", mode=ToolMode.INLINE, - schema={ - "name": "Read", - "description": ("Read file content (text/code/images/PDF/PPTX/Notebook). Path must be absolute."), - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Absolute file path", - }, - "offset": { - "type": "integer", - "description": "Start line (1-indexed, optional)", - }, - "limit": { - "type": "integer", - "description": "Number of lines to read (optional)", - }, + schema=make_tool_schema( + name="Read", + description=( + "Read file content. Output uses cat -n format (line numbers starting at 1). " + "Default reads up to 2000 lines from start; use offset/limit for long files. " + "Supports images (PNG/JPG), PDF (use pages param for large PDFs), and Jupyter notebooks. " + "Path must be absolute." + ), + properties={ + "file_path": { + "type": "string", + "description": "Absolute file path", + "minLength": 1, + "pattern": ABSOLUTE_PATH_PATTERN, + }, + "offset": { + "type": "integer", + "description": "Start line (1-indexed, optional)", + }, + "limit": { + "type": "integer", + "description": "Number of lines to read (optional)", + }, + "pages": { + "type": "string", + "description": "Page range for PDF files (e.g. '1-5'). Max 20 pages per request.", }, - "required": ["file_path"], }, - }, + required=["file_path"], + ), handler=self._read_file, + validate_input=self._validate_read_args, source="FileSystemService", + search_hint="read view file content text code image PDF notebook", + is_read_only=True, + is_concurrency_safe=True, ) ) @@ -98,26 +186,27 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Write", mode=ToolMode.INLINE, - schema={ - "name": "Write", - "description": "Create new file. Path must be absolute. Fails if file exists.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Absolute file path", - }, - "content": { - "type": "string", - "description": "File content", - }, + schema=make_tool_schema( + name="Write", + description="Create or overwrite a file with full content. Forces LF line endings. Path must be absolute.", + properties={ + "file_path": { + "type": "string", + "description": "Absolute file path", + "minLength": 1, + "pattern": ABSOLUTE_PATH_PATTERN, + }, + "content": { + "type": "string", + "description": "File content", }, - "required": ["file_path", "content"], }, - }, + required=["file_path", "content"], + ), handler=self._write_file, + validate_input=self._validate_write_args, source="FileSystemService", + search_hint="create new file write content to disk", ) ) @@ -125,39 +214,39 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Edit", mode=ToolMode.INLINE, - schema={ - "name": "Edit", - "description": ( - "Edit existing file using exact string replacement. " - "MUST read file before editing. " - "old_string must be unique in file. " - "Set replace_all=true to replace all occurrences." + schema=make_tool_schema( + name="Edit", + description=( + "Edit file via exact string replacement. You MUST Read the file first. " + "old_string must match exactly one location (or use replace_all=true). " + "Does not support .ipynb files (use Write to overwrite full JSON). Path must be absolute." ), - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Absolute file path", - }, - "old_string": { - "type": "string", - "description": "Exact string to replace", - }, - "new_string": { - "type": "string", - "description": "Replacement string", - }, - "replace_all": { - "type": "boolean", - "description": "Replace all occurrences (default: false)", - }, + properties={ + "file_path": { + "type": "string", + "description": "Absolute file path", + "minLength": 1, + "pattern": ABSOLUTE_PATH_PATTERN, + }, + "old_string": { + "type": "string", + "description": "Exact string to replace", + }, + "new_string": { + "type": "string", + "description": "Replacement string", + }, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences (default: false)", }, - "required": ["file_path", "old_string", "new_string"], }, - }, + required=["file_path", "old_string", "new_string"], + ), handler=self._edit_file, + validate_input=self._validate_edit_args, source="FileSystemService", + search_hint="edit modify replace string in existing file", ) ) @@ -165,22 +254,25 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="list_dir", mode=ToolMode.INLINE, - schema={ - "name": "list_dir", - "description": "List directory contents. Path must be absolute.", - "parameters": { - "type": "object", - "properties": { - "directory_path": { - "type": "string", - "description": "Absolute directory path", - }, + schema=make_tool_schema( + name="list_dir", + description="List directory contents (files and subdirectories, non-recursive). Path must be absolute.", + properties={ + "path": { + "type": "string", + "description": "Absolute directory path", + "minLength": 1, + "pattern": ABSOLUTE_PATH_PATTERN, }, - "required": ["directory_path"], }, - }, + required=["path"], + ), handler=self._list_dir, + validate_input=self._validate_list_dir_args, source="FileSystemService", + search_hint="list directory contents browse folder", + is_read_only=True, + is_concurrency_safe=True, ) ) @@ -188,12 +280,15 @@ def _register(self, registry: ToolRegistry) -> None: # Path validation (reused from middleware) # ------------------------------------------------------------------ - def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | None]: - if not Path(path).is_absolute(): + def _validate_path(self, path: str, operation: str) -> ValidationResult: + if self.backend.is_remote: + if not _remote_path(path).is_absolute(): + return False, f"Path must be absolute: {path}", None + elif not Path(path).is_absolute(): return False, f"Path must be absolute: {path}", None try: - resolved = Path(path) if self.backend.is_remote else Path(path).resolve() + resolved = _remote_path(path) if self.backend.is_remote else Path(path).resolve() except Exception as e: return False, f"Invalid path: {path} ({e})", None @@ -224,10 +319,159 @@ def _validate_path(self, path: str, operation: str) -> tuple[bool, str, Path | N return True, "", resolved - def _check_file_staleness(self, resolved: Path) -> str | None: - if resolved not in self._read_files: - return "File has not been read yet. Read it first before writing to it." - stored_mtime = self._read_files[resolved] + def _validation_error(self, message: str, error_code: str) -> dict[str, object]: + return { + "result": False, + "message": message, + "errorCode": error_code, + } + + def _path_validation_error(self, message: str) -> dict[str, object]: + # @@@filesystem-validation-codes - Keep the pre-execution path failure + # mapping centralized so the runner can surface stable structured + # codes instead of ad-hoc handler strings on the highest-traffic tools. + if message.startswith("Path must be absolute:"): + return self._validation_error(message, "PATH_NOT_ABSOLUTE") + if message.startswith("Invalid path:"): + return self._validation_error(message, "INVALID_PATH") + if message.startswith("Path outside workspace"): + return self._validation_error(message, "PATH_OUTSIDE_WORKSPACE") + if message.startswith("File type not allowed:"): + return self._validation_error(message, "FILE_TYPE_NOT_ALLOWED") + return self._validation_error(message, "INVALID_PATH") + + def _validate_existing_path(self, path: str, operation: str) -> tuple[dict[str, object] | None, ResolvedPath | None]: + is_valid, error, resolved = self._validate_path(path, operation) + if not is_valid: + return self._path_validation_error(error), None + assert resolved is not None + return None, resolved + + def _validation_message(self, error: dict[str, object]) -> str: + return str(error["message"]) + + def _read_preflight( + self, + *, + file_path: str, + offset: int = 0, + limit: int | None = None, + pages: str | None = None, + ) -> tuple[dict[str, object] | None, ResolvedPath | None]: + error, resolved = self._validate_existing_path(file_path, "read") + if error is not None: + return error, None + assert resolved is not None + + file_size = self.backend.file_size(str(resolved)) + if file_size is not None and file_size > self.max_file_size: + return ( + self._validation_error( + f"File too large: {file_size:,} bytes (max: {self.max_file_size:,} bytes)", + "FILE_TOO_LARGE", + ), + None, + ) + + has_pagination = offset > 0 or limit is not None or pages is not None + if not has_pagination and file_size is not None: + limits = ReadLimits() + if file_size > limits.max_size_bytes: + total_lines = self._count_lines(resolved) + return ( + self._validation_error( + ( + f"File content ({file_size:,} bytes) exceeds maximum allowed size ({limits.max_size_bytes:,} bytes).\n" + f"Use offset and limit parameters to read specific sections.\n" + f"Total lines: {total_lines}" + ), + "READ_REQUIRES_PAGINATION", + ), + None, + ) + estimated_tokens = file_size // 4 + if estimated_tokens > limits.max_tokens: + total_lines = self._count_lines(resolved) + return ( + self._validation_error( + ( + f"File content (~{estimated_tokens:,} tokens) exceeds maximum allowed tokens ({limits.max_tokens:,}).\n" + f"Use offset and limit parameters to read specific sections.\n" + f"Total lines: {total_lines}" + ), + "READ_REQUIRES_PAGINATION", + ), + None, + ) + + return None, resolved + + def _edit_preflight(self, *, file_path: str) -> tuple[dict[str, object] | None, ResolvedPath | None]: + error, resolved = self._validate_existing_path(file_path, "edit") + if error is not None: + return error, None + assert resolved is not None + + if resolved.suffix.lower() == ".ipynb": + return ( + self._validation_error( + "Notebook files (.ipynb) are not supported by Edit. Use Write to overwrite the full JSON.", + "NOTEBOOK_EDIT_UNSUPPORTED", + ), + None, + ) + + file_size = self.backend.file_size(str(resolved)) + if file_size is not None and file_size > self.max_edit_file_size: + return ( + self._validation_error( + f"File too large for Edit: {file_size:,} bytes (max: {self.max_edit_file_size:,} bytes)", + "FILE_TOO_LARGE", + ), + None, + ) + + return None, resolved + + def _list_dir_preflight(self, *, path: str) -> tuple[dict[str, object] | None, ResolvedPath | None]: + error, resolved = self._validate_existing_path(path, "list") + if error is not None: + return error, None + assert resolved is not None + if not self.backend.is_dir(str(resolved)): + if self.backend.file_exists(str(resolved)): + return self._validation_error(f"Not a directory: {path}", "NOT_A_DIRECTORY"), None + return self._validation_error(f"Directory not found: {path}", "DIRECTORY_NOT_FOUND"), None + return None, resolved + + def _validate_read_args(self, args: dict[str, Any], request: Any) -> dict[str, Any]: + error, _ = self._read_preflight( + file_path=args["file_path"], + offset=args.get("offset") or 0, + limit=args.get("limit"), + pages=args.get("pages"), + ) + return error or args + + def _validate_write_args(self, args: dict[str, Any], request: Any) -> dict[str, Any]: + error, _ = self._validate_existing_path(args["file_path"], "write") + return error or args + + def _validate_edit_args(self, args: dict[str, Any], request: Any) -> dict[str, Any]: + error, _ = self._edit_preflight(file_path=args["file_path"]) + return error or args + + def _validate_list_dir_args(self, args: dict[str, Any], request: Any) -> dict[str, Any]: + error, _ = self._list_dir_preflight(path=args["path"]) + return error or args + + def _check_file_staleness(self, resolved: ResolvedPath) -> str | None: + state = self._read_files.get(resolved) + if state is None: + return "File has not been read yet. Read the full file first before editing." + if state.is_partial: + return "File has only been read partially. Read the full file before editing." + stored_mtime = state.timestamp if stored_mtime is None: return None current_mtime = self.backend.file_mtime(str(resolved)) @@ -235,8 +479,70 @@ def _check_file_staleness(self, resolved: Path) -> str | None: return "File has been modified since last read. Read it again before editing." return None - def _update_file_tracking(self, resolved: Path) -> None: - self._read_files[resolved] = self.backend.file_mtime(str(resolved)) + def _update_file_tracking( + self, + resolved: ResolvedPath, + *, + is_partial: bool, + file_type: FileType | None = None, + ) -> None: + if file_type is None: + file_type = self._detect_file_type(resolved) + if file_type not in {FileType.TEXT, FileType.NOTEBOOK}: + return + self._read_files.set( + resolved, + _ReadFileState( + timestamp=self.backend.file_mtime(str(resolved)), + is_partial=is_partial, + ), + ) + + def _normalize_write_content(self, content: str) -> str: + return content.replace("\r\n", "\n").replace("\r", "\n") + + def _read_result_is_partial(self, result) -> bool: + if getattr(result, "truncated", False): + return True + if getattr(result, "file_type", None) == FileType.TEXT: + start_line = getattr(result, "start_line", None) or 1 + total_lines = getattr(result, "total_lines", None) + end_line = getattr(result, "end_line", None) or total_lines or start_line + if total_lines is not None: + return start_line > 1 or end_line < total_lines + return False + + def _detect_file_type(self, resolved: ResolvedPath) -> FileType: + return detect_file_type(Path(str(resolved))) + + def _structured_media_success( + self, + *, + resolved: ResolvedPath, + file_type: FileType, + content_blocks: list[dict[str, str]], + ) -> ToolResultEnvelope: + return tool_success( + [ + { + "type": "text", + "text": (f"Read file: {resolved.name}\nSpecial content is attached below as structured blocks."), + }, + *content_blocks, + ], + metadata={"file_type": file_type.value}, + ) + + def _restore_special_result_identity( + self, + *, + result, + resolved: ResolvedPath, + temp_path: Path, + ) -> None: + result.file_path = str(resolved) + if isinstance(getattr(result, "content", None), str): + result.content = result.content.replace(str(temp_path), str(resolved)).replace(temp_path.name, resolved.name) def _record_operation( self, @@ -267,7 +573,7 @@ def _record_operation( except Exception as e: raise RuntimeError(f"[FileSystemService] Failed to record operation: {e}") from e - def _count_lines(self, resolved: Path) -> int: + def _count_lines(self, resolved: ResolvedPath) -> int: try: raw = self.backend.read_file(str(resolved)) return raw.content.count("\n") + 1 @@ -278,50 +584,86 @@ def _count_lines(self, resolved: Path) -> int: # Tool handlers # ------------------------------------------------------------------ - def _read_file(self, file_path: str, offset: int = 0, limit: int | None = None) -> str: - is_valid, error, resolved = self._validate_path(file_path, "read") - if not is_valid: - return error - - file_size = self.backend.file_size(str(resolved)) - - if file_size is not None and file_size > self.max_file_size: - return f"File too large: {file_size:,} bytes (max: {self.max_file_size:,} bytes)" - - has_pagination = offset > 0 or limit is not None - if not has_pagination and file_size is not None: - limits = ReadLimits() - if file_size > limits.max_size_bytes: - total_lines = self._count_lines(resolved) - return ( - f"File content ({file_size:,} bytes) exceeds maximum allowed size ({limits.max_size_bytes:,} bytes).\n" - f"Use offset and limit parameters to read specific sections.\n" - f"Total lines: {total_lines}" - ) - estimated_tokens = file_size // 4 - if estimated_tokens > limits.max_tokens: - total_lines = self._count_lines(resolved) - return ( - f"File content (~{estimated_tokens:,} tokens) exceeds maximum allowed tokens ({limits.max_tokens:,}).\n" - f"Use offset and limit parameters to read specific sections.\n" - f"Total lines: {total_lines}" - ) + def _read_file(self, file_path: str, offset: int = 0, limit: int | None = None, pages: str | None = None) -> str | ToolResultEnvelope: + error, resolved = self._read_preflight( + file_path=file_path, + offset=offset, + limit=limit, + pages=pages, + ) + if error is not None: + return self._validation_message(error) + assert resolved is not None from core.tools.filesystem.local_backend import LocalBackend if isinstance(self.backend, LocalBackend): + assert isinstance(resolved, Path) limits = ReadLimits() result = read_file_dispatch( path=resolved, limits=limits, offset=offset if offset > 0 else None, limit=limit, + pages=pages, ) if not result.error: - self._update_file_tracking(resolved) + self._update_file_tracking( + resolved, + is_partial=self._read_result_is_partial(result), + file_type=result.file_type, + ) + if result.content_blocks: + return self._structured_media_success( + resolved=resolved, + file_type=result.file_type, + content_blocks=result.content_blocks, + ) return result.format_output() try: + file_type = self._detect_file_type(resolved) + download_bytes = getattr(self.backend, "download_bytes", None) + if callable(download_bytes) and file_type in {FileType.BINARY, FileType.DOCUMENT}: + # @@@dt-02-remote-special-file-bridge + # Remote providers expose raw-byte download hooks. Reuse the + # same local dispatcher for binary/document reads instead of + # degrading special files into placeholder text. + raw_bytes = download_bytes(str(resolved)) + if not isinstance(raw_bytes, (bytes, bytearray)): + raise TypeError(f"Remote special-file download returned {type(raw_bytes).__name__}, expected bytes.") + raw_bytes = bytes(raw_bytes) + if ( + file_type == FileType.BINARY + and resolved.suffix.lstrip(".").lower() in IMAGE_EXTENSIONS + and len(raw_bytes) > MAX_IMAGE_SIZE + ): + return f"Image exceeds size limit: {len(raw_bytes)} bytes" + with tempfile.NamedTemporaryFile(suffix=resolved.suffix, delete=False) as tmp: + tmp.write(raw_bytes) + tmp_path = Path(tmp.name) + try: + result = read_file_dispatch( + path=tmp_path, + limits=ReadLimits(), + offset=offset if offset > 0 else None, + limit=limit, + pages=pages, + ) + finally: + tmp_path.unlink(missing_ok=True) + self._restore_special_result_identity( + result=result, + resolved=resolved, + temp_path=tmp_path, + ) + if result.content_blocks: + return self._structured_media_success( + resolved=resolved, + file_type=result.file_type, + content_blocks=result.content_blocks, + ) + return result.format_output() raw = self.backend.read_file(str(resolved)) lines = raw.content.split("\n") total_lines = len(lines) @@ -331,7 +673,10 @@ def _read_file(self, file_path: str, offset: int = 0, limit: int | None = None) selected = lines[start:end] numbered = [f"{start + i + 1:>6}\t{line}" for i, line in enumerate(selected)] content = "\n".join(numbered) - self._update_file_tracking(resolved) + self._update_file_tracking( + resolved, + is_partial=start > 0 or end < total_lines, + ) return content except Exception as e: return f"Error reading file: {e}" @@ -340,88 +685,102 @@ def _write_file(self, file_path: str, content: str) -> str: is_valid, error, resolved = self._validate_path(file_path, "write") if not is_valid: return error - - if self.backend.file_exists(str(resolved)): - return f"File already exists: {file_path}\nUse Edit to modify existing files" + assert resolved is not None try: - result = self.backend.write_file(str(resolved), content) + normalized = self._normalize_write_content(content) + result = self.backend.write_file(str(resolved), normalized) if not result.success: return f"Error writing file: {result.error}" - self._update_file_tracking(resolved) + self._update_file_tracking(resolved, is_partial=False) self._record_operation( operation_type="write", file_path=file_path, before_content=None, - after_content=content, + after_content=normalized, ) - lines = content.count("\n") + 1 + lines = normalized.count("\n") + 1 return f"File created: {file_path}\n Lines: {lines}\n Size: {len(content)} bytes" except Exception as e: return f"Error writing file: {e}" def _edit_file(self, file_path: str, old_string: str, new_string: str, replace_all: bool = False) -> str: - is_valid, error, resolved = self._validate_path(file_path, "edit") - if not is_valid: - return error - - if not self.backend.file_exists(str(resolved)): - return f"File not found: {file_path}" - - staleness_error = self._check_file_staleness(resolved) - if staleness_error: - return staleness_error - - if old_string == new_string: - return "Error: old_string and new_string are identical (no-op edit)" + error, resolved = self._edit_preflight(file_path=file_path) + if error is not None: + return self._validation_message(error) + assert resolved is not None try: - raw = self.backend.read_file(str(resolved)) - content = raw.content - - if old_string not in content: - return f"String not found in file\n Looking for: {old_string[:100]}..." - - if replace_all: - count = content.count(old_string) - new_content = content.replace(old_string, new_string) - else: - count = content.count(old_string) - if count > 1: - return ( - f"String appears {count} times in file (not unique)\n" - f" Use replace_all=true or provide more context to make it unique" - ) - new_content = content.replace(old_string, new_string, 1) - count = 1 - - result = self.backend.write_file(str(resolved), new_content) - if not result.success: - return f"Error editing file: {result.error}" - - self._update_file_tracking(resolved) - self._record_operation( - operation_type="edit", - file_path=file_path, - before_content=content, - after_content=new_content, - changes=[{"old_string": old_string, "new_string": new_string}], - ) - return f"File edited: {file_path}\n Replaced {count} occurrence(s)" + # @@@edit-critical-lock + # dt-01 requires the reread -> stale check -> write path to be one + # synchronous critical section so two stale concurrent edits cannot + # both commit from the same prior read snapshot. + with self._edit_critical_section: + try: + raw = self.backend.read_file(str(resolved)) + except FileNotFoundError: + if old_string == "": + return self._write_file(file_path, new_string) + return f"File not found: {file_path}" + content = raw.content + + if old_string == "": + return "Cannot use empty old_string on an existing file. Use Write to replace the full file content." + staleness_error = self._check_file_staleness(resolved) + if staleness_error: + return staleness_error + + if old_string == new_string: + return "Error: old_string and new_string are identical (no-op edit)" + + # @@@edit-critical-staleness + # te-06 needs a second stale-read check inside the read->write + # critical section so an external write that lands after the + # preflight check cannot be silently overwritten. + staleness_error = self._check_file_staleness(resolved) + if staleness_error: + return staleness_error + + if old_string not in content: + return f"String not found in file\n Looking for: {old_string[:100]}..." + + if replace_all: + count = content.count(old_string) + new_content = content.replace(old_string, new_string) + else: + count = content.count(old_string) + if count > 1: + return ( + f"String appears {count} times in file (not unique)\n" + f" Use replace_all=true or provide more context to make it unique" + ) + new_content = content.replace(old_string, new_string, 1) + count = 1 + + result = self.backend.write_file(str(resolved), new_content) + if not result.success: + return f"Error editing file: {result.error}" + + self._update_file_tracking(resolved, is_partial=False) + self._record_operation( + operation_type="edit", + file_path=file_path, + before_content=content, + after_content=new_content, + changes=[{"old_string": old_string, "new_string": new_string}], + ) + return f"File edited: {file_path}\n Replaced {count} occurrence(s)" except Exception as e: return f"Error editing file: {e}" - def _list_dir(self, directory_path: str) -> str: - is_valid, error, resolved = self._validate_path(directory_path, "list") - if not is_valid: - return error - - if not self.backend.is_dir(str(resolved)): - if self.backend.file_exists(str(resolved)): - return f"Not a directory: {directory_path}" - return f"Directory not found: {directory_path}" + def _list_dir(self, path: str) -> str: + directory_path = path + error, resolved = self._list_dir_preflight(path=directory_path) + if error is not None: + return self._validation_message(error) + assert resolved is not None try: result = self.backend.list_dir(str(resolved)) diff --git a/core/tools/lsp/__init__.py b/core/tools/lsp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/tools/lsp/service.py b/core/tools/lsp/service.py new file mode 100644 index 000000000..dc480812d --- /dev/null +++ b/core/tools/lsp/service.py @@ -0,0 +1,838 @@ +"""LSP Service - Language Server Protocol code intelligence via multilspy. + +Registers a single DEFERRED `LSP` tool with 9 operations: + goToDefinition, findReferences, hover, documentSymbol, workspaceSymbol, + goToImplementation, prepareCallHierarchy, incomingCalls, outgoingCalls + +Sessions are managed by the process-level _LSPSessionPool singleton — they +start lazily on first use and persist for the lifetime of the process, +surviving agent restarts. Call `await lsp_pool.close_all()` on process exit. + +Supported languages (via multilspy): + python, typescript, javascript, go, rust, java, ruby, kotlin, csharp +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import shutil +import subprocess +from pathlib import Path +from typing import Any + +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema + +_FILE_SIZE_LIMIT = 10 * 1024 * 1024 # 10 MB — matches CC LSP limit + +logger = logging.getLogger(__name__) + +LSP_SCHEMA = make_tool_schema( + name="LSP", + description=( + "Language Server Protocol code intelligence. " + "Operations: goToDefinition, findReferences, hover, documentSymbol, workspaceSymbol, " + "goToImplementation, prepareCallHierarchy, incomingCalls, outgoingCalls. " + "Language servers are auto-downloaded on first use. " + "Supports python, typescript, javascript, go, rust, java, ruby, kotlin. " + "file_path must be absolute. line/character are 1-based. " + "incomingCalls/outgoingCalls require 'item' from prepareCallHierarchy output." + ), + properties={ + "operation": { + "type": "string", + "enum": [ + "goToDefinition", + "findReferences", + "hover", + "documentSymbol", + "workspaceSymbol", + "goToImplementation", + "prepareCallHierarchy", + "incomingCalls", + "outgoingCalls", + ], + "description": "LSP operation to perform", + }, + "file_path": { + "type": "string", + "description": "Absolute path to file (required for all operations except workspaceSymbol)", + }, + "line": { + "type": "integer", + "description": "1-based line number (required for goToDefinition, findReferences, hover)", + }, + "character": { + "type": "integer", + "description": "1-based character offset (required for goToDefinition, findReferences, hover)", + }, + "query": { + "type": "string", + "description": "Symbol name to search (required for workspaceSymbol)", + }, + "language": { + "type": "string", + "description": "Language override. Auto-detected from file extension if omitted.", + }, + "item": { + "type": "object", + "description": "CallHierarchyItem from prepareCallHierarchy (required for incomingCalls/outgoingCalls).", + }, + }, + required=["operation"], +) + +# File extension → multilspy language identifier +_EXT_TO_LANG: dict[str, str] = { + ".py": "python", + ".ts": "typescript", + ".tsx": "typescript", + ".js": "javascript", + ".jsx": "javascript", + ".go": "go", + ".rs": "rust", + ".java": "java", + ".rb": "ruby", + ".kt": "kotlin", + ".cs": "csharp", +} + + +def _find_pyright() -> str | None: + """Locate pyright-langserver: venv-local first, then PATH.""" + for name in ("pyright-langserver", "pyright_langserver"): + # prefer the binary in the same venv as the current interpreter + venv_bin = Path(os.__file__).parent.parent.parent / "bin" / name + if venv_bin.exists(): + return str(venv_bin) + found = shutil.which(name) + if found: + return found + return None + + +class _PyrightSession: + """Minimal asyncio LSP client for pyright-langserver (stdio). + + Used for Python operations not supported by Jedi: + goToImplementation, prepareCallHierarchy, incomingCalls, outgoingCalls. + + Requires pyright in the active venv: pip install pyright + """ + + def __init__(self, workspace_root: str) -> None: + self._workspace_root = workspace_root + self._proc: asyncio.subprocess.Process | None = None + self._pending: dict[int, asyncio.Future] = {} + self._next_id = 1 + self._reader_task: asyncio.Task | None = None + self._open_files: set[str] = set() + + async def start(self) -> None: + server = _find_pyright() + if not server: + raise RuntimeError("pyright-langserver not found. Install with: pip install pyright") + self._proc = await asyncio.create_subprocess_exec( + server, + "--stdio", + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + self._reader_task = asyncio.create_task(self._read_loop(), name="pyright-reader") + + # LSP handshake + await self._request( + "initialize", + { + "processId": os.getpid(), + "rootUri": Path(self._workspace_root).as_uri(), + "capabilities": { + "textDocument": { + "synchronization": {"dynamicRegistration": False}, + "implementation": {"dynamicRegistration": False, "linkSupport": True}, + "callHierarchy": {"dynamicRegistration": False}, + } + }, + "initializationOptions": {}, + }, + ) + self._notify("initialized", {}) + + # ── I/O ─────────────────────────────────────────────────────────── + + async def _read_loop(self) -> None: + try: + while True: + assert self._proc and self._proc.stdout + # Read headers until blank line + content_length = 0 + while True: + raw = await self._proc.stdout.readline() + if not raw: + return + line = raw.decode().rstrip() + if not line: + break + if line.lower().startswith("content-length:"): + content_length = int(line.split(":", 1)[1].strip()) + if content_length == 0: + continue + body = await self._proc.stdout.readexactly(content_length) + msg = json.loads(body) + # Route response/error to waiting Future + msg_id = msg.get("id") + msg_method = msg.get("method", "") + if msg_id is not None and msg_method: + # Server-to-client request — must acknowledge with a response + self._write({"jsonrpc": "2.0", "id": msg_id, "result": None}) + await self._drain() + elif msg_id is not None and msg_id in self._pending: + fut = self._pending.pop(msg_id) + if not fut.done(): + if "error" in msg: + fut.set_exception(RuntimeError(f"{msg['error'].get('message', 'LSP error')} ({msg['error'].get('code', '')})")) + else: + fut.set_result(msg.get("result")) + # All other notifications ($/progress, diagnostics, etc.) are silently dropped + except Exception as exc: + for fut in self._pending.values(): + if not fut.done(): + fut.set_exception(exc) + + def _write(self, msg: dict) -> None: + """Encode and buffer one LSP message (call drain() to flush).""" + assert self._proc and self._proc.stdin + body = json.dumps(msg, separators=(",", ":")).encode() + header = f"Content-Length: {len(body)}\r\n\r\n".encode() + self._proc.stdin.write(header + body) + + async def _drain(self) -> None: + assert self._proc and self._proc.stdin + await self._proc.stdin.drain() + + def _notify(self, method: str, params: Any) -> None: + self._write({"jsonrpc": "2.0", "method": method, "params": params}) + + async def _request(self, method: str, params: Any, timeout: float = 30.0) -> Any: + req_id = self._next_id + self._next_id += 1 + loop = asyncio.get_event_loop() + fut: asyncio.Future = loop.create_future() + self._pending[req_id] = fut + self._write({"jsonrpc": "2.0", "id": req_id, "method": method, "params": params}) + await self._drain() + return await asyncio.wait_for(fut, timeout=timeout) + + # ── file lifecycle ──────────────────────────────────────────────── + + def _open_file(self, abs_path: str) -> None: + uri = Path(abs_path).as_uri() + if uri in self._open_files: + return + try: + text = Path(abs_path).read_text(encoding="utf-8", errors="replace") + except OSError: + text = "" + self._notify("textDocument/didOpen", {"textDocument": {"uri": uri, "languageId": "python", "version": 1, "text": text}}) + self._open_files.add(uri) + + def _close_file(self, abs_path: str) -> None: + uri = Path(abs_path).as_uri() + if uri not in self._open_files: + return + self._notify("textDocument/didClose", {"textDocument": {"uri": uri}}) + self._open_files.discard(uri) + + def _abs(self, rel_path: str) -> str: + return str(Path(self._workspace_root) / rel_path) + + # ── LSP operations ──────────────────────────────────────────────── + + async def request_implementation(self, rel_path: str, line: int, col: int) -> list: + abs_path = self._abs(rel_path) + self._open_file(abs_path) + await self._drain() + uri = Path(abs_path).as_uri() + response = await self._request( + "textDocument/implementation", + { + "textDocument": {"uri": uri}, + "position": {"line": line, "character": col}, + }, + ) + return self._normalise_locations(response) + + async def request_prepare_call_hierarchy(self, rel_path: str, line: int, col: int) -> list: + abs_path = self._abs(rel_path) + self._open_file(abs_path) + await self._drain() + uri = Path(abs_path).as_uri() + response = await self._request( + "textDocument/prepareCallHierarchy", + { + "textDocument": {"uri": uri}, + "position": {"line": line, "character": col}, + }, + ) + # File stays open — callHierarchy/incomingCalls and outgoingCalls may need it + return response or [] + + async def request_incoming_calls(self, item: dict) -> list: + response = await self._request("callHierarchy/incomingCalls", {"item": item}) + return response or [] + + async def request_outgoing_calls(self, item: dict) -> list: + response = await self._request("callHierarchy/outgoingCalls", {"item": item}) + return response or [] + + @staticmethod + def _normalise_locations(response: Any) -> list: + if not response: + return [] + if isinstance(response, dict): + response = [response] + out = [] + for loc in response: + uri = loc.get("uri") or loc.get("targetUri", "") + rng = loc.get("range") or loc.get("targetSelectionRange") or loc.get("targetRange") or {} + out.append({"uri": uri, "absolutePath": uri.replace("file://", ""), "range": rng}) + return out + + # ── shutdown ────────────────────────────────────────────────────── + + async def stop(self) -> None: + if self._proc: + try: + await asyncio.wait_for(self._request("shutdown", {}), timeout=5) + self._notify("exit", {}) + except Exception: + pass + try: + self._proc.terminate() + await asyncio.wait_for(self._proc.wait(), timeout=5) + except Exception: + self._proc.kill() + if self._reader_task and not self._reader_task.done(): + self._reader_task.cancel() + try: + await self._reader_task + except (asyncio.CancelledError, Exception): + pass + + +class _LSPSession: + """Holds a multilspy LanguageServer alive in a background asyncio task. + + Pattern: start_server() is an async context manager that must stay open + for the lifetime of the session. We enter it inside a background Task and + use an Event to signal readiness. Stopping sets a second Event that causes + the background task to exit the context and shut down the server process. + """ + + def __init__(self, language: str, workspace_root: str) -> None: + self.language = language + self._workspace_root = workspace_root + self._ready = asyncio.Event() + self._stop = asyncio.Event() + self._task: asyncio.Task | None = None + self._lsp: Any = None + self._error: Exception | None = None + + async def start(self) -> None: + self._task = asyncio.create_task(self._run(), name=f"lsp-{self.language}") + try: + await asyncio.wait_for(asyncio.shield(self._ready.wait()), timeout=60) + except TimeoutError: + raise TimeoutError(f"LSP server for '{self.language}' did not start within 60s") + if self._error: + raise self._error + + async def _run(self) -> None: + try: + from multilspy import LanguageServer # core dep — always available + from multilspy.multilspy_config import MultilspyConfig + from multilspy.multilspy_logger import MultilspyLogger + + config = MultilspyConfig.from_dict({"code_language": self.language}) + lsp_logger = MultilspyLogger() + self._lsp = LanguageServer.create(config, lsp_logger, self._workspace_root) + async with self._lsp.start_server(): + self._ready.set() + await self._stop.wait() + except Exception as e: + self._error = e + self._ready.set() # unblock any waiters + logger.error("[LSPService] %s server error: %s", self.language, e) + + async def stop(self) -> None: + self._stop.set() + if self._task and not self._task.done(): + try: + await asyncio.wait_for(self._task, timeout=5) + except (TimeoutError, asyncio.CancelledError): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + # ── request methods ─────────────────────────────────────────────── + + async def request_definition(self, rel_path: str, line: int, col: int) -> list: + try: + return await self._lsp.request_definition(rel_path, line, col) or [] + except AssertionError: + return [] # multilspy asserts on None response (no definition found) + + async def request_references(self, rel_path: str, line: int, col: int) -> list: + try: + return await self._lsp.request_references(rel_path, line, col) or [] + except AssertionError: + return [] + + async def request_hover(self, rel_path: str, line: int, col: int) -> Any: + try: + return await self._lsp.request_hover(rel_path, line, col) + except AssertionError: + return None + + async def request_document_symbols(self, rel_path: str) -> list: + try: + symbols, _ = await self._lsp.request_document_symbols(rel_path) + return symbols or [] + except AssertionError: + return [] + + async def request_workspace_symbol(self, query: str) -> list: + return await self._lsp.request_workspace_symbol(query) or [] + + # ── advanced ops (direct server.send, for servers that support them) ── + + async def request_implementation(self, rel_path: str, line: int, col: int) -> list: + abs_uri = Path(self._workspace_root, rel_path).as_uri() + with self._lsp.open_file(rel_path): + response = await self._lsp.server.send.implementation( + {"textDocument": {"uri": abs_uri}, "position": {"line": line, "character": col}} + ) + if not response: + return [] + if isinstance(response, dict): + response = [response] + out = [] + for item in response: + if "uri" in item and "range" in item: + item.setdefault("absolutePath", item["uri"].replace("file://", "")) + out.append(item) + elif "targetUri" in item: + out.append( + { + "uri": item["targetUri"], + "absolutePath": item["targetUri"].replace("file://", ""), + "range": item.get("targetSelectionRange", item.get("targetRange", {})), + } + ) + return out + + async def request_prepare_call_hierarchy(self, rel_path: str, line: int, col: int) -> list: + abs_uri = Path(self._workspace_root, rel_path).as_uri() + with self._lsp.open_file(rel_path): + response = await self._lsp.server.send.prepare_call_hierarchy( + {"textDocument": {"uri": abs_uri}, "position": {"line": line, "character": col}} + ) + return response or [] + + async def request_incoming_calls(self, item: dict) -> list: + response = await self._lsp.server.send.incoming_calls({"item": item}) + return response or [] + + async def request_outgoing_calls(self, item: dict) -> list: + response = await self._lsp.server.send.outgoing_calls({"item": item}) + return response or [] + + +class _LSPSessionPool: + """Process-level singleton managing LSP sessions across all agent instances. + + Sessions are keyed by (language, workspace_root) and survive agent restarts. + Call close_all() once at process exit (e.g. from backend lifespan shutdown). + """ + + def __init__(self) -> None: + # (language, workspace_root) → _LSPSession + self._sessions: dict[tuple[str, str], _LSPSession] = {} + # workspace_root → _PyrightSession + self._pyright: dict[str, _PyrightSession] = {} + # In-flight start tasks to prevent duplicate starts under concurrent requests + self._starting: dict[tuple[str, str], asyncio.Task] = {} + self._starting_pyright: dict[str, asyncio.Task] = {} + + async def get_session(self, language: str, workspace_root: str) -> _LSPSession: + key = (language, workspace_root) + if key in self._sessions: + return self._sessions[key] + if key not in self._starting: + + async def _start() -> _LSPSession: + logger.info("[LSPPool] starting %s language server (workspace=%s)...", language, workspace_root) + s = _LSPSession(language, workspace_root) + await s.start() + self._sessions[key] = s + self._starting.pop(key, None) + logger.info("[LSPPool] %s language server ready", language) + return s + + self._starting[key] = asyncio.create_task(_start(), name=f"lsp-start-{language}") + return await self._starting[key] + + async def get_pyright(self, workspace_root: str) -> _PyrightSession: + if workspace_root in self._pyright: + return self._pyright[workspace_root] + if workspace_root not in self._starting_pyright: + + async def _start() -> _PyrightSession: + logger.info("[LSPPool] starting pyright (workspace=%s)...", workspace_root) + s = _PyrightSession(workspace_root) + await s.start() + self._pyright[workspace_root] = s + self._starting_pyright.pop(workspace_root, None) + logger.info("[LSPPool] pyright ready") + return s + + self._starting_pyright[workspace_root] = asyncio.create_task(_start(), name="lsp-start-pyright") + return await self._starting_pyright[workspace_root] + + async def close_all(self) -> None: + """Stop all running language server processes. Call once at process exit.""" + for (lang, ws), session in list(self._sessions.items()): + try: + await session.stop() + logger.debug("[LSPPool] stopped %s server (workspace=%s)", lang, ws) + except Exception as e: + logger.debug("[LSPPool] error stopping %s: %s", lang, e) + self._sessions.clear() + for ws, session in list(self._pyright.items()): + try: + await session.stop() + logger.debug("[LSPPool] stopped pyright (workspace=%s)", ws) + except Exception as e: + logger.debug("[LSPPool] error stopping pyright: %s", e) + self._pyright.clear() + + +# Process-level singleton — import and use directly +lsp_pool = _LSPSessionPool() + + +class LSPService: + """Registers the LSP tool (DEFERRED) into ToolRegistry. + + Delegates all session management to the process-level lsp_pool singleton. + Language servers start lazily on first use and persist across agent restarts. + """ + + # Operations that Jedi doesn't support — routed to pyright for Python, + # or to the native server.send.* for other languages. + _ADVANCED_OPS: frozenset[str] = frozenset({"goToImplementation", "prepareCallHierarchy", "incomingCalls", "outgoingCalls"}) + + def __init__(self, registry: ToolRegistry, workspace_root: str | Path) -> None: + self._workspace_root = str(Path(workspace_root).resolve()) + registry.register( + ToolEntry( + name="LSP", + mode=ToolMode.DEFERRED, + schema=LSP_SCHEMA, + handler=self._handle, + source="LSPService", + search_hint="language server definition references hover symbols go-to", + is_read_only=True, + is_concurrency_safe=True, + ) + ) + logger.debug("[LSPService] registered (workspace=%s)", self._workspace_root) + + # ── session management (delegates to process-level pool) ────────── + + async def _get_session(self, language: str) -> _LSPSession: + return await lsp_pool.get_session(language, self._workspace_root) + + async def _get_pyright(self) -> _PyrightSession: + return await lsp_pool.get_pyright(self._workspace_root) + + def _detect_language(self, file_path: str) -> str | None: + return _EXT_TO_LANG.get(Path(file_path).suffix.lower()) + + def _to_relative(self, file_path: str) -> str: + try: + return str(Path(file_path).relative_to(self._workspace_root)) + except ValueError: + return file_path # fallback: pass as-is + + # ── pre-flight checks ───────────────────────────────────────────── + + @staticmethod + def _check_file(file_path: str) -> str | None: + """Return error string if file exceeds 10 MB limit, else None.""" + try: + size = Path(file_path).stat().st_size + except OSError: + return None # let LSP handle missing file errors + if size > _FILE_SIZE_LIMIT: + mb = size / (1024 * 1024) + return f"File too large ({mb:.1f} MB). LSP file size limit is 10 MB." + return None + + def _filter_gitignored(self, locations: list) -> list: + """Filter out locations inside gitignored paths (batches of 50, like CC).""" + if not locations: + return locations + abs_paths = [loc.get("absolutePath") or loc.get("uri", "").replace("file://", "") for loc in locations] + try: + # git check-ignore exits 0 if any path is ignored, 1 if none are + result = subprocess.run( + ["git", "check-ignore", "--stdin", "-z"], + input="\0".join(abs_paths), + capture_output=True, + text=True, + cwd=self._workspace_root, + timeout=5, + ) + ignored = set(result.stdout.split("\0")) if result.stdout else set() + except Exception: + return locations # on error, return all (fail-open) + return [loc for loc, p in zip(locations, abs_paths) if p not in ignored] + + def _filter_gitignored_batched(self, locations: list) -> list: + """Run _filter_gitignored in batches of 50 (matches CC batch size).""" + out = [] + for i in range(0, len(locations), 50): + out.extend(self._filter_gitignored(locations[i : i + 50])) + return out + + async def _filter_gitignored_batched_async(self, locations: list) -> list: + return await asyncio.to_thread(self._filter_gitignored_batched, locations) + + # ── output formatters ───────────────────────────────────────────── + + @staticmethod + def _fmt_location(loc: Any) -> dict: + start = loc.get("range", {}).get("start", {}) + return { + "file": loc.get("absolutePath") or loc.get("uri", ""), + "line": start.get("line", 0), + "column": start.get("character", 0), + } + + @staticmethod + def _fmt_hover(result: Any) -> str: + contents = result.get("contents", "") + if isinstance(contents, dict): + return contents.get("value", str(contents)) + if isinstance(contents, list): + parts = [] + for c in contents: + parts.append(c.get("value", str(c)) if isinstance(c, dict) else str(c)) + return "\n".join(parts) + return str(contents) + + @staticmethod + def _fmt_symbol(sym: Any) -> dict: + loc = sym.get("location") or {} + if loc: + # SymbolInformation (workspaceSymbol) — location.uri + location.range + start = loc.get("range", {}).get("start", {}) + uri = loc.get("uri", "") + file = loc.get("absolutePath") or (uri.replace("file://", "") if uri.startswith("file://") else uri) + else: + # DocumentSymbol (documentSymbol) — range/selectionRange at top level, no file + start = sym.get("selectionRange", sym.get("range", {})).get("start", {}) + file = "" + return { + "name": sym.get("name", ""), + "kind": sym.get("kind"), + "file": file, + "line": start.get("line"), + } + + @staticmethod + def _fmt_call_hierarchy_item(item: Any) -> dict: + uri = item.get("uri", "") + start = item.get("range", {}).get("start", {}) + return { + "name": item.get("name", ""), + "kind": item.get("kind"), + "file": uri.replace("file://", "") if uri.startswith("file://") else uri, + "line": start.get("line"), + "item": item, # pass-through for incomingCalls/outgoingCalls + } + + @staticmethod + def _fmt_call_hierarchy_call(call: Any, direction: str) -> dict: + item_key = "from" if direction == "incoming" else "to" + caller = call.get(item_key, {}) + uri = caller.get("uri", "") + start = caller.get("range", {}).get("start", {}) + ranges = [r.get("start", {}) for r in call.get(f"{item_key}Ranges", [])] + return { + "name": caller.get("name", ""), + "kind": caller.get("kind"), + "file": uri.replace("file://", "") if uri.startswith("file://") else uri, + "line": start.get("line"), + "call_sites": [{"line": r.get("line"), "column": r.get("character")} for r in ranges], + "item": caller, # pass-through for chaining + } + + # ── tool handler ────────────────────────────────────────────────── + + async def _handle( + self, + operation: str, + file_path: str | None = None, + line: int | None = None, + character: int | None = None, + query: str | None = None, + language: str | None = None, + item: dict | None = None, + ) -> str: + # Resolve language (incomingCalls/outgoingCalls carry language in item["uri"]) + lang = language + if not lang and file_path: + lang = self._detect_language(file_path) + if not lang and operation in ("incomingCalls", "outgoingCalls") and item: + uri = item.get("uri", "") + lang = self._detect_language(uri) + if not lang: + supported = ", ".join(sorted(set(_EXT_TO_LANG.values()))) + return f"Cannot detect language. Set 'language' parameter. Supported: {supported}" + + # 10 MB file size guard (matches CC LSP limit) + if file_path: + err = self._check_file(file_path) + if err: + return err + + # Python advanced ops → pyright; other languages → multilspy server.send.* + use_pyright = lang == "python" and operation in self._ADVANCED_OPS + + pyright: _PyrightSession | None = None + session: _LSPSession | None = None + + if use_pyright: + try: + pyright = await self._get_pyright() + except Exception as e: + return f"Failed to start pyright language server: {e}" + else: + try: + session = await self._get_session(lang) + except Exception as e: + return f"Failed to start {lang} language server: {e}" + + rel = self._to_relative(file_path) if file_path else "" + # @@@dt-04-lsp-position-contract - CC exposes editor-facing 1-based + # positions and converts at the tool boundary. Leon must do the same + # or every position-aware operation silently lands one symbol off. + zero_line = line - 1 if line is not None else None + zero_character = character - 1 if character is not None else None + + try: + if operation == "goToDefinition": + if not file_path or zero_line is None or zero_character is None: + return "goToDefinition requires: file_path, line, character" + assert session is not None + results = await session.request_definition(rel, zero_line, zero_character) + results = await self._filter_gitignored_batched_async(results) + if not results: + return "No definition found." + return json.dumps([self._fmt_location(r) for r in results], indent=2) + + elif operation == "findReferences": + if not file_path or zero_line is None or zero_character is None: + return "findReferences requires: file_path, line, character" + assert session is not None + results = await session.request_references(rel, zero_line, zero_character) + results = await self._filter_gitignored_batched_async(results) + if not results: + return "No references found." + return json.dumps([self._fmt_location(r) for r in results], indent=2) + + elif operation == "hover": + if not file_path or zero_line is None or zero_character is None: + return "hover requires: file_path, line, character" + assert session is not None + result = await session.request_hover(rel, zero_line, zero_character) + if not result: + return "No hover info." + return self._fmt_hover(result) + + elif operation == "documentSymbol": + if not file_path: + return "documentSymbol requires: file_path" + assert session is not None + symbols = await session.request_document_symbols(rel) + if not symbols: + return "No symbols found." + return json.dumps([self._fmt_symbol(s) for s in symbols], indent=2) + + elif operation == "workspaceSymbol": + if not query: + return "workspaceSymbol requires: query" + assert session is not None + symbols = await session.request_workspace_symbol(query) + if not symbols: + return f"No symbols matching '{query}'." + return json.dumps([self._fmt_symbol(s) for s in symbols], indent=2) + + elif operation == "goToImplementation": + if not file_path or zero_line is None or zero_character is None: + return "goToImplementation requires: file_path, line, character" + src = pyright if use_pyright else session + assert src is not None + results = await src.request_implementation(rel, zero_line, zero_character) + results = await self._filter_gitignored_batched_async(results) + if not results: + return "No implementation found." + return json.dumps([self._fmt_location(r) for r in results], indent=2) + + elif operation == "prepareCallHierarchy": + if not file_path or zero_line is None or zero_character is None: + return "prepareCallHierarchy requires: file_path, line, character" + src = pyright if use_pyright else session + assert src is not None + items = await src.request_prepare_call_hierarchy(rel, zero_line, zero_character) + if not items: + return "No call hierarchy items found." + return json.dumps([self._fmt_call_hierarchy_item(i) for i in items], indent=2) + + elif operation == "incomingCalls": + if not item: + return "incomingCalls requires: item (CallHierarchyItem from prepareCallHierarchy)" + src = pyright if use_pyright else session + assert src is not None + calls = await src.request_incoming_calls(item) + if not calls: + return "No incoming calls found." + return json.dumps([self._fmt_call_hierarchy_call(c, "incoming") for c in calls], indent=2) + + elif operation == "outgoingCalls": + if not item: + return "outgoingCalls requires: item (CallHierarchyItem from prepareCallHierarchy)" + src = pyright if use_pyright else session + assert src is not None + calls = await src.request_outgoing_calls(item) + if not calls: + return "No outgoing calls found." + return json.dumps([self._fmt_call_hierarchy_call(c, "outgoing") for c in calls], indent=2) + + else: + return ( + f"Unknown operation '{operation}'. " + "Valid: goToDefinition, findReferences, hover, documentSymbol, workspaceSymbol, " + "goToImplementation, prepareCallHierarchy, incomingCalls, outgoingCalls" + ) + + except Exception as e: + logger.exception("[LSPService] operation=%s failed", operation) + return f"LSP error: {e}" diff --git a/core/tools/mcp_resources/service.py b/core/tools/mcp_resources/service.py new file mode 100644 index 000000000..bf44c2cbc --- /dev/null +++ b/core/tools/mcp_resources/service.py @@ -0,0 +1,155 @@ +"""Expose MCP resource discovery and reading as agent-callable deferred tools.""" + +from __future__ import annotations + +import base64 +import json +from collections.abc import Callable +from typing import Any + +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema + +LIST_MCP_RESOURCES_SCHEMA = make_tool_schema( + name="ListMcpResources", + description="List MCP resources exposed by connected MCP servers.", + properties={ + "server": { + "type": "string", + "description": "Optional MCP server name to filter by.", + "minLength": 1, + } + }, +) + +READ_MCP_RESOURCE_SCHEMA = make_tool_schema( + name="ReadMcpResource", + description="Read a specific MCP resource by server name and URI.", + properties={ + "server": { + "type": "string", + "description": "MCP server name.", + "minLength": 1, + }, + "uri": { + "type": "string", + "description": "Resource URI to read.", + "minLength": 1, + }, + }, + required=["server", "uri"], +) + + +class McpResourceToolService: + def __init__( + self, + *, + registry: ToolRegistry, + client_fn: Callable[[], Any | None], + server_configs_fn: Callable[[], dict[str, Any]], + ) -> None: + self._client_fn = client_fn + self._server_configs_fn = server_configs_fn + if not self._server_configs_fn(): + return + self._register(registry) + + def _register(self, registry: ToolRegistry) -> None: + for name, schema, handler in [ + ("ListMcpResources", LIST_MCP_RESOURCES_SCHEMA, self._list_resources), + ("ReadMcpResource", READ_MCP_RESOURCE_SCHEMA, self._read_resource), + ]: + registry.register( + ToolEntry( + name=name, + mode=ToolMode.DEFERRED, + schema=schema, + handler=handler, + source="McpResourceToolService", + is_concurrency_safe=True, + is_read_only=True, + ) + ) + + def _get_client(self) -> Any: + client = self._client_fn() + if client is None: + raise ValueError("MCP client is not initialized") + return client + + def _available_servers(self) -> list[str]: + return list(self._server_configs_fn().keys()) + + @staticmethod + def _stringify_uri(value: Any) -> str | None: + if value is None: + return None + return str(value) + + async def _list_resources(self, server: str | None = None, **_kwargs: Any) -> str: + client = self._get_client() + server_names = [server] if server else self._available_servers() + if server and server not in self._available_servers(): + raise ValueError(f'MCP server not found: "{server}"') + + items: list[dict[str, Any]] = [] + for server_name in server_names: + async with client.session(server_name) as session: + result = await session.list_resources() + for resource in result.resources: + items.append( + { + "server": server_name, + "uri": self._stringify_uri(resource.uri), + "name": getattr(resource, "name", self._stringify_uri(resource.uri)), + "mime_type": getattr(resource, "mimeType", None), + "description": getattr(resource, "description", None), + } + ) + return json.dumps({"items": items, "total": len(items)}, ensure_ascii=False, indent=2) + + async def _read_resource(self, *, server: str, uri: str, **_kwargs: Any) -> str: + client = self._get_client() + if server not in self._available_servers(): + raise ValueError(f'MCP server not found: "{server}"') + + async with client.session(server) as session: + result = await session.read_resource(uri) + + contents: list[dict[str, Any]] = [] + for content in result.contents: + if hasattr(content, "text"): + contents.append( + { + "uri": self._stringify_uri(content.uri), + "mime_type": getattr(content, "mimeType", None), + "text": content.text, + } + ) + continue + if hasattr(content, "blob"): + blob_size = len(base64.b64decode(content.blob)) + contents.append( + { + "uri": self._stringify_uri(content.uri), + "mime_type": getattr(content, "mimeType", None), + "text": f"Binary MCP resource omitted from context ({blob_size} bytes).", + } + ) + continue + contents.append( + { + "uri": self._stringify_uri(getattr(content, "uri", uri)), + "mime_type": getattr(content, "mimeType", None), + } + ) + + return json.dumps( + { + "server": server, + "uri": uri, + "contents": contents, + }, + ensure_ascii=False, + indent=2, + ) diff --git a/core/tools/search/service.py b/core/tools/search/service.py index 4329de6e4..a6ff0a4d4 100644 --- a/core/tools/search/service.py +++ b/core/tools/search/service.py @@ -12,11 +12,16 @@ import subprocess from pathlib import Path -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema DEFAULT_EXCLUDES: list[str] = [ "node_modules", ".git", + ".svn", + ".hg", + ".bzr", + ".jj", + ".sl", "__pycache__", ".venv", "venv", @@ -50,67 +55,76 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Grep", mode=ToolMode.INLINE, - schema={ - "name": "Grep", - "description": "Search file contents using regex patterns.", - "parameters": { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Regex pattern to search for", - }, - "path": { - "type": "string", - "description": "File or directory (absolute). Defaults to workspace.", - }, - "glob": { - "type": "string", - "description": "Filter files by glob (e.g., '*.py')", - }, - "type": { - "type": "string", - "description": "Filter by file type (e.g., 'py', 'js')", - }, - "case_insensitive": { - "type": "boolean", - "description": "Case insensitive search", - }, - "after_context": { - "type": "integer", - "description": "Lines to show after each match", - }, - "before_context": { - "type": "integer", - "description": "Lines to show before each match", - }, - "context": { - "type": "integer", - "description": "Context lines before and after each match", - }, - "output_mode": { - "type": "string", - "enum": ["content", "files_with_matches", "count"], - "description": "Output format. Default: files_with_matches", - }, - "head_limit": { - "type": "integer", - "description": "Limit to first N entries", - }, - "offset": { - "type": "integer", - "description": "Skip first N entries", - }, - "multiline": { - "type": "boolean", - "description": "Allow pattern to span multiple lines", - }, + schema=make_tool_schema( + name="Grep", + description=( + "Regex search across files (ripgrep-based). " + "Default output_mode: files_with_matches (sorted by mtime). Default head_limit: 250 entries. " + "Auto-excludes .git/.svn/.hg dirs. Max column width 500 chars (suppresses minified/base64). " + "Use output_mode='content' with after_context/before_context/context for context lines." + ), + properties={ + "pattern": { + "type": "string", + "description": "Regex pattern to search for", + }, + "path": { + "type": "string", + "description": "File or directory (absolute). Defaults to workspace.", + }, + "glob": { + "type": "string", + "description": "Filter files by glob (e.g., '*.py')", + }, + "type": { + "type": "string", + "description": "Filter by file type (e.g., 'py', 'js')", + }, + "case_insensitive": { + "type": "boolean", + "description": "Case insensitive search", + }, + "after_context": { + "type": "integer", + "description": "Lines to show after each match", + }, + "before_context": { + "type": "integer", + "description": "Lines to show before each match", + }, + "context": { + "type": "integer", + "description": "Context lines before and after each match", + }, + "output_mode": { + "type": "string", + "enum": ["content", "files_with_matches", "count"], + "description": "Output format. Default: files_with_matches", + }, + "head_limit": { + "type": "integer", + "description": "Limit to first N entries", + }, + "offset": { + "type": "integer", + "description": "Skip first N entries", + }, + "multiline": { + "type": "boolean", + "description": "Allow pattern to span multiple lines", + }, + "line_numbers": { + "type": "boolean", + "description": "Show line numbers (default true). Only applies with output_mode='content'.", }, - "required": ["pattern"], }, - }, + required=["pattern"], + ), handler=self._grep, source="SearchService", + search_hint="search file contents regex pattern matching ripgrep", + is_read_only=True, + is_concurrency_safe=True, ) ) @@ -118,26 +132,30 @@ def _register(self, registry: ToolRegistry) -> None: ToolEntry( name="Glob", mode=ToolMode.INLINE, - schema={ - "name": "Glob", - "description": "Find files by glob pattern. Returns paths sorted by modification time.", - "parameters": { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Glob pattern (e.g., '**/*.py')", - }, - "path": { - "type": "string", - "description": "Directory to search (absolute). Defaults to workspace.", - }, + schema=make_tool_schema( + name="Glob", + description=( + "Fast file pattern matching (ripgrep-based). Returns paths sorted by modification time. " + "Includes hidden files, ignores .gitignore. Default limit 100 results. " + "Use '**/*.py' for recursive search. Path must be absolute." + ), + properties={ + "pattern": { + "type": "string", + "description": "Glob pattern (e.g., '**/*.py')", + }, + "path": { + "type": "string", + "description": "Directory to search (absolute). Defaults to workspace.", }, - "required": ["pattern"], }, - }, + required=["pattern"], + ), handler=self._glob, source="SearchService", + search_hint="find files by name glob pattern matching", + is_read_only=True, + is_concurrency_safe=True, ) ) @@ -183,9 +201,10 @@ def _grep( before_context: int | None = None, context: int | None = None, output_mode: str = "files_with_matches", - head_limit: int | None = None, + head_limit: int | None = 250, offset: int | None = None, multiline: bool = False, + line_numbers: bool = True, ) -> str: ok, error, resolved = self._validate_path(path) if not ok: @@ -209,6 +228,7 @@ def _grep( head_limit=head_limit, offset=offset, multiline=multiline, + line_numbers=line_numbers, ) except Exception: pass # fallback to Python @@ -238,8 +258,9 @@ def _ripgrep_search( head_limit: int | None, offset: int | None, multiline: bool, + line_numbers: bool = True, ) -> str: - cmd: list[str] = ["rg", pattern, str(path)] + cmd: list[str] = ["rg", pattern, str(path), "--max-columns", "500"] for excl in DEFAULT_EXCLUDES: cmd.extend(["--glob", f"!{excl}"]) @@ -258,7 +279,8 @@ def _ripgrep_search( elif output_mode == "count": cmd.append("--count") elif output_mode == "content": - cmd.extend(["--line-number", "--no-heading"]) + ln_flag = "--line-number" if line_numbers else "--no-line-number" + cmd.extend([ln_flag, "--no-heading"]) if context is not None: cmd.extend(["-C", str(context)]) else: diff --git a/core/tools/skills/service.py b/core/tools/skills/service.py index e65215a20..17c0b842a 100644 --- a/core/tools/skills/service.py +++ b/core/tools/skills/service.py @@ -9,9 +9,10 @@ from __future__ import annotations import re +from collections.abc import Sequence from pathlib import Path -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema class SkillsService: @@ -20,7 +21,7 @@ class SkillsService: def __init__( self, registry: ToolRegistry, - skill_paths: list[str | Path], + skill_paths: Sequence[str | Path], enabled_skills: dict[str, bool] | None = None, ): self.skill_paths = [Path(p).expanduser().resolve() for p in skill_paths] @@ -65,6 +66,8 @@ def _register(self, registry: ToolRegistry) -> None: schema=self._get_schema, handler=self._load_skill, source="SkillsService", + is_concurrency_safe=True, + is_read_only=True, ) ) @@ -72,24 +75,22 @@ def _get_schema(self) -> dict: available_skills = list(self._skills_index.keys()) skills_list = "\n".join(f"- {name}" for name in available_skills) - return { - "name": "load_skill", - "description": ( - f"Load a specialized skill to access domain-specific knowledge and workflows.\n\n" - f"Available skills:\n{skills_list}\n\n" - f"Returns the skill's instructions and context." + return make_tool_schema( + name="load_skill", + description=( + f"Load a skill for domain-specific guidance. " + f"Use when you need specialized workflows (TDD, debugging, git). " + f"Skills are loaded on-demand to save context.\n\n" + f"Available skills:\n{skills_list}" ), - "parameters": { - "type": "object", - "properties": { - "skill_name": { - "type": "string", - "description": f"Name of the skill to load. Available: {', '.join(self._skills_index.keys())}", - }, + properties={ + "skill_name": { + "type": "string", + "description": f"Name of the skill to load. Available: {', '.join(self._skills_index.keys())}", }, - "required": ["skill_name"], }, - } + required=["skill_name"], + ) def _load_skill(self, skill_name: str) -> str: if skill_name not in self._skills_index: diff --git a/core/tools/task/service.py b/core/tools/task/service.py index b6e9f6f96..114b2939d 100644 --- a/core/tools/task/service.py +++ b/core/tools/task/service.py @@ -13,117 +13,109 @@ from typing import Any from backend.web.core.storage_factory import make_tool_task_repo -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema from core.tools.task.types import Task, TaskStatus logger = logging.getLogger(__name__) DEFAULT_DB_PATH = Path.home() / ".leon" / "tasks.db" -TASK_CREATE_SCHEMA = { - "name": "TaskCreate", - "description": ("Create a new task to track work progress. Tasks are created with status 'pending'."), - "parameters": { - "type": "object", - "properties": { - "subject": { - "type": "string", - "description": "Brief task title in imperative form", - }, - "description": { - "type": "string", - "description": "Detailed description of what needs to be done", - }, - "active_form": { - "type": "string", - "description": "Present continuous form for spinner display", - }, - "metadata": { - "type": "object", - "description": "Optional metadata to attach to the task", - }, +TASK_CREATE_SCHEMA = make_tool_schema( + name="TaskCreate", + description=( + "Create a task to track multi-step work. " + "Use for complex tasks with 3+ steps or when managing multiple parallel workstreams. " + "Status starts as 'pending'." + ), + properties={ + "subject": { + "type": "string", + "description": "Brief task title in imperative form", }, - "required": ["subject", "description"], - }, -} - -TASK_GET_SCHEMA = { - "name": "TaskGet", - "description": "Get full details of a task including description and dependencies.", - "parameters": { - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID to retrieve", - }, + "description": { + "type": "string", + "description": "Detailed description of what needs to be done", + }, + "active_form": { + "type": "string", + "description": "Present continuous form for spinner display", + }, + "metadata": { + "type": "object", + "description": "Optional metadata to attach to the task", }, - "required": ["task_id"], }, -} - -TASK_LIST_SCHEMA = { - "name": "TaskList", - "description": ("List all tasks with summary info: id, subject, status, owner, blockedBy."), - "parameters": { - "type": "object", - "properties": {}, + required=["subject", "description"], +) + +TASK_GET_SCHEMA = make_tool_schema( + name="TaskGet", + description="Get full details of a task including description and dependencies.", + properties={ + "task_id": { + "type": "string", + "description": "The task ID to retrieve", + }, }, -} - -TASK_UPDATE_SCHEMA = { - "name": "TaskUpdate", - "description": ( + required=["task_id"], +) + +TASK_LIST_SCHEMA = make_tool_schema( + name="TaskList", + description="List all tasks with summary info: id, subject, status, owner, blockedBy.", + properties={}, +) + +TASK_UPDATE_SCHEMA = make_tool_schema( + name="TaskUpdate", + description=( "Update a task's status, dependencies, or other fields. " "Status flow: pending -> in_progress -> completed. " "Use status='deleted' to remove a task." ), - "parameters": { - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID to update", - }, - "status": { - "type": "string", - "enum": ["pending", "in_progress", "completed", "deleted"], - "description": "New status for the task", - }, - "subject": { - "type": "string", - "description": "New subject for the task", - }, - "description": { - "type": "string", - "description": "New description for the task", - }, - "active_form": { - "type": "string", - "description": "New activeForm for the task", - }, - "owner": { - "type": "string", - "description": "Assign task to an agent", - }, - "add_blocks": { - "type": "array", - "items": {"type": "string"}, - "description": "Task IDs that this task blocks", - }, - "add_blocked_by": { - "type": "array", - "items": {"type": "string"}, - "description": "Task IDs that block this task", - }, - "metadata": { - "type": "object", - "description": "Metadata keys to merge (set key to null to delete)", - }, + properties={ + "task_id": { + "type": "string", + "description": "The task ID to update", + }, + "status": { + "type": "string", + "enum": ["pending", "in_progress", "completed", "deleted"], + "description": "New status for the task", + }, + "subject": { + "type": "string", + "description": "New subject for the task", + }, + "description": { + "type": "string", + "description": "New description for the task", + }, + "active_form": { + "type": "string", + "description": "New activeForm for the task", + }, + "owner": { + "type": "string", + "description": "Assign task to an agent", + }, + "add_blocks": { + "type": "array", + "items": {"type": "string"}, + "description": "Task IDs that this task blocks", + }, + "add_blocked_by": { + "type": "array", + "items": {"type": "string"}, + "description": "Task IDs that block this task", + }, + "metadata": { + "type": "object", + "description": "Metadata keys to merge (set key to null to delete)", }, - "required": ["task_id"], }, -} + required=["task_id"], +) class TaskService: @@ -139,7 +131,7 @@ class TaskService: def __init__( self, registry: ToolRegistry, - workspace_root: str | None = None, + workspace_root: str | Path | None = None, db_path: Path | None = None, thread_id: str | None = None, ): @@ -157,12 +149,14 @@ def _get_thread_id(self) -> str: return tid or "default" def _register(self, registry: ToolRegistry) -> None: + read_only = {"TaskGet", "TaskList"} for name, schema, handler in [ ("TaskCreate", TASK_CREATE_SCHEMA, self._create), ("TaskGet", TASK_GET_SCHEMA, self._get), ("TaskList", TASK_LIST_SCHEMA, self._list), ("TaskUpdate", TASK_UPDATE_SCHEMA, self._update), ]: + ro = name in read_only registry.register( ToolEntry( name=name, @@ -170,6 +164,8 @@ def _register(self, registry: ToolRegistry) -> None: schema=schema, handler=handler, source="TaskService", + is_concurrency_safe=ro, + is_read_only=ro, ) ) diff --git a/core/tools/tool_search/service.py b/core/tools/tool_search/service.py index 9b5ceba77..234007182 100644 --- a/core/tools/tool_search/service.py +++ b/core/tools/tool_search/service.py @@ -9,24 +9,26 @@ import json import logging -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema logger = logging.getLogger(__name__) -TOOL_SEARCH_SCHEMA = { - "name": "tool_search", - "description": ("Search for available tools. Use this to discover tools that might help with your task."), - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query - tool name or description of what you want to do", - }, +TOOL_SEARCH_SCHEMA = make_tool_schema( + name="tool_search", + description=( + "Search for available deferred tools by name or keyword. " + "Use 'select:ToolA,ToolB' for exact deferred-tool lookup (returns full schema). " + "Use keywords for fuzzy search (up to 5 results). " + "Deferred tools are only usable after discovery via this tool." + ), + properties={ + "query": { + "type": "string", + "description": "Search query. Use 'select:ToolA,ToolB' for exact deferred-tool lookup, or keywords for fuzzy search.", }, - "required": ["query"], }, -} + required=["query"], +) class ToolSearchService: @@ -41,11 +43,34 @@ def __init__(self, registry: ToolRegistry): schema=TOOL_SEARCH_SCHEMA, handler=self._search, source="ToolSearchService", + is_concurrency_safe=True, + is_read_only=True, ) ) logger.info("ToolSearchService initialized") - def _search(self, query: str = "", **kwargs) -> str: - results = self._registry.search(query) + def _search(self, query: str = "", tool_context=None, **kwargs) -> str: + select_names: list[str] = [] + normalized = query.strip() + if normalized.lower().startswith("select:"): + select_names = [name.strip() for name in normalized[len("select:") :].split(",") if name.strip()] + + results = self._registry.search(query, modes={ToolMode.DEFERRED}) + if select_names: + found_names = {entry.name for entry in results} + missing = [name for name in select_names if name not in found_names] + inline = [name for name in missing if (entry := self._registry.get(name)) is not None and entry.mode == ToolMode.INLINE] + unknown = [name for name in missing if self._registry.get(name) is None] + if inline or unknown: + parts: list[str] = [] + if inline: + parts.append(f"inline/already-available tools: {', '.join(inline)}") + if unknown: + parts.append(f"unknown tools: {', '.join(unknown)}") + raise ValueError("tool_search select: only supports deferred tools; " + "; ".join(parts)) + else: + results = results[:5] + if tool_context is not None and hasattr(tool_context, "discovered_tool_names"): + tool_context.discovered_tool_names.update(entry.name for entry in results) schemas = [e.get_schema() for e in results] return json.dumps(schemas, indent=2, ensure_ascii=False) diff --git a/core/tools/web/middleware.py b/core/tools/web/middleware.py index fedf1708e..f244a5bfb 100644 --- a/core/tools/web/middleware.py +++ b/core/tools/web/middleware.py @@ -103,8 +103,8 @@ async def _web_search_impl( self, Query: str, MaxResults: int | None = None, - IncludeDomains: list[str] | None = None, - ExcludeDomains: list[str] | None = None, + AllowedDomains: list[str] | None = None, + BlockedDomains: list[str] | None = None, ) -> SearchResult: """ 实现 web_search(多提供商降级) @@ -121,8 +121,8 @@ async def _web_search_impl( result = await searcher.search( query=Query, max_results=max_results, - include_domains=IncludeDomains, - exclude_domains=ExcludeDomains, + include_domains=AllowedDomains, + exclude_domains=BlockedDomains, ) if not result.error: return result @@ -217,12 +217,12 @@ def _get_tool_definitions(self) -> list[dict]: "type": "integer", "description": "Maximum number of results (default: 5)", }, - "IncludeDomains": { + "AllowedDomains": { "type": "array", "items": {"type": "string"}, "description": "Only include results from these domains", }, - "ExcludeDomains": { + "BlockedDomains": { "type": "array", "items": {"type": "string"}, "description": "Exclude results from these domains", @@ -281,8 +281,8 @@ async def _handle_tool_call(self, tool_name: str, args: dict, tool_call_id: str) result = await self._web_search_impl( Query=args.get("Query", ""), MaxResults=args.get("MaxResults"), - IncludeDomains=args.get("IncludeDomains"), - ExcludeDomains=args.get("ExcludeDomains"), + AllowedDomains=args.get("AllowedDomains"), + BlockedDomains=args.get("BlockedDomains"), ) return ToolMessage(content=result.format_output(), tool_call_id=tool_call_id) diff --git a/core/tools/web/service.py b/core/tools/web/service.py index 077db9b70..02d2f12e8 100644 --- a/core/tools/web/service.py +++ b/core/tools/web/service.py @@ -10,7 +10,7 @@ import asyncio from typing import Any -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry +from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry, make_tool_schema from core.tools.web.fetchers.jina import JinaFetcher from core.tools.web.fetchers.markdownify import MarkdownifyFetcher from core.tools.web.searchers.exa import ExaSearcher @@ -59,64 +59,74 @@ def _register(self, registry: ToolRegistry) -> None: registry.register( ToolEntry( name="WebSearch", - mode=ToolMode.INLINE, - schema={ - "name": "WebSearch", - "description": "Search the web for current information. Returns titles, URLs, and snippets.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query", - }, - "max_results": { - "type": "integer", - "description": "Maximum number of results (default: 5)", - }, - "include_domains": { - "type": "array", - "items": {"type": "string"}, - "description": "Only include results from these domains", - }, - "exclude_domains": { - "type": "array", - "items": {"type": "string"}, - "description": "Exclude results from these domains", - }, + mode=ToolMode.DEFERRED, + schema=make_tool_schema( + name="WebSearch", + description=( + "Search the web. Returns titles, URLs, and text snippets. " + "Use for current events, documentation lookups, or fact-checking. Max 10 results per query." + ), + properties={ + "query": { + "type": "string", + "description": "Search query", + "minLength": 1, + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results (default: 5)", + "minimum": 1, + "maximum": 10, + }, + "allowed_domains": { + "type": "array", + "items": {"type": "string"}, + "description": "Only include results from these domains", + }, + "blocked_domains": { + "type": "array", + "items": {"type": "string"}, + "description": "Exclude results from these domains", }, - "required": ["query"], }, - }, + required=["query"], + ), handler=self._web_search, source="WebService", + is_concurrency_safe=True, + is_read_only=True, ) ) registry.register( ToolEntry( name="WebFetch", - mode=ToolMode.INLINE, - schema={ - "name": "WebFetch", - "description": "Fetch a URL and extract specific information using AI. Returns processed content, not raw HTML.", - "parameters": { - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "URL to fetch content from", - }, - "prompt": { - "type": "string", - "description": "What information to extract from the page", - }, + mode=ToolMode.DEFERRED, + schema=make_tool_schema( + name="WebFetch", + description=( + "Fetch a URL and extract specific information via AI. Returns processed text, not raw HTML. " + "Provide a focused prompt describing what to extract. " + "Useful for reading documentation pages, API references, or articles." + ), + properties={ + "url": { + "type": "string", + "description": "URL to fetch content from", + "minLength": 1, + }, + "prompt": { + "type": "string", + "description": "What information to extract from the page", + "minLength": 1, }, - "required": ["url", "prompt"], }, - }, + required=["url", "prompt"], + ), handler=self._web_fetch, source="WebService", + is_concurrency_safe=True, + is_read_only=True, ) ) @@ -124,8 +134,8 @@ async def _web_search( self, query: str, max_results: int | None = None, - include_domains: list[str] | None = None, - exclude_domains: list[str] | None = None, + allowed_domains: list[str] | None = None, + blocked_domains: list[str] | None = None, ) -> str: if not self._searchers: return "No search providers configured" @@ -137,8 +147,8 @@ async def _web_search( result: SearchResult = await searcher.search( query=query, max_results=effective_max, - include_domains=include_domains, - exclude_domains=exclude_domains, + include_domains=allowed_domains, + exclude_domains=blocked_domains, ) if not result.error: return result.format_output() diff --git a/core/tools/wechat/service.py b/core/tools/wechat/service.py deleted file mode 100644 index 9cb57e233..000000000 --- a/core/tools/wechat/service.py +++ /dev/null @@ -1,109 +0,0 @@ -"""WeChat tool service — registers wechat_send and wechat_contacts into ToolRegistry. - -Thin wrapper: actual API calls go through WeChatConnection (backend). -Tools are scoped to the agent's owner's user_id (the human who connected WeChat). -""" - -from __future__ import annotations - -import logging -from collections.abc import Callable -from typing import TYPE_CHECKING - -from core.runtime.registry import ToolEntry, ToolMode, ToolRegistry - -if TYPE_CHECKING: - from backend.web.services.wechat_service import WeChatConnection - -logger = logging.getLogger(__name__) - - -class WeChatToolService: - """Registers WeChat tools for agents to interact with WeChat contacts. - - @@@lazy-connection — connection_fn is called at tool invocation time, not registration. - This avoids import-time dependency on app.state. - """ - - def __init__(self, registry: ToolRegistry, connection_fn: Callable[[], WeChatConnection | None]) -> None: - self._get_conn = connection_fn - self._register(registry) - - def _register(self, registry: ToolRegistry) -> None: - self._register_wechat_send(registry) - self._register_wechat_contacts(registry) - - def _register_wechat_send(self, registry: ToolRegistry) -> None: - get_conn = self._get_conn - - async def handle(user_id: str, text: str) -> str: - conn = get_conn() - if not conn or not conn.connected: - return "Error: WeChat is not connected. Ask the owner to connect via the Connections page." - try: - await conn.send_message(user_id, text) - return f"Message sent to {user_id.split('@')[0]}" - except RuntimeError as e: - return f"Error: {e}" - - registry.register( - ToolEntry( - name="wechat_send", - mode=ToolMode.INLINE, - schema={ - "name": "wechat_send", - "description": ( - "Send a text message to a WeChat user via the connected WeChat bot.\n" - "Use wechat_contacts to find available user_ids.\n" - "The user must have messaged the bot first before you can reply.\n" - "Keep messages concise — WeChat is a chat app. Use plain text, no markdown." - ), - "parameters": { - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "WeChat user ID (format: xxx@im.wechat). Get from wechat_contacts.", - }, - "text": { - "type": "string", - "description": "Plain text message to send. No markdown — WeChat won't render it.", - }, - }, - "required": ["user_id", "text"], - }, - }, - handler=handle, - source="wechat", - ) - ) - - def _register_wechat_contacts(self, registry: ToolRegistry) -> None: - get_conn = self._get_conn - - def handle() -> str: - conn = get_conn() - if not conn or not conn.connected: - return "WeChat is not connected." - contacts = conn.list_contacts() - if not contacts: - return "No WeChat contacts yet. Users need to message the bot first." - lines = [f"- {c['display_name']} [user_id: {c['user_id']}]" for c in contacts] - return "\n".join(lines) - - registry.register( - ToolEntry( - name="wechat_contacts", - mode=ToolMode.INLINE, - schema={ - "name": "wechat_contacts", - "description": "List WeChat contacts who have messaged the bot. Returns user_ids for use with wechat_send.", - "parameters": { - "type": "object", - "properties": {}, - }, - }, - handler=handle, - source="wechat", - ) - ) diff --git a/docker-compose.yml b/docker-compose.yml index cb302edf3..15c3e7c7a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,6 +3,10 @@ services: build: context: . dockerfile: Dockerfile + volumes: + # @@@staging-leon-home-volume - staging runtime state (models/members/sandboxes) + # must survive container replacement, otherwise each deploy boots with an empty ~/.leon. + - leon-home:/root/.leon restart: unless-stopped frontend: @@ -14,3 +18,6 @@ services: depends_on: - backend restart: unless-stopped + +volumes: + leon-home: diff --git a/docs/en/introduction.mdx b/docs/en/introduction.mdx index 306238336..84e35bd7d 100644 --- a/docs/en/introduction.mdx +++ b/docs/en/introduction.mdx @@ -49,7 +49,7 @@ flowchart LR direction LR H["Human Entity"] A["Agent Entity"] - H <-->|chat_send / chat_read| A + H <-->|send_message / read_messages| A end subgraph Infra["Infrastructure"] diff --git a/docs/en/multi-agent-chat.mdx b/docs/en/multi-agent-chat.mdx index 6a10e8fec..2da8a8591 100644 --- a/docs/en/multi-agent-chat.mdx +++ b/docs/en/multi-agent-chat.mdx @@ -3,7 +3,7 @@ title: Multi-agent chat sidebarTitle: Social layer description: How humans and agents communicate on the Mycel social layer icon: comments -keywords: [entity, chat, agent communication, social, directory, chat_send, SSE] +keywords: [entity, chat, agent communication, social, list_chats, send_message, SSE] --- Mycel's social layer lets humans and agents coexist as equals in a shared messaging environment. Agents can initiate conversations, forward context to teammates, and collaborate autonomously — without any special orchestration code. @@ -19,7 +19,7 @@ flowchart LR direction TB HE["Human Entity"] AE["Agent Entity"] - HE <-->|"chat_send / chat_read"| AE + HE <-->|"send_message / read_messages"| AE end T --> Chat @@ -53,42 +53,33 @@ Every participant on the platform — human or agent — has an **Entity**. When ## Agent chat tools -Agents have five built-in tools for social interaction: +Agents have four built-in tools for social interaction: - - Browse all known Entities. Returns Entity IDs needed for other tools. - - ```text - directory(search="Alice", type="human") - → - Alice [human] entity_id=m_abc123-1 - ``` - - - + List the agent's active chats with unread counts and last message preview. ```text - chats(unread_only=true) + list_chats(unread_only=true) → - Alice [m_abc123-1] (3 unread) — last: "Can you help me with..." ``` - + Read message history in a chat. Automatically marks messages as read. ```text - chat_read(entity_id="m_abc123-1", limit=10) + read_messages(entity_id="m_abc123-1", limit=10) → [Alice]: Can you help me with this bug? [you]: Sure, let me take a look. ``` - + Send a message. The agent must read unread messages before sending (enforced by the system). ```text - chat_send(content="Here's the fix.", entity_id="m_abc123-1") + send_message(content="Here's the fix.", entity_id="m_abc123-1") ``` **Signal protocol** controls conversation flow: @@ -100,11 +91,11 @@ Agents have five built-in tools for social interaction: | `close` | "Conversation over, do not reply" | - + Search through message history across all chats or within a specific chat. ```text - chat_search(query="bug fix", entity_id="m_abc123-1") + search_messages(query="bug fix", entity_id="m_abc123-1") ``` @@ -124,15 +115,15 @@ sequenceDiagram API->>H: SSE push (message event) API->>Q: Enqueue notification Q->>T: Wake thread (if idle) - T->>API: chat_read (get actual message) + T->>API: read_messages (get actual message) T->>T: Process message - T->>API: chat_send (response) + T->>API: send_message (response) API->>DB: Store response API->>H: SSE push (message event) ``` - Notifications don't include message content — the agent must call `chat_read` to read them. This enforces a consistent **read → respond** pattern and prevents agents from acting on stale summaries. + Notifications don't include message content — the agent must call `read_messages` to read them. This enforces a consistent **read → respond** pattern and prevents agents from acting on stale summaries. ## Real-time updates diff --git a/docs/en/quickstart.mdx b/docs/en/quickstart.mdx index 91954831c..204f99163 100644 --- a/docs/en/quickstart.mdx +++ b/docs/en/quickstart.mdx @@ -100,7 +100,7 @@ Mycel's social layer lets agents message each other — and you — like a group - In the first agent's thread, tell it to message your code reviewer: "Ask the code reviewer to look at this function." The agent will call `chat_send` and the reviewer will respond autonomously. + In the first agent's thread, tell it to message your code reviewer: "Ask the code reviewer to look at this function." The agent will call `send_message` and the reviewer will respond autonomously. diff --git a/docs/zh/introduction.mdx b/docs/zh/introduction.mdx index fdc5e8693..9566e8cfe 100644 --- a/docs/zh/introduction.mdx +++ b/docs/zh/introduction.mdx @@ -49,7 +49,7 @@ flowchart LR direction LR H["人类 Entity"] A["Agent Entity"] - H <-->|"chat_send / chat_read"| A + H <-->|"send_message / read_messages"| A end subgraph Infra["基础设施"] diff --git a/docs/zh/multi-agent-chat.mdx b/docs/zh/multi-agent-chat.mdx index 3a44bd48c..4fb44940a 100644 --- a/docs/zh/multi-agent-chat.mdx +++ b/docs/zh/multi-agent-chat.mdx @@ -3,7 +3,7 @@ title: 多 Agent 通讯 sidebarTitle: 社交层 description: 人与 Agent 如何在 Mycel 社交层中通讯 icon: comments -keywords: [entity, chat, agent 通讯, 社交, directory, chat_send, SSE] +keywords: [entity, chat, agent 通讯, 社交, list_chats, send_message, SSE] --- Mycel 的社交层让人与 Agent 在共享的消息环境中平等共存。Agent 可以主动发起对话、把上下文转发给队友、自主协作 — 无需任何特殊的编排代码。 @@ -19,7 +19,7 @@ flowchart LR direction TB HE["人类 Entity"] AE["Agent Entity"] - HE <-->|"chat_send / chat_read"| AE + HE <-->|"send_message / read_messages"| AE end T --> Chat @@ -52,39 +52,30 @@ flowchart LR ## Agent 聊天工具 - - 浏览所有已知的 Entity,返回其他工具需要的 Entity ID。 - - ```text - directory(search="Alice", type="human") - → - Alice [human] entity_id=m_abc123-1 - ``` - - - + 列出 Agent 的活跃对话,包含未读数和最新消息预览。 ```text - chats(unread_only=true) + list_chats(unread_only=true) → - Alice [m_abc123-1] (3 条未读) — 最新:"能帮我看看..." ``` - + 读取对话消息历史,自动标记为已读。 ```text - chat_read(entity_id="m_abc123-1", limit=10) + read_messages(entity_id="m_abc123-1", limit=10) → [Alice]: 能帮我看看这个 bug 吗? [you]: 好的,我来看看。 ``` - + 发送消息。系统强制要求 Agent 先读取未读消息再发送。 ```text - chat_send(content="这是修复方案。", entity_id="m_abc123-1") + send_message(content="这是修复方案。", entity_id="m_abc123-1") ``` **信号协议**控制对话流转: @@ -96,11 +87,11 @@ flowchart LR | `close` | "对话结束,不需要回复" | - + 在所有对话或指定对话中搜索消息历史。 ```text - chat_search(query="bug 修复", entity_id="m_abc123-1") + search_messages(query="bug 修复", entity_id="m_abc123-1") ``` @@ -120,15 +111,15 @@ sequenceDiagram API->>H: SSE 推送(message 事件) API->>Q: 入队通知 Q->>T: 唤醒 Thread(若空闲) - T->>API: chat_read(读取实际消息) + T->>API: read_messages(读取实际消息) T->>T: 处理消息 - T->>API: chat_send(回复) + T->>API: send_message(回复) API->>DB: 存储回复 API->>H: SSE 推送(message 事件) ``` - 通知不包含消息内容 — Agent 必须调用 `chat_read` 才能读到。这强制执行「先读后发」的一致模式。 + 通知不包含消息内容 — Agent 必须调用 `read_messages` 才能读到。这强制执行「先读后发」的一致模式。 ## 联系人与投递设置 diff --git a/docs/zh/quickstart.mdx b/docs/zh/quickstart.mdx index 884bf09f4..37c67e8c8 100644 --- a/docs/zh/quickstart.mdx +++ b/docs/zh/quickstart.mdx @@ -100,7 +100,7 @@ Mycel 的社交层让 Agent 之间可以像群聊一样互相发消息。 - 在第一个 Agent 的 Thread 中,告诉它去联系代码审查员:「帮我把这个函数发给代码审查员看看。」Agent 会调用 `chat_send` 工具,审查员会自主回复。 + 在第一个 Agent 的 Thread 中,告诉它去联系代码审查员:「帮我把这个函数发给代码审查员看看。」Agent 会调用 `send_message` 工具,审查员会自主回复。 diff --git a/frontend/app/package-lock.json b/frontend/app/package-lock.json index 8af285c77..96b3f10b2 100644 --- a/frontend/app/package-lock.json +++ b/frontend/app/package-lock.json @@ -62,6 +62,7 @@ }, "devDependencies": { "@eslint/js": "^9.39.1", + "@testing-library/react": "^16.3.2", "@types/node": "^24.10.1", "@types/react": "^19.2.5", "@types/react-dom": "^19.2.3", @@ -71,6 +72,7 @@ "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-refresh": "^0.4.24", "globals": "^16.5.0", + "jsdom": "^28.1.0", "kimi-plugin-inspect-react": "^1.0.3", "postcss": "^8.5.6", "tailwindcss": "^3.4.19", @@ -78,9 +80,17 @@ "tw-animate-css": "^1.4.0", "typescript": "~5.9.3", "typescript-eslint": "^8.46.4", - "vite": "^7.2.4" + "vite": "^7.2.4", + "vitest": "^4.1.2" } }, + "node_modules/@acemir/cssom": { + "version": "0.9.31", + "resolved": "https://registry.npmjs.org/@acemir/cssom/-/cssom-0.9.31.tgz", + "integrity": "sha512-ZnR3GSaH+/vJ0YlHau21FjfLYjMpYVIzTD8M8vIEQvIGxeOXyXdzCI140rrCY862p/C/BbzWsjc1dgnM9mkoTA==", + "dev": true, + "license": "MIT" + }, "node_modules/@alloc/quick-lru": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/@alloc/quick-lru/-/quick-lru-5.2.0.tgz", @@ -94,6 +104,64 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/@asamuzakjp/css-color": { + "version": "5.1.5", + "resolved": "https://registry.npmjs.org/@asamuzakjp/css-color/-/css-color-5.1.5.tgz", + "integrity": "sha512-8cMAA1bE66Mb/tfmkhcfJLjEPgyT7SSy6lW6id5XL113ai1ky76d/1L27sGnXCMsLfq66DInAU3OzuahB4lu9Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@csstools/css-calc": "^3.1.1", + "@csstools/css-color-parser": "^4.0.2", + "@csstools/css-parser-algorithms": "^4.0.0", + "@csstools/css-tokenizer": "^4.0.0", + "lru-cache": "^11.2.7" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + } + }, + "node_modules/@asamuzakjp/css-color/node_modules/lru-cache": { + "version": "11.2.7", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.7.tgz", + "integrity": "sha512-aY/R+aEsRelme17KGQa/1ZSIpLpNYYrhcrepKTZgE+W3WM16YMCaPwOHLHsmopZHELU0Ojin1lPVxKR0MihncA==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@asamuzakjp/dom-selector": { + "version": "6.8.1", + "resolved": "https://registry.npmjs.org/@asamuzakjp/dom-selector/-/dom-selector-6.8.1.tgz", + "integrity": "sha512-MvRz1nCqW0fsy8Qz4dnLIvhOlMzqDVBabZx6lH+YywFDdjXhMY37SmpV1XFX3JzG5GWHn63j6HX6QPr3lZXHvQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@asamuzakjp/nwsapi": "^2.3.9", + "bidi-js": "^1.0.3", + "css-tree": "^3.1.0", + "is-potential-custom-element-name": "^1.0.1", + "lru-cache": "^11.2.6" + } + }, + "node_modules/@asamuzakjp/dom-selector/node_modules/lru-cache": { + "version": "11.2.7", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.7.tgz", + "integrity": "sha512-aY/R+aEsRelme17KGQa/1ZSIpLpNYYrhcrepKTZgE+W3WM16YMCaPwOHLHsmopZHELU0Ojin1lPVxKR0MihncA==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@asamuzakjp/nwsapi": { + "version": "2.3.9", + "resolved": "https://registry.npmjs.org/@asamuzakjp/nwsapi/-/nwsapi-2.3.9.tgz", + "integrity": "sha512-n8GuYSrI9bF7FFZ/SjhwevlHc8xaVlb/7HmHelnc/PZXBD2ZR49NnN9sMMuDdEGPeeRQ5d0hqlSlEpgCX3Wl0Q==", + "dev": true, + "license": "MIT" + }, "node_modules/@babel/code-frame": { "version": "7.28.6", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.28.6.tgz", @@ -125,7 +193,6 @@ "integrity": "sha512-e7jT4DxYvIDLk1ZHmU/m/mB19rex9sv0c2ftBtjSBv+kVM/902eh0fINUzD7UwLLNR+jU585GxUJ8/EBfAM5fw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@babel/code-frame": "^7.27.1", "@babel/generator": "^7.28.5", @@ -1846,6 +1913,159 @@ "node": ">=6.9.0" } }, + "node_modules/@bramus/specificity": { + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/@bramus/specificity/-/specificity-2.4.2.tgz", + "integrity": "sha512-ctxtJ/eA+t+6q2++vj5j7FYX3nRu311q1wfYH3xjlLOsczhlhxAg2FWNUXhpGvAw3BWo1xBcvOV6/YLc2r5FJw==", + "dev": true, + "license": "MIT", + "dependencies": { + "css-tree": "^3.0.0" + }, + "bin": { + "specificity": "bin/cli.js" + } + }, + "node_modules/@csstools/color-helpers": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/@csstools/color-helpers/-/color-helpers-6.0.2.tgz", + "integrity": "sha512-LMGQLS9EuADloEFkcTBR3BwV/CGHV7zyDxVRtVDTwdI2Ca4it0CCVTT9wCkxSgokjE5Ho41hEPgb8OEUwoXr6Q==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT-0", + "engines": { + "node": ">=20.19.0" + } + }, + "node_modules/@csstools/css-calc": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/@csstools/css-calc/-/css-calc-3.1.1.tgz", + "integrity": "sha512-HJ26Z/vmsZQqs/o3a6bgKslXGFAungXGbinULZO3eMsOyNJHeBBZfup5FiZInOghgoM4Hwnmw+OgbJCNg1wwUQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "engines": { + "node": ">=20.19.0" + }, + "peerDependencies": { + "@csstools/css-parser-algorithms": "^4.0.0", + "@csstools/css-tokenizer": "^4.0.0" + } + }, + "node_modules/@csstools/css-color-parser": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@csstools/css-color-parser/-/css-color-parser-4.0.2.tgz", + "integrity": "sha512-0GEfbBLmTFf0dJlpsNU7zwxRIH0/BGEMuXLTCvFYxuL1tNhqzTbtnFICyJLTNK4a+RechKP75e7w42ClXSnJQw==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "dependencies": { + "@csstools/color-helpers": "^6.0.2", + "@csstools/css-calc": "^3.1.1" + }, + "engines": { + "node": ">=20.19.0" + }, + "peerDependencies": { + "@csstools/css-parser-algorithms": "^4.0.0", + "@csstools/css-tokenizer": "^4.0.0" + } + }, + "node_modules/@csstools/css-parser-algorithms": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/@csstools/css-parser-algorithms/-/css-parser-algorithms-4.0.0.tgz", + "integrity": "sha512-+B87qS7fIG3L5h3qwJ/IFbjoVoOe/bpOdh9hAjXbvx0o8ImEmUsGXN0inFOnk2ChCFgqkkGFQ+TpM5rbhkKe4w==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "engines": { + "node": ">=20.19.0" + }, + "peerDependencies": { + "@csstools/css-tokenizer": "^4.0.0" + } + }, + "node_modules/@csstools/css-syntax-patches-for-csstree": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@csstools/css-syntax-patches-for-csstree/-/css-syntax-patches-for-csstree-1.1.2.tgz", + "integrity": "sha512-5GkLzz4prTIpoyeUiIu3iV6CSG3Plo7xRVOFPKI7FVEJ3mZ0A8SwK0XU3Gl7xAkiQ+mDyam+NNp875/C5y+jSA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT-0", + "peerDependencies": { + "css-tree": "^3.2.1" + }, + "peerDependenciesMeta": { + "css-tree": { + "optional": true + } + } + }, + "node_modules/@csstools/css-tokenizer": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/@csstools/css-tokenizer/-/css-tokenizer-4.0.0.tgz", + "integrity": "sha512-QxULHAm7cNu72w97JUNCBFODFaXpbDg+dP8b/oWFAZ2MTRppA3U00Y2L1HqaS4J6yBqxwa/Y3nMBaxVKbB/NsA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/csstools" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/csstools" + } + ], + "license": "MIT", + "engines": { + "node": ">=20.19.0" + } + }, "node_modules/@date-fns/tz": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/@date-fns/tz/-/tz-1.4.1.tgz", @@ -2451,6 +2671,24 @@ "node": "^18.18.0 || ^20.9.0 || >=21.1.0" } }, + "node_modules/@exodus/bytes": { + "version": "1.15.0", + "resolved": "https://registry.npmjs.org/@exodus/bytes/-/bytes-1.15.0.tgz", + "integrity": "sha512-UY0nlA+feH81UGSHv92sLEPLCeZFjXOuHhrIo0HQydScuQc8s0A7kL/UdgwgDq8g8ilksmuoF35YVTNphV2aBQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + }, + "peerDependencies": { + "@noble/hashes": "^1.8.0 || ^2.0.0" + }, + "peerDependenciesMeta": { + "@noble/hashes": { + "optional": true + } + } + }, "node_modules/@floating-ui/core": { "version": "1.7.3", "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.7.3.tgz", @@ -4607,12 +4845,76 @@ "win32" ] }, + "node_modules/@standard-schema/spec": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", + "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", + "dev": true, + "license": "MIT" + }, "node_modules/@standard-schema/utils": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/@standard-schema/utils/-/utils-0.3.0.tgz", "integrity": "sha512-e7Mew686owMaPJVNNLs55PUvgz371nKgwsc4vxE49zsODpJEnxgxRo2y/OKrqueavXgZNMDVj3DdHFlaSAeU8g==", "license": "MIT" }, + "node_modules/@testing-library/dom": { + "version": "10.4.1", + "resolved": "https://registry.npmjs.org/@testing-library/dom/-/dom-10.4.1.tgz", + "integrity": "sha512-o4PXJQidqJl82ckFaXUeoAW+XysPLauYI43Abki5hABd853iMhitooc6znOnczgbTYmEP6U6/y1ZyKAIsvMKGg==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "@babel/code-frame": "^7.10.4", + "@babel/runtime": "^7.12.5", + "@types/aria-query": "^5.0.1", + "aria-query": "5.3.0", + "dom-accessibility-api": "^0.5.9", + "lz-string": "^1.5.0", + "picocolors": "1.1.1", + "pretty-format": "^27.0.2" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@testing-library/react": { + "version": "16.3.2", + "resolved": "https://registry.npmjs.org/@testing-library/react/-/react-16.3.2.tgz", + "integrity": "sha512-XU5/SytQM+ykqMnAnvB2umaJNIOsLF3PVv//1Ew4CTcpz0/BRyy/af40qqrt7SjKpDdT1saBMc42CUok5gaw+g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.12.5" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@testing-library/dom": "^10.0.0", + "@types/react": "^18.0.0 || ^19.0.0", + "@types/react-dom": "^18.0.0 || ^19.0.0", + "react": "^18.0.0 || ^19.0.0", + "react-dom": "^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@types/aria-query": { + "version": "5.0.4", + "resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz", + "integrity": "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==", + "dev": true, + "license": "MIT", + "peer": true + }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", @@ -4658,6 +4960,17 @@ "@babel/types": "^7.28.2" } }, + "node_modules/@types/chai": { + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/@types/chai/-/chai-5.2.3.tgz", + "integrity": "sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/deep-eql": "*", + "assertion-error": "^2.0.1" + } + }, "node_modules/@types/d3-array": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.2.tgz", @@ -4730,6 +5043,13 @@ "@types/ms": "*" } }, + "node_modules/@types/deep-eql": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@types/deep-eql/-/deep-eql-4.0.2.tgz", + "integrity": "sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/diff": { "version": "7.0.2", "resolved": "https://registry.npmjs.org/@types/diff/-/diff-7.0.2.tgz", @@ -4788,7 +5108,6 @@ "integrity": "sha512-vnDVpYPMzs4wunl27jHrfmwojOGKya0xyM3sH+UE5iv5uPS6vX7UIoh6m+vQc5LGBq52HBKPIn/zcSZVzeDEZg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "undici-types": "~7.16.0" } @@ -4799,7 +5118,6 @@ "integrity": "sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg==", "devOptional": true, "license": "MIT", - "peer": true, "dependencies": { "csstype": "^3.2.2" } @@ -4810,7 +5128,6 @@ "integrity": "sha512-jp2L/eY6fn+KgVVQAOqYItbF0VY/YApe5Mz2F0aykSO8gx31bYCZyvSeYxCHKvzHG5eZjc+zyaS5BrBWya2+kQ==", "devOptional": true, "license": "MIT", - "peer": true, "peerDependencies": { "@types/react": "^19.2.0" } @@ -4866,7 +5183,6 @@ "integrity": "sha512-iIACsx8pxRnguSYhHiMn2PvhvfpopO9FXHyn1mG5txZIsAaB6F0KwbFnUQN3KCiG3Jcuad/Cao2FAs1Wp7vAyg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.52.0", "@typescript-eslint/types": "8.52.0", @@ -5118,82 +5434,215 @@ "vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" } }, - "node_modules/acorn": { - "version": "8.15.0", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", - "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", + "node_modules/@vitest/expect": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-4.1.2.tgz", + "integrity": "sha512-gbu+7B0YgUJ2nkdsRJrFFW6X7NTP44WlhiclHniUhxADQJH5Szt9mZ9hWnJPJ8YwOK5zUOSSlSvyzRf0u1DSBQ==", "dev": true, "license": "MIT", - "peer": true, - "bin": { - "acorn": "bin/acorn" + "dependencies": { + "@standard-schema/spec": "^1.1.0", + "@types/chai": "^5.2.2", + "@vitest/spy": "4.1.2", + "@vitest/utils": "4.1.2", + "chai": "^6.2.2", + "tinyrainbow": "^3.1.0" }, - "engines": { - "node": ">=0.4.0" + "funding": { + "url": "https://opencollective.com/vitest" } }, - "node_modules/acorn-jsx": { - "version": "5.3.2", - "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", - "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "node_modules/@vitest/mocker": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-4.1.2.tgz", + "integrity": "sha512-Ize4iQtEALHDttPRCmN+FKqOl2vxTiNUhzobQFFt/BM1lRUTG7zRCLOykG/6Vo4E4hnUdfVLo5/eqKPukcWW7Q==", "dev": true, "license": "MIT", + "dependencies": { + "@vitest/spy": "4.1.2", + "estree-walker": "^3.0.3", + "magic-string": "^0.30.21" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, "peerDependencies": { - "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + "msw": "^2.4.9", + "vite": "^6.0.0 || ^7.0.0 || ^8.0.0" + }, + "peerDependenciesMeta": { + "msw": { + "optional": true + }, + "vite": { + "optional": true + } } }, - "node_modules/ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "node_modules/@vitest/pretty-format": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-4.1.2.tgz", + "integrity": "sha512-dwQga8aejqeuB+TvXCMzSQemvV9hNEtDDpgUKDzOmNQayl2OG241PSWeJwKRH3CiC+sESrmoFd49rfnq7T4RnA==", "dev": true, "license": "MIT", "dependencies": { - "fast-deep-equal": "^3.1.1", - "fast-json-stable-stringify": "^2.0.0", - "json-schema-traverse": "^0.4.1", - "uri-js": "^4.2.2" + "tinyrainbow": "^3.1.0" }, "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" + "url": "https://opencollective.com/vitest" } }, - "node_modules/ansi-styles": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", - "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "node_modules/@vitest/runner": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-4.1.2.tgz", + "integrity": "sha512-Gr+FQan34CdiYAwpGJmQG8PgkyFVmARK8/xSijia3eTFgVfpcpztWLuP6FttGNfPLJhaZVP/euvujeNYar36OQ==", "dev": true, "license": "MIT", "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" + "@vitest/utils": "4.1.2", + "pathe": "^2.0.3" }, "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" + "url": "https://opencollective.com/vitest" } }, - "node_modules/any-promise": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/any-promise/-/any-promise-1.3.0.tgz", - "integrity": "sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==", - "dev": true, - "license": "MIT" - }, - "node_modules/anymatch": { - "version": "3.1.3", - "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", - "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "node_modules/@vitest/snapshot": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-4.1.2.tgz", + "integrity": "sha512-g7yfUmxYS4mNxk31qbOYsSt2F4m1E02LFqO53Xpzg3zKMhLAPZAjjfyl9e6z7HrW6LvUdTwAQR3HHfLjpko16A==", "dev": true, - "license": "ISC", + "license": "MIT", "dependencies": { - "normalize-path": "^3.0.0", - "picomatch": "^2.0.4" + "@vitest/pretty-format": "4.1.2", + "@vitest/utils": "4.1.2", + "magic-string": "^0.30.21", + "pathe": "^2.0.3" }, - "engines": { - "node": ">= 8" + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/spy": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-4.1.2.tgz", + "integrity": "sha512-DU4fBnbVCJGNBwVA6xSToNXrkZNSiw59H8tcuUspVMsBDBST4nfvsPsEHDHGtWRRnqBERBQu7TrTKskmjqTXKA==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/utils": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-4.1.2.tgz", + "integrity": "sha512-xw2/TiX82lQHA06cgbqRKFb5lCAy3axQ4H4SoUFhUsg+wztiet+co86IAMDtF6Vm1hc7J6j09oh/rgDn+JdKIQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.1.2", + "convert-source-map": "^2.0.0", + "tinyrainbow": "^3.1.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/acorn": { + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", + "dev": true, + "license": "MIT", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-jsx": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", + "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/agent-base": { + "version": "7.1.4", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-7.1.4.tgz", + "integrity": "sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 14" + } + }, + "node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "peer": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/any-promise": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/any-promise/-/any-promise-1.3.0.tgz", + "integrity": "sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==", + "dev": true, + "license": "MIT" + }, + "node_modules/anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "dev": true, + "license": "ISC", + "dependencies": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + }, + "engines": { + "node": ">= 8" } }, "node_modules/anymatch/node_modules/picomatch": { @@ -5235,6 +5684,27 @@ "node": ">=10" } }, + "node_modules/aria-query": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/aria-query/-/aria-query-5.3.0.tgz", + "integrity": "sha512-b0P0sZPKtyu8HkeRAfCq0IfURZK+SuwMjY1UXGBU27wpAiTwQAIlq56IbIO+ytk/JjS1fMR14ee5WBBfKi5J6A==", + "dev": true, + "license": "Apache-2.0", + "peer": true, + "dependencies": { + "dequal": "^2.0.3" + } + }, + "node_modules/assertion-error": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-2.0.1.tgz", + "integrity": "sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + } + }, "node_modules/autoprefixer": { "version": "10.4.23", "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.23.tgz", @@ -5341,6 +5811,16 @@ "baseline-browser-mapping": "dist/cli.js" } }, + "node_modules/bidi-js": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/bidi-js/-/bidi-js-1.0.3.tgz", + "integrity": "sha512-RKshQI1R3YQ+n9YJz2QQ147P66ELpa1FQEg20Dk8oW9t2KgLbpDLLp9aGZ7y8WHSshDknG0bknqGw5/tyCs5tw==", + "dev": true, + "license": "MIT", + "dependencies": { + "require-from-string": "^2.0.2" + } + }, "node_modules/binary-extensions": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", @@ -5398,7 +5878,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -5464,6 +5943,16 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/chai": { + "version": "6.2.2", + "resolved": "https://registry.npmjs.org/chai/-/chai-6.2.2.tgz", + "integrity": "sha512-NUPRluOfOiTKBKvWPtSD4PhFvWCqOi0BGStNWs57X9js7XGTprSmFoz5F0tWhR4WPjNeR9jXqdC7/UpSJTnlRg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/chalk": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", @@ -5692,6 +6181,20 @@ "node": ">= 8" } }, + "node_modules/css-tree": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/css-tree/-/css-tree-3.2.1.tgz", + "integrity": "sha512-X7sjQzceUhu1u7Y/ylrRZFU2FS6LRiFVp6rKLPg23y3x3c3DOKAwuXGDp+PAGjh6CSnCjYeAul8pcT8bAl+lSA==", + "dev": true, + "license": "MIT", + "dependencies": { + "mdn-data": "2.27.1", + "source-map-js": "^1.2.1" + }, + "engines": { + "node": "^10 || ^12.20.0 || ^14.13.0 || >=15.0.0" + } + }, "node_modules/cssesc": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", @@ -5705,6 +6208,32 @@ "node": ">=4" } }, + "node_modules/cssstyle": { + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/cssstyle/-/cssstyle-6.2.0.tgz", + "integrity": "sha512-Fm5NvhYathRnXNVndkUsCCuR63DCLVVwGOOwQw782coXFi5HhkXdu289l59HlXZBawsyNccXfWRYvLzcDCdDig==", + "dev": true, + "license": "MIT", + "dependencies": { + "@asamuzakjp/css-color": "^5.0.1", + "@csstools/css-syntax-patches-for-csstree": "^1.0.28", + "css-tree": "^3.1.0", + "lru-cache": "^11.2.6" + }, + "engines": { + "node": ">=20" + } + }, + "node_modules/cssstyle/node_modules/lru-cache": { + "version": "11.2.7", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.7.tgz", + "integrity": "sha512-aY/R+aEsRelme17KGQa/1ZSIpLpNYYrhcrepKTZgE+W3WM16YMCaPwOHLHsmopZHELU0Ojin1lPVxKR0MihncA==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": "20 || >=22" + } + }, "node_modules/csstype": { "version": "3.2.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz", @@ -5832,6 +6361,20 @@ "node": ">=12" } }, + "node_modules/data-urls": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/data-urls/-/data-urls-7.0.0.tgz", + "integrity": "sha512-23XHcCF+coGYevirZceTVD7NdJOqVn+49IHyxgszm+JIiHLoB2TkmPtsYkNWT1pvRSGkc35L6NHs0yHkN2SumA==", + "dev": true, + "license": "MIT", + "dependencies": { + "whatwg-mimetype": "^5.0.0", + "whatwg-url": "^16.0.0" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + } + }, "node_modules/date-fns": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/date-fns/-/date-fns-4.1.0.tgz", @@ -5865,6 +6408,13 @@ } } }, + "node_modules/decimal.js": { + "version": "10.6.0", + "resolved": "https://registry.npmjs.org/decimal.js/-/decimal.js-10.6.0.tgz", + "integrity": "sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==", + "dev": true, + "license": "MIT" + }, "node_modules/decimal.js-light": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/decimal.js-light/-/decimal.js-light-2.5.1.tgz", @@ -5942,6 +6492,14 @@ "dev": true, "license": "MIT" }, + "node_modules/dom-accessibility-api": { + "version": "0.5.16", + "resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.5.16.tgz", + "integrity": "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==", + "dev": true, + "license": "MIT", + "peer": true + }, "node_modules/dom-helpers": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/dom-helpers/-/dom-helpers-5.2.1.tgz", @@ -5963,8 +6521,7 @@ "version": "8.6.0", "resolved": "https://registry.npmjs.org/embla-carousel/-/embla-carousel-8.6.0.tgz", "integrity": "sha512-SjWyZBHJPbqxHOzckOfo8lHisEaJWmwd23XppYFYVh10bU66/Pn5tkVkbkCMZVdbUE5eTCI2nD8OyIP4Z+uwkA==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/embla-carousel-react": { "version": "8.6.0", @@ -6000,6 +6557,13 @@ "url": "https://github.com/fb55/entities?sponsor=1" } }, + "node_modules/es-module-lexer": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-2.0.0.tgz", + "integrity": "sha512-5POEcUuZybH7IdmGsD8wlf0AI55wMecM9rVBTI/qEAy2c1kTOm3DjFYjrBdI2K3BaJjJYfYFeRtM0t9ssnRuxw==", + "dev": true, + "license": "MIT" + }, "node_modules/esbuild": { "version": "0.27.2", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.27.2.tgz", @@ -6071,7 +6635,6 @@ "integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -6250,6 +6813,16 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/estree-walker": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", + "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0" + } + }, "node_modules/esutils": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", @@ -6266,6 +6839,16 @@ "integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==", "license": "MIT" }, + "node_modules/expect-type": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.3.0.tgz", + "integrity": "sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/extend": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", @@ -6697,6 +7280,19 @@ "hermes-estree": "0.25.1" } }, + "node_modules/html-encoding-sniffer": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-6.0.0.tgz", + "integrity": "sha512-CV9TW3Y3f8/wT0BRFc1/KAVQ3TUHiXmaAb6VW9vtiMFf7SLoMd1PdAc4W3KFOFETBJUb90KatHqlsZMWV+R9Gg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@exodus/bytes": "^1.6.0" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + } + }, "node_modules/html-url-attributes": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/html-url-attributes/-/html-url-attributes-3.0.1.tgz", @@ -6717,6 +7313,34 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/http-proxy-agent": { + "version": "7.0.2", + "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-7.0.2.tgz", + "integrity": "sha512-T1gkAiYYDWYx3V5Bmyu7HcfcvL7mUrTWiM6yOfa3PIphViJ/gFPbvidQ+veqSOHci/PxBcDabeUNCzpOODJZig==", + "dev": true, + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.0", + "debug": "^4.3.4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/https-proxy-agent": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.6.tgz", + "integrity": "sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==", + "dev": true, + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.2", + "debug": "4" + }, + "engines": { + "node": ">= 14" + } + }, "node_modules/ignore": { "version": "5.3.2", "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", @@ -6897,6 +7521,13 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/is-potential-custom-element-name": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-potential-custom-element-name/-/is-potential-custom-element-name-1.0.1.tgz", + "integrity": "sha512-bCYeRA2rVibKZd+s2625gGnGF/t7DSqDs4dP7CrLA1m7jKWz6pps0LpYLJN8Q64HtmPKJ1hrN3nzPNKFEKOUiQ==", + "dev": true, + "license": "MIT" + }, "node_modules/isexe": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", @@ -6910,7 +7541,6 @@ "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "dev": true, "license": "MIT", - "peer": true, "bin": { "jiti": "bin/jiti.js" } @@ -6934,6 +7564,60 @@ "js-yaml": "bin/js-yaml.js" } }, + "node_modules/jsdom": { + "version": "28.1.0", + "resolved": "https://registry.npmjs.org/jsdom/-/jsdom-28.1.0.tgz", + "integrity": "sha512-0+MoQNYyr2rBHqO1xilltfDjV9G7ymYGlAUazgcDLQaUf8JDHbuGwsxN6U9qWaElZ4w1B2r7yEGIL3GdeW3Rug==", + "dev": true, + "license": "MIT", + "dependencies": { + "@acemir/cssom": "^0.9.31", + "@asamuzakjp/dom-selector": "^6.8.1", + "@bramus/specificity": "^2.4.2", + "@exodus/bytes": "^1.11.0", + "cssstyle": "^6.0.1", + "data-urls": "^7.0.0", + "decimal.js": "^10.6.0", + "html-encoding-sniffer": "^6.0.0", + "http-proxy-agent": "^7.0.2", + "https-proxy-agent": "^7.0.6", + "is-potential-custom-element-name": "^1.0.1", + "parse5": "^8.0.0", + "saxes": "^6.0.0", + "symbol-tree": "^3.2.4", + "tough-cookie": "^6.0.0", + "undici": "^7.21.0", + "w3c-xmlserializer": "^5.0.0", + "webidl-conversions": "^8.0.1", + "whatwg-mimetype": "^5.0.0", + "whatwg-url": "^16.0.0", + "xml-name-validator": "^5.0.0" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + }, + "peerDependencies": { + "canvas": "^3.0.0" + }, + "peerDependenciesMeta": { + "canvas": { + "optional": true + } + } + }, + "node_modules/jsdom/node_modules/parse5": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/parse5/-/parse5-8.0.0.tgz", + "integrity": "sha512-9m4m5GSgXjL4AjumKzq1Fgfp3Z8rsvjRNbnkVwfu2ImRqE5D0LnY2QfDen18FSY9C573YU5XxSapdHZTZ2WolA==", + "dev": true, + "license": "MIT", + "dependencies": { + "entities": "^6.0.0" + }, + "funding": { + "url": "https://github.com/inikulin/parse5?sponsor=1" + } + }, "node_modules/jsesc": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-3.1.0.tgz", @@ -7121,6 +7805,17 @@ "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, + "node_modules/lz-string": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz", + "integrity": "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==", + "dev": true, + "license": "MIT", + "peer": true, + "bin": { + "lz-string": "bin/bin.js" + } + }, "node_modules/magic-string": { "version": "0.30.21", "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.21.tgz", @@ -7435,6 +8130,13 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/mdn-data": { + "version": "2.27.1", + "resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.27.1.tgz", + "integrity": "sha512-9Yubnt3e8A0OKwxYSXyhLymGW4sCufcLG6VdiDdUGVkPhpqLxlvP5vl1983gQjJl3tqbrM731mjaZaP68AgosQ==", + "dev": true, + "license": "CC0-1.0" + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -8138,6 +8840,17 @@ "node": ">= 6" } }, + "node_modules/obug": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/obug/-/obug-2.1.1.tgz", + "integrity": "sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/sxzz", + "https://opencollective.com/debug" + ], + "license": "MIT" + }, "node_modules/optionator": { "version": "0.9.4", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", @@ -8265,6 +8978,13 @@ "dev": true, "license": "MIT" }, + "node_modules/pathe": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz", + "integrity": "sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==", + "dev": true, + "license": "MIT" + }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", @@ -8278,7 +8998,6 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -8326,7 +9045,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -8480,6 +9198,44 @@ "node": ">= 0.8.0" } }, + "node_modules/pretty-format": { + "version": "27.5.1", + "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-27.5.1.tgz", + "integrity": "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==", + "dev": true, + "license": "MIT", + "peer": true, + "dependencies": { + "ansi-regex": "^5.0.1", + "ansi-styles": "^5.0.0", + "react-is": "^17.0.1" + }, + "engines": { + "node": "^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0" + } + }, + "node_modules/pretty-format/node_modules/ansi-styles": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", + "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", + "dev": true, + "license": "MIT", + "peer": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/pretty-format/node_modules/react-is": { + "version": "17.0.2", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz", + "integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==", + "dev": true, + "license": "MIT", + "peer": true + }, "node_modules/prop-types": { "version": "15.8.1", "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", @@ -8552,7 +9308,6 @@ "resolved": "https://registry.npmjs.org/react/-/react-19.2.3.tgz", "integrity": "sha512-Ku/hhYbVjOQnXDZFv2+RibmLFGwFdeeKHFcOTlrt7xplBnya5OGn/hIRDsqDiSUcfORsDC7MPxwork8jBwsIWA==", "license": "MIT", - "peer": true, "engines": { "node": ">=0.10.0" } @@ -8583,7 +9338,6 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.3.tgz", "integrity": "sha512-yELu4WmLPw5Mr/lmeEpox5rw3RETacE++JgHqQzd2dg+YbJuat3jH4ingc+WPZhxaoFzdv9y33G+F7Nl5O0GBg==", "license": "MIT", - "peer": true, "dependencies": { "scheduler": "^0.27.0" }, @@ -8596,7 +9350,6 @@ "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.70.0.tgz", "integrity": "sha512-COOMajS4FI3Wuwrs3GPpi/Jeef/5W1DRR84Yl5/ShlT3dKVFUfoGiEZ/QE6Uw8P4T2/CLJdcTVYKvWBMQTEpvw==", "license": "MIT", - "peer": true, "engines": { "node": ">=18.0.0" }, @@ -9008,6 +9761,16 @@ "integrity": "sha512-4ZJgIB9EG9fQE41mOJCRHMmnxDTKHWawQoJWZyUbZuj680wVyogu2ihnj8Edqm7vh2mo/TWHyEZpn2kqeDvS7w==", "license": "Apache-2.0" }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/resolve": { "version": "1.22.11", "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.11.tgz", @@ -9119,6 +9882,19 @@ "queue-microtask": "^1.2.2" } }, + "node_modules/saxes": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/saxes/-/saxes-6.0.0.tgz", + "integrity": "sha512-xAg7SOnEhrm5zI3puOOKyy1OMcMlIJZYNJY7xLBwSze0UjhPLnWfj2GF2EpT0jmzaJKIWKHLsaSSajf35bcYnA==", + "dev": true, + "license": "ISC", + "dependencies": { + "xmlchars": "^2.2.0" + }, + "engines": { + "node": ">=v12.22.7" + } + }, "node_modules/scheduler": { "version": "0.27.0", "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.27.0.tgz", @@ -9164,6 +9940,13 @@ "node": ">=8" } }, + "node_modules/siginfo": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz", + "integrity": "sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==", + "dev": true, + "license": "ISC" + }, "node_modules/sonner": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/sonner/-/sonner-2.0.7.tgz", @@ -9194,6 +9977,20 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/stackback": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz", + "integrity": "sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==", + "dev": true, + "license": "MIT" + }, + "node_modules/std-env": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/std-env/-/std-env-4.0.0.tgz", + "integrity": "sha512-zUMPtQ/HBY3/50VbpkupYHbRroTRZJPRLvreamgErJVys0ceuzMkD44J/QjqhHjOzK42GQ3QZIeFG1OYfOtKqQ==", + "dev": true, + "license": "MIT" + }, "node_modules/streamdown": { "version": "2.4.0", "resolved": "https://registry.npmjs.org/streamdown/-/streamdown-2.4.0.tgz", @@ -9315,6 +10112,13 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/symbol-tree": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/symbol-tree/-/symbol-tree-3.2.4.tgz", + "integrity": "sha512-9QNk5KwDF+Bvz+PyObkmSYjI5ksVUYtjW7AU22r2NKcfLJcXp96hkDWU3+XndOsUb+AQ9QhfzfCT2O+CNWT5Tw==", + "dev": true, + "license": "MIT" + }, "node_modules/tailwind-merge": { "version": "3.4.0", "resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-3.4.0.tgz", @@ -9331,7 +10135,6 @@ "integrity": "sha512-3ofp+LL8E+pK/JuPLPggVAIaEuhvIz4qNcf3nA1Xn2o/7fb7s/TYpHhwGDv1ZU3PkBluUVaF8PyCHcm48cKLWQ==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@alloc/quick-lru": "^5.2.0", "arg": "^5.0.2", @@ -9403,6 +10206,23 @@ "integrity": "sha512-+FbBPE1o9QAYvviau/qC5SE3caw21q3xkvWKBtja5vgqOWIHHJ3ioaq1VPfn/Szqctz2bU/oYeKd9/z5BL+PVg==", "license": "MIT" }, + "node_modules/tinybench": { + "version": "2.9.0", + "resolved": "https://registry.npmjs.org/tinybench/-/tinybench-2.9.0.tgz", + "integrity": "sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==", + "dev": true, + "license": "MIT" + }, + "node_modules/tinyexec": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/tinyexec/-/tinyexec-1.0.4.tgz", + "integrity": "sha512-u9r3uZC0bdpGOXtlxUIdwf9pkmvhqJdrVCH9fapQtgy/OeTTMZ1nqH7agtvEfmGui6e1XxjcdrlxvxJvc3sMqw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/tinyglobby": { "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", @@ -9420,6 +10240,36 @@ "url": "https://github.com/sponsors/SuperchupuDev" } }, + "node_modules/tinyrainbow": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/tinyrainbow/-/tinyrainbow-3.1.0.tgz", + "integrity": "sha512-Bf+ILmBgretUrdJxzXM0SgXLZ3XfiaUuOj/IKQHuTXip+05Xn+uyEYdVg0kYDipTBcLrCVyUzAPz7QmArb0mmw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/tldts": { + "version": "7.0.28", + "resolved": "https://registry.npmjs.org/tldts/-/tldts-7.0.28.tgz", + "integrity": "sha512-+Zg3vWhRUv8B1maGSTFdev9mjoo8Etn2Ayfs4cnjlD3CsGkxXX4QyW3j2WJ0wdjYcYmy7Lx2RDsZMhgCWafKIw==", + "dev": true, + "license": "MIT", + "dependencies": { + "tldts-core": "^7.0.28" + }, + "bin": { + "tldts": "bin/cli.js" + } + }, + "node_modules/tldts-core": { + "version": "7.0.28", + "resolved": "https://registry.npmjs.org/tldts-core/-/tldts-core-7.0.28.tgz", + "integrity": "sha512-7W5Efjhsc3chVdFhqtaU0KtK32J37Zcr9RKtID54nG+tIpcY79CQK/veYPODxtD/LJ4Lue66jvrQzIX2Z2/pUQ==", + "dev": true, + "license": "MIT" + }, "node_modules/to-regex-range": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", @@ -9433,6 +10283,32 @@ "node": ">=8.0" } }, + "node_modules/tough-cookie": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-6.0.1.tgz", + "integrity": "sha512-LktZQb3IeoUWB9lqR5EWTHgW/VTITCXg4D21M+lvybRVdylLrRMnqaIONLVb5mav8vM19m44HIcGq4qASeu2Qw==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "tldts": "^7.0.5" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/tr46": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-6.0.0.tgz", + "integrity": "sha512-bLVMLPtstlZ4iMQHpFHTR7GAGj2jxi8Dg0s2h2MafAE4uSWF98FC/3MomU51iQAMf8/qDUbKWf5GxuvvVcXEhw==", + "dev": true, + "license": "MIT", + "dependencies": { + "punycode": "^2.3.1" + }, + "engines": { + "node": ">=20" + } + }, "node_modules/trim-lines": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/trim-lines/-/trim-lines-3.0.1.tgz", @@ -9508,7 +10384,6 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -9541,6 +10416,16 @@ "typescript": ">=4.8.4 <6.0.0" } }, + "node_modules/undici": { + "version": "7.24.7", + "resolved": "https://registry.npmjs.org/undici/-/undici-7.24.7.tgz", + "integrity": "sha512-H/nlJ/h0ggGC+uRL3ovD+G0i4bqhvsDOpbDv7At5eFLlj2b41L8QliGbnl2H7SnDiYhENphh1tQFJZf+MyfLsQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=20.18.1" + } + }, "node_modules/undici-types": { "version": "7.16.0", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.16.0.tgz", @@ -9862,7 +10747,6 @@ "integrity": "sha512-dZwN5L1VlUBewiP6H9s2+B3e3Jg96D0vzN+Ry73sOefebhYr9f94wwkMNN/9ouoU8pV1BqA1d1zGk8928cx0rg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.27.0", "fdir": "^6.5.0", @@ -9932,6 +10816,101 @@ } } }, + "node_modules/vitest": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-4.1.2.tgz", + "integrity": "sha512-xjR1dMTVHlFLh98JE3i/f/WePqJsah4A0FK9cc8Ehp9Udk0AZk6ccpIZhh1qJ/yxVWRZ+Q54ocnD8TXmkhspGg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/expect": "4.1.2", + "@vitest/mocker": "4.1.2", + "@vitest/pretty-format": "4.1.2", + "@vitest/runner": "4.1.2", + "@vitest/snapshot": "4.1.2", + "@vitest/spy": "4.1.2", + "@vitest/utils": "4.1.2", + "es-module-lexer": "^2.0.0", + "expect-type": "^1.3.0", + "magic-string": "^0.30.21", + "obug": "^2.1.1", + "pathe": "^2.0.3", + "picomatch": "^4.0.3", + "std-env": "^4.0.0-rc.1", + "tinybench": "^2.9.0", + "tinyexec": "^1.0.2", + "tinyglobby": "^0.2.15", + "tinyrainbow": "^3.1.0", + "vite": "^6.0.0 || ^7.0.0 || ^8.0.0", + "why-is-node-running": "^2.3.0" + }, + "bin": { + "vitest": "vitest.mjs" + }, + "engines": { + "node": "^20.0.0 || ^22.0.0 || >=24.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@edge-runtime/vm": "*", + "@opentelemetry/api": "^1.9.0", + "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", + "@vitest/browser-playwright": "4.1.2", + "@vitest/browser-preview": "4.1.2", + "@vitest/browser-webdriverio": "4.1.2", + "@vitest/ui": "4.1.2", + "happy-dom": "*", + "jsdom": "*", + "vite": "^6.0.0 || ^7.0.0 || ^8.0.0" + }, + "peerDependenciesMeta": { + "@edge-runtime/vm": { + "optional": true + }, + "@opentelemetry/api": { + "optional": true + }, + "@types/node": { + "optional": true + }, + "@vitest/browser-playwright": { + "optional": true + }, + "@vitest/browser-preview": { + "optional": true + }, + "@vitest/browser-webdriverio": { + "optional": true + }, + "@vitest/ui": { + "optional": true + }, + "happy-dom": { + "optional": true + }, + "jsdom": { + "optional": true + }, + "vite": { + "optional": false + } + } + }, + "node_modules/w3c-xmlserializer": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-5.0.0.tgz", + "integrity": "sha512-o8qghlI8NZHU1lLPrpi2+Uq7abh4GGPpYANlalzWxyWteJOCsr/P+oPBA49TOLu5FTZO4d3F9MnWJfiMo4BkmA==", + "dev": true, + "license": "MIT", + "dependencies": { + "xml-name-validator": "^5.0.0" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/web-namespaces": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/web-namespaces/-/web-namespaces-2.0.1.tgz", @@ -9942,6 +10921,41 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/webidl-conversions": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-8.0.1.tgz", + "integrity": "sha512-BMhLD/Sw+GbJC21C/UgyaZX41nPt8bUTg+jWyDeg7e7YN4xOM05YPSIXceACnXVtqyEw/LMClUQMtMZ+PGGpqQ==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=20" + } + }, + "node_modules/whatwg-mimetype": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-5.0.0.tgz", + "integrity": "sha512-sXcNcHOC51uPGF0P/D4NVtrkjSU2fNsm9iog4ZvZJsL3rjoDAzXZhkm2MWt1y+PUdggKAYVoMAIYcs78wJ51Cw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=20" + } + }, + "node_modules/whatwg-url": { + "version": "16.0.1", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-16.0.1.tgz", + "integrity": "sha512-1to4zXBxmXHV3IiSSEInrreIlu02vUOvrhxJJH5vcxYTBDAx51cqZiKdyTxlecdKNSjj8EcxGBxNf6Vg+945gw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@exodus/bytes": "^1.11.0", + "tr46": "^6.0.0", + "webidl-conversions": "^8.0.1" + }, + "engines": { + "node": "^20.19.0 || ^22.12.0 || >=24.0.0" + } + }, "node_modules/which": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", @@ -9958,6 +10972,23 @@ "node": ">= 8" } }, + "node_modules/why-is-node-running": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/why-is-node-running/-/why-is-node-running-2.3.0.tgz", + "integrity": "sha512-hUrmaWBdVDcxvYqnyh09zunKzROWjbZTiNy8dBEjkS7ehEDQibXJ7XvlmtbwuTclUiIyN+CyXQD4Vmko8fNm8w==", + "dev": true, + "license": "MIT", + "dependencies": { + "siginfo": "^2.0.0", + "stackback": "0.0.2" + }, + "bin": { + "why-is-node-running": "cli.js" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/word-wrap": { "version": "1.2.5", "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz", @@ -9968,6 +10999,23 @@ "node": ">=0.10.0" } }, + "node_modules/xml-name-validator": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/xml-name-validator/-/xml-name-validator-5.0.0.tgz", + "integrity": "sha512-EvGK8EJ3DhaHfbRlETOWAS5pO9MZITeauHKJyb8wyajUfQUenkIg2MvLDTZ4T/TgIcm3HU0TFBgWWboAZ30UHg==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18" + } + }, + "node_modules/xmlchars": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz", + "integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==", + "dev": true, + "license": "MIT" + }, "node_modules/yallist": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", @@ -9993,7 +11041,6 @@ "resolved": "https://registry.npmjs.org/zod/-/zod-4.3.5.tgz", "integrity": "sha512-k7Nwx6vuWx1IJ9Bjuf4Zt1PEllcwe7cls3VNzm4CQ1/hgtFUK2bRNG3rvnpPUhFjmqJKAKtjV576KnUkHocg/g==", "license": "MIT", - "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } diff --git a/frontend/app/package.json b/frontend/app/package.json index 52199cd30..b613f7a9f 100644 --- a/frontend/app/package.json +++ b/frontend/app/package.json @@ -7,7 +7,8 @@ "dev": "vite", "build": "tsc -b && vite build", "lint": "eslint .", - "preview": "vite preview" + "preview": "vite preview", + "test": "vitest run" }, "dependencies": { "@hookform/resolvers": "^5.2.2", @@ -64,6 +65,7 @@ }, "devDependencies": { "@eslint/js": "^9.39.1", + "@testing-library/react": "^16.3.2", "@types/node": "^24.10.1", "@types/react": "^19.2.5", "@types/react-dom": "^19.2.3", @@ -73,6 +75,7 @@ "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-refresh": "^0.4.24", "globals": "^16.5.0", + "jsdom": "^28.1.0", "kimi-plugin-inspect-react": "^1.0.3", "postcss": "^8.5.6", "tailwindcss": "^3.4.19", @@ -80,6 +83,7 @@ "tw-animate-css": "^1.4.0", "typescript": "~5.9.3", "typescript-eslint": "^8.46.4", - "vite": "^7.2.4" + "vite": "^7.2.4", + "vitest": "^4.1.2" } } diff --git a/frontend/app/src/api/client.ts b/frontend/app/src/api/client.ts index 2dd5c8c56..10bdb4f2d 100644 --- a/frontend/app/src/api/client.ts +++ b/frontend/app/src/api/client.ts @@ -11,7 +11,10 @@ import type { LeaseStatus, ThreadDetail, ThreadSummary, - SandboxChannelFilesResult, + ThreadPermissions, + ThreadPermissionRules, + PermissionRuleBehavior, + AskUserAnswer, SandboxFileResult, SandboxFilesListResult, SandboxUploadResult, @@ -99,26 +102,55 @@ export async function getThread(threadId: string): Promise { return request(`/api/threads/${encodeURIComponent(threadId)}`); } -export async function getThreadRuntime(threadId: string): Promise { - return request(`/api/threads/${encodeURIComponent(threadId)}/runtime`); +export async function getThreadPermissions(threadId: string, signal?: AbortSignal): Promise { + return request(`/api/threads/${encodeURIComponent(threadId)}/permissions`, { signal }); } -export async function sendMessage(threadId: string, message: string): Promise<{ status: string; routing: string }> { - return request(`/api/threads/${encodeURIComponent(threadId)}/messages`, { +export async function resolveThreadPermission( + threadId: string, + requestId: string, + decision: "allow" | "deny", + message?: string, + answers?: AskUserAnswer[], + annotations?: Record, +): Promise<{ ok: boolean; thread_id: string; request_id: string }> { + return request(`/api/threads/${encodeURIComponent(threadId)}/permissions/${encodeURIComponent(requestId)}/resolve`, { method: "POST", - body: JSON.stringify({ message }), + body: JSON.stringify({ decision, message, answers, annotations }), }); } -export async function queueMessage(threadId: string, message: string): Promise { - await request(`/api/threads/${encodeURIComponent(threadId)}/queue`, { +export async function addThreadPermissionRule( + threadId: string, + behavior: PermissionRuleBehavior, + toolName: string, +): Promise<{ ok: boolean; thread_id: string; scope: string; rules: ThreadPermissionRules; managed_only: boolean }> { + return request(`/api/threads/${encodeURIComponent(threadId)}/permissions/rules`, { method: "POST", - body: JSON.stringify({ message }), + body: JSON.stringify({ behavior, tool_name: toolName }), }); } -export async function getQueue(threadId: string): Promise<{ messages: Array<{ id: number; content: string; created_at: string }> }> { - return request(`/api/threads/${encodeURIComponent(threadId)}/queue`); +export async function removeThreadPermissionRule( + threadId: string, + behavior: PermissionRuleBehavior, + toolName: string, +): Promise<{ ok: boolean; thread_id: string; scope: string; rules: ThreadPermissionRules; managed_only: boolean }> { + return request( + `/api/threads/${encodeURIComponent(threadId)}/permissions/rules/${encodeURIComponent(behavior)}/${encodeURIComponent(toolName)}`, + { method: "DELETE" }, + ); +} + +export async function getThreadRuntime(threadId: string): Promise { + return request(`/api/threads/${encodeURIComponent(threadId)}/runtime`); +} + +export async function sendMessage(threadId: string, message: string): Promise<{ status: string; routing: string }> { + return request(`/api/threads/${encodeURIComponent(threadId)}/messages`, { + method: "POST", + body: JSON.stringify({ message }), + }); } // --- Sandbox API --- @@ -163,32 +195,6 @@ export async function listMyLeases(signal?: AbortSignal): Promise { - await request(`/api/threads/${encodeURIComponent(threadId)}/sandbox/pause`, { method: "POST" }); -} - -export async function resumeThreadSandbox(threadId: string): Promise { - await request(`/api/threads/${encodeURIComponent(threadId)}/sandbox/resume`, { method: "POST" }); -} - -export async function destroyThreadSandbox(threadId: string): Promise { - await request(`/api/threads/${encodeURIComponent(threadId)}/sandbox`, { method: "DELETE" }); -} - -export async function pauseSandboxSession(sessionId: string, provider: string): Promise { - await request( - `/api/sandbox/sessions/${encodeURIComponent(sessionId)}/pause?provider=${encodeURIComponent(provider)}`, - { method: "POST" }, - ); -} - -export async function resumeSandboxSession(sessionId: string, provider: string): Promise { - await request( - `/api/sandbox/sessions/${encodeURIComponent(sessionId)}/resume?provider=${encodeURIComponent(provider)}`, - { method: "POST" }, - ); -} - export async function destroySandboxSession(sessionId: string, provider: string): Promise { await request( `/api/sandbox/sessions/${encodeURIComponent(sessionId)}?provider=${encodeURIComponent(provider)}`, @@ -225,12 +231,6 @@ export async function readSandboxFile(threadId: string, path: string): Promise { - return request(`${sandboxFilesBase(threadId)}/channel-files`); -} - export async function uploadSandboxFile( threadId: string, opts: { file: File; path?: string }, @@ -261,11 +261,6 @@ export function getSandboxDownloadUrl( // --- Settings API --- -export async function listSandboxConfigs(): Promise>> { - const payload = await request<{ sandboxes: Record> }>("/api/settings/sandboxes"); - return payload.sandboxes; -} - export async function saveSandboxConfig(name: string, config: Record): Promise { await request("/api/settings/sandboxes", { method: "POST", @@ -275,10 +270,6 @@ export async function saveSandboxConfig(name: string, config: Record> { - return request("/api/settings/observation"); -} - export async function saveObservationConfig( active: string | null, config?: Record, diff --git a/frontend/app/src/api/types.ts b/frontend/app/src/api/types.ts index 08d990935..c031f3582 100644 --- a/frontend/app/src/api/types.ts +++ b/frontend/app/src/api/types.ts @@ -45,6 +45,49 @@ export interface ThreadDetail { sandbox: SandboxInfo | null; } +export interface PermissionRequest { + request_id: string; + thread_id: string; + tool_name: string; + args: Record; + message?: string | null; +} + +export interface AskUserQuestionOption { + label: string; + description: string; + preview?: string | null; +} + +export interface AskUserQuestionPrompt { + header: string; + question: string; + options: AskUserQuestionOption[]; + multiSelect?: boolean; +} + +export interface AskUserAnswer { + header?: string; + question?: string; + selected_options: string[]; + free_text?: string | null; +} + +export type PermissionRuleBehavior = "allow" | "deny" | "ask"; + +export interface ThreadPermissionRules { + allow: string[]; + deny: string[]; + ask: string[]; +} + +export interface ThreadPermissions { + thread_id: string; + requests: PermissionRequest[]; + session_rules: ThreadPermissionRules; + managed_only: boolean; +} + export interface SandboxType { name: string; provider?: string; @@ -219,6 +262,7 @@ export interface StreamStatus { state: { state: string; flags: Record }; tokens: { total_tokens: number; input_tokens: number; output_tokens: number; cost: number }; context: { message_count: number; estimated_tokens: number; usage_percent: number; near_limit: boolean }; + model?: string; current_tool?: string; last_seq?: number; run_start_seq?: number; diff --git a/frontend/app/src/components/FileBrowser.tsx b/frontend/app/src/components/FileBrowser.tsx deleted file mode 100644 index 4cef7086a..000000000 --- a/frontend/app/src/components/FileBrowser.tsx +++ /dev/null @@ -1,101 +0,0 @@ -import { useState } from 'react'; -import { authFetch } from '@/store/auth-store'; -import { useFileList } from '@/hooks/useFileList'; -import { MoreVertical } from 'lucide-react'; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuTrigger, -} from '@/components/ui/dropdown-menu'; -import { Button } from '@/components/ui/button'; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from '@/components/ui/alert-dialog'; - -interface FileBrowserProps { - threadId: string; -} - -export function FileBrowser({ threadId }: FileBrowserProps) { - const { files, loading, error, refetch } = useFileList(threadId); - const [deleteTarget, setDeleteTarget] = useState(null); - const [deleting, setDeleting] = useState(false); - - const handleDownload = (path: string) => { - const url = `/api/threads/${threadId}/files/download?path=${encodeURIComponent(path)}`; - window.open(url, '_blank'); - }; - - const handleDelete = async () => { - if (!deleteTarget) return; - setDeleting(true); - try { - const res = await authFetch( - `/api/threads/${threadId}/files/files?path=${encodeURIComponent(deleteTarget)}`, - { method: 'DELETE' } - ); - if (!res.ok) throw new Error('Failed to delete file'); - await refetch(); - } catch (e) { - alert(e instanceof Error ? e.message : 'Failed to delete file'); - } finally { - setDeleting(false); - setDeleteTarget(null); - } - }; - - if (loading) return
加载文件中...
; - if (error) return
错误:{error}
; - if (files.length === 0) return
暂无已上传文件
; - - return ( - <> -
- {files.map((file) => ( -
- {file.relative_path} -
- {(file.size_bytes / 1024).toFixed(1)} KB - - - - - - handleDownload(file.relative_path)}>下载 - setDeleteTarget(file.relative_path)} disabled={deleting}>删除 - - -
-
- ))} -
- - setDeleteTarget(null)}> - - - 删除文件? - - 确定要删除 "{deleteTarget}" 吗?此操作无法撤销。 - - - - 取消 - - {deleting ? '删除中...' : '删除'} - - - - - - ); -} diff --git a/frontend/app/src/components/Header.tsx b/frontend/app/src/components/Header.tsx index 9273f8c7b..a4a5e07cd 100644 --- a/frontend/app/src/components/Header.tsx +++ b/frontend/app/src/components/Header.tsx @@ -1,4 +1,4 @@ -import { ChevronLeft, PanelLeft, Pause, Play } from "lucide-react"; +import { ChevronLeft, PanelLeft } from "lucide-react"; import { useNavigate } from "react-router-dom"; import type { SandboxInfo } from "../api"; import { useIsMobile } from "../hooks/use-mobile"; @@ -22,8 +22,6 @@ interface HeaderProps { sandboxInfo: SandboxInfo | null; currentModel?: string; onToggleSidebar: () => void; - onPauseSandbox: () => void; - onResumeSandbox: () => void; onModelChange?: (model: string) => void; } @@ -33,8 +31,6 @@ export default function Header({ sandboxInfo, currentModel = "leon:medium", onToggleSidebar, - onPauseSandbox, - onResumeSandbox, onModelChange, }: HeaderProps) { const isMobile = useIsMobile(); @@ -90,25 +86,6 @@ export default function Header({ threadId={activeThreadId} onModelChange={onModelChange} /> - - {hasRemote && sandboxInfo?.status === "running" && ( - - )} - {hasRemote && sandboxInfo?.status === "paused" && ( - - )} ); diff --git a/frontend/app/src/components/NewChatDialog.tsx b/frontend/app/src/components/NewChatDialog.tsx index 1a7ed3a29..c5eb6ff63 100644 --- a/frontend/app/src/components/NewChatDialog.tsx +++ b/frontend/app/src/components/NewChatDialog.tsx @@ -41,8 +41,8 @@ export default function NewChatDialog({ open, onOpenChange }: NewChatDialogProps - 发起会话 - 选择成员发起新对话 + 打开成员线程 + 选择成员打开专属线程
diff --git a/frontend/app/src/components/SandboxSessionsModal.test.tsx b/frontend/app/src/components/SandboxSessionsModal.test.tsx new file mode 100644 index 000000000..b6bcb10a8 --- /dev/null +++ b/frontend/app/src/components/SandboxSessionsModal.test.tsx @@ -0,0 +1,53 @@ +// @vitest-environment jsdom + +import { render, screen, waitFor } from "@testing-library/react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import SandboxSessionsModal from "./SandboxSessionsModal"; +import type { SandboxSession } from "../api"; + +const { listSandboxSessions } = vi.hoisted(() => ({ + listSandboxSessions: vi.fn(), +})); + +vi.mock("../api", async () => { + const actual = await vi.importActual("../api"); + return { + ...actual, + listSandboxSessions, + destroySandboxSession: vi.fn(), + }; +}); + +describe("SandboxSessionsModal", () => { + beforeEach(() => { + listSandboxSessions.mockReset(); + }); + + it("does not render pause or resume controls for running or paused sessions", async () => { + const sessions: SandboxSession[] = [ + { + session_id: "session-running", + thread_id: "thread-running", + provider: "local", + status: "running", + }, + { + session_id: "session-paused", + thread_id: "thread-paused", + provider: "daytona_selfhost", + status: "paused", + }, + ]; + listSandboxSessions.mockResolvedValue(sessions); + + render(); + + await waitFor(() => { + expect(listSandboxSessions).toHaveBeenCalled(); + }); + + expect(screen.queryByTitle("暂停")).toBeNull(); + expect(screen.queryByTitle("恢复")).toBeNull(); + expect(screen.getAllByTitle("销毁")).toHaveLength(2); + }); +}); diff --git a/frontend/app/src/components/SandboxSessionsModal.tsx b/frontend/app/src/components/SandboxSessionsModal.tsx index 955a1b28c..48cae6a1e 100644 --- a/frontend/app/src/components/SandboxSessionsModal.tsx +++ b/frontend/app/src/components/SandboxSessionsModal.tsx @@ -1,10 +1,8 @@ -import { Loader2, Pause, Play, Trash2 } from "lucide-react"; -import { useEffect, useState } from "react"; +import { Loader2, Trash2 } from "lucide-react"; +import { useCallback, useEffect, useState } from "react"; import { destroySandboxSession, listSandboxSessions, - pauseSandboxSession, - resumeSandboxSession, type SandboxSession, } from "../api"; import { @@ -29,7 +27,7 @@ export default function SandboxSessionsModal({ isOpen, onClose, onSessionMutated const [busy, setBusy] = useState(null); const [error, setError] = useState(null); - async function refresh(opts?: { silent?: boolean }) { + const refresh = useCallback(async (opts?: { silent?: boolean }) => { const silent = opts?.silent ?? false; const showInitialLoading = !hasLoaded && !silent; if (showInitialLoading) { @@ -48,7 +46,7 @@ export default function SandboxSessionsModal({ isOpen, onClose, onSessionMutated setLoading(false); setRefreshing(false); } - } + }, [hasLoaded]); useEffect(() => { if (!isOpen) return; @@ -57,7 +55,7 @@ export default function SandboxSessionsModal({ isOpen, onClose, onSessionMutated void refresh({ silent: true }); }, 2500); return () => window.clearInterval(timer); - }, [isOpen]); + }, [isOpen, refresh]); async function withBusy(row: SandboxSession, fn: () => Promise) { setBusy(row.session_id); @@ -153,26 +151,6 @@ export default function SandboxSessionsModal({ isOpen, onClose, onSessionMutated
- {row.status === "running" && ( - - )} - {row.status === "paused" && ( - - )} diff --git a/frontend/app/src/components/computer-panel/AgentsView.tsx b/frontend/app/src/components/computer-panel/AgentsView.tsx index 51a537de0..d9866046f 100644 --- a/frontend/app/src/components/computer-panel/AgentsView.tsx +++ b/frontend/app/src/components/computer-panel/AgentsView.tsx @@ -2,6 +2,9 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { Loader2 } from "lucide-react"; import type { AssistantTurn, ToolStep } from "../../api"; import { useThreadData } from "../../hooks/use-thread-data"; +import { useDisplayDeltas } from "../../hooks/use-display-deltas"; +import { useThreadStream } from "../../hooks/use-thread-stream"; +import { resolveAgentVisualStatus, type AgentVisualStatus } from "./agent-visual-status"; import { parseAgentArgs } from "./utils"; import type { FlowItem } from "./utils"; import { FlowList } from "./flow-items"; @@ -22,12 +25,43 @@ export function AgentsView({ steps }: AgentsViewProps) { const dragStartX = useRef(0); const dragStartWidth = useRef(0); - const focused = steps.find((s) => s.id === selectedAgentId) ?? null; + const effectiveSelectedAgentId = useMemo(() => { + if (steps.length === 0) return null; + if (selectedAgentId && steps.some((step) => step.id === selectedAgentId)) return selectedAgentId; + return ( + [...steps].reverse().find((step) => { + const status = step.subagent_stream?.status; + return status === "running" || step.status === "calling"; + })?.id ?? steps[steps.length - 1].id + ); + }, [steps, selectedAgentId]); + + const focused = steps.find((s) => s.id === effectiveSelectedAgentId) ?? null; const stream = focused?.subagent_stream; const threadId = stream?.thread_id || undefined; - const isRunning = stream?.status === "running" || focused?.status === "calling"; - - const { entries, loading, refreshThread } = useThreadData(threadId); + const { entries, loading, refreshThread, setEntries, displaySeq } = useThreadData(threadId); + const refreshThreads = useCallback(async () => {}, []); + // @@@child-thread-live-bridge - the Agent pane must subscribe to the child + // thread's own SSE stream. Polling child detail alone misses the running + // window and makes the pane look empty until a later refresh. + const childStream = useThreadStream(threadId ?? "", { + loading: loading || !threadId, + refreshThreads, + }); + const childDisplay = useDisplayDeltas({ + threadId: threadId ?? "", + onUpdate: setEntries, + displaySeq, + stream: childStream, + }); + const focusedStatus = + focused + ? resolveAgentVisualStatus(focused, { + childDisplayRunning: childDisplay.isRunning, + childRuntimeState: childStream.runtimeStatus?.state?.state ?? null, + }) + : null; + const isRunning = focusedStatus === "running"; // Poll every second while sub-agent is running useEffect(() => { @@ -61,7 +95,7 @@ export function AgentsView({ steps }: AgentsViewProps) { id: tc.id, name: tc.name, args: tc.args, status: tc.status === "done" ? "done" : "calling", result: tc.result, - timestamp: Date.now(), + timestamp: focused?.timestamp ?? 0, }, turnId: "live", }); @@ -73,7 +107,7 @@ export function AgentsView({ steps }: AgentsViewProps) { } return items; - }, [entries]); + }, [entries, stream, focused?.timestamp]); const handleMouseDown = useCallback((e: React.MouseEvent) => { e.preventDefault(); @@ -118,7 +152,8 @@ export function AgentsView({ steps }: AgentsViewProps) { setSelectedAgentId(step.id)} /> ))} @@ -141,7 +176,7 @@ export function AgentsView({ steps }: AgentsViewProps) {
) : ( <> - + {loading ? (
@@ -164,14 +199,25 @@ export function AgentsView({ steps }: AgentsViewProps) { /* -- Agent list item -- */ -function AgentListItem({ step, isSelected, onClick }: { step: ToolStep; isSelected: boolean; onClick: () => void }) { +function AgentListItem({ + step, + visualStatus, + isSelected, + onClick, +}: { + step: ToolStep; + visualStatus: AgentVisualStatus | null; + isSelected: boolean; + onClick: () => void; +}) { const args = parseAgentArgs(step.args); const ss = step.subagent_stream; const displayName = ss?.description || args.description || args.prompt?.slice(0, 40) || "子任务"; const prompt = args.prompt || ""; - const isRunning = ss?.status === "running" || (step.status === "calling" && ss?.status !== "completed"); - const isError = step.status === "error" || ss?.status === "error"; - const isDone = !isRunning && !isError && (step.status === "done" || ss?.status === "completed"); + const status = resolveAgentVisualStatus(step, { statusOverride: visualStatus }); + const isRunning = status === "running"; + const isError = status === "error"; + const isDone = status === "completed"; const statusDot = isRunning ? "bg-success animate-pulse" : isError ? "bg-destructive" : isDone ? "bg-success" : "bg-warning animate-pulse"; return ( @@ -194,21 +240,27 @@ function AgentListItem({ step, isSelected, onClick }: { step: ToolStep; isSelect /* -- Agent detail header -- */ -function getStatusLabel(focused: ToolStep, stream: SubagentStream | undefined): string { - if (stream?.status === "running") return "运行中"; - if (stream?.status === "error") return "出错"; - if (focused.status === "calling") return "启动中"; +function getStatusLabel(status: AgentVisualStatus): string { + if (status === "running") return "运行中"; + if (status === "error") return "出错"; return "已完成"; } -function getStatusDotClass(focused: ToolStep, stream: SubagentStream | undefined): string { - if (stream?.status === "running") return "bg-success animate-pulse"; - if (stream?.status === "error") return "bg-destructive"; - if (focused.status === "calling") return "bg-warning animate-pulse"; +function getStatusDotClass(status: AgentVisualStatus): string { + if (status === "running") return "bg-success animate-pulse"; + if (status === "error") return "bg-destructive"; return "bg-success"; } -function AgentDetailHeader({ focused, stream }: { focused: ToolStep; stream: SubagentStream | undefined }) { +function AgentDetailHeader({ + focused, + stream, + visualStatus, +}: { + focused: ToolStep; + stream: SubagentStream | undefined; + visualStatus: AgentVisualStatus; +}) { const args = parseAgentArgs(focused.args); const displayName = stream?.description || args.description || args.prompt?.slice(0, 40) || "子任务"; const agentType = args.subagent_type; @@ -218,8 +270,8 @@ function AgentDetailHeader({ focused, stream }: { focused: ToolStep; stream: Sub {agentType} )}
{displayName}
- - {getStatusLabel(focused, stream)} + + {getStatusLabel(visualStatus)}
); } @@ -239,4 +291,3 @@ function AgentPromptSection({ args }: { args: unknown }) { ); } - diff --git a/frontend/app/src/components/computer-panel/PanelHeader.test.tsx b/frontend/app/src/components/computer-panel/PanelHeader.test.tsx new file mode 100644 index 000000000..c061dfe59 --- /dev/null +++ b/frontend/app/src/components/computer-panel/PanelHeader.test.tsx @@ -0,0 +1,31 @@ +// @vitest-environment jsdom + +import { render, screen } from "@testing-library/react"; +import { describe, expect, it, vi } from "vitest"; +import { PanelHeader } from "./PanelHeader"; + +describe("PanelHeader", () => { + it("does not render pause or resume controls for remote sandboxes", () => { + const onClose = vi.fn(); + + const { rerender } = render( + , + ); + + expect(screen.getAllByRole("button")).toHaveLength(1); + expect(screen.getByTitle("收起视窗")).toBeTruthy(); + + rerender( + , + ); + + expect(screen.getAllByRole("button")).toHaveLength(1); + expect(screen.getByTitle("收起视窗")).toBeTruthy(); + }); +}); diff --git a/frontend/app/src/components/computer-panel/PanelHeader.tsx b/frontend/app/src/components/computer-panel/PanelHeader.tsx index 8340d2634..871586479 100644 --- a/frontend/app/src/components/computer-panel/PanelHeader.tsx +++ b/frontend/app/src/components/computer-panel/PanelHeader.tsx @@ -1,17 +1,9 @@ -import { Pause, Play } from "lucide-react"; -import { pauseThreadSandbox, resumeThreadSandbox, type LeaseStatus } from "../../api"; - interface PanelHeaderProps { threadId: string | null; - isRemote: boolean; - lease: LeaseStatus | null; onClose: () => void; - onRefreshStatus: () => Promise; } -export function PanelHeader({ threadId, isRemote, lease, onClose, onRefreshStatus }: PanelHeaderProps) { - const instanceState = lease?.instance?.state; - +export function PanelHeader({ threadId, onClose }: PanelHeaderProps) { return (
@@ -21,22 +13,6 @@ export function PanelHeader({ threadId, isRemote, lease, onClose, onRefreshStatu

- {isRemote && instanceState === "running" && ( - - )} - {isRemote && instanceState === "paused" && ( - - )}