diff --git a/examples/realtime/demo.py b/examples/realtime/demo.py new file mode 100644 index 000000000..663784573 --- /dev/null +++ b/examples/realtime/demo.py @@ -0,0 +1,90 @@ +import asyncio +import base64 +import os +import sys +from typing import TYPE_CHECKING + +import numpy as np + +# Add the current directory to path so we can import ui +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from agents import function_tool +from agents.realtime import RealtimeAgent, RealtimeSession, RealtimeSessionEvent + +if TYPE_CHECKING: + from .ui import AppUI +else: + # At runtime, try both import styles + try: + # Try relative import first (when used as a package) + from .ui import AppUI + except ImportError: + # Fall back to direct import (when run as a script) + from ui import AppUI + + +@function_tool +def get_weather(city: str) -> str: + """Get the weather in a city.""" + return f"The weather in {city} is sunny." + + +agent = RealtimeAgent( + name="Assistant", + instructions="You always greet the user with 'Top of the morning to you'.", + tools=[get_weather], +) + + +class Example: + def __init__(self) -> None: + self.session = RealtimeSession(agent) + self.ui = AppUI() + self.ui.connected = asyncio.Event() + self.ui.last_audio_item_id = None + # Set the audio callback + self.ui.set_audio_callback(self.on_audio_recorded) + + async def run(self) -> None: + self.session.add_listener(self.on_event) + await self.session.connect() + self.ui.set_is_connected(True) + await self.ui.run_async() + + async def on_audio_recorded(self, audio_bytes: bytes) -> None: + """Called when audio is recorded by the UI.""" + try: + # Send the audio to the session + await self.session.send_audio(audio_bytes) + except Exception as e: + self.ui.log_message(f"Error sending audio: {e}") + + async def on_event(self, event: RealtimeSessionEvent) -> None: + # Display event in the UI + try: + if event.type == "raw_transport_event" and event.data.type == "other": + # self.ui.log_message(f"{event.data}, {type(event.data.data)}") + if event.data.data["type"] == "response.audio.delta": + self.ui.log_message("audio deltas") + delta_b64_string = event.data.data["delta"] + delta_bytes = base64.b64decode(delta_b64_string) + audio_data = np.frombuffer(delta_bytes, dtype=np.int16) + self.ui.play_audio(audio_data) + + # Handle audio from model + if event.type == "audio": + try: + # Convert bytes to numpy array for audio player + audio_data = np.frombuffer(event.audio.data, dtype=np.int16) + self.ui.play_audio(audio_data) + except Exception as e: + self.ui.log_message(f"Audio play error: {e}") + except Exception: + # This can happen if the UI has already exited + pass + + +if __name__ == "__main__": + example = Example() + asyncio.run(example.run()) diff --git a/examples/realtime/ui.py b/examples/realtime/ui.py new file mode 100644 index 000000000..d619ae8e6 --- /dev/null +++ b/examples/realtime/ui.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Coroutine +from typing import Any, Callable + +import numpy as np +import numpy.typing as npt +import sounddevice as sd +from textual import events +from textual.app import App, ComposeResult +from textual.containers import Container +from textual.reactive import reactive +from textual.widgets import RichLog, Static +from typing_extensions import override + +CHUNK_LENGTH_S = 0.05 # 50ms +SAMPLE_RATE = 24000 +FORMAT = np.int16 +CHANNELS = 1 + + +class Header(Static): + """A header widget.""" + + @override + def render(self) -> str: + return "Realtime Demo" + + +class AudioStatusIndicator(Static): + """A widget that shows the current audio recording status.""" + + is_recording = reactive(False) + + @override + def render(self) -> str: + status = ( + "🔴 Conversation started." + if self.is_recording + else "⚪ Press SPACE to start the conversation (q to quit)" + ) + return status + + +class AppUI(App[None]): + CSS = """ + Screen { + background: #1a1b26; /* Dark blue-grey background */ + } + + Container { + border: double rgb(91, 164, 91); + } + + #input-container { + height: 5; /* Explicit height for input container */ + margin: 1 1; + padding: 1 2; + } + + #bottom-pane { + width: 100%; + height: 82%; /* Reduced to make room for session display */ + border: round rgb(205, 133, 63); + content-align: center middle; + } + + #status-indicator { + height: 3; + content-align: center middle; + background: #2a2b36; + border: solid rgb(91, 164, 91); + margin: 1 1; + } + + #session-display { + height: 3; + content-align: center middle; + background: #2a2b36; + border: solid rgb(91, 164, 91); + margin: 1 1; + } + + Static { + color: white; + } + """ + + should_send_audio: asyncio.Event + connected: asyncio.Event + last_audio_item_id: str | None + audio_callback: Callable[[bytes], Coroutine[Any, Any, None]] | None + + def __init__(self) -> None: + super().__init__() + self.audio_player = sd.OutputStream( + samplerate=SAMPLE_RATE, + channels=CHANNELS, + dtype=FORMAT, + ) + self.should_send_audio = asyncio.Event() + self.connected = asyncio.Event() + self.audio_callback = None + + @override + def compose(self) -> ComposeResult: + """Create child widgets for the app.""" + with Container(): + yield Header(id="session-display") + yield AudioStatusIndicator(id="status-indicator") + yield RichLog(id="bottom-pane", wrap=True, highlight=True, markup=True) + + def set_is_connected(self, is_connected: bool) -> None: + self.connected.set() if is_connected else self.connected.clear() + + def set_audio_callback(self, callback: Callable[[bytes], Coroutine[Any, Any, None]]) -> None: + """Set a callback function to be called when audio is recorded.""" + self.audio_callback = callback + + # High-level methods for UI operations + def set_header_text(self, text: str) -> None: + """Update the header text.""" + header = self.query_one("#session-display", Header) + header.update(text) + + def set_recording_status(self, is_recording: bool) -> None: + """Set the recording status indicator.""" + status_indicator = self.query_one(AudioStatusIndicator) + status_indicator.is_recording = is_recording + + def log_message(self, message: str) -> None: + """Add a message to the log pane.""" + try: + bottom_pane = self.query_one("#bottom-pane", RichLog) + bottom_pane.write(message) + except Exception: + # Handle the case where the widget might not be available + pass + + def play_audio(self, audio_data: npt.NDArray[np.int16]) -> None: + """Play audio data through the audio player.""" + try: + self.audio_player.write(audio_data) + except Exception as e: + self.log_message(f"Audio play error: {e}") + + async def on_mount(self) -> None: + """Set up audio player and start the audio capture worker.""" + self.audio_player.start() + self.run_worker(self.capture_audio()) + + async def capture_audio(self) -> None: + """Capture audio from the microphone and send to the session.""" + # Wait for connection to be established + await self.connected.wait() + + self.log_message("Connected to agent. Press space to start the conversation") + + # Set up audio input stream + stream = sd.InputStream( + channels=CHANNELS, + samplerate=SAMPLE_RATE, + dtype=FORMAT, + ) + + try: + # Wait for user to press spacebar to start + await self.should_send_audio.wait() + + stream.start() + self.set_recording_status(True) + self.log_message("Recording started - speak to the agent") + + # Buffer size in samples + read_size = int(SAMPLE_RATE * CHUNK_LENGTH_S) + + while True: + # Check if there's enough data to read + if stream.read_available < read_size: + await asyncio.sleep(0.01) # Small sleep to avoid CPU hogging + continue + + # Read audio data + data, _ = stream.read(read_size) + + # Convert numpy array to bytes + audio_bytes = data.tobytes() + + # Call audio callback if set + if self.audio_callback: + try: + await self.audio_callback(audio_bytes) + except Exception as e: + self.log_message(f"Audio callback error: {e}") + + # Yield control back to event loop + await asyncio.sleep(0) + + except Exception as e: + self.log_message(f"Audio capture error: {e}") + finally: + if stream.active: + stream.stop() + stream.close() + + async def on_key(self, event: events.Key) -> None: + """Handle key press events.""" + # add the keypress to the log + self.log_message(f"Key pressed: {event.key}") + + if event.key == "q": + self.audio_player.stop() + self.audio_player.close() + self.exit() + return + + if event.key == "space": # Spacebar + if not self.should_send_audio.is_set(): + self.should_send_audio.set() + self.set_recording_status(True) diff --git a/src/agents/agent.py b/src/agents/agent.py index b855e03b4..cf5e687e2 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -94,6 +94,33 @@ class AgentBase: mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) """Configuration for MCP servers.""" + async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: + """Fetches the available tools from the MCP servers.""" + convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, convert_schemas_to_strict, run_context, self + ) + + async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: + """All agent tools, including MCP tools and function tools.""" + mcp_tools = await self.get_mcp_tools(run_context) + + async def _check_tool_enabled(tool: Tool) -> bool: + if not isinstance(tool, FunctionTool): + return True + + attr = tool.is_enabled + if isinstance(attr, bool): + return attr + res = attr(run_context, self) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools)) + enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok] + return [*mcp_tools, *enabled] + @dataclass class Agent(AgentBase, Generic[TContext]): @@ -262,30 +289,3 @@ async def get_prompt( ) -> ResponsePromptParam | None: """Get the prompt for the agent.""" return await PromptUtil.to_model_input(self.prompt, run_context, self) - - async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: - """Fetches the available tools from the MCP servers.""" - convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) - return await MCPUtil.get_all_function_tools( - self.mcp_servers, convert_schemas_to_strict, run_context, self - ) - - async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: - """All agent tools, including MCP tools and function tools.""" - mcp_tools = await self.get_mcp_tools(run_context) - - async def _check_tool_enabled(tool: Tool) -> bool: - if not isinstance(tool, FunctionTool): - return True - - attr = tool.is_enabled - if isinstance(attr, bool): - return attr - res = attr(run_context, self) - if inspect.isawaitable(res): - return bool(await res) - return bool(res) - - results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools)) - enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok] - return [*mcp_tools, *enabled] diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 26af94ba3..e22e6e8a0 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -17,9 +17,9 @@ class _OmitTypeAnnotation: @classmethod def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: def validate_from_none(value: None) -> _Omit: return _Omit() @@ -39,12 +39,14 @@ def validate_from_none(value: None) -> _Omit: from_none_schema, ] ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: None - ), + serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None), ) + + Omit = Annotated[_Omit, _OmitTypeAnnotation] Headers: TypeAlias = Mapping[str, Union[str, Omit]] +ToolChoice: TypeAlias = Literal["auto", "required", "none"] | str | None + @dataclass class ModelSettings: diff --git a/src/agents/realtime/__init__.py b/src/agents/realtime/__init__.py index 7f99ac8a8..7e5d4932a 100644 --- a/src/agents/realtime/__init__.py +++ b/src/agents/realtime/__init__.py @@ -1,3 +1,51 @@ from .agent import RealtimeAgent, RealtimeAgentHooks, RealtimeRunHooks +from .config import APIKeyOrKeyFunc +from .events import ( + RealtimeAgentEndEvent, + RealtimeAgentStartEvent, + RealtimeAudio, + RealtimeAudioEnd, + RealtimeAudioInterrupted, + RealtimeError, + RealtimeHandoffEvent, + RealtimeHistoryAdded, + RealtimeHistoryUpdated, + RealtimeRawTransportEvent, + RealtimeSessionEvent, + RealtimeToolEnd, + RealtimeToolStart, +) +from .session import RealtimeSession +from .transport import ( + RealtimeModelName, + RealtimeSessionTransport, + RealtimeTransportConnectionOptions, + RealtimeTransportListener, +) -__all__ = ["RealtimeAgent", "RealtimeAgentHooks", "RealtimeRunHooks"] +__all__ = [ + "RealtimeAgent", + "RealtimeAgentHooks", + "RealtimeRunHooks", + "RealtimeSession", + "RealtimeSessionListener", + "RealtimeSessionListenerFunc", + "APIKeyOrKeyFunc", + "RealtimeModelName", + "RealtimeSessionTransport", + "RealtimeTransportListener", + "RealtimeTransportConnectionOptions", + "RealtimeSessionEvent", + "RealtimeAgentStartEvent", + "RealtimeAgentEndEvent", + "RealtimeHandoffEvent", + "RealtimeToolStart", + "RealtimeToolEnd", + "RealtimeRawTransportEvent", + "RealtimeAudioEnd", + "RealtimeAudio", + "RealtimeAudioInterrupted", + "RealtimeError", + "RealtimeHistoryUpdated", + "RealtimeHistoryAdded", +] diff --git a/src/agents/realtime/agent.py b/src/agents/realtime/agent.py index 416be9e73..9bbed8cb4 100644 --- a/src/agents/realtime/agent.py +++ b/src/agents/realtime/agent.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import dataclasses import inspect from collections.abc import Awaitable @@ -10,9 +9,7 @@ from ..agent import AgentBase from ..lifecycle import AgentHooksBase, RunHooksBase from ..logger import logger -from ..mcp import MCPUtil from ..run_context import RunContextWrapper, TContext -from ..tool import FunctionTool, Tool from ..util._types import MaybeAwaitable RealtimeAgentHooks = AgentHooksBase[TContext, "RealtimeAgent[TContext]"] @@ -81,30 +78,3 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s logger.error(f"Instructions must be a string or a function, got {self.instructions}") return None - - async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: - """Fetches the available tools from the MCP servers.""" - convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) - return await MCPUtil.get_all_function_tools( - self.mcp_servers, convert_schemas_to_strict, run_context, self - ) - - async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: - """All agent tools, including MCP tools and function tools.""" - mcp_tools = await self.get_mcp_tools(run_context) - - async def _check_tool_enabled(tool: Tool) -> bool: - if not isinstance(tool, FunctionTool): - return True - - attr = tool.is_enabled - if isinstance(attr, bool): - return attr - res = attr(run_context, self) - if inspect.isawaitable(res): - return bool(await res) - return bool(res) - - results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools)) - enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok] - return [*mcp_tools, *enabled] diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py new file mode 100644 index 000000000..aa15c837d --- /dev/null +++ b/src/agents/realtime/config.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import inspect +from typing import ( + Any, + Callable, + Literal, + Union, +) + +from typing_extensions import NotRequired, TypeAlias, TypedDict + +from ..model_settings import ToolChoice +from ..tool import FunctionTool +from ..util._types import MaybeAwaitable + + +class RealtimeClientMessage(TypedDict): + type: str # explicitly required + other_data: NotRequired[dict[str, Any]] + + +class UserInputText(TypedDict): + type: Literal["input_text"] + text: str + + +class RealtimeUserInputMessage(TypedDict): + type: Literal["message"] + role: Literal["user"] + content: list[UserInputText] + + +RealtimeUserInput: TypeAlias = Union[str, RealtimeUserInputMessage] + + +RealtimeAudioFormat: TypeAlias = Union[Literal["pcm16", "g711_ulaw", "g711_alaw"], str] + + +class RealtimeInputAudioTranscriptionConfig(TypedDict): + language: NotRequired[str] + model: NotRequired[Literal["gpt-4o-transcribe", "gpt-4o-mini-transcribe", "whisper-1"] | str] + prompt: NotRequired[str] + + +class RealtimeTurnDetectionConfig(TypedDict): + """Turn detection config. Allows extra vendor keys if needed.""" + + type: NotRequired[Literal["semantic_vad", "server_vad"]] + create_response: NotRequired[bool] + eagerness: NotRequired[Literal["auto", "low", "medium", "high"]] + interrupt_response: NotRequired[bool] + prefix_padding_ms: NotRequired[int] + silence_duration_ms: NotRequired[int] + threshold: NotRequired[float] + + +class RealtimeSessionConfig(TypedDict): + api_key: NotRequired[APIKeyOrKeyFunc] + model: NotRequired[str] + instructions: NotRequired[str] + modalities: NotRequired[list[Literal["text", "audio"]]] + voice: NotRequired[str] + + input_audio_format: NotRequired[RealtimeAudioFormat] + output_audio_format: NotRequired[RealtimeAudioFormat] + input_audio_transcription: NotRequired[RealtimeInputAudioTranscriptionConfig] + turn_detection: NotRequired[RealtimeTurnDetectionConfig] + + tool_choice: NotRequired[ToolChoice] + tools: NotRequired[list[FunctionTool]] + + +APIKeyOrKeyFunc = str | Callable[[], MaybeAwaitable[str]] +"""Either an API key or a function that returns an API key.""" + + +async def get_api_key(key: APIKeyOrKeyFunc | None) -> str | None: + """Get the API key from the key or key function.""" + if key is None: + return None + elif isinstance(key, str): + return key + + result = key() + if inspect.isawaitable(result): + return await result + return result + + # TODO (rm) Add tracing support + # tracing: NotRequired[RealtimeTracingConfig | None] diff --git a/src/agents/realtime/events.py b/src/agents/realtime/events.py new file mode 100644 index 000000000..bd6b7b5b0 --- /dev/null +++ b/src/agents/realtime/events.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass +from typing import Any, Literal, Union + +from typing_extensions import TypeAlias + +from ..run_context import RunContextWrapper +from ..tool import Tool +from .agent import RealtimeAgent +from .items import RealtimeItem +from .transport_events import RealtimeTransportAudioEvent, RealtimeTransportEvent + + +@dataclass +class RealtimeEventInfo: + context: RunContextWrapper + """The context for the event.""" + + +@dataclass +class RealtimeAgentStartEvent: + """A new agent has started.""" + + agent: RealtimeAgent + """The new agent.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["agent_start"] = "agent_start" + + +@dataclass +class RealtimeAgentEndEvent: + """An agent has ended.""" + + agent: RealtimeAgent + """The agent that ended.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["agent_end"] = "agent_end" + + +@dataclass +class RealtimeHandoffEvent: + """An agent has handed off to another agent.""" + + from_agent: RealtimeAgent + """The agent that handed off.""" + + to_agent: RealtimeAgent + """The agent that was handed off to.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["handoff"] = "handoff" + + +@dataclass +class RealtimeToolStart: + """An agent is starting a tool call.""" + + agent: RealtimeAgent + """The agent that updated.""" + + tool: Tool + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["tool_start"] = "tool_start" + + +@dataclass +class RealtimeToolEnd: + """An agent has ended a tool call.""" + + agent: RealtimeAgent + """The agent that ended the tool call.""" + + tool: Tool + """The tool that was called.""" + + output: Any + """The output of the tool call.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["tool_end"] = "tool_end" + + +@dataclass +class RealtimeRawTransportEvent: + """Forwards raw events from the transport layer.""" + + data: RealtimeTransportEvent + """The raw data from the transport layer.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["raw_transport_event"] = "raw_transport_event" + + +@dataclass +class RealtimeAudioEnd: + """Triggered when the agent stops generating audio.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["audio_end"] = "audio_end" + + +@dataclass +class RealtimeAudio: + """Triggered when the agent generates new audio to be played.""" + + audio: RealtimeTransportAudioEvent + """The audio event from the transport layer.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["audio"] = "audio" + + +@dataclass +class RealtimeAudioInterrupted: + """Triggered when the agent is interrupted. Can be listened to by the user to stop audio + playback or give visual indicators to the user. + """ + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["audio_interrupted"] = "audio_interrupted" + + +@dataclass +class RealtimeError: + """An error has occurred.""" + + error: Any + """The error that occurred.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["error"] = "error" + + +@dataclass +class RealtimeHistoryUpdated: + """The history has been updated. Contains the full history of the session.""" + + history: list[RealtimeItem] + """The full history of the session.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["history_updated"] = "history_updated" + + +@dataclass +class RealtimeHistoryAdded: + """A new item has been added to the history.""" + + item: RealtimeItem + """The new item that was added to the history.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["history_added"] = "history_added" + + +# TODO (rm) Add guardrails + +RealtimeSessionEvent: TypeAlias = Union[ + RealtimeAgentStartEvent, + RealtimeAgentEndEvent, + RealtimeHandoffEvent, + RealtimeToolStart, + RealtimeToolEnd, + RealtimeRawTransportEvent, + RealtimeAudioEnd, + RealtimeAudio, + RealtimeAudioInterrupted, + RealtimeError, + RealtimeHistoryUpdated, + RealtimeHistoryAdded, +] +"""An event emitted by the realtime session.""" diff --git a/src/agents/realtime/items.py b/src/agents/realtime/items.py new file mode 100644 index 000000000..117a35a02 --- /dev/null +++ b/src/agents/realtime/items.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import Annotated, Literal, Union + +from pydantic import BaseModel, ConfigDict, Field + + +class InputText(BaseModel): + type: Literal["input_text"] = "input_text" + text: str + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class InputAudio(BaseModel): + type: Literal["input_audio"] = "input_audio" + audio: str | None = None + transcript: str | None = None + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class AssistantText(BaseModel): + type: Literal["text"] = "text" + text: str + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class AssistantAudio(BaseModel): + type: Literal["audio"] = "audio" + audio: str | None = None + transcript: str | None = None + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class SystemMessageItem(BaseModel): + item_id: str + previous_item_id: str | None = None + type: Literal["message"] = "message" + role: Literal["system"] = "system" + content: list[InputText] + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class UserMessageItem(BaseModel): + item_id: str + previous_item_id: str | None = None + type: Literal["message"] = "message" + role: Literal["user"] = "user" + content: list[InputText | InputAudio] + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class AssistantMessageItem(BaseModel): + item_id: str + previous_item_id: str | None = None + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + status: Literal["in_progress", "completed", "incomplete"] | None = None + content: list[AssistantText | AssistantAudio] + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +RealtimeMessageItem = Annotated[ + Union[SystemMessageItem, UserMessageItem, AssistantMessageItem], + Field(discriminator="role"), +] + + +class RealtimeToolCallItem(BaseModel): + item_id: str + previous_item_id: str | None = None + type: Literal["function_call"] = "function_call" + status: Literal["in_progress", "completed"] + arguments: str + name: str + output: str | None = None + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +RealtimeItem = RealtimeMessageItem | RealtimeToolCallItem + + +class RealtimeResponse(BaseModel): + id: str + output: list[RealtimeMessageItem] diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py new file mode 100644 index 000000000..57347a84b --- /dev/null +++ b/src/agents/realtime/openai_realtime.py @@ -0,0 +1,354 @@ +import asyncio +import base64 +import json +import os +from datetime import datetime +from typing import Any + +import websockets +from openai.types.beta.realtime.realtime_server_event import ( + RealtimeServerEvent as OpenAIRealtimeServerEvent, +) +from pydantic import TypeAdapter +from websockets.asyncio.client import ClientConnection + +from agents.realtime.items import RealtimeMessageItem, RealtimeToolCallItem + +from ..exceptions import UserError +from ..logger import logger +from .config import RealtimeClientMessage, RealtimeUserInput, get_api_key +from .transport import ( + RealtimeSessionTransport, + RealtimeTransportConnectionOptions, + RealtimeTransportListener, +) +from .transport_events import ( + RealtimeTransportAudioDoneEvent, + RealtimeTransportAudioEvent, + RealtimeTransportAudioInterruptedEvent, + RealtimeTransportErrorEvent, + RealtimeTransportEvent, + RealtimeTransportInputAudioTranscriptionCompletedEvent, + RealtimeTransportItemDeletedEvent, + RealtimeTransportItemUpdatedEvent, + RealtimeTransportToolCallEvent, + RealtimeTransportTranscriptDelta, + RealtimeTransportTurnEndedEvent, + RealtimeTransportTurnStartedEvent, +) + + +class OpenAIRealtimeWebSocketTransport(RealtimeSessionTransport): + """A transport layer for realtime sessions that uses OpenAI's WebSocket API.""" + + def __init__(self) -> None: + self.model = "gpt-4o-realtime-preview" # Default model + self._websocket: ClientConnection | None = None + self._websocket_task: asyncio.Task[None] | None = None + self._listeners: list[RealtimeTransportListener] = [] + self._current_item_id: str | None = None + self._audio_start_time: datetime | None = None + self._audio_length_ms: float = 0.0 + self._ongoing_response: bool = False + self._current_audio_content_index: int | None = None + + async def connect(self, options: RealtimeTransportConnectionOptions) -> None: + """Establish a connection to the model and keep it alive.""" + assert self._websocket is None, "Already connected" + assert self._websocket_task is None, "Already connected" + + self.model = options.get("model", self.model) + api_key = await get_api_key(options.get("api_key", os.getenv("OPENAI_API_KEY"))) + + if not api_key: + raise UserError("API key is required but was not provided.") + + url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}") + + headers = { + "Authorization": f"Bearer {api_key}", + "OpenAI-Beta": "realtime=v1", + } + self._websocket = await websockets.connect(url, additional_headers=headers) + self._websocket_task = asyncio.create_task(self._listen_for_messages()) + + def add_listener(self, listener: RealtimeTransportListener) -> None: + """Add a listener to the transport.""" + self._listeners.append(listener) + + async def remove_listener(self, listener: RealtimeTransportListener) -> None: + """Remove a listener from the transport.""" + self._listeners.remove(listener) + + async def _emit_event(self, event: RealtimeTransportEvent) -> None: + """Emit an event to the listeners.""" + for listener in self._listeners: + await listener.on_event(event) + + async def _listen_for_messages(self): + assert self._websocket is not None, "Not connected" + + try: + async for message in self._websocket: + parsed = json.loads(message) + await self._handle_ws_event(parsed) + + except websockets.exceptions.ConnectionClosed: + # TODO connection closed handling (event, cleanup) + logger.warning("WebSocket connection closed") + except Exception as e: + logger.error(f"WebSocket error: {e}") + + async def send_event(self, event: RealtimeClientMessage) -> None: + """Send an event to the model.""" + assert self._websocket is not None, "Not connected" + converted_event = { + "type": event["type"], + } + + converted_event.update(event.get("other_data", {})) + + await self._websocket.send(json.dumps(converted_event)) + + async def send_message( + self, message: RealtimeUserInput, other_event_data: dict[str, Any] | None = None + ) -> None: + """Send a message to the model.""" + message = ( + message + if isinstance(message, dict) + else { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": message}], + } + ) + other_data = { + "item": message, + } + if other_event_data: + other_data.update(other_event_data) + + await self.send_event({"type": "conversation.item.create", "other_data": other_data}) + + await self.send_event({"type": "response.create"}) + + async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: + """Send a raw audio chunk to the model. + + Args: + audio: The audio data to send. + commit: Whether to commit the audio buffer to the model. If the model does not do turn + detection, this can be used to indicate the turn is completed. + """ + assert self._websocket is not None, "Not connected" + base64_audio = base64.b64encode(audio).decode("utf-8") + await self.send_event( + { + "type": "input_audio_buffer.append", + "other_data": { + "audio": base64_audio, + }, + } + ) + if commit: + await self.send_event({"type": "input_audio_buffer.commit"}) + + async def send_tool_output( + self, tool_call: RealtimeTransportToolCallEvent, output: str, start_response: bool + ) -> None: + """Send tool output to the model.""" + await self.send_event( + { + "type": "conversation.item.create", + "other_data": { + "item": { + "type": "function_call_output", + "output": output, + "call_id": tool_call.id, + }, + }, + } + ) + + tool_item = RealtimeToolCallItem( + item_id=tool_call.id or "", + previous_item_id=tool_call.previous_item_id, + type="function_call", + status="completed", + arguments=tool_call.arguments, + name=tool_call.name, + output=output, + ) + await self._emit_event(RealtimeTransportItemUpdatedEvent(item=tool_item)) + + if start_response: + await self.send_event({"type": "response.create"}) + + async def interrupt(self) -> None: + """Interrupt the model.""" + if not self._current_item_id or not self._audio_start_time: + return + + await self._cancel_response() + + elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000 + if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms: + await self._emit_event(RealtimeTransportAudioInterruptedEvent()) + await self.send_event( + { + "type": "conversation.item.truncate", + "other_data": { + "item_id": self._current_item_id, + "content_index": self._current_audio_content_index, + "audio_end_ms": elapsed_time_ms, + }, + } + ) + + self._current_item_id = None + self._audio_start_time = None + self._audio_length_ms = 0.0 + self._current_audio_content_index = None + + async def close(self) -> None: + """Close the session.""" + if self._websocket: + await self._websocket.close() + self._websocket = None + if self._websocket_task: + self._websocket_task.cancel() + self._websocket_task = None + + async def _cancel_response(self) -> None: + if self._ongoing_response: + await self.send_event({"type": "response.cancel"}) + self._ongoing_response = False + + async def _handle_ws_event(self, event: dict[str, Any]): + try: + parsed: OpenAIRealtimeServerEvent = TypeAdapter( + OpenAIRealtimeServerEvent + ).validate_python(event) + except Exception as e: + logger.error(f"Invalid event: {event} - {e}") + await self._emit_event(RealtimeTransportErrorEvent(error=f"Invalid event: {event}")) + return + + if parsed.type == "response.audio.delta": + self._current_audio_content_index = parsed.content_index + self._current_item_id = parsed.item_id + if self._audio_start_time is None: + self._audio_start_time = datetime.now() + self._audio_length_ms = 0.0 + + audio_bytes = base64.b64decode(parsed.delta) + # Calculate audio length in ms using 24KHz pcm16le + self._audio_length_ms += len(audio_bytes) / 24 / 2 + await self._emit_event( + RealtimeTransportAudioEvent(data=audio_bytes, response_id=parsed.response_id) + ) + elif parsed.type == "response.audio.done": + await self._emit_event(RealtimeTransportAudioDoneEvent()) + elif parsed.type == "input_audio_buffer.speech_started": + await self.interrupt() + elif parsed.type == "response.created": + self._ongoing_response = True + await self._emit_event(RealtimeTransportTurnStartedEvent()) + elif parsed.type == "response.done": + self._ongoing_response = False + await self._emit_event(RealtimeTransportTurnEndedEvent()) + elif parsed.type == "session.created": + # TODO (rm) tracing stuff here + pass + elif parsed.type == "error": + await self._emit_event(RealtimeTransportErrorEvent(error=parsed.error)) + elif parsed.type == "conversation.item.deleted": + await self._emit_event(RealtimeTransportItemDeletedEvent(item_id=parsed.item_id)) + elif ( + parsed.type == "conversation.item.created" + or parsed.type == "conversation.item.retrieved" + ): + item = parsed.item + previous_item_id = ( + parsed.previous_item_id if parsed.type == "conversation.item.created" else None + ) + message_item: RealtimeMessageItem = TypeAdapter(RealtimeMessageItem).validate_python( + { + "item_id": item.id or "", + "previous_item_id": previous_item_id, + "type": item.type, + "role": item.role, + "content": item.content, + "status": "in_progress", + } + ) + await self._emit_event(RealtimeTransportItemUpdatedEvent(item=message_item)) + elif ( + parsed.type == "conversation.item.input_audio_transcription.completed" + or parsed.type == "conversation.item.truncated" + ): + await self.send_event( + { + "type": "conversation.item.retrieve", + "other_data": { + "item_id": self._current_item_id, + }, + } + ) + if parsed.type == "conversation.item.input_audio_transcription.completed": + await self._emit_event( + RealtimeTransportInputAudioTranscriptionCompletedEvent( + item_id=parsed.item_id, transcript=parsed.transcript + ) + ) + elif parsed.type == "response.audio_transcript.delta": + await self._emit_event( + RealtimeTransportTranscriptDelta( + item_id=parsed.item_id, delta=parsed.delta, response_id=parsed.response_id + ) + ) + elif ( + parsed.type == "conversation.item.input_audio_transcription.delta" + or parsed.type == "response.text.delta" + or parsed.type == "response.function_call_arguments.delta" + ): + # No support for partials yet + pass + elif ( + parsed.type == "response.output_item.added" + or parsed.type == "response.output_item.done" + ): + item = parsed.item + if item.type == "function_call" and item.status == "completed": + tool_call = RealtimeToolCallItem( + item_id=item.id or "", + previous_item_id=None, + type="function_call", + # We use the same item for tool call and output, so it will be completed by the + # output being added + status="in_progress", + arguments=item.arguments or "", + name=item.name or "", + output=None, + ) + await self._emit_event(RealtimeTransportItemUpdatedEvent(item=tool_call)) + await self._emit_event( + RealtimeTransportToolCallEvent( + call_id=item.id or "", + name=item.name or "", + arguments=item.arguments or "", + id=item.id or "", + ) + ) + elif item.type == "message": + message_item = TypeAdapter(RealtimeMessageItem).validate_python( + { + "item_id": item.id or "", + "type": item.type, + "role": item.role, + "content": item.content, + "status": "in_progress", + } + ) + await self._emit_event(RealtimeTransportItemUpdatedEvent(item=message_item)) diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py new file mode 100644 index 000000000..60ae83f0c --- /dev/null +++ b/src/agents/realtime/session.py @@ -0,0 +1,370 @@ +"""Minimal realtime session implementation for voice agents.""" + +from __future__ import annotations + +import abc +import asyncio +from collections.abc import Awaitable +from typing import Any, Callable, Literal + +from typing_extensions import TypeAlias, assert_never + +from agents.handoffs import Handoff +from agents.tool_context import ToolContext + +from ..run_context import RunContextWrapper +from ..tool import FunctionTool +from .agent import RealtimeAgent +from .config import APIKeyOrKeyFunc, RealtimeSessionConfig, RealtimeUserInput +from .events import ( + RealtimeAgentEndEvent, + RealtimeAgentStartEvent, + RealtimeAudio, + RealtimeAudioEnd, + RealtimeAudioInterrupted, + RealtimeError, + RealtimeEventInfo, + RealtimeHandoffEvent, # noqa: F401 + RealtimeHistoryAdded, + RealtimeHistoryUpdated, + RealtimeRawTransportEvent, + RealtimeSessionEvent, + RealtimeToolEnd, # noqa: F401 + RealtimeToolStart, # noqa: F401 +) +from .items import InputAudio, InputText, RealtimeItem +from .openai_realtime import OpenAIRealtimeWebSocketTransport +from .transport import ( + RealtimeModelName, + RealtimeSessionTransport, + RealtimeTransportConnectionOptions, + RealtimeTransportListener, +) +from .transport_events import ( + RealtimeTransportEvent, + RealtimeTransportInputAudioTranscriptionCompletedEvent, + RealtimeTransportToolCallEvent, +) + + +class RealtimeSessionListener(abc.ABC): + """A listener for realtime session events.""" + + @abc.abstractmethod + async def on_event(self, event: RealtimeSessionEvent) -> None: + """Called when an event is emitted by the realtime session.""" + pass + + +RealtimeSessionListenerFunc: TypeAlias = Callable[[RealtimeSessionEvent], Awaitable[None]] +"""A function that can be used as a listener for realtime session events.""" + + +class _RealtimeFuncListener(RealtimeSessionListener): + """A listener that wraps a function.""" + + def __init__(self, func: RealtimeSessionListenerFunc) -> None: + self._func = func + + async def on_event(self, event: RealtimeSessionEvent) -> None: + """Call the wrapped function with the event.""" + await self._func(event) + + +class RealtimeSession(RealtimeTransportListener): + """A `RealtimeSession` is the equivalent of `Runner` for realtime agents. It automatically + handles multiple turns by maintaining a persistent connection with the underlying transport + layer. + + The session manages the local history copy, executes tools, runs guardrails and facilitates + handoffs between agents. + + Since this code runs on your server, it uses WebSockets by default. You can optionally create + your own custom transport layer by implementing the `RealtimeSessionTransport` interface. + """ + + def __init__( + self, + starting_agent: RealtimeAgent, + *, + context: Any | None = None, + transport: Literal["websocket"] | RealtimeSessionTransport = "websocket", + api_key: APIKeyOrKeyFunc | None = None, + model: RealtimeModelName | None = None, + config: RealtimeSessionConfig | None = None, + # TODO (rm) Add guardrail support + # TODO (rm) Add tracing support + # TODO (rm) Add history audio storage config + ) -> None: + """Initialize the realtime session. + + Args: + starting_agent: The agent to start the session with. + context: The context to use for the session. + transport: The transport to use for the session. Defaults to using websockets. + api_key: The API key to use for the session. + model: The model to use. Must be a realtime model. + config: Override parameters to use. + """ + self._current_agent = starting_agent + self._context_wrapper = RunContextWrapper(context) + self._event_info = RealtimeEventInfo(context=self._context_wrapper) + self._override_config = config + self._history: list[RealtimeItem] = [] + self._model = model + self._api_key = api_key + + self._listeners: list[RealtimeSessionListener] = [] + + if transport == "websocket": + self._transport: RealtimeSessionTransport = OpenAIRealtimeWebSocketTransport() + else: + self._transport = transport + + async def __aenter__(self) -> RealtimeSession: + """Async context manager entry.""" + await self.connect() + return self + + async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None: + """Async context manager exit.""" + await self.end() + + async def connect(self) -> None: + """Start the session: connect to the model and start the connection.""" + self._transport.add_listener(self) + + config = await self.create_session_config( + overrides=self._override_config, + ) + + options: RealtimeTransportConnectionOptions = { + "initial_session_config": config, + } + + if config.get("api_key") is not None: + options["api_key"] = config["api_key"] + elif self._api_key is not None: + options["api_key"] = self._api_key + + if config.get("model") is not None: + options["model"] = config["model"] + elif self._model is not None: + options["model"] = self._model + + await self._transport.connect(options) + + await self._emit_event( + RealtimeHistoryUpdated( + history=self._history, + info=self._event_info, + ) + ) + + async def end(self) -> None: + """End the session: disconnect from the model and close the connection.""" + pass + + def add_listener(self, listener: RealtimeSessionListener | RealtimeSessionListenerFunc) -> None: + """Add a listener to the session.""" + if isinstance(listener, RealtimeSessionListener): + self._listeners.append(listener) + else: + self._listeners.append(_RealtimeFuncListener(listener)) + + def remove_listener( + self, listener: RealtimeSessionListener | RealtimeSessionListenerFunc + ) -> None: + """Remove a listener from the session.""" + if isinstance(listener, RealtimeSessionListener): + self._listeners.remove(listener) + else: + for x in self._listeners: + if isinstance(x, _RealtimeFuncListener) and x._func == listener: + self._listeners.remove(x) + break + + async def create_session_config( + self, overrides: RealtimeSessionConfig | None = None + ) -> RealtimeSessionConfig: + """Create the session config.""" + agent = self._current_agent + instructions, tools = await asyncio.gather( + agent.get_system_prompt(self._context_wrapper), + agent.get_all_tools(self._context_wrapper), + ) + config = RealtimeSessionConfig() + + if self._model is not None: + config["model"] = self._model + if instructions is not None: + config["instructions"] = instructions + if tools is not None: + config["tools"] = [tool for tool in tools if isinstance(tool, FunctionTool)] + + if overrides: + config.update(overrides) + + return config + + async def send_message(self, message: RealtimeUserInput) -> None: + """Send a message to the model.""" + await self._transport.send_message(message) + + async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: + """Send a raw audio chunk to the model.""" + await self._transport.send_audio(audio, commit=commit) + + async def interrupt(self) -> None: + """Interrupt the model.""" + await self._transport.interrupt() + + async def on_event(self, event: RealtimeTransportEvent) -> None: + """Called when an event is emitted by the realtime transport.""" + await self._emit_event(RealtimeRawTransportEvent(data=event, info=self._event_info)) + + if event.type == "error": + await self._emit_event(RealtimeError(info=self._event_info, error=event.error)) + elif event.type == "function_call": + await self._handle_tool_call(event) + elif event.type == "audio": + await self._emit_event(RealtimeAudio(info=self._event_info, audio=event)) + elif event.type == "audio_interrupted": + await self._emit_event(RealtimeAudioInterrupted(info=self._event_info)) + elif event.type == "audio_done": + await self._emit_event(RealtimeAudioEnd(info=self._event_info)) + elif event.type == "conversation.item.input_audio_transcription.completed": + self._history = self._get_new_history(self._history, event) + await self._emit_event( + RealtimeHistoryUpdated(info=self._event_info, history=self._history) + ) + elif event.type == "transcript_delta": + # TODO (rm) Add guardrails + pass + elif event.type == "item_updated": + is_new = any(item.item_id == event.item.item_id for item in self._history) + self._history = self._get_new_history(self._history, event.item) + if is_new: + new_item = next( + item for item in self._history if item.item_id == event.item.item_id + ) + await self._emit_event(RealtimeHistoryAdded(info=self._event_info, item=new_item)) + else: + await self._emit_event( + RealtimeHistoryUpdated(info=self._event_info, history=self._history) + ) + pass + elif event.type == "item_deleted": + deleted_id = event.item_id + self._history = [item for item in self._history if item.item_id != deleted_id] + await self._emit_event( + RealtimeHistoryUpdated(info=self._event_info, history=self._history) + ) + elif event.type == "connection_status": + pass + elif event.type == "turn_started": + await self._emit_event( + RealtimeAgentStartEvent( + agent=self._current_agent, + info=self._event_info, + ) + ) + elif event.type == "turn_ended": + await self._emit_event( + RealtimeAgentEndEvent( + agent=self._current_agent, + info=self._event_info, + ) + ) + elif event.type == "other": + pass + else: + assert_never(event) + + async def _emit_event(self, event: RealtimeSessionEvent) -> None: + """Emit an event to the listeners.""" + await asyncio.gather(*[listener.on_event(event) for listener in self._listeners]) + + async def _handle_tool_call(self, event: RealtimeTransportToolCallEvent) -> None: + all_tools = await self._current_agent.get_all_tools(self._context_wrapper) + function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} + handoff_map = {tool.name: tool for tool in all_tools if isinstance(tool, Handoff)} + + if event.name in function_map: + await self._emit_event( + RealtimeToolStart( + info=self._event_info, + tool=function_map[event.name], + agent=self._current_agent, + ) + ) + + func_tool = function_map[event.name] + tool_context = ToolContext.from_agent_context(self._context_wrapper, event.call_id) + result = await func_tool.on_invoke_tool(tool_context, event.arguments) + + await self._transport.send_tool_output(event, str(result), True) + + await self._emit_event( + RealtimeToolEnd( + info=self._event_info, + tool=func_tool, + output=result, + agent=self._current_agent, + ) + ) + elif event.name in handoff_map: + # TODO (rm) Add support for handoffs + pass + else: + # TODO (rm) Add error handling + pass + + def _get_new_history( + self, + old_history: list[RealtimeItem], + event: RealtimeTransportInputAudioTranscriptionCompletedEvent | RealtimeItem, + ) -> list[RealtimeItem]: + # Merge transcript into placeholder input_audio message. + if isinstance(event, RealtimeTransportInputAudioTranscriptionCompletedEvent): + new_history: list[RealtimeItem] = [] + for item in old_history: + if item.item_id == event.item_id and item.type == "message" and item.role == "user": + content: list[InputText | InputAudio] = [] + for entry in item.content: + if entry.type == "input_audio": + copied_entry = entry.model_copy(update={"transcript": event.transcript}) + content.append(copied_entry) + else: + content.append(entry) # type: ignore + new_history.append( + item.model_copy(update={"content": content, "status": "completed"}) + ) + else: + new_history.append(item) + return new_history + + # Otherwise it's just a new item + # TODO (rm) Add support for audio storage config + + # If the item already exists, update it + existing_index = next( + (i for i, item in enumerate(old_history) if item.item_id == event.item_id), None + ) + if existing_index is not None: + new_history = old_history.copy() + new_history[existing_index] = event + return new_history + # Otherwise, insert it after the previous_item_id if that is set + elif item.previous_item_id: + # Insert the new item after the previous item + previous_index = next( + (i for i, item in enumerate(old_history) if item.item_id == event.previous_item_id), + None, + ) + if previous_index is not None: + new_history = old_history.copy() + new_history.insert(previous_index + 1, event) + return new_history + # Otherwise, add it to the end + return old_history + [event] diff --git a/src/agents/realtime/transport.py b/src/agents/realtime/transport.py new file mode 100644 index 000000000..18290d128 --- /dev/null +++ b/src/agents/realtime/transport.py @@ -0,0 +1,107 @@ +import abc +from typing import Any, Literal, Union + +from typing_extensions import NotRequired, TypeAlias, TypedDict + +from .config import APIKeyOrKeyFunc, RealtimeClientMessage, RealtimeSessionConfig, RealtimeUserInput +from .transport_events import RealtimeTransportEvent, RealtimeTransportToolCallEvent + +RealtimeModelName: TypeAlias = Union[ + Literal[ + "gpt-4o-realtime-preview", + "gpt-4o-mini-realtime-preview", + "gpt-4o-realtime-preview-2025-06-03", + "gpt-4o-realtime-preview-2024-12-17", + "gpt-4o-realtime-preview-2024-10-01", + "gpt-4o-mini-realtime-preview-2024-12-17", + ], + str, +] +"""The name of a realtime model.""" + + +class RealtimeTransportListener(abc.ABC): + """A listener for realtime transport events.""" + + @abc.abstractmethod + async def on_event(self, event: RealtimeTransportEvent) -> None: + """Called when an event is emitted by the realtime transport.""" + pass + + +class RealtimeTransportConnectionOptions(TypedDict): + """Options for connecting to a realtime transport.""" + + api_key: NotRequired[APIKeyOrKeyFunc] + """The API key to use for the transport. If unset, the transport will attempt to use the + `OPENAI_API_KEY` environment variable. + """ + + model: NotRequired[str] + """The model to use.""" + + url: NotRequired[str] + """The URL to use for the transport. If unset, the transport will use the default OpenAI + WebSocket URL. + """ + + initial_session_config: NotRequired[RealtimeSessionConfig] + + +class RealtimeSessionTransport(abc.ABC): + """A transport layer for realtime sessions.""" + + @abc.abstractmethod + async def connect(self, options: RealtimeTransportConnectionOptions) -> None: + """Establish a connection to the model and keep it alive.""" + pass + + @abc.abstractmethod + def add_listener(self, listener: RealtimeTransportListener) -> None: + """Add a listener to the transport.""" + pass + + @abc.abstractmethod + async def remove_listener(self, listener: RealtimeTransportListener) -> None: + """Remove a listener from the transport.""" + pass + + @abc.abstractmethod + async def send_event(self, event: RealtimeClientMessage) -> None: + """Send an event to the model.""" + pass + + @abc.abstractmethod + async def send_message( + self, message: RealtimeUserInput, other_event_data: dict[str, Any] | None = None + ) -> None: + """Send a message to the model.""" + pass + + @abc.abstractmethod + async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: + """Send a raw audio chunk to the model. + + Args: + audio: The audio data to send. + commit: Whether to commit the audio buffer to the model. If the model does not do turn + detection, this can be used to indicate the turn is completed. + """ + pass + + @abc.abstractmethod + async def send_tool_output( + self, tool_call: RealtimeTransportToolCallEvent, output: str, start_response: bool + ) -> None: + """Send tool output to the model.""" + pass + + @abc.abstractmethod + async def interrupt(self) -> None: + """Interrupt the model. For example, could be triggered by a guardrail.""" + pass + + @abc.abstractmethod + async def close(self) -> None: + """Close the session.""" + pass diff --git a/src/agents/realtime/transport_events.py b/src/agents/realtime/transport_events.py new file mode 100644 index 000000000..d65e32591 --- /dev/null +++ b/src/agents/realtime/transport_events.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal, TypeAlias, Union + +from agents.realtime.items import RealtimeItem + +RealtimeConnectionStatus: TypeAlias = Literal["connecting", "connected", "disconnected"] + + +@dataclass +class RealtimeTransportErrorEvent: + """Represents a transport‑layer error.""" + + error: Any + + type: Literal["error"] = "error" + + +@dataclass +class RealtimeTransportToolCallEvent: + """Model attempted a tool/function call.""" + + name: str + call_id: str + arguments: str + + id: str | None = None + previous_item_id: str | None = None + + type: Literal["function_call"] = "function_call" + + +@dataclass +class RealtimeTransportAudioEvent: + """Raw audio bytes emitted by the model.""" + + data: bytes + response_id: str + + type: Literal["audio"] = "audio" + + +@dataclass +class RealtimeTransportAudioInterruptedEvent: + """Audio interrupted.""" + + type: Literal["audio_interrupted"] = "audio_interrupted" + + +@dataclass +class RealtimeTransportAudioDoneEvent: + """Audio done.""" + + type: Literal["audio_done"] = "audio_done" + + +@dataclass +class RealtimeTransportInputAudioTranscriptionCompletedEvent: + """Input audio transcription completed.""" + + item_id: str + transcript: str + + type: Literal["conversation.item.input_audio_transcription.completed"] = ( + "conversation.item.input_audio_transcription.completed" + ) + + +@dataclass +class RealtimeTransportTranscriptDelta: + """Partial transcript update.""" + + item_id: str + delta: str + response_id: str + + type: Literal["transcript_delta"] = "transcript_delta" + + +@dataclass +class RealtimeTransportItemUpdatedEvent: + """Item added to the history or updated.""" + + item: RealtimeItem + + type: Literal["item_updated"] = "item_updated" + + +@dataclass +class RealtimeTransportItemDeletedEvent: + """Item deleted from the history.""" + + item_id: str + + type: Literal["item_deleted"] = "item_deleted" + + +@dataclass +class RealtimeTransportConnectionStatusEvent: + """Connection status changed.""" + + status: RealtimeConnectionStatus + + type: Literal["connection_status"] = "connection_status" + + +@dataclass +class RealtimeTransportTurnStartedEvent: + """Triggered when the model starts generating a response for a turn.""" + + type: Literal["turn_started"] = "turn_started" + + +@dataclass +class RealtimeTransportTurnEndedEvent: + """Triggered when the model finishes generating a response for a turn.""" + + type: Literal["turn_ended"] = "turn_ended" + + +@dataclass +class RealtimeTransportOtherEvent: + """Used as a catchall for vendor-specific events.""" + + data: Any + + type: Literal["other"] = "other" + + +# TODO (rm) Add usage events + + +RealtimeTransportEvent: TypeAlias = Union[ + RealtimeTransportErrorEvent, + RealtimeTransportToolCallEvent, + RealtimeTransportAudioEvent, + RealtimeTransportAudioInterruptedEvent, + RealtimeTransportAudioDoneEvent, + RealtimeTransportInputAudioTranscriptionCompletedEvent, + RealtimeTransportTranscriptDelta, + RealtimeTransportItemUpdatedEvent, + RealtimeTransportItemDeletedEvent, + RealtimeTransportConnectionStatusEvent, + RealtimeTransportTurnStartedEvent, + RealtimeTransportTurnEndedEvent, + RealtimeTransportOtherEvent, +] diff --git a/tests/realtime/test_transport_events.py b/tests/realtime/test_transport_events.py new file mode 100644 index 000000000..2219303d0 --- /dev/null +++ b/tests/realtime/test_transport_events.py @@ -0,0 +1,12 @@ +from typing import get_args + +from agents.realtime.transport_events import RealtimeTransportEvent + + +def test_all_events_have_type() -> None: + """Test that all events have a type.""" + events = get_args(RealtimeTransportEvent) + assert len(events) > 0 + for event in events: + assert event.type is not None + assert isinstance(event.type, str)