Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/proxy_lite/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@

class BaseAgentConfig(BaseModel):
client: ClientConfigTypes = Field(default_factory=OpenAIClientConfig)
history_messages_limit: dict[MessageLabel, int] = Field(default_factory=lambda: dict())
history_messages_limit: dict[MessageLabel, int] = Field(
default_factory=lambda: dict()
)
history_messages_include: Optional[dict[MessageLabel, int]] = Field(
default=None,
description="If set, overrides history_messages_limit by setting all message types to 0 except those specified",
Expand Down Expand Up @@ -73,7 +75,9 @@ def tools(self) -> list[Tool]: ...
def tool_descriptions(self) -> str:
tool_descriptions = []
for tool in self.tools:
func_descriptions = "\n".join("- {name}: {description}".format(**schema) for schema in tool.schema)
func_descriptions = "\n".join(
"- {name}: {description}".format(**schema) for schema in tool.schema
)
tool_title = f"{tool.__class__.__name__}:\n" if len(self.tools) > 1 else ""
tool_descriptions.append(f"{tool_title}{func_descriptions}")
return "\n\n".join(tool_descriptions)
Expand Down Expand Up @@ -108,9 +112,21 @@ async def generate_output(
)
).model_dump()
response_content = response_content["choices"][0]["message"]
# Re‑inject tool_calls as text blocks so legacy parser keeps working
tool_blocks = ""
if response_content["tool_calls"]:
import json, uuid

tool_blocks = "\n".join(
f"<tool_call>{json.dumps(tc)}</tool_call>"
for tc in response_content["tool_calls"]
)

content_text = (response_content["content"] or "") + tool_blocks

