Skip to content

Commit

Permalink
feat: Support intent detection (eosphoros-ai#1588)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored May 30, 2024
1 parent 73d175a commit a88af6f
Show file tree
Hide file tree
Showing 22 changed files with 881 additions and 54 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fmt: setup ## Format Python code
$(VENV_BIN)/blackdoc examples
# TODO: Use flake8 to enforce Python style guide.
# https://flake8.pycqa.org/en/latest/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/ dbgpt/experimental/
# TODO: More package checks with flake8.

.PHONY: fmt-check
Expand All @@ -56,7 +56,7 @@ fmt-check: setup ## Check Python code formatting and style without making change
$(VENV_BIN)/isort --check-only --extend-skip="examples/notebook" examples
$(VENV_BIN)/black --check --extend-exclude="examples/notebook" .
$(VENV_BIN)/blackdoc --check dbgpt examples
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/ dbgpt/experimental/

.PHONY: pre-commit
pre-commit: fmt-check test test-doc mypy ## Run formatting and unit tests before committing
Expand All @@ -72,7 +72,7 @@ test-doc: $(VENV)/.testenv ## Run doctests
.PHONY: mypy
mypy: $(VENV)/.testenv ## Run mypy checks
# https://github.com/python/mypy
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ dbgpt/datasource/ dbgpt/client/ dbgpt/agent/ dbgpt/vis/ dbgpt/experimental/
# rag depends on core and storage, so we not need to check it again.
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/storage/
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
Expand Down
3 changes: 3 additions & 0 deletions assets/schema/upgrade/v_0_5_7/upgrade_to_v0.5.7.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
USE dbgpt;
ALTER TABLE dbgpt_serve_flow
ADD COLUMN `define_type` varchar(32) null comment 'Flow define type(json or python)' after `version`;
395 changes: 395 additions & 0 deletions assets/schema/upgrade/v_0_5_7/v0.5.6.sql

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dbgpt/_private/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PositiveInt,
PrivateAttr,
ValidationError,
WithJsonSchema,
field_validator,
model_validator,
root_validator,
Expand Down
7 changes: 0 additions & 7 deletions dbgpt/app/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,6 @@ def stream_plugin_call(self, text):
def stream_call_reinforce_fn(self, text):
return text

async def check_iterator_end(iterator):
try:
await asyncio.anext(iterator)
return False # 迭代器还有下一个元素
except StopAsyncIteration:
return True # 迭代器已经执行结束

def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
del metadata["prompt"]
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/core/awel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .operators.common_operator import (
BranchFunc,
BranchOperator,
BranchTaskType,
InputOperator,
JoinOperator,
MapOperator,
Expand Down Expand Up @@ -80,6 +81,7 @@
"BranchOperator",
"InputOperator",
"BranchFunc",
"BranchTaskType",
"WorkflowRunner",
"TaskState",
"is_empty_data",
Expand Down
8 changes: 6 additions & 2 deletions dbgpt/core/awel/dag/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
DAGLoader will load DAGs from dag_dirs or other sources.
Now only support load DAGs from local files.
"""

import hashlib
import logging
import os
Expand Down Expand Up @@ -98,15 +99,18 @@ def parse(mod_name, filepath):
return parse(mod_name, filepath)


def _process_modules(mods) -> List[DAG]:
def _process_modules(mods, show_log: bool = True) -> List[DAG]:
top_level_dags = (
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
)
found_dags = []
for dag, mod in top_level_dags:
try:
# TODO validate dag params
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
if show_log:
logger.info(
f"Found dag {dag} from mod {mod} and model file {mod.__file__}"
)
found_dags.append(dag)
except Exception:
msg = traceback.format_exc()
Expand Down
43 changes: 39 additions & 4 deletions dbgpt/core/awel/flow/flow_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast

from typing_extensions import Annotated

from dbgpt._private.pydantic import (
BaseModel,
ConfigDict,
Field,
WithJsonSchema,
field_validator,
model_to_dict,
model_validator,
Expand Down Expand Up @@ -255,9 +259,27 @@ def value_of(cls, value: Optional[str]) -> "FlowCategory":
raise ValueError(f"Invalid flow category value: {value}")


_DAGModel = Annotated[
DAG,
WithJsonSchema(
{
"type": "object",
"properties": {
"task_name": {"type": "string", "description": "Dummy task name"}
},
"description": "DAG model, not used in the serialization.",
}
),
]


class FlowPanel(BaseModel):
"""Flow panel."""

model_config = ConfigDict(
arbitrary_types_allowed=True, json_encoders={DAG: lambda v: None}
)

uid: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Flow panel uid",
Expand All @@ -277,7 +299,8 @@ class FlowPanel(BaseModel):
description="Flow category",
examples=[FlowCategory.COMMON, FlowCategory.CHAT_AGENT],
)
flow_data: FlowData = Field(..., description="Flow data")
flow_data: Optional[FlowData] = Field(None, description="Flow data")
flow_dag: Optional[_DAGModel] = Field(None, description="Flow DAG", exclude=True)
description: Optional[str] = Field(
None,
description="Flow panel description",
Expand Down Expand Up @@ -305,6 +328,11 @@ class FlowPanel(BaseModel):
description="Version of the flow panel",
examples=["0.1.0", "0.2.0"],
)
define_type: Optional[str] = Field(
"json",
description="Define type of the flow panel",
examples=["json", "python"],
)
editable: bool = Field(
True,
description="Whether the flow panel is editable",
Expand Down Expand Up @@ -344,7 +372,7 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:

def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self)
return model_to_dict(self, exclude={"flow_dag"})


class FlowFactory:
Expand All @@ -356,7 +384,9 @@ def __init__(self, dag_prefix: str = "flow_dag"):

def build(self, flow_panel: FlowPanel) -> DAG:
"""Build the flow."""
flow_data = flow_panel.flow_data
if not flow_panel.flow_data:
raise ValueError("Flow data is required.")
flow_data = cast(FlowData, flow_panel.flow_data)
key_to_operator_nodes: Dict[str, FlowNodeData] = {}
key_to_resource_nodes: Dict[str, FlowNodeData] = {}
key_to_resource: Dict[str, ResourceMetadata] = {}
Expand Down Expand Up @@ -610,7 +640,10 @@ def pre_load_requirements(self, flow_panel: FlowPanel):
"""
from dbgpt.util.module_utils import import_from_string

flow_data = flow_panel.flow_data
if not flow_panel.flow_data:
return

flow_data = cast(FlowData, flow_panel.flow_data)
for node in flow_data.nodes:
if node.data.is_operator:
node_data = cast(ViewMetadata, node.data)
Expand Down Expand Up @@ -709,6 +742,8 @@ def fill_flow_panel(flow_panel: FlowPanel):
Args:
flow_panel (FlowPanel): The flow panel to fill.
"""
if not flow_panel.flow_data:
return
for node in flow_panel.flow_data.nodes:
try:
parameters_map = {}
Expand Down
12 changes: 11 additions & 1 deletion dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base classes for operators that can be executed within a workflow."""

import asyncio
import functools
from abc import ABC, ABCMeta, abstractmethod
Expand Down Expand Up @@ -265,7 +266,16 @@ async def call_stream(
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
)
return out_ctx.current_task_context.task_output.output_stream

task_output = out_ctx.current_task_context.task_output
if task_output.is_stream:
return out_ctx.current_task_context.task_output.output_stream
else:

async def _gen():
yield task_output.output

return _gen()

def _blocking_call_stream(
self,
Expand Down
32 changes: 28 additions & 4 deletions dbgpt/core/awel/operators/common_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common operators of AWEL."""

import asyncio
import logging
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Union
Expand Down Expand Up @@ -171,6 +172,8 @@ async def map(self, input_value: IN) -> OUT:


BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
# Function that return the task name
BranchTaskType = Union[str, Callable[[IN], str], Callable[[IN], Awaitable[str]]]


class BranchOperator(BaseOperator, Generic[IN, OUT]):
Expand All @@ -187,7 +190,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):

def __init__(
self,
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
branches: Optional[Dict[BranchFunc[IN], BranchTaskType]] = None,
**kwargs,
):
"""Create a BranchDAGNode with a branching function.
Expand All @@ -208,6 +211,10 @@ def __init__(
if not value.node_name:
raise ValueError("branch node name must be set")
branches[branch_function] = value.node_name
elif callable(value):
raise ValueError(
"BranchTaskType must be str or BaseOperator on init"
)
self._branches = branches

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
Expand All @@ -234,14 +241,31 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
branches = await self.branches()

branch_func_tasks = []
branch_nodes: List[Union[BaseOperator, str]] = []
branch_name_tasks = []
# branch_nodes: List[Union[BaseOperator, str]] = []
for func, node_name in branches.items():
branch_nodes.append(node_name)
# branch_nodes.append(node_name)
branch_func_tasks.append(
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
)
if callable(node_name):

async def map_node_name(func) -> str:
input_context = await curr_task_ctx.task_input.map(func)
task_name = input_context.parent_outputs[0].task_output.output
return task_name

branch_name_tasks.append(map_node_name(node_name))

else:

async def _tmp_map_node_name(task_name: str) -> str:
return task_name

branch_name_tasks.append(_tmp_map_node_name(node_name))

branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
branch_nodes: List[str] = await asyncio.gather(*branch_name_tasks)
parent_output = task_input.parent_outputs[0].task_output
curr_task_ctx.set_task_output(parent_output)
skip_node_names = []
Expand All @@ -258,7 +282,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
return parent_output

async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
async def branches(self) -> Dict[BranchFunc[IN], BranchTaskType]:
"""Return branch logic based on input data."""
raise NotImplementedError

Expand Down
20 changes: 15 additions & 5 deletions dbgpt/core/interface/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,24 @@ def get_printable_message(messages: List["ModelMessage"]) -> str:
return str_msg

@staticmethod
def messages_to_string(messages: List["ModelMessage"]) -> str:
def messages_to_string(
messages: List["ModelMessage"],
human_prefix: str = "Human",
ai_prefix: str = "AI",
system_prefix: str = "System",
) -> str:
"""Convert messages to str.
Args:
messages (List[ModelMessage]): The messages
human_prefix (str): The human prefix
ai_prefix (str): The ai prefix
system_prefix (str): The system prefix
Returns:
str: The str messages
"""
return _messages_to_str(messages)
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)


_SingleRoundMessage = List[BaseMessage]
Expand Down Expand Up @@ -1211,9 +1219,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
content=ai_message.content,
index=ai_message.index,
round_index=ai_message.round_index,
additional_kwargs=ai_message.additional_kwargs.copy()
if ai_message.additional_kwargs
else {},
additional_kwargs=(
ai_message.additional_kwargs.copy()
if ai_message.additional_kwargs
else {}
),
)
current_round.append(view_message)
return sum(messages_by_round, [])
Loading

0 comments on commit a88af6f

Please sign in to comment.