From 78873f16f5011bd1cebafb3e5f293576bb36a80f Mon Sep 17 00:00:00 2001 From: Federico C Date: Mon, 19 May 2025 13:04:29 +0200 Subject: [PATCH 1/2] Enhance OpenAI client configuration and response handling --- src/proxy_lite/agents/agent_base.py | 22 ++++++++++++++--- src/proxy_lite/client.py | 23 +++++++++++------- src/proxy_lite/history.py | 5 +++- src/proxy_lite/serializer.py | 5 +++- src/proxy_lite/solvers/simple_solver.py | 32 ++++++++++++++++++++----- 5 files changed, 68 insertions(+), 19 deletions(-) diff --git a/src/proxy_lite/agents/agent_base.py b/src/proxy_lite/agents/agent_base.py index b66a878..a6b015e 100644 --- a/src/proxy_lite/agents/agent_base.py +++ b/src/proxy_lite/agents/agent_base.py @@ -28,7 +28,9 @@ class BaseAgentConfig(BaseModel): client: ClientConfigTypes = Field(default_factory=OpenAIClientConfig) - history_messages_limit: dict[MessageLabel, int] = Field(default_factory=lambda: dict()) + history_messages_limit: dict[MessageLabel, int] = Field( + default_factory=lambda: dict() + ) history_messages_include: Optional[dict[MessageLabel, int]] = Field( default=None, description="If set, overrides history_messages_limit by setting all message types to 0 except those specified", @@ -73,7 +75,9 @@ def tools(self) -> list[Tool]: ... def tool_descriptions(self) -> str: tool_descriptions = [] for tool in self.tools: - func_descriptions = "\n".join("- {name}: {description}".format(**schema) for schema in tool.schema) + func_descriptions = "\n".join( + "- {name}: {description}".format(**schema) for schema in tool.schema + ) tool_title = f"{tool.__class__.__name__}:\n" if len(self.tools) > 1 else "" tool_descriptions.append(f"{tool_title}{func_descriptions}") return "\n\n".join(tool_descriptions) @@ -108,9 +112,21 @@ async def generate_output( ) ).model_dump() response_content = response_content["choices"][0]["message"] + # Re‑inject tool_calls as text blocks so legacy parser keeps working + tool_blocks = "" + if response_content["tool_calls"]: + import json, uuid + + tool_blocks = "\n".join( + f"{json.dumps(tc)}" + for tc in response_content["tool_calls"] + ) + + content_text = (response_content["content"] or "") + tool_blocks + assistant_message = AssistantMessage( role=response_content["role"], - content=[Text(text=response_content["content"])] if response_content["content"] else [], + content=[Text(text=content_text)] if content_text else [], tool_calls=response_content["tool_calls"], ) if append_assistant_message: diff --git a/src/proxy_lite/client.py b/src/proxy_lite/client.py index f6a5903..fe3c085 100644 --- a/src/proxy_lite/client.py +++ b/src/proxy_lite/client.py @@ -51,6 +51,7 @@ async def create_completion( @classmethod def create(cls, config: BaseClientConfig) -> "BaseClient": supported_clients = { + "openai": OpenAIClient, "openai-azure": OpenAIClient, "convergence": ConvergenceClient, } @@ -72,7 +73,7 @@ def http_client(self) -> httpx.AsyncClient: class OpenAIClientConfig(BaseClientConfig): name: Literal["openai"] = "openai" - model_id: str = "gpt-4o" + model_id: str = "gpt-4.1" api_key: str = os.environ.get("OPENAI_API_KEY") @@ -103,8 +104,10 @@ async def create_completion( optional_params = { "seed": seed, "tools": self.serializer.serialize_tools(tools) if tools else None, - "tool_choice": "required" if tools else None, - "response_format": {"type": "json_object"} if response_format else {"type": "text"}, + "tool_choice": "auto" if tools else None, + "response_format": ( + {"type": "json_object"} if response_format else {"type": "text"} + ), } base_params.update({k: v for k, v in optional_params.items() if v is not None}) return await self.external_client.chat.completions.create(**base_params) @@ -125,11 +128,13 @@ class ConvergenceClient(OpenAIClient): async def _validate_model(self) -> None: try: response = await self.external_client.models.list() - assert self.config.model_id in [model.id for model in response.data], ( - f"Model {self.config.model_id} not found in {response.data}" - ) + assert self.config.model_id in [ + model.id for model in response.data + ], f"Model {self.config.model_id} not found in {response.data}" self._model_validated = True - logger.debug(f"Model {self.config.model_id} validated and connected to cluster") + logger.debug( + f"Model {self.config.model_id} validated and connected to cluster" + ) except Exception as e: logger.error(f"Error retrieving model: {e}") raise e @@ -160,7 +165,9 @@ async def create_completion( optional_params = { "seed": seed, "tools": self.serializer.serialize_tools(tools) if tools else None, - "tool_choice": "auto" if tools else None, # vLLM does not support "required" + "tool_choice": ( + "auto" if tools else None + ), # vLLM does not support "required" "response_format": response_format if response_format else {"type": "text"}, } base_params.update({k: v for k, v in optional_params.items() if v is not None}) diff --git a/src/proxy_lite/history.py b/src/proxy_lite/history.py index 13e2d98..db03596 100644 --- a/src/proxy_lite/history.py +++ b/src/proxy_lite/history.py @@ -13,6 +13,7 @@ class MessageLabel(str, Enum): USER_INPUT = "user_input" SCREENSHOT = "screenshot" AGENT_MODEL_RESPONSE = "agent_model_response" + TOOL_RESPONSE = "tool_response" MAX_MESSAGES_FOR_CONTEXT_WINDOW = { @@ -74,7 +75,9 @@ def from_media( if text is not None: text = Text(text=text) if image is not None: - base64_image = image if is_base64 else base64.b64encode(image).decode("utf-8") + base64_image = ( + image if is_base64 else base64.b64encode(image).decode("utf-8") + ) data_url = f"data:image/jpeg;base64,{base64_image}" image = Image(image_url=ImageUrl(url=data_url)) content = [text, image] if text is not None else [image] diff --git a/src/proxy_lite/serializer.py b/src/proxy_lite/serializer.py index 8394120..1b17ee4 100644 --- a/src/proxy_lite/serializer.py +++ b/src/proxy_lite/serializer.py @@ -35,5 +35,8 @@ def deserialize_messages(self, data: list[dict]) -> MessageHistory: ) def serialize_tools(self, tools: list[Tool]) -> list[dict]: - tool_schemas = [[{"type": "function", "function": schema} for schema in tool.schema] for tool in tools] + tool_schemas = [ + [{"type": "function", "function": schema} for schema in tool.schema] + for tool in tools + ] return list(itertools.chain.from_iterable(tool_schemas)) diff --git a/src/proxy_lite/solvers/simple_solver.py b/src/proxy_lite/solvers/simple_solver.py index d85dc6a..4ef2bc8 100644 --- a/src/proxy_lite/solvers/simple_solver.py +++ b/src/proxy_lite/solvers/simple_solver.py @@ -43,7 +43,8 @@ def agent(self) -> BaseAgent: @property def history(self) -> MessageHistory: return MessageHistory( - messages=[SystemMessage.from_media(text=self.agent.system_prompt)] + self.agent.history.messages, + messages=[SystemMessage.from_media(text=self.agent.system_prompt)] + + self.agent.history.messages, ) async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> None: @@ -56,6 +57,14 @@ async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> N self.logger.debug(f"Initialised with task: {task}") async def act(self, observation: Observation) -> Action: + # If the previous env step contained tool responses, convert them + if observation.state.tool_responses: + for resp in observation.state.tool_responses: + self.agent.receive_tool_message( + text=resp.content or "", + tool_id=resp.id, + label=MessageLabel.TOOL_RESPONSE, + ) self.agent.receive_user_message( image=observation.state.image, text=observation.state.text, @@ -68,7 +77,10 @@ async def act(self, observation: Observation) -> Action: self.logger.debug(f"Assistant message generated: {message}") # check tool calls for return_value - if any(tool_call.function["name"] == "return_value" for tool_call in message.tool_calls): + if any( + tool_call.function["name"] == "return_value" + for tool_call in message.tool_calls + ): self.complete = True arguments = json.loads(message.tool_calls[0].function["arguments"]) if isinstance(arguments, str): @@ -78,15 +90,23 @@ async def act(self, observation: Observation) -> Action: text_content = message.content[0].text - observation_match = re.search(r"(.*?)", text_content, re.DOTALL) - observation_content = observation_match.group(1).strip() if observation_match else "" + observation_match = re.search( + r"(.*?)", text_content, re.DOTALL + ) + observation_content = ( + observation_match.group(1).strip() if observation_match else "" + ) self.logger.info("🌐 [bold blue]Observation:[/]") await self.logger.stream_message(observation_content) # Extract text between thinking tags if present - thinking_match = re.search(r"(.*?)", text_content, re.DOTALL) - thinking_content = thinking_match.group(1).strip() if thinking_match else text_content + thinking_match = re.search( + r"(.*?)", text_content, re.DOTALL + ) + thinking_content = ( + thinking_match.group(1).strip() if thinking_match else text_content + ) self.logger.info("🧠 [bold purple]Thinking:[/]") await self.logger.stream_message(thinking_content) From af666dab1e4fcf8d4c5b0bf6cffa083e1a42d444 Mon Sep 17 00:00:00 2001 From: Federico C Date: Mon, 19 May 2025 13:21:42 +0200 Subject: [PATCH 2/2] Await tools --- src/proxy_lite/history.py | 1 - src/proxy_lite/solvers/simple_solver.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/proxy_lite/history.py b/src/proxy_lite/history.py index db03596..df632cb 100644 --- a/src/proxy_lite/history.py +++ b/src/proxy_lite/history.py @@ -13,7 +13,6 @@ class MessageLabel(str, Enum): USER_INPUT = "user_input" SCREENSHOT = "screenshot" AGENT_MODEL_RESPONSE = "agent_model_response" - TOOL_RESPONSE = "tool_response" MAX_MESSAGES_FOR_CONTEXT_WINDOW = { diff --git a/src/proxy_lite/solvers/simple_solver.py b/src/proxy_lite/solvers/simple_solver.py index 4ef2bc8..f29110f 100644 --- a/src/proxy_lite/solvers/simple_solver.py +++ b/src/proxy_lite/solvers/simple_solver.py @@ -60,10 +60,9 @@ async def act(self, observation: Observation) -> Action: # If the previous env step contained tool responses, convert them if observation.state.tool_responses: for resp in observation.state.tool_responses: - self.agent.receive_tool_message( + await self.agent.receive_tool_message( text=resp.content or "", tool_id=resp.id, - label=MessageLabel.TOOL_RESPONSE, ) self.agent.receive_user_message( image=observation.state.image,