Skip to content

Commit

Permalink
feat(agent):support chat agent scene use dbgpts (#1075)
Browse files Browse the repository at this point in the history
  • Loading branch information
yhjun1026 authored Jan 16, 2024
1 parent fb2d18a commit 74eb15e
Show file tree
Hide file tree
Showing 82 changed files with 1,073 additions and 629 deletions.
File renamed without changes.
18 changes: 11 additions & 7 deletions dbgpt/agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@ def describe(self) -> str:
"""Get the name of the agent."""
return self._describe

@property
def is_terminal_agent(self) -> bool:
return False
async def a_notification(
self,
message: Union[Dict, str],
recipient: Agent,
reviewer: Agent,
):
"""Notification a message to recipient agent(Receive a record message from the notification and process it according to your own process. You cannot send the message through send and directly return the current final result.)"""

async def a_send(
self,
Expand All @@ -57,7 +61,7 @@ async def a_send(
request_reply: Optional[bool] = True,
is_recovery: Optional[bool] = False,
) -> None:
"""(Abstract async method) Send a message to another agent."""
"""(Abstract async method) Send a message to recipient agent."""

async def a_receive(
self,
Expand All @@ -84,9 +88,6 @@ async def a_review(
Any: the censored message
"""

def reset(self) -> None:
"""(Abstract method) Reset the agent."""

async def a_generate_reply(
self,
message: Optional[Dict],
Expand Down Expand Up @@ -150,6 +151,9 @@ async def a_verify_reply(
"""

def reset(self) -> None:
"""(Abstract method) Reset the agent."""


@dataclasses.dataclass
class AgentResource:
Expand Down
90 changes: 90 additions & 0 deletions dbgpt/agent/agents/agents_manage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import logging
import re
from collections import defaultdict
from typing import Dict, List, Optional, Type

from .agent import Agent
from .expand.code_assistant_agent import CodeAssistantAgent
from .expand.dashboard_assistant_agent import DashboardAssistantAgent
from .expand.data_scientist_agent import DataScientistAgent
from .expand.plugin_assistant_agent import PluginAssistantAgent
from .expand.sql_assistant_agent import SQLAssistantAgent
from .expand.summary_assistant_agent import SummaryAssistantAgent

logger = logging.getLogger(__name__)


def get_all_subclasses(cls):
all_subclasses = []
direct_subclasses = cls.__subclasses__()
all_subclasses.extend(direct_subclasses)

for subclass in direct_subclasses:
all_subclasses.extend(get_all_subclasses(subclass))
return all_subclasses


def participant_roles(agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
agents = agents

roles = []
for agent in agents:
if agent.system_message.strip() == "":
logger.warning(
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."
)
roles.append(f"{agent.name}: {agent.describe}")
return "\n".join(roles)


def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict:
"""
Finds and counts agent mentions in the string message_content, taking word boundaries into account.
Returns: A dictionary mapping agent names to mention counts (to be included, at least one mention must occur)
"""
mentions = dict()
for agent in agents:
regex = (
r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
) # Finds agent mentions, taking word boundaries into account
count = len(
re.findall(regex, " " + message_content + " ")
) # Pad the message to help with matching
if count > 0:
mentions[agent.name] = count
return mentions


class AgentsManage:
def __init__(self):
self._agents = defaultdict()

def register_agent(self, cls):
self._agents[cls.NAME] = cls

def get_by_name(self, name: str) -> Optional[Type[Agent]]:
if name not in self._agents:
raise ValueError(f"Agent:{name} not register!")
return self._agents[name]

def get_describe_by_name(self, name: str) -> Optional[Type[Agent]]:
return self._agents[name].DEFAULT_DESCRIBE

def all_agents(self):
result = {}
for name, cls in self._agents.items():
result[name] = cls.DEFAULT_DESCRIBE
return result


agent_manage = AgentsManage()

agent_manage.register_agent(CodeAssistantAgent)
agent_manage.register_agent(DashboardAssistantAgent)
agent_manage.register_agent(DataScientistAgent)
agent_manage.register_agent(SQLAssistantAgent)
agent_manage.register_agent(SummaryAssistantAgent)
agent_manage.register_agent(PluginAssistantAgent)
158 changes: 81 additions & 77 deletions dbgpt/agent/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ def __init__(
memory: GptsMemory = GptsMemory(),
agent_context: AgentContext = None,
system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "TERMINATE",
default_auto_reply: Optional[Union[str, Dict, None]] = "",
is_terminal_agent: bool = False,
):
super().__init__(name, memory, describe)

Expand All @@ -43,11 +41,6 @@ def __init__(
{"content": system_message, "role": ModelMessageRoleType.SYSTEM}
]
self._rely_messages = []
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
else (lambda x: x.get("content") == "TERMINATE")
)

self.client = AIWrapper(llm_client=agent_context.llm_provider)

Expand All @@ -58,9 +51,8 @@ def __init__(
else self.MAX_CONSECUTIVE_AUTO_REPLY
)
self.consecutive_auto_reply_counter: int = 0
self._current_retry_counter: int = 0
self._max_retry_count: int = 5
self._is_terminal_agent = is_terminal_agent
self.current_retry_counter: int = 0
self.max_retry_count: int = 5

## By default, the memory of 4 rounds of dialogue is retained.
self.dialogue_memory_rounds = 5
Expand Down Expand Up @@ -93,9 +85,18 @@ def register_reply(
},
)

@property
def is_terminal_agent(self):
return self._is_terminal_agent
def is_termination_msg(self, message: Union[Dict, str, bool]):
if isinstance(message, dict):
if "is_termination" in message:
return message.get("is_termination", False)
else:
return message["content"].find("TERMINATE") >= 0
elif isinstance(message, bool):
return message
elif isinstance(message, str):
return message.find("TERMINATE") >= 0
else:
return False

@property
def system_message(self):
Expand Down Expand Up @@ -279,6 +280,70 @@ async def a_send(
is_recovery=is_recovery,
)

async def a_receive(
self,
message: Optional[Dict],
sender: Agent,
reviewer: "Agent",
request_reply: Optional[bool] = True,
silent: Optional[bool] = False,
is_recovery: Optional[bool] = False,
):
if not is_recovery:
self.consecutive_auto_reply_counter = (
sender.consecutive_auto_reply_counter + 1
)
self.process_received_message(message, sender, silent)
else:
logger.info("Process received retrying")
self.consecutive_auto_reply_counter = sender.consecutive_auto_reply_counter
if request_reply is False or request_reply is None:
logger.info("Messages that do not require a reply")
return
if self.is_termination_msg(message):
logger.info(f"TERMINATE!")
return

verify_paas, reply = await self.a_generate_reply(
message=message, sender=sender, reviewer=reviewer, silent=silent
)
if verify_paas:
await self.a_send(
message=reply, recipient=sender, reviewer=reviewer, silent=silent
)
else:
# Exit after the maximum number of rounds of self-optimization
if self.current_retry_counter >= self.max_retry_count:
# If the maximum number of retries is exceeded, the abnormal answer will be returned directly.
logger.warning(
f"More than {self.current_retry_counter} times and still no valid answer is output."
)
reply[
"content"
] = f"After trying {self.current_retry_counter} times, I still can't generate a valid answer. The current problem is:{reply['content']}!"
reply["is_termination"] = True
await self.a_send(
message=reply, recipient=sender, reviewer=reviewer, silent=silent
)
# raise ValueError(
# f"After {self.current_retry_counter} rounds of re-optimization, we still cannot get an effective answer."
# )
else:
self.current_retry_counter += 1
logger.info(
"The generated answer failed to verify, so send it to yourself for optimization."
)
await sender.a_send(
message=reply, recipient=self, reviewer=reviewer, silent=silent
)

async def a_notification(
self,
message: Union[Dict, str],
recipient: Agent,
):
recipient.process_received_message(message, self)

def _print_received_message(self, message: Union[Dict, str], sender: Agent):
# print the message received
print(
Expand Down Expand Up @@ -329,7 +394,7 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
print(colored(action_print, "blue"), flush=True)
print("\n", "-" * 80, flush=True, sep="")

def _process_received_message(self, message, sender, silent):
def process_received_message(self, message, sender, silent):
message = self._message_to_dict(message)
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
valid = self.append_message(message, None, sender)
Expand Down Expand Up @@ -440,7 +505,6 @@ async def a_generate_reply(

## 1.LLM Reasonging
await self.a_system_fill_param()
await asyncio.sleep(5) ##TODO Rate limit reached for gpt-3.5-turbo
current_messages = self.process_now_message(message, sender, rely_messages)
ai_reply, model = await self.a_reasoning_reply(messages=current_messages)
new_message["content"] = ai_reply
Expand All @@ -462,61 +526,6 @@ async def a_generate_reply(
## 4.verify reply
return await self.a_verify_reply(new_message, sender, reviewer)

async def a_receive(
self,
message: Optional[Dict],
sender: Agent,
reviewer: "Agent",
request_reply: Optional[bool] = True,
silent: Optional[bool] = False,
is_recovery: Optional[bool] = False,
):
if not is_recovery:
self.consecutive_auto_reply_counter = (
sender.consecutive_auto_reply_counter + 1
)
self._process_received_message(message, sender, silent)

else:
logger.info("Process received retrying")
self.consecutive_auto_reply_counter = sender.consecutive_auto_reply_counter
if request_reply is False or request_reply is None:
logger.info("Messages that do not require a reply")
return
if self._is_termination_msg(message) or sender.is_terminal_agent:
logger.info(f"TERMINATE!")
return

verify_paas, reply = await self.a_generate_reply(
message=message, sender=sender, reviewer=reviewer, silent=silent
)

if verify_paas:
await self.a_send(
message=reply, recipient=sender, reviewer=reviewer, silent=silent
)
else:
# Exit after the maximum number of rounds of self-optimization
if self._current_retry_counter >= self._max_retry_count:
# If the maximum number of retries is exceeded, the abnormal answer will be returned directly.
logger.warning(
f"More than {self._current_retry_counter} times and still no valid answer is output."
)
reply[
"content"
] = f"After n optimizations, the following problems still exist:{reply['content']}"
await self.a_send(
message=reply, recipient=sender, reviewer=reviewer, silent=silent
)
else:
self._current_retry_counter += 1
logger.info(
"The generated answer failed to verify, so send it to yourself for optimization."
)
await sender.a_send(
message=reply, recipient=self, reviewer=reviewer, silent=silent
)

async def a_verify(self, message: Optional[Dict]):
return True, message

Expand Down Expand Up @@ -573,7 +582,7 @@ async def a_verify_reply(
return False, retry_message
else:
## The verification passes, the message is released, and the number of retries returns to 0.
self._current_retry_counter = 0
self.current_retry_counter = 0
return True, message

async def a_retry_chat(
Expand Down Expand Up @@ -622,7 +631,6 @@ async def a_initiate_chat(
await self.a_send(
{
"content": self.generate_init_message(**context),
"current_gogal": self.generate_init_message(**context),
},
recipient,
reviewer,
Expand Down Expand Up @@ -652,10 +660,6 @@ def clear_history(self, agent: Optional[Agent] = None):
agent: the agent with whom the chat history to clear. If None, clear the chat history with all agents.
"""
pass
# if agent is None:
# self._oai_messages.clear()
# else:
# self._oai_messages[agent].clear()

def _get_model_priority(self):
llm_models_priority = self.agent_context.model_priority
Expand Down Expand Up @@ -727,7 +731,7 @@ async def a_reasoning_reply(
retry_count += 1
last_model = llm_model
last_err = str(e)
await asyncio.sleep(10) ## TODO,Rate limit reached for gpt-3.5-turbo
await asyncio.sleep(15) ## TODO,Rate limit reached for gpt-3.5-turbo

if last_err:
raise ValueError(last_err)
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/agent/agents/base_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def a_run_chat(
pass


class MangerAgent(ConversableAgent, Team):
class ManagerAgent(ConversableAgent, Team):
def __init__(
self,
name: str,
Expand Down
Loading

0 comments on commit 74eb15e

Please sign in to comment.