From 91c177d29af0d255ad249e108b16785f8ad26e74 Mon Sep 17 00:00:00 2001 From: Matvii Sakhnenko Date: Tue, 7 Apr 2026 13:03:57 +0200 Subject: [PATCH 01/11] fix: set misfire_grace_time=None for reliable heartbeats With misfire_grace_time=300, heartbeats were still missed when the bot was busy for >5 minutes (observed: 8m54s miss). Setting to None guarantees every job fires exactly once, combined with coalesce=True to prevent duplicate execution. --- src/scheduler/scheduler.py | 7 +++- .../test_scheduler/test_misfire_config.py | 37 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_scheduler/test_misfire_config.py diff --git a/src/scheduler/scheduler.py b/src/scheduler/scheduler.py index 98d90a553..8cbca1d67 100644 --- a/src/scheduler/scheduler.py +++ b/src/scheduler/scheduler.py @@ -32,7 +32,12 @@ def __init__( self.event_bus = event_bus self.db_manager = db_manager self.default_working_directory = default_working_directory - self._scheduler = AsyncIOScheduler() + self._scheduler = AsyncIOScheduler( + job_defaults={ + "misfire_grace_time": None, # Always run, no matter how late + "coalesce": True, # Merge multiple missed runs into one + } + ) async def start(self) -> None: """Load persisted jobs and start the scheduler.""" diff --git a/tests/unit/test_scheduler/test_misfire_config.py b/tests/unit/test_scheduler/test_misfire_config.py new file mode 100644 index 000000000..9f3f5f890 --- /dev/null +++ b/tests/unit/test_scheduler/test_misfire_config.py @@ -0,0 +1,37 @@ +"""Tests for scheduler misfire configuration.""" + +from unittest.mock import MagicMock + +import pytest + +from src.events.bus import EventBus +from src.scheduler.scheduler import JobScheduler +from src.storage.database import DatabaseManager + + +class TestSchedulerMisfireConfig: + """Verify APScheduler is configured for resilient job execution.""" + + @pytest.fixture + def scheduler(self, tmp_path): + event_bus = EventBus() + db_manager = MagicMock(spec=DatabaseManager) + return JobScheduler( + event_bus=event_bus, + db_manager=db_manager, + default_working_directory=tmp_path, + ) + + def test_misfire_grace_time_is_none(self, scheduler): + """misfire_grace_time=None ensures jobs always run, no matter how late.""" + job_defaults = scheduler._scheduler._job_defaults + assert job_defaults.get("misfire_grace_time") is None, ( + "misfire_grace_time must be None to guarantee late jobs still fire" + ) + + def test_coalesce_is_enabled(self, scheduler): + """coalesce=True merges multiple missed runs into a single execution.""" + job_defaults = scheduler._scheduler._job_defaults + assert job_defaults.get("coalesce") is True, ( + "coalesce must be True to prevent spam-firing missed runs" + ) From 25f216b69cd6cddfbbc0c07ae65884d826cfe65c Mon Sep 17 00:00:00 2001 From: Matvii Sakhnenko Date: Tue, 7 Apr 2026 13:06:24 +0200 Subject: [PATCH 02/11] fix: run scheduled jobs in background tasks to prevent bus blocking AgentHandler.handle_scheduled now dispatches work via asyncio.create_task and returns immediately, preventing long Claude executions from blocking the event bus and causing missed heartbeats. A semaphore (max 2) caps concurrent Claude executions. handle_webhook remains unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/events/handlers.py | 42 ++++-- .../test_events/test_concurrent_scheduled.py | 124 ++++++++++++++++++ 2 files changed, 156 insertions(+), 10 deletions(-) create mode 100644 tests/unit/test_events/test_concurrent_scheduled.py diff --git a/src/events/handlers.py b/src/events/handlers.py index 74ceaaad1..d2dcc334a 100644 --- a/src/events/handlers.py +++ b/src/events/handlers.py @@ -4,8 +4,9 @@ NotificationHandler: subscribes to AgentResponseEvent and delivers to Telegram. """ +import asyncio from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Set import structlog @@ -35,6 +36,8 @@ def __init__( self.claude = claude_integration self.default_working_directory = default_working_directory self.default_user_id = default_user_id + self._scheduled_semaphore = asyncio.Semaphore(2) + self._background_tasks: Set[asyncio.Task[Any]] = set() def register(self) -> None: """Subscribe to events that need agent processing.""" @@ -92,15 +95,27 @@ async def handle_scheduled(self, event: Event) -> None: job_name=event.job_name, ) - prompt = event.prompt - if event.skill_name: - prompt = ( - f"/{event.skill_name}\n\n{prompt}" if prompt else f"/{event.skill_name}" - ) + task = asyncio.create_task( + self._run_scheduled(event), + name=f"scheduled:{event.job_id or event.job_name or event.id}", + ) + self._background_tasks.add(task) + task.add_done_callback(self._task_done) + await asyncio.sleep(0) + + async def _run_scheduled(self, event: ScheduledEvent) -> None: + """Run scheduled Claude work in the background with concurrency limits.""" + async with self._scheduled_semaphore: + prompt = event.prompt + if event.skill_name: + prompt = ( + f"/{event.skill_name}\n\n{prompt}" + if prompt + else f"/{event.skill_name}" + ) - working_dir = event.working_directory or self.default_working_directory + working_dir = event.working_directory or self.default_working_directory - try: response = await self.claude.run_command( prompt=prompt, working_directory=working_dir, @@ -126,11 +141,18 @@ async def handle_scheduled(self, event: Event) -> None: originating_event_id=event.id, ) ) + + def _task_done(self, task: asyncio.Task[Any]) -> None: + """Clean up finished scheduled tasks and log failures.""" + self._background_tasks.discard(task) + try: + task.result() + except asyncio.CancelledError: + return except Exception: logger.exception( "Agent execution failed for scheduled event", - job_id=event.job_id, - event_id=event.id, + task_name=task.get_name(), ) def _build_webhook_prompt(self, event: WebhookEvent) -> str: diff --git a/tests/unit/test_events/test_concurrent_scheduled.py b/tests/unit/test_events/test_concurrent_scheduled.py new file mode 100644 index 000000000..b1d255069 --- /dev/null +++ b/tests/unit/test_events/test_concurrent_scheduled.py @@ -0,0 +1,124 @@ +"""Tests for concurrent scheduled event handling.""" + +import asyncio +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from src.events.bus import EventBus +from src.events.handlers import AgentHandler +from src.events.types import ScheduledEvent, WebhookEvent + + +@pytest.fixture +def event_bus() -> EventBus: + return EventBus() + + +@pytest.fixture +def mock_claude() -> AsyncMock: + mock = AsyncMock() + mock.run_command = AsyncMock() + return mock + + +@pytest.fixture +def agent_handler(event_bus: EventBus, mock_claude: AsyncMock) -> AgentHandler: + return AgentHandler( + event_bus=event_bus, + claude_integration=mock_claude, + default_working_directory=Path("/tmp/test"), + default_user_id=42, + ) + + +class TestConcurrentScheduledHandling: + """Tests for background dispatch of scheduled jobs.""" + + async def test_handle_scheduled_returns_immediately( + self, agent_handler: AgentHandler, mock_claude: AsyncMock + ) -> None: + """Scheduled jobs should dispatch to the background and return fast.""" + + async def slow_run_command(**_: object) -> MagicMock: + await asyncio.sleep(10) + response = MagicMock() + response.content = "done" + return response + + mock_claude.run_command.side_effect = slow_run_command + event = ScheduledEvent( + job_name="standup", + prompt="Generate daily standup", + target_chat_ids=[100], + ) + + await asyncio.wait_for(agent_handler.handle_scheduled(event), timeout=1.0) + + for task in list(getattr(agent_handler, "_background_tasks", set())): + task.cancel() + if getattr(agent_handler, "_background_tasks", None): + await asyncio.gather(*agent_handler._background_tasks, return_exceptions=True) + + def test_scheduled_semaphore_limits_concurrency( + self, agent_handler: AgentHandler + ) -> None: + """Scheduled jobs are capped to two concurrent Claude executions.""" + assert agent_handler._scheduled_semaphore._value == 2 + + async def test_scheduled_task_errors_are_logged_not_raised( + self, agent_handler: AgentHandler, mock_claude: AsyncMock + ) -> None: + """Background scheduled task failures should be logged, not propagated.""" + + async def boom(**_: object) -> MagicMock: + raise RuntimeError("SDK error") + + mock_claude.run_command.side_effect = boom + event = ScheduledEvent( + job_name="standup", + prompt="Generate daily standup", + target_chat_ids=[100], + ) + + with patch("src.events.handlers.logger.exception") as mock_log: + await agent_handler.handle_scheduled(event) + tasks = list(agent_handler._background_tasks) + assert tasks + await asyncio.gather(*tasks, return_exceptions=True) + await asyncio.sleep(0) + + mock_log.assert_called() + + async def test_webhook_handler_still_blocks( + self, agent_handler: AgentHandler, mock_claude: AsyncMock + ) -> None: + """Webhook handling should still await Claude directly.""" + release = asyncio.Event() + started = asyncio.Event() + + async def blocked_run_command(**_: object) -> MagicMock: + started.set() + await release.wait() + response = MagicMock() + response.content = "done" + return response + + mock_claude.run_command.side_effect = blocked_run_command + event = WebhookEvent( + provider="github", + event_type_name="push", + payload={"ref": "refs/heads/main"}, + delivery_id="del-1", + ) + + task = asyncio.create_task(agent_handler.handle_webhook(event)) + await started.wait() + await asyncio.sleep(0) + + assert not task.done() + + release.set() + await asyncio.wait_for(task, timeout=1.0) + mock_claude.run_command.assert_awaited_once() From 62122cb1c629ed2a1e077d12d21aa21c34acccf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20D=C3=ADez=20Arias?= Date: Sun, 15 Mar 2026 10:16:05 +0100 Subject: [PATCH 03/11] feat(schedule-command): inject JobScheduler into bot dependencies Add "scheduler" key to the bot deps dict and assign the real JobScheduler instance after creation in run_application(), making it accessible to command handlers via context.bot_data["scheduler"]. --- src/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main.py b/src/main.py index 026607338..4683bcd66 100644 --- a/src/main.py +++ b/src/main.py @@ -181,6 +181,7 @@ async def create_application(config: Settings) -> Dict[str, Any]: "event_bus": event_bus, "project_registry": None, "project_threads_manager": None, + "scheduler": None, } bot = ClaudeCodeBot(config, dependencies) @@ -314,6 +315,7 @@ def signal_handler(signum: int, frame: Any) -> None: default_working_directory=config.approved_directory, ) await scheduler.start() + bot.deps["scheduler"] = scheduler logger.info("Job scheduler enabled") # Shutdown task From a5e14f95ebcf9213a03ac72a65e8feb5b4ed4930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20D=C3=ADez=20Arias?= Date: Sun, 15 Mar 2026 10:17:19 +0100 Subject: [PATCH 04/11] feat(schedule-command): add pause_job/resume_job to JobScheduler pause_job() removes job from APScheduler and sets is_active=0. resume_job() re-registers with APScheduler and sets is_active=1. list_jobs() now accepts include_paused flag to show all jobs. --- src/scheduler/scheduler.py | 83 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 3 deletions(-) diff --git a/src/scheduler/scheduler.py b/src/scheduler/scheduler.py index 8cbca1d67..025b23a36 100644 --- a/src/scheduler/scheduler.py +++ b/src/scheduler/scheduler.py @@ -121,12 +121,89 @@ async def remove_job(self, job_id: str) -> bool: logger.info("Scheduled job removed", job_id=job_id) return True - async def list_jobs(self) -> List[Dict[str, Any]]: - """List all scheduled jobs from the database.""" + async def pause_job(self, job_id: str) -> bool: + """Pause a scheduled job (remove from APScheduler, mark inactive in DB).""" + try: + self._scheduler.remove_job(job_id) + except Exception: + logger.warning("Job not found in scheduler", job_id=job_id) + + async with self.db_manager.get_connection() as conn: + cursor = await conn.execute( + "UPDATE scheduled_jobs SET is_active = 0 WHERE job_id = ?", + (job_id,), + ) + await conn.commit() + if cursor.rowcount == 0: + return False + + logger.info("Scheduled job paused", job_id=job_id) + return True + + async def resume_job(self, job_id: str) -> bool: + """Resume a paused job (re-register with APScheduler, mark active in DB).""" async with self.db_manager.get_connection() as conn: cursor = await conn.execute( - "SELECT * FROM scheduled_jobs WHERE is_active = 1 ORDER BY created_at" + "SELECT * FROM scheduled_jobs WHERE job_id = ?", + (job_id,), + ) + row = await cursor.fetchone() + + if not row: + return False + + row_dict = dict(row) + + # Re-register with APScheduler + try: + trigger = CronTrigger.from_crontab(row_dict["cron_expression"]) + chat_ids_str = row_dict.get("target_chat_ids", "") + chat_ids = ( + [int(x) for x in chat_ids_str.split(",") if x.strip()] + if chat_ids_str + else [] + ) + self._scheduler.add_job( + self._fire_event, + trigger=trigger, + kwargs={ + "job_name": row_dict["job_name"], + "prompt": row_dict["prompt"], + "working_directory": row_dict["working_directory"], + "target_chat_ids": chat_ids, + "skill_name": row_dict.get("skill_name"), + }, + id=job_id, + name=row_dict["job_name"], + replace_existing=True, + ) + except Exception: + logger.exception("Failed to re-register job with scheduler", job_id=job_id) + return False + + # Mark active in DB + async with self.db_manager.get_connection() as conn: + await conn.execute( + "UPDATE scheduled_jobs SET is_active = 1 WHERE job_id = ?", + (job_id,), ) + await conn.commit() + + logger.info("Scheduled job resumed", job_id=job_id) + return True + + async def list_jobs(self, include_paused: bool = False) -> List[Dict[str, Any]]: + """List scheduled jobs from the database.""" + async with self.db_manager.get_connection() as conn: + if include_paused: + cursor = await conn.execute( + "SELECT * FROM scheduled_jobs ORDER BY created_at" + ) + else: + cursor = await conn.execute( + "SELECT * FROM scheduled_jobs" + " WHERE is_active = 1 ORDER BY created_at" + ) rows = await cursor.fetchall() return [dict(row) for row in rows] From 6c46e3ef706f3963dc007bbe91312c7f3c9b2a58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20D=C3=ADez=20Arias?= Date: Sun, 15 Mar 2026 10:19:38 +0100 Subject: [PATCH 05/11] feat(schedule-command): add ScheduleCommandHandler Single /schedule entry point dispatching to subcommands: list, add, remove, pause, resume. Auto-populates chat_id, working_directory, and created_by from Telegram context. --- src/bot/handlers/schedule.py | 200 +++++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 src/bot/handlers/schedule.py diff --git a/src/bot/handlers/schedule.py b/src/bot/handlers/schedule.py new file mode 100644 index 000000000..19af84e63 --- /dev/null +++ b/src/bot/handlers/schedule.py @@ -0,0 +1,200 @@ +"""Schedule command handler for managing scheduled jobs via Telegram.""" + +from typing import List + +import structlog +from telegram import Update +from telegram.ext import ContextTypes + +from ...scheduler.scheduler import JobScheduler +from ..utils.html_format import escape_html + +logger = structlog.get_logger() + + +def _get_scheduler(context: ContextTypes.DEFAULT_TYPE) -> JobScheduler | None: + """Get the JobScheduler from bot dependencies.""" + return context.bot_data.get("scheduler") + + +async def schedule_command( + update: Update, context: ContextTypes.DEFAULT_TYPE +) -> None: + """Handle /schedule command with subcommands. + + Usage: + /schedule list + /schedule add + /schedule remove + /schedule pause + /schedule resume + """ + scheduler = _get_scheduler(context) + if not scheduler: + await update.message.reply_text( + "Scheduler is not enabled. Set ENABLE_SCHEDULER=true to use this command." + ) + return + + args: List[str] = context.args or [] + + if not args: + await _show_usage(update) + return + + subcommand = args[0].lower() + sub_args = args[1:] + + if subcommand == "list": + await _handle_list(update, scheduler) + elif subcommand == "add": + await _handle_add(update, context, scheduler, sub_args) + elif subcommand == "remove": + await _handle_remove(update, scheduler, sub_args) + elif subcommand == "pause": + await _handle_pause(update, scheduler, sub_args) + elif subcommand == "resume": + await _handle_resume(update, scheduler, sub_args) + else: + await _show_usage(update) + + +async def _show_usage(update: Update) -> None: + """Show usage information.""" + await update.message.reply_html( + "Usage:\n" + "/schedule list\n" + "/schedule add <name> <cron> <prompt>\n" + "/schedule remove <job_id>\n" + "/schedule pause <job_id>\n" + "/schedule resume <job_id>\n\n" + "Cron format: min hour day month weekday\n" + "Example: /schedule add daily-report 0 9 * * * Check status" + ) + + +async def _handle_list(update: Update, scheduler: JobScheduler) -> None: + """List all scheduled jobs.""" + jobs = await scheduler.list_jobs(include_paused=True) + + if not jobs: + await update.message.reply_text("No scheduled jobs.") + return + + lines = ["Scheduled Jobs:\n"] + for job in jobs: + status = "active" if job.get("is_active") else "paused" + name = escape_html(job.get("job_name", "?")) + cron = escape_html(job.get("cron_expression", "?")) + job_id = escape_html(str(job.get("job_id", "?"))) + lines.append( + f"{name} [{status}]\n" + f" Cron: {cron}\n" + f" ID: {job_id}" + ) + + await update.message.reply_html("\n\n".join(lines)) + + +async def _handle_add( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + scheduler: JobScheduler, + args: List[str], +) -> None: + """Add a new scheduled job. + + Expected args: + The 5 cron fields are joined into a single cron expression. + """ + # Need at least: name + 5 cron fields + 1 prompt word = 7 + if len(args) < 7: + await update.message.reply_html( + "Usage: /schedule add <name> " + "<min> <hour> <day> <month> <weekday> " + "<prompt>\n\n" + "Example: /schedule add daily-report 0 9 * * * Check status" + ) + return + + job_name = args[0] + cron_expression = " ".join(args[1:6]) + prompt = " ".join(args[6:]) + + # Auto-populate fields from context + chat_id = update.effective_chat.id + user_id = update.effective_user.id + settings = context.bot_data.get("settings") + working_dir = context.user_data.get( + "current_directory", + settings.approved_directory if settings else None, + ) + + try: + job_id = await scheduler.add_job( + job_name=job_name, + cron_expression=cron_expression, + prompt=prompt, + target_chat_ids=[chat_id], + working_directory=working_dir, + created_by=user_id, + ) + await update.message.reply_html( + f"Job {escape_html(job_name)} created.\n" + f"ID: {escape_html(job_id)}\n" + f"Cron: {escape_html(cron_expression)}" + ) + except Exception as e: + logger.exception("Failed to add scheduled job", error=str(e)) + await update.message.reply_text(f"Failed to create job: {e}") + + +async def _handle_remove( + update: Update, scheduler: JobScheduler, args: List[str] +) -> None: + """Remove a scheduled job.""" + if not args: + await update.message.reply_text("Usage: /schedule remove ") + return + + job_id = args[0] + await scheduler.remove_job(job_id) + await update.message.reply_html( + f"Job {escape_html(job_id)} removed." + ) + + +async def _handle_pause( + update: Update, scheduler: JobScheduler, args: List[str] +) -> None: + """Pause a scheduled job.""" + if not args: + await update.message.reply_text("Usage: /schedule pause ") + return + + job_id = args[0] + success = await scheduler.pause_job(job_id) + if success: + await update.message.reply_html( + f"Job {escape_html(job_id)} paused." + ) + else: + await update.message.reply_text(f"Job '{job_id}' not found.") + + +async def _handle_resume( + update: Update, scheduler: JobScheduler, args: List[str] +) -> None: + """Resume a paused job.""" + if not args: + await update.message.reply_text("Usage: /schedule resume ") + return + + job_id = args[0] + success = await scheduler.resume_job(job_id) + if success: + await update.message.reply_html( + f"Job {escape_html(job_id)} resumed." + ) + else: + await update.message.reply_text(f"Job '{job_id}' not found or failed to resume.") From d234b36704b28487909686702b4fc05aa6809068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20D=C3=ADez=20Arias?= Date: Sun, 15 Mar 2026 10:27:49 +0100 Subject: [PATCH 06/11] feat(schedule-command): register /schedule in MessageOrchestrator Add /schedule handler to both agentic and classic mode registration, gated on enable_scheduler. Add to get_bot_commands() for Telegram command menu visibility. --- src/bot/orchestrator.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index 6d9719f0d..902b7b448 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -331,6 +331,10 @@ def _register_agentic_handlers(self, app: Application) -> None: ] if self.settings.enable_project_threads: handlers.append(("sync_threads", command.sync_threads)) + if self.settings.enable_scheduler: + from .handlers import schedule + + handlers.append(("schedule", schedule.schedule_command)) # Derive known commands dynamically — avoids drift when new commands are added self._known_commands: frozenset[str] = frozenset(cmd for cmd, _ in handlers) @@ -420,6 +424,10 @@ def _register_classic_handlers(self, app: Application) -> None: ] if self.settings.enable_project_threads: handlers.append(("sync_threads", command.sync_threads)) + if self.settings.enable_scheduler: + from .handlers import schedule + + handlers.append(("schedule", schedule.schedule_command)) for cmd, handler in handlers: app.add_handler(CommandHandler(cmd, self._inject_deps(handler))) @@ -464,6 +472,8 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] ] if self.settings.enable_project_threads: commands.append(BotCommand("sync_threads", "Sync project topics")) + if self.settings.enable_scheduler: + commands.append(BotCommand("schedule", "Manage scheduled jobs")) return commands else: commands = [ @@ -484,6 +494,8 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] ] if self.settings.enable_project_threads: commands.append(BotCommand("sync_threads", "Sync project topics")) + if self.settings.enable_scheduler: + commands.append(BotCommand("schedule", "Manage scheduled jobs")) return commands # --- Agentic handlers --- From e815333bf349af97b859135d002e8036f57cad1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20D=C3=ADez=20Arias?= Date: Sun, 15 Mar 2026 10:30:37 +0100 Subject: [PATCH 07/11] feat(schedule-command): add unit tests for schedule handler 16 tests covering all subcommands (list, add, remove, pause, resume), argument parsing edge cases, error handling, and the scheduler-not- available fallback path. 100% coverage on the handler module. --- tests/unit/test_schedule_handler.py | 214 ++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 tests/unit/test_schedule_handler.py diff --git a/tests/unit/test_schedule_handler.py b/tests/unit/test_schedule_handler.py new file mode 100644 index 000000000..408a41268 --- /dev/null +++ b/tests/unit/test_schedule_handler.py @@ -0,0 +1,214 @@ +"""Tests for the /schedule command handler.""" + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from src.bot.handlers.schedule import schedule_command + + +@pytest.fixture +def tmp_dir(): + with tempfile.TemporaryDirectory() as d: + yield Path(d) + + +@pytest.fixture +def mock_scheduler(): + scheduler = AsyncMock() + scheduler.list_jobs = AsyncMock(return_value=[]) + scheduler.add_job = AsyncMock(return_value="job-123") + scheduler.remove_job = AsyncMock(return_value=True) + scheduler.pause_job = AsyncMock(return_value=True) + scheduler.resume_job = AsyncMock(return_value=True) + return scheduler + + +@pytest.fixture +def update(): + update = MagicMock() + update.message = MagicMock() + update.message.reply_text = AsyncMock() + update.message.reply_html = AsyncMock() + update.effective_chat = MagicMock() + update.effective_chat.id = 12345 + update.effective_user = MagicMock() + update.effective_user.id = 67890 + return update + + +@pytest.fixture +def context(mock_scheduler, tmp_dir): + context = MagicMock() + context.bot_data = { + "scheduler": mock_scheduler, + "settings": MagicMock(approved_directory=tmp_dir), + } + context.user_data = {"current_directory": tmp_dir} + context.args = [] + return context + + +async def test_no_scheduler_shows_error(update, context): + """When scheduler is not injected, show error message.""" + context.bot_data["scheduler"] = None + await schedule_command(update, context) + update.message.reply_text.assert_called_once() + assert "not enabled" in update.message.reply_text.call_args[0][0] + + +async def test_no_args_shows_usage(update, context): + """No subcommand shows usage info.""" + context.args = [] + await schedule_command(update, context) + update.message.reply_html.assert_called_once() + assert "Usage" in update.message.reply_html.call_args[0][0] + + +async def test_unknown_subcommand_shows_usage(update, context): + """Unknown subcommand shows usage info.""" + context.args = ["unknown"] + await schedule_command(update, context) + update.message.reply_html.assert_called_once() + assert "Usage" in update.message.reply_html.call_args[0][0] + + +async def test_list_empty(update, context, mock_scheduler): + """List with no jobs shows empty message.""" + context.args = ["list"] + await schedule_command(update, context) + mock_scheduler.list_jobs.assert_called_once_with(include_paused=True) + update.message.reply_text.assert_called_once_with("No scheduled jobs.") + + +async def test_list_with_jobs(update, context, mock_scheduler): + """List with jobs formats them correctly.""" + mock_scheduler.list_jobs.return_value = [ + { + "job_id": "abc-123", + "job_name": "daily-report", + "cron_expression": "0 9 * * *", + "is_active": 1, + }, + { + "job_id": "def-456", + "job_name": "weekly-backup", + "cron_expression": "0 0 * * 0", + "is_active": 0, + }, + ] + context.args = ["list"] + await schedule_command(update, context) + html = update.message.reply_html.call_args[0][0] + assert "daily-report" in html + assert "active" in html + assert "weekly-backup" in html + assert "paused" in html + assert "abc-123" in html + + +async def test_add_success(update, context, mock_scheduler, tmp_dir): + """Add creates a job with correct parameters.""" + context.args = ["add", "my-job", "0", "9", "*", "*", "1-5", "Run", "status", "check"] + await schedule_command(update, context) + + mock_scheduler.add_job.assert_called_once_with( + job_name="my-job", + cron_expression="0 9 * * 1-5", + prompt="Run status check", + target_chat_ids=[12345], + working_directory=tmp_dir, + created_by=67890, + ) + html = update.message.reply_html.call_args[0][0] + assert "my-job" in html + assert "job-123" in html + + +async def test_add_too_few_args(update, context): + """Add with insufficient args shows usage.""" + context.args = ["add", "my-job", "0", "9"] + await schedule_command(update, context) + update.message.reply_html.assert_called_once() + assert "Usage" in update.message.reply_html.call_args[0][0] + + +async def test_add_failure(update, context, mock_scheduler): + """Add handles scheduler errors gracefully.""" + mock_scheduler.add_job.side_effect = ValueError("Invalid cron expression") + context.args = ["add", "bad-job", "x", "y", "z", "a", "b", "Do", "something"] + await schedule_command(update, context) + update.message.reply_text.assert_called_once() + assert "Failed" in update.message.reply_text.call_args[0][0] + + +async def test_remove(update, context, mock_scheduler): + """Remove calls scheduler.remove_job.""" + context.args = ["remove", "job-123"] + await schedule_command(update, context) + mock_scheduler.remove_job.assert_called_once_with("job-123") + html = update.message.reply_html.call_args[0][0] + assert "job-123" in html + assert "removed" in html + + +async def test_remove_no_id(update, context): + """Remove without job_id shows usage.""" + context.args = ["remove"] + await schedule_command(update, context) + update.message.reply_text.assert_called_once() + assert "Usage" in update.message.reply_text.call_args[0][0] + + +async def test_pause_success(update, context, mock_scheduler): + """Pause calls scheduler.pause_job.""" + context.args = ["pause", "job-123"] + await schedule_command(update, context) + mock_scheduler.pause_job.assert_called_once_with("job-123") + html = update.message.reply_html.call_args[0][0] + assert "paused" in html + + +async def test_pause_not_found(update, context, mock_scheduler): + """Pause with unknown job_id shows not found.""" + mock_scheduler.pause_job.return_value = False + context.args = ["pause", "unknown-id"] + await schedule_command(update, context) + update.message.reply_text.assert_called_once() + assert "not found" in update.message.reply_text.call_args[0][0] + + +async def test_pause_no_id(update, context): + """Pause without job_id shows usage.""" + context.args = ["pause"] + await schedule_command(update, context) + update.message.reply_text.assert_called_once() + assert "Usage" in update.message.reply_text.call_args[0][0] + + +async def test_resume_success(update, context, mock_scheduler): + """Resume calls scheduler.resume_job.""" + context.args = ["resume", "job-123"] + await schedule_command(update, context) + mock_scheduler.resume_job.assert_called_once_with("job-123") + html = update.message.reply_html.call_args[0][0] + assert "resumed" in html + + +async def test_resume_not_found(update, context, mock_scheduler): + """Resume with unknown job_id shows failure message.""" + mock_scheduler.resume_job.return_value = False + context.args = ["resume", "unknown-id"] + await schedule_command(update, context) + update.message.reply_text.assert_called_once() + assert "not found" in update.message.reply_text.call_args[0][0] + + +async def test_resume_no_id(update, context): + """Resume without job_id shows usage.""" + context.args = ["resume"] + await schedule_command(update, context) + update.message.reply_text.assert_called_once() + assert "Usage" in update.message.reply_text.call_args[0][0] From a19ec877e97b78c86341abb1eb8cb93a76dc1550 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20D=C3=ADez=20Arias?= Date: Sun, 15 Mar 2026 10:49:35 +0100 Subject: [PATCH 08/11] docs: document /schedule command Update command lists in README and CLAUDE.md, expand the Job Scheduler section in docs/setup.md with usage examples, and add CHANGELOG entry. --- CHANGELOG.md | 3 +++ CLAUDE.md | 2 +- README.md | 8 ++++---- docs/setup.md | 18 +++++++++++++++++- 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96760404a..b348f675f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- **`/schedule` command**: Create, list, pause, resume, and remove scheduled jobs directly from Telegram. Auto-populates chat, directory, and user from context. Requires `ENABLE_SCHEDULER=true` (#150) + ## [1.6.0] - 2026-03-30 ### Added diff --git a/CLAUDE.md b/CLAUDE.md index d162b829f..f0636e7e8 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -181,7 +181,7 @@ All datetimes use timezone-aware UTC: `datetime.now(UTC)` (not `datetime.utcnow( ### Agentic mode -Agentic mode commands: `/start`, `/new`, `/status`, `/verbose`, `/repo`. If `ENABLE_PROJECT_THREADS=true`: `/sync_threads`. To add a new command: +Agentic mode commands: `/start`, `/new`, `/status`, `/verbose`, `/repo`. If `ENABLE_PROJECT_THREADS=true`: `/sync_threads`. If `ENABLE_SCHEDULER=true`: `/schedule`. To add a new command: 1. Add handler function in `src/bot/orchestrator.py` 2. Register in `MessageOrchestrator._register_agentic_handlers()` diff --git a/README.md b/README.md index e30bb05be..ce2a1297a 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ The bot supports two interaction modes: The default conversational mode. Just talk to Claude naturally -- no special commands required. **Commands:** `/start`, `/new`, `/status`, `/verbose`, `/repo` -If `ENABLE_PROJECT_THREADS=true`: `/sync_threads` +If `ENABLE_PROJECT_THREADS=true`: `/sync_threads` | If `ENABLE_SCHEDULER=true`: `/schedule` ``` You: What files are in this project? @@ -157,8 +157,8 @@ Use `/repo` to list cloned repos in your workspace, or `/repo ` to switch Set `AGENTIC_MODE=false` to enable the full 13-command terminal-like interface with directory navigation, inline keyboards, quick actions, git integration, and session export. -**Commands:** `/start`, `/help`, `/new`, `/continue`, `/end`, `/status`, `/cd`, `/ls`, `/pwd`, `/projects`, `/export`, `/actions`, `/git` -If `ENABLE_PROJECT_THREADS=true`: `/sync_threads` +**Commands:** `/start`, `/help`, `/new`, `/continue`, `/end`, `/status`, `/cd`, `/ls`, `/pwd`, `/projects`, `/export`, `/actions`, `/git` +If `ENABLE_PROJECT_THREADS=true`: `/sync_threads` | If `ENABLE_SCHEDULER=true`: `/schedule` ``` You: /cd my-web-app @@ -176,7 +176,7 @@ Bot: [Run Tests] [Install Deps] [Format Code] [Run Linter] Beyond direct chat, the bot can respond to external triggers: - **Webhooks** -- Receive GitHub events (push, PR, issues) and route them through Claude for automated summaries or code review -- **Scheduler** -- Run recurring Claude tasks on a cron schedule (e.g., daily code health checks) +- **Scheduler** -- Run recurring Claude tasks on a cron schedule (e.g., daily code health checks). Manage jobs directly from Telegram with `/schedule` - **Notifications** -- Deliver agent responses to configured Telegram chats Enable with `ENABLE_API_SERVER=true` and `ENABLE_SCHEDULER=true`. See [docs/setup.md](docs/setup.md) for configuration. diff --git a/docs/setup.md b/docs/setup.md index acb7f906d..23c201497 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -173,7 +173,23 @@ ENABLE_SCHEDULER=true NOTIFICATION_CHAT_IDS=123456789 # Where to deliver results ``` -Jobs are managed programmatically and persist in the SQLite database. +Jobs persist in the SQLite database and survive bot restarts. Manage them directly from Telegram with the `/schedule` command: + +``` +/schedule list # List all jobs (active + paused) +/schedule add +/schedule remove # Remove a job +/schedule pause # Pause without deleting +/schedule resume # Resume a paused job +``` + +Example -- create a job that runs every weekday at 9 AM: + +``` +/schedule add daily-report 0 9 * * 1-5 Summarize yesterday's git commits +``` + +The `chat_id`, `working_directory`, and `created_by` fields are auto-populated from your current Telegram context. Results are delivered to the chat where the job was created. ### Voice Message Transcription From d9ad1f4baba488729c0be2be5e853e9a4462fdcc Mon Sep 17 00:00:00 2001 From: Talla86 Date: Fri, 20 Mar 2026 18:25:11 +0100 Subject: [PATCH 09/11] feat: add /model command for runtime model and effort switching Add /model command with inline keyboard UI for switching between Opus/Sonnet/Haiku models and effort levels (low/medium/high/max) at runtime. Model changes force a new session since the CLI doesn't support model switching on resumed sessions. - Effort levels are model-aware: Haiku has none, Sonnet excludes "max", Opus supports all including "max" - Override is per-user via context.user_data (in-memory, resets on bot restart) - Threaded through all run_command call sites (orchestrator, classic message handler) into the SDK layer - Registered in both agentic and classic handler modes - Added to bot command menu and /help text - 17 new tests covering keyboard display, model/effort selection, label formatting, and effort-per-model configuration Closes #138 --- src/bot/handlers/callback.py | 7 + src/bot/handlers/command.py | 162 ++++++++++++++- src/bot/handlers/message.py | 8 + src/bot/orchestrator.py | 18 ++ src/claude/facade.py | 12 ++ src/claude/sdk_integration.py | 12 +- tests/unit/test_bot/test_model_command.py | 242 ++++++++++++++++++++++ tests/unit/test_orchestrator.py | 30 +-- 8 files changed, 475 insertions(+), 16 deletions(-) create mode 100644 tests/unit/test_bot/test_model_command.py diff --git a/src/bot/handlers/callback.py b/src/bot/handlers/callback.py index 66dd660c4..1a985e81b 100644 --- a/src/bot/handlers/callback.py +++ b/src/bot/handlers/callback.py @@ -57,6 +57,11 @@ async def handle_callback_query( action, param = data, None # Route to appropriate handler + from .command import _handle_model_selection + + async def _model_effort_handler(query, param, context): + await _handle_model_selection(query, f"{action}:{param}", context) + handlers = { "cd": handle_cd_callback, "action": handle_action_callback, @@ -66,6 +71,8 @@ async def handle_callback_query( "conversation": handle_conversation_callback, "git": handle_git_callback, "export": handle_export_callback, + "model": _model_effort_handler, + "effort": _model_effort_handler, } handler = handlers.get(action) diff --git a/src/bot/handlers/command.py b/src/bot/handlers/command.py index 651a08f8c..7bdab52cc 100644 --- a/src/bot/handlers/command.py +++ b/src/bot/handlers/command.py @@ -174,7 +174,8 @@ async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No "• /status - Show session and usage status\n" "• /export - Export session history\n" "• /actions - Show context-aware quick actions\n" - "• /git - Git repository information\n\n" + "• /git - Git repository information\n" + "• /model [name] - View or switch Claude model\n\n" "Session Behavior:\n" "• Sessions are automatically maintained per project directory\n" "• Switching directories with /cd resumes the session for that project\n" @@ -1232,6 +1233,165 @@ async def git_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non logger.error("Error in git_command", error=str(e), user_id=user_id) +# Model IDs mapped to user-friendly labels +_MODELS = { + "opus": "claude-opus-4-6", + "sonnet": "claude-sonnet-4-6", + "haiku": "claude-haiku-4-5-20251001", +} + +# Effort levels per model (Haiku doesn't support effort; "max" is Opus-only) +_EFFORT_BY_MODEL = { + "opus": ["low", "medium", "high", "max"], + "sonnet": ["low", "medium", "high"], + "haiku": [], +} + + +def _current_model_label(context: ContextTypes.DEFAULT_TYPE) -> str: + """Return a human-friendly label for the active model + effort.""" + override = context.user_data.get("model_override") + effort = context.user_data.get("effort_override") + # Reverse-map model ID to short name + model_id = override or "" + label = model_id + for short, full in _MODELS.items(): + if full == model_id: + label = short.capitalize() + break + if not override: + label = "Default" + parts = [label] + if effort: + parts.append(f"effort={effort}") + return " | ".join(parts) + + +async def model_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle /model command - show model selection keyboard.""" + current = _current_model_label(context) + + keyboard = [ + [ + InlineKeyboardButton("Opus", callback_data="model:opus"), + InlineKeyboardButton("Sonnet", callback_data="model:sonnet"), + InlineKeyboardButton("Haiku", callback_data="model:haiku"), + ], + [InlineKeyboardButton("Reset to default", callback_data="model:default")], + ] + + await update.message.reply_text( + f"🤖 Current: {escape_html(current)}\n\n" + "Choose a model:\n" + "⚠️ Switching will start a new session.", + parse_mode="HTML", + reply_markup=InlineKeyboardMarkup(keyboard), + ) + + +async def _handle_model_selection(query, data: str, context) -> None: + """Shared logic for model/effort selection (used by both callback routes).""" + if data.startswith("model:"): + choice = data.split(":", 1)[1] + + if choice == "default": + context.user_data.pop("model_override", None) + context.user_data.pop("effort_override", None) + context.user_data["force_new_session"] = True + await query.edit_message_text( + "🤖 Model and effort reset to server defaults.\n" + "Next message starts a fresh session.", + parse_mode="HTML", + ) + logger.info("Model override cleared", user_id=query.from_user.id) + return + + model_id = _MODELS.get(choice) + if not model_id: + await query.edit_message_text("Unknown model.") + return + + context.user_data["model_override"] = model_id + # Clear stale effort when switching models + context.user_data.pop("effort_override", None) + # Force new session so the model change takes effect immediately + context.user_data["force_new_session"] = True + + logger.info( + "Model override set", + user_id=query.from_user.id, + model=model_id, + ) + + # Show effort level selection (if supported by this model) + effort_levels = _EFFORT_BY_MODEL.get(choice, []) + if not effort_levels: + # Model doesn't support effort (e.g. Haiku) + current = _current_model_label(context) + await query.edit_message_text( + f"🤖 {escape_html(current)} — ready.\n" + "New session will start with your next message.", + parse_mode="HTML", + ) + return + + rows = [] + row = [] + for level in effort_levels: + row.append( + InlineKeyboardButton(level.capitalize(), callback_data=f"effort:{level}") + ) + if len(row) == 2: + rows.append(row) + row = [] + if row: + rows.append(row) + rows.append( + [InlineKeyboardButton("Skip (keep current)", callback_data="effort:skip")] + ) + + await query.edit_message_text( + f"🤖 Model set to {escape_html(choice.capitalize())}.\n\n" + "Choose effort level:", + parse_mode="HTML", + reply_markup=InlineKeyboardMarkup(rows), + ) + + elif data.startswith("effort:"): + level = data.split(":", 1)[1] + + if level == "skip": + current = _current_model_label(context) + await query.edit_message_text( + f"🤖 {escape_html(current)} — ready.\n" + "New session will start with your next message.", + parse_mode="HTML", + ) + return + + all_effort_levels = {"low", "medium", "high", "max"} + if level in all_effort_levels: + context.user_data["effort_override"] = level + current = _current_model_label(context) + await query.edit_message_text( + f"🤖 {escape_html(current)} — ready.\n" + "New session will start with your next message.", + parse_mode="HTML", + ) + logger.info( + "Effort override set", + user_id=query.from_user.id, + effort=level, + ) + + +async def model_callback(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle model and effort selection callbacks (agentic mode route).""" + query = update.callback_query + await query.answer() + await _handle_model_selection(query, query.data, context) + + async def restart_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle /restart command - gracefully restart the bot process. diff --git a/src/bot/handlers/message.py b/src/bot/handlers/message.py index bbd240840..f0d87b7ff 100644 --- a/src/bot/handlers/message.py +++ b/src/bot/handlers/message.py @@ -393,6 +393,8 @@ async def stream_handler(update_obj): session_id=session_id, on_stream=stream_handler, force_new=force_new, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # New session created successfully — clear the one-shot flag @@ -818,6 +820,8 @@ async def handle_document(update: Update, context: ContextTypes.DEFAULT_TYPE) -> working_directory=current_dir, user_id=user_id, session_id=session_id, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # Update session ID @@ -945,6 +949,8 @@ async def handle_photo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No working_directory=current_dir, user_id=user_id, session_id=session_id, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # Update session ID @@ -1073,6 +1079,8 @@ async def handle_voice(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No working_directory=current_dir, user_id=user_id, session_id=session_id, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) context.user_data["claude_session_id"] = claude_response.session_id diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index 902b7b448..4d020ecb4 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -327,6 +327,7 @@ def _register_agentic_handlers(self, app: Application) -> None: ("status", self.agentic_status), ("verbose", self.agentic_verbose), ("repo", self.agentic_repo), + ("model", command.model_command), ("restart", command.restart_command), ] if self.settings.enable_project_threads: @@ -392,6 +393,14 @@ def _register_agentic_handlers(self, app: Application) -> None: ) ) + # Model/effort selection callbacks + app.add_handler( + CallbackQueryHandler( + self._inject_deps(command.model_callback), + pattern=r"^(model|effort):", + ) + ) + # Only cd: callbacks (for project selection), scoped by pattern app.add_handler( CallbackQueryHandler( @@ -420,6 +429,7 @@ def _register_classic_handlers(self, app: Application) -> None: ("export", command.export_session), ("actions", command.quick_actions), ("git", command.git_command), + ("model", command.model_command), ("restart", command.restart_command), ] if self.settings.enable_project_threads: @@ -468,6 +478,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] BotCommand("status", "Show session status"), BotCommand("verbose", "Set output verbosity (0/1/2)"), BotCommand("repo", "List repos / switch workspace"), + BotCommand("model", "Switch Claude model and effort"), BotCommand("restart", "Restart the bot"), ] if self.settings.enable_project_threads: @@ -490,6 +501,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] BotCommand("export", "Export current session"), BotCommand("actions", "Show quick actions"), BotCommand("git", "Git repository commands"), + BotCommand("model", "Switch Claude model and effort"), BotCommand("restart", "Restart the bot"), ] if self.settings.enable_project_threads: @@ -1026,6 +1038,8 @@ async def agentic_text( on_stream=on_stream, force_new=force_new, interrupt_event=interrupt_event, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # New session created successfully — clear the one-shot flag @@ -1276,6 +1290,8 @@ async def agentic_document( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) if force_new: @@ -1486,6 +1502,8 @@ async def _handle_agentic_media_message( on_stream=on_stream, force_new=force_new, images=images, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) finally: heartbeat.cancel() diff --git a/src/claude/facade.py b/src/claude/facade.py index b1cafba49..03b416c6a 100644 --- a/src/claude/facade.py +++ b/src/claude/facade.py @@ -40,6 +40,8 @@ async def run_command( force_new: bool = False, interrupt_event: Optional["asyncio.Event"] = None, images: Optional[List[Dict[str, str]]] = None, + model_override: Optional[str] = None, + effort_override: Optional[str] = None, ) -> ClaudeResponse: """Run Claude Code command with full integration.""" logger.info( @@ -49,6 +51,8 @@ async def run_command( session_id=session_id, prompt_length=len(prompt), force_new=force_new, + model_override=model_override, + effort_override=effort_override, ) # If no session_id provided, try to find an existing session for this @@ -90,6 +94,8 @@ async def run_command( stream_callback=on_stream, interrupt_event=interrupt_event, images=images, + model_override=model_override, + effort_override=effort_override, ) except Exception as resume_error: # If resume failed (e.g., session expired/missing on Claude's side), @@ -116,6 +122,8 @@ async def run_command( stream_callback=on_stream, interrupt_event=interrupt_event, images=images, + model_override=model_override, + effort_override=effort_override, ) else: raise @@ -161,6 +169,8 @@ async def _execute( stream_callback: Optional[Callable] = None, interrupt_event: Optional[asyncio.Event] = None, images: Optional[List[Dict[str, str]]] = None, + model_override: Optional[str] = None, + effort_override: Optional[str] = None, ) -> ClaudeResponse: """Execute command via SDK.""" return await self.sdk_manager.execute_command( @@ -171,6 +181,8 @@ async def _execute( stream_callback=stream_callback, interrupt_event=interrupt_event, images=images, + model_override=model_override, + effort_override=effort_override, ) async def _find_resumable_session( diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index 5a95f16da..0bf0f0dd2 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -277,6 +277,8 @@ async def execute_command( stream_callback: Optional[Callable[[StreamUpdate], None]] = None, interrupt_event: Optional[asyncio.Event] = None, images: Optional[List[Dict[str, str]]] = None, + model_override: Optional[str] = None, + effort_override: Optional[str] = None, ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" start_time = asyncio.get_event_loop().time() @@ -321,7 +323,7 @@ def _stderr_callback(line: str) -> None: # Build Claude Agent options options = ClaudeAgentOptions( max_turns=self.config.claude_max_turns, - model=self.config.claude_model or None, + model=model_override or self.config.claude_model or None, max_budget_usd=self.config.claude_max_cost_per_request, cwd=str(working_directory), allowed_tools=sdk_allowed_tools, @@ -334,10 +336,18 @@ def _stderr_callback(line: str) -> None: "excludedCommands": self.config.sandbox_excluded_commands or [], }, system_prompt=base_prompt, + effort=effort_override, setting_sources=["project"], stderr=_stderr_callback, ) + if model_override or effort_override: + logger.info( + "Runtime model/effort override active", + model_override=model_override, + effort_override=effort_override, + ) + # Pass MCP server configuration if enabled if self.config.enable_mcp and self.config.mcp_config_path: options.mcp_servers = self._load_mcp_config(self.config.mcp_config_path) diff --git a/tests/unit/test_bot/test_model_command.py b/tests/unit/test_bot/test_model_command.py new file mode 100644 index 000000000..c6ad0df0a --- /dev/null +++ b/tests/unit/test_bot/test_model_command.py @@ -0,0 +1,242 @@ +"""Tests for the /model command — runtime model and effort switching. + +Covers: +- /model shows inline keyboard with model choices +- Model selection sets model_override and force_new_session +- Effort selection sets effort_override +- "default" clears all overrides +- Haiku skips effort keyboard (not supported) +- Opus shows "max" effort, Sonnet does not +- _current_model_label returns correct labels +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from telegram import InlineKeyboardMarkup + +from src.bot.handlers.command import ( + _EFFORT_BY_MODEL, + _MODELS, + _current_model_label, + _handle_model_selection, + model_command, +) +from src.config.settings import Settings + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def settings(tmp_path): + return Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + ) + + +@pytest.fixture +def context(settings): + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {"settings": settings} + ctx.args = None + return ctx + + +@pytest.fixture +def update(context): + upd = MagicMock() + upd.message = AsyncMock() + upd.effective_user.id = 12345 + return upd + + +@pytest.fixture +def callback_query(): + query = MagicMock() + query.answer = AsyncMock() + query.edit_message_text = AsyncMock() + query.from_user.id = 12345 + return query + + +# --------------------------------------------------------------------------- +# /model command (keyboard display) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_model_command_shows_keyboard(update, context): + """Verify /model sends an inline keyboard with model choices.""" + await model_command(update, context) + + update.message.reply_text.assert_called_once() + call_kwargs = update.message.reply_text.call_args + assert isinstance(call_kwargs.kwargs["reply_markup"], InlineKeyboardMarkup) + # Should contain the session warning + assert "new session" in call_kwargs.args[0].lower() + + +@pytest.mark.asyncio +async def test_model_command_shows_current_override(update, context): + """When an override is active, /model should show it.""" + context.user_data["model_override"] = _MODELS["sonnet"] + await model_command(update, context) + + text = update.message.reply_text.call_args.args[0] + assert "Sonnet" in text + + +# --------------------------------------------------------------------------- +# Model selection callback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_select_opus_sets_override(callback_query, context): + """Selecting Opus sets model_override and force_new_session.""" + await _handle_model_selection(callback_query, "model:opus", context) + + assert context.user_data["model_override"] == _MODELS["opus"] + assert context.user_data["force_new_session"] is True + + +@pytest.mark.asyncio +async def test_select_sonnet_sets_override(callback_query, context): + """Selecting Sonnet sets the correct model ID.""" + await _handle_model_selection(callback_query, "model:sonnet", context) + + assert context.user_data["model_override"] == _MODELS["sonnet"] + + +@pytest.mark.asyncio +async def test_select_haiku_skips_effort(callback_query, context): + """Selecting Haiku should not show effort keyboard (not supported).""" + await _handle_model_selection(callback_query, "model:haiku", context) + + assert context.user_data["model_override"] == _MODELS["haiku"] + # Final message, no reply_markup (no effort keyboard) + call_kwargs = callback_query.edit_message_text.call_args + assert "reply_markup" not in call_kwargs.kwargs or call_kwargs.kwargs.get("reply_markup") is None + assert "ready" in call_kwargs.args[0].lower() + + +@pytest.mark.asyncio +async def test_select_opus_shows_effort_with_max(callback_query, context): + """Opus should show effort keyboard including 'max'.""" + await _handle_model_selection(callback_query, "model:opus", context) + + call_kwargs = callback_query.edit_message_text.call_args + markup = call_kwargs.kwargs.get("reply_markup") + assert markup is not None + # Flatten button labels + labels = [btn.text for row in markup.inline_keyboard for btn in row] + assert "Max" in labels + assert "effort" in call_kwargs.args[0].lower() + + +@pytest.mark.asyncio +async def test_select_sonnet_shows_effort_without_max(callback_query, context): + """Sonnet should show effort keyboard without 'max'.""" + await _handle_model_selection(callback_query, "model:sonnet", context) + + call_kwargs = callback_query.edit_message_text.call_args + markup = call_kwargs.kwargs.get("reply_markup") + assert markup is not None + labels = [btn.text for row in markup.inline_keyboard for btn in row] + assert "Max" not in labels + assert "High" in labels + + +@pytest.mark.asyncio +async def test_default_clears_overrides(callback_query, context): + """Selecting 'default' clears model, effort, and forces new session.""" + context.user_data["model_override"] = _MODELS["opus"] + context.user_data["effort_override"] = "high" + + await _handle_model_selection(callback_query, "model:default", context) + + assert "model_override" not in context.user_data + assert "effort_override" not in context.user_data + assert context.user_data["force_new_session"] is True + + +# --------------------------------------------------------------------------- +# Effort selection callback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_effort_sets_override(callback_query, context): + """Selecting an effort level stores it in user_data.""" + context.user_data["model_override"] = _MODELS["opus"] + + await _handle_model_selection(callback_query, "effort:high", context) + + assert context.user_data["effort_override"] == "high" + + +@pytest.mark.asyncio +async def test_effort_skip_keeps_existing(callback_query, context): + """Selecting 'skip' should not set effort_override.""" + context.user_data["model_override"] = _MODELS["sonnet"] + + await _handle_model_selection(callback_query, "effort:skip", context) + + assert "effort_override" not in context.user_data + + +@pytest.mark.asyncio +async def test_model_switch_clears_stale_effort(callback_query, context): + """Switching models should clear any previous effort override.""" + context.user_data["effort_override"] = "high" + + await _handle_model_selection(callback_query, "model:haiku", context) + + assert "effort_override" not in context.user_data + + +# --------------------------------------------------------------------------- +# Label helper +# --------------------------------------------------------------------------- + + +def test_label_default(): + ctx = MagicMock() + ctx.user_data = {} + assert _current_model_label(ctx) == "Default" + + +def test_label_with_model_and_effort(): + ctx = MagicMock() + ctx.user_data = {"model_override": _MODELS["sonnet"], "effort_override": "medium"} + assert _current_model_label(ctx) == "Sonnet | effort=medium" + + +def test_label_model_only(): + ctx = MagicMock() + ctx.user_data = {"model_override": _MODELS["opus"]} + assert _current_model_label(ctx) == "Opus" + + +# --------------------------------------------------------------------------- +# Effort level configuration +# --------------------------------------------------------------------------- + + +def test_haiku_has_no_effort_levels(): + assert _EFFORT_BY_MODEL["haiku"] == [] + + +def test_sonnet_has_no_max(): + assert "max" not in _EFFORT_BY_MODEL["sonnet"] + assert "high" in _EFFORT_BY_MODEL["sonnet"] + + +def test_opus_has_max(): + assert "max" in _EFFORT_BY_MODEL["opus"] diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index ce5e419e9..9db8c8efe 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -82,8 +82,8 @@ def deps(): } -def test_agentic_registers_6_commands(agentic_settings, deps): - """Agentic mode registers start, new, status, verbose, repo, restart commands.""" +def test_agentic_registers_7_commands(agentic_settings, deps): + """Agentic mode registers start, new, status, verbose, repo, model, restart commands.""" orchestrator = MessageOrchestrator(agentic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -100,17 +100,18 @@ def test_agentic_registers_6_commands(agentic_settings, deps): ] commands = [h[0][0].commands for h in cmd_handlers] - assert len(cmd_handlers) == 6 + assert len(cmd_handlers) == 7 assert frozenset({"start"}) in commands assert frozenset({"new"}) in commands assert frozenset({"status"}) in commands assert frozenset({"verbose"}) in commands assert frozenset({"repo"}) in commands + assert frozenset({"model"}) in commands assert frozenset({"restart"}) in commands -def test_classic_registers_14_commands(classic_settings, deps): - """Classic mode registers all 14 commands.""" +def test_classic_registers_15_commands(classic_settings, deps): + """Classic mode registers all 15 commands.""" orchestrator = MessageOrchestrator(classic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -125,7 +126,7 @@ def test_classic_registers_14_commands(classic_settings, deps): if isinstance(call[0][0], CommandHandler) ] - assert len(cmd_handlers) == 14 + assert len(cmd_handlers) == 15 def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): @@ -151,30 +152,31 @@ def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): # 5 message handlers (text, document, photo, voice, unknown commands passthrough) assert len(msg_handlers) == 5 - # 2 callback handlers (stop: + cd:) - assert len(cb_handlers) == 2 + # 3 callback handlers (stop: + model/effort: + cd:) + assert len(cb_handlers) == 3 async def test_agentic_bot_commands(agentic_settings, deps): - """Agentic mode returns 6 bot commands.""" + """Agentic mode returns 7 bot commands.""" orchestrator = MessageOrchestrator(agentic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 6 + assert len(commands) == 7 cmd_names = [c.command for c in commands] - assert cmd_names == ["start", "new", "status", "verbose", "repo", "restart"] + assert cmd_names == ["start", "new", "status", "verbose", "repo", "model", "restart"] async def test_classic_bot_commands(classic_settings, deps): - """Classic mode returns 14 bot commands.""" + """Classic mode returns 15 bot commands.""" orchestrator = MessageOrchestrator(classic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 14 + assert len(commands) == 15 cmd_names = [c.command for c in commands] assert "start" in cmd_names assert "help" in cmd_names assert "git" in cmd_names + assert "model" in cmd_names assert "restart" in cmd_names @@ -338,7 +340,7 @@ async def test_agentic_callback_scoped_to_cd_pattern(agentic_settings, deps): if isinstance(call[0][0], CallbackQueryHandler) ] - assert len(cb_handlers) == 2 + assert len(cb_handlers) == 3 # Find the cd: handler by pattern cd_handler = [h for h in cb_handlers if h.pattern and h.pattern.match("cd:x")] assert len(cd_handler) == 1 From 7f32e0796449e6b6ee81174acd942a564de4cfbc Mon Sep 17 00:00:00 2001 From: Talla86 Date: Fri, 20 Mar 2026 18:30:29 +0100 Subject: [PATCH 10/11] fix: show server default model in /model display Instead of just "Default", show "Default (claude-sonnet-4-6)" or "Default (CLI default)" so users can verify what model is active after resetting. --- src/bot/handlers/command.py | 4 +++- tests/unit/test_bot/test_model_command.py | 12 ++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/bot/handlers/command.py b/src/bot/handlers/command.py index 7bdab52cc..18f164285 100644 --- a/src/bot/handlers/command.py +++ b/src/bot/handlers/command.py @@ -1260,7 +1260,9 @@ def _current_model_label(context: ContextTypes.DEFAULT_TYPE) -> str: label = short.capitalize() break if not override: - label = "Default" + settings = context.bot_data.get("settings") + server_model = getattr(settings, "claude_model", None) if settings else None + label = f"Default ({server_model or 'CLI default'})" parts = [label] if effort: parts.append(f"effort={effort}") diff --git a/tests/unit/test_bot/test_model_command.py b/tests/unit/test_bot/test_model_command.py index c6ad0df0a..538f59946 100644 --- a/tests/unit/test_bot/test_model_command.py +++ b/tests/unit/test_bot/test_model_command.py @@ -206,10 +206,18 @@ async def test_model_switch_clears_stale_effort(callback_query, context): # --------------------------------------------------------------------------- -def test_label_default(): +def test_label_default_no_settings(): ctx = MagicMock() ctx.user_data = {} - assert _current_model_label(ctx) == "Default" + ctx.bot_data = {} + assert _current_model_label(ctx) == "Default (CLI default)" + + +def test_label_default_with_server_model(): + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {"settings": MagicMock(claude_model="claude-sonnet-4-6")} + assert _current_model_label(ctx) == "Default (claude-sonnet-4-6)" def test_label_with_model_and_effort(): From 1937557877abc308c99bb24b4868aa09367df26d Mon Sep 17 00:00:00 2001 From: Talla86 Date: Wed, 1 Apr 2026 12:26:37 +0200 Subject: [PATCH 11/11] =?UTF-8?q?fix:=20address=20reviewer=20feedback=20?= =?UTF-8?q?=E2=80=94=20type=20annotations,=20closure,=20model=20aliases?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - callback.py: replace shared _model_effort_handler closure with two explicit lambdas (model:/effort:) — eliminates outer-scope capture; move import to module level - command.py: drop _MODELS dict with hardcoded version IDs; use _MODEL_FAMILIES list of short CLI aliases ("opus"/"sonnet"/"haiku") which the CLI resolves to current latest automatically - command.py: add CallbackQuery + ContextTypes type annotations to _handle_model_selection (fixes mypy strict mode) - command.py: simplify _current_model_label (no reverse-map needed) - command.py: add PR #165 compatibility comment on force_new_session - tests: update imports/assertions for _MODEL_FAMILIES; add 3 tests: closure regression guard, effort: prefix isolation, force_new_session Co-Authored-By: Claude Sonnet 4.6 --- src/bot/handlers/callback.py | 10 +--- src/bot/handlers/command.py | 48 ++++++++-------- tests/unit/test_bot/test_model_command.py | 70 +++++++++++++++++++---- 3 files changed, 84 insertions(+), 44 deletions(-) diff --git a/src/bot/handlers/callback.py b/src/bot/handlers/callback.py index 1a985e81b..e7ed002de 100644 --- a/src/bot/handlers/callback.py +++ b/src/bot/handlers/callback.py @@ -12,6 +12,7 @@ from ...security.audit import AuditLogger from ...security.validators import SecurityValidator from ..utils.html_format import escape_html +from .command import _handle_model_selection logger = structlog.get_logger() @@ -57,11 +58,6 @@ async def handle_callback_query( action, param = data, None # Route to appropriate handler - from .command import _handle_model_selection - - async def _model_effort_handler(query, param, context): - await _handle_model_selection(query, f"{action}:{param}", context) - handlers = { "cd": handle_cd_callback, "action": handle_action_callback, @@ -71,8 +67,8 @@ async def _model_effort_handler(query, param, context): "conversation": handle_conversation_callback, "git": handle_git_callback, "export": handle_export_callback, - "model": _model_effort_handler, - "effort": _model_effort_handler, + "model": lambda q, p, ctx: _handle_model_selection(q, f"model:{p}", ctx), + "effort": lambda q, p, ctx: _handle_model_selection(q, f"effort:{p}", ctx), } handler = handlers.get(action) diff --git a/src/bot/handlers/command.py b/src/bot/handlers/command.py index 18f164285..668e023cd 100644 --- a/src/bot/handlers/command.py +++ b/src/bot/handlers/command.py @@ -7,7 +7,7 @@ from typing import Optional import structlog -from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update +from telegram import CallbackQuery, InlineKeyboardButton, InlineKeyboardMarkup, Update from telegram.ext import ContextTypes from ...claude.facade import ClaudeIntegration @@ -1233,14 +1233,13 @@ async def git_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non logger.error("Error in git_command", error=str(e), user_id=user_id) -# Model IDs mapped to user-friendly labels -_MODELS = { - "opus": "claude-opus-4-6", - "sonnet": "claude-sonnet-4-6", - "haiku": "claude-haiku-4-5-20251001", -} +# Short CLI aliases passed directly to the Claude CLI, which resolves them to +# the current latest model of each family. No version numbers to maintain here. +# See: https://docs.anthropic.com/en/docs/about-claude/models/overview +_MODEL_FAMILIES = ["opus", "sonnet", "haiku"] -# Effort levels per model (Haiku doesn't support effort; "max" is Opus-only) +# Effort levels per model family. Haiku has none; "max" is Opus-only. +# Update here if a future model's effort support changes. _EFFORT_BY_MODEL = { "opus": ["low", "medium", "high", "max"], "sonnet": ["low", "medium", "high"], @@ -1250,23 +1249,15 @@ async def git_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non def _current_model_label(context: ContextTypes.DEFAULT_TYPE) -> str: """Return a human-friendly label for the active model + effort.""" - override = context.user_data.get("model_override") + override = context.user_data.get("model_override") # "opus", "sonnet", "haiku", or None effort = context.user_data.get("effort_override") - # Reverse-map model ID to short name - model_id = override or "" - label = model_id - for short, full in _MODELS.items(): - if full == model_id: - label = short.capitalize() - break if not override: settings = context.bot_data.get("settings") server_model = getattr(settings, "claude_model", None) if settings else None label = f"Default ({server_model or 'CLI default'})" - parts = [label] - if effort: - parts.append(f"effort={effort}") - return " | ".join(parts) + else: + label = override.capitalize() + return f"{label} | effort={effort}" if effort else label async def model_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -1291,7 +1282,11 @@ async def model_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N ) -async def _handle_model_selection(query, data: str, context) -> None: +async def _handle_model_selection( + query: CallbackQuery, + data: str, + context: ContextTypes.DEFAULT_TYPE, +) -> None: """Shared logic for model/effort selection (used by both callback routes).""" if data.startswith("model:"): choice = data.split(":", 1)[1] @@ -1299,6 +1294,7 @@ async def _handle_model_selection(query, data: str, context) -> None: if choice == "default": context.user_data.pop("model_override", None) context.user_data.pop("effort_override", None) + # Note: if PR #165 merges first, change this to context.chat_data context.user_data["force_new_session"] = True await query.edit_message_text( "🤖 Model and effort reset to server defaults.\n" @@ -1308,21 +1304,23 @@ async def _handle_model_selection(query, data: str, context) -> None: logger.info("Model override cleared", user_id=query.from_user.id) return - model_id = _MODELS.get(choice) - if not model_id: + if choice not in _MODEL_FAMILIES: await query.edit_message_text("Unknown model.") return - context.user_data["model_override"] = model_id + # Store short CLI alias ("opus"/"sonnet"/"haiku") — the CLI resolves it + # to the current latest model, so no version numbers to maintain. + context.user_data["model_override"] = choice # Clear stale effort when switching models context.user_data.pop("effort_override", None) # Force new session so the model change takes effect immediately + # Note: if PR #165 merges first, change this to context.chat_data context.user_data["force_new_session"] = True logger.info( "Model override set", user_id=query.from_user.id, - model=model_id, + model=choice, ) # Show effort level selection (if supported by this model) diff --git a/tests/unit/test_bot/test_model_command.py b/tests/unit/test_bot/test_model_command.py index 538f59946..681669f99 100644 --- a/tests/unit/test_bot/test_model_command.py +++ b/tests/unit/test_bot/test_model_command.py @@ -8,6 +8,8 @@ - Haiku skips effort keyboard (not supported) - Opus shows "max" effort, Sonnet does not - _current_model_label returns correct labels +- Regression: model: callback data prefix is never rewritten as effort: +- force_new_session is always set on model change """ from unittest.mock import AsyncMock, MagicMock @@ -17,7 +19,7 @@ from src.bot.handlers.command import ( _EFFORT_BY_MODEL, - _MODELS, + _MODEL_FAMILIES, _current_model_label, _handle_model_selection, model_command, @@ -85,7 +87,7 @@ async def test_model_command_shows_keyboard(update, context): @pytest.mark.asyncio async def test_model_command_shows_current_override(update, context): """When an override is active, /model should show it.""" - context.user_data["model_override"] = _MODELS["sonnet"] + context.user_data["model_override"] = "sonnet" await model_command(update, context) text = update.message.reply_text.call_args.args[0] @@ -99,19 +101,19 @@ async def test_model_command_shows_current_override(update, context): @pytest.mark.asyncio async def test_select_opus_sets_override(callback_query, context): - """Selecting Opus sets model_override and force_new_session.""" + """Selecting Opus sets model_override to short alias and force_new_session.""" await _handle_model_selection(callback_query, "model:opus", context) - assert context.user_data["model_override"] == _MODELS["opus"] + assert context.user_data["model_override"] == "opus" assert context.user_data["force_new_session"] is True @pytest.mark.asyncio async def test_select_sonnet_sets_override(callback_query, context): - """Selecting Sonnet sets the correct model ID.""" + """Selecting Sonnet sets the correct short alias.""" await _handle_model_selection(callback_query, "model:sonnet", context) - assert context.user_data["model_override"] == _MODELS["sonnet"] + assert context.user_data["model_override"] == "sonnet" @pytest.mark.asyncio @@ -119,7 +121,7 @@ async def test_select_haiku_skips_effort(callback_query, context): """Selecting Haiku should not show effort keyboard (not supported).""" await _handle_model_selection(callback_query, "model:haiku", context) - assert context.user_data["model_override"] == _MODELS["haiku"] + assert context.user_data["model_override"] == "haiku" # Final message, no reply_markup (no effort keyboard) call_kwargs = callback_query.edit_message_text.call_args assert "reply_markup" not in call_kwargs.kwargs or call_kwargs.kwargs.get("reply_markup") is None @@ -156,7 +158,7 @@ async def test_select_sonnet_shows_effort_without_max(callback_query, context): @pytest.mark.asyncio async def test_default_clears_overrides(callback_query, context): """Selecting 'default' clears model, effort, and forces new session.""" - context.user_data["model_override"] = _MODELS["opus"] + context.user_data["model_override"] = "opus" context.user_data["effort_override"] = "high" await _handle_model_selection(callback_query, "model:default", context) @@ -174,7 +176,7 @@ async def test_default_clears_overrides(callback_query, context): @pytest.mark.asyncio async def test_effort_sets_override(callback_query, context): """Selecting an effort level stores it in user_data.""" - context.user_data["model_override"] = _MODELS["opus"] + context.user_data["model_override"] = "opus" await _handle_model_selection(callback_query, "effort:high", context) @@ -184,7 +186,7 @@ async def test_effort_sets_override(callback_query, context): @pytest.mark.asyncio async def test_effort_skip_keeps_existing(callback_query, context): """Selecting 'skip' should not set effort_override.""" - context.user_data["model_override"] = _MODELS["sonnet"] + context.user_data["model_override"] = "sonnet" await _handle_model_selection(callback_query, "effort:skip", context) @@ -201,6 +203,44 @@ async def test_model_switch_clears_stale_effort(callback_query, context): assert "effort_override" not in context.user_data +# --------------------------------------------------------------------------- +# Regression: callback data prefix integrity (closure-bug guard) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_model_callback_sets_model_not_effort(callback_query, context): + """model: callback must set model_override, not effort_override. + + Regression guard: a shared closure capturing 'action' from outer scope + could silently rewrite 'model:opus' as 'effort:opus'. This test would + have caught that. + """ + await _handle_model_selection(callback_query, "model:sonnet", context) + + assert context.user_data.get("model_override") == "sonnet" + assert "effort_override" not in context.user_data + + +@pytest.mark.asyncio +async def test_effort_callback_sets_effort_not_model(callback_query, context): + """effort: callback must set effort_override and not overwrite model_override.""" + context.user_data["model_override"] = "sonnet" + + await _handle_model_selection(callback_query, "effort:high", context) + + assert context.user_data.get("effort_override") == "high" + assert context.user_data.get("model_override") == "sonnet" # unchanged + + +@pytest.mark.asyncio +async def test_force_new_session_set_on_model_switch(callback_query, context): + """force_new_session must be True after any model switch.""" + await _handle_model_selection(callback_query, "model:opus", context) + + assert context.user_data.get("force_new_session") is True + + # --------------------------------------------------------------------------- # Label helper # --------------------------------------------------------------------------- @@ -222,13 +262,13 @@ def test_label_default_with_server_model(): def test_label_with_model_and_effort(): ctx = MagicMock() - ctx.user_data = {"model_override": _MODELS["sonnet"], "effort_override": "medium"} + ctx.user_data = {"model_override": "sonnet", "effort_override": "medium"} assert _current_model_label(ctx) == "Sonnet | effort=medium" def test_label_model_only(): ctx = MagicMock() - ctx.user_data = {"model_override": _MODELS["opus"]} + ctx.user_data = {"model_override": "opus"} assert _current_model_label(ctx) == "Opus" @@ -248,3 +288,9 @@ def test_sonnet_has_no_max(): def test_opus_has_max(): assert "max" in _EFFORT_BY_MODEL["opus"] + + +def test_model_families_contains_expected(): + assert "opus" in _MODEL_FAMILIES + assert "sonnet" in _MODEL_FAMILIES + assert "haiku" in _MODEL_FAMILIES