Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ class Role(str, Enum):
class PromptBuilder:
"""Class for accumulating components of a prompt and then formatting them into an output."""

def __init__(self) -> None:

def __init__(self, allow_extra_system_messages: bool = False) -> None:
"""The `allow_extra_system_messages` instance variable allows the caller to specify that the prompt
should be allowed to contain system messages after the very first one."""

self._messages: list[HarmonyMessage] = []
self._allow_extra_system_messages = allow_extra_system_messages

def add_message(self, role: Role, message: str) -> None:
self._messages.append(HarmonyMessage(role=role, content=message))
Expand Down Expand Up @@ -93,7 +98,7 @@ def build(
for message in messages:
role = message["role"]
assert role == expected_next_role or (
allow_extra_system_messages and role == Role.SYSTEM
(self._allow_extra_system_messages or allow_extra_system_messages) and role == Role.SYSTEM
), f"Expected message from {expected_next_role} but got message from {role}"
if role == Role.SYSTEM:
expected_next_role = Role.USER
Expand Down