|
| 1 | +import json |
| 2 | +from typing import List, Tuple |
| 3 | + |
| 4 | +from autogen_core import ( |
| 5 | + FunctionCall, |
| 6 | + MessageContext, |
| 7 | + RoutedAgent, |
| 8 | + TopicId, |
| 9 | + message_handler, |
| 10 | +) |
| 11 | +from autogen_core.models import ( |
| 12 | + AssistantMessage, |
| 13 | + ChatCompletionClient, |
| 14 | + FunctionExecutionResult, |
| 15 | + FunctionExecutionResultMessage, |
| 16 | + SystemMessage |
| 17 | +) |
| 18 | +from autogen_core.tools import Tool |
| 19 | +from models import UserTask,AgentResponse |
| 20 | +import asyncio |
| 21 | + |
| 22 | + |
| 23 | + |
| 24 | +class AIAgent(RoutedAgent): |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + description: str, |
| 28 | + system_message: SystemMessage, |
| 29 | + model_client: ChatCompletionClient, |
| 30 | + tools: List[Tool], |
| 31 | + delegate_tools: List[Tool], |
| 32 | + agent_topic_type: str, |
| 33 | + user_topic_type: str, |
| 34 | + response_queue : asyncio.Queue[str | object] |
| 35 | + ) -> None: |
| 36 | + super().__init__(description) |
| 37 | + self._system_message = system_message |
| 38 | + self._model_client = model_client |
| 39 | + self._tools = dict([(tool.name, tool) for tool in tools]) |
| 40 | + self._tool_schema = [tool.schema for tool in tools] |
| 41 | + self._delegate_tools = dict([(tool.name, tool) for tool in delegate_tools]) |
| 42 | + self._delegate_tool_schema = [tool.schema for tool in delegate_tools] |
| 43 | + self._agent_topic_type = agent_topic_type |
| 44 | + self._user_topic_type = user_topic_type |
| 45 | + self._response_queue = response_queue |
| 46 | + |
| 47 | + @message_handler |
| 48 | + async def handle_task(self, message: UserTask, ctx: MessageContext) -> None: |
| 49 | + # Start streaming LLM responses |
| 50 | + llm_stream = self._model_client.create_stream( |
| 51 | + messages=[self._system_message] + message.context, |
| 52 | + tools=self._tool_schema + self._delegate_tool_schema, |
| 53 | + cancellation_token=ctx.cancellation_token |
| 54 | + ) |
| 55 | + final_response = None |
| 56 | + async for chunk in llm_stream: |
| 57 | + if isinstance(chunk, str): |
| 58 | + await self._response_queue.put({'type': "string", 'message': chunk}) |
| 59 | + else: |
| 60 | + final_response = chunk |
| 61 | + assert final_response is not None, "No response from model" |
| 62 | + print(f"{'-'*80}\n{self.id.type}:\n{final_response.content}", flush=True) |
| 63 | + # Process the LLM result. |
| 64 | + while isinstance(final_response.content, list) and all(isinstance(m, FunctionCall) for m in final_response.content): |
| 65 | + tool_call_results: List[FunctionExecutionResult] = [] |
| 66 | + delegate_targets: List[Tuple[str, UserTask]] = [] |
| 67 | + # Process each function call. |
| 68 | + for call in final_response.content: |
| 69 | + arguments = json.loads(call.arguments) |
| 70 | + await self._response_queue.put({"type":"function","message":f"Executing {call.name}"}) |
| 71 | + if call.name in self._tools: |
| 72 | + # Execute the tool directly. |
| 73 | + result = await self._tools[call.name].run_json(arguments, ctx.cancellation_token) |
| 74 | + result_as_str = self._tools[call.name].return_value_as_string(result) |
| 75 | + tool_call_results.append( |
| 76 | + FunctionExecutionResult(call_id=call.id, content=result_as_str, is_error=False, name=call.name) |
| 77 | + ) |
| 78 | + elif call.name in self._delegate_tools: |
| 79 | + # Execute the tool to get the delegate agent's topic type. |
| 80 | + result = await self._delegate_tools[call.name].run_json(arguments, ctx.cancellation_token) |
| 81 | + topic_type = self._delegate_tools[call.name].return_value_as_string(result) |
| 82 | + # Create the context for the delegate agent, including the function call and the result. |
| 83 | + delegate_messages = list(message.context) + [ |
| 84 | + AssistantMessage(content=[call], source=self.id.type), |
| 85 | + FunctionExecutionResultMessage( |
| 86 | + content=[ |
| 87 | + FunctionExecutionResult( |
| 88 | + call_id=call.id, |
| 89 | + content=f"Transferred to {topic_type}. Adopt persona immediately.", |
| 90 | + is_error=False, |
| 91 | + name=call.name, |
| 92 | + ) |
| 93 | + ] |
| 94 | + ), |
| 95 | + ] |
| 96 | + delegate_targets.append((topic_type, UserTask(context=delegate_messages))) |
| 97 | + else: |
| 98 | + raise ValueError(f"Unknown tool: {call.name}") |
| 99 | + if len(delegate_targets) > 0: |
| 100 | + # Delegate the task to other agents by publishing messages to the corresponding topics. |
| 101 | + for topic_type, task in delegate_targets: |
| 102 | + print(f"{'-'*80}\n{self.id.type}:\nDelegating to {topic_type}", flush=True) |
| 103 | + await self._response_queue.put({"type":"function","message":f"You are now talking to {topic_type}"}) |
| 104 | + await self.publish_message(task, topic_id=TopicId(topic_type, source=self.id.key)) |
| 105 | + if len(tool_call_results) > 0: |
| 106 | + print(f"{'-'*80}\n{self.id.type}:\n{tool_call_results}", flush=True) |
| 107 | + # Make another LLM call with the results. |
| 108 | + message.context.extend([ |
| 109 | + AssistantMessage(content=final_response.content, source=self.id.type), |
| 110 | + FunctionExecutionResultMessage(content=tool_call_results), |
| 111 | + ]) |
| 112 | + llm_stream = self._model_client.create_stream( |
| 113 | + messages=[self._system_message] + message.context, |
| 114 | + tools=self._tool_schema + self._delegate_tool_schema, |
| 115 | + cancellation_token=ctx.cancellation_token |
| 116 | + ) |
| 117 | + final_response = None |
| 118 | + async for chunk in llm_stream: |
| 119 | + if isinstance(chunk, str): |
| 120 | + await self._response_queue.put({'type': 'string', 'message': chunk}) |
| 121 | + else: |
| 122 | + final_response = chunk |
| 123 | + assert final_response is not None, "No response from model" |
| 124 | + print(f"{'-'*80}\n{self.id.type}:\n{final_response.content}", flush=True) |
| 125 | + else: |
| 126 | + # The task has been delegated, so we are done. |
| 127 | + return |
| 128 | + # The task has been completed, publish the final result. |
| 129 | + assert isinstance(final_response.content, str) |
| 130 | + message.context.append(AssistantMessage(content=final_response.content, source=self.id.type)) |
| 131 | + await self.publish_message( |
| 132 | + AgentResponse(context=message.context, reply_to_topic_type=self._agent_topic_type), |
| 133 | + topic_id=TopicId(self._user_topic_type, source=self.id.key), |
| 134 | + ) |
0 commit comments