Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,9 @@ def _run_stream(
if not returned_final_answer and self.step_number == max_steps + 1:
final_answer = self._handle_max_steps_reached(task)
yield action_step
yield FinalAnswerStep(handle_agent_output_types(final_answer))
final_answer_step = FinalAnswerStep(handle_agent_output_types(final_answer))
self._finalize_step(final_answer_step)
yield final_answer_step

def _validate_final_answer(self, final_answer: Any):
for check_function in self.final_answer_checks:
Expand All @@ -614,8 +616,9 @@ def _validate_final_answer(self, final_answer: Any):
except Exception as e:
raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger)

def _finalize_step(self, memory_step: ActionStep | PlanningStep):
memory_step.timing.end_time = time.time()
def _finalize_step(self, memory_step: ActionStep | PlanningStep | FinalAnswerStep):
if not isinstance(memory_step, FinalAnswerStep):
memory_step.timing.end_time = time.time()
self.step_callbacks.callback(memory_step, agent=self)

def _handle_max_steps_reached(self, task: str) -> Any:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,12 +1150,14 @@ def test_finalize_step_callbacks_by_type(self):
action_step_callback_2 = MagicMock()
planning_step_callback = MagicMock()
step_callback = MagicMock()
final_answer_step_callback = MagicMock()

# Register callbacks for different step types
step_callbacks = {
ActionStep: [action_step_callback, action_step_callback_2],
PlanningStep: planning_step_callback,
MemoryStep: step_callback,
FinalAnswerStep: final_answer_step_callback,
}
agent = DummyMultiStepAgent(tools=[], model=MagicMock(), step_callbacks=step_callbacks)

Expand All @@ -1167,6 +1169,7 @@ def test_finalize_step_callbacks_by_type(self):
model_output_message=ChatMessage(role="assistant", content="Test plan"),
plan="Test planning step",
)
final_answer_step=FinalAnswerStep(output="Sample output")

# Test with ActionStep
agent._finalize_step(action_step)
Expand All @@ -1176,12 +1179,14 @@ def test_finalize_step_callbacks_by_type(self):
action_step_callback_2.assert_called_once_with(action_step, agent=agent)
step_callback.assert_called_once_with(action_step, agent=agent)
planning_step_callback.assert_not_called()
final_answer_step_callback.assert_not_called()

# Reset mocks
action_step_callback.reset_mock()
action_step_callback_2.reset_mock()
planning_step_callback.reset_mock()
step_callback.reset_mock()
final_answer_step_callback.reset_mock()

# Test with PlanningStep
agent._finalize_step(planning_step)
Expand All @@ -1191,6 +1196,24 @@ def test_finalize_step_callbacks_by_type(self):
step_callback.assert_called_once_with(planning_step, agent=agent)
action_step_callback.assert_not_called()
action_step_callback_2.assert_not_called()
final_answer_step_callback.assert_not_called()

# Reset mocks
action_step_callback.reset_mock()
action_step_callback_2.reset_mock()
planning_step_callback.reset_mock()
step_callback.reset_mock()
final_answer_step_callback.reset_mock()

# Test with PlanningStep
agent._finalize_step(final_answer_step)

# Verify correct callbacks were called
planning_step_callback.assert_not_called()
step_callback.assert_called_once_with(final_answer_step, agent=agent)
action_step_callback.assert_not_called()
action_step_callback_2.assert_not_called()
final_answer_step_callback.assert_called_once_with(final_answer_step, agent=agent)

def test_logs_display_thoughts_even_if_error(self):
class FakeJsonModelNoCall(Model):
Expand Down