diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index f9cdb619f..5b723e120 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -25,7 +25,7 @@ from collections import deque from logging import getLogger from pathlib import Path -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union import jinja2 import yaml @@ -36,6 +36,10 @@ from rich.rule import Rule from rich.text import Text + +if TYPE_CHECKING: + import PIL.Image + from .agent_types import AgentAudio, AgentImage, AgentType, handle_agent_output_types from .default_tools import TOOL_MAPPING, FinalAnswerTool from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonExecutor, PythonExecutor, fix_final_answer_code @@ -264,7 +268,7 @@ def run( task: str, stream: bool = False, reset: bool = True, - images: Optional[List[str]] = None, + images: Optional[List["PIL.Image.Image"]] = None, additional_args: Optional[Dict] = None, max_steps: Optional[int] = None, ): @@ -275,7 +279,7 @@ def run( task (`str`): Task to perform. stream (`bool`): Whether to run in a streaming way. reset (`bool`): Whether to reset the conversation or keep it going from previous run. - images (`list[str]`, *optional*): Paths to image(s). + images (`list[PIL.Image.Image]`, *optional*): Image(s) objects. additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names! max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value. @@ -319,7 +323,7 @@ def run( return deque(self._run(task=self.task, max_steps=max_steps, images=images), maxlen=1)[0] def _run( - self, task: str, max_steps: int, images: List[str] | None = None + self, task: str, max_steps: int, images: List["PIL.Image.Image"] | None = None ) -> Generator[ActionStep | AgentType, None, None]: final_answer = None self.step_number = 1 @@ -344,7 +348,7 @@ def _run( yield memory_step yield handle_agent_output_types(final_answer) - def _create_memory_step(self, step_start_time: float, images: List[str] | None) -> ActionStep: + def _create_memory_step(self, step_start_time: float, images: List["PIL.Image.Image"] | None) -> ActionStep: return ActionStep(step_number=self.step_number, start_time=step_start_time, observations_images=images) def _execute_step(self, task: str, memory_step: ActionStep) -> Union[None, Any]: @@ -373,7 +377,7 @@ def _finalize_step(self, memory_step: ActionStep, step_start_time: float): memory_step, agent=self ) - def _handle_max_steps_reached(self, task: str, images: List[str], step_start_time: float) -> Any: + def _handle_max_steps_reached(self, task: str, images: List["PIL.Image.Image"], step_start_time: float) -> Any: final_answer = self.provide_final_answer(task, images) final_memory_step = ActionStep( step_number=self.step_number, error=AgentMaxStepsError("Reached max steps.", self.logger) @@ -557,13 +561,13 @@ def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str] ) return rationale.strip(), action.strip() - def provide_final_answer(self, task: str, images: Optional[list[str]]) -> str: + def provide_final_answer(self, task: str, images: Optional[list["PIL.Image.Image"]]) -> str: """ Provide the final answer to the task, based on the logs of the agent's interactions. Args: task (`str`): Task to perform. - images (`list[str]`, *optional*): Paths to image(s). + images (`list[PIL.Image.Image]`, *optional*): Image(s) objects. Returns: `str`: Final answer to the task. diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index 7eabbea4f..d6489ae0a 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: + import PIL.Image + from smolagents.models import ChatMessage from smolagents.monitoring import AgentLogger @@ -58,7 +60,7 @@ class ActionStep(MemoryStep): model_output_message: ChatMessage = None model_output: str | None = None observations: str | None = None - observations_images: List[str] | None = None + observations_images: List["PIL.Image.Image"] | None = None action_output: Any = None def dict(self): @@ -169,7 +171,7 @@ def to_messages(self, summary_mode: bool, **kwargs) -> List[Message]: @dataclass class TaskStep(MemoryStep): task: str - task_images: List[str] | None = None + task_images: List["PIL.Image.Image"] | None = None def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]: content = [{"type": "text", "text": f"New task:\n{self.task}"}]