diff --git a/neuron-explainer/neuron_explainer/explanations/prompt_builder.py b/neuron-explainer/neuron_explainer/explanations/prompt_builder.py index 3782940..8776e7f 100644 --- a/neuron-explainer/neuron_explainer/explanations/prompt_builder.py +++ b/neuron-explainer/neuron_explainer/explanations/prompt_builder.py @@ -49,8 +49,13 @@ class Role(str, Enum): class PromptBuilder: """Class for accumulating components of a prompt and then formatting them into an output.""" - def __init__(self) -> None: + + def __init__(self, allow_extra_system_messages: bool = False) -> None: + """The `allow_extra_system_messages` instance variable allows the caller to specify that the prompt + should be allowed to contain system messages after the very first one.""" + self._messages: list[HarmonyMessage] = [] + self._allow_extra_system_messages = allow_extra_system_messages def add_message(self, role: Role, message: str) -> None: self._messages.append(HarmonyMessage(role=role, content=message)) @@ -93,7 +98,7 @@ def build( for message in messages: role = message["role"] assert role == expected_next_role or ( - allow_extra_system_messages and role == Role.SYSTEM + (self._allow_extra_system_messages or allow_extra_system_messages) and role == Role.SYSTEM ), f"Expected message from {expected_next_role} but got message from {role}" if role == Role.SYSTEM: expected_next_role = Role.USER