diff --git a/docs/architecture.md b/docs/architecture.md index de4145c3..9b948061 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -16,11 +16,14 @@ 2. Initialize state with `_runtime_workspace` from `BubFramework.workspace`. 3. Merge all `load_state(message, session_id)` dicts. 4. Build prompt via `build_prompt(message, session_id, state)` (fallback to inbound `content` if empty). -5. Execute `run_model(prompt, session_id, state)`. -6. Always execute `save_state(...)` in a `finally` block. -7. Render outbound batches via `render_outbound(...)`, then flatten them. -8. If no outbound exists, emit one fallback outbound. -9. Dispatch each outbound via `dispatch_outbound(message)`. +5. Execute `run_model_stream(prompt, session_id, state)`. +6. For each stream event, call `OutboundChannelRouter.dispatch_event(...)`, which forwards to `channel.on_event(event, message)` when the target channel exists. +7. Always execute `save_state(...)` in a `finally` block. +8. Render outbound batches via `render_outbound(...)`, then flatten them. +9. If no outbound exists, emit one fallback outbound. +10. Dispatch each outbound via `dispatch_outbound(message)`. + +If no plugin implements `run_model_stream`, `HookRuntime` falls back to `run_model(prompt, session_id, state)` and adapts the returned text into a stream with a single text chunk. ## Hook Priority Semantics @@ -47,12 +50,23 @@ Builtin `BuiltinImpl` behavior includes: - `build_prompt`: supports comma command mode; non-command text may include `context_str`. -- `run_model`: delegates to `Agent.run()`. +- `run_model_stream`: delegates to `Agent.run()`. - `system_prompt`: combines a default prompt with workspace `AGENTS.md`. - `register_cli_commands`: installs `run`, `gateway`, `chat`, plus hidden diagnostic commands. - `provide_channels`: returns `telegram` and `cli` channel adapters. - `provide_tape_store`: returns a file-backed tape store under `~/.bub/tapes`. +## Channel Event Streaming + +Channels have two different outbound surfaces: + +- `send(message)`: handles the final rendered outbound message. +- `on_event(event, message)`: handles raw stream events while the model is still running. + +`on_event` is optional. Implement it when a channel can benefit from incremental rendering, typing indicators, progress updates, or partial text display. The `message` argument is the original inbound message, so channel implementations usually use it to recover routing metadata such as target channel, chat id, session id, or message kind. + +If a channel does not implement any special event behavior, it can ignore `on_event` and rely entirely on `send()`. + ## Boundaries - `Envelope` stays intentionally weakly typed (`Any` + accessor helpers). diff --git a/docs/channels/index.md b/docs/channels/index.md index 3a0b2390..3219df95 100644 --- a/docs/channels/index.md +++ b/docs/channels/index.md @@ -33,6 +33,17 @@ uv run bub gateway --enable-channel telegram - Telegram channel session id: `telegram:` - `chat` command default session id: `cli_session` (override with `--session-id`) +## Outbound Delivery Surfaces + +Channel adapters can receive outbound data in two forms: + +- `send(message)`: the final rendered outbound message +- `on_event(event, message)`: streaming events emitted while the model is still producing output + +Use `on_event` for incremental UX such as live text updates, typing indicators, progress bars, or chunk-level logging. Use `send` for the final durable outbound payload. + +`on_event` is optional. A channel that does not need streaming behavior can ignore it and only implement `send`. + ## Debounce Behavior - `cli` does not debounce; each input is processed immediately. diff --git a/docs/extension-guide.md b/docs/extension-guide.md index 04a85d8b..a13b67d2 100644 --- a/docs/extension-guide.md +++ b/docs/extension-guide.md @@ -100,11 +100,18 @@ Current `process_inbound()` hook usage: 1. `resolve_session` (`call_first`) 2. `load_state` (`call_many`, then merged by framework) 3. `build_prompt` (`call_first`) -4. `run_model` (`call_first`) +4. `run_model_stream` (`call_first`) 5. `save_state` (`call_many`, always executed in `finally`) 6. `render_outbound` (`call_many`) 7. `dispatch_outbound` (`call_many`, per outbound) +Compatibility note: + +- `run_model_stream` is the primary model hook. +- If no plugin implements `run_model_stream`, Bub falls back to `run_model`. +- The `run_model` return value is wrapped into a stream with exactly one text chunk. +- A plugin should implement one of these hooks, not both. + Other hook consumers: - `register_cli_commands`: called by `call_many_sync` @@ -150,6 +157,8 @@ class SessionPlugin: ```python from __future__ import annotations +from republic import AsyncStreamEvents, StreamEvent + from bub import hookimpl @@ -159,8 +168,11 @@ class EchoPlugin: return f"[echo] {message['content']}" @hookimpl - async def run_model(self, prompt, session_id, state): - return prompt + async def run_model_stream(self, prompt, session_id, state): + async def iterator(): + yield StreamEvent("text", {"delta": prompt}) + + return AsyncStreamEvents(iterator()) ``` Run and verify: @@ -170,9 +182,56 @@ uv run bub hooks uv run bub run "hello" ``` -Check that your plugin is listed for `build_prompt` / `run_model`, and output reflects your override. +Check that your plugin is listed for `build_prompt` / `run_model_stream`, and output reflects your override. +If you intentionally use the legacy compatibility hook, check for `run_model`. + +## 10) Listen To Parent Stream + +If you want to observe or transform the parent stream instead of fully replacing it, implement `run_model_stream` and wrap the parent hook's async iterator. + +This pattern uses `subset_hook_caller(...)` to call the same hook chain without the current plugin, then returns a new `AsyncStreamEvents` wrapper. + +```python +from __future__ import annotations + +from republic import AsyncStreamEvents, StreamEvent + +from bub import hookimpl + + +class StreamTapPlugin: + def __init__(self, framework) -> None: + self.framework = framework + + @hookimpl + async def run_model_stream(self, prompt, session_id, state): + parent_hook = self.framework._plugin_manager.subset_hook_caller( + "run_model_stream", + remove_plugins=[self], + ) + parent_stream = await parent_hook( + prompt=prompt, + session_id=session_id, + state=state, + ) + if parent_stream is None: + raise RuntimeError("no parent run_model_stream implementation found") + + async def iterator(): + async for event in parent_stream: + if event.kind == "text": + delta = str(event.data.get("delta", "")) + print(delta, end="") + yield event + + return AsyncStreamEvents(iterator(), state=parent_stream._state) +``` + +Use this when you need to log chunks, redact text, inject extra events, or measure stream timing without reimplementing the underlying model call. + +If you also need to support parents that only implement legacy `run_model`, add your own fallback path and wrap that text result into a one-chunk stream. -## 10) Common Pitfalls +## 11) Common Pitfalls - Defining `@tool` functions without importing the module from your plugin means the tools never register. - Returning awaitables from hooks invoked via sync paths (`call_many_sync` / `call_first_sync`) causes skip. diff --git a/docs/features.md b/docs/features.md index e5e7b067..9b81848d 100644 --- a/docs/features.md +++ b/docs/features.md @@ -5,7 +5,9 @@ Every turn stage is a [pluggy](https://pluggy.readthedocs.io/) hook. Builtins are ordinary plugins — override any stage by registering your own. Both first-result hooks (override) and broadcast hooks (observer) are supported. -Safe fallback to prompt text when `run_model` returns no value (with `on_error` notification). +`run_model_stream` is the primary model hook. +Legacy `run_model` hooks still work and are adapted into a single text chunk stream. +Safe fallback to prompt text when no model hook returns a value (with `on_error` notification). Automatic fallback outbound when `render_outbound` produces nothing. ## Tape-Based Context diff --git a/docs/index.md b/docs/index.md index 408781a9..4b9b9245 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,11 +31,13 @@ uv run bub gateway # channel listener mode Every inbound message goes through one turn pipeline. Each stage is a hook. +```text +resolve_session → load_state → build_prompt → run_model_stream + ↓ + dispatch_outbound ← render_outbound ← save_state ``` -resolve_session → load_state → build_prompt → run_model - ↓ - dispatch_outbound ← render_outbound ← save_state -``` + +`run_model` remains supported as a compatibility hook and is adapted into a single-chunk stream when `run_model_stream` is absent. Builtins are plugins registered first. Later plugins override earlier ones. No special cases. diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py index 9097ccd2..a6363def 100644 --- a/src/bub/builtin/agent.py +++ b/src/bub/builtin/agent.py @@ -7,7 +7,8 @@ import re import shlex import time -from collections.abc import Collection +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Collection, Coroutine, Iterable +from contextlib import AsyncExitStack from dataclasses import dataclass, replace from datetime import UTC, datetime from functools import cached_property @@ -15,7 +16,15 @@ from typing import Any from loguru import logger -from republic import LLM, AsyncTapeStore, TapeContext, ToolAutoResult, ToolContext +from republic import ( + LLM, + AsyncStreamEvents, + AsyncTapeStore, + StreamEvent, + StreamState, + TapeContext, + ToolContext, +) from republic.tape import InMemoryTapeStore, Tape from bub.builtin.settings import AgentSettings, load_settings @@ -52,6 +61,25 @@ def tapes(self) -> TapeService: llm = _build_llm(self.settings, tape_store, self.framework.build_tape_context()) return TapeService(llm, self.settings.home / "tapes", tape_store) + @staticmethod + def _events_from_iterable(iterable: Iterable) -> AsyncStreamEvents: + async def generator() -> AsyncIterator: + for item in iterable: + yield item + + return AsyncStreamEvents(generator()) + + @staticmethod + def _events_with_callback( + events: AsyncStreamEvents, callback: Callable[[], Coroutine[Any, Any, Any]] + ) -> AsyncStreamEvents: + async def generator() -> AsyncIterator[StreamEvent]: + async for event in events: + yield event + await callback() + + return AsyncStreamEvents(generator(), state=events._state) + async def run( self, *, @@ -61,19 +89,33 @@ async def run( model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, - ) -> str: + ) -> AsyncStreamEvents: if not prompt: - return "error: empty prompt" + events = [ + StreamEvent("text", {"delta": "error: empty prompt"}), + StreamEvent("final", {"text": "error: empty prompt", "ok": False}), + ] + return self._events_from_iterable(events) + tape = self.tapes.session_tape(session_id, workspace_from_state(state)) tape.context = replace(tape.context, state=state) merge_back = not session_id.startswith("temp/") - async with self.tapes.fork_tape(tape.name, merge_back=merge_back): - await self.tapes.ensure_bootstrap_anchor(tape.name) - if isinstance(prompt, str) and prompt.strip().startswith(","): - return await self._run_command(tape=tape, line=prompt.strip()) - return await self._agent_loop( + stack = AsyncExitStack() + # the fork_tape context manager must not be exited until the last chunk of the stream is consumed. + # So we use an AsyncExitStack and inject a callback to the iterator. + await stack.enter_async_context(self.tapes.fork_tape(tape.name, merge_back=merge_back)) + await self.tapes.ensure_bootstrap_anchor(tape.name) + if isinstance(prompt, str) and prompt.strip().startswith(","): + result = await self._run_command(tape=tape, line=prompt.strip()) + events = self._events_from_iterable([ + StreamEvent("text", {"delta": result}), + StreamEvent("final", {"text": result, "ok": True}), + ]) + else: + events = await self._agent_loop( tape=tape, prompt=prompt, model=model, allowed_skills=allowed_skills, allowed_tools=allowed_tools ) + return self._events_with_callback(events, callback=stack.aclose) async def _run_command(self, tape: Tape, *, line: str) -> str: line = line[1:].strip() @@ -123,10 +165,9 @@ async def _agent_loop( model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, - ) -> str: + ) -> AsyncStreamEvents: next_prompt: str | list[dict] = prompt display_model = model or self.settings.model - auto_handoff_remaining = MAX_AUTO_HANDOFF_RETRIES await self.tapes.append_event( tape.name, "loop.start", @@ -137,34 +178,61 @@ async def _agent_loop( "allowed_tools": list(allowed_tools) if allowed_tools else None, }, ) + state = StreamState() + iterator = self._stream_events_with_auto_handoff( + tape=tape, + prompt=next_prompt, + state=state, + model=model, + allowed_skills=allowed_skills, + allowed_tools=allowed_tools, + ) + return AsyncStreamEvents(iterator, state=state) + + async def _stream_events_with_auto_handoff( + self, + tape: Tape, + prompt: str | list[dict], + state: StreamState, + model: str | None = None, + allowed_skills: Collection[str] | None = None, + allowed_tools: Collection[str] | None = None, + ) -> AsyncGenerator[StreamEvent, None]: + auto_handoff_remaining = MAX_AUTO_HANDOFF_RETRIES + display_model = model or self.settings.model + next_prompt = prompt for step in range(1, self.settings.max_steps + 1): start = time.monotonic() + outcome = _ToolAutoOutcome(kind="text", text="", error="") logger.info("loop.step step={} tape={} model={}", step, tape.name, display_model) await self.tapes.append_event(tape.name, "loop.step.start", {"step": step, "prompt": next_prompt}) - try: - output = await self._run_tools_once( - tape=tape, - prompt=next_prompt, - model=model, - allowed_skills=allowed_skills, - allowed_tools=allowed_tools, - ) - except Exception as exc: - elapsed_ms = int((time.monotonic() - start) * 1000) - await self.tapes.append_event( - tape.name, - "loop.step", - { - "step": step, - "elapsed_ms": elapsed_ms, - "status": "error", - "error": f"{exc!s}", - "date": datetime.now(UTC).isoformat(), - }, - ) - raise - - outcome = _resolve_tool_auto_result(output) + output = await self._run_once( + tape=tape, + prompt=next_prompt, + model=model, + allowed_skills=allowed_skills, + allowed_tools=allowed_tools, + ) + async for event in output: + yield event + if event.kind == "error": + elapsed_ms = int((time.monotonic() - start) * 1000) + await self.tapes.append_event( + tape.name, + "loop.step", + { + "step": step, + "elapsed_ms": elapsed_ms, + "status": "error", + "error": event.data.get("message", ""), + "date": datetime.now(UTC).isoformat(), + }, + ) + elif event.kind == "final": + outcome = _resolve_tool_auto_result(event.data) + + state.error = output.error + state.usage = output.usage elapsed_ms = int((time.monotonic() - start) * 1000) if outcome.kind == "text": await self.tapes.append_event( @@ -177,7 +245,7 @@ async def _agent_loop( "date": datetime.now(UTC).isoformat(), }, ) - return outcome.text + return if outcome.kind == "continue": if "context" in tape.context.state: next_prompt = f"{CONTINUE_PROMPT} [context: {tape.context.state['context']}]" @@ -247,7 +315,7 @@ def _load_skills_prompt(self, prompt: str, workspace: Path, allowed_skills: set[ expanded_skills = set(HINT_RE.findall(prompt)) & set(skill_index.keys()) return render_skills_prompt(list(skill_index.values()), expanded_skills=expanded_skills) - async def _run_tools_once( + async def _run_once( self, *, tape: Tape, @@ -255,7 +323,7 @@ async def _run_tools_once( model: str | None = None, allowed_tools: Collection[str] | None = None, allowed_skills: Collection[str] | None = None, - ) -> ToolAutoResult: + ) -> AsyncStreamEvents: prompt_text = prompt if isinstance(prompt, str) else _extract_text_from_parts(prompt) if allowed_tools is not None: allowed_tools = {name.casefold() for name in allowed_tools} @@ -267,8 +335,8 @@ async def _run_tools_once( else: tools = list(REGISTRY.values()) async with asyncio.timeout(self.settings.model_timeout_seconds): - return await tape.run_tools_async( - prompt=prompt, # republic accepts list content parts at runtime + return await tape.stream_events_async( + prompt=prompt, system_prompt=self._system_prompt(prompt_text, state=tape.context.state, allowed_skills=allowed_skills), max_tokens=self.settings.max_tokens, tools=model_tools(tools), @@ -295,15 +363,12 @@ class _ToolAutoOutcome: error: str = "" -def _resolve_tool_auto_result(output: ToolAutoResult) -> _ToolAutoOutcome: - if output.kind == "text": - return _ToolAutoOutcome(kind="text", text=output.text or "") - if output.kind == "tools" or output.tool_calls or output.tool_results: +def _resolve_tool_auto_result(final_data: dict[str, Any]) -> _ToolAutoOutcome: + if (text := final_data.get("text")) is not None: + return _ToolAutoOutcome(kind="text", text=text) + if final_data.get("tool_calls") or final_data.get("tool_results"): return _ToolAutoOutcome(kind="continue") - if output.error is None: - return _ToolAutoOutcome(kind="error", error="tool_auto_error: unknown") - error_kind = getattr(output.error.kind, "value", str(output.error.kind)) - return _ToolAutoOutcome(kind="error", error=f"{error_kind}: {output.error.message}") + return _ToolAutoOutcome(kind="error", error="unknown error") def _build_llm(settings: AgentSettings, tape_store: AsyncTapeStore, tape_context: TapeContext) -> LLM: diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index ddba06b8..187886bb 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -5,7 +5,7 @@ import typer from loguru import logger -from republic import TapeContext +from republic import AsyncStreamEvents, TapeContext from republic.tape import TapeStore from bub.builtin.agent import Agent @@ -106,7 +106,7 @@ async def build_prompt(self, message: ChannelMessage, session_id: str, state: St return text @hookimpl - async def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> str: + async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents: return await self.agent.run(session_id=session_id, prompt=prompt, state=state) @hookimpl diff --git a/src/bub/builtin/tools.py b/src/bub/builtin/tools.py index b3243089..ce5ebcd0 100644 --- a/src/bub/builtin/tools.py +++ b/src/bub/builtin/tools.py @@ -266,14 +266,20 @@ async def run_subagent(param: SubAgentInput, *, context: ToolContext) -> str: subagent_session = param.session state = {**context.state, "session_id": subagent_session} allowed_tools = resolve_tool_names(param.allowed_tools or None, exclude={"subagent"}) - return await agent.run( + output = "" + async for event in await agent.run( session_id=subagent_session, prompt=param.prompt, state=state, model=param.model, allowed_tools=allowed_tools, allowed_skills=param.allowed_skills, - ) + ): + if event.kind == "error": + output += f"[Error: {event.data.get('message', 'unknown error')}]" + elif event.kind == "text": + output += str(event.data.get("delta", "")) + return output @tool(name="help") diff --git a/src/bub/channels/base.py b/src/bub/channels/base.py index 939b37d6..5772e5e0 100644 --- a/src/bub/channels/base.py +++ b/src/bub/channels/base.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from typing import ClassVar +from republic import StreamEvent + from bub.channels.message import ChannelMessage @@ -32,3 +34,8 @@ async def send(self, message: ChannelMessage) -> None: """Send a message to the channel. Optional to implement.""" # Do nothing by default return + + async def on_event(self, event: StreamEvent, message: ChannelMessage) -> None: + """Handle an event from the agent. Optional to implement.""" + # Do nothing by default + return diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index 21606583..ea137f8d 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -1,6 +1,7 @@ import asyncio import contextlib from collections.abc import AsyncGenerator +from dataclasses import dataclass from datetime import datetime from hashlib import md5 from pathlib import Path @@ -11,18 +12,27 @@ from prompt_toolkit.history import FileHistory from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.patch_stdout import patch_stdout +from republic import StreamEvent from rich import get_console +from rich.live import Live from bub.builtin.agent import Agent from bub.builtin.tape import TapeInfo from bub.channels.base import Channel from bub.channels.cli.renderer import CliRenderer -from bub.channels.message import ChannelMessage -from bub.envelope import content_of, field_of +from bub.channels.message import ChannelMessage, MessageKind +from bub.envelope import field_of from bub.tools import REGISTRY from bub.types import MessageHandler +@dataclass +class _StreamRenderState: + live: Live + kind: MessageKind + text: str = "" + + class CliChannel(Channel): """A simple CLI channel for testing and debugging.""" @@ -65,15 +75,6 @@ async def stop(self) -> None: with contextlib.suppress(asyncio.CancelledError): await self._main_task - async def send(self, message: ChannelMessage) -> None: - match message.kind: - case "error": - self._renderer.error(content_of(message)) - case "command": - self._renderer.command_output(content_of(message)) - case _: - self._renderer.assistant_output(content_of(message)) - async def _main_loop(self) -> None: self._renderer.welcome(model=self._agent.settings.model, workspace=str(self._workspace)) await self._refresh_tape_info() @@ -103,9 +104,8 @@ async def _main_loop(self) -> None: content=request, lifespan=self.message_lifespan(request_completed), ) - with self._renderer.console.status("[cyan]Processing...[/cyan]", spinner="dots"): - await self._on_receive(message) - await request_completed.wait() + await self._on_receive(message) + await request_completed.wait() request_completed.clear() self._renderer.info("Bye.") @@ -131,6 +131,22 @@ def _prompt_message(self) -> FormattedText: symbol = ">" if self._mode == "agent" else "," return FormattedText([("bold", f"{cwd} {symbol} ")]) + async def on_event(self, event: StreamEvent, message: ChannelMessage) -> None: + streams = self._stream_render_states() + state = streams.get(message.session_id) + if event.kind == "text": + if state is None: + state = _StreamRenderState(live=self._renderer.start_stream(message.kind), kind=message.kind) + streams[message.session_id] = state + content = str(event.data.get("delta", "")) + state.text += content + self._renderer.update_stream(state.live, kind=message.kind, text=state.text) + elif event.kind == "final": + if state is None: + return + self._renderer.finish_stream(state.live, kind=state.kind, text=state.text) + streams.pop(message.session_id, None) + def _build_prompt(self, workspace: Path) -> PromptSession[str]: kb = KeyBindings() @@ -172,3 +188,10 @@ def _render_bottom_toolbar(self) -> FormattedText: def _history_file(home: Path, workspace: Path) -> Path: workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest() return home / "history" / f"{workspace_hash}.history" + + def _stream_render_states(self) -> dict[str, _StreamRenderState]: + states = getattr(self, "_active_stream_renders", None) + if states is None: + states = {} + self._active_stream_renders = states + return states diff --git a/src/bub/channels/cli/renderer.py b/src/bub/channels/cli/renderer.py index c89981bc..2db91d7c 100644 --- a/src/bub/channels/cli/renderer.py +++ b/src/bub/channels/cli/renderer.py @@ -5,9 +5,12 @@ from dataclasses import dataclass from rich.console import Console +from rich.live import Live from rich.panel import Panel from rich.text import Text +from bub.channels.message import MessageKind + @dataclass class CliRenderer: @@ -30,17 +33,50 @@ def info(self, text: str) -> None: return self.console.print(Text(text, style="bright_black")) + def panel(self, kind: MessageKind, text: str) -> Panel: + title, border_style = self._panel_style(kind) + return Panel(text, title=title, border_style=border_style) + def command_output(self, text: str) -> None: if not text.strip(): return - self.console.print(Panel(text, title="Command", border_style="green")) + self.console.print(self.panel("command", text)) def assistant_output(self, text: str) -> None: if not text.strip(): return - self.console.print(Panel(text, title="Assistant", border_style="blue")) + self.console.print(self.panel("normal", text)) def error(self, text: str) -> None: if not text.strip(): return - self.console.print(Panel(text, title="Error", border_style="red")) + self.console.print(self.panel("error", text)) + + def start_stream(self, kind: MessageKind) -> Live: + live = Live( + self.panel(kind, ""), + console=self.console, + auto_refresh=False, + transient=False, + vertical_overflow="visible", + ) + live.start() + live.refresh() + return live + + def update_stream(self, live: Live, *, kind: MessageKind, text: str) -> None: + live.update(self.panel(kind, text), refresh=True) + + def finish_stream(self, live: Live, *, kind: MessageKind, text: str) -> None: + live.update(self.panel(kind, text), refresh=True) + live.stop() + + @staticmethod + def _panel_style(kind: MessageKind) -> tuple[str, str]: + match kind: + case "error": + return "Error", "red" + case "command": + return "Command", "green" + case _: + return "Assistant", "blue" diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index 08d2f7a5..8643e697 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -6,6 +6,7 @@ from loguru import logger from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict +from republic import StreamEvent from bub.channels.base import Channel from bub.channels.handler import BufferedMessageHandler @@ -72,7 +73,7 @@ async def on_receive(self, message: ChannelMessage) -> None: def get_channel(self, name: str) -> Channel | None: return self._channels.get(name) - async def dispatch(self, message: Envelope) -> bool: + async def dispatch_output(self, message: Envelope) -> bool: channel_name = field_of(message, "output_channel", field_of(message, "channel")) if channel_name is None: return False @@ -93,6 +94,18 @@ async def dispatch(self, message: Envelope) -> bool: await channel.send(outbound) return True + async def dispatch_event(self, event: StreamEvent, message: Envelope) -> None: + channel_name = field_of(message, "output_channel", field_of(message, "channel")) + if channel_name is None: + return + + channel_key = str(channel_name) + channel = self.get_channel(channel_key) + if channel is None: + return + + await channel.on_event(event, message) + async def quit(self, session_id: str) -> None: tasks = self._ongoing_tasks.pop(session_id, set()) for task in tasks: diff --git a/src/bub/framework.py b/src/bub/framework.py index bd0e23e6..55bc8e58 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -10,7 +10,7 @@ import typer from dotenv import load_dotenv from loguru import logger -from republic import AsyncTapeStore, TapeContext +from republic import AsyncTapeStore, RepublicError, TapeContext from republic.tape import TapeStore from bub.envelope import content_of, field_of, unpack_batch @@ -108,18 +108,7 @@ async def process_inbound(self, inbound: Envelope) -> TurnResult: prompt = content_of(inbound) model_output = "" try: - model_output = await self._hook_runtime.call_first( - "run_model", prompt=prompt, session_id=session_id, state=state - ) - if model_output is None: - await self._hook_runtime.notify_error( - stage="run_model:fallback", - error=RuntimeError("no model skill returned output"), - message=inbound, - ) - model_output = prompt if isinstance(prompt, str) else content_of(inbound) - else: - model_output = str(model_output) + model_output = await self._run_model(inbound, prompt, session_id, state) finally: await self._hook_runtime.call_many( "save_state", @@ -138,6 +127,29 @@ async def process_inbound(self, inbound: Envelope) -> TurnResult: await self._hook_runtime.notify_error(stage="turn", error=exc, message=inbound) raise + async def _run_model( + self, inbound: Envelope, prompt: str | list[dict], session_id: str, state: dict[str, Any] + ) -> str: + stream = await self._hook_runtime.run_model_stream(prompt=prompt, session_id=session_id, state=state) + if stream is None: + await self._hook_runtime.notify_error( + stage="run_model", + error=RuntimeError("no model skill returned output"), + message=inbound, + ) + return prompt if isinstance(prompt, str) else content_of(inbound) + else: + parts: list[str] = [] + async for event in stream: + await self.dispatch_event_via_router(event, inbound) + if event.kind == "text": + parts.append(str(event.data.get("delta", ""))) + elif event.kind == "error": + await self._hook_runtime.notify_error( + stage="run_model", error=RepublicError(**event.data), message=inbound + ) + return "".join(parts) + def hook_report(self) -> dict[str, list[str]]: """Return hook implementation summary for diagnostics.""" @@ -149,7 +161,13 @@ def bind_outbound_router(self, router: OutboundChannelRouter | None) -> None: async def dispatch_via_router(self, message: Envelope) -> bool: if self._outbound_router is None: return False - return await self._outbound_router.dispatch(message) + return await self._outbound_router.dispatch_output(message) + + async def dispatch_event_via_router(self, event: Any, message: Envelope) -> bool: + if self._outbound_router is not None: + await self._outbound_router.dispatch_event(event, message) + return True + return False async def quit_via_router(self, session_id: str) -> None: if self._outbound_router is not None: diff --git a/src/bub/hook_runtime.py b/src/bub/hook_runtime.py index 22ab3abe..c69cd1db 100644 --- a/src/bub/hook_runtime.py +++ b/src/bub/hook_runtime.py @@ -3,10 +3,12 @@ from __future__ import annotations import inspect +from collections.abc import AsyncGenerator from typing import Any import pluggy from loguru import logger +from republic import AsyncStreamEvents, StreamEvent, StreamState from bub.types import Envelope @@ -158,12 +160,21 @@ def _iter_hookimpls(self, hook_name: str) -> list[Any]: def _kwargs_for_impl(impl: Any, kwargs: dict[str, Any]) -> dict[str, Any]: return {name: kwargs[name] for name in impl.argnames if name in kwargs} + async def run_model_stream( + self, prompt: str | list[dict], session_id: str, state: dict[str, Any] + ) -> AsyncStreamEvents | None: + """Run the first `run_model_stream` hook found and fallback to `run_model` hook.""" + for _, plugin in reversed(self._plugin_manager.list_name_plugin()): + if hasattr(plugin, "run_model_stream"): + return await self.call_first("run_model_stream", prompt=prompt, session_id=session_id, state=state) + elif hasattr(plugin, "run_model"): -def _message_from_kwargs(kwargs: dict[str, Any]) -> Envelope | None: - message = kwargs.get("message") - if message is None: + async def iterator() -> AsyncGenerator[StreamEvent, None]: + result = await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) + yield StreamEvent("text", {"delta": result}) + + return AsyncStreamEvents(iterator(), state=StreamState()) return None - return message _SKIP_VALUE = object() diff --git a/src/bub/hookspecs.py b/src/bub/hookspecs.py index 237a0c16..47da2b66 100644 --- a/src/bub/hookspecs.py +++ b/src/bub/hookspecs.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any import pluggy -from republic import AsyncTapeStore, TapeContext +from republic import AsyncStreamEvents, AsyncTapeStore, TapeContext from republic.tape import TapeStore from bub.types import Envelope, MessageHandler, State @@ -42,7 +42,12 @@ def build_prompt(self, message: Envelope, session_id: str, state: State) -> str @hookspec(firstresult=True) def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> str: - """Run model for one turn and return plain text output.""" + """Run model for one turn and return plain text output. Should not be implemented if `run_model_stream` is implemented.""" + raise NotImplementedError + + @hookspec(firstresult=True) + def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents: + """Run model for one turn and return a stream of events. Should not be implemented if `run_model` is implemented.""" raise NotImplementedError @hookspec diff --git a/src/bub/types.py b/src/bub/types.py index a1f73c77..0f84ea32 100644 --- a/src/bub/types.py +++ b/src/bub/types.py @@ -6,6 +6,8 @@ from dataclasses import dataclass, field from typing import Any, Protocol +from republic import StreamEvent + type Envelope = Any type State = dict[str, Any] type MessageHandler = Callable[[Envelope], Coroutine[Any, Any, None]] @@ -13,7 +15,8 @@ class OutboundChannelRouter(Protocol): - async def dispatch(self, message: Envelope) -> bool: ... + async def dispatch_output(self, message: Envelope) -> bool: ... + async def dispatch_event(self, event: StreamEvent, message: Envelope) -> None: ... async def quit(self, session_id: str) -> None: ... diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py index ecb90544..f55169ea 100644 --- a/tests/test_builtin_agent.py +++ b/tests/test_builtin_agent.py @@ -7,7 +7,7 @@ import pytest import republic.auth.openai_codex as openai_codex -from republic import TapeContext, ToolAutoResult +from republic import AsyncStreamEvents, StreamEvent, TapeContext import bub.builtin.agent as agent_module from bub.builtin.agent import Agent @@ -90,11 +90,15 @@ def session_tape(self, session_id: str, workspace: Any) -> MagicMock: tape.name = "test-tape" tape.context = TapeContext(state={}) - async def fake_run_tools_async(**kwargs: Any) -> ToolAutoResult: + async def fake_stream_events_async(**kwargs: Any) -> AsyncStreamEvents: self.run_tools_model = kwargs.get("model") - return ToolAutoResult(kind="text", text="done", tool_calls=[], tool_results=[], error=None) - tape.run_tools_async = fake_run_tools_async + async def iterator(): + yield StreamEvent("final", {"text": "done"}) + + return AsyncStreamEvents(iterator()) + + tape.stream_events_async = fake_stream_events_async return tape async def ensure_bootstrap_anchor(self, tape_name: str) -> None: @@ -116,7 +120,8 @@ async def test_agent_run_regular_session_merges_back() -> None: fork_capture = _ForkCapture() agent.tapes = _FakeTapeService(fork_capture) # type: ignore[assignment] - await agent.run(session_id="user/session1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + result = await agent.run(session_id="user/session1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + [event async for event in result] assert fork_capture.merge_back_values == [True] @@ -128,20 +133,27 @@ async def test_agent_run_temp_session_does_not_merge_back() -> None: fork_capture = _ForkCapture() agent.tapes = _FakeTapeService(fork_capture) # type: ignore[assignment] - await agent.run(session_id="temp/abc123", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + result = await agent.run(session_id="temp/abc123", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + [event async for event in result] assert fork_capture.merge_back_values == [False] @pytest.mark.asyncio async def test_agent_run_passes_model_to_llm() -> None: - """The model parameter should be forwarded to run_tools_async.""" + """The model parameter should be forwarded to stream_events_async.""" agent = _make_agent() fork_capture = _ForkCapture() fake_tapes = _FakeTapeService(fork_capture) agent.tapes = fake_tapes # type: ignore[assignment] - await agent.run(session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}, model="openai:gpt-4o") # noqa: S108 + result = await agent.run( + session_id="user/s1", + prompt="hello", + state={"_runtime_workspace": "/tmp"}, # noqa: S108 + model="openai:gpt-4o", + ) + [event async for event in result] assert fake_tapes.run_tools_model == "openai:gpt-4o" @@ -152,8 +164,12 @@ async def test_agent_run_empty_prompt_returns_error() -> None: agent.tapes = MagicMock() # type: ignore[assignment] result = await agent.run(session_id="user/s1", prompt="", state={}) + events = [event async for event in result] - assert result == "error: empty prompt" + assert [(event.kind, event.data) for event in events] == [ + ("text", {"delta": "error: empty prompt"}), + ("final", {"ok": False, "text": "error: empty prompt"}), + ] @pytest.mark.asyncio @@ -164,6 +180,7 @@ async def test_agent_run_model_defaults_to_none() -> None: fake_tapes = _FakeTapeService(fork_capture) agent.tapes = fake_tapes # type: ignore[assignment] - await agent.run(session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + result = await agent.run(session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + [event async for event in result] assert fake_tapes.run_tools_model is None diff --git a/tests/test_builtin_hook_impl.py b/tests/test_builtin_hook_impl.py index ce81a183..717004c3 100644 --- a/tests/test_builtin_hook_impl.py +++ b/tests/test_builtin_hook_impl.py @@ -4,6 +4,7 @@ from types import SimpleNamespace import pytest +from republic import AsyncStreamEvents, StreamEvent from bub.builtin.hook_impl import AGENTS_FILE_NAME, DEFAULT_SYSTEM_PROMPT, BuiltinImpl from bub.builtin.store import FileTapeStore @@ -28,9 +29,13 @@ def __init__(self, home: Path) -> None: self.settings = SimpleNamespace(home=home) self.calls: list[tuple[str, str, dict[str, object]]] = [] - async def run(self, *, session_id: str, prompt: str, state: dict[str, object]) -> str: + async def run(self, *, session_id: str, prompt: str, state: dict[str, object]) -> AsyncStreamEvents: self.calls.append((session_id, prompt, state)) - return "agent-output" + + async def iterator(): + yield StreamEvent("text", {"delta": "agent-output"}) + + return AsyncStreamEvents(iterator()) def _raise_value_error() -> None: @@ -113,13 +118,14 @@ async def test_build_prompt_marks_commands_and_prefixes_context(tmp_path: Path) @pytest.mark.asyncio -async def test_run_model_delegates_to_agent(tmp_path: Path) -> None: +async def test_run_model_stream_delegates_to_agent(tmp_path: Path) -> None: _, impl, agent = _build_impl(tmp_path) state = {"context": "ctx"} - result = await impl.run_model(prompt="prompt", session_id="session", state=state) + stream = await impl.run_model_stream(prompt="prompt", session_id="session", state=state) + events = [event async for event in stream] - assert result == "agent-output" + assert [(event.kind, event.data) for event in events] == [("text", {"delta": "agent-output"})] assert agent.calls == [("session", "prompt", state)] diff --git a/tests/test_channels.py b/tests/test_channels.py index 6e0f2958..db095627 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -7,6 +7,7 @@ from types import SimpleNamespace import pytest +from republic import StreamEvent from bub.channels.cli import CliChannel from bub.channels.handler import BufferedMessageHandler @@ -98,7 +99,7 @@ async def test_channel_manager_dispatch_uses_output_channel_and_preserves_metada cli_channel = FakeChannel("cli") manager = ChannelManager(FakeFramework({"cli": cli_channel}), enabled_channels=["cli"]) - result = await manager.dispatch({ + result = await manager.dispatch_output({ "session_id": "session", "channel": "telegram", "output_channel": "cli", @@ -206,20 +207,32 @@ def test_cli_channel_normalize_input_prefixes_shell_commands() -> None: @pytest.mark.asyncio -async def test_cli_channel_send_routes_by_message_kind() -> None: +async def test_cli_channel_on_event_renders_stream_and_suppresses_followup_send() -> None: channel = CliChannel.__new__(CliChannel) - events: list[tuple[str, str]] = [] + events: list[tuple[str, str, str]] = [] + live_handle = object() channel._renderer = SimpleNamespace( - error=lambda content: events.append(("error", content)), - command_output=lambda content: events.append(("command", content)), - assistant_output=lambda content: events.append(("assistant", content)), + start_stream=lambda kind: events.append(("start", kind, "")) or live_handle, + update_stream=lambda live, *, kind, text: events.append(("update", kind, text)), + finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)), + error=lambda content: events.append(("error", "error", content)), + command_output=lambda content: events.append(("send", "command", content)), + assistant_output=lambda content: events.append(("send", "normal", content)), ) - await channel.send(_message("bad", channel="cli", kind="error")) - await channel.send(_message("ok", channel="cli", kind="command")) - await channel.send(_message("hi", channel="cli")) + message = _message("ignored", channel="cli", kind="command", session_id="cli:1") - assert events == [("error", "bad"), ("command", "ok"), ("assistant", "hi")] + await channel.on_event(StreamEvent("text", {"delta": "hel"}), message) + await channel.on_event(StreamEvent("text", {"delta": "lo"}), message) + await channel.on_event(StreamEvent("final", {}), message) + await channel.send(_message("hello", channel="cli", kind="command", session_id="cli:1")) + + assert events == [ + ("start", "command", ""), + ("update", "command", "hel"), + ("update", "command", "hello"), + ("finish", "command", "hello"), + ] def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None: diff --git a/tests/test_subagent_tool.py b/tests/test_subagent_tool.py index 5991f58d..371592e9 100644 --- a/tests/test_subagent_tool.py +++ b/tests/test_subagent_tool.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock import pytest +from republic import AsyncStreamEvents, StreamEvent from bub.builtin.tools import run_subagent from bub.tools import REGISTRY, tool @@ -19,7 +20,13 @@ def __init__(self, state: dict[str, Any]) -> None: class FakeAgent: def __init__(self) -> None: - self.run = AsyncMock(return_value="agent result") + self.run = AsyncMock(side_effect=self._run) + + async def _run(self, **kwargs: Any) -> AsyncStreamEvents: + async def iterator(): + yield StreamEvent("text", {"delta": "agent result"}) + + return AsyncStreamEvents(iterator()) @pytest.mark.asyncio