assistant_message = AssistantMessage(
role=response_content["role"],
content=[Text(text=response_content["content"])] if response_content["content"] else [],
content=[Text(text=content_text)] if content_text else [],
tool_calls=response_content["tool_calls"],
)
if append_assistant_message:
Expand Down
23 changes: 15 additions & 8 deletions src/proxy_lite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async def create_completion(
@classmethod
def create(cls, config: BaseClientConfig) -> "BaseClient":
supported_clients = {
"openai": OpenAIClient,
"openai-azure": OpenAIClient,
"convergence": ConvergenceClient,
}
Expand All @@ -72,7 +73,7 @@ def http_client(self) -> httpx.AsyncClient:

class OpenAIClientConfig(BaseClientConfig):
name: Literal["openai"] = "openai"
model_id: str = "gpt-4o"
model_id: str = "gpt-4.1"
api_key: str = os.environ.get("OPENAI_API_KEY")


Expand Down Expand Up @@ -103,8 +104,10 @@ async def create_completion(
optional_params = {
"seed": seed,
"tools": self.serializer.serialize_tools(tools) if tools else None,
"tool_choice": "required" if tools else None,
"response_format": {"type": "json_object"} if response_format else {"type": "text"},
"tool_choice": "auto" if tools else None,
"response_format": (
{"type": "json_object"} if response_format else {"type": "text"}
),
}
base_params.update({k: v for k, v in optional_params.items() if v is not None})
return await self.external_client.chat.completions.create(**base_params)
Expand All @@ -125,11 +128,13 @@ class ConvergenceClient(OpenAIClient):
async def _validate_model(self) -> None:
try:
response = await self.external_client.models.list()
assert self.config.model_id in [model.id for model in response.data], (
f"Model {self.config.model_id} not found in {response.data}"
)
assert self.config.model_id in [
model.id for model in response.data
], f"Model {self.config.model_id} not found in {response.data}"
self._model_validated = True
logger.debug(f"Model {self.config.model_id} validated and connected to cluster")
logger.debug(
f"Model {self.config.model_id} validated and connected to cluster"
)
except Exception as e:
logger.error(f"Error retrieving model: {e}")
raise e
Expand Down Expand Up @@ -160,7 +165,9 @@ async def create_completion(
optional_params = {
"seed": seed,
"tools": self.serializer.serialize_tools(tools) if tools else None,
"tool_choice": "auto" if tools else None, # vLLM does not support "required"
"tool_choice": (
"auto" if tools else None
), # vLLM does not support "required"
"response_format": response_format if response_format else {"type": "text"},
}
base_params.update({k: v for k, v in optional_params.items() if v is not None})
Expand Down
4 changes: 3 additions & 1 deletion src/proxy_lite/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def from_media(
if text is not None:
text = Text(text=text)
if image is not None:
base64_image = image if is_base64 else base64.b64encode(image).decode("utf-8")
base64_image = (
image if is_base64 else base64.b64encode(image).decode("utf-8")
)
data_url = f"data:image/jpeg;base64,{base64_image}"
image = Image(image_url=ImageUrl(url=data_url))
content = [text, image] if text is not None else [image]
Expand Down
5 changes: 4 additions & 1 deletion src/proxy_lite/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,8 @@ def deserialize_messages(self, data: list[dict]) -> MessageHistory:
)

def serialize_tools(self, tools: list[Tool]) -> list[dict]:
tool_schemas = [[{"type": "function", "function": schema} for schema in tool.schema] for tool in tools]
tool_schemas = [
[{"type": "function", "function": schema} for schema in tool.schema]
for tool in tools
]
return list(itertools.chain.from_iterable(tool_schemas))
31 changes: 25 additions & 6 deletions src/proxy_lite/solvers/simple_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def agent(self) -> BaseAgent:
@property
def history(self) -> MessageHistory:
return MessageHistory(
messages=[SystemMessage.from_media(text=self.agent.system_prompt)] + self.agent.history.messages,
messages=[SystemMessage.from_media(text=self.agent.system_prompt)]
+ self.agent.history.messages,
)

async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> None:
Expand All @@ -56,6 +57,13 @@ async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> N
self.logger.debug(f"Initialised with task: {task}")

async def act(self, observation: Observation) -> Action:
# If the previous env step contained tool responses, convert them
if observation.state.tool_responses:
for resp in observation.state.tool_responses:
await self.agent.receive_tool_message(
text=resp.content or "",
tool_id=resp.id,
)
self.agent.receive_user_message(
image=observation.state.image,
text=observation.state.text,
Expand All @@ -68,7 +76,10 @@ async def act(self, observation: Observation) -> Action:
self.logger.debug(f"Assistant message generated: {message}")

# check tool calls for return_value
if any(tool_call.function["name"] == "return_value" for tool_call in message.tool_calls):
if any(
tool_call.function["name"] == "return_value"
for tool_call in message.tool_calls
):
self.complete = True
arguments = json.loads(message.tool_calls[0].function["arguments"])
if isinstance(arguments, str):
Expand All @@ -78,15 +89,23 @@ async def act(self, observation: Observation) -> Action:

text_content = message.content[0].text

observation_match = re.search(r"<observation>(.*?)</observation>", text_content, re.DOTALL)
observation_content = observation_match.group(1).strip() if observation_match else ""
observation_match = re.search(
r"<observation>(.*?)</observation>", text_content, re.DOTALL
)
observation_content = (
observation_match.group(1).strip() if observation_match else ""
)

self.logger.info("🌐 [bold blue]Observation:[/]")
await self.logger.stream_message(observation_content)

# Extract text between thinking tags if present
thinking_match = re.search(r"<thinking>(.*?)</thinking>", text_content, re.DOTALL)
thinking_content = thinking_match.group(1).strip() if thinking_match else text_content
thinking_match = re.search(
r"<thinking>(.*?)</thinking>", text_content, re.DOTALL
)
thinking_content = (
thinking_match.group(1).strip() if thinking_match else text_content
)

self.logger.info("🧠 [bold purple]Thinking:[/]")
await self.logger.stream_message(thinking_content)
Expand Down