diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 8696697d3a..4208fc18b6 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -9,6 +9,7 @@ from dspy.adapters.types.reasoning import Reasoning from dspy.adapters.types.tool import Tool, ToolCalls from dspy.experimental import Citations +from dspy.signatures.field import InputField, OutputField from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks @@ -452,13 +453,13 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]]) return messages - def _get_history_field_name(self, signature: type[Signature]) -> bool: + def _get_history_field_name(self, signature: type[Signature]) -> str | None: for name, field in signature.input_fields.items(): if field.annotation == History: return name return None - def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool: + def _get_tool_call_input_field_name(self, signature: type[Signature]) -> str | None: for name, field in signature.input_fields.items(): # Look for annotation `list[dspy.Tool]` or `dspy.Tool` origin = get_origin(field.annotation) @@ -468,54 +469,104 @@ def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool: return name return None - def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool: + def _get_tool_call_output_field_name(self, signature: type[Signature]) -> str | None: for name, field in signature.output_fields.items(): if field.annotation == ToolCalls: return name return None + def _serialize_kv_value(self, v: Any) -> Any: + """Safely serialize values for kv-mode formatting.""" + if isinstance(v, (str, int, float, bool)) or v is None: + return v + try: + return str(v) + except Exception: + return f"" + + def _make_dynamic_signature_for_inputs(self, keys: list[str]) -> type[Signature]: + """Create a dynamic signature with input fields only (no instructions).""" + return Signature({k: InputField() for k in keys}, instructions="") + + def _make_dynamic_signature_for_outputs(self, keys: list[str]) -> type[Signature]: + """Create a dynamic signature with output fields only (no instructions).""" + return Signature({k: OutputField() for k in keys}, instructions="") + def format_conversation_history( self, signature: type[Signature], history_field_name: str, inputs: dict[str, Any], ) -> list[dict[str, Any]]: - """Format the conversation history. - - This method formats the conversation history and the current input as multiturn messages. - - Args: - signature: The DSPy signature for which to format the conversation history. - history_field_name: The name of the history field in the signature. - inputs: The input arguments to the DSPy module. + """Format the conversation history as multiturn messages. - Returns: - A list of multiturn messages. + Supports four modes: + - raw: Direct LM messages → passed through as-is + - demo: {"input_fields": {...}, "output_fields": {...}} → user/assistant pairs + - flat: Arbitrary kv pairs → single user message per dict (default) + - signature: Dict keys match signature fields → user/assistant pairs """ - conversation_history = inputs[history_field_name].messages if history_field_name in inputs else None - - if conversation_history is None: + history = inputs.get(history_field_name) + if history is None: return [] - messages = [] - for message in conversation_history: - messages.append( - { + del inputs[history_field_name] + + if history.mode == "raw": + return [dict(msg) for msg in history.messages] + if history.mode == "demo": + return self._format_demo_history(history.messages) + if history.mode == "signature": + return self._format_signature_history(signature, history.messages) + return self._format_flat_history(history.messages) + + def _format_demo_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Format demo-mode history (input_fields/output_fields → user/assistant).""" + result = [] + for msg in messages: + if "input_fields" in msg: + input_dict = {k: self._serialize_kv_value(v) for k, v in msg["input_fields"].items()} + sig = self._make_dynamic_signature_for_inputs(list(input_dict.keys())) + result.append({ "role": "user", - "content": self.format_user_message_content(signature, message), - } - ) - messages.append( - { + "content": self.format_user_message_content(sig, input_dict), + }) + if "output_fields" in msg: + output_dict = {k: self._serialize_kv_value(v) for k, v in msg["output_fields"].items()} + sig = self._make_dynamic_signature_for_outputs(list(output_dict.keys())) + result.append({ "role": "assistant", - "content": self.format_assistant_message_content(signature, message), - } - ) - - # Remove the history field from the inputs - del inputs[history_field_name] + "content": self.format_assistant_message_content(sig, output_dict), + }) + return result - return messages + def _format_signature_history( + self, signature: type[Signature], messages: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Format signature-mode history (signature fields → user/assistant pairs).""" + result = [] + for msg in messages: + result.append({ + "role": "user", + "content": self.format_user_message_content(signature, msg), + }) + result.append({ + "role": "assistant", + "content": self.format_assistant_message_content(signature, msg), + }) + return result + + def _format_flat_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Format flat-mode history (all kv pairs in single user message).""" + result = [] + for msg in messages: + serialized = {k: self._serialize_kv_value(v) for k, v in msg.items()} + sig = self._make_dynamic_signature_for_inputs(list(serialized.keys())) + result.append({ + "role": "user", + "content": self.format_user_message_content(sig, serialized), + }) + return result def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]: """Parse the LM output into a dictionary of the output fields. diff --git a/dspy/adapters/types/history.py b/dspy/adapters/types/history.py index 2c39d5c4ab..e5738ea402 100644 --- a/dspy/adapters/types/history.py +++ b/dspy/adapters/types/history.py @@ -1,25 +1,47 @@ -from typing import Any +from typing import Any, Literal import pydantic class History(pydantic.BaseModel): - """Class representing the conversation history. - - The conversation history is a list of messages, each message entity should have keys from the associated signature. - For example, if you have the following signature: - - ``` - class MySignature(dspy.Signature): - question: str = dspy.InputField() - history: dspy.History = dspy.InputField() - answer: str = dspy.OutputField() - ``` - - Then the history should be a list of dictionaries with keys "question" and "answer". + """Class representing conversation history. + + History supports four message formats, with one mode per History instance: + + 1. **Raw mode**: Direct LM messages with `{"role": "...", "content": "..."}`. + Used for ReAct trajectories and native tool calling. + ```python + history = dspy.History.from_raw([ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ]) + ``` + + 2. **Demo mode**: Nested `{"input_fields": {...}, "output_fields": {...}}` pairs. + Used for few-shot demonstrations with explicit input/output separation. + ```python + history = dspy.History.from_demo([ + {"input_fields": {"question": "2+2?"}, "output_fields": {"answer": "4"}}, + ]) + ``` + + 3. **Flat mode** (default): Arbitrary key-value pairs in a single user message. + ```python + history = dspy.History(messages=[ + {"thought": "I need to search", "tool_name": "search", "observation": "Found it"}, + ]) + ``` + + 4. **Signature mode**: Dict keys match signature fields → user/assistant pairs. + Must be explicitly set. + ```python + history = dspy.History.from_signature([ + {"question": "What is 2+2?", "answer": "4"}, + ]) + ``` Example: - ``` + ```python import dspy dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) @@ -29,19 +51,16 @@ class MySignature(dspy.Signature): history: dspy.History = dspy.InputField() answer: str = dspy.OutputField() - history = dspy.History( - messages=[ - {"question": "What is the capital of France?", "answer": "Paris"}, - {"question": "What is the capital of Germany?", "answer": "Berlin"}, - ] - ) + history = dspy.History.from_signature([ + {"question": "What is the capital of France?", "answer": "Paris"}, + ]) predict = dspy.Predict(MySignature) outputs = predict(question="What is the capital of France?", history=history) ``` Example of capturing the conversation history: - ``` + ```python import dspy dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) @@ -59,6 +78,7 @@ class MySignature(dspy.Signature): """ messages: list[dict[str, Any]] + mode: Literal["signature", "demo", "flat", "raw"] = "flat" model_config = pydantic.ConfigDict( frozen=True, @@ -66,3 +86,86 @@ class MySignature(dspy.Signature): validate_assignment=True, extra="forbid", ) + + @staticmethod + def _infer_mode_from_msg(msg: dict) -> str: + """Infer the mode from a message's structure. + + Detection rules (conservative): + - Raw: has "role" key and ONLY LM-like keys (role, content, tool_calls, tool_call_id, name) + - Demo: keys are ONLY "input_fields" and/or "output_fields" + - Flat: everything else (signature mode must be explicit) + """ + keys = set(msg.keys()) + lm_keys = {"role", "content", "tool_calls", "tool_call_id", "name"} + + if "role" in keys and keys <= lm_keys: + return "raw" + + if keys <= {"input_fields", "output_fields"} and keys: + return "demo" + + return "flat" + + def _validate_msg_for_mode(self, msg: dict, mode: str) -> None: + """Validate a message conforms to the expected mode structure.""" + if mode == "raw": + if not isinstance(msg.get("role"), str): + raise ValueError(f"Raw mode: 'role' must be a string: {msg}") + content = msg.get("content") + if content is not None and not isinstance(content, str): + raise ValueError(f"Raw mode: 'content' must be a string or None: {msg}") + + elif mode == "demo": + if "input_fields" in msg and not isinstance(msg["input_fields"], dict): + raise ValueError(f"Demo mode: 'input_fields' must be a dict: {msg}") + if "output_fields" in msg and not isinstance(msg["output_fields"], dict): + raise ValueError(f"Demo mode: 'output_fields' must be a dict: {msg}") + + elif mode == "signature": + if not isinstance(msg, dict) or not msg: + raise ValueError(f"Signature mode: messages must be non-empty dicts: {msg}") + + @pydantic.model_validator(mode="after") + def _validate_messages(self) -> "History": + if not self.messages: + return self + + # Only infer if mode is the default "flat" and messages clearly match another mode + if self.mode == "flat": + inferred = self._infer_mode_from_msg(self.messages[0]) + if inferred in {"raw", "demo"}: + object.__setattr__(self, "mode", inferred) + + for msg in self.messages: + self._validate_msg_for_mode(msg, self.mode) + + return self + + def with_messages(self, messages: list[dict[str, Any]]) -> "History": + """Return a new History with additional messages appended.""" + return History(messages=[*self.messages, *messages], mode=self.mode) + + @classmethod + def from_demo(cls, messages: list[dict[str, Any]]) -> "History": + """Create a History with demo mode. + + Demo mode expects messages with "input_fields" and/or "output_fields" keys. + """ + return cls(messages=messages, mode="demo") + + @classmethod + def from_raw(cls, messages: list[dict[str, Any]]) -> "History": + """Create a History with raw mode. + + Raw mode expects direct LM messages with "role" and "content" keys. + """ + return cls(messages=messages, mode="raw") + + @classmethod + def from_signature(cls, messages: list[dict[str, Any]]) -> "History": + """Create a History with signature mode. + + Signature mode expects dicts with keys matching the signature's fields. + """ + return cls(messages=messages, mode="signature") diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 5f87879f80..3e3b1e1f39 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -1,9 +1,12 @@ +import json import logging +import uuid from typing import TYPE_CHECKING, Any, Callable, Literal from litellm import ContextWindowExceededError import dspy +from dspy.adapters.types.history import History from dspy.adapters.types.tool import Tool from dspy.primitives.module import Module from dspy.signatures.signature import ensure_signature @@ -15,29 +18,29 @@ class ReAct(Module): - def __init__(self, signature: type["Signature"], tools: list[Callable], max_iters: int = 10): - """ - ReAct stands for "Reasoning and Acting," a popular paradigm for building tool-using agents. - In this approach, the language model is iteratively provided with a list of tools and has - to reason about the current situation. The model decides whether to call a tool to gather more - information or to finish the task based on its reasoning process. The DSPy version of ReAct is - generalized to work over any signature, thanks to signature polymorphism. + """ReAct (Reasoning and Acting) agent module. - Args: - signature: The signature of the module, which defines the input and output of the react module. - tools (list[Callable]): A list of functions, callable objects, or `dspy.Tool` instances. - max_iters (Optional[int]): The maximum number of iterations to run. Defaults to 10. + ReAct iteratively reasons about the current situation and takes actions using tools. + The trajectory is stored as a History in raw LM message format. - Example: + Args: + signature: The signature defining input and output fields. + tools: List of callable tools the agent can use. + max_iters: Maximum reasoning iterations (default: 10). + Example: ```python def get_weather(city: str) -> str: return f"The weather in {city} is sunny." - react = dspy.ReAct(signature="question->answer", tools=[get_weather]) + react = dspy.ReAct("question -> answer", tools=[get_weather]) pred = react(question="What is the weather in Tokyo?") + print(pred.answer) + print(pred.trajectory) # History object with tool call messages ``` - """ + """ + + def __init__(self, signature: type["Signature"], tools: list[Callable], max_iters: int = 10): super().__init__() self.signature = signature = ensure_signature(signature) self.max_iters = max_iters @@ -49,21 +52,22 @@ def get_weather(city: str) -> str: outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) instr = [f"{signature.instructions}\n"] if signature.instructions else [] - instr.extend( - [ - f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", - f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", - "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.", - "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", - "When writing next_thought, you may reason about the current situation and plan for future steps.", - "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", - ] - ) + instr.extend([ + f"You are an Agent. In each episode, you will be given the fields {inputs} as input. " + "And you can see your past trajectory so far.", + f"Your goal is to use one or more of the supplied tools to collect any necessary information " + f"for producing {outputs}.\n", + "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, " + "and also when finishing the task.", + "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", + "When writing next_thought, you may reason about the current situation and plan for future steps.", + "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", + ]) tools["finish"] = Tool( func=lambda: "Completed.", name="finish", - desc=f"Marks the task as complete. That is, signals that all information for producing the outputs, i.e. {outputs}, are now available to be extracted.", + desc=f"Marks the task as complete. Signals that all information for producing {outputs} is available.", args={}, ) @@ -73,115 +77,190 @@ def get_weather(city: str) -> str: react_signature = ( dspy.Signature({**signature.input_fields}, "\n".join(instr)) - .append("trajectory", dspy.InputField(), type_=str) + .append("trajectory", dspy.InputField(), type_=History) .append("next_thought", dspy.OutputField(), type_=str) .append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) ) + extract_instructions = ( + f"You are an extraction agent. Extract the fields: {outputs} from the given trajectory.\n" + f"The original task was:\n{signature.instructions}\n" + "An executor agent has used tools to generate the conversation below. " + f"Given this trajectory, extract the fields: {outputs}." + ) fallback_signature = dspy.Signature( {**signature.input_fields, **signature.output_fields}, - signature.instructions, - ).append("trajectory", dspy.InputField(), type_=str) + extract_instructions, + ).append( + "trajectory", + dspy.InputField(desc="The conversation history with enough context to produce the output"), + type_=History, + ) self.tools = tools self.react = dspy.Predict(react_signature) self.extract = dspy.ChainOfThought(fallback_signature) - def _format_trajectory(self, trajectory: dict[str, Any]): - adapter = dspy.settings.adapter or dspy.ChatAdapter() - trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") - return adapter.format_user_message_content(trajectory_signature, trajectory) - - def forward(self, **input_args): - trajectory = {} + def forward(self, *, trajectory: History | None = None, **input_args): max_iters = input_args.pop("max_iters", self.max_iters) - for idx in range(max_iters): + trajectory = trajectory or History.from_raw([]) + + for _ in range(max_iters): try: - pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) - except ValueError as err: - logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") + pred, trajectory = self._call_with_retry(self.react, trajectory, **input_args) + except (ValueError, ContextWindowExceededError) as err: + logger.warning(f"Ending trajectory: {_fmt_exc(err)}") break - trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args - - try: - trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) - except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + observation = self._run_tool(pred.next_tool_name, pred.next_tool_args) + trajectory = self._record_step(trajectory, pred, observation) if pred.next_tool_name == "finish": break - extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) + extract, trajectory = self._call_with_retry(self.extract, trajectory, **input_args) + trajectory = self._record_extract(trajectory, extract) + return dspy.Prediction(trajectory=trajectory, **extract) - async def aforward(self, **input_args): - trajectory = {} + async def aforward(self, *, trajectory: History | None = None, **input_args): max_iters = input_args.pop("max_iters", self.max_iters) - for idx in range(max_iters): + trajectory = trajectory or History.from_raw([]) + + for _ in range(max_iters): try: - pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) - except ValueError as err: - logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") + pred, trajectory = await self._acall_with_retry(self.react, trajectory, **input_args) + except (ValueError, ContextWindowExceededError) as err: + logger.warning(f"Ending trajectory: {_fmt_exc(err)}") break - trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args - - try: - trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args) - except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + observation = await self._arun_tool(pred.next_tool_name, pred.next_tool_args) + trajectory = self._record_step(trajectory, pred, observation) if pred.next_tool_name == "finish": break - extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) + extract, trajectory = await self._acall_with_retry(self.extract, trajectory, **input_args) + trajectory = self._record_extract(trajectory, extract) + return dspy.Prediction(trajectory=trajectory, **extract) - def _call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): + # ------------------------------------------------------------------------- + # Tool execution + # ------------------------------------------------------------------------- + + def _run_tool(self, name: str, args: dict) -> str: + try: + result = self.tools[name](**args) + return self._serialize(result) + except Exception as err: + return f"Execution error in {name}: {_fmt_exc(err)}" + + async def _arun_tool(self, name: str, args: dict) -> str: + try: + result = await self.tools[name].acall(**args) + return self._serialize(result) + except Exception as err: + return f"Execution error in {name}: {_fmt_exc(err)}" + + def _serialize(self, value: Any) -> str: + if isinstance(value, str): + return value + try: + return json.dumps(value) + except (TypeError, ValueError): + return str(value) + + # ------------------------------------------------------------------------- + # Trajectory recording + # ------------------------------------------------------------------------- + + def _record_step(self, trajectory: History, pred, observation: str) -> History: + """Record a single agent step (action + observation) to the trajectory.""" + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + + action_msg = { + "role": "assistant", + "content": pred.next_thought, + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": pred.next_tool_name, + "arguments": json.dumps(pred.next_tool_args), + }, + }], + } + + observation_msg = { + "role": "tool", + "tool_call_id": tool_call_id, + "content": observation, + } + + return trajectory.with_messages([action_msg, observation_msg]) + + def _record_extract(self, trajectory: History, extract) -> History: + """Record the final extraction result to the trajectory.""" + extract_dict = dict(extract) + reasoning = extract_dict.pop("reasoning", None) + + parts = [] + if reasoning: + parts.append(f"Reasoning: {reasoning}") + for key, value in extract_dict.items(): + parts.append(f"{key}: {self._serialize(value)}") + + return trajectory.with_messages([{"role": "assistant", "content": "\n".join(parts)}]) + + # ------------------------------------------------------------------------- + # LM calls with truncation retry + # ------------------------------------------------------------------------- + + def _call_with_retry(self, module, trajectory: History, **input_args) -> tuple[Any, History]: + last_err = None for _ in range(3): try: - return module( - **input_args, - trajectory=self._format_trajectory(trajectory), - ) - except ContextWindowExceededError: - logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") + return module(**input_args, trajectory=trajectory), trajectory + except ContextWindowExceededError as err: + last_err = err + logger.warning("Context window exceeded, truncating oldest step.") trajectory = self.truncate_trajectory(trajectory) - async def _async_call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): + raise ContextWindowExceededError(f"Context window exceeded after 3 truncation attempts: {last_err}") + + async def _acall_with_retry(self, module, trajectory: History, **input_args) -> tuple[Any, History]: + last_err = None for _ in range(3): try: - return await module.acall( - **input_args, - trajectory=self._format_trajectory(trajectory), - ) - except ContextWindowExceededError: - logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") + return await module.acall(**input_args, trajectory=trajectory), trajectory + except ContextWindowExceededError as err: + last_err = err + logger.warning("Context window exceeded, truncating oldest step.") trajectory = self.truncate_trajectory(trajectory) - def truncate_trajectory(self, trajectory): - """Truncates the trajectory so that it fits in the context window. + raise ContextWindowExceededError(f"Context window exceeded after 3 truncation attempts: {last_err}") + + def truncate_trajectory(self, trajectory: History) -> History: + """Remove the oldest tool call pair from the trajectory. - Users can override this method to implement their own truncation logic. + Override this method to implement custom truncation logic. """ - keys = list(trajectory.keys()) - if len(keys) < 4: - # Every tool call has 4 keys: thought, tool_name, tool_args, and observation. - raise ValueError( - "The trajectory is too long so your prompt exceeded the context window, but the trajectory cannot be " - "truncated because it only has one tool call." - ) + messages = list(trajectory.messages) - for key in keys[:4]: - trajectory.pop(key) + if len(messages) < 2: + raise ValueError("Trajectory too long but cannot truncate: only one step remains.") - return trajectory + # Remove assistant + following tool response(s) + if messages[0].get("role") == "assistant" and messages[0].get("tool_calls"): + messages.pop(0) + while messages and messages[0].get("role") == "tool": + messages.pop(0) + else: + messages.pop(0) + + return History(messages=messages, mode=trajectory.mode) def _fmt_exc(err: BaseException, *, limit: int = 5) -> str: @@ -191,7 +270,6 @@ def _fmt_exc(err: BaseException, *, limit: int = 5) -> str: """ import traceback - return "\n" + "".join(traceback.format_exception(type(err), err, err.__traceback__, limit=limit)).strip() diff --git a/dspy/utils/inspect_history.py b/dspy/utils/inspect_history.py index 07934157fd..65a32fab25 100644 --- a/dspy/utils/inspect_history.py +++ b/dspy/utils/inspect_history.py @@ -10,6 +10,14 @@ def _blue(text: str, end: str = "\n"): return "\x1b[34m" + str(text) + "\x1b[0m" + end +def _yellow(text: str, end: str = "\n"): + return "\x1b[33m" + str(text) + "\x1b[0m" + end + + +def _cyan(text: str, end: str = "\n"): + return "\x1b[36m" + str(text) + "\x1b[0m" + end + + def pretty_print_history(history, n: int = 1): """Prints the last n prompts and their completions.""" @@ -22,37 +30,67 @@ def pretty_print_history(history, n: int = 1): print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") for msg in messages: - print(_red(f"{msg['role'].capitalize()} message:")) - if isinstance(msg["content"], str): - print(msg["content"].strip()) - else: - if isinstance(msg["content"], list): - for c in msg["content"]: - if c["type"] == "text": - print(c["text"].strip()) - elif c["type"] == "image_url": - image_str = "" - if "base64" in c["image_url"].get("url", ""): - len_base64 = len(c["image_url"]["url"].split("base64,")[1]) - image_str = ( - f"<{c['image_url']['url'].split('base64,')[0]}base64," - f"" - ) - else: - image_str = f"" - print(_blue(image_str.strip())) - elif c["type"] == "input_audio": - audio_format = c["input_audio"]["format"] - len_audio = len(c["input_audio"]["data"]) - audio_str = f"