From 83b1552fc3d483eba4e91e8ebb8793ae25675665 Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Fri, 2 Feb 2024 15:41:58 +0800 Subject: [PATCH] Bump version to v0.2.0 (#60) * Refactor RemoteTool to support OpenAPI style spec * Refactor base tool, remote tool and tool server to support OpenAPI spec. * Update all tools to new style annotation * Fix wrappers and server * Update requirements * Remove mmengine and opencv requirements * Add missing file * Update requirements and langchain * Fix remote tool to support schema.AllOf * Update code style to increase line width * Support general file type inputs and outputs * Update `RemoteTool` and add docstring. * Add Gradio interface * Fix bugs * Support Lagent 0.2.0 and update gradio web ui * Fix internlm2 agent in webui * Make function tool can be used in server. * Move gradio webui to `webui` folder. * Add Chinese WebUI Readme * Update docs * Update search tool * Bump version to v0.2.0 --- .dev_scripts/generate_readme.py | 3 +- .gitignore | 7 + .pre-commit-config.yaml | 2 + README.md | 13 +- README_zh-CN.md | 11 +- agentlego/__init__.py | 1 + agentlego/apis/tool.py | 40 +- agentlego/parsers/default_parser.py | 95 ++- agentlego/schema.py | 51 +- agentlego/search.py | 24 +- agentlego/server/__init__.py | 0 agentlego/server/__main__.py | 4 + agentlego/server/server.py | 241 ++++++ agentlego/testing/setup_tool.py | 2 +- agentlego/tools/__init__.py | 21 +- agentlego/tools/base.py | 64 +- .../tools/calculator/python_calculator.py | 36 +- agentlego/tools/func.py | 108 +++ agentlego/tools/image_canny/README.md | 3 +- agentlego/tools/image_canny/canny_to_image.py | 39 +- agentlego/tools/image_canny/image_to_canny.py | 29 +- agentlego/tools/image_depth/depth_to_image.py | 39 +- agentlego/tools/image_depth/image_to_depth.py | 25 +- agentlego/tools/image_editing/__init__.py | 4 +- agentlego/tools/image_editing/expansion.py | 71 +- agentlego/tools/image_editing/remove.py | 53 +- agentlego/tools/image_editing/replace.py | 51 +- agentlego/tools/image_editing/stylization.py | 36 +- agentlego/tools/image_pose/README.md | 3 +- agentlego/tools/image_pose/facelandmark.py | 35 +- agentlego/tools/image_pose/image_to_pose.py | 32 +- agentlego/tools/image_pose/pose_to_image.py | 42 +- .../tools/image_scribble/image_to_scribble.py | 24 +- .../tools/image_scribble/scribble_to_image.py | 39 +- agentlego/tools/image_text/README.md | 11 +- agentlego/tools/image_text/__init__.py | 4 +- agentlego/tools/image_text/image_to_text.py | 28 +- agentlego/tools/image_text/text_to_image.py | 38 +- agentlego/tools/imagebind/__init__.py | 8 +- .../tools/imagebind/anything_to_image.py | 139 ++-- agentlego/tools/imagebind/data.py | 43 +- agentlego/tools/imagebind/models/helpers.py | 3 +- .../tools/imagebind/models/imagebind_model.py | 48 +- .../models/multimodal_preprocessors.py | 64 +- .../tools/imagebind/models/transformer.py | 10 +- .../object_detection/object_detection.py | 62 +- .../tools/object_detection/text_to_bbox.py | 75 +- agentlego/tools/ocr/ocr.py | 87 ++- agentlego/tools/remote.py | 299 ++++++-- agentlego/tools/search/google.py | 37 +- .../tools/segmentation/segment_anything.py | 127 ++-- .../segmentation/semantic_segmentation.py | 35 +- agentlego/tools/speech_text/speech_to_text.py | 47 +- agentlego/tools/speech_text/text_to_speech.py | 93 +-- agentlego/tools/translation/translation.py | 38 +- agentlego/tools/utils/diffusers.py | 13 +- agentlego/tools/utils/parameters.py | 111 +++ agentlego/tools/vqa/README.md | 6 +- agentlego/tools/vqa/__init__.py | 4 +- .../tools/vqa/visual_question_answering.py | 41 +- agentlego/tools/wrappers/lagent.py | 78 +- agentlego/tools/wrappers/langchain.py | 8 +- .../tools/wrappers/transformers_agent.py | 64 +- agentlego/types.py | 118 ++- agentlego/utils/__init__.py | 7 +- agentlego/utils/dependency.py | 3 +- agentlego/utils/file.py | 16 +- agentlego/utils/misc.py | 42 ++ agentlego/utils/module.py | 32 + agentlego/utils/openapi/__init__.py | 11 + agentlego/utils/openapi/api_model.py | 701 ++++++++++++++++++ agentlego/utils/openapi/extract.py | 78 ++ agentlego/utils/openapi/spec.py | 326 ++++++++ agentlego/utils/parse.py | 15 + agentlego/version.py | 2 +- docs/en/collect_docs.py | 25 +- docs/en/conf.py | 8 +- docs/en/get_started.md | 80 +- docs/en/index.rst | 1 + docs/en/modules/tool-server.md | 88 +++ docs/en/modules/tool.md | 68 +- docs/zh_cn/collect_docs.py | 31 +- docs/zh_cn/conf.py | 8 +- docs/zh_cn/get_started.md | 79 +- docs/zh_cn/index.rst | 1 + docs/zh_cn/modules/tool-server.md | 89 +++ docs/zh_cn/modules/tool.md | 66 +- examples/hf_agent/hf_agent_example.py | 5 +- examples/hf_agent/hf_agent_notebook.ipynb | 2 +- examples/lagent_example.py | 17 +- examples/langchain_example.py | 22 +- examples/remote_example.py | 3 +- examples/streamlit_demo.py | 22 +- examples/visual_chatgpt/visual_chatgpt.py | 26 +- pyproject.toml | 10 +- requirements/runtime.txt | 8 +- requirements/server.txt | 3 +- server.py | 175 ----- setup.cfg | 3 + setup.py | 3 + tests/test_apis.py | 2 +- tests/test_tools/test_basetool.py | 53 +- .../test_image_canny/test_canny_to_image.py | 30 - .../test_image_canny/test_image_to_canny.py | 27 - .../test_image_depth/test_depth_to_image.py | 27 - .../test_image_depth/test_image_to_depth.py | 27 - .../test_image_extension.py | 21 - .../test_image_stylization.py | 25 - .../test_image_editing/test_object_remove.py | 26 - .../test_image_editing/test_object_replace.py | 26 - .../test_image_pose/test_facelandmark.py | 32 - .../test_image_pose/test_image_to_pose.py | 30 - .../test_image_pose/test_pose_to_image.py | 30 - .../test_image_to_scribble.py | 27 - .../test_scribble_to_image.py | 29 - .../test_image_text/test_image_to_text.py | 38 - .../test_image_text/test_text_to_image.py | 34 - .../test_imagebind/test_anything_to_image.py | 37 - .../test_object_detection.py | 32 - .../test_text_to_bbox.py | 29 - tests/test_tools/test_ocr/test_ocr.py | 39 - .../test_segment_anything.py | 41 - .../test_semantic_segmentation.py | 31 - .../test_speech_test/test_speech_to_text.py | 38 - .../test_speech_test/test_text_to_speech.py | 34 - .../test_translation/test_translation.py | 39 - tests/test_tools/test_vqa/test_vqa.py | 38 - webui/README.md | 116 +++ webui/README_zh-CN.md | 110 +++ webui/agent_config.yml.example | 14 + webui/app.py | 98 +++ webui/css/chat.css | 140 ++++ webui/css/main.css | 655 ++++++++++++++++ webui/js/main.js | 180 +++++ webui/modules/__init__.py | 0 webui/modules/agents/__init__.py | 83 +++ webui/modules/agents/lagent_agent.py | 102 +++ webui/modules/agents/langchain_agent.py | 207 ++++++ webui/modules/chat.py | 220 ++++++ webui/modules/html_generator.py | 221 ++++++ webui/modules/logging.py | 13 + webui/modules/message_schema.py | 32 + webui/modules/settings.py | 121 +++ webui/modules/shared.py | 107 +++ webui/modules/text_generation.py | 43 ++ webui/modules/tools.py | 87 +++ webui/modules/ui.py | 103 +++ webui/modules/ui_agent.py | 113 +++ webui/modules/ui_chat.py | 118 +++ webui/modules/ui_tools.py | 172 +++++ webui/modules/utils.py | 154 ++++ webui/one_click.py | 170 +++++ webui/start_linux.sh | 67 ++ webui/tool_config.yml.example | 9 + 154 files changed, 6770 insertions(+), 2590 deletions(-) create mode 100644 agentlego/server/__init__.py create mode 100644 agentlego/server/__main__.py create mode 100644 agentlego/server/server.py create mode 100644 agentlego/tools/func.py create mode 100644 agentlego/tools/utils/parameters.py create mode 100644 agentlego/utils/misc.py create mode 100644 agentlego/utils/module.py create mode 100644 agentlego/utils/openapi/__init__.py create mode 100644 agentlego/utils/openapi/api_model.py create mode 100644 agentlego/utils/openapi/extract.py create mode 100644 agentlego/utils/openapi/spec.py create mode 100644 agentlego/utils/parse.py create mode 100644 docs/en/modules/tool-server.md create mode 100644 docs/zh_cn/modules/tool-server.md delete mode 100644 server.py create mode 100644 setup.cfg create mode 100644 setup.py delete mode 100644 tests/test_tools/test_image_canny/test_canny_to_image.py delete mode 100644 tests/test_tools/test_image_canny/test_image_to_canny.py delete mode 100644 tests/test_tools/test_image_depth/test_depth_to_image.py delete mode 100644 tests/test_tools/test_image_depth/test_image_to_depth.py delete mode 100644 tests/test_tools/test_image_editing/test_image_extension.py delete mode 100644 tests/test_tools/test_image_editing/test_image_stylization.py delete mode 100644 tests/test_tools/test_image_editing/test_object_remove.py delete mode 100644 tests/test_tools/test_image_editing/test_object_replace.py delete mode 100644 tests/test_tools/test_image_pose/test_facelandmark.py delete mode 100644 tests/test_tools/test_image_pose/test_image_to_pose.py delete mode 100644 tests/test_tools/test_image_pose/test_pose_to_image.py delete mode 100644 tests/test_tools/test_image_scribble/test_image_to_scribble.py delete mode 100644 tests/test_tools/test_image_scribble/test_scribble_to_image.py delete mode 100644 tests/test_tools/test_image_text/test_image_to_text.py delete mode 100644 tests/test_tools/test_image_text/test_text_to_image.py delete mode 100644 tests/test_tools/test_imagebind/test_anything_to_image.py delete mode 100644 tests/test_tools/test_object_detection/test_object_detection.py delete mode 100644 tests/test_tools/test_object_detection/test_text_to_bbox.py delete mode 100644 tests/test_tools/test_ocr/test_ocr.py delete mode 100644 tests/test_tools/test_segmentation/test_segment_anything.py delete mode 100644 tests/test_tools/test_segmentation/test_semantic_segmentation.py delete mode 100644 tests/test_tools/test_speech_test/test_speech_to_text.py delete mode 100644 tests/test_tools/test_speech_test/test_text_to_speech.py delete mode 100644 tests/test_tools/test_translation/test_translation.py delete mode 100644 tests/test_tools/test_vqa/test_vqa.py create mode 100644 webui/README.md create mode 100644 webui/README_zh-CN.md create mode 100644 webui/agent_config.yml.example create mode 100644 webui/app.py create mode 100644 webui/css/chat.css create mode 100644 webui/css/main.css create mode 100644 webui/js/main.js create mode 100644 webui/modules/__init__.py create mode 100644 webui/modules/agents/__init__.py create mode 100644 webui/modules/agents/lagent_agent.py create mode 100644 webui/modules/agents/langchain_agent.py create mode 100644 webui/modules/chat.py create mode 100644 webui/modules/html_generator.py create mode 100644 webui/modules/logging.py create mode 100644 webui/modules/message_schema.py create mode 100644 webui/modules/settings.py create mode 100644 webui/modules/shared.py create mode 100644 webui/modules/text_generation.py create mode 100644 webui/modules/tools.py create mode 100644 webui/modules/ui.py create mode 100644 webui/modules/ui_agent.py create mode 100644 webui/modules/ui_chat.py create mode 100644 webui/modules/ui_tools.py create mode 100644 webui/modules/utils.py create mode 100644 webui/one_click.py create mode 100755 webui/start_linux.sh create mode 100644 webui/tool_config.yml.example diff --git a/.dev_scripts/generate_readme.py b/.dev_scripts/generate_readme.py index cf51b6be..8b107a77 100644 --- a/.dev_scripts/generate_readme.py +++ b/.dev_scripts/generate_readme.py @@ -59,8 +59,7 @@ def parse_args(): parser = argparse.ArgumentParser(description=prog_description) - parser.add_argument( - 'tools', type=str, nargs='+', help='The tool class to generate.') + parser.add_argument('tools', type=str, nargs='+', help='The tool class to generate.') args = parser.parse_args() return args diff --git a/.gitignore b/.gitignore index e6450b28..594e4f4c 100644 --- a/.gitignore +++ b/.gitignore @@ -119,3 +119,10 @@ venv.bak/ # generated image, audio or video files. generated/ + +# gradio demo +webui/logs/ +webui/custom_tools/ +webui/installer_files/ +webui/generated/ +webui/*.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8b249bc6..9a399f32 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: rev: 6.1.0 hooks: - id: flake8 + args: ['--exclude', 'webui/*'] - repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: @@ -12,6 +13,7 @@ repos: rev: v0.40.2 hooks: - id: yapf + args: ['--exclude', 'webui'] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: diff --git a/README.md b/README.md index 2d824e40..3644342a 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,8 @@ pip install agentlego Some tools requires extra packages, please check the readme file of the tool, and confirm all requirements are satisfied. -For example, if we want to use the `ImageCaption` tool. We need to check the **Set up** section of -[readme](agentlego/tools/image_text/README.md#ImageCaption) and install the requirements. +For example, if we want to use the `ImageDescription` tool. We need to check the **Set up** section of +[readme](agentlego/tools/image_text/README.md#ImageDescription) and install the requirements. ```bash pip install -U openmim @@ -62,7 +62,7 @@ from agentlego import list_tools, load_tool print(list_tools()) # list tools in AgentLego -image_caption_tool = load_tool('ImageCaption', device='cuda') +image_caption_tool = load_tool('ImageDescription', device='cuda') print(image_caption_tool.description) image = './examples/demo.png' caption = image_caption_tool(image) @@ -88,9 +88,9 @@ caption = image_caption_tool(image) **Image-processing related** -- [ImageCaption](agentlego/tools/image_text/README.md#ImageCaption): Describe the input image. +- [ImageDescription](agentlego/tools/image_text/README.md#ImageDescription): Describe the input image. - [OCR](agentlego/tools/ocr/README.md#OCR): Recognize the text from a photo. -- [VisualQuestionAnswering](agentlego/tools/vqa/README.md#VisualQuestionAnswering): Answer the question according to the image. +- [VQA](agentlego/tools/vqa/README.md#VQA): Answer the question according to the image. - [HumanBodyPose](agentlego/tools/image_pose/README.md#HumanBodyPose): Estimate the pose or keypoints of human in an image. - [HumanFaceLandmark](agentlego/tools/image_pose/README.md#HumanFaceLandmark): Estimate the landmark or keypoints of human faces in an image. - [ImageToCanny](agentlego/tools/image_canny/README.md#ImageToCanny): Extract the edge image from an image. @@ -100,8 +100,7 @@ caption = image_caption_tool(image) - [TextToBbox](agentlego/tools/object_detection/README.md#TextToBbox): Detect specific objects described by the given text in the image. - Segment Anything series - [SegmentAnything](agentlego/tools/segmentation/README.md#SegmentAnything): Segment all items in the image. - - [SegmentClicked](agentlego/tools/segmentation/README.md#SegmentClicked): Segment the masked region in the image. - - [ObjectSegmenting](agentlego/tools/segmentation/README.md#ObjectSegmenting): Segment the certain objects in the image according to the given object name. + - [SegmentObject](agentlego/tools/segmentation/README.md#SegmentObject): Segment the certain objects in the image according to the given object name. **AIGC related** diff --git a/README_zh-CN.md b/README_zh-CN.md index e70e456e..3132e42c 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -46,7 +46,7 @@ pip install agentlego 一些工具需要额外的软件包,请查看工具的自述文件,并确认所有要求都得到满足。 -例如,如果我们想要使用`ImageCaption`工具。我们需要查看工具 [readme](agentlego/tools/image_text/README.md#ImageCaption) 的 **Set up** 小节并安装所需的软件。 +例如,如果我们想要使用`ImageDescription`工具。我们需要查看工具 [readme](agentlego/tools/image_text/README.md#ImageDescription) 的 **Set up** 小节并安装所需的软件。 ```bash pip install -U openmim @@ -60,7 +60,7 @@ from agentlego import list_tools, load_tool print(list_tools()) # list tools in AgentLego -image_caption_tool = load_tool('ImageCaption', device='cuda') +image_caption_tool = load_tool('ImageDescription', device='cuda') print(image_caption_tool.description) image = './examples/demo.png' caption = image_caption_tool(image) @@ -86,9 +86,9 @@ caption = image_caption_tool(image) **图像处理相关** -- [ImageCaption](agentlego/tools/image_text/README.md#ImageCaption): 描述输入图像。 +- [ImageDescription](agentlego/tools/image_text/README.md#ImageDescription): 描述输入图像。 - [OCR](agentlego/tools/ocr/README.md#OCR): 从照片中识别文本。 -- [VisualQuestionAnswering](agentlego/tools/vqa/README.md#VisualQuestionAnswering): 根据图片回答问题。 +- [VQA](agentlego/tools/vqa/README.md#VQA): 根据图片回答问题。 - [HumanBodyPose](agentlego/tools/image_pose/README.md#HumanBodyPose): 估计图像中人体的姿态或关键点,并绘制人体姿态图像 - [HumanFaceLandmark](agentlego/tools/image_pose/README.md#HumanFaceLandmark): 识别图像中人脸的关键点,并绘制带有关键点的图像。 - [ImageToCanny](agentlego/tools/image_canny/README.md#ImageToCanny): 从图像中提取边缘图像。 @@ -98,8 +98,7 @@ caption = image_caption_tool(image) - [TextToBbox](agentlego/tools/object_detection/README.md#TextToBbox): 检测图像中的给定对象。 - Segment Anything 系列工具 - [SegmentAnything](agentlego/tools/segmentation/README.md#SegmentAnything): 分割图像中的所有物体。 - - [SegmentClicked](agentlego/tools/segmentation/README.md#SegmentClicked): 分割图像中指定区域的物体。 - - [ObjectSegmenting](agentlego/tools/segmentation/README.md#ObjectSegmenting): 根据给定的物体名称,在图像中分割出特定的物体。 + - [SegmentObject](agentlego/tools/segmentation/README.md#SegmentObject): 根据给定的物体名称,在图像中分割出特定的物体。 **AIGC 相关** diff --git a/agentlego/__init__.py b/agentlego/__init__.py index 51074691..6301c9c2 100644 --- a/agentlego/__init__.py +++ b/agentlego/__init__.py @@ -1,4 +1,5 @@ from .apis.tool import list_tools, load_tool from .search import search_tool +from .version import __version__ # noqa: F401, F403 __all__ = ['load_tool', 'list_tools', 'search_tool'] diff --git a/agentlego/apis/tool.py b/agentlego/apis/tool.py index 41f144bd..2b3d8497 100644 --- a/agentlego/apis/tool.py +++ b/agentlego/apis/tool.py @@ -1,21 +1,30 @@ import importlib import inspect +from typing import Optional, Union import agentlego.tools from agentlego.tools import BaseTool +from agentlego.tools.func import _FuncToolType from agentlego.utils.cache import load_or_build_object NAMES2TOOLS = {} -def register_all_tools(module): +def extract_all_tools(module): if isinstance(module, str): module = importlib.import_module(module) + tools = {} for k, v in module.__dict__.items(): - if (isinstance(v, type) and issubclass(v, BaseTool) - and (v is not BaseTool)): - NAMES2TOOLS[k] = v + if (isinstance(v, type) and issubclass(v, BaseTool) and (v is not BaseTool)): + tools[k] = v + elif isinstance(v, _FuncToolType): + tools[k] = v + return tools + + +def register_all_tools(module): + NAMES2TOOLS.update(extract_all_tools(module)) register_all_tools(agentlego.tools) @@ -39,15 +48,15 @@ def list_tools(with_description=False): ... print(name, description) """ if with_description: - return list((name, cls.DEFAULT_TOOLMETA.description) + return list((name, cls.get_default_toolmeta().description) for name, cls in NAMES2TOOLS.items()) else: return list(NAMES2TOOLS.keys()) def load_tool(tool_type: str, - name: str = None, - description: str = None, + name: Optional[str] = None, + description: Optional[str] = None, device=None, **kwargs) -> BaseTool: """Load a configurable callable tool for different task. @@ -56,7 +65,7 @@ def load_tool(tool_type: str, tool_name (str): tool name for specific task. You can find more description about supported tools by :func:`~agentlego.apis.list_tools`. - override_name (str | None): The name to override the default name. + name (str | None): The name to override the default name. Defaults to None. description (str): The description to override the default description. Defaults to None. @@ -72,22 +81,19 @@ def load_tool(tool_type: str, Examples: >>> from agentlego import load_tool >>> # load tool with tool name - >>> tool, meta = load_tool('object detection') - >>> # load a specific model - >>> tool, meta = load_tool( - >>> 'object detection', model='rtmdet_l_8xb32-300e_coco') + >>> tool, meta = load_tool('GoogleSearch', with_url=True) """ if tool_type not in NAMES2TOOLS: # Using ValueError to show error msg cross lines. raise ValueError(f'{tool_type} is not supported now, the available ' 'tools are:\n' + '\n'.join(NAMES2TOOLS.keys())) - tool_type = NAMES2TOOLS[tool_type] - if 'device' in inspect.getfullargspec(tool_type).args: + constructor: Union[type, _FuncToolType] = NAMES2TOOLS[tool_type] + if 'device' in inspect.getfullargspec(constructor).args: kwargs['device'] = device - if name or description: - tool_obj = tool_type(**kwargs) + if name or description or isinstance(constructor, _FuncToolType): + tool_obj = constructor(**kwargs) if name: tool_obj.name = name if description: @@ -95,5 +101,5 @@ def load_tool(tool_type: str, else: # Only enable cache if no overrode attribution # to avoid the cached tool is changed. - tool_obj = load_or_build_object(tool_type, **kwargs) + tool_obj = load_or_build_object(constructor, **kwargs) return tool_obj diff --git a/agentlego/parsers/default_parser.py b/agentlego/parsers/default_parser.py index 10b581ec..6fa7dab6 100644 --- a/agentlego/parsers/default_parser.py +++ b/agentlego/parsers/default_parser.py @@ -1,31 +1,27 @@ from typing import Tuple -from agentlego.types import CatgoryToIO, IOType +from agentlego.types import AudioIO, File, ImageIO, IOType from .base_parser import BaseParser class DefaultParser(BaseParser): - agent_cat2type = { - 'image': 'path', - 'text': 'string', - 'audio': 'path', - 'int': 'int', - 'bool': 'bool', - 'float': 'float', + agent_type2format = { + ImageIO: 'path', + AudioIO: 'path', + File: 'path', } def parse_inputs(self, *args, **kwargs) -> Tuple[tuple, dict]: - for arg, arg_name in zip(args, self.tool.parameters): - kwargs[arg_name] = arg + for arg, p in zip(args, self.tool.inputs): + kwargs[p.name] = arg parsed_kwargs = {} for k, v in kwargs.items(): - if k not in self.tool.parameters: + p = self.tool.arguments.get(k) + if p is None: raise TypeError(f'Got unexcepted keyword argument "{k}".') - p = self.tool.parameters[k] - tool_type = CatgoryToIO[p.category] - if not isinstance(v, tool_type): - tool_input = tool_type(v) + if not isinstance(v, p.type): + tool_input = p.type(v) else: tool_input = v parsed_kwargs[k] = tool_input @@ -35,18 +31,27 @@ def parse_inputs(self, *args, **kwargs) -> Tuple[tuple, dict]: def parse_outputs(self, outputs): if isinstance(outputs, tuple): assert len(outputs) == len(self.toolmeta.outputs) + parsed_outs = [] + for out in outputs: + format = self.agent_type2format.get(type(out)) + if isinstance(out, IOType) and format: + out = out.to(format) + parsed_outs.append(out) + parsed_outs = tuple(parsed_outs) + elif isinstance(outputs, dict): + parsed_outs = {} + for k, out in outputs.items(): + format = self.agent_type2format.get(type(out)) + if isinstance(out, IOType) and format: + out = out.to(format) + parsed_outs[k] = out else: - assert len(self.toolmeta.outputs) == 1 - outputs = [outputs] + format = self.agent_type2format.get(type(outputs)) + if isinstance(outputs, IOType) and format: + outputs = outputs.to(format) + parsed_outs = outputs - parsed_outs = [] - for tool_output, out_category in zip(outputs, self.toolmeta.outputs): - agent_type = self.agent_cat2type[out_category] - if isinstance(tool_output, IOType): - tool_output = tool_output.to(agent_type) - parsed_outs.append(tool_output) - - return parsed_outs[0] if len(parsed_outs) == 1 else tuple(parsed_outs) + return parsed_outs def refine_description(self) -> str: """Refine the tool description by replacing the input and output @@ -61,12 +66,40 @@ def refine_description(self) -> str: """ inputs_desc = [] - for p in self.tool.parameters.values(): - type_ = self.agent_cat2type[p.category] - default = f', Defaults to {p.default}' if p.optional else '' - inputs_desc.append(f'{p.name} ({p.category} {type_}{default})') - inputs_desc = 'Args: ' + ', '.join(inputs_desc) + for p in self.tool.inputs: + desc = f'{p.name}' + format = self.agent_type2format.get(p.type, p.type.__name__) + if p.description: + format += f', {p.description}' + if p.optional: + format += f'. Optional, Defaults to {p.default}' + desc += f' ({format})' + inputs_desc.append(desc) + if len(inputs_desc) > 0: + inputs_desc = 'Args: ' + '; '.join(inputs_desc) + else: + inputs_desc = 'No argument.' + + outputs_desc = [] + for p in self.tool.outputs: + format = self.agent_type2format.get(p.type, p.type.__name__) + if p.name and p.description: + desc = f'{p.name} ({format}, {p.description})' + elif p.name: + desc = f'{p.name} ({format})' + elif p.description: + desc = f'{format} ({p.description})' + else: + desc = f'{format}' + outputs_desc.append(desc) + if len(outputs_desc) > 0: + outputs_desc = 'Returns: ' + '; '.join(outputs_desc) + else: + outputs_desc = 'No returns.' - description = f'{self.toolmeta.description} {inputs_desc}' + description = '' + if self.toolmeta.description: + description += f'{self.toolmeta.description}\n' + description += f'{inputs_desc}\n{outputs_desc}' return description diff --git a/agentlego/schema.py b/agentlego/schema.py index 1661bfc2..286888d5 100644 --- a/agentlego/schema.py +++ b/agentlego/schema.py @@ -1,38 +1,45 @@ +import copy from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, Type @dataclass -class ToolMeta: - """Meta information for tool. +class Parameter: + """Meta information for parameters. Args: + type (type): The type of the value. name (str): tool name for agent to identify the tool. - description (str): Description for tool. - inputs (tuple[str, ...]): Input categories for tool. - outputs (tuple[str, ...]): Output categories for tool. + description (str): Description for the parameter. + optional (bool): Whether the parameter has a default value. + Defaults to False. + default (Any): The default value of the parameter. """ - name: str - description: str - inputs: Tuple[str, ...] - outputs: Tuple[str, ...] + type: Optional[Type] = None + name: Optional[str] = None + description: Optional[str] = None + optional: Optional[bool] = None + default: Optional[Any] = None + filetype: Optional[str] = None + + def update(self, other: 'Parameter'): + other = copy.deepcopy(other) + for k, v in copy.deepcopy(other.__dict__).items(): + if v is not None: + self.__dict__[k] = v @dataclass -class Parameter: - """Meta information for parameters. +class ToolMeta: + """Meta information for tool. Args: name (str): tool name for agent to identify the tool. - category (str): Category of the parameter. - description (Optional[str]): Description for the parameter. - Defaults to None. - optional (bool): Whether the parameter has a default value. - Defaults to False. - default (Any): The default value of the parameter. + description (str): Description for tool. + inputs (tuple[str | Parameter, ...]): Input categories for tool. + outputs (tuple[str | Parameter, ...]): Output categories for tool. """ - name: str - category: str + name: Optional[str] = None description: Optional[str] = None - optional: bool = False - default: Any = None + inputs: Optional[Tuple[Parameter, ...]] = None + outputs: Optional[Tuple[Parameter, ...]] = None diff --git a/agentlego/search.py b/agentlego/search.py index 9df58da8..cab66dbe 100644 --- a/agentlego/search.py +++ b/agentlego/search.py @@ -6,7 +6,7 @@ from .utils import load_or_build_object -def _cosine_similarity(a: np.array, b: np.array) -> list: +def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray: """Calculate the cosine similarity of a and b.""" dot_product = np.dot(b, a) norm_a = np.linalg.norm(a) @@ -15,10 +15,7 @@ def _cosine_similarity(a: np.array, b: np.array) -> list: return res -def _search_with_openai(query, - choices, - model='text-embedding-ada-002', - topk=5): +def _search_with_openai(query, choices, model='text-embedding-ada-002', topk=5): """Search tools with openai API. Note: @@ -36,25 +33,24 @@ def _search_with_openai(query, list: List of tool descriptions. """ try: - from openai.embeddings_utils import get_embeddings + from openai import OpenAI except ModuleNotFoundError: raise ModuleNotFoundError( 'please install openai to enable searching tools powered by ' 'openai') - embeddings = get_embeddings([query] + choices, engine=model) - similarity = _cosine_similarity( - np.array(embeddings[0]), np.array(embeddings[1:])) + client = OpenAI() + embeddings = client.embeddings.create(input=[query] + choices, model=model).data + similarity = _cosine_similarity(np.array(embeddings[0]), np.array(embeddings[1:])) indices = np.argsort(-similarity)[:topk] return [choices[i] for i in indices] -def _serach_with_sentence_transformers( - query, - choices, - model='sentence-transformers/all-mpnet-base-v2', - topk=5): +def _serach_with_sentence_transformers(query, + choices, + model='sentence-transformers/all-mpnet-base-v2', + topk=5): """Search tools with sentence-transformers. Args: diff --git a/agentlego/server/__init__.py b/agentlego/server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentlego/server/__main__.py b/agentlego/server/__main__.py new file mode 100644 index 00000000..40456d55 --- /dev/null +++ b/agentlego/server/__main__.py @@ -0,0 +1,4 @@ +from .server import cli + +if __name__ == '__main__': + cli() diff --git a/agentlego/server/server.py b/agentlego/server/server.py new file mode 100644 index 00000000..da8688a6 --- /dev/null +++ b/agentlego/server/server.py @@ -0,0 +1,241 @@ +import base64 +import inspect +import logging +import sys +from contextlib import asynccontextmanager +from io import BytesIO +from pathlib import Path +from typing import List, Optional, Tuple + +from agentlego.apis.tool import (extract_all_tools, list_tools, load_tool, + register_all_tools) +from agentlego.parsers import NaiveParser +from agentlego.tools.base import BaseTool +from agentlego.types import AudioIO +from agentlego.types import File as FileType +from agentlego.types import ImageIO +from agentlego.utils import resolve_module + +try: + import rich + import typer + import uvicorn + from fastapi import FastAPI, File, Form, HTTPException, UploadFile + from fastapi.responses import RedirectResponse + from makefun import create_function + from pydantic import Field + from rich.table import Table + from typing_extensions import Annotated +except ImportError: + print('[Import Error] Failed to import server dependencies, ' + 'please install by `pip install agentlego[server]`') + sys.exit(1) + +cli = typer.Typer(add_completion=False, no_args_is_help=True) + + +def create_input_params(tool: BaseTool) -> List[inspect.Parameter]: + params = [] + for p in tool.inputs: + field_kwargs = {} + if p.description: + field_kwargs['description'] = p.description + if p.type is ImageIO: + field_kwargs['format'] = 'image;binary' + annotation = Annotated[UploadFile, File(**field_kwargs)] + elif p.type is AudioIO: + field_kwargs['format'] = 'audio;binary' + annotation = Annotated[UploadFile, File(**field_kwargs)] + elif p.type is FileType: + filetype = p.filetype or 'file' + field_kwargs['format'] = f'{filetype};binary' + annotation = Annotated[UploadFile, File(**field_kwargs)] + else: + annotation = Annotated[p.type, Form(**field_kwargs)] + + param = inspect.Parameter( + p.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=p.default if p.optional else inspect._empty, + annotation=annotation, + ) + params.append(param) + + return params + + +def create_output_annotation(tool: BaseTool): + output_schema = [] + + for p in tool.outputs: + field_kwargs = {} + if p.description: + field_kwargs['description'] = p.description + if p.type is ImageIO: + annotation = str + field_kwargs['format'] = 'image/png;base64' + elif p.type is AudioIO: + annotation = str + field_kwargs['format'] = 'audio/wav;base64' + elif p.type is FileType: + annotation = str + filetype = p.filetype or 'file' + field_kwargs['format'] = f'{filetype};base64' + else: + assert p.type is not None + annotation = p.type + + output_schema.append(Annotated[annotation, Field(**field_kwargs)]) + + if len(output_schema) == 0: + return None + elif len(output_schema) == 1: + return output_schema[0] + else: + return Tuple.copy_with(tuple(output_schema)) + + +def add_tool(tool: BaseTool, app: FastAPI): + tool_name = tool.name.replace(' ', '_') + + input_params = create_input_params(tool) + return_annotation = create_output_annotation(tool) + signature = inspect.Signature(input_params, return_annotation=return_annotation) + + def _call(**kwargs): + args = {} + for p in tool.inputs: + data = kwargs[p.name] + if p.type is ImageIO: + from PIL import Image + data = ImageIO(Image.open(data.file)) + elif p.type is AudioIO: + import torchaudio + file_format = data.filename.rpartition('.')[-1] or None + raw, sr = torchaudio.load(data.file, format=file_format) + data = AudioIO(raw, sampling_rate=sr) + elif p.type is FileType: + data = FileType(data.file.read()) + elif data is None: + continue + else: + data = p.type(data) + args[p.name] = data + + outs = tool(**args) + if not isinstance(outs, tuple): + outs = [outs] + + res = [] + for out, p in zip(outs, tool.outputs): + if p.type is ImageIO: + file = BytesIO() + out.to_pil().save(file, format='png') + out = base64.b64encode(file.getvalue()).decode() + elif p.type is AudioIO: + import torchaudio + file = BytesIO() + torchaudio.save(file, out.to_tensor(), out.sampling_rate, format='wav') + out = base64.b64encode(file.getvalue()).decode() + elif p.type is FileType: + out = base64.b64encode(out.to_bytes()).decode() + res.append(out) + + if len(res) == 0: + return None + elif len(res) == 1: + return res[0] + else: + return tuple(res) + + def call(**kwargs): + try: + return _call(**kwargs) + except Exception as e: + raise HTTPException(status_code=400, detail=repr(e)) + + app.add_api_route( + f'/{tool_name}', + endpoint=create_function(signature, call), + methods=['POST'], + operation_id=tool_name, + summary=tool.toolmeta.description, + ) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger = logging.getLogger('uvicorn.error') + logger.info(f'OpenAPI spec file at \x1b[1m{app.openapi_url}\x1b[0m') + yield + + +@cli.command(no_args_is_help=True) +def start( + tools: List[str] = typer.Argument( + help='The class name of tools to deploy.', show_default=False), + device: str = typer.Option( + 'cuda:0', help='The device to use to deploy the tools.'), + setup: bool = typer.Option(True, help='Setup tools during starting the server.'), + extra: Optional[List[Path]] = typer.Option( + None, + help='The extra Python source files or modules includes tools.', + file_okay=True, + dir_okay=True, + exists=True, + show_default=False), + host: str = typer.Option('127.0.0.1', help='The server address.'), + port: int = typer.Option(16180, help='The server port.'), + title: str = typer.Option('AgentLego', help='The title of the tool collection.'), +): + """Start a tool server with the specified tools.""" + app = FastAPI(title=title, openapi_url='/openapi.json', lifespan=lifespan) + + @app.get('/', include_in_schema=False) + async def root(): + return RedirectResponse(url='/openapi.json') + + if extra is not None: + for path in extra: + register_all_tools(resolve_module(path)) + + for name in tools: + tool = load_tool(name, device=device) + tool.set_parser(NaiveParser) + if setup: + tool.setup() + tool._is_setup = True + + add_tool(tool, app) + + uvicorn.run(app, host=host, port=port) + + +@cli.command(name='list') +def list_available_tools( + official: bool = typer.Option( + True, help='Whether to show AgentLego official tools.'), + extra: Optional[List[Path]] = typer.Option( + None, + help='The extra Python source files or modules includes tools.', + exists=True, + show_default=False, + resolve_path=True), +): + """List all available tools.""" + + table = Table('Class', 'source') + if official: + for name in sorted(list_tools()): + table.add_row(name, '[green]Official[/green]') + + if extra is not None: + for path in extra: + names2tools = extract_all_tools(resolve_module(path)) + for name in sorted(list(names2tools.keys())): + table.add_row(name, str(path)) + rich.print(table) + + +if __name__ == '__main__': + cli() diff --git a/agentlego/testing/setup_tool.py b/agentlego/testing/setup_tool.py index 83ab7309..f078ada9 100644 --- a/agentlego/testing/setup_tool.py +++ b/agentlego/testing/setup_tool.py @@ -10,4 +10,4 @@ def setup_tool(tool_type, **kwargs): return tool_type(**kwargs) else: domain = quote_plus(tool_type.DEFAULT_TOOLMETA.name.replace(' ', '')) - return RemoteTool(urljoin(remote_url, domain)) + return RemoteTool.from_url(urljoin(remote_url, domain)) diff --git a/agentlego/tools/__init__.py b/agentlego/tools/__init__.py index 22be75b7..66b98e41 100644 --- a/agentlego/tools/__init__.py +++ b/agentlego/tools/__init__.py @@ -1,29 +1,28 @@ from .base import BaseTool from .calculator import Calculator +from .func import make_tool from .image_canny import CannyTextToImage, ImageToCanny from .image_depth import DepthTextToImage, ImageToDepth -from .image_editing import (ImageExpansion, ImageStylization, ObjectRemove, - ObjectReplace) +from .image_editing import ImageExpansion, ImageStylization, ObjectRemove, ObjectReplace from .image_pose import HumanBodyPose, HumanFaceLandmark, PoseToImage from .image_scribble import ImageToScribble, ScribbleTextToImage -from .image_text import ImageCaption, TextToImage -from .imagebind import (AudioImageToImage, AudioTextToImage, AudioToImage, - ThermalToImage) +from .image_text import ImageDescription, TextToImage +from .imagebind import AudioImageToImage, AudioTextToImage, AudioToImage, ThermalToImage from .object_detection import ObjectDetection, TextToBbox from .ocr import OCR from .search import GoogleSearch from .segmentation import SegmentAnything, SegmentObject, SemanticSegmentation from .speech_text import SpeechToText, TextToSpeech from .translation import Translation -from .vqa import VisualQuestionAnswering +from .vqa import VQA __all__ = [ 'CannyTextToImage', 'ImageToCanny', 'DepthTextToImage', 'ImageToDepth', 'ImageExpansion', 'ObjectRemove', 'ObjectReplace', 'HumanFaceLandmark', 'HumanBodyPose', 'PoseToImage', 'ImageToScribble', 'ScribbleTextToImage', - 'ImageCaption', 'TextToImage', 'VisualQuestionAnswering', - 'ObjectDetection', 'TextToBbox', 'OCR', 'SegmentObject', 'SegmentAnything', - 'SemanticSegmentation', 'ImageStylization', 'AudioToImage', - 'ThermalToImage', 'AudioImageToImage', 'AudioTextToImage', 'SpeechToText', - 'TextToSpeech', 'Translation', 'GoogleSearch', 'Calculator', 'BaseTool' + 'ImageDescription', 'TextToImage', 'VQA', 'ObjectDetection', 'TextToBbox', 'OCR', + 'SegmentObject', 'SegmentAnything', 'SemanticSegmentation', 'ImageStylization', + 'AudioToImage', 'ThermalToImage', 'AudioImageToImage', 'AudioTextToImage', + 'SpeechToText', 'TextToSpeech', 'Translation', 'GoogleSearch', 'Calculator', + 'BaseTool', 'make_tool' ] diff --git a/agentlego/tools/base.py b/agentlego/tools/base.py index f6a73127..4a531404 100644 --- a/agentlego/tools/base.py +++ b/agentlego/tools/base.py @@ -1,19 +1,21 @@ import copy -import inspect from abc import ABCMeta, abstractmethod -from types import MethodType -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Mapping, Optional, Tuple, Union +from agentlego.parsers import DefaultParser from agentlego.schema import Parameter, ToolMeta +from .utils.parameters import extract_toolmeta class BaseTool(metaclass=ABCMeta): - - def __init__(self, toolmeta: Union[dict, ToolMeta], parser: Callable): - toolmeta = copy.deepcopy(toolmeta) - if isinstance(toolmeta, dict): - toolmeta = ToolMeta(**toolmeta) - self.toolmeta = toolmeta + default_desc: Optional[str] = None + + def __init__( + self, + toolmeta: Union[dict, ToolMeta, None] = None, + parser: Callable = DefaultParser, + ): + self.toolmeta = self.get_default_toolmeta(toolmeta) self.set_parser(parser) self._is_setup = False @@ -33,6 +35,33 @@ def description(self) -> str: def description(self, val: str): self.toolmeta.description = val + @property + def inputs(self) -> Tuple[Parameter, ...]: + return self.toolmeta.inputs + + @property + def arguments(self) -> Mapping[str, Parameter]: + return {i.name: i for i in self.toolmeta.inputs} + + @property + def outputs(self) -> Tuple[Parameter, ...]: + return self.toolmeta.outputs + + @classmethod + def get_default_toolmeta(cls, override=None) -> ToolMeta: + if isinstance(override, dict): + override = ToolMeta(**override) + override = ToolMeta() if override is None else copy.deepcopy(override) + + if override.name is None: + override.name = cls.__name__ + + if override.description is None: + doc = (cls.default_desc or '').partition('\n\n')[0].replace('\n', ' ') + override.description = doc.strip() + + return extract_toolmeta(cls.apply, override=override) + def set_parser(self, parser: Callable): self.parser = parser(self) self._parser_constructor = parser @@ -66,23 +95,6 @@ def __repr__(self) -> str: f'parser={type(self.parser).__name__})') return repr_str - @property - def parameters(self) -> Dict[str, Parameter]: - parameters = {} - for category, p in zip( - self.toolmeta.inputs, - inspect.signature(self.apply).parameters.values()): - if isinstance(self.apply, MethodType) and p.name == 'self': - continue - parameters[p.name] = Parameter( - name=p.name, - category=category, - description=None, - optional=p.default != inspect._empty, - default=p.default if p.default != inspect._empty else None, - ) - return parameters - def __copy__(self): obj = object.__new__(type(self)) obj.__dict__.update(self.__dict__) diff --git a/agentlego/tools/calculator/python_calculator.py b/agentlego/tools/calculator/python_calculator.py index 5fd98bd9..ddea3029 100644 --- a/agentlego/tools/calculator/python_calculator.py +++ b/agentlego/tools/calculator/python_calculator.py @@ -1,19 +1,13 @@ import math -from typing import Callable, Union import addict from func_timeout import func_timeout -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from ..base import BaseTool def safe_eval(expr): - math_methods = { - k: v - for k, v in math.__dict__.items() if not k.startswith('_') - } + math_methods = {k: v for k, v in math.__dict__.items() if not k.startswith('_')} allowed_methods = { 'math': addict.Addict(math_methods), 'max': max, @@ -30,26 +24,18 @@ class Calculator(BaseTool): """A calculator based on Python expression. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. + timeout (int): The timeout value to interrupt calculation. + Defaults to 2. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='Calculator', - description='A calculator tool. The input must be a single Python ' - 'expression and you cannot import packages. You can use functions ' - 'in the `math` package without import.', - inputs=['text'], - outputs=['text'], - ) - - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - timeout=2): - super().__init__(toolmeta=toolmeta, parser=parser) + default_desc = ('A calculator tool. The input must be a single Python ' + 'expression and you cannot import packages. You can use functions ' + 'in the `math` package without import.') + + def __init__(self, timeout=2, toolmeta=None): + super().__init__(toolmeta=toolmeta) self.timeout = timeout def apply(self, expression: str) -> str: diff --git a/agentlego/tools/func.py b/agentlego/tools/func.py new file mode 100644 index 00000000..cb7a5c21 --- /dev/null +++ b/agentlego/tools/func.py @@ -0,0 +1,108 @@ +from copy import deepcopy +from inspect import cleandoc +from typing import Callable, Optional, Union + +from agentlego.parsers import DefaultParser +from agentlego.schema import ToolMeta +from .base import BaseTool +from .utils.parameters import extract_toolmeta + + +class _FuncTool(BaseTool): + + def __init__(self, + func: Callable, + toolmeta: ToolMeta, + parser: Callable = DefaultParser): + self.func = func + self.toolmeta = deepcopy(toolmeta) + self.set_parser(parser) + self._is_setup = True + + def apply(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +class _FuncToolType: + + def __init__(self, func: Callable, toolmeta: ToolMeta): + self.func = func + self.toolmeta = toolmeta + + def __call__(self, + toolmeta: Union[dict, ToolMeta, None] = None, + parser: Callable = DefaultParser): + return _FuncTool(self.func, self.get_default_toolmeta(toolmeta), parser=parser) + + def get_default_toolmeta(self, override=None) -> ToolMeta: + if override is None: + return self.toolmeta + + override = deepcopy(override) + override = ToolMeta(**override) if isinstance(override, dict) else override + + if override.name is None: + override.name = self.toolmeta.name + if override.description is None: + override.description = self.toolmeta.description + if override.inputs is None: + override.inputs = self.toolmeta.inputs + if override.outputs is None: + override.outputs = self.toolmeta.outputs + + return override + + +def make_tool(func: Optional[Callable] = None, + toolmeta: Optional[ToolMeta] = None, + infer_meta: bool = True) -> Union[BaseTool, Callable]: + """Make tool from function. + + Args: + func (Callable | None): The execution function. If not specified, return a + function decorator. Defaults to None. + toolmeta (ToolMeta | dict | None): The meta information of the tool. + Defaults to None. + infer_meta (bool): Whether to infer the tool meta information. If False, directly + use the input ``toolmeta``. If True, try to extract meta information and + merge to the input toolmeta: Use function name as tool name; Use function + docstring as description; Use type hint to infer inputs and outputs + annotations. Defaults to True. + + Examples: + .. code-block:: python + from agentlego.tools import make_tool + + @make_tool + def multiply(a: int, b: int) -> int: + '''Multiply the input integers.''' + return a * b + + @make_tool(toolmeta=dict(name="GetTime", description='Return the current time.')) + def clock() -> str: + from datetime import datetime + return datetime.now().strftime('%Y/%m/%d %H:%M') + """ # noqa: E501 + if isinstance(toolmeta, dict): + toolmeta = ToolMeta(**toolmeta) + + def make_tool(func, override): + if infer_meta: + toolmeta = extract_toolmeta(func, override) + if toolmeta.name is None: + toolmeta.name = func.__name__ + if toolmeta.description is None and func.__doc__: + toolmeta.description = cleandoc(func.__doc__).partition('\n\n')[0] + else: + toolmeta = deepcopy(override) + tool = _FuncToolType(func, toolmeta=toolmeta) + return tool + + if func is None: + + def wrapper(func: Callable): + return make_tool(func, toolmeta) + + return wrapper + else: + return make_tool(func, override=toolmeta) diff --git a/agentlego/tools/image_canny/README.md b/agentlego/tools/image_canny/README.md index 192a540a..eaa00ff4 100644 --- a/agentlego/tools/image_canny/README.md +++ b/agentlego/tools/image_canny/README.md @@ -81,8 +81,7 @@ for step in ret.inner_steps[1:]: Before using the tool, please confirm you have installed the related dependencies by the below commands. ```bash -pip install -U openmim -mim install -U mmagic +pip install -U diffusers ``` ## Reference diff --git a/agentlego/tools/image_canny/canny_to_image.py b/agentlego/tools/image_canny/canny_to_image.py index eb7a77ae..9772aae1 100644 --- a/agentlego/tools/image_canny/canny_to_image.py +++ b/agentlego/tools/image_canny/canny_to_image.py @@ -1,8 +1,4 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import require from ..base import BaseTool from ..utils.diffusers import load_sd, load_sdxl @@ -12,31 +8,19 @@ class CannyTextToImage(BaseTool): """A tool to generate image according to a canny edge image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The canny controlnet model to use. You can choose from "sd" and "sdxl". Defaults to "sd". device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='CannyTextToImage', - description='This tool can generate an image from a canny edge ' - 'image and a text. The text should be a series of English keywords ' - 'separated by comma.', - inputs=['image', 'text'], - outputs=['image'], - ) + default_desc = ('This tool can generate an image from a canny edge ' + 'image and keywords.') @require('diffusers') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - model: str = 'sd', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, model: str = 'sd', device: str = 'cuda', toolmeta=None): + super().__init__(toolmeta=toolmeta) assert model in ['sd', 'sdxl'] self.model = model self.device = device @@ -58,8 +42,13 @@ def setup(self): ' missing fingers, extra digit, fewer digits, '\ 'cropped, worst quality, low quality' - def apply(self, image: ImageIO, text: str) -> ImageIO: - prompt = f'{text}, {self.a_prompt}' + def apply( + self, + image: ImageIO, + keywords: Annotated[str, + Info('A series of English keywords separated by comma.')], + ) -> ImageIO: + prompt = f'{keywords}, {self.a_prompt}' image = self.pipe( prompt, image=image.to_pil(), diff --git a/agentlego/tools/image_canny/image_to_canny.py b/agentlego/tools/image_canny/image_to_canny.py index 5a7eb518..62f02ef1 100644 --- a/agentlego/tools/image_canny/image_to_canny.py +++ b/agentlego/tools/image_canny/image_to_canny.py @@ -1,11 +1,7 @@ -from typing import Callable, Union - -import cv2 import numpy as np -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import ImageIO +from agentlego.utils import require from ..base import BaseTool @@ -13,29 +9,20 @@ class ImageToCanny(BaseTool): """A tool to do edge detection by canny algorithm on an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='EdgeDetectionOnImage', - description='This tool can extract the edge image from an image.', - inputs=['image'], - outputs=['image'], - ) + default_desc = 'This tool can extract the edge image from an image.' - def __init__( - self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - ): - super().__init__(toolmeta=toolmeta, parser=parser) + @require('opencv-python') + def __init__(self, toolmeta=None): + super().__init__(toolmeta=toolmeta) self.low_threshold = 100 self.high_threshold = 200 def apply(self, image: ImageIO) -> ImageIO: + import cv2 canny = cv2.Canny(image.to_array(), self.low_threshold, self.high_threshold)[:, :, None] canny = np.concatenate([canny] * 3, axis=2) diff --git a/agentlego/tools/image_depth/depth_to_image.py b/agentlego/tools/image_depth/depth_to_image.py index f3d96210..2e7e68ae 100644 --- a/agentlego/tools/image_depth/depth_to_image.py +++ b/agentlego/tools/image_depth/depth_to_image.py @@ -1,8 +1,4 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import require from ..base import BaseTool from ..utils.diffusers import load_sd, load_sdxl @@ -12,31 +8,19 @@ class DepthTextToImage(BaseTool): """A tool to generate image according to a depth image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The depth controlnet model to use. You can choose from "sd" and "sdxl". Defaults to "sd". device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='DepthTextToImage', - description='This tool can generate an image from a depth ' - 'image and a text. The text should be a series of English keywords ' - 'separated by comma.', - inputs=['image', 'text'], - outputs=['image'], - ) + default_desc = ('This tool can generate an image from a depth ' + 'image and keywords.') @require('diffusers') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - model: str = 'sd', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, model: str = 'sd', device: str = 'cuda', toolmeta=None): + super().__init__(toolmeta=toolmeta) assert model in ['sd', 'sdxl'] self.model = model self.device = device @@ -58,8 +42,13 @@ def setup(self): ' missing fingers, extra digit, fewer digits, '\ 'cropped, worst quality, low quality' - def apply(self, image: ImageIO, text: str) -> ImageIO: - prompt = f'{text}, {self.a_prompt}' + def apply( + self, + image: ImageIO, + keywords: Annotated[str, + Info('A series of English keywords separated by comma.')], + ) -> ImageIO: + prompt = f'{keywords}, {self.a_prompt}' image = self.pipe( prompt, image=image.to_pil(), diff --git a/agentlego/tools/image_depth/image_to_depth.py b/agentlego/tools/image_depth/image_to_depth.py index 8c9d7b41..c82d66f2 100644 --- a/agentlego/tools/image_depth/image_to_depth.py +++ b/agentlego/tools/image_depth/image_to_depth.py @@ -1,9 +1,5 @@ -from typing import Callable, Union - import numpy as np -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import ImageIO from agentlego.utils import load_or_build_object, require from ..base import BaseTool @@ -13,26 +9,17 @@ class ImageToDepth(BaseTool): """A tool to estimation depth of an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. device (str): The device to load the model. Defaults to 'cuda'. + device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='ImageToDepth', - description='This tool can generate the depth image of an image.', - inputs=['image'], - outputs=['image'], - ) + default_desc = 'This tool can generate the depth image of an image.' @require('transformers') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, device: str = 'cuda', toolmeta=None): + super().__init__(toolmeta=toolmeta) self.device = device def setup(self): diff --git a/agentlego/tools/image_editing/__init__.py b/agentlego/tools/image_editing/__init__.py index ef466efb..47211a80 100644 --- a/agentlego/tools/image_editing/__init__.py +++ b/agentlego/tools/image_editing/__init__.py @@ -3,6 +3,4 @@ from .replace import ObjectReplace from .stylization import ImageStylization -__all__ = [ - 'ImageExpansion', 'ObjectRemove', 'ObjectReplace', 'ImageStylization' -] +__all__ = ['ImageExpansion', 'ObjectRemove', 'ObjectReplace', 'ImageStylization'] diff --git a/agentlego/tools/image_editing/expansion.py b/agentlego/tools/image_editing/expansion.py index 081dac54..0c85a155 100644 --- a/agentlego/tools/image_editing/expansion.py +++ b/agentlego/tools/image_editing/expansion.py @@ -1,15 +1,10 @@ import math -from typing import Callable, Union -import cv2 import numpy as np from PIL import Image, ImageOps -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO -from agentlego.utils import require -from agentlego.utils.cache import load_or_build_object +from agentlego.types import Annotated, ImageIO, Info +from agentlego.utils import load_or_build_object, parse_multi_float, require from ..base import BaseTool from .replace import Inpainting @@ -29,6 +24,7 @@ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100): Returns: PIL.Image.Image: The blended image. """ + import cv2 new_size = new_image.size old_size = old_image.size easy_img = np.array(new_image) @@ -95,37 +91,24 @@ class ImageExpansion(BaseTool): """A tool to expand the given image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. caption_model (str): The model name used to inference. Which can be found in the ``MMPreTrain`` repository. Defaults to ``blip-base_3rdparty_caption``. device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='ImageExpansion', - description='This tool can expand the peripheral area of ' - 'an image based on its content, thus obtaining a larger image. ' - 'You need to provide the target image and the expand ratio. ' - 'The expand ratio can be a float string (for both width and ' - 'height expand ratio, like "1.25") or a string include two ' - 'float separated by comma (for width ratio and height ratio, ' - 'like "1.25, 1.0")', - inputs=['image', 'text'], - outputs=['image'], - ) + default_desc = ('This tool can expand the peripheral area of an image ' + 'based on its content, thus obtaining a larger image.') @require('mmpretrain') @require('diffusers') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, caption_model: str = 'blip-base_3rdparty_caption', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta=toolmeta) self.caption_model_name = caption_model self.device = device @@ -134,18 +117,21 @@ def setup(self): from mmpretrain.apis import ImageCaptionInferencer self.caption_inferencer = load_or_build_object( - ImageCaptionInferencer, - model=self.caption_model_name, - device=self.device) + ImageCaptionInferencer, model=self.caption_model_name, device=self.device) - self.inpainting_inferencer = load_or_build_object( - Inpainting, device=self.device) + self.inpainting_inferencer = load_or_build_object(Inpainting, device=self.device) - def apply(self, image: ImageIO, scale: str) -> ImageIO: + def apply( + self, + image: ImageIO, + scale: Annotated[str, + Info('expand ratio, can be a float number or two ' + 'float number for width and height ratio.')], + ) -> ImageIO: old_img = image.to_pil().convert('RGB') expand_ratio = 4 # maximum expand ratio for a single round. - scale_w, scale_h = self.parse_scale(scale) + scale_w, scale_h = parse_multi_float(scale, 2) target_w = int(old_img.size[0] * scale_w) target_h = int(old_img.size[1] * scale_h) @@ -153,10 +139,8 @@ def apply(self, image: ImageIO, scale: str) -> ImageIO: caption = self.get_caption(old_img) # crop the some border to re-generation. - crop_w = 15 if (old_img.width != target_w - and old_img.width > 100) else 0 - crop_h = 15 if (old_img.height != target_h - and old_img.height > 100) else 0 + crop_w = 15 if (old_img.width != target_w and old_img.width > 100) else 0 + crop_h = 15 if (old_img.height != target_h and old_img.height > 100) else 0 old_img = ImageOps.crop(old_img, (crop_w, crop_h, crop_w, crop_h)) canvas_w = min(expand_ratio * old_img.width, target_w) @@ -184,21 +168,12 @@ def apply(self, image: ImageIO, scale: str) -> ImageIO: # Resize the generated image into the canvas size and # blend with the old image. - image = image.resize((canvas.width, canvas.height), - Image.ANTIALIAS) + image = image.resize((canvas.width, canvas.height), Image.ANTIALIAS) image = blend_gt2pt(old_img, image) old_img = image return ImageIO(old_img) - @staticmethod - def parse_scale(scale: str): - if isinstance(scale, str) and ',' in scale: - w_scale, h_scale = scale.split(',')[:2] - else: - w_scale, h_scale = scale, scale - return float(w_scale), float(h_scale) - def get_caption(self, image: Image.Image): image = np.array(image)[:, :, ::-1] return self.caption_inferencer(image)[0]['pred_caption'] diff --git a/agentlego/tools/image_editing/remove.py b/agentlego/tools/image_editing/remove.py index 929a7900..5363bb9d 100644 --- a/agentlego/tools/image_editing/remove.py +++ b/agentlego/tools/image_editing/remove.py @@ -1,18 +1,10 @@ -from typing import Callable, Union - -import cv2 import numpy as np from PIL import Image -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO -from agentlego.utils import is_package_available, load_or_build_object, require +from agentlego.types import Annotated, ImageIO, Info +from agentlego.utils import load_or_build_object, require from ..base import BaseTool -if is_package_available('torch'): - import torch - GLOBAL_SEED = 1912 @@ -20,10 +12,6 @@ class ObjectRemove(BaseTool): """A tool to remove the certain objects in the image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. sam_model (str): The model name used to inference. Which can be found in the ``segment_anything`` repository. Defaults to ``sam_vit_h_4b8939.pth``. @@ -31,27 +19,22 @@ class ObjectRemove(BaseTool): found in the ``MMdetection`` repository. Defaults to ``glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365``. device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='RemoveObjectFromImage', - description='This tool can remove the specified object in the image. ' - 'You need to input the image and the object name to remove.', - inputs=['image', 'text'], - outputs=['image'], - ) + default_desc = 'This tool can remove the specified object in the image.' @require('mmdet') @require('segment_anything') @require('diffusers') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, sam_model: str = 'sam_vit_h_4b8939.pth', - grounding_model: - str = 'glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365', - device: str = 'cuda'): - super().__init__(toolmeta, parser) + grounding_model: str = 'glip_atss_swin-t_a' + '_fpn_dyhead_pretrain_obj365', + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta) self.grounding_model = grounding_model self.sam_model = sam_model self.device = device @@ -69,24 +52,24 @@ def setup(self): self.inpainting = load_or_build_object(Inpainting, device=self.device) - def apply(self, image: ImageIO, text: str) -> ImageIO: + def apply( + self, + image: ImageIO, + text: Annotated[str, Info('The object to remove.')], + ) -> ImageIO: + import torch image_path = image.to_path() image_pil = image.to_pil() text1 = text text2 = 'background' results = self.grounding( - inputs=image_path, - texts=[text1], - no_save_vis=True, - return_datasamples=True) + inputs=image_path, texts=[text1], no_save_vis=True, return_datasamples=True) results = results['predictions'][0].pred_instances boxes_filt = results.bboxes - image = cv2.imread(image_path) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - self.sam_predictor.set_image(image) + self.sam_predictor.set_image(image.to_array()) masks = self.get_mask_with_boxes(image_pil, image, boxes_filt) mask = torch.sum(masks, dim=0).unsqueeze(0) mask = torch.where(mask > 0, True, False) diff --git a/agentlego/tools/image_editing/replace.py b/agentlego/tools/image_editing/replace.py index ed140111..04f65bd7 100644 --- a/agentlego/tools/image_editing/replace.py +++ b/agentlego/tools/image_editing/replace.py @@ -1,12 +1,7 @@ -from typing import Callable, Union - -import cv2 import numpy as np from PIL import Image -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import is_package_available, load_or_build_object, require from ..base import BaseTool @@ -66,10 +61,6 @@ class ObjectReplace(BaseTool): """A tool to replace the certain objects in the image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. sam_model (str): The model name used to inference. Which can be found in the ``segment_anything`` repository. Defaults to ``sam_vit_h_4b8939.pth``. @@ -77,29 +68,23 @@ class ObjectReplace(BaseTool): found in the ``MMdetection`` repository. Defaults to ``glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365``. device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='ReplaceObjectInImage', - description='This tool can replace the specified object in the ' - 'input image with another object, like replacing a cat in an image ' - 'with a dog. You need to input the image to edit, the object name ' - 'to be replaced, and the object to replace with.', - inputs=['image', 'text', 'text'], - outputs=['image'], - ) + default_desc = ('This tool can replace the specified object in the input ' + 'image with another object, like replacing a cat in an ' + 'image with a dog.') @require('mmdet') @require('segment_anything') @require('diffusers') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, sam_model: str = 'sam_vit_h_4b8939.pth', - grounding_model: - str = 'glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365', - device: str = 'cuda'): - super().__init__(toolmeta, parser) + grounding_model: str = 'glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365', + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta) self.sam_model = sam_model self.grounding_model = grounding_model self.device = device @@ -116,22 +101,22 @@ def setup(self): self.inpainting = load_or_build_object(Inpainting, device=self.device) - def apply(self, image: ImageIO, text1: str, text2: str) -> ImageIO: + def apply( + self, + image: Annotated[ImageIO, Info('The image to edit.')], + text1: Annotated[str, Info('The object to be replaced.')], + text2: Annotated[str, Info('The object to replace with.')], + ) -> ImageIO: image_path = image.to_path() image_pil = image.to_pil() results = self.grounding( - inputs=image_path, - texts=[text1], - no_save_vis=True, - return_datasamples=True) + inputs=image_path, texts=[text1], no_save_vis=True, return_datasamples=True) results = results['predictions'][0].pred_instances boxes_filt = results.bboxes - image = cv2.imread(image_path) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - self.sam_predictor.set_image(image) + self.sam_predictor.set_image(image.to_array()) masks = self.get_mask_with_boxes(image_pil, image, boxes_filt) mask = torch.sum(masks, dim=0).unsqueeze(0) mask = torch.where(mask > 0, True, False) diff --git a/agentlego/tools/image_editing/stylization.py b/agentlego/tools/image_editing/stylization.py index 427f3428..455cc7ec 100644 --- a/agentlego/tools/image_editing/stylization.py +++ b/agentlego/tools/image_editing/stylization.py @@ -1,16 +1,10 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import ImageIO -from agentlego.utils import is_package_available, load_or_build_object, require +from agentlego.utils import load_or_build_object, require from ..base import BaseTool -if is_package_available('torch'): - import torch - def load_instruct_pix2pix(model, device): + import torch from diffusers import (EulerAncestralDiscreteScheduler, StableDiffusionInstructPix2PixPipeline) @@ -30,35 +24,27 @@ class ImageStylization(BaseTool): """A tool to stylize an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The model name used to inference. Which can be found in the ``diffusers`` repository. Defaults to 'timbrooks/instruct-pix2pix'. inference_steps (int): The number of inference steps. Defaults to 20. device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='ImageModification', - description='This tool can modify the input image according to the ' - 'input instruction. Here are some example instructions: ' - '"turn him into cyborg", "add fireworks to the sky", ' - '"make his jacket out of leather".', - inputs=['image', 'text'], - outputs=['image'], - ) + default_desc = ('This tool can modify the input image according to the ' + 'input instruction. Here are some example instructions: ' + '"turn him into cyborg", "add fireworks to the sky", ' + '"make his jacket out of leather".') @require('diffusers') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, model: str = 'timbrooks/instruct-pix2pix', inference_steps: int = 20, - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta=toolmeta) self.model_name = model self.inference_steps = inference_steps self.device = device diff --git a/agentlego/tools/image_pose/README.md b/agentlego/tools/image_pose/README.md index 989f2565..b5e8d388 100644 --- a/agentlego/tools/image_pose/README.md +++ b/agentlego/tools/image_pose/README.md @@ -113,8 +113,7 @@ for step in ret.inner_steps[1:]: Before using the tool, please confirm you have installed the related dependencies by the below commands. ```bash -pip install -U openmim -mim install -U mmagic +pip install -U diffusers ``` ## Reference diff --git a/agentlego/tools/image_pose/facelandmark.py b/agentlego/tools/image_pose/facelandmark.py index bbef4b47..2ad787d8 100644 --- a/agentlego/tools/image_pose/facelandmark.py +++ b/agentlego/tools/image_pose/facelandmark.py @@ -1,8 +1,4 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import load_or_build_object, require from ..base import BaseTool @@ -11,31 +7,19 @@ class HumanFaceLandmark(BaseTool): """A tool to extract human face landmarks from an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The model name used to inference. Which can be found - in the ``MMPose`` repository. - Defaults to 'face'. + in the ``MMPose`` repository. Defaults to 'face'. device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='HumanFaceLandmark', - description='This tool can estimate the landmark or keypoints of ' - 'human faces in an image and draw the landmarks image.', - inputs=['image'], - outputs=['image'], - ) + default_desc = ('This tool can estimate the landmark or keypoints of ' + 'human faces in an image and draw the landmarks image.') @require('mmpose') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - model: str = 'face', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, model: str = 'face', device: str = 'cuda', toolmeta=None): + super().__init__(toolmeta=toolmeta) self.model_name = model self.device = device @@ -44,7 +28,8 @@ def setup(self): self._inferencer = load_or_build_object( MMPoseInferencer, pose2d=self.model_name, device=self.device) - def apply(self, image: ImageIO) -> ImageIO: + def apply(self, image: ImageIO + ) -> Annotated[ImageIO, Info('The human face landmarks image.')]: image = image.to_array()[:, :, ::-1] results = next( self._inferencer( diff --git a/agentlego/tools/image_pose/image_to_pose.py b/agentlego/tools/image_pose/image_to_pose.py index f804463d..aa76d934 100644 --- a/agentlego/tools/image_pose/image_to_pose.py +++ b/agentlego/tools/image_pose/image_to_pose.py @@ -1,8 +1,4 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import load_or_build_object, require from ..base import BaseTool @@ -11,31 +7,20 @@ class HumanBodyPose(BaseTool): """A tool to extract human body keypoints from an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The model name used to inference. Which can be found in the ``MMPose`` repository. Defaults to `human`. device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='HumanBodyPoseDetectionOnImage', - description='This tool can estimate the pose or keypoints of ' - 'human in an image and draw the human pose image', - inputs=['image'], - outputs=['image'], - ) + default_desc = ('This tool can estimate the pose or keypoints of ' + 'human in an image and draw the human pose image.') @require('mmpose') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - model: str = 'human', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, model: str = 'human', device: str = 'cuda', toolmeta=None): + super().__init__(toolmeta=toolmeta) self.model_name = model self.device = device @@ -44,7 +29,8 @@ def setup(self): self._inferencer = load_or_build_object( MMPoseInferencer, pose2d=self.model_name, device=self.device) - def apply(self, image: ImageIO) -> ImageIO: + def apply(self, image: ImageIO + ) -> Annotated[ImageIO, Info('The human pose keypoints image.')]: image = image.to_array()[:, :, ::-1] vis_params = self.adaptive_vis_params(*image.shape[:2]) results = next( diff --git a/agentlego/tools/image_pose/pose_to_image.py b/agentlego/tools/image_pose/pose_to_image.py index 217bc6fd..a36e7bc5 100644 --- a/agentlego/tools/image_pose/pose_to_image.py +++ b/agentlego/tools/image_pose/pose_to_image.py @@ -1,10 +1,8 @@ -from typing import Callable, Tuple, Union +from typing import Tuple from PIL import Image -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import require from ..base import BaseTool from ..utils.diffusers import load_sd, load_sdxl @@ -14,31 +12,19 @@ class PoseToImage(BaseTool): """A tool to generate image according to a human pose image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The pose controlnet model to use. You can choose from "sd" and "sdxl". Defaults to "sd". device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='PoseToImage', - description='This tool can generate an image from a human pose ' - 'image and a text. The text should be a series of English keywords ' - 'separated by comma.', - inputs=['image', 'text'], - outputs=['image'], - ) + default_desc = ('This tool can generate an image from a human pose ' + 'image and a text.') @require('diffusers') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - model: str = 'sd', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, model: str = 'sd', device: str = 'cuda', toolmeta=None): + super().__init__(toolmeta=toolmeta) assert model in ['sd', 'sdxl'] self.model = model self.device = device @@ -61,10 +47,14 @@ def setup(self): ' missing fingers, extra digit, fewer digits, '\ 'cropped, worst quality, low quality' - def apply(self, image: ImageIO, text: str) -> ImageIO: - text = f'{text}, {self.a_prompt}' - width, height = self.get_image_size( - image.to_pil(), canvas_size=self.canvas_size) + def apply( + self, + image: ImageIO, + keywords: Annotated[str, + Info('A series of English keywords separated by comma.')], + ) -> ImageIO: + text = f'{keywords}, {self.a_prompt}' + width, height = self.get_image_size(image.to_pil(), canvas_size=self.canvas_size) image = self.pipe( text, image=image.to_pil(), diff --git a/agentlego/tools/image_scribble/image_to_scribble.py b/agentlego/tools/image_scribble/image_to_scribble.py index 8690a438..a6ab4bca 100644 --- a/agentlego/tools/image_scribble/image_to_scribble.py +++ b/agentlego/tools/image_scribble/image_to_scribble.py @@ -1,7 +1,3 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import ImageIO from agentlego.utils import load_or_build_object, require from ..base import BaseTool @@ -11,26 +7,16 @@ class ImageToScribble(BaseTool): """A tool to convert image to a scribble sketch. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='ImageToScribble', - description='This tool can generate a sketch scribble of an image.', - inputs=['image'], - outputs=['image'], - ) + default_desc = 'This tool can generate a sketch scribble of an image.' @require('controlnet_aux') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, device: str = 'cuda', toolmeta=None): + super().__init__(toolmeta=toolmeta) self.device = device def setup(self): diff --git a/agentlego/tools/image_scribble/scribble_to_image.py b/agentlego/tools/image_scribble/scribble_to_image.py index 21f6c327..a101a560 100644 --- a/agentlego/tools/image_scribble/scribble_to_image.py +++ b/agentlego/tools/image_scribble/scribble_to_image.py @@ -1,8 +1,4 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import require from ..base import BaseTool from ..utils.diffusers import load_sd @@ -12,34 +8,22 @@ class ScribbleTextToImage(BaseTool): """A tool to generate image according to a scribble sketch. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The model name used to inference. Which can be found in the ``diffusers`` repository. Defaults to 'lllyasviel/sd-controlnet_scribble'. model (str): The scribble controlnet model to use. You can only choose "sd" by now. Defaults to "sd". device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='ScribbleTextToImage', - description='This tool can generate an image from a sketch scribble ' - 'image and a text. The text should be a series of English keywords ' - 'separated by comma.', - inputs=['image', 'text'], - outputs=['image'], - ) + default_desc = ('This tool can generate an image from a sketch scribble ' + 'image and a text.') @require('diffusers') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - model: str = 'sd', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, model: str = 'sd', device: str = 'cuda', toolmeta=None): + super().__init__(toolmeta=toolmeta) assert model in ['sd'] self.model_name = model self.device = device @@ -55,8 +39,13 @@ def setup(self): ' missing fingers, extra digit, fewer digits, '\ 'cropped, worst quality, low quality' - def apply(self, image: ImageIO, text: str) -> ImageIO: - prompt = f'{text}, {self.a_prompt}' + def apply( + self, + image: ImageIO, + keywords: Annotated[str, + Info('A series of English keywords separated by comma.')], + ) -> ImageIO: + prompt = f'{keywords}, {self.a_prompt}' image = self.pipe( prompt, image.to_pil(), diff --git a/agentlego/tools/image_text/README.md b/agentlego/tools/image_text/README.md index de881eaf..617a48b2 100644 --- a/agentlego/tools/image_text/README.md +++ b/agentlego/tools/image_text/README.md @@ -1,4 +1,4 @@ -# ImageCaption +# ImageDescription ## Examples @@ -8,7 +8,7 @@ from agentlego.apis import load_tool # load tool -tool = load_tool('ImageCaption', device='cuda') +tool = load_tool('ImageDescription', device='cuda') # apply tool caption = tool('examples/demo.png') @@ -23,7 +23,7 @@ from agentlego.apis import load_tool # load tools and build agent # please set `OPENAI_API_KEY` in your environment variable. -tool = load_tool('ImageCaption', device='cuda').to_lagent() +tool = load_tool('ImageDescription', device='cuda').to_lagent() agent = ReAct(GPTAPI(temperature=0.), action_executor=ActionExecutor([tool])) # agent running with the tool. @@ -42,7 +42,7 @@ from agentlego.apis import load_tool from PIL import Image # load tools and build transformers agent -tool = load_tool('ImageCaption', device='cuda').to_transformers_agent() +tool = load_tool('ImageDescription', device='cuda').to_transformers_agent() agent = HfAgent('https://api-inference.huggingface.co/models/bigcode/starcoder', additional_tools=[tool]) # agent running with the tool (For demo, we directly specify the tool name here.) @@ -127,8 +127,7 @@ print(image) Before using the tool, please confirm you have installed the related dependencies by the below commands. ```bash -pip install -U openmim -mim install -U mmagic +pip install -U diffusers ``` ## Reference diff --git a/agentlego/tools/image_text/__init__.py b/agentlego/tools/image_text/__init__.py index c51fd807..0cbad9a1 100644 --- a/agentlego/tools/image_text/__init__.py +++ b/agentlego/tools/image_text/__init__.py @@ -1,4 +1,4 @@ -from .image_to_text import ImageCaption +from .image_to_text import ImageDescription from .text_to_image import TextToImage -__all__ = ['ImageCaption', 'TextToImage'] +__all__ = ['ImageDescription', 'TextToImage'] diff --git a/agentlego/tools/image_text/image_to_text.py b/agentlego/tools/image_text/image_to_text.py index 28179e0d..4421c621 100644 --- a/agentlego/tools/image_text/image_to_text.py +++ b/agentlego/tools/image_text/image_to_text.py @@ -1,41 +1,29 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import ImageIO from agentlego.utils import load_or_build_object, require from ..base import BaseTool -class ImageCaption(BaseTool): +class ImageDescription(BaseTool): """A tool to describe an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The model name used to inference. Which can be found in the ``MMPreTrain`` repository. Defaults to ``blip-base_3rdparty_caption``. device (str): The device to load the model. Defaults to 'cpu'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='ImageDescription', - description=('A useful tool that returns a brief ' - 'description of the input image.'), - inputs=['image'], - outputs=['text'], - ) + default_desc = ('A useful tool that returns a brief ' + 'description of the input image.') @require('mmpretrain') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, model: str = 'blip-base_3rdparty_caption', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta=toolmeta) self.model = model self.device = device diff --git a/agentlego/tools/image_text/text_to_image.py b/agentlego/tools/image_text/text_to_image.py index a2cffc3f..768e4c0a 100644 --- a/agentlego/tools/image_text/text_to_image.py +++ b/agentlego/tools/image_text/text_to_image.py @@ -1,8 +1,4 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import require from ..base import BaseTool from ..utils.diffusers import load_sd, load_sdxl @@ -12,31 +8,19 @@ class TextToImage(BaseTool): """A tool to generate image according to some keywords. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The stable diffusion model to use. You can choose from "sd" and "sdxl". Defaults to "sd". device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='TextToImage', - description='This tool can generate an image according to the ' - 'input text. The input text should be a series of keywords ' - 'separated by comma, and all keywords must be in English.', - inputs=['text'], - outputs=['image'], - ) + default_desc = ('This tool can generate an image according to the ' + 'input text.') @require('diffusers') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - model: str = 'sd', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, model: str = 'sd', device: str = 'cuda', toolmeta=None): + super().__init__(toolmeta=toolmeta) assert model in ['sd', 'sdxl'] self.model = model self.device = device @@ -51,8 +35,12 @@ def setup(self): ' missing fingers, extra digit, fewer digits, '\ 'cropped, worst quality, low quality' - def apply(self, text: str) -> ImageIO: - prompt = f'{text}, {self.a_prompt}' + def apply( + self, + keywords: Annotated[str, + Info('A series of English keywords separated by comma.')], + ) -> ImageIO: + prompt = f'{keywords}, {self.a_prompt}' image = self.pipe( prompt, num_inference_steps=30, diff --git a/agentlego/tools/imagebind/__init__.py b/agentlego/tools/imagebind/__init__.py index a143cc6f..b28e6e20 100644 --- a/agentlego/tools/imagebind/__init__.py +++ b/agentlego/tools/imagebind/__init__.py @@ -1,6 +1,4 @@ -from .anything_to_image import (AudioImageToImage, AudioTextToImage, - AudioToImage, ThermalToImage) +from .anything_to_image import (AudioImageToImage, AudioTextToImage, AudioToImage, + ThermalToImage) -__all__ = [ - 'AudioToImage', 'ThermalToImage', 'AudioImageToImage', 'AudioTextToImage' -] +__all__ = ['AudioToImage', 'ThermalToImage', 'AudioImageToImage', 'AudioTextToImage'] diff --git a/agentlego/tools/imagebind/anything_to_image.py b/agentlego/tools/imagebind/anything_to_image.py index fff87b8a..98d12a0a 100644 --- a/agentlego/tools/imagebind/anything_to_image.py +++ b/agentlego/tools/imagebind/anything_to_image.py @@ -1,7 +1,3 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import AudioIO, ImageIO from agentlego.utils import is_package_available, load_or_build_object, require from ..base import BaseTool @@ -36,31 +32,21 @@ class AudioToImage(BaseTool): """A tool to generate image from an audio. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. device (str): The device to load the model. Defaults to 'cpu'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='AudioToImage', - description=('This tool can generate an image ' - 'according to the input audio'), - inputs=['audio'], - outputs=['image'], - ) + + default_desc = ('This tool can generate an image ' + 'according to the input audio.') @require(['diffusers', 'ftfy', 'iopath', 'timm', 'pytorchvideo']) - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - device: str = 'cpu'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, device: str = 'cpu', toolmeta=None): + super().__init__(toolmeta=toolmeta) self.device = device def setup(self): - self._inferencer = load_or_build_object( - AnythingToImage, device=self.device) + self._inferencer = load_or_build_object(AnythingToImage, device=self.device) def apply(self, audio: AudioIO) -> ImageIO: from .data import load_and_transform_audio_data @@ -68,8 +54,7 @@ def apply(self, audio: AudioIO) -> ImageIO: audio_paths = [audio.to_path()] audio_data = load_and_transform_audio_data(audio_paths, self.device) - embeddings = self._inferencer.model.forward( - {ModalityType.AUDIO: audio_data}) + embeddings = self._inferencer.model.forward({ModalityType.AUDIO: audio_data}) embeddings = embeddings[ModalityType.AUDIO] images = self._inferencer.pipe( image_embeds=embeddings.half(), width=512, height=512).images @@ -82,41 +67,29 @@ class ThermalToImage(BaseTool): """A tool to generate image from an thermal image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. device (str): The device to load the model. Defaults to 'cpu'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='ThermalToImage', - description=('This tool can generate an image ' - 'according to the input thermal image.'), - inputs=['image'], - outputs=['image'], - ) + + default_desc = ('This tool can generate an image ' + 'according to the input thermal image.') @require(['diffusers', 'ftfy', 'iopath', 'timm']) - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - device: str = 'cpu'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, device: str = 'cpu', toolmeta=None): + super().__init__(toolmeta=toolmeta) self.device = device def setup(self): - self._inferencer = load_or_build_object( - AnythingToImage, device=self.device) + self._inferencer = load_or_build_object(AnythingToImage, device=self.device) def apply(self, thermal: ImageIO) -> ImageIO: from .data import load_and_transform_thermal_data from .models.imagebind_model import ModalityType thermal_paths = [thermal.to_path()] - thermal_data = load_and_transform_thermal_data(thermal_paths, - self.device) - embeddings = self._inferencer.model.forward( - {ModalityType.THERMAL: thermal_data}) + thermal_data = load_and_transform_thermal_data(thermal_paths, self.device) + embeddings = self._inferencer.model.forward({ModalityType.THERMAL: thermal_data}) embeddings = embeddings[ModalityType.THERMAL] images = self._inferencer.pipe( image_embeds=embeddings.half(), width=512, height=512).images @@ -129,50 +102,36 @@ class AudioImageToImage(BaseTool): """A tool to generate image from an audio and an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. device (str): The device to load the model. Defaults to 'cpu'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='AudioImageToImage', - description=('This tool can generate an image according to ' - 'the input reference image and the input audio.'), - inputs=['image', 'audio'], - outputs=['image'], - ) + + default_desc = ('This tool can generate an image according to ' + 'the input reference image and the input audio.') @require(['diffusers', 'ftfy', 'iopath', 'timm', 'pytorchvideo']) - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - device: str = 'cpu'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, device: str = 'cpu', toolmeta=None): + super().__init__(toolmeta=toolmeta) self.device = device def setup(self): - self._inferencer = load_or_build_object( - AnythingToImage, device=self.device) + self._inferencer = load_or_build_object(AnythingToImage, device=self.device) def apply(self, image: ImageIO, audio: AudioIO) -> ImageIO: - from .data import (load_and_transform_audio_data, - load_and_transform_vision_data) + from .data import load_and_transform_audio_data, load_and_transform_vision_data from .models.imagebind_model import ModalityType # process image data - vision_data = load_and_transform_vision_data([image.to_path()], - self.device) - embeddings = self._inferencer.model.forward( - {ModalityType.VISION: vision_data}, normalize=False) + vision_data = load_and_transform_vision_data([image.to_path()], self.device) + embeddings = self._inferencer.model.forward({ModalityType.VISION: vision_data}, + normalize=False) img_embeddings = embeddings[ModalityType.VISION] # process audio data - audio_data = load_and_transform_audio_data([audio.to_path()], - self.device) + audio_data = load_and_transform_audio_data([audio.to_path()], self.device) embeddings = self._inferencer.model.forward({ - ModalityType.AUDIO: - audio_data, + ModalityType.AUDIO: audio_data, }) audio_embeddings = embeddings[ModalityType.AUDIO] @@ -188,35 +147,24 @@ class AudioTextToImage(BaseTool): """A tool to generate image from an audio and texts. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. device (str): The device to load the model. Defaults to 'cpu'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='AudioTextToImage', - description=('This tool can generate an image according to ' - 'the input audio and the input description.'), - inputs=['audio', 'text'], - outputs=['image'], - ) + + default_desc = ('This tool can generate an image according to ' + 'the input audio and the input description.') @require(['diffusers', 'ftfy', 'iopath', 'timm', 'pytorchvideo']) - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - device: str = 'cpu'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, device: str = 'cpu', toolmeta=None): + super().__init__(toolmeta=toolmeta) self.device = device def setup(self): - self._inferencer = load_or_build_object( - AnythingToImage, device=self.device) + self._inferencer = load_or_build_object(AnythingToImage, device=self.device) def apply(self, audio: AudioIO, prompt: str) -> ImageIO: - from .data import (load_and_transform_audio_data, - load_and_transform_text) + from .data import load_and_transform_audio_data, load_and_transform_text from .models.imagebind_model import ModalityType audio_paths = [audio.to_path()] @@ -227,8 +175,7 @@ def apply(self, audio: AudioIO, prompt: str) -> ImageIO: audio_data = load_and_transform_audio_data(audio_paths, self.device) embeddings = self._inferencer.model.forward({ - ModalityType.AUDIO: - audio_data, + ModalityType.AUDIO: audio_data, }) audio_embeddings = embeddings[ModalityType.AUDIO] embeddings = text_embeddings * 0.5 + audio_embeddings * 0.5 diff --git a/agentlego/tools/imagebind/data.py b/agentlego/tools/imagebind/data.py index eecdbb68..30335137 100644 --- a/agentlego/tools/imagebind/data.py +++ b/agentlego/tools/imagebind/data.py @@ -67,8 +67,7 @@ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): ) # cut and pad if p > 0: - fbank = torch.nn.functional.pad( - fbank, (0, p), mode='constant', value=0) + fbank = torch.nn.functional.pad(fbank, (0, p), mode='constant', value=0) elif p < 0: fbank = fbank[:, 0:target_length] # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 @@ -83,8 +82,7 @@ def get_clip_timepoints(clip_sampler, duration): is_last_clip = False end = 0.0 while not is_last_clip: - start, end, _, _, is_last_clip = clip_sampler( - end, duration, annotation=None) + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) all_clips_timepoints.append((start, end)) return all_clips_timepoints @@ -96,8 +94,7 @@ def load_and_transform_vision_data(image_paths, device): image_ouputs = [] for image_path in image_paths: data_transform = transforms.Compose([ - transforms.Resize( - 224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( @@ -120,8 +117,7 @@ def load_and_transform_depth_data(depth_paths, device): depth_ouputs = [] for depth_path in depth_paths: data_transform = transforms.Compose([ - transforms.Resize( - 224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), # if I use this normalization, I cannot get good results... @@ -142,8 +138,7 @@ def load_and_transform_thermal_data(thermal_paths, device): thermal_ouputs = [] for thermal_path in thermal_paths: data_transform = transforms.Compose([ - transforms.Resize( - 224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) @@ -188,9 +183,8 @@ def load_and_transform_audio_data( if sample_rate != sr: waveform = torchaudio.functional.resample( waveform, orig_freq=sr, new_freq=sample_rate) - all_clips_timepoints = get_clip_timepoints( - clip_sampler, - waveform.size(1) / sample_rate) + all_clips_timepoints = get_clip_timepoints(clip_sampler, + waveform.size(1) / sample_rate) all_clips = [] for clip_timepoints in all_clips_timepoints: waveform_clip = waveform[ @@ -198,8 +192,8 @@ def load_and_transform_audio_data( int(clip_timepoints[0] * sample_rate):int(clip_timepoints[1] * sample_rate), ] - waveform_melspec = waveform2melspec(waveform_clip, sample_rate, - num_mel_bins, target_length) + waveform_melspec = waveform2melspec(waveform_clip, sample_rate, num_mel_bins, + target_length) all_clips.append(waveform_melspec) normalize = transforms.Normalize(mean=mean, std=std) @@ -217,8 +211,7 @@ def get_clip_timepoints(clip_sampler, duration): # noqa is_last_clip = False end = 0.0 while not is_last_clip: - start, end, _, _, is_last_clip = clip_sampler( - end, duration, annotation=None) + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) all_clips_timepoints.append((start, end)) return all_clips_timepoints @@ -295,8 +288,7 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): elif spatial_idx == 2: x_offset = width - size cropped = images[:, :, y_offset:y_offset + size, x_offset:x_offset + size] - cropped_boxes = crop_boxes(boxes, x_offset, - y_offset) if boxes is not None else None + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None if ndim == 3: cropped = cropped.squeeze(0) return cropped, cropped_boxes @@ -331,8 +323,7 @@ def forward(self, videos): videos: A list with 3x the number of elements. Each video converted to C, T, H', W' by spatial cropping. """ - assert isinstance( - videos, list), 'Must be a list of videos after temporal crops' + assert isinstance(videos, list), 'Must be a list of videos after temporal crops' assert all([video.ndim == 4 for video in videos]), 'Must be (C,T,H,W)' res = [] for video in videos: @@ -342,9 +333,7 @@ def forward(self, videos): continue flipped_video = transforms.functional.hflip(video) for spatial_idx in self.flipped_crops_to_ext: - res.append( - uniform_crop(flipped_video, self.crop_size, - spatial_idx)[0]) + res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) return res @@ -369,8 +358,7 @@ def load_and_transform_video_data( clip_sampler = ConstantClipsPerVideoSampler( clip_duration=clip_duration, clips_per_video=clips_per_video) - frame_sampler = pv_transforms.UniformTemporalSubsample( - num_samples=clip_duration) + frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) for video_path in video_paths: video = EncodedVideo.from_path( @@ -380,8 +368,7 @@ def load_and_transform_video_data( **{'sample_rate': sample_rate}, ) - all_clips_timepoints = get_clip_timepoints(clip_sampler, - video.duration) + all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) all_video = [] for clip_timepoints in all_clips_timepoints: diff --git a/agentlego/tools/imagebind/models/helpers.py b/agentlego/tools/imagebind/models/helpers.py index 856c2490..0c17020e 100644 --- a/agentlego/tools/imagebind/models/helpers.py +++ b/agentlego/tools/imagebind/models/helpers.py @@ -46,8 +46,7 @@ def __init__( self.register_buffer('log_logit_scale', log_logit_scale) def forward(self, x): - return torch.clip( - self.log_logit_scale.exp(), max=self.max_logit_scale) * x + return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x def extra_repr(self): st = (f'logit_scale_init={self.logit_scale_init},' diff --git a/agentlego/tools/imagebind/models/imagebind_model.py b/agentlego/tools/imagebind/models/imagebind_model.py index 7fad89a4..1b7855b6 100644 --- a/agentlego/tools/imagebind/models/imagebind_model.py +++ b/agentlego/tools/imagebind/models/imagebind_model.py @@ -12,11 +12,10 @@ import torch import torch.nn as nn -from .helpers import (EinOpsRearrange, LearnableLogitScaling, Normalize, - SelectElement, SelectEOSAndProject) -from .multimodal_preprocessors import (AudioPreprocessor, IMUPreprocessor, - PadIm2Video, PatchEmbedGeneric, - RGBDTPreprocessor, +from .helpers import (EinOpsRearrange, LearnableLogitScaling, Normalize, SelectElement, + SelectEOSAndProject) +from .multimodal_preprocessors import (AudioPreprocessor, IMUPreprocessor, PadIm2Video, + PatchEmbedGeneric, RGBDTPreprocessor, SpatioTemporalPosEmbeddingHelper, TextPreprocessor, ThermalPreprocessor) from .transformer import MultiheadAttention, SimpleTransformer @@ -155,8 +154,7 @@ def _create_modality_preprocessors( rgbt_preprocessor = RGBDTPreprocessor( img_size=[3, video_frames, 224, 224], num_cls_tokens=1, - pos_embed_fn=partial( - SpatioTemporalPosEmbeddingHelper, learnable=True), + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), rgbt_stem=rgbt_stem, depth_stem=None, ) @@ -183,8 +181,7 @@ def _create_modality_preprocessors( audio_preprocessor = AudioPreprocessor( img_size=[1, audio_num_mel_bins, audio_target_len], num_cls_tokens=1, - pos_embed_fn=partial( - SpatioTemporalPosEmbeddingHelper, learnable=True), + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), audio_stem=audio_stem, ) @@ -204,8 +201,7 @@ def _create_modality_preprocessors( depth_preprocessor = RGBDTPreprocessor( img_size=[1, 224, 224], num_cls_tokens=1, - pos_embed_fn=partial( - SpatioTemporalPosEmbeddingHelper, learnable=True), + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), rgbt_stem=None, depth_stem=depth_stem, ) @@ -225,8 +221,7 @@ def _create_modality_preprocessors( thermal_preprocessor = ThermalPreprocessor( img_size=[1, 224, 224], num_cls_tokens=1, - pos_embed_fn=partial( - SpatioTemporalPosEmbeddingHelper, learnable=True), + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), thermal_stem=thermal_stem, ) @@ -246,8 +241,7 @@ def _create_modality_preprocessors( num_cls_tokens=1, kernel_size=8, embed_dim=imu_embed_dim, - pos_embed_fn=partial( - SpatioTemporalPosEmbeddingHelper, learnable=True), + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), imu_stem=imu_stem, ) @@ -288,8 +282,8 @@ def _create_modality_trunks( imu_drop_path=0.7, ): - def instantiate_trunk(embed_dim, num_blocks, num_heads, - pre_transformer_ln, add_bias_kv, drop_path): + def instantiate_trunk(embed_dim, num_blocks, num_heads, pre_transformer_ln, + add_bias_kv, drop_path): return SimpleTransformer( embed_dim=embed_dim, num_blocks=num_blocks, @@ -441,13 +435,11 @@ def _create_modality_postprocessors(self, out_embed_dim): def forward(self, inputs, normalize=True): outputs = {} for modality_key, modality_value in inputs.items(): - reduce_list = ( - modality_value.ndim - >= 5) # Audio and Video inputs consist of multiple clips + reduce_list = (modality_value.ndim + >= 5) # Audio and Video inputs consist of multiple clips if reduce_list: B, S = modality_value.shape[:2] - modality_value = modality_value.reshape( - B * S, *modality_value.shape[2:]) + modality_value = modality_value.reshape(B * S, *modality_value.shape[2:]) if modality_value is not None: modality_value = self.modality_preprocessors[modality_key]( @@ -456,14 +448,12 @@ def forward(self, inputs, normalize=True): }) trunk_inputs = modality_value['trunk'] head_inputs = modality_value['head'] - modality_value = self.modality_trunks[modality_key]( - **trunk_inputs) - modality_value = self.modality_heads[modality_key]( - modality_value, **head_inputs) + modality_value = self.modality_trunks[modality_key](**trunk_inputs) + modality_value = self.modality_heads[modality_key](modality_value, + **head_inputs) if normalize: - modality_value = self.modality_postprocessors[ - modality_key]( - modality_value) + modality_value = self.modality_postprocessors[modality_key]( + modality_value) if reduce_list: modality_value = modality_value.reshape(B, S, -1) diff --git a/agentlego/tools/imagebind/models/multimodal_preprocessors.py b/agentlego/tools/imagebind/models/multimodal_preprocessors.py index e8ae763d..337d3099 100644 --- a/agentlego/tools/imagebind/models/multimodal_preprocessors.py +++ b/agentlego/tools/imagebind/models/multimodal_preprocessors.py @@ -47,8 +47,7 @@ def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): return pos_embed dim = pos_embed.shape[-1] # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 # noqa - pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, - torch.float32) + pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) pos_embed = nn.functional.interpolate( pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), @@ -56,8 +55,7 @@ def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): mode='bicubic', ) if updated: - pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, - torch.bfloat16) + pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return pos_embed @@ -70,14 +68,12 @@ def interpolate_pos_encoding( first_patch_idx=1, ): assert first_patch_idx == 0 or first_patch_idx == 1, 'there is 1 CLS token or none' # noqa - N = pos_embed.shape[ - 1] - first_patch_idx # since it's 1 if cls_token exists # noqa + N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists # noqa if npatch_per_img == N: return pos_embed - assert ( - patches_layout[-1] == patches_layout[-2] - ), 'Interpolation of pos embed not supported for non-square layouts' + assert (patches_layout[-1] == patches_layout[-2] + ), 'Interpolation of pos embed not supported for non-square layouts' class_emb = pos_embed[:, :first_patch_idx] pos_embed = pos_embed[:, first_patch_idx:] @@ -93,8 +89,8 @@ def interpolate_pos_encoding( num_spatial_tokens = patches_layout[1] * patches_layout[2] pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) # interpolate embedding for zeroth frame - pos_embed = interpolate_pos_encoding_2d( - npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)) + pos_embed = interpolate_pos_encoding_2d(npatch_per_img, + pos_embed[0, 0, ...].unsqueeze(0)) else: raise ValueError("This type of interpolation isn't implemented") @@ -169,13 +165,11 @@ def __init__( self.num_tokens = num_cls_tokens + num_patches self.learnable = learnable if self.learnable: - self.pos_embed = nn.Parameter( - torch.zeros(1, self.num_tokens, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) trunc_normal_(self.pos_embed, std=0.02) else: - self.register_buffer( - 'pos_embed', - get_sinusoid_encoding_table(self.num_tokens, embed_dim)) + self.register_buffer('pos_embed', + get_sinusoid_encoding_table(self.num_tokens, embed_dim)) def get_pos_embedding(self, vision_input, all_vision_tokens): input_shape = vision_input.shape @@ -260,8 +254,7 @@ def tokenize_input_and_cls_pos(self, input, stem, mask): B, -1, -1) # stole class_tokens impl from Phil Wang, thanks tokens = torch.cat((class_tokens, tokens), dim=1) if self.use_pos_embed: - pos_embed = self.pos_embedding_helper.get_pos_embedding( - input, tokens) + pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) tokens = tokens + pos_embed if self.use_type_embed: tokens = tokens + self.type_embed.expand(B, -1, -1) @@ -272,12 +265,12 @@ def forward(self, vision=None, depth=None, patch_mask=None): raise NotImplementedError() if vision is not None: - vision_tokens = self.tokenize_input_and_cls_pos( - vision, self.rgbt_stem, patch_mask) + vision_tokens = self.tokenize_input_and_cls_pos(vision, self.rgbt_stem, + patch_mask) if depth is not None: - depth_tokens = self.tokenize_input_and_cls_pos( - depth, self.depth_stem, patch_mask) + depth_tokens = self.tokenize_input_and_cls_pos(depth, self.depth_stem, + patch_mask) # aggregate tokens if vision is not None and depth is not None: @@ -349,8 +342,7 @@ def __init__( self.embed_dim = embed_dim if num_cls_tokens > 0: assert self.causal_masking is False, "Masking + CLS token isn't implemented" # noqa - self.cls_token = nn.Parameter( - torch.zeros(1, self.num_cls_tokens, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, self.num_cls_tokens, embed_dim)) self.init_parameters(init_param_style) @@ -433,8 +425,7 @@ def forward(self, x): x = x.repeat(new_shape) elif self.pad_type == 'zero': padarg = [0, 0] * len(x.shape) - padarg[2 * self.time_dim + - 1] = self.ntimes - x.shape[self.time_dim] + padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] x = nn.functional.pad(x, padarg) return x @@ -534,8 +525,7 @@ def bpe(self, token): return token + '' while True: - bigram = min( - pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram @@ -551,8 +541,7 @@ def bpe(self, token): new_word.extend(word[i:]) break - if word[i] == first and i < len(word) - 1 and word[ - i + 1] == second: + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 else: @@ -572,8 +561,7 @@ def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] - for b in token.encode('utf-8')) + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) return bpe_tokens @@ -581,9 +569,9 @@ def encode(self, text): def decode(self, tokens): text = ''.join([self.decoder[token] for token in tokens]) text = ( - bytearray([self.byte_decoder[c] for c in text - ]).decode('utf-8', - errors='replace').replace('', ' ')) + bytearray([self.byte_decoder[c] + for c in text]).decode('utf-8', + errors='replace').replace('', ' ')) return text def __call__(self, texts, context_length=None): @@ -595,8 +583,7 @@ def __call__(self, texts, context_length=None): sot_token = self.encoder['<|startoftext|>'] eot_token = self.encoder['<|endoftext|>'] - all_tokens = [[sot_token] + self.encode(text) + [eot_token] - for text in texts] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): @@ -628,8 +615,7 @@ def __init__( self.num_cls_tokens = num_cls_tokens self.kernel_size = kernel_size self.pos_embed = nn.Parameter( - torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, - embed_dim)) + torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)) if self.num_cls_tokens > 0: self.cls_token = nn.Parameter( diff --git a/agentlego/tools/imagebind/models/transformer.py b/agentlego/tools/imagebind/models/transformer.py index ae8e7337..d033f9d6 100644 --- a/agentlego/tools/imagebind/models/transformer.py +++ b/agentlego/tools/imagebind/models/transformer.py @@ -96,8 +96,7 @@ def forward(self, x): class MultiheadAttention(nn.MultiheadAttention): def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): - return super().forward( - x, x, x, need_weights=False, attn_mask=attn_mask)[0] + return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] class ViTAttention(Attention): @@ -170,8 +169,7 @@ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): x = ( x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) * self.layer_scale_gamma1) - x = x + self.drop_path(self.mlp( - self.norm_2(x))) * self.layer_scale_gamma2 + x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 return x @@ -209,9 +207,7 @@ def __init__( super().__init__() self.pre_transformer_layer = pre_transformer_layer if drop_path_type == 'progressive': - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, num_blocks) - ] + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] elif drop_path_type == 'uniform': dpr = [drop_path_rate for i in range(num_blocks)] else: diff --git a/agentlego/tools/object_detection/object_detection.py b/agentlego/tools/object_detection/object_detection.py index 003f30f7..82c1a3be 100644 --- a/agentlego/tools/object_detection/object_detection.py +++ b/agentlego/tools/object_detection/object_detection.py @@ -1,8 +1,4 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import load_or_build_object, require from ..base import BaseTool @@ -11,31 +7,22 @@ class ObjectDetection(BaseTool): """A tool to detection all objects defined in COCO 80 classes. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The model name used to detect texts. Which can be found in the ``MMDetection`` repository. Defaults to ``rtmdet_l_8xb32-300e_coco``. - device (str): The device to load the model. Defaults to 'cpu'. + device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='DetectAllObjects', - description=('A useful tool when you only want to detect the picture ' - 'or detect all objects in the picture. like: detect all ' - 'objects. '), - inputs=['image'], - outputs=['image'], - ) + + default_desc = 'The tool can detect all common objects in the picture.' @require('mmdet>=3.1.0') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, model: str = 'rtmdet_l_8xb32-300e_coco', - device: str = 'cpu'): - super().__init__(toolmeta=toolmeta, parser=parser) + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta=toolmeta) self.model = model self.device = device @@ -43,9 +30,30 @@ def setup(self): from mmdet.apis import DetInferencer self._inferencer = load_or_build_object( DetInferencer, model=self.model, device=self.device) + self.classes = self._inferencer.model.dataset_meta['classes'] + + def apply( + self, + image: ImageIO, + ) -> Annotated[str, + Info('All detected objects, include object name, ' + 'bbox in (x1, y1, x2, y2) format, ' + 'and detection score.')]: + from mmdet.structures import DetDataSample - def apply(self, image: ImageIO) -> ImageIO: - image = image.to_path() - results = self._inferencer(image, return_vis=True) - output_image = results['visualization'][0] - return ImageIO(output_image) + results = self._inferencer( + image.to_array()[:, :, ::-1], + return_datasamples=True, + ) + data_sample = results['predictions'][0] + preds: DetDataSample = data_sample.pred_instances + preds = preds[preds.scores > 0.5] + pred_descs = [] + pred_tmpl = '{} ({:.0f}, {:.0f}, {:.0f}, {:.0f}), score {:.0f}' + for label, bbox, score in zip(preds.labels, preds.bboxes, preds.scores): + label = self.classes[label] + pred_descs.append(pred_tmpl.format(label, *bbox, score * 100)) + if len(pred_descs) == 0: + return 'No object found.' + else: + return '\n'.join(pred_descs) diff --git a/agentlego/tools/object_detection/text_to_bbox.py b/agentlego/tools/object_detection/text_to_bbox.py index bf181962..cfbec79b 100644 --- a/agentlego/tools/object_detection/text_to_bbox.py +++ b/agentlego/tools/object_detection/text_to_bbox.py @@ -1,8 +1,4 @@ -from typing import Callable, Tuple, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import load_or_build_object, require from ..base import BaseTool @@ -11,33 +7,23 @@ class TextToBbox(BaseTool): """A tool to detection the given object. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The model name used to detect texts. Which can be found in the ``MMDetection`` repository. Defaults to ``glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365``. device (str): The device to load the model. Defaults to 'cpu'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='DetectGivenObject', - description='The tool can detect the object location according to ' - 'description in English. It will return an image with a bbox of the ' - 'detected object, and the coordinates of bbox. If specify ' - '`top1` to false, return all detected objects instead the single ' - 'object with highest score.', - inputs=['image', 'text', 'bool'], - outputs=['image', 'text'], - ) + + default_desc = ('The tool can detect the object location according to ' + 'description.') @require('mmdet>=3.1.0') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, model: str = 'glip_atss_swin-t_b_fpn_dyhead_pretrain_obj365', - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta=toolmeta) self.model = model self.device = device @@ -47,10 +33,16 @@ def setup(self): DetInferencer, model=self.model, device=self.device) self._visualizer = self._inferencer.visualizer - def apply(self, - image: ImageIO, - text: str, - top1: bool = True) -> Tuple[ImageIO, str]: + def apply( + self, + image: ImageIO, + text: Annotated[str, Info('The object description in English.')], + top1: Annotated[bool, + Info('If true, return the object with highest score. ' + 'If false, return all detected objects.')] = True, + ) -> Annotated[str, + Info('Detected objects, include bbox in ' + '(x1, y1, x2, y2) format, and detection score.')]: from mmdet.structures import DetDataSample results = self._inferencer( @@ -61,24 +53,17 @@ def apply(self, data_sample = results['predictions'][0] preds: DetDataSample = data_sample.pred_instances - pred_tmpl = ('bbox ({:.0f}, {:.0f}, {:.0f}, {:.0f}), ' - 'score {:.0f}') if len(preds) == 0: - pred_str = 'No object found.' - output_image = image - else: - if top1: - preds = preds[preds.scores.topk(1).indices] - else: - preds = preds[preds.scores > 0.5] - pred_descs = [] - for bbox, score in zip(preds.bboxes, preds.scores): - pred_descs.append(pred_tmpl.format(*bbox, score * 100)) - pred_str = '\n'.join(pred_descs) + return 'No object found.' - data_sample.pred_instances = preds - self._visualizer.add_datasample( - 'vis', image.to_array(), data_sample, draw_gt=False) - output_image = ImageIO(self._visualizer.get_image()) + pred_tmpl = '({:.0f}, {:.0f}, {:.0f}, {:.0f}), score {:.0f}' + if top1: + preds = preds[preds.scores.topk(1).indices] + else: + preds = preds[preds.scores > 0.5] + pred_descs = [] + for bbox, score in zip(preds.bboxes, preds.scores): + pred_descs.append(pred_tmpl.format(*bbox, score * 100)) + pred_str = '\n'.join(pred_descs) - return output_image, pred_str + return pred_str diff --git a/agentlego/tools/ocr/ocr.py b/agentlego/tools/ocr/ocr.py index 6b38c65d..92a34ac8 100644 --- a/agentlego/tools/ocr/ocr.py +++ b/agentlego/tools/ocr/ocr.py @@ -1,8 +1,6 @@ -from typing import Callable, Sequence, Union +from typing import Sequence, Tuple, Union -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import load_or_build_object, require from ..base import BaseTool @@ -11,33 +9,28 @@ class OCR(BaseTool): """A tool to recognize the optical characters on an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. lang (str | Sequence[str]): The language to be recognized. Defaults to 'en'. + line_group_tolerance (int): The line group tolerance threshold. + Defaults to -1, which means to disable the line group method. device (str | bool): The device to load the model. Defaults to True, which means automatically select device. **read_args: Other keyword arguments for read text. Please check the `EasyOCR docs `_. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='OCR', - description='This tool can recognize all text on the input image.', - inputs=['image'], - outputs=['text'], - ) + + default_desc = 'This tool can recognize all text on the input image.' @require('easyocr') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, lang: Union[str, Sequence[str]] = 'en', + line_group_tolerance: int = -1, device: Union[bool, str] = True, - line_group_tolerance = -1, + toolmeta=None, **read_args): - super().__init__(toolmeta=toolmeta, parser=parser) + super().__init__(toolmeta=toolmeta) if isinstance(lang, str): lang = [lang] self.lang = list(lang) @@ -56,32 +49,52 @@ def setup(self): self._reader: easyocr.Reader = load_or_build_object( easyocr.Reader, self.lang, gpu=self.device) - def apply(self, image: ImageIO) -> str: + def apply( + self, + image: ImageIO, + ) -> Annotated[str, + Info('OCR results, include bbox in x1, y1, x2, y2 format ' + 'and the recognized text.')]: image = image.to_array() + results = self._reader.readtext(image, detail=1, **self.read_args) + results = [(self.extract_bbox(item[0]), item[1]) for item in results] + if self.line_group_tolerance >= 0: - results = self._reader.readtext(image, **self.read_args) - results.sort(key=lambda x: x[0][0][1]) + results.sort(key=lambda x: x[0][1]) + + groups = [] + group = [] - lines = [] - line = [results[0]] + for item in results: + if not group: + group.append(item) + continue - for result in results[1:]: - if abs(result[0][0][1] - line[0][0][0][1]) <= self.line_group_tolerance: - line.append(result) + if abs(item[0][1] - group[-1][0][1]) <= self.line_group_tolerance: + group.append(item) else: - lines.append(line) - line = [result] + groups.append(group) + group = [item] - lines.append(line) + groups.append(group) - ocr_results = [] - for line in lines: + results = [] + for group in groups: # For each line, sort the elements by their left x-coordinate and join their texts - sorted_line = sorted(line, key=lambda x: x[0][0][0]) - text_line = ' '.join(item[1] for item in sorted_line) - ocr_results.append(text_line) - else: - ocr_results = self._reader.readtext(image, detail=0, **self.read_args) - outputs = '\n'.join(ocr_results) + line = sorted(group, key=lambda x: x[0][0]) + bboxes = [item[0] for item in line] + text = ' '.join(item[1] for item in line) + results.append((self.extract_bbox(bboxes), text)) + + outputs = [] + for item in results: + outputs.append('({}, {}, {}, {}) {}'.format(*item[0], item[1])) + outputs = '\n'.join(outputs) return outputs + + @staticmethod + def extract_bbox(char_boxes) -> Tuple[int, int, int, int]: + xs = [int(box[0]) for box in char_boxes] + ys = [int(box[1]) for box in char_boxes] + return min(xs), min(ys), max(xs), max(ys) diff --git a/agentlego/tools/remote.py b/agentlego/tools/remote.py index 31b1e055..e8d79864 100644 --- a/agentlego/tools/remote.py +++ b/agentlego/tools/remote.py @@ -1,110 +1,245 @@ import base64 -from io import BytesIO -from typing import Dict, List, Optional, Union -from urllib.parse import urljoin +from io import BytesIO, IOBase +from typing import Any, Dict, List, Optional +from urllib.parse import urljoin, urlsplit import requests from agentlego.parsers import DefaultParser -from agentlego.schema import Parameter, ToolMeta +from agentlego.schema import Parameter from agentlego.tools.base import BaseTool -from agentlego.types import AudioIO, ImageIO -from agentlego.utils import temp_path +from agentlego.types import AudioIO, File, ImageIO +from agentlego.utils.openapi import (APIOperation, APIResponseProperty, OpenAPISpec, + operation_toolmeta) class RemoteTool(BaseTool): + """Create a tool from an OpenAPI Specification (OAS). + + It supports `OpenAPI v3.1.0 `_ + + Examples: + 1. Construct a series of tools from an OAS. + + .. code::python + from agentlego.tools import RemoteTool + + tools = RemoteTool.from_openapi('http://localhost:16180/openapi.json') + + In this situation, you need to provide the path or URL of an OAS, and each + method will be constructed as a tool. + + 2. Construct a single tool from URL. + + .. code::python + from agentlego.tools import RemoteTool + + tool = RemoteTool.from_url('http://localhost:16180/ImageDescription') + + In this situation, you need to provide the URL of the tool endpoint. + By default, it will get the OAS from ``http://localhost:16180/openapi.json``, + and use the operation ``post`` at path ``/ImageDescription`` to construct the + tool. + + Notice: + The ``RemoteTool`` works well with the ``agentlego-server``. + """ # noqa: E501 def __init__( self, - url, - toolmeta: Union[dict, ToolMeta, None] = None, - parameters: Optional[Dict[str, Parameter]] = None, - parser=DefaultParser, + operation: APIOperation, + headers: Optional[dict] = None, + auth: Optional[tuple] = None, + toolkit: Optional[str] = None, ): - if not url.endswith('/'): - url += '/' - self.url = url - - if toolmeta is None or parameters is None: - toolmeta, parameters = self.request_meta() - - self._parameters = parameters - super().__init__(toolmeta, parser) - - def request_meta(self): - url = urljoin(self.url, 'meta') - response = requests.get(url).json() - toolmeta = response['toolmeta'] - parameters = { - p['name']: Parameter(**p) - for p in response['parameters'] - } - return toolmeta, parameters + self.operation = operation + self.url = urljoin(operation.base_url, operation.path) + self.headers = headers + self.auth = auth + self.method = operation.method.name + self.toolmeta = operation_toolmeta(operation) + self.toolkit = toolkit + self.set_parser(DefaultParser) + self._is_setup = False + + def _construct_path(self, kwargs: Dict[str, str]) -> str: + """Construct url according to path parameters from inputs.""" + path = self.url + for param in self.operation.path_params: + path = path.replace(f'{{{param}}}', str(kwargs.pop(param, ''))) + return path + + def _construct_query(self, kwargs: Dict[str, str]) -> Dict[str, str]: + """Construct query parameters from inputs.""" + query_params = {} + for param in self.operation.query_params: + if param in kwargs: + query_params[param] = kwargs.pop(param) + return query_params + + def _construct_body(self, kwargs: Dict[str, str]) -> Dict[str, Any]: + """Construct request body parameters from inputs.""" + if not self.operation.request_body or not self.operation.body_params: + return {} + + media_type = self.operation.request_body.media_type + + body = {} + for param in self.operation.body_params: + if param in kwargs: + value = kwargs.pop(param) + if isinstance(value, (ImageIO, AudioIO, File)): + value = value.to_file() + body[param] = value + + if media_type == 'multipart/form-data': + body = { + k: (k, v) if isinstance(v, IOBase) else (None, v) + for k, v in body.items() + } + return {'files': body} + elif media_type == 'application/json': + return {'json': body} + elif media_type == 'application/x-www-form-urlencoded': + return {'data': body} + else: + raise NotImplementedError(f'Unsupported media type `{media_type}`') + + @staticmethod + def _parse_output(out: Any, p: Parameter): + if p.type is ImageIO: + file = BytesIO(base64.b64decode(out)) + out = ImageIO.from_file(file) + elif p.type is AudioIO: + file = BytesIO(base64.b64decode(out)) + out = AudioIO.from_file(file) + elif p.type is File: + file = BytesIO(base64.b64decode(out)) + out = File.from_file(file, filetype=p.filetype) + return out def apply(self, *args, **kwargs): - for arg, arg_name in zip(args, self.parameters): - kwargs[arg_name] = arg - - form = {} - for k, v in kwargs.items(): - if isinstance(v, (ImageIO, AudioIO)): - file = v.to_path() - form[k] = (file, open(file, 'rb')) - else: - form[k] = (None, v) + for arg, p in zip(args, self.inputs): + kwargs[p.name] = arg + + request_args = { + 'url': self._construct_path(kwargs), + 'params': self._construct_query(kwargs), + **self._construct_body(kwargs) + } - url = urljoin(self.url, 'call') try: - response = requests.post(url, files=form).json() + response = requests.request( + method=self.method, + **request_args, + headers=self.headers, + auth=self.auth, + ) except requests.ConnectionError as e: raise ConnectionError( f'Failed to connect the remote tool `{self.name}`.') from e - except requests.JSONDecodeError: - raise RuntimeError('Unexcepted server response.') - - if isinstance(response, dict): - if 'error' in response: - # Tool internal error - raise RuntimeError(response['error']) - elif 'detail' in response: - # FastAPI validation error - msg = response['detail']['msg'] - err_type = response['detail']['type'] - raise ValueError(f'{err_type}({msg})') - - parsed_res = [] - for res in response: - if not isinstance(res, dict): - data = res - elif res['type'] == 'image': - from PIL import Image - file = BytesIO(base64.decodebytes(res['data'].encode('ascii'))) - data = ImageIO(Image.open(file)) - elif res['type'] == 'audio': - filename = temp_path('audio', '.wav') - with open(filename, 'wb') as f: - f.write(base64.decodebytes(res['data'].encode('ascii'))) - data = AudioIO(filename) - parsed_res.append(data) - - return parsed_res[0] if len(parsed_res) == 1 else tuple(parsed_res) + if response.status_code != 200: + if response.headers.get('Content-Type') == 'application/json': + content = response.json() + else: + content = response.content.decode() + raise RuntimeError(f'Failed to call the remote tool `{self.name}` ' + f'because of {response.reason}.\nResponse: {content}') + try: + response = response.json() + except requests.JSONDecodeError as e: + raise RuntimeError(f'Failed to call the remote tool `{self.name}` ' + 'because of unknown response.\n' + f'Response: {response.content.decode()}') from e + + response_schema = self.operation.responses + if response_schema is None or response_schema.get('200') is None: + # Directly use string if the response schema is not specified + return str(response) + + out_props = response_schema['200'].properties + + if isinstance(out_props, APIResponseProperty): + # Single output + return self._parse_output(response, self.outputs[0]) + elif isinstance(out_props, list): + # Multiple output + return tuple( + self._parse_output(out, p) for out, p in zip(response, self.outputs)) + else: + # Dict-style output + return { + p.name: self._parse_output(out, p) + for out, p in zip(response, self.outputs) + } + + @classmethod + def from_server(cls, url: str, **kwargs) -> List['RemoteTool']: + return cls.from_openapi(url=urljoin(url, '/openapi.json'), **kwargs) @classmethod - def from_server(cls, url: str) -> List['RemoteTool']: - response = requests.get(url).json() + def from_openapi( + cls, + url: str, + **kwargs, + ) -> List['RemoteTool']: + """Construct a series of remote tools from the specified OpenAPI + Specification. + + Args: + url (str): The path or URL of the OpenAPI Specification file. + headers (str | None): The headers to send in the requests. Defaults to None. + auth (tuple | None): Auth tuple to enable Basic/Digest/Custom HTTP Auth. + Defaults to None. + """ + if url.startswith('http'): + spec = OpenAPISpec.from_url(url) + else: + spec = OpenAPISpec.from_file(url) + + toolkit = spec.info.title.replace(' ', '_') + tools = [] - for tool_info in response: + for path, method in spec.iter_all_method(): + operation = APIOperation.from_openapi_spec(spec, path, method) tool = cls( - url=urljoin(url, tool_info['domain'] + '/'), - toolmeta=tool_info['toolmeta'], - parameters={ - p['name']: Parameter(**p) - for p in tool_info['parameters'] - }, + operation=operation, + toolkit=toolkit, + **kwargs, ) tools.append(tool) return tools - @property - def parameters(self) -> Dict[str, Parameter]: - return self._parameters + @classmethod + def from_url(cls, + url: str, + method: str = 'post', + openapi: Optional[str] = None, + path: Optional[str] = None, + **kwargs) -> 'RemoteTool': + """Construct a remote tool from the specified URL endpoint. + + Args: + url (str): The URL path of the remote tool. + method (str): The method of the operation. Defaults to 'post', + openapi (str | None): The OAS path or URL. Defaults to None, which means to + use ``/openapi.json``. + path (str | None): The path in the OAS. Defaults to None, which means to use + ``urlsplit(url).path``. + headers (str | None): The headers to send in the requests. Defaults to None. + auth (tuple | None): Auth tuple to enable Basic/Digest/Custom HTTP Auth. + Defaults to None. + """ + # The default openapi file for the tool server. + openapi = openapi or urljoin(url, '/openapi.json') + path = path or urlsplit(url).path + + if openapi.startswith('http'): + spec = OpenAPISpec.from_url(openapi) + else: + spec = OpenAPISpec.from_file(openapi) + + toolkit = spec.info.title.replace(' ', '_') + + operation = APIOperation.from_openapi_spec(spec, path, method) + return cls(operation=operation, toolkit=toolkit, **kwargs) diff --git a/agentlego/tools/search/google.py b/agentlego/tools/search/google.py index d601ffb3..060bf47d 100644 --- a/agentlego/tools/search/google.py +++ b/agentlego/tools/search/google.py @@ -1,10 +1,8 @@ import os -from typing import Callable, List, Tuple, Union +from typing import List, Tuple, Union import requests -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from ..base import BaseTool @@ -19,10 +17,6 @@ class GoogleSearch(BaseTool): To get an Serper.dev API key. you can create it at https://serper.dev Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. api_key (str): API key to use for serper google search API. Defaults to 'env', which means to use the `SERPER_API_KEY` in the environ variable. @@ -37,6 +31,8 @@ class GoogleSearch(BaseTool): Defaults to False. k (int): select first k results in the search results as response. Defaults to 10. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ result_key_for_type = { @@ -46,31 +42,24 @@ class GoogleSearch(BaseTool): 'search': 'organic', } - DEFAULT_TOOLMETA = ToolMeta( - name='GoogleSearch', - description=('The tool can search the input query text from Google ' - 'and return the related results'), - inputs=['text'], - outputs=['text'], - ) + default_desc = ('The tool can search the input query text from Google ' + 'and return the related results.') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, api_key: str = 'env', timeout: int = 5, search_type: str = 'search', max_out_len: int = 1500, with_url: bool = False, - k: int = 10) -> None: - super().__init__(toolmeta=toolmeta, parser=parser) + k: int = 10, + toolmeta=None) -> None: + super().__init__(toolmeta=toolmeta) if api_key == 'env': api_key = os.environ.get('SERPER_API_KEY', None) if not api_key: - raise ValueError( - 'Please set Serper API key either in the environment ' - ' as SERPER_API_KEY or pass it as `api_key` parameter.') + raise ValueError('Please set Serper API key either in the environment ' + ' as SERPER_API_KEY or pass it as `api_key` parameter.') self.api_key = api_key self.timeout = timeout @@ -123,13 +112,11 @@ def _parse_results(self, results: dict) -> Union[str, List[str]]: if kg.get('description'): content += kg['description'] if kg.get('attributes'): - attributes = ', '.join(f'{k}: {v}' - for k, v in kg['attributes'].items()) + attributes = ', '.join(f'{k}: {v}' for k, v in kg['attributes'].items()) content += f'({attributes})' snippets.append(content) - for item in results[self.result_key_for_type[ - self.search_type]][:self.k]: + for item in results[self.result_key_for_type[self.search_type]][:self.k]: content = '' if item.get('title'): content += item['title'] + ': ' diff --git a/agentlego/tools/segmentation/segment_anything.py b/agentlego/tools/segmentation/segment_anything.py index 030ab258..9ef56976 100644 --- a/agentlego/tools/segmentation/segment_anything.py +++ b/agentlego/tools/segmentation/segment_anything.py @@ -1,17 +1,13 @@ import random from pathlib import Path -from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple -import cv2 import numpy as np from PIL import Image -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import (download_checkpoint, download_url_to_file, - is_package_available, load_or_build_object, - require) + is_package_available, load_or_build_object, require) from ..base import BaseTool if is_package_available('torch'): @@ -96,8 +92,8 @@ def set_image( # Transform the image to the form expected by the model input_image = self.transform.apply_image(image) input_image_torch = torch.as_tensor(input_image, device=self.device) - input_image_torch = input_image_torch.permute( - 2, 0, 1).contiguous()[None, :, :, :] + input_image_torch = input_image_torch.permute(2, 0, + 1).contiguous()[None, :, :, :] return self.set_torch_image(input_image_torch, image.shape[:2]) @@ -116,8 +112,7 @@ def set_torch_image( original_image_size (tuple(int, int)): The size of the image before transformation, in (H, W) format. """ - assert (len(transformed_image.shape) == 4 - and transformed_image.shape[1] == 3 + assert (len(transformed_image.shape) == 4 and transformed_image.shape[1] == 3 and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size), ( 'set_torch_image input must be BCHW with long side' @@ -186,21 +181,18 @@ def predict( coords_torch, labels_torch = None, None box_torch, mask_input_torch = None, None if point_coords is not None: - assert ( - point_labels is not None - ), 'point_labels must be supplied if point_coords is supplied.' - point_coords = self.transform.apply_coords( - point_coords, features['original_size']) + assert (point_labels is not None + ), 'point_labels must be supplied if point_coords is supplied.' + point_coords = self.transform.apply_coords(point_coords, + features['original_size']) coords_torch = torch.as_tensor( point_coords, dtype=torch.float, device=self.device) labels_torch = torch.as_tensor( point_labels, dtype=torch.int, device=self.device) - coords_torch, labels_torch = coords_torch[ - None, :, :], labels_torch[None, :] + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] if box is not None: box = self.transform.apply_boxes(box, features['original_size']) - box_torch = torch.as_tensor( - box, dtype=torch.float, device=self.device) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) box_torch = box_torch[None, :] if mask_input is not None: mask_input_torch = torch.as_tensor( @@ -293,8 +285,7 @@ def predict_torch( ) # Upscale the masks to the original image resolution - masks = self.model.postprocess_masks(low_res_masks, - features['input_size'], + masks = self.model.postprocess_masks(low_res_masks, features['input_size'], features['original_size']) if not return_logits: @@ -314,30 +305,23 @@ class SegmentAnything(BaseTool): """A tool to segment all objects on an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. sam_model (str): The model name used to inference. Which can be found in the ``segment_anything`` repository. Defaults to ``sam_vit_h_4b8939.pth``. - device (str): The device to load the model. Defaults to 'cpu'. + device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='SegmentAnything', - description='This tool can segment all items in the image and ' - 'return a segmentation result image', - inputs=['image'], - outputs=['image'], - ) + + default_desc = ('This tool can segment all items in the image and ' + 'return a segmentation result image.') @require('segment_anything') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, sam_model: str = 'sam_vit_h_4b8939.pth', - device: str = 'cpu'): - super().__init__(toolmeta=toolmeta, parser=parser) + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta=toolmeta) self.sam_model = sam_model self.device = device @@ -345,20 +329,17 @@ def setup(self): self.sam, self.sam_predictor = load_sam_and_predictor( self.sam_model, device=self.device) - def apply(self, image: ImageIO) -> ImageIO: - image = image.to_path() - annos = self.segment_anything(image) + def apply(self, image: ImageIO + ) -> Annotated[ImageIO, Info('The segmentation result image.')]: + annos = self.segment_anything(image.to_array()) full_img, _ = self.show_annos(annos) return ImageIO(full_img) - def segment_anything(self, img_path): + def segment_anything(self, img): if not self._is_setup: self.setup() self._is_setup = True - img = cv2.imread(img_path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - from segment_anything import SamAutomaticMaskGenerator mask_generator = SamAutomaticMaskGenerator(self.sam) @@ -386,8 +367,8 @@ def segment_by_mask(self, mask, features): return res_masks[np.argmax(scores), :, :] - def get_detection_map(self, img_path): - annos = self.segment_anything(img_path) + def get_detection_map(self, img): + annos = self.segment_anything(img) _, detection_map = self.show_anns(annos) return detection_map @@ -433,10 +414,6 @@ class SegmentObject(BaseTool): """A tool to segment all objects on an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. sam_model (str): The model name used to inference. Which can be found in the ``segment_anything`` repository. Defaults to ``sam_vit_h_4b8939.pth``. @@ -444,26 +421,23 @@ class SegmentObject(BaseTool): Which can be found in the ``MMDetection`` repository. Defaults to ``glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365``. device (str): The device to load the model. Defaults to 'cpu'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='SegmentSpecifiedObject', - description=('This tool can segment the specified kind of ' - 'objects in the input image, and return the ' - 'segmentation result image.'), - inputs=['image', 'text'], - outputs=['image'], - ) + + default_desc = ('This tool can segment the specified kind of objects in ' + 'the input image, and return the segmentation ' + 'result image.') @require('segment_anything') @require('mmdet>=3.1.0') - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - sam_model: str = 'sam_vit_h_4b8939.pth', - grounding_model: str = ( - 'glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365'), - device: str = 'cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__( + self, + sam_model: str = 'sam_vit_h_4b8939.pth', + grounding_model: str = ('glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365'), + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta=toolmeta) self.sam_model = sam_model self.grounding_model = grounding_model self.device = device @@ -477,7 +451,11 @@ def setup(self): self.sam, self.sam_predictor = load_sam_and_predictor( self.sam_model, device=self.device) - def apply(self, image: ImageIO, text: str) -> ImageIO: + def apply( + self, + image: ImageIO, + text: Annotated[str, Info('The object to segment.')], + ) -> Annotated[ImageIO, Info('The segmentation result image.')]: results = self.grounding( inputs=image.to_array()[:, :, ::-1], # Input BGR @@ -489,8 +467,8 @@ def apply(self, image: ImageIO, text: str) -> ImageIO: boxes_filt = results.bboxes pred_phrases = results.label_names - output_image = self.segment_image_with_boxes(image.to_array(), - boxes_filt, pred_phrases) + output_image = self.segment_image_with_boxes(image.to_array(), boxes_filt, + pred_phrases) return ImageIO(output_image) def get_mask_with_boxes(self, image, boxes_filt): @@ -523,10 +501,7 @@ def segment_image_with_boxes(self, image, boxes_filt, pred_phrases): # draw output image for mask in masks: image = self.show_mask( - mask[0].cpu().numpy(), - image, - random_color=True, - transparency=0.3) + mask[0].cpu().numpy(), image, random_color=True, transparency=0.3) return image @@ -546,6 +521,7 @@ def show_mask(self, visualized on top of the image. transparenccy: the transparency of the segmentation mask """ + import cv2 if random_color: color = np.concatenate([np.random.random(3)], axis=0) @@ -554,8 +530,7 @@ def show_mask(self, h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) * 255 - image = cv2.addWeighted(image, 0.7, mask_image.astype('uint8'), - transparency, 0) + image = cv2.addWeighted(image, 0.7, mask_image.astype('uint8'), transparency, 0) return image def show_box(self, box, ax, label): diff --git a/agentlego/tools/segmentation/semantic_segmentation.py b/agentlego/tools/segmentation/semantic_segmentation.py index 381acd27..39da2785 100644 --- a/agentlego/tools/segmentation/semantic_segmentation.py +++ b/agentlego/tools/segmentation/semantic_segmentation.py @@ -1,7 +1,3 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import ImageIO from agentlego.utils import load_or_build_object, require from ..base import BaseTool @@ -11,31 +7,24 @@ class SemanticSegmentation(BaseTool): """A tool to conduct semantic segmentation on an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. seg_model (str): The model name used to inference. Which can be found in the ``MMSegmentation`` repository. Defaults to ``mask2former_r50_8xb2-90k_cityscapes-512x1024``. - device (str): The device to load the model. Defaults to 'cpu'. + device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='SemanticSegmentOnUrbanScene', - description='This tool can segment all items in the input image and ' - 'return a segmentation result image. It focus on urban scene images.', - inputs=['image'], - outputs=['image'], - ) + + default_desc = ('This tool can segment all items in the input image and ' + 'return a segmentation result image. ' + 'It focus on urban scene images.') @require('mmsegmentation') - def __init__( - self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - seg_model: str = 'mask2former_r50_8xb2-90k_cityscapes-512x1024', - device: str = 'cpu'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, + seg_model: str = 'mask2former_r50_8xb2-90k_cityscapes-512x1024', + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta=toolmeta) self.seg_model = seg_model self.device = device diff --git a/agentlego/tools/speech_text/speech_to_text.py b/agentlego/tools/speech_text/speech_to_text.py index d17df1f5..b4eaca74 100644 --- a/agentlego/tools/speech_text/speech_to_text.py +++ b/agentlego/tools/speech_text/speech_to_text.py @@ -1,11 +1,5 @@ -from typing import Callable, Union - -from mmengine.utils import apply_to - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import AudioIO -from agentlego.utils import is_package_available, load_or_build_object, require +from agentlego.utils import apply_to, is_package_available, load_or_build_object, require from ..base import BaseTool if is_package_available('torch'): @@ -15,11 +9,9 @@ import torchaudio -@require('torchaudio') def resampling_audio(audio: AudioIO, new_rate): tensor, ori_sampling_rate = audio.to_tensor(), audio.sampling_rate - tensor = torchaudio.functional.resample(tensor, ori_sampling_rate, - new_rate) + tensor = torchaudio.functional.resample(tensor, ori_sampling_rate, new_rate) return AudioIO(tensor, sampling_rate=new_rate) @@ -27,38 +19,25 @@ class SpeechToText(BaseTool): """A tool to recognize speech and convert to text. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The model name used to inference. Which can be found in the ``HuggingFace`` model page. Defaults to ``openai/whisper-base``. device (str): The device to load the model. Defaults to 'cpu'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='Transcriber', - description='This is a tool that transcribes an audio into text.', - inputs=['audio'], - outputs=['text'], - ) + default_desc = 'The tool can translate spoken language audio into text.' - @require(('torch', 'transformers')) - def __init__( - self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - model='openai/whisper-base', - device='cuda', - ): - super().__init__(toolmeta, parser) + @require(('torch', 'transformers', 'torchaudio')) + def __init__(self, model='openai/whisper-base', device='cuda', toolmeta=None): + super().__init__(toolmeta) self.model_name = model self.device = device def setup(self) -> None: - from transformers.models.whisper import ( - WhisperForConditionalGeneration, WhisperProcessor) + from transformers.models.whisper import (WhisperForConditionalGeneration, + WhisperProcessor) self.processor = load_or_build_object(WhisperProcessor.from_pretrained, self.model_name) self.model = load_or_build_object( @@ -73,11 +52,9 @@ def apply(self, audio: AudioIO) -> str: audio.to_tensor().numpy().reshape(-1), return_tensors='pt', sampling_rate=target_sampling_rate).input_features - encoded_inputs = apply_to(encoded_inputs, - lambda x: isinstance(x, torch.Tensor), + encoded_inputs = apply_to(encoded_inputs, lambda x: isinstance(x, torch.Tensor), lambda x: x.to(self.device)) outputs = self.model.generate(inputs=encoded_inputs) outputs = apply_to(outputs, lambda x: isinstance(x, torch.Tensor), lambda x: x.to('cpu')) - return self.processor.batch_decode( - outputs, skip_special_tokens=True)[0] + return self.processor.batch_decode(outputs, skip_special_tokens=True)[0] diff --git a/agentlego/tools/speech_text/text_to_speech.py b/agentlego/tools/speech_text/text_to_speech.py index 6983754f..9d0fd28e 100644 --- a/agentlego/tools/speech_text/text_to_speech.py +++ b/agentlego/tools/speech_text/text_to_speech.py @@ -1,62 +1,68 @@ from io import BytesIO -from typing import Callable, Union +from typing import Union import requests -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import AudioIO +from agentlego.types import Annotated, AudioIO, Info from agentlego.utils import is_package_available, require from ..base import BaseTool if is_package_available('torch'): import torch +LANG_CODES = { + 'zh-cn': 'Chinese', + 'en': 'English', + 'es': 'Spanish', + 'fr': 'French', + 'de': 'German', + 'it': 'Italian', + 'tr': 'Turkish', + 'ru': 'Russian', + 'ar': 'Arabic', + 'ja': 'Japanese', + 'ko': 'Korean', + # "pt": "Portuguese", + # "pl": "Polish", + # "nl": "Dutch", + # "cs": "Czech", + # "hu": "Hungarian", + # "hi": "Hindi", +} + class TextToSpeech(BaseTool): """A tool to convert input text to speech audio. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. model (str): The model name used to inference. Which can be found - in the ``HuggingFace`` model page. - Defaults to ``microsoft/speecht5_tts``. - post_processor (str): The post-processor of the output audio. - Defaults to ``microsoft/speecht5_hifigan``. + in https://github.com/coqui-ai/TTSHuggingFace . + Defaults to ``tts_models/multilingual/multi-dataset/xtts_v2``. speaker_embeddings (str | dict): The speaker embedding - of the Speech-T5 model. Defaults to an embedding from - ``Matthijs/speecht5-tts-demo``. + of the TTS model. Defaults to a default embedding. device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - SAMPLING_RATE = 16000 - DEFAULT_TOOLMETA = ToolMeta( - name='TextReader', - description='This is a tool that can speak the input text into audio.', - inputs=['text'], - outputs=['audio'], - ) + + SPEAKER_EMBEDDING = ('http://download.openmmlab.com/agentlego/default_voice.pth') + default_desc = ('The tool can speak the input text into audio. The language code ' + 'should be one of ' + + ', '.join(f"'{k}' ({v})" for k, v in LANG_CODES.items()) + '.') @require('TTS', 'langid') def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, model: str = 'tts_models/multilingual/multi-dataset/xtts_v2', - speaker_embeddings: Union[str, dict] = ( - 'http://download.openmmlab.com/agentlego/' - 'default_voice.pth'), - sampling_rate=16000, - device='cuda'): - super().__init__(toolmeta=toolmeta, parser=parser) + speaker_embeddings: Union[str, dict] = SPEAKER_EMBEDDING, + device='cuda', + toolmeta=None): + super().__init__(toolmeta=toolmeta) self.model_name = model if isinstance(speaker_embeddings, str): with BytesIO(requests.get(speaker_embeddings).content) as f: speaker_embeddings = torch.load(f, map_location=device) self.speaker_embeddings = speaker_embeddings - self.sampling_rate = sampling_rate self.device = device def setup(self) -> None: @@ -65,16 +71,20 @@ def setup(self) -> None: self.model = TTS(self.model_name).to(self.device).synthesizer.tts_model self.model: Xtts - def apply(self, text: str) -> AudioIO: - import langid - langid.set_languages([ - lang if lang != 'zh-cn' else 'zh' - for lang in self.model.config.languages - ]) - lang = langid.classify(text)[0] - lang = 'zh-cn' if lang == 'zh' else lang - text = text.replace(',', ', ').replace('。', '. ').replace( - '?', '? ').replace('!', '! ').replace('、', ', ').strip() + def apply( + self, + text: str, + lang: Annotated[str, Info('The language code of text.')] = 'auto', + ) -> AudioIO: + if lang == 'auto': + import langid + langid.set_languages( + [lang if lang != 'zh-cn' else 'zh' for lang in LANG_CODES]) + lang = langid.classify(text)[0] + lang = 'zh-cn' if lang == 'zh' else lang + + text = text.replace(',', ', ').replace('。', '. ').replace('?', '? ').replace( + '!', '! ').replace('、', ', ').strip() out = self.model.inference( text, language=lang, @@ -83,5 +93,4 @@ def apply(self, text: str) -> AudioIO: **self.speaker_embeddings, ) - return AudioIO( - torch.tensor(out['wav']).unsqueeze(0), sampling_rate=24000) + return AudioIO(torch.tensor(out['wav']).unsqueeze(0), sampling_rate=24000) diff --git a/agentlego/tools/translation/translation.py b/agentlego/tools/translation/translation.py index a100bf65..5a980b74 100644 --- a/agentlego/tools/translation/translation.py +++ b/agentlego/tools/translation/translation.py @@ -1,14 +1,11 @@ -from typing import Callable, Union from urllib.parse import quote_plus import requests -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta +from agentlego.types import Annotated, Info from ..base import BaseTool LANG_CODES = { - 'auto': 'Detect source language', 'zh-CN': 'Chinese', 'en': 'English', 'fr': 'French', @@ -27,38 +24,31 @@ class Translation(BaseTool): - DEFAULT_TOOLMETA = ToolMeta( - name='Translation', - description='This tool can translate a text from source language to ' - 'the target language. The source_lang and target_lang can be one of ' + - ', '.join(f"'{k}' ({v})" for k, v in LANG_CODES.items()) + '.', - inputs=['text', 'text', 'text'], - outputs=['text'], - ) + default_desc = ('This tool can translate a text from source language to ' + 'the target language. The language code should be one of ' + + ', '.join(f"'{k}' ({v})" for k, v in LANG_CODES.items()) + '.') - PROMPT = ('translate {source_lang} to {target_lang}: {input}') - - def __init__(self, - toolmeta: Union[dict, ToolMeta] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, - backend: str = 'google'): - super().__init__(toolmeta=toolmeta, parser=parser) + def __init__(self, backend: str = 'google', toolmeta=None): + super().__init__(toolmeta=toolmeta) if backend == 'google': self._translate = self.google_translate else: - raise NotImplementedError( - f'The backend {backend} is not available.') + raise NotImplementedError(f'The backend {backend} is not available.') def apply(self, text: str, source_lang: str, target_lang: str) -> str: return self._translate(text, source_lang, target_lang) - def google_translate(self, text: str, source: str, target: str): + def google_translate( + self, + text: Annotated[str, Info('The text to translate.')], + target: Annotated[str, Info('The target language code.')], + source: Annotated[str, Info('The source language code.')] = 'auto', + ) -> str: text = quote_plus(text) url_tmpl = ('https://translate.googleapis.com/translate_a/' 'single?client=gtx&sl={}&tl={}&dt=at&dt=bd&dt=ex&' 'dt=ld&dt=md&dt=qca&dt=rw&dt=rm&dt=ss&dt=t&q={}') - response = requests.get( - url_tmpl.format(source, target, text), timeout=10).json() + response = requests.get(url_tmpl.format(source, target, text), timeout=10).json() try: result = ''.join(x[0] for x in response[0] if x[0] is not None) except Exception: diff --git a/agentlego/tools/utils/diffusers.py b/agentlego/tools/utils/diffusers.py index b268cab5..05bf81d3 100644 --- a/agentlego/tools/utils/diffusers.py +++ b/agentlego/tools/utils/diffusers.py @@ -12,8 +12,7 @@ def load_sd(model: str = 'runwayml/stable-diffusion-v1-5', device=None): import torch from diffusers import (AutoencoderKL, ControlNetModel, - StableDiffusionControlNetPipeline, - StableDiffusionPipeline) + StableDiffusionControlNetPipeline, StableDiffusionPipeline) dtype = torch.float16 if 'cuda' in str(device) else torch.float32 params = {'torch_dtype': dtype} @@ -29,8 +28,7 @@ def load_sd(model: str = 'runwayml/stable-diffusion-v1-5', ) params['vae'] = vae - t2i = load_or_build_object(StableDiffusionPipeline.from_pretrained, model, - **params) + t2i = load_or_build_object(StableDiffusionPipeline.from_pretrained, model, **params) if controlnet is None: return t2i.to(device) @@ -41,8 +39,7 @@ def load_sd(model: str = 'runwayml/stable-diffusion-v1-5', torch_dtype=dtype, variant=controlnet_variant, ) - pipe = StableDiffusionControlNetPipeline( - **t2i.components, controlnet=controlnet) + pipe = StableDiffusionControlNetPipeline(**t2i.components, controlnet=controlnet) return pipe.to(device) @@ -71,8 +68,8 @@ def load_sdxl(model: str = 'stabilityai/stable-diffusion-xl-base-1.0', ) params['vae'] = vae - t2i = load_or_build_object(StableDiffusionXLPipeline.from_pretrained, - model, **params) + t2i = load_or_build_object(StableDiffusionXLPipeline.from_pretrained, model, + **params) if controlnet is None: return t2i.to(device) diff --git a/agentlego/tools/utils/parameters.py b/agentlego/tools/utils/parameters.py new file mode 100644 index 00000000..7926dedf --- /dev/null +++ b/agentlego/tools/utils/parameters.py @@ -0,0 +1,111 @@ +import copy +import inspect +from typing import Callable, Optional, Tuple + +from typing_extensions import Annotated, get_args, get_origin + +from agentlego.schema import Parameter, ToolMeta +from agentlego.types import CatgoryToIO + + +def get_input_parameters(func: Callable) -> Tuple[Parameter, ...]: + inputs = [] + for p in inspect.signature(func).parameters.values(): + if p.name == 'self': + continue + + annotation = p.annotation + info = None + if get_origin(annotation) is Annotated: + for item in get_args(annotation): + if isinstance(item, Parameter): + info = item + annotation = get_args(annotation)[0] + + input_ = Parameter( + name=p.name, + type=annotation, + optional=p.default is not inspect._empty, + default=p.default if p.default is not inspect._empty else None, + ) + if info is not None: + input_.update(info) + inputs.append(input_) + return tuple(inputs) + + +def get_output_parameters(func: Callable) -> Optional[Tuple[Parameter, ...]]: + outputs = [] + return_ann = inspect.signature(func).return_annotation + if return_ann is inspect._empty: + return None + elif get_origin(return_ann) is tuple: + annotations = get_args(return_ann) + assert len(annotations) > 1 and Ellipsis not in annotations, ( + f'The number of outputs of `{func.__qualname__}` ' + 'is undefined. Please specify like `Tuple[int, int, str]`') + else: + annotations = (return_ann, ) + + for annotation in annotations: + info = None + if get_origin(annotation) is Annotated: + for item in get_args(annotation): + if isinstance(item, Parameter): + info = item + annotation = get_args(annotation)[0] + + output = Parameter(type=annotation) + if info is not None: + output.update(info) + outputs.append(output) + return tuple(outputs) + + +def extract_toolmeta(func: Callable, override: Optional[ToolMeta] = None) -> ToolMeta: + supported_types = set(CatgoryToIO.values()) + + inputs = get_input_parameters(func) + if override is not None and override.inputs is not None: + assert len(inputs) == len( + override.inputs), ('The length of `inputs` in toolmeta is different with ' + f'the number of arguments of `{func.__qualname__}`.') + for input_, new_input in zip(inputs, override.inputs): + input_.update(new_input) + for input_ in inputs: + assert input_.type is not inspect._empty, ( + f'The type of input `{input_.name}` of ' + f'`{func.__qualname__}` is not specified.') + assert input_.type in supported_types, ( + f'The type of input `{input_.name}` of {func.__qualname__}` ' + 'is not supported. Supported types are ' + + ', '.join(i.__name__ for i in supported_types)) + + outputs = get_output_parameters(func) + if outputs is None: + assert override is not None and override.outputs is not None, ( + f'The type of output of `{func.__qualname__}` is not specified.') + outputs = override.outputs + elif override is not None and override.outputs is not None: + assert len(outputs) == len(override.outputs), ( + 'The length of `outputs` in toolmeta is different with ' + f'the type hint of return value of `{func.__qualname__}`.') + for output, new_output in zip(outputs, override.outputs): + output.update(new_output) + for output in outputs: + assert output.type is not inspect._empty, ( + f'The type of output `{output.name}` of ' + f'`{func.__qualname__}` is not specified.') + assert output.type in supported_types, ( + f'The type of return value of {func.__qualname__}` ' + 'is not supported. Supported types are ' + + ', '.join(i.__name__ for i in supported_types)) + + if override: + toolmeta = copy.deepcopy(override) + toolmeta.inputs = tuple(inputs) + toolmeta.outputs = tuple(outputs) + else: + toolmeta = ToolMeta(inputs=inputs, outputs=outputs) + + return toolmeta diff --git a/agentlego/tools/vqa/README.md b/agentlego/tools/vqa/README.md index 24162161..a8c801dd 100644 --- a/agentlego/tools/vqa/README.md +++ b/agentlego/tools/vqa/README.md @@ -1,4 +1,4 @@ -# VisualQuestionAnswering +# VQA ## Examples @@ -8,7 +8,7 @@ from agentlego.apis import load_tool # load tool -tool = load_tool('VisualQuestionAnswering', device='cuda') +tool = load_tool('VQA', device='cuda') # apply tool answer = tool('examples/demo.png', 'What is the color of the cat?') @@ -23,7 +23,7 @@ from agentlego.apis import load_tool # load tools and build agent # please set `OPENAI_API_KEY` in your environment variable. -tool = load_tool('VisualQuestionAnswering', device='cuda').to_lagent() +tool = load_tool('VQA', device='cuda').to_lagent() agent = ReAct(GPTAPI(temperature=0.), action_executor=ActionExecutor([tool])) # agent running with the tool. diff --git a/agentlego/tools/vqa/__init__.py b/agentlego/tools/vqa/__init__.py index e75a258c..323de181 100644 --- a/agentlego/tools/vqa/__init__.py +++ b/agentlego/tools/vqa/__init__.py @@ -1,3 +1,3 @@ -from .visual_question_answering import VisualQuestionAnswering +from .visual_question_answering import VQA -__all__ = ['VisualQuestionAnswering'] +__all__ = ['VQA'] diff --git a/agentlego/tools/vqa/visual_question_answering.py b/agentlego/tools/vqa/visual_question_answering.py index 0adeed45..62090204 100644 --- a/agentlego/tools/vqa/visual_question_answering.py +++ b/agentlego/tools/vqa/visual_question_answering.py @@ -1,37 +1,26 @@ -from typing import Callable, Union - -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta -from agentlego.types import ImageIO +from agentlego.types import Annotated, ImageIO, Info from agentlego.utils import load_or_build_object, require from ..base import BaseTool -class VisualQuestionAnswering(BaseTool): +class VQA(BaseTool): """A tool to answer the question about an image. Args: - toolmeta (dict | ToolMeta): The meta info of the tool. Defaults to - the :attr:`DEFAULT_TOOLMETA`. - parser (Callable): The parser constructor, Defaults to - :class:`DefaultParser`. remote (bool): Whether to use the remote model. Defaults to False. device (str): The device to load the model. Defaults to 'cuda'. + toolmeta (None | dict | ToolMeta): The additional info of the tool. + Defaults to None. """ - DEFAULT_TOOLMETA = ToolMeta( - name='VQA', - description='This tool can answer the input question based on the ' - 'input image. The question should be in English.', - inputs=['image', 'text'], - outputs=['text']) + default_desc = ('This tool can answer the input question based on the ' + 'input image.') def __init__(self, - toolmeta: Union[ToolMeta, dict] = DEFAULT_TOOLMETA, - parser: Callable = DefaultParser, model: str = 'ofa-base_3rdparty-zeroshot_vqa', - device: str = 'cuda'): - super().__init__(toolmeta, parser) + device: str = 'cuda', + toolmeta=None): + super().__init__(toolmeta) self.device = device self.model = model @@ -42,10 +31,12 @@ def setup(self): with DefaultScope.overwrite_default_scope('mmpretrain'): self._inferencer = load_or_build_object( - VisualQuestionAnsweringInferencer, - model=self.model, - device=self.device) + VisualQuestionAnsweringInferencer, model=self.model, device=self.device) - def apply(self, image: ImageIO, text: str) -> str: + def apply( + self, + image: ImageIO, + question: Annotated[str, Info('The question should be in English.')], + ) -> str: image = image.to_array()[:, :, ::-1] - return self._inferencer(image, text)[0]['pred_answer'] + return self._inferencer(image, question)[0]['pred_answer'] diff --git a/agentlego/tools/wrappers/lagent.py b/agentlego/tools/wrappers/lagent.py index 3320f66a..e6695c5e 100644 --- a/agentlego/tools/wrappers/lagent.py +++ b/agentlego/tools/wrappers/lagent.py @@ -1,15 +1,25 @@ import copy -import json -import re -from collections import defaultdict from lagent.actions import BaseAction from lagent.schema import ActionReturn, ActionStatusCode from agentlego.parsers import DefaultParser +from agentlego.types import AudioIO, File, ImageIO from ..base import BaseTool +def convert_type(t): + if t in [str, ImageIO, AudioIO, File]: + return 'STRING' + elif t is int: + return 'NUMBER' + elif t is float: + return 'FLOAT' + elif t is bool: + return 'BOOLEAN' + return 'STRING' + + class LagentTool(BaseAction): """A wrapper to align with the interface of Lagent tools.""" @@ -18,48 +28,52 @@ def __init__(self, tool: BaseTool): tool.set_parser(DefaultParser) # Use string input & output self.tool = tool - example_args = ', '.join(f'"{name}": xxx' for name in tool.parameters) - description = (f'{tool.description} Combine all args to one json ' - f'string like {{{example_args}}}') + parameters = [] + required = [] + for p in tool.inputs: + parameters.append( + dict( + name=p.name, + description=p.description, + type=convert_type(p.type), + )) + if not p.optional: + required.append(p.name) + self._is_toolkit = False super().__init__( - name=tool.name.replace(' ', ''), - description=description, + description=dict( + name=tool.name, + description=tool.toolmeta.description, + parameters=parameters, + required=required, + ), enable=True, ) - def run(self, json_args: str): - # load json format arguments - try: - item = next( - re.finditer('{.*}', json_args, re.MULTILINE | re.DOTALL)) - kwargs = json.loads(item.group()) - except Exception: - error = ValueError( - 'All arguments should be combined into one json string.') - return ActionReturn( - type=self.name, - errmsg=repr(error), - state=ActionStatusCode.ARGS_ERROR, - args={'raw_input': json_args}, - ) + def run(self, **kwargs) -> ActionReturn: try: - result = self.tool(**kwargs) - result_dict = defaultdict(list) - result_dict['text'] = str(result) + outputs = self.tool(**kwargs) + results = [] - if not isinstance(result, tuple): - result = [result] + if not isinstance(outputs, tuple): + outputs = [outputs] - for res, out_type in zip(result, self.tool.toolmeta.outputs): - if out_type != 'text': - result_dict[out_type].append(res) + for out, p in zip(outputs, self.tool.outputs): + if p.type is ImageIO: + results.append(dict(type='image', content=out)) + elif p.type is AudioIO: + results.append(dict(type='audio', content=out)) + elif p.type is File: + results.append(dict(type='file', content=out)) + else: + results.append(dict(type='text', content=str(out))) return ActionReturn( type=self.name, args=kwargs, - result=result_dict, + result=results, ) except Exception as e: return ActionReturn( diff --git a/agentlego/tools/wrappers/langchain.py b/agentlego/tools/wrappers/langchain.py index 85ccc6dd..a7a0ea73 100644 --- a/agentlego/tools/wrappers/langchain.py +++ b/agentlego/tools/wrappers/langchain.py @@ -16,18 +16,20 @@ def call(*args, **kwargs): call_args = {} call_params = [] - for p in tool.parameters.values(): + for p in tool.inputs: call_args[p.name] = str call_params.append( inspect.Parameter( p.name, inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=str)) + annotation=str, + default=p.default if p.optional else inspect._empty, + )) call.__signature__ = inspect.Signature(call_params) call.__annotations__ = call_args return StructuredTool.from_function( func=call, name=tool.name, - description=tool.toolmeta.description, + description=tool.description, ) diff --git a/agentlego/tools/wrappers/transformers_agent.py b/agentlego/tools/wrappers/transformers_agent.py index 1ff45520..b1031056 100644 --- a/agentlego/tools/wrappers/transformers_agent.py +++ b/agentlego/tools/wrappers/transformers_agent.py @@ -1,12 +1,11 @@ import copy from transformers.tools import Tool -from transformers.tools.agent_types import (AgentAudio, AgentImage, AgentText, - AgentType) +from transformers.tools.agent_types import AgentAudio, AgentImage, AgentText, AgentType from agentlego.parsers import NaiveParser from agentlego.tools.base import BaseTool -from agentlego.types import AudioIO, CatgoryToIO, ImageIO +from agentlego.types import AudioIO, ImageIO def cast_lego_to_hf(value): @@ -32,29 +31,24 @@ def __init__(self, tool: BaseTool): # transformers agent system self.name: str = 'agentlego_' + tool.name.lower().replace(' ', '_') - inputs_desc = [] - for p in tool.parameters.values(): - default = f', Defaults to {p.default}' if p.optional else '' - inputs_desc.append(f'{p.name} ({p.category}{default})') - inputs_desc = 'Args: ' + ', '.join(inputs_desc) - self.description: str = f'{tool.toolmeta.description} {inputs_desc}' + self.description: str = self.refine_description(tool) self.inputs = list(tool.toolmeta.inputs) self.outputs = list(tool.toolmeta.outputs) def __call__(self, *args, **kwargs): - for k, v in zip(self.tool.parameters, args): - kwargs[k] = v + for arg, p in zip(args, self.tool.inputs): + kwargs[p.name] = arg parsed_kwargs = {} for k, v in kwargs.items(): - p = self.tool.parameters[k] - if p.category == 'audio': + p = self.tool.arguments[k] + if p.type is AudioIO: parsed_kwargs[k] = AudioIO(v) - elif p.category == 'image': + elif p.type is ImageIO: parsed_kwargs[k] = ImageIO(v) else: - parsed_kwargs[k] = CatgoryToIO[p.category](v) + parsed_kwargs[k] = p.type(v) outputs = self.tool(**parsed_kwargs) @@ -63,3 +57,43 @@ def __call__(self, *args, **kwargs): parsed_outs = [cast_lego_to_hf(out) for out in outputs] return parsed_outs[0] if len(parsed_outs) == 1 else parsed_outs + + @staticmethod + def refine_description(tool) -> str: + inputs_desc = [] + type2format = {ImageIO: 'image', AudioIO: 'audio'} + for p in tool.inputs: + desc = f'{p.name}' + format = type2format.get(p.type, p.type.__name__) + if p.description: + format += f', {p.description}' + if p.optional: + format += f'. Optional, Defaults to {p.default}' + desc += f' ({format})' + inputs_desc.append(desc) + if len(inputs_desc) > 0: + inputs_desc = 'Args: ' + '; '.join(inputs_desc) + else: + inputs_desc = 'No argument.' + + outputs_desc = [] + for p in tool.outputs: + format = type2format.get(p.type, p.type.__name__) + if p.name and p.description: + desc = f'{p.name} ({format}, {p.description})' + elif p.name: + desc = f'{p.name} ({format})' + elif p.description: + desc = f'{format} ({p.description})' + else: + desc = f'{format}' + outputs_desc.append(desc) + if len(outputs_desc) > 0: + outputs_desc = 'Returns: ' + '; '.join(outputs_desc) + else: + outputs_desc = 'No returns.' + + description = (f'{tool.toolmeta.description}\n' + f'{inputs_desc}\n{outputs_desc}') + + return description diff --git a/agentlego/types.py b/agentlego/types.py index 6f7ea5d5..5d369ef9 100644 --- a/agentlego/types.py +++ b/agentlego/types.py @@ -1,10 +1,13 @@ +from io import BytesIO, IOBase from pathlib import Path from typing import TYPE_CHECKING, Optional, Union import numpy as np from PIL import Image +from typing_extensions import Annotated -from .utils import temp_path +from .schema import Parameter +from .utils import is_package_available, temp_path if TYPE_CHECKING: import torch @@ -31,9 +34,8 @@ def __init__(self, value): self.value = value self.type = name if self.type is None: - raise NotImplementedError( - f'The value type `{type(value)}` is not ' - f'supported by `{self.__class__.__name__}`') + raise NotImplementedError(f'The value type `{type(value)}` is not ' + f'supported by `{self.__class__.__name__}`') def to(self, dst_type: str): if self.type == dst_type: @@ -49,6 +51,48 @@ def __str__(self) -> str: return f'{self.__class__.__name__}(value={self.value})' +class File(IOType): + support_types = {'path': str, 'bytes': bytes} + + def __init__(self, value: Union[str, bytes], filetype: Optional[str] = None): + super().__init__(value) + if self.type == 'path' and not Path(self.value).exists(): + raise FileNotFoundError(f"No such file: '{self.value}'") + self.filetype = filetype + + def to_path(self) -> str: + return self.to('path') + + def to_bytes(self) -> bytes: + return self.to('bytes') + + def to_file(self) -> IOBase: + if self.type == 'path': + return open(self.value, 'rb') + else: + return BytesIO(self.value) + + @classmethod + def from_file(cls, file: IOBase, filetype: Optional[str] = None) -> 'File': + return cls(file.read(), filetype=filetype) + + @staticmethod + def _path_to_bytes(path: str) -> bytes: + return open(path, 'rb').read() + + def _bytes_to_path(self, data: bytes) -> str: + if self.filetype: + category, _, suffix = self.filetype.partition('/') + suffix = '.' + suffix if suffix else '' + else: + category = 'file' + suffix = '' + path = temp_path(category, suffix) + with open(path, 'wb') as f: + f.write(data) + return path + + class ImageIO(IOType): support_types = {'path': str, 'pil': Image.Image, 'array': np.ndarray} @@ -66,6 +110,20 @@ def to_pil(self) -> Image.Image: def to_array(self) -> np.ndarray: return self.to('array') + def to_file(self) -> IOBase: + if self.type == 'path': + return open(self.value, 'rb') + else: + file = BytesIO() + self.to_pil().save(file, 'PNG') + file.seek(0) + return file + + @classmethod + def from_file(cls, file: IOBase) -> 'ImageIO': + from PIL import Image + return cls(Image.open(file)) + @staticmethod def _path_to_pil(path: str) -> Image.Image: return Image.open(path) @@ -129,6 +187,28 @@ def to_tensor(self) -> 'torch.Tensor': def to_path(self) -> str: return self.to('path') + def to_file(self) -> IOBase: + if self.type == 'path' or not is_package_available('torchaudio'): + return open(self.to_path(), 'rb') + else: + import torchaudio + file = BytesIO() + torchaudio.save(file, self.to_tensor(), self.sampling_rate) + file.seek(0) + return file + + @classmethod + def from_file(cls, file: IOBase) -> 'AudioIO': + try: + import torchaudio + audio, sr = torchaudio.load(file) + return cls(audio, sampling_rate=sr) + except ImportError: + filename = temp_path('audio', '.wav') + with open(filename, 'wb') as f: + f.write(file.read()) + return cls(filename) + def _path_to_tensor(self, path: str) -> 'torch.Tensor': import torchaudio audio, sampling_rate = torchaudio.load(path) @@ -142,6 +222,33 @@ def _tensor_to_path(self, tensor: 'torch.Tensor') -> str: return filename +def Info(description: Optional[str] = None, + *, + name: Optional[str] = None, + filetype: Optional[str] = None): + """Used to add additional information of arguments and outputs. + + Args: + description (str | None): Description for the parameter. Defaults to None. + name (str | None): tool name for agent to identify the tool. Defaults to None. + filetype (str | None): The file type for `File` inputs and outputs. + Defaults to None. + + Examples: + + .. code:: python + from agentlego.types import Annotated, Info, File + + class CustomTool(BaseTool): + ... + def apply( + self, arg1: Annotated[str, Info('Description of arg1')] + ) -> Annotated[File, Info('Description of output.', filetype='office/xlsx')]: + pass + """ + return Parameter(description=description, name=name, filetype=filetype) + + CatgoryToIO = { 'image': ImageIO, 'text': str, @@ -149,6 +256,7 @@ def _tensor_to_path(self, tensor: 'torch.Tensor') -> str: 'bool': bool, 'int': int, 'float': float, + 'file': File, } -__all__ = ['ImageIO', 'AudioIO', 'CatgoryToIO'] +__all__ = ['ImageIO', 'AudioIO', 'CatgoryToIO', 'Info', 'Annotated'] diff --git a/agentlego/utils/__init__.py b/agentlego/utils/__init__.py index 926cd586..f747cf9e 100644 --- a/agentlego/utils/__init__.py +++ b/agentlego/utils/__init__.py @@ -1,8 +1,13 @@ from .cache import load_or_build_object from .dependency import is_package_available, require from .file import download_checkpoint, download_url_to_file, temp_path +from .misc import apply_to +from .module import resolve_module +from .openapi import APIOperation, OpenAPISpec +from .parse import * # noqa: F401, F403 __all__ = [ 'temp_path', 'load_or_build_object', 'require', 'is_package_available', - 'download_checkpoint', 'download_url_to_file' + 'download_checkpoint', 'download_url_to_file', 'OpenAPISpec', 'APIOperation', + 'resolve_module', 'apply_to' ] diff --git a/agentlego/utils/dependency.py b/agentlego/utils/dependency.py index 1b1298d6..01ca838f 100644 --- a/agentlego/utils/dependency.py +++ b/agentlego/utils/dependency.py @@ -114,8 +114,7 @@ def ask_install(*args, **kwargs): msg = '{name} requires {dep}, please install by `{ins}`.'.format( name=fn.__qualname__.replace('.__init__', ''), dep=', '.join(dep), - ins=install - or 'pip install {}'.format(' '.join(repr(i) for i in dep))) + ins=install or 'pip install {}'.format(' '.join(repr(i) for i in dep))) raise ImportError(msg) if all(_check_dependency(item) for item in dep): diff --git a/agentlego/utils/file.py b/agentlego/utils/file.py index 0cdb780e..9b1d131c 100644 --- a/agentlego/utils/file.py +++ b/agentlego/utils/file.py @@ -7,7 +7,7 @@ from datetime import datetime from pathlib import Path from typing import Optional -from urllib.parse import urlparse # noqa: F401 +from urllib.parse import urlparse from urllib.request import Request, urlopen from tqdm import tqdm @@ -16,7 +16,9 @@ def temp_path(category: str, suffix: str, prefix: str = '', - root: str = 'generated') -> str: + root: Optional[str] = None) -> str: + if root is None: + root = os.getenv('AGENTLEGO_TMPDIR', 'generated') output_dir = Path(root) / category output_dir.mkdir(exist_ok=True, parents=True) timestamp = datetime.now().strftime('%Y%m%d') @@ -28,9 +30,8 @@ def temp_path(category: str, def _get_torchhub_dir(): torch_home = os.path.expanduser( - os.getenv( - 'TORCH_HOME', - os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) + os.getenv('TORCH_HOME', + os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) return os.path.join(torch_home, 'hub') @@ -87,9 +88,8 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): if hash_prefix is not None: digest = sha256.hexdigest() if digest[:len(hash_prefix)] != hash_prefix: - raise RuntimeError( - 'invalid hash value (expected "{}", got "{}")'.format( - hash_prefix, digest)) + raise RuntimeError('invalid hash value (expected "{}", got "{}")'.format( + hash_prefix, digest)) Path(f.name).rename(dst) finally: f.close() diff --git a/agentlego/utils/misc.py b/agentlego/utils/misc.py new file mode 100644 index 00000000..28d96410 --- /dev/null +++ b/agentlego/utils/misc.py @@ -0,0 +1,42 @@ +from typing import Any, Callable + + +def apply_to(data: Any, expr: Callable, apply_func: Callable): + """Apply function to each element in dict, list or tuple that matches with + the expression. + + For examples, if you want to convert each element in a list of dict from + `np.ndarray` to `Tensor`. You can use the following code: + + Examples: + >>> from agentlego.utils import apply_to + >>> import numpy as np + >>> import torch + >>> data = dict(array=[np.array(1)]) # {'array': [array(1)]} + >>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x)) + >>> print(result) # {'array': [tensor(1)]} + + Args: + data (Any): Data to be applied. + expr (Callable): Expression to tell which data should be applied with + the function. It should return a boolean. + apply_func (Callable): Function applied to data. + + Returns: + Any: The data after applying. + """ # noqa: E501 + if isinstance(data, dict): + # Keep the original dict type + res = type(data)() + for key, value in data.items(): + res[key] = apply_to(value, expr, apply_func) + return res + elif isinstance(data, tuple) and hasattr(data, '_fields'): + # namedtuple + return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # noqa: E501 # yapf:disable + elif isinstance(data, (tuple, list)): + return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # noqa: E501 # yapf:disable + elif expr(data): + return apply_func(data) + else: + return data diff --git a/agentlego/utils/module.py b/agentlego/utils/module.py new file mode 100644 index 00000000..1aea9fba --- /dev/null +++ b/agentlego/utils/module.py @@ -0,0 +1,32 @@ +import importlib +import sys +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +from types import ModuleType +from typing import Union + + +def resolve_module(name_or_path: Union[Path, str]) -> ModuleType: + if isinstance(name_or_path, str): + try: + module = importlib.import_module(name_or_path) + return module + except Exception: + name_or_path = Path(name_or_path) + + name = '_ext_' + name_or_path.stem.replace(' ', '').replace('-', '_') + if name_or_path.is_dir(): + name_or_path = name_or_path / '__init__.py' + + if not name_or_path.exists(): + raise ImportError(f'Cannot import from `{name_or_path}` ' + 'since the path does not exist.') + + spec = spec_from_file_location(name, str(name_or_path)) + if spec is None: + raise ImportError(f'Failed to import from `{name_or_path}`.') + module = module_from_spec(spec) + sys.modules[name] = module + spec.loader.exec_module(module) + del sys.modules[name] + return module diff --git a/agentlego/utils/openapi/__init__.py b/agentlego/utils/openapi/__init__.py new file mode 100644 index 00000000..2c1f3437 --- /dev/null +++ b/agentlego/utils/openapi/__init__.py @@ -0,0 +1,11 @@ +from .api_model import (PRIMITIVE_TYPES, APIOperation, APIProperty, APIPropertyBase, + APIPropertyLocation, APIRequestBody, APIRequestBodyProperty, + APIResponse, APIResponseProperty) +from .extract import operation_toolmeta +from .spec import HTTPVerb, OpenAPISpec + +__all__ = [ + 'APIOperation', 'APIPropertyBase', 'APIPropertyLocation', 'APIProperty', + 'APIRequestBody', 'APIRequestBodyProperty', 'APIResponse', 'APIResponseProperty', + 'OpenAPISpec', 'HTTPVerb', 'PRIMITIVE_TYPES', 'operation_toolmeta' +] diff --git a/agentlego/utils/openapi/api_model.py b/agentlego/utils/openapi/api_model.py new file mode 100644 index 00000000..2b7a2a82 --- /dev/null +++ b/agentlego/utils/openapi/api_model.py @@ -0,0 +1,701 @@ +# Copied from https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/tools/openapi/utils/api_models.py # noqa: E501 +# Modified to support multipart/form-data style media-type +"""Pydantic models for parsing an OpenAPI spec.""" +from __future__ import annotations +from enum import Enum +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union + +from openapi_pydantic import MediaType, Parameter, RequestBody, Response, Schema +from pydantic import BaseModel, Field + +from .spec import HTTPVerb, OpenAPISpec + +PRIMITIVE_TYPES = { + 'integer': int, + 'number': float, + 'string': str, + 'boolean': bool, + 'array': List, + 'object': Dict, + 'null': None, +} + + +# See https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#parameterIn # noqa: E501 +# for more info. +class APIPropertyLocation(Enum): + """The location of the property.""" + + QUERY = 'query' + PATH = 'path' + HEADER = 'header' + COOKIE = 'cookie' # Not yet supported + + @classmethod + def from_str(cls, location: str) -> APIPropertyLocation: + """Parse an APIPropertyLocation.""" + try: + return cls(location) + except ValueError: + raise ValueError('Invalid APIPropertyLocation. ' + f'Valid values are {cls.__members__}') + + +_SUPPORTED_REQUEST_MEDIA_TYPES = ('application/json', 'multipart/form-data', + 'application/x-www-form-urlencoded') +_SUPPORTED_RESPONSE_MEDIA_TYPES = ('application/json', ) + +SUPPORTED_LOCATIONS = { + APIPropertyLocation.QUERY, + APIPropertyLocation.PATH, +} +INVALID_LOCATION_TEMPL = ( + 'Unsupported APIPropertyLocation "{location}"' + ' for parameter {name}. ' + + f'Valid values are {[loc.value for loc in SUPPORTED_LOCATIONS]}') + +SCHEMA_TYPE = Union[str, Type, tuple, None, Enum] + + +class APIPropertyBase(BaseModel): + """Base model for an API property.""" + + # The name of the parameter is required and is case-sensitive. If "in" is + # "path", the "name" field must correspond to a template expression within + # the path field in the Paths Object. If "in" is "header" and the "name" + # field is "Accept", "Content-Type", or "Authorization", the parameter + # definition is ignored. For all other cases, the "name" corresponds to the + # parameter name used by the "in" property. + name: str = Field(alias='name') + """The name of the property.""" + + required: bool = Field(alias='required') + """Whether the property is required.""" + + type: SCHEMA_TYPE = Field(alias='type') + """The type of the property. + + Either a primitive type, a component/parameter type, or an array or + 'object' (dict) of the above. + """ + + format: Optional[str] = Field(alias='format', default=None) + """The modifier property to provide detail for primitive data types.""" + + default: Optional[Any] = Field(alias='default', default=None) + """The default value of the property.""" + + description: Optional[str] = Field(alias='description', default=None) + """The description of the property.""" + + +class APIProperty(APIPropertyBase): + """A model for a property in the query, path, header, or cookie params.""" + + location: APIPropertyLocation = Field(alias='location') + """The path/how it's being passed to the endpoint.""" + + @staticmethod + def _cast_schema_list_type(schema: Schema, + ) -> Optional[Union[str, Tuple[str, ...]]]: + type_ = schema.type + if not isinstance(type_, list): + return type_ + else: + return tuple(type_) + + @staticmethod + def _get_schema_type_for_enum(parameter: Parameter, schema: Schema) -> Enum: + """Get the schema type when the parameter is an enum.""" + param_name = f'{parameter.name}Enum' + return Enum(param_name, {str(v): v for v in schema.enum}) + + @staticmethod + def _get_schema_type_for_array(schema: Schema, + ) -> Optional[Union[str, Tuple[str, ...]]]: + from openapi_pydantic import Reference, Schema + + items = schema.items + if isinstance(items, Schema): + schema_type = APIProperty._cast_schema_list_type(items) + elif isinstance(items, Reference): + ref_name = items.ref.split('/')[-1] + # TODO: Add ref definitions to make his valid + schema_type = ref_name + else: + raise ValueError(f'Unsupported array items: {items}') + + if isinstance(schema_type, str): + # TODO: recurse + schema_type = (schema_type, ) + + return schema_type + + @staticmethod + def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE: + if schema is None: + return None + schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema) + if schema_type == 'array': + raise NotImplementedError('Array not yet supported') + elif schema_type == 'object': + raise NotImplementedError('Objects not yet supported') + elif schema_type in PRIMITIVE_TYPES: + if schema.enum: + schema_type = APIProperty._get_schema_type_for_enum(parameter, schema) + else: + # Directly use the primitive type + pass + elif schema_type is None: + return None + else: + raise NotImplementedError(f'Unsupported type: {schema_type}') + + return schema_type + + @staticmethod + def _validate_location(location: APIPropertyLocation, name: str) -> None: + if location not in SUPPORTED_LOCATIONS: + raise NotImplementedError( + INVALID_LOCATION_TEMPL.format(location=location, name=name)) + + @staticmethod + def _validate_content(content: Optional[Dict[str, MediaType]]) -> None: + if content: + raise ValueError( + 'API Properties with media content not supported. ' + "Media content only supported within APIRequestBodyProperty's") + + @staticmethod + def _get_schema(parameter: Parameter, spec: OpenAPISpec) -> Optional[Schema]: + from openapi_pydantic import Reference, Schema + + schema = parameter.param_schema + if isinstance(schema, Reference): + schema = spec.get_referenced_schema(schema) + elif schema is None: + return None + elif not isinstance(schema, Schema): + raise ValueError(f'Error dereferencing schema: {schema}') + + return schema + + @staticmethod + def is_supported_location(location: str) -> bool: + """Return whether the provided location is supported.""" + try: + return APIPropertyLocation.from_str(location) in SUPPORTED_LOCATIONS + except ValueError: + return False + + @classmethod + def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> 'APIProperty': + """Instantiate from an OpenAPI Parameter.""" + location = APIPropertyLocation.from_str(parameter.param_in) + cls._validate_location( + location, + parameter.name, + ) + cls._validate_content(parameter.content) + schema = cls._get_schema(parameter, spec) + schema_type = cls._get_schema_type(parameter, schema) + default_val = schema.default if schema is not None else None + return cls( + name=parameter.name, + location=location, + default=default_val, + description=parameter.description, + required=parameter.required, + type=schema_type, + ) + + +class APIRequestBodyProperty(APIPropertyBase): + """A model for a request body property.""" + + properties: List['APIRequestBodyProperty'] = Field(alias='properties') + """The sub-properties of the property.""" + + # This is useful for handling nested property cycles. + # We can define separate types in that case. + references_used: List[str] = Field(alias='references_used') + """The references used by the property.""" + + @classmethod + def _process_object_schema( + cls, schema: Schema, spec: OpenAPISpec, references_used: List[str] + ) -> Tuple[Union[str, List[str], None], List['APIRequestBodyProperty']]: + from openapi_pydantic import Reference + + properties = [] + required_props = schema.required or [] + if schema.properties is None: + raise ValueError( + f'No properties found when processing object schema: {schema}') + for prop_name, prop_schema in schema.properties.items(): + if isinstance(prop_schema, Reference): + ref_name = prop_schema.ref.split('/')[-1] + if ref_name not in references_used: + references_used.append(ref_name) + prop_schema = spec.get_referenced_schema(prop_schema) + else: + continue + + properties.append( + cls.from_schema( + schema=prop_schema, + name=prop_name, + required=prop_name in required_props, + spec=spec, + references_used=references_used, + )) + return schema.type, properties + + @classmethod + def _process_array_schema( + cls, + schema: Schema, + name: str, + spec: OpenAPISpec, + references_used: List[str], + ) -> str: + from openapi_pydantic import Reference, Schema + + items = schema.items + if items is not None: + if isinstance(items, Reference): + ref_name = items.ref.split('/')[-1] + if ref_name not in references_used: + references_used.append(ref_name) + items = spec.get_referenced_schema(items) + else: + pass + return f'Array<{ref_name}>' + else: + pass + + if isinstance(items, Schema): + array_type = cls.from_schema( + schema=items, + name=f'{name}Item', + required=True, # TODO: Add required + spec=spec, + references_used=references_used, + ) + return f'Array<{array_type.type}>' + + return 'array' + + @classmethod + def from_schema( + cls, + schema: Schema, + name: str, + required: bool, + spec: OpenAPISpec, + references_used: Optional[List[str]] = None, + ) -> 'APIRequestBodyProperty': + """Recursively populate from an OpenAPI Schema.""" + if references_used is None: + references_used = [] + + schema_type = schema.type + properties: List[APIRequestBodyProperty] = [] + if schema_type == 'object' and schema.properties: + raise NotImplementedError('Objects not yet supported') + elif schema_type == 'array': + raise NotImplementedError('Array not yet supported') + elif schema_type in PRIMITIVE_TYPES: + # Use the primitive type directly + pass + elif schema_type is None: + # No typing specified/parsed. WIll map to 'any' + pass + else: + raise ValueError(f'Unsupported type: {schema_type}') + + return cls( + name=name, + required=required, + type=schema_type, + format=schema.schema_format, + default=schema.default, + description=schema.description, + properties=properties, + references_used=references_used, + ) + + +class APIRequestBody(BaseModel): + """A model for a request body.""" + + description: Optional[str] = Field(alias='description') + """The description of the request body.""" + + properties: List[APIRequestBodyProperty] = Field(alias='properties') + + # E.g., application/json or multipart/form + media_type: str = Field(alias='media_type') + """The media type of the request body.""" + + @classmethod + def _process_supported_media_type( + cls, + media_type_obj: MediaType, + spec: OpenAPISpec, + ) -> List[APIRequestBodyProperty]: + """Process the media type of the request body.""" + from openapi_pydantic import Reference + + references_used = [] + schema = media_type_obj.media_type_schema + if isinstance(schema, Schema) and schema.allOf: + schema = schema.allOf[0] + if isinstance(schema, Reference): + references_used.append(schema.ref.split('/')[-1]) + schema = spec.get_referenced_schema(schema) + if schema is None: + raise ValueError( + f'Could not resolve schema for media type: {media_type_obj}') + api_request_body_properties = [] + required_properties = schema.required or [] + if schema.type == 'object' and schema.properties: + for prop_name, prop_schema in schema.properties.items(): + if isinstance(prop_schema, Reference): + prop_schema = spec.get_referenced_schema(prop_schema) + + api_request_body_properties.append( + APIRequestBodyProperty.from_schema( + schema=prop_schema, + name=prop_name, + required=prop_name in required_properties, + spec=spec, + )) + else: + api_request_body_properties.append( + APIRequestBodyProperty( + name='body', + required=True, + type=schema.type, + default=schema.default, + description=schema.description, + properties=[], + references_used=references_used, + )) + + return api_request_body_properties + + @classmethod + def from_request_body(cls, request_body: RequestBody, + spec: OpenAPISpec) -> 'APIRequestBody': + """Instantiate from an OpenAPI RequestBody.""" + properties = [] + for media_type, media_type_obj in request_body.content.items(): + if media_type not in _SUPPORTED_REQUEST_MEDIA_TYPES: + continue + api_request_body_properties = cls._process_supported_media_type( + media_type_obj, + spec, + ) + properties.extend(api_request_body_properties) + + return cls( + description=request_body.description, + properties=properties, + media_type=media_type, + ) + + +class APIResponseProperty(APIPropertyBase): + """A model for a response property.""" + + @classmethod + def from_schema(cls, schema: Schema, name: str, + required: bool) -> 'APIResponseProperty': + """Recursively populate from an OpenAPI Schema.""" + schema_type = schema.type + if schema_type in ['array', 'object']: + # Handle the complex type as simple string + schema_type = 'string' + elif schema_type in PRIMITIVE_TYPES: + # Use the primitive type directly + pass + elif schema_type is None: + # No typing specified/parsed. WIll map to 'any' + pass + else: + # Handle the unknown type as simple string + schema_type = 'string' + + return cls( + name=name, + required=required, + type=schema_type, + format=schema.schema_format, + default=schema.default, + description=schema.description, + ) + + +class APIResponse(BaseModel): + """A model for a response.""" + + description: Optional[str] = Field(alias='description') + """The description of the response.""" + + properties: Union[List[APIResponseProperty], Dict[str, APIResponseProperty], + APIResponseProperty] = Field(alias='properties') + + # E.g., application/json - we only support JSON at the moment. + media_type: str = Field(alias='media_type') + """The media type of the response.""" + + @classmethod + def _process_supported_media_type( + cls, + media_type_obj: MediaType, + spec: OpenAPISpec, + ) -> List[APIResponseProperty]: + """Process the media type of the response.""" + from openapi_pydantic import Reference + + references_used = [] + schema = media_type_obj.media_type_schema + if isinstance(schema, Reference): + references_used.append(schema.ref.split('/')[-1]) + schema = spec.get_referenced_schema(schema) + if schema is None: + raise ValueError( + f'Could not resolve schema for media type: {media_type_obj}') + required_properties = schema.required or [] + if schema.type == 'object' and schema.properties: + # Dict style output + properties = {} + for prop_name, prop_schema in schema.properties.items(): + if isinstance(prop_schema, Reference): + prop_schema = spec.get_referenced_schema(prop_schema) + + properties[prop_name] = APIResponseProperty.from_schema( + schema=prop_schema, + name=prop_name, + required=prop_name in required_properties, + ) + elif schema.type == 'array' and schema.prefixItems: + # Tuple style output + properties = [] + for i, prop_schema in enumerate(schema.prefixItems): + if isinstance(prop_schema, Reference): + prop_schema = spec.get_referenced_schema(prop_schema) + + properties.append( + APIResponseProperty.from_schema( + schema=prop_schema, name='_null', required=True)) + else: + # Simple style output + properties = APIResponseProperty.from_schema( + schema=schema, name='_null', required=True) + + return properties + + @classmethod + def from_response(cls, response: Response, spec: OpenAPISpec) -> 'APIResponse': + """Instantiate from an OpenAPI Response.""" + # Only handle one potential response payload style. + media_type = next( + (k for k in response.content if k in _SUPPORTED_RESPONSE_MEDIA_TYPES), None) + + if media_type is not None: + media_type_obj = response.content[media_type] + properties = cls._process_supported_media_type(media_type_obj, spec) + + return cls( + description=response.description, + properties=properties, + media_type=media_type, + ) + + +# class APIRequestBodyProperty(APIPropertyBase): +# class APIRequestBody(BaseModel): +class APIOperation(BaseModel): + """A model for a single API operation.""" + + operation_id: str = Field(alias='operation_id') + """The unique identifier of the operation.""" + + description: Optional[str] = Field(alias='description') + """The description of the operation.""" + + base_url: str = Field(alias='base_url') + """The base URL of the operation.""" + + path: str = Field(alias='path') + """The path of the operation.""" + + method: HTTPVerb = Field(alias='method') + """The HTTP method of the operation.""" + + properties: Sequence[APIProperty] = Field(alias='properties') + + # TODO: Add parse in used components to be able to specify what type of + # referenced object it is. + # """The properties of the operation.""" + # components: Dict[str, BaseModel] = Field(alias="components") + + request_body: Optional[APIRequestBody] = Field(alias='request_body') + """The request body of the operation.""" + + responses: Optional[Dict[str, APIResponse]] = Field(alias='responses') + + @staticmethod + def _get_properties_from_parameters(parameters: List[Parameter], + spec: OpenAPISpec) -> List[APIProperty]: + """Get the properties of the operation.""" + properties = [] + for param in parameters: + if APIProperty.is_supported_location(param.param_in): + properties.append(APIProperty.from_parameter(param, spec)) + elif param.required: + raise ValueError( + INVALID_LOCATION_TEMPL.format( + location=param.param_in, name=param.name)) + return properties + + @classmethod + def from_openapi_url( + cls, + spec_url: str, + path: str, + method: str, + ) -> 'APIOperation': + """Create an APIOperation from an OpenAPI URL.""" + spec = OpenAPISpec.from_url(spec_url) + return cls.from_openapi_spec(spec, path, method) + + @classmethod + def from_openapi_spec( + cls, + spec: OpenAPISpec, + path: str, + method: str, + ) -> 'APIOperation': + """Create an APIOperation from an OpenAPI spec.""" + operation = spec.get_operation(path, method) + parameters = spec.get_parameters_for_operation(operation) + properties = cls._get_properties_from_parameters(parameters, spec) + operation_id = OpenAPISpec.get_cleaned_operation_id(operation, path, method) + request_body = spec.get_request_body_for_operation(operation) + api_request_body = ( + APIRequestBody.from_request_body(request_body, spec) + if request_body is not None else None) + responses = spec.get_responses_for_operation(operation) + if responses is not None: + api_responses = { + k: APIResponse.from_response(response, spec) + for k, response in responses.items() + } + else: + api_responses = None + + description = operation.description or operation.summary + if not description and spec.paths is not None: + description = spec.paths[path].description or spec.paths[path].summary + return cls( + operation_id=operation_id, + description=description or '', + base_url=spec.base_url, + path=path, + method=method, + properties=properties, + request_body=api_request_body, + responses=api_responses, + ) + + @staticmethod + def ts_type_from_python(type_: SCHEMA_TYPE) -> str: + if type_ is None: + # TODO: Handle Nones better. These often result when + # parsing specs that are < v3 + return 'any' + elif isinstance(type_, str): + return { + 'str': 'string', + 'integer': 'number', + 'float': 'number', + 'date-time': 'string', + }.get(type_, type_) + elif isinstance(type_, tuple): + return f'Array<{APIOperation.ts_type_from_python(type_[0])}>' + elif isinstance(type_, type) and issubclass(type_, Enum): + return ' | '.join([f"'{e.value}'" for e in type_]) + else: + return str(type_) + + def _format_nested_properties(self, + properties: List[APIRequestBodyProperty], + indent: int = 2) -> str: + """Format nested properties.""" + formatted_props = [] + + for prop in properties: + prop_name = prop.name + prop_type = self.ts_type_from_python(prop.type) + prop_required = '' if prop.required else '?' + prop_desc = f'/* {prop.description} */' if prop.description else '' + + if prop.properties: + nested_props = self._format_nested_properties(prop.properties, + indent + 2) + prop_type = f"{{\n{nested_props}\n{' ' * indent}}}" + + formatted_props.append(f"{prop_desc}\n{' ' * indent}{prop_name}" + f'{prop_required}: {prop_type},') + + return '\n'.join(formatted_props) + + def to_typescript(self) -> str: + """Get typescript string representation of the operation.""" + operation_name = self.operation_id + params = [] + + if self.request_body: + formatted_request_body_props = self._format_nested_properties( + self.request_body.properties) + params.append(formatted_request_body_props) + + for prop in self.properties: + prop_name = prop.name + prop_type = self.ts_type_from_python(prop.type) + prop_required = '' if prop.required else '?' + prop_desc = f'/* {prop.description} */' if prop.description else '' + params.append(f'{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},') + + formatted_params = '\n'.join(params).strip() + description_str = (f'/* {self.description} */' if self.description else '') + typescript_definition = f""" +{description_str} +type {operation_name} = (_: {{ +{formatted_params} +}}) => any; +""" + return typescript_definition.strip() + + @property + def query_params(self) -> List[str]: + return [ + property.name for property in self.properties + if property.location == APIPropertyLocation.QUERY + ] + + @property + def path_params(self) -> List[str]: + return [ + property.name for property in self.properties + if property.location == APIPropertyLocation.PATH + ] + + @property + def body_params(self) -> List[str]: + if self.request_body is None: + return [] + return [prop.name for prop in self.request_body.properties] diff --git a/agentlego/utils/openapi/extract.py b/agentlego/utils/openapi/extract.py new file mode 100644 index 00000000..fd411902 --- /dev/null +++ b/agentlego/utils/openapi/extract.py @@ -0,0 +1,78 @@ +import warnings +from typing import Tuple + +from agentlego.schema import Parameter, ToolMeta +from .api_model import PRIMITIVE_TYPES, APIOperation, APIPropertyBase + + +def prop_to_parameter(prop: APIPropertyBase) -> Parameter: + from agentlego.types import AudioIO, File, ImageIO + p_type = PRIMITIVE_TYPES.get(prop.type, prop.type) # type: ignore + p = Parameter( + type=p_type, + name=prop.name if prop.name != '_null' else None, + description=prop.description, + optional=not prop.required, + default=prop.default, + ) + if p_type is str: + schema_format = prop.format or '' + if 'image' in schema_format: + p.type = ImageIO + elif 'audio' in schema_format: + p.type = AudioIO + elif 'binary' in schema_format or 'base64' in schema_format: + p.type = File + p.filetype, _, _ = schema_format.partition(';') + return p + + +def operation_inputs(op: APIOperation) -> Tuple[Parameter, ...]: + inputs = [] + properties = [] + if op.properties: + properties.extend(op.properties) + if op.request_body and op.request_body.properties: + properties.extend(op.request_body.properties) + for p in properties: + inputs.append(prop_to_parameter(p)) + return tuple(inputs) + + +def operation_outputs(op: APIOperation) -> Tuple[Parameter, ...]: + if op.responses is None or op.responses.get('200') is None: + # If not specify outputs, directly handle as a single text. + outputs = [Parameter(type=str)] + + response_schema = op.responses + if response_schema is None or response_schema.get('200') is None: + # Directly use string if the response schema is not specified + warnings.warn(f'The response of {op.operation_id} is not specified, ' + 'assume as a string response by default.') + return (Parameter(type=str), ) + else: + out_props = response_schema['200'].properties + + if isinstance(out_props, list): + outputs = [prop_to_parameter(out) for out in out_props] + elif isinstance(out_props, dict): + outputs = [prop_to_parameter(out) for out in out_props.values()] + else: + outputs = [prop_to_parameter(out_props)] + + return tuple(outputs) + + +def operation_toolmeta(operation: APIOperation) -> ToolMeta: + """Extract tool meta information from a HTTP operation.""" + name = operation.operation_id + inputs = operation_inputs(operation) + outputs = operation_outputs(operation) + toolmeta = ToolMeta( + name=name, + description=operation.description, + inputs=inputs, + outputs=outputs, + ) + + return toolmeta diff --git a/agentlego/utils/openapi/spec.py b/agentlego/utils/openapi/spec.py new file mode 100644 index 00000000..8884d12c --- /dev/null +++ b/agentlego/utils/openapi/spec.py @@ -0,0 +1,326 @@ +# Copied from https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/utilities/openapi.py # noqa: E501 +"""Utility functions for parsing an OpenAPI spec.""" +from __future__ import annotations +import json +import re +import warnings +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union +from urllib.parse import urljoin + +import requests +import yaml +from openapi_pydantic import OpenAPI, Response + +if TYPE_CHECKING: + from openapi_pydantic import (Components, Operation, Parameter, PathItem, Paths, + Reference, RequestBody, Schema) + + +class HTTPVerb(str, Enum): + """Enumerator of the HTTP verbs.""" + + GET = 'get' + PUT = 'put' + POST = 'post' + DELETE = 'delete' + OPTIONS = 'options' + HEAD = 'head' + PATCH = 'patch' + TRACE = 'trace' + + @classmethod + def from_str(cls, verb: str) -> HTTPVerb: + """Parse an HTTP verb.""" + try: + return cls(verb) + except ValueError: + raise ValueError(f'Invalid HTTP verb. Valid values are {cls.__members__}') + + +class OpenAPISpec(OpenAPI): + """OpenAPI Model that removes mis-formatted parts of the spec.""" + + # overriding overly restrictive type from parent class + openapi: str = '3.1.0' + + @property + def _paths_strict(self) -> Paths: + if not self.paths: + raise ValueError('No paths found in spec') + return self.paths + + def _get_path_strict(self, path: str) -> PathItem: + path_item = self._paths_strict.get(path) + if not path_item: + raise ValueError(f'No path found for {path}') + return path_item + + @property + def _components_strict(self) -> Components: + """Get components or err.""" + if self.components is None: + raise ValueError('No components found in spec. ') + return self.components + + @property + def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]: + """Get parameters or err.""" + parameters = self._components_strict.parameters + if parameters is None: + raise ValueError('No parameters found in spec. ') + return parameters + + @property + def _schemas_strict(self) -> Dict[str, Schema]: + """Get the dictionary of schemas or err.""" + schemas = self._components_strict.schemas + if schemas is None: + raise ValueError('No schemas found in spec. ') + return schemas + + @property + def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]: + """Get the request body or err.""" + request_bodies = self._components_strict.requestBodies + if request_bodies is None: + raise ValueError('No request body found in spec. ') + return request_bodies + + @property + def _responses_strict(self) -> Dict[str, Union[Response, Reference]]: + """Get the request body or err.""" + responses = self._components_strict.responses + if responses is None: + raise ValueError('No responses found in spec. ') + return responses + + def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]: + """Get a parameter (or nested reference) or err.""" + ref_name = ref.ref.split('/')[-1] + parameters = self._parameters_strict + if ref_name not in parameters: + raise ValueError(f'No parameter found for {ref_name}') + return parameters[ref_name] + + def _get_root_referenced_parameter(self, ref: Reference) -> Parameter: + """Get the root reference or err.""" + from openapi_pydantic import Reference + + parameter = self._get_referenced_parameter(ref) + while isinstance(parameter, Reference): + parameter = self._get_referenced_parameter(parameter) + return parameter + + def get_referenced_schema(self, ref: Reference) -> Schema: + """Get a schema (or nested reference) or err.""" + ref_name = ref.ref.split('/')[-1] + schemas = self._schemas_strict + if ref_name not in schemas: + raise ValueError(f'No schema found for {ref_name}') + return schemas[ref_name] + + def get_schema(self, schema: Union[Reference, Schema]) -> Schema: + from openapi_pydantic import Reference + + if isinstance(schema, Reference): + return self.get_referenced_schema(schema) + return schema + + def _get_root_referenced_schema(self, ref: Reference) -> Schema: + """Get the root reference or err.""" + from openapi_pydantic import Reference + + schema = self.get_referenced_schema(ref) + while isinstance(schema, Reference): + schema = self.get_referenced_schema(schema) + return schema + + def _get_referenced_request_body(self, ref: Reference + ) -> Optional[Union[Reference, RequestBody]]: + """Get a request body (or nested reference) or err.""" + ref_name = ref.ref.split('/')[-1] + request_bodies = self._request_bodies_strict + if ref_name not in request_bodies: + raise ValueError(f'No request body found for {ref_name}') + return request_bodies[ref_name] + + def _get_root_referenced_request_body(self, ref: Reference) -> Optional[RequestBody]: + """Get the root request Body or err.""" + from openapi_pydantic import Reference + + request_body = self._get_referenced_request_body(ref) + while isinstance(request_body, Reference): + request_body = self._get_referenced_request_body(request_body) + return request_body + + def _get_referenced_response(self, + ref: Reference) -> Optional[Union[Reference, Response]]: + """Get a response (or nested reference) or err.""" + ref_name = ref.ref.split('/')[-1] + responses = self._responses_strict + if ref_name not in responses: + raise ValueError(f'No responses found for {ref_name}') + return responses[ref_name] + + def _get_root_referenced_response(self, ref: Reference) -> Optional[Response]: + """Get the root response or err.""" + from openapi_pydantic import Reference + + response = self._get_referenced_response(ref) + while isinstance(response, Reference): + response = self._get_referenced_response(response) + return response + + @staticmethod + def _alert_unsupported_spec(obj: dict) -> None: + """Alert if the spec is not supported.""" + warning_message = (' This may result in degraded performance.' + + ' Convert your OpenAPI spec to 3.1.* spec' + + ' for better support.') + swagger_version = obj.get('swagger') + openapi_version = obj.get('openapi') + if isinstance(openapi_version, str): + if openapi_version != '3.1.0': + warnings.warn(f'Attempting to load an OpenAPI {openapi_version}' + f' spec. {warning_message}') + else: + pass + elif isinstance(swagger_version, str): + warnings.warn(f'Attempting to load a Swagger {swagger_version}' + f' spec. {warning_message}') + else: + raise ValueError('Attempting to load an unsupported spec:' + f'\n\n{obj}\n{warning_message}') + + @classmethod + def model_validate(cls, obj: dict) -> OpenAPISpec: + cls._alert_unsupported_spec(obj) + return super().model_validate(obj) + + @classmethod + def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec: + """Get an OpenAPI spec from a dict.""" + return cls.model_validate(spec_dict) + + @classmethod + def from_text(cls, text: str) -> OpenAPISpec: + """Get an OpenAPI spec from a text.""" + try: + spec_dict = json.loads(text) + except json.JSONDecodeError: + spec_dict = yaml.safe_load(text) + return cls.from_spec_dict(spec_dict) + + @classmethod + def from_file(cls, path: Union[str, Path]) -> OpenAPISpec: + """Get an OpenAPI spec from a file path.""" + path_ = path if isinstance(path, Path) else Path(path) + if not path_.exists(): + raise FileNotFoundError(f'{path} does not exist') + with path_.open('r') as f: + return cls.from_text(f.read()) + + @classmethod + def from_url(cls, url: str) -> OpenAPISpec: + """Get an OpenAPI spec from a URL.""" + response = requests.get(url) + spec = cls.from_text(response.text) + if spec.base_url == '/': + spec.servers[0].url = urljoin(url, '/') + return spec + + @property + def base_url(self) -> str: + """Get the base url.""" + return self.servers[0].url + + def get_methods_for_path(self, path: str) -> List[str]: + """Return a list of valid methods for the specified path.""" + from openapi_pydantic import Operation + + path_item = self._get_path_strict(path) + results = [] + for method in HTTPVerb: + operation = getattr(path_item, method.value, None) + if isinstance(operation, Operation): + results.append(method.value) + return results + + def get_parameters_for_path(self, path: str) -> List[Parameter]: + from openapi_pydantic import Reference + + path_item = self._get_path_strict(path) + parameters = [] + if not path_item.parameters: + return [] + for parameter in path_item.parameters: + if isinstance(parameter, Reference): + parameter = self._get_root_referenced_parameter(parameter) + parameters.append(parameter) + return parameters + + def get_operation(self, path: str, method: str) -> Operation: + """Get the operation object for a given path and HTTP method.""" + from openapi_pydantic import Operation + + path_item = self._get_path_strict(path) + operation_obj = getattr(path_item, method, None) + if not isinstance(operation_obj, Operation): + raise ValueError(f'No {method} method found for {path}') + return operation_obj + + def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]: + """Get the components for a given operation.""" + from openapi_pydantic import Reference + + parameters = [] + if operation.parameters: + for parameter in operation.parameters: + if isinstance(parameter, Reference): + parameter = self._get_root_referenced_parameter(parameter) + parameters.append(parameter) + return parameters + + def get_request_body_for_operation(self, + operation: Operation) -> Optional[RequestBody]: + """Get the request body for a given operation.""" + from openapi_pydantic import Reference + + request_body = operation.requestBody + if isinstance(request_body, Reference): + request_body = self._get_root_referenced_request_body(request_body) + return request_body + + def get_responses_for_operation(self, operation: Operation + ) -> Optional[Dict[str, Response]]: + """Get the responses for a given operation.""" + from openapi_pydantic import Reference + + responses = operation.responses + if responses is None: + return None + + results = {} + for k, response in responses.items(): + if isinstance(response, Reference): + response = self._get_root_referenced_response(response) + results[k] = response + + return results + + @staticmethod + def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str: + """Get a cleaned operation id from an operation id.""" + operation_id = operation.operationId + if operation_id is None: + # Replace all punctuation of any kind with underscore + path = re.sub(r'[^a-zA-Z0-9]', '_', path.lstrip('/')) + operation_id = f'{path}_{method}' + return operation_id.replace('-', '_').replace('.', '_').replace('/', '_') + + def iter_all_method(self) -> Iterator[Tuple[str, str]]: + for path in self._paths_strict: + for method in self.get_methods_for_path(path): + yield path, method diff --git a/agentlego/utils/parse.py b/agentlego/utils/parse.py new file mode 100644 index 00000000..39d3c70e --- /dev/null +++ b/agentlego/utils/parse.py @@ -0,0 +1,15 @@ +import re +from typing import Optional, Tuple + + +def parse_multi_float( + input_str: str, + number: Optional[int] = None, +) -> Tuple[float, ...]: + pattern = r'([-+]?\d*\.?\d+)' + matches = re.findall(pattern, input_str) + + if number is not None and len(matches) != number: + raise ValueError(f'Expected {number} numbers, got {input_str}.') + else: + return tuple(float(num) for num in matches) diff --git a/agentlego/version.py b/agentlego/version.py index 7ff53f71..504440c6 100644 --- a/agentlego/version.py +++ b/agentlego/version.py @@ -1,4 +1,4 @@ -__version__ = '0.1.2' +__version__ = '0.2.0' short_version = __version__ diff --git a/docs/en/collect_docs.py b/docs/en/collect_docs.py index 6ac2ec37..91a55ec3 100755 --- a/docs/en/collect_docs.py +++ b/docs/en/collect_docs.py @@ -20,8 +20,10 @@ - **name**: {name} - **description**: {description} -- **inputs**: {inputs} -- **outputs**: {outputs} +- **inputs**: +{inputs} +- **outputs**: +{outputs} ''' @@ -43,14 +45,27 @@ def format_tool_readme(path): content = contents[start:end] cls_name = content[0].strip('\n# ') from agentlego import tools - toolmeta = getattr(tools, cls_name).DEFAULT_TOOLMETA + from agentlego.schema import ToolMeta + toolmeta: ToolMeta = getattr(tools, cls_name).get_default_toolmeta() + inputs = [] + for p in toolmeta.inputs: + desc = f' - {p.name} ({p.type.__name__})' + if p.description: + desc += f': {p.description}' + inputs.append(desc) + outputs = [] + for p in toolmeta.outputs: + desc = f' - {p.type.__name__}' + if p.description: + desc += f': {p.description}' + outputs.append(desc) content.insert( 1, DEFAULT_TOOLMETA_TMPL.format( name=toolmeta.name, description=toolmeta.description, - inputs=', '.join(toolmeta.inputs), - outputs=', '.join(toolmeta.outputs), + inputs='\n'.join(inputs), + outputs='\n'.join(outputs), )) content.insert(1, AUTODOC_TMPL.format(cls_name=cls_name)) target = tmp_dir / 'tools' / (cls_name + '.md') diff --git a/docs/en/conf.py b/docs/en/conf.py index 36789346..c61e8798 100755 --- a/docs/en/conf.py +++ b/docs/en/conf.py @@ -110,8 +110,7 @@ def get_version(): 'css/readthedocs.css' ] html_js_files = [ - 'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.js', - 'js/custom.js' + 'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.js', 'js/custom.js' ] # -- Options for HTMLHelp output --------------------------------------------- @@ -155,8 +154,7 @@ def get_version(): # dir menu entry, description, category) texinfo_documents = [ (root_doc, 'agentlego', 'AgentLego Documentation', author, 'agentlego', - 'A versatile tool library for enhancing LLM-based agents', - 'Miscellaneous'), + 'A versatile tool library for enhancing LLM-based agents', 'Miscellaneous'), ] # -- Options for Epub output ------------------------------------------------- @@ -195,8 +193,6 @@ def get_version(): intersphinx_mapping = { 'python': ('https://docs.python.org/3', None), 'numpy': ('https://numpy.org/doc/stable', None), - 'torch': ('https://pytorch.org/docs/stable/', None), - 'mmengine': ('https://mmengine.readthedocs.io/en/latest/', None), } # Disable docstring inheritance diff --git a/docs/en/get_started.md b/docs/en/get_started.md index ded2491e..e0d5f46f 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -24,7 +24,7 @@ pip install agentlego[optional] openmim mim install mmpretrain mmdet mmpose easyocr # For image generation tools. -mim install transformers diffusers mmagic +pip install transformers diffusers ``` 3. Some tools requires extra dependencies, check the **Set up** section in `Tool APIs` before you want to use. @@ -60,7 +60,7 @@ print(calculator_tool('cos(pi / 6)')) # Image or Audio input supports multiple formats from PIL import Image -image_caption_tool = load_tool('ImageCaption', device='cuda') +image_caption_tool = load_tool('ImageDescription', device='cuda') img_path = './examples/demo.png' img_pil = Image.open(img_path) print(image_caption_tool(img_path)) @@ -77,15 +77,15 @@ efficiently build large language model(LLM) -based agents. Here is an example script to integrate agentlego tools to Lagent: ```python -from agentlego.apis import load_tool from lagent import ReAct, GPTAPI, ActionExecutor +from agentlego.tools import Calculator # Load the tools you want to use. -tool = load_tool('Calculator').to_lagent() +tools = [Calculator().to_lagent()] # Build Lagent Agent model = GPTAPI(temperature=0.) -agent = ReAct(llm=model, action_executor=ActionExecutor([tool])) +agent = ReAct(llm=model, action_executor=ActionExecutor(tools)) user_input = 'If the side lengths of a triangle are 3cm, 4cm and 5cm, please tell me the area of the triangle.' ret = agent.chat(user_input) @@ -106,27 +106,23 @@ users to start and customize applications. Here is an example script to integrate agentlego tools to LangChain: ```python -from agentlego.apis import load_tool -from langchain.agents import AgentType, initialize_agent -from langchain.chains.conversation.memory import ConversationBufferMemory -from langchain.chat_models import ChatOpenAI +from langchain import hub +from langchain.agents import create_structured_chat_agent, AgentExecutor +from langchain.memory import ConversationBufferMemory +from langchain_openai import ChatOpenAI +from agentlego.tools import Calculator # Load the tools you want to use. -tool = load_tool('Calculator').to_langchain() +tools = [Calculator().to_langchain()] # Build LangChain Agent -model = ChatOpenAI(temperature=0.) +llm = ChatOpenAI(temperature=0.) memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) -agent = initialize_agent( - agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, - llm=model, - tools=[tool], - memory=memory, - verbose=True, -) +agent = create_structured_chat_agent(llm, tools, prompt=hub.pull("hwchase17/structured-chat-agent")) +agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=True) user_input = 'If the side lengths of a triangle are 3cm, 4cm and 5cm, please tell me the area of the triangle.' -agent.run(input=user_input) +agent_executor.invoke(dict(input=user_input)) ``` ### Transformers Agent @@ -138,60 +134,20 @@ easy incorporation of additional community-developed tools. Here is an example script to integrate agentlego tools to Transformers agent: ```python -from agentlego.apis import load_tool from transformers import HfAgent +from agentlego.tools import Calculator # Load the tools you want to use. -tool = load_tool('Calculator').to_transformers_agent() +tools = [Calculator().to_transformers_agent()] # Build HuggingFace Transformers Agent prompt = open('examples/hf_agent/hf_demo_prompts.txt', 'r').read() agent = HfAgent( 'https://api-inference.huggingface.co/models/bigcode/starcoder', chat_prompt_template=prompt, - additional_tools=[tool], + additional_tools=tools, ) user_input = 'If the side lengths of a triangle are 3cm, 4cm and 5cm, please tell me the area of the triangle.' agent.chat(user_input) ``` - -# Tool Server - -AgentLego provides a suit of tool server utilities to help you deploy tools on a server and use it like local -tools on clients. - -## Start a server - -We provide a script `server.py` to start a tool server. You can specify the tool names you want to use. - -```bash -python server.py Calculator ImageCaption TextToImage -``` - -And then, the server will setup all tools and start. - -```bash -INFO: Started server process [1741344] -INFO: Waiting for application startup. -INFO: Application startup complete. -INFO: Uvicorn running on http://0.0.0.0:16180 (Press CTRL+C to quit) -``` - -## Use tools in client - -In the client, you can create a remote tool from the url of the tool server. - -```python -from agentlego.tools.remote import RemoteTool - -# Create all remote tools from a tool server root url. -tools = RemoteTool.from_server('http://127.0.0.1:16180') -for tool in tools: - print(tool.name, tool.url) - -# Create single remote tool from a tool server endpoint. -# All endpoint can be found in the docs of the tool server, like http://127.0.0.1:16180/docs -tool = RemoteTool('http://127.0.0.1:16180/ImageDescription') -print(tool.description) -``` diff --git a/docs/en/index.rst b/docs/en/index.rst index 3915f085..4e31e5c2 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -24,6 +24,7 @@ Welcome to the documentation of AgentLego! modules/apis.md modules/tool.md + modules/tool-server.md .. _Tool APIs: diff --git a/docs/en/modules/tool-server.md b/docs/en/modules/tool-server.md new file mode 100644 index 00000000..49eaec50 --- /dev/null +++ b/docs/en/modules/tool-server.md @@ -0,0 +1,88 @@ +# Tool Server + +AgentLego provides a suit of tool server utilities to help you deploy tools on a server and use it like local +tools on clients. + +## Start a server + +We provide a command-line tool `agentlego-server` to start a tool server. You can specify the tool names you want to use. + +```bash +agentlego-server start Calculator ImageDescription TextToImage +``` + +And then, the server will setup all tools and start. + +```bash +INFO: Started server process [1741344] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://127.0.0.1:16180 (Press CTRL+C to quit) +``` + +## Use tools in client + +In the client, you can create a remote tool from the url of the tool server. + +```python +from agentlego.tools.remote import RemoteTool + +# Create all remote tools from a tool server root url. +tools = RemoteTool.from_server('http://127.0.0.1:16180') +for tool in tools: + print(tool.name, tool.url) + +# Create single remote tool from a tool server endpoint. +# All endpoint can be found in the docs of the tool server, like http://127.0.0.1:16180/docs +tool = RemoteTool.from_url('http://127.0.0.1:16180/ImageDescription') +print(tool.description) +``` + +## How to Deploy Your Own Tools + +`agentlego-server` accepts additional tool modules, which means you don't need to modify the source code of `AgentLego`. You just need to write your tool source code in a Python file or module to deploy tools using `agentlego-server`. + +First, we create a Python file named `my_tool.py` + +```python +from agentlego.tools import BaseTool + +class Clock(BaseTool): + default_desc = 'Returns the current date and time.' + + def apply(self) -> str: + from datetime import datetime + return datetime.now().strftime('%Y/%m/%d %H:%M') + +class RandomNumber(BaseTool): + default_desc = 'Returns a random number not greater than `max`' + + def apply(self, max: int) -> int: + import random + return random.randint(0, max) +``` + +In this file, we defined two tools: `Clock` and `RandomNumber`. After saving the file, use the following command in the command line to check if `agentlego-server` can correctly read these two tools: + +```bash +# We use the --extra option to specify the additional tool source file +# Use the --no-official option to hide the built-in tools of AgentLego +agentlego-server list --extra ./my_tool.py --no-official +``` + +Getting the following output means that `agentlego-server` can read these tools + +``` +┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ +┃ Class ┃ source ┃ +┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ +│ Clock │ /home/my_tool.py │ +│ RandomNumber │ /home/my_tool.py │ +└──────────────┴──────────────────┘ +``` + +Start the tool server: + +```bash +agentlego-server start --extra ./my_tool.py Clock RandomNumber +``` diff --git a/docs/en/modules/tool.md b/docs/en/modules/tool.md index 0871b4bf..13c276a0 100644 --- a/docs/en/modules/tool.md +++ b/docs/en/modules/tool.md @@ -8,38 +8,22 @@ First, all tools should inherit the `BaseTool` class. As an example, assume we w ```python from agentlego.tools import BaseTool -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta class Clock(BaseTool): - def __init__(self): - toolmeta = ToolMeta( - name='Clock', - description='A clock that return the current date and time.', - inputs=[], - outputs=['text'], - ) - super().__init__(toolmeta=toolmeta, parser=DefaultParser) -``` + default_desc = 'A clock that return the current date and time.' -To initialize the tool, you need to construct a `ToolMeta` to specify the name, description, input arguments -categories and the output categories. The available categories are `text`, `image` and `audio` by now. + def apply(self) -> str: + from datetime import datetime + return datetime.now().strftime('%Y/%m/%d %H:%M') +``` -Then, you need also to specify a default parser, it's used to handle the input & output type. And usually, you -can directly use `DefaultParser` as the default parser. +In the class attribute, you need to specify a `default_desc` to specify the description of the tool. -Now, you can override the `setup` and `apply` method. The `setup` method will run when the tool is called at +Then, you can override the `setup` and `apply` method. The `setup` method will run when the tool is called at the first time, and it's usually used to lazy-load some heavy modules. And the `apply` method is the core method to perform when the tool is called. In this example, we only need to override the `apply` method. -```python -class Clock(BaseTool): - ... - - def apply(self): - from datetime import datetime - return datetime.now().strftime('%Y/%m/%d %H:%M') -``` +In the `apply` method, we need to use **Type hint** to specify the type of all inputs and outputs. We have already finished the tool, now you can instantiate it and use it in agent systems. @@ -48,15 +32,20 @@ We have already finished the tool, now you can instantiate it and use it in agen tool = Clock() # Use it in langchain -from langchain.agents import initialize_agent -from langchain.chat_models import ChatOpenAI - -agent = initialize_agent( - agent='structured-chat-zero-shot-react-description', - llm=ChatOpenAI(temperature=0.), +from langchain import hub +from langchain.agents import create_structured_chat_agent, AgentExecutor +from langchain_openai import ChatOpenAI + +# Be attention to specify `OPENAI_API_KEY` environment variable to call ChatGPT. +agent_executor = AgentExecutor( + agent=create_structured_chat_agent( + llm=ChatOpenAI(temperature=0.), + tools=[tool.to_langchain()], + prompt=hub.pull("hwchase17/structured-chat-agent") + ), tools=[tool.to_langchain()], verbose=True) -agent.invoke("What's the time?") +agent_executor.invoke(dict(input="What's the time?")) # Use it in lagent from lagent import ReAct, GPTAPI, ActionExecutor @@ -79,25 +68,16 @@ systems require the raw data to display at the front-end. Therefore, we use agent types as the input & output types of the tool, and use a `parser` to convert to the destination format automatically. -Assume we want a tool that can create an audio caption with the destination language on the input image. +Assume we want a tool that can create a caption audio with the destination language on the input image. ```python from agentlego.tools import BaseTool -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import ImageIO, AudioIO class AudioCaption(BaseTool): - def __init__(self): - toolmeta = ToolMeta( - name='AudioCaption', - description='A tool that can create an audio caption on the input image with the specified language.', - inputs=['image', 'text'], - outputs=['audio'], - ) - super().__init__(toolmeta=toolmeta, parser=DefaultParser) - - def apply(self, image: ImageIO, language: str): + default_desc = 'A tool that can create an audio caption on the input image with the specified language.' + + def apply(self, image: ImageIO, language: str) -> AudioIO: # Convert the agent type to the format we need in the tool. image = image.to_pil() diff --git a/docs/zh_cn/collect_docs.py b/docs/zh_cn/collect_docs.py index 6ac2ec37..455ccc92 100755 --- a/docs/zh_cn/collect_docs.py +++ b/docs/zh_cn/collect_docs.py @@ -16,12 +16,14 @@ ''' DEFAULT_TOOLMETA_TMPL = ''' -## Default Tool Meta +## 默认工具信息 -- **name**: {name} -- **description**: {description} -- **inputs**: {inputs} -- **outputs**: {outputs} +- **名称**: {name} +- **描述**: {description} +- **输入**: +{inputs} +- **输出**: +{outputs} ''' @@ -43,14 +45,27 @@ def format_tool_readme(path): content = contents[start:end] cls_name = content[0].strip('\n# ') from agentlego import tools - toolmeta = getattr(tools, cls_name).DEFAULT_TOOLMETA + from agentlego.schema import ToolMeta + toolmeta: ToolMeta = getattr(tools, cls_name).get_default_toolmeta() + inputs = [] + for p in toolmeta.inputs: + desc = f' - {p.name} ({p.type.__name__})' + if p.description: + desc += f': {p.description}' + inputs.append(desc) + outputs = [] + for p in toolmeta.outputs: + desc = f' - {p.type.__name__}' + if p.description: + desc += f': {p.description}' + outputs.append(desc) content.insert( 1, DEFAULT_TOOLMETA_TMPL.format( name=toolmeta.name, description=toolmeta.description, - inputs=', '.join(toolmeta.inputs), - outputs=', '.join(toolmeta.outputs), + inputs='\n'.join(inputs), + outputs='\n'.join(outputs), )) content.insert(1, AUTODOC_TMPL.format(cls_name=cls_name)) target = tmp_dir / 'tools' / (cls_name + '.md') diff --git a/docs/zh_cn/conf.py b/docs/zh_cn/conf.py index 78d17d41..d87c9399 100755 --- a/docs/zh_cn/conf.py +++ b/docs/zh_cn/conf.py @@ -110,8 +110,7 @@ def get_version(): 'css/readthedocs.css' ] html_js_files = [ - 'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.js', - 'js/custom.js' + 'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.js', 'js/custom.js' ] # -- Options for HTMLHelp output --------------------------------------------- @@ -155,8 +154,7 @@ def get_version(): # dir menu entry, description, category) texinfo_documents = [ (root_doc, 'agentlego', 'AgentLego Documentation', author, 'agentlego', - 'A versatile tool library for enhancing LLM-based agents', - 'Miscellaneous'), + 'A versatile tool library for enhancing LLM-based agents', 'Miscellaneous'), ] # -- Options for Epub output ------------------------------------------------- @@ -195,8 +193,6 @@ def get_version(): intersphinx_mapping = { 'python': ('https://docs.python.org/3', None), 'numpy': ('https://numpy.org/doc/stable', None), - 'torch': ('https://pytorch.org/docs/stable/', None), - 'mmengine': ('https://mmengine.readthedocs.io/en/latest/', None), } # Disable docstring inheritance diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md index d387bd22..72e0800c 100644 --- a/docs/zh_cn/get_started.md +++ b/docs/zh_cn/get_started.md @@ -24,7 +24,7 @@ pip install agentlego[optional] openmim pip install mmpretrain mmdet mmpose easyocr # 用于图像生成工具。 -pip install transformers diffusers mmagic +pip install transformers diffusers ``` 3. 某些工具需要额外的依赖项,在使用之前请查看 `Tool APIs` 中的 **Set up** 部分。 @@ -59,7 +59,7 @@ print(calculator_tool('cos(pi / 6)')) # 图像或音频输入支持多种格式 from PIL import Image -image_caption_tool = load_tool('ImageCaption', device='cuda') +image_caption_tool = load_tool('ImageDescription', device='cuda') img_path = './examples/demo.png' img_pil = Image.open(img_path) print(image_caption_tool(img_path)) @@ -75,15 +75,15 @@ print(image_caption_tool(img_pil)) 以下是一个示例脚本,将 agentlego 工具集成到 Lagent 中: ```python -from agentlego.apis import load_tool from lagent import ReAct, GPTAPI, ActionExecutor +from agentlego.tools import Calculator # 加载您想要使用的工具 -tool = load_tool('Calculator').to_lagent() +tools = [Calculator().to_lagent()] # 构建 Lagent 智能体 model = GPTAPI(temperature=0.) -agent = ReAct(llm=model, action_executor=ActionExecutor([tool])) +agent = ReAct(llm=model, action_executor=ActionExecutor(tools)) user_input = '如果三角形的边长分别为 3cm、4cm 和 5cm,请告诉我三角形的面积。' ret = agent.chat(user_input) @@ -101,27 +101,23 @@ for step in ret.inner_steps[1:]: 以下是一个示例脚本,将 agentlego 工具集成到 LangChain 中: ```python -from agentlego.apis import load_tool -from langchain.agents import AgentType, initialize_agent -from langchain.chains.conversation.memory import ConversationBufferMemory -from langchain.chat_models import ChatOpenAI +from langchain import hub +from langchain.agents import create_structured_chat_agent, AgentExecutor +from langchain.memory import ConversationBufferMemory +from langchain_openai import ChatOpenAI +from agentlego.tools import Calculator # 加载要使用的工具。 -tool = load_tool('Calculator').to_langchain() +tools = [Calculator().to_langchain()] # 构建 LangChain 智能体链 -model = ChatOpenAI(temperature=0., model='gpt-4') +llm = ChatOpenAI(temperature=0.) memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) -agent = initialize_agent( - agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, - llm=model, - tools=[tool], - memory=memory, - verbose=True, -) +agent = create_structured_chat_agent(llm, tools, prompt=hub.pull("hwchase17/structured-chat-agent")) +agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=True) user_input = '如果三角形的边长分别为3cm、4cm和5cm,请用工具计算三角形的面积。' -agent.run(input=user_input) +agent_executor.invoke(dict(input=user_input)) ``` ### Transformers Agent @@ -131,59 +127,20 @@ agent.run(input=user_input) 以下是一个示例脚本,将 agentlego 工具集成到 Transformers agent 中: ```python -from agentlego.apis import load_tool from transformers import HfAgent +from agentlego.tools import Calculator # 加载要使用的工具 -tool = load_tool('Calculator').to_transformers_agent() +tools = [Calculator().to_transformers_agent()] # 构建 Transformers Agent prompt = open('examples/hf_agent/hf_demo_prompts.txt', 'r').read() agent = HfAgent( 'https://api-inference.huggingface.co/models/bigcode/starcoder', chat_prompt_template=prompt, - additional_tools=[tool], + additional_tools=tools, ) user_input = '如果三角形的边长分别为3厘米、4厘米和5厘米,请告诉我三角形的面积。' agent.chat(user_input) ``` - -# 工具服务器 - -AgentLego 提供了一套工具服务器辅助程序,帮助您在服务器上部署工具,并在客户端上像使用本地工具一样调用这些工具。 - -## 启动服务器 - -我们提供了一个 `server.py` 脚本来启动工具服务器。您可以指定要启动的工具类别。 - -```bash -python server.py Calculator ImageCaption TextToImage -``` - -然后,服务器将启动所有工具。 - -```bash -INFO: Started server process [1741344] -INFO: Waiting for application startup. -INFO: Application startup complete. -INFO: Uvicorn running on http://0.0.0.0:16180 (Press CTRL+C to quit) -``` - -## 在客户端使用工具 - -在客户端,您可以使用工具服务器的 URL 创建所有远程工具。 - -```python -from agentlego.tools.remote import RemoteTool - -# 从工具服务器 URL 创建所有远程工具。 -tools = RemoteTool.from_server('http://127.0.0.1:16180') -for tool in tools: - print(tool.name, tool.url) - -# 从工具服务器端点创建单个远程工具。 -# 所有端点都可以在工具服务器的文档中找到,例如 http://127.0.0.1:16180/docs -tool = RemoteTool('http://127.0.0.1:16180/ImageDescription') -print(tool.description) -``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 0a2e29e8..62eb6a40 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -23,6 +23,7 @@ modules/apis.md modules/tool.md + modules/tool-server.md .. _Tool APIs: diff --git a/docs/zh_cn/modules/tool-server.md b/docs/zh_cn/modules/tool-server.md new file mode 100644 index 00000000..bd5ea286 --- /dev/null +++ b/docs/zh_cn/modules/tool-server.md @@ -0,0 +1,89 @@ +# 工具服务器 + +AgentLego 提供了一套工具服务器辅助程序,帮助您在服务器上部署工具,并在客户端上像使用本地工具一样调用这些工具。 + +## 启动服务器 + +我们提供了一个命令行工具 `agentlego-server` 来启动工具服务器。您可以指定要启动的工具类别。 + +```bash +agentlego-server start Calculator ImageDescription TextToImage +``` + +然后,服务器将启动所有工具。 + +```bash +INFO: Started server process [1741344] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://127.0.0.1:16180 (Press CTRL+C to quit) +``` + +## 在客户端使用工具 + +在客户端,您可以使用工具服务器的 URL 创建所有远程工具。 + +```python +from agentlego.tools.remote import RemoteTool + +# 从工具服务器 URL 创建所有远程工具。 +tools = RemoteTool.from_server('http://127.0.0.1:16180') +for tool in tools: + print(tool.name, tool.url) + +# 从工具服务器端点创建单个远程工具。 +# 所有端点都可以在工具服务器的文档中找到,例如 http://127.0.0.1:16180/docs +tool = RemoteTool.from_url('http://127.0.0.1:16180/ImageDescription') +print(tool.description) +``` + +## 如何部署自己的工具 + +`agentlego-server` 接受额外的工具模块,这意味着你不需要修改 `AgentLego` 的源码,只需要在一个 Python 文件或者模 +块里编写你的工具源码,即可使用 `agentlego-server` 部署工具。 + +首先,我们新建一个 Python 文件,名称为 `my_tool.py` + +```python +from agentlego.tools import BaseTool + +class Clock(BaseTool): + default_desc = '返回当前日期和时间的时钟。' + + def apply(self) -> str: + from datetime import datetime + return datetime.now().strftime('%Y/%m/%d %H:%M') + +class RandomNumber(BaseTool): + default_desc = '返回一个不大于 `max` 的随机数' + + def apply(self, max: int) -> int: + import random + return random.randint(0, max) +``` + +在这个文件中,我们定义了两个工具 `Clock` 和 `RandomNumber`,保存文件之后,在命令行中,使用如下命令,检查 +`agentlego-server` 是否能够正确读取这两个工具: + +```bash +# 我们使用 --extra 选项,指定额外的工具源码文件 +# 使用 --no-official 选项隐藏 AgentLego 内置的工具 +agentlego-server list --extra ./my_tool.py --no-official +``` + +获得如下输出,说明 `agentlego-server` 能够读取这些工具 + +``` +┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ +┃ Class ┃ source ┃ +┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ +│ Clock │ /home/my_tool.py │ +│ RandomNumber │ /home/my_tool.py │ +└──────────────┴──────────────────┘ +``` + +启动工具服务器: + +```bash +agentlego-server start --extra ./my_tool.py Clock RandomNumber +``` diff --git a/docs/zh_cn/modules/tool.md b/docs/zh_cn/modules/tool.md index 58509c5d..a94a8d51 100644 --- a/docs/zh_cn/modules/tool.md +++ b/docs/zh_cn/modules/tool.md @@ -8,51 +8,42 @@ AgentLego 是可扩展的,您可以轻松地添加自定义工具并将其应 ```python from agentlego.tools import BaseTool -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta class Clock(BaseTool): - def __init__(self): - toolmeta = ToolMeta( - name='Clock', - description='返回当前日期和时间的时钟。', - inputs=[], - outputs=['text'], - ) - super().__init__(toolmeta=toolmeta, parser=DefaultParser) -``` + default_desc = '返回当前日期和时间的时钟。' -在初始化方法中,您需要构建一个 `ToolMeta` 来指定名称、描述、输入参数类别和输出类别。目前可用的类别有 `text`、`image` 和 `audio`。 + def apply(self) -> str: + from datetime import datetime + return datetime.now().strftime('%Y/%m/%d %H:%M') +``` -然后,还需要指定一个默认解析器 (parser),它用于处理输入和输出类型。通常情况下,您可以直接使用 `DefaultParser` 作为默认解析器。 +在类的属性中,您需要用一个 `default_desc` 来指定工具的默认描述。 之后,可以重载 `BaseTool` 的 `setup` 和 `apply` 方法。`setup` 方法将在第一次调用工具时运行,通常用于延迟加载一些重型模块。`apply` 方法是在调用工具时执行的核心方法。在这个示例中,我们只需重载 `apply` 方法。 -```python -class Clock(BaseTool): - ... - - def apply(self): - from datetime import datetime - return datetime.now().strftime('%Y/%m/%d %H:%M') -``` +在 `apply` 方法中,我们需要使用**类型注解** (Type hint) 的方式指定输入输出的类型。 我们已经完成了这个工具,现在您可以实例化它并在智能体系统中使用。 ```python -# 创建一个实例 +# 创建一个工具实例 tool = Clock() # 在 langchain 中使用 -from langchain.agents import initialize_agent -from langchain.chat_models import ChatOpenAI - -agent = initialize_agent( - agent='structured-chat-zero-shot-react-description', - llm=ChatOpenAI(temperature=0.), +from langchain import hub +from langchain.agents import create_structured_chat_agent, AgentExecutor +from langchain_openai import ChatOpenAI + +# 注意在环境变量中设定 OPENAI_API_KEY 以调用 ChatGPT +agent_executor = AgentExecutor( + agent=create_structured_chat_agent( + llm=ChatOpenAI(temperature=0.), + tools=[tool.to_langchain()], + prompt=hub.pull("hwchase17/structured-chat-agent") + ), tools=[tool.to_langchain()], verbose=True) -agent.invoke("现在几点了?") +agent_executor.invoke(dict(input="现在几点了?")) # 在 lagent 中使用 from lagent import ReAct, GPTAPI, ActionExecutor @@ -72,25 +63,16 @@ AgentLego 的一个核心特性是支持多模态工具,同时我们也需要 因此,我们使用代理类型作为工具的输入和输出类型,并使用一个`parser`自动将其转换为目标格式。 -假设我们要实现一个工具,它可以使用输入图像,用指定语言生成一个音频概述。 +假设我们要实现一个工具,它可以使用输入图像,用指定语言生成一段概述音频。 ```python from agentlego.tools import BaseTool -from agentlego.parsers import DefaultParser -from agentlego.schema import ToolMeta from agentlego.types import ImageIO, AudioIO class AudioCaption(BaseTool): - def __init__(self): - toolmeta = ToolMeta( - name='AudioCaption', - description='一个可以根据输入图像和指定语言,生成概要音频的工具。', - inputs=['image', 'text'], - outputs=['audio'], - ) - super().__init__(toolmeta=toolmeta, parser=DefaultParser) - - def apply(self, image: ImageIO, language: str): + default_desc = '一个可以根据输入图像和指定语言,生成概要音频的工具。' + + def apply(self, image: ImageIO, language: str) -> AudioIO: # 将代理类型转换为工具中所需的格式。 image = image.to_pil() diff --git a/examples/hf_agent/hf_agent_example.py b/examples/hf_agent/hf_agent_example.py index 27a3efba..137796db 100644 --- a/examples/hf_agent/hf_agent_example.py +++ b/examples/hf_agent/hf_agent_example.py @@ -9,14 +9,13 @@ tools = [ load_tool(tool_type).to_transformers_agent() for tool_type in [ - 'ImageCaption', + 'ImageDescription', 'TextToSpeech', ] ] agent = HfAgent( 'https://api-inference.huggingface.co/models/bigcode/starcoder', - chat_prompt_template=(Path(__file__).parent / - 'hf_demo_prompts.txt').read_text(), + chat_prompt_template=(Path(__file__).parent / 'hf_demo_prompts.txt').read_text(), additional_tools=tools, ) diff --git a/examples/hf_agent/hf_agent_notebook.ipynb b/examples/hf_agent/hf_agent_notebook.ipynb index f0160a61..ff253b41 100644 --- a/examples/hf_agent/hf_agent_notebook.ipynb +++ b/examples/hf_agent/hf_agent_notebook.ipynb @@ -21,7 +21,7 @@ "from agentlego.apis import load_tool\n", "tools = [\n", " load_tool(tool_type, device='cuda').to_transformers_agent()\n", - " for tool_type in ['ImageCaption', 'TextToSpeech']\n", + " for tool_type in ['ImageDescription', 'TextToSpeech']\n", "]" ] }, diff --git a/examples/lagent_example.py b/examples/lagent_example.py index a8572fdb..82d89c41 100644 --- a/examples/lagent_example.py +++ b/examples/lagent_example.py @@ -30,9 +30,6 @@ def main(): max_turn=3, action_executor=ActionExecutor(actions=tools), ) - system = chatbot._protocol.format([], [], - chatbot._action_executor)[0]['content'] - print(f'\033[92mSystem\033[0m:\n{system}') while True: try: @@ -43,14 +40,12 @@ def main(): if user == 'exit': exit(0) - try: - chatbot.chat(user) - finally: - for history in chatbot._inner_history[1:]: - if history['role'] == 'system': - print(f"\033[92mSystem\033[0m:{history['content']}") - elif history['role'] == 'assistant': - print(f"\033[92mBot\033[0m:\n{history['content']}") + result = chatbot.chat(user) + for history in result.inner_steps: + if history['role'] == 'system': + print(f"\033[92mSystem\033[0m:{history['content']}") + elif history['role'] == 'assistant': + print(f"\033[92mBot\033[0m:\n{history['content']}") if __name__ == '__main__': diff --git a/examples/langchain_example.py b/examples/langchain_example.py index 4711be7c..c10547b1 100644 --- a/examples/langchain_example.py +++ b/examples/langchain_example.py @@ -1,8 +1,9 @@ import argparse -from langchain.agents import AgentType, initialize_agent -from langchain.chains.conversation.memory import ConversationBufferMemory -from langchain.chat_models import ChatOpenAI +from langchain import hub +from langchain.agents import AgentExecutor, create_structured_chat_agent +from langchain.memory import ConversationBufferMemory +from langchain_openai import ChatOpenAI from prompt_toolkit import ANSI, prompt from agentlego.apis import load_tool @@ -27,16 +28,14 @@ def main(): tools = [load_tool(tool_type).to_langchain() for tool_type in args.tools] # set OPEN_API_KEY in your environment or directly pass it with key='' llm = ChatOpenAI(temperature=0, model=args.model) - memory = ConversationBufferMemory( - memory_key='chat_history', return_messages=True) + memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) - agent = initialize_agent( - tools, + agent = create_structured_chat_agent( llm, - agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - memory=memory, + tools, + prompt=hub.pull('hwchase17/structured-chat-agent') ) + agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=True) while True: try: @@ -46,7 +45,8 @@ def main(): continue if user == 'exit': exit(0) - print(f'\033[91m{args.model}\033[0m:', agent.run(input=user)) + res = agent_executor.invoke(dict(input=user)) + print(f'\033[91m{args.model}\033[0m: {res["output"]}') if __name__ == '__main__': diff --git a/examples/remote_example.py b/examples/remote_example.py index 8ab7fcb9..b0418666 100644 --- a/examples/remote_example.py +++ b/examples/remote_example.py @@ -30,8 +30,7 @@ def main(): max_turn=3, action_executor=ActionExecutor(actions=tools), ) - system = chatbot._protocol.format([], [], - chatbot._action_executor)[0]['content'] + system = chatbot._protocol.format([], [], chatbot._action_executor)[0]['content'] print(f'\033[92mSystem\033[0m:\n{system}') while True: diff --git a/examples/streamlit_demo.py b/examples/streamlit_demo.py index 32802457..5a1abfc6 100644 --- a/examples/streamlit_demo.py +++ b/examples/streamlit_demo.py @@ -144,7 +144,7 @@ def process_result_text(text: str) -> str: user=('Please describe the above image, ' 'and draw a similar image in anime style.'), files=[load_image(rootdir / 'examples/demo.png')], - tools=['VisualQuestionAnswering', 'TextToImage'], + tools=['VQA', 'TextToImage'], ) ] @@ -222,12 +222,10 @@ def get_model(model_name): def init_chatbot(model_name, tools): - logger.info(f'Init {model_name} with: ' + - ', '.join([tool.name for tool in tools])) + logger.info(f'Init {model_name} with: ' + ', '.join([tool.name for tool in tools])) chatbot = ReAct( llm=get_model(model_name), - action_executor=ActionExecutor( - actions=[tool.to_lagent() for tool in tools]), + action_executor=ActionExecutor(actions=[tool.to_lagent() for tool in tools]), protocol=ReActProtocol(), max_turn=10, ) @@ -267,8 +265,7 @@ def update_api_key(): return model.keys = [state.api_key] - st.sidebar.text_input( - 'API key', key='api_key', on_change=update_api_key) + st.sidebar.text_input('API key', key='api_key', on_change=update_api_key) st.sidebar.multiselect( 'Tools', @@ -305,8 +302,7 @@ def click_example(example: dict): clear_session() files = [] for file in example['files']: - from streamlit.runtime.uploaded_file_manager import \ - UploadedFile + from streamlit.runtime.uploaded_file_manager import UploadedFile files.append(UploadedFile(file, None)) state.files = files for tool in example['tools']: @@ -392,10 +388,7 @@ def render_action(action: ActionReturn, retry: Tuple[int, int] = None): st.markdown(process_result_text(action.result['text'])) for image in action.result.get('image', []): w, h = Image.open(image).size - st.image( - image, - caption='Generated Image', - width=int(350 / h * w)) + st.image(image, caption='Generated Image', width=int(350 / h * w)) for audio in action.result.get('audio', []): st.audio(audio) elif action.errmsg: @@ -494,8 +487,7 @@ def start_chat(): st.error(repr(e)) logger.info(f'Error: {repr(e)}') - state.history.append( - dict(role='assistant', content=copy.deepcopy(responses))) + state.history.append(dict(role='assistant', content=copy.deepcopy(responses))) state.disable_chat = False state.disable_clear = False st.rerun() diff --git a/examples/visual_chatgpt/visual_chatgpt.py b/examples/visual_chatgpt/visual_chatgpt.py index e9dbd536..7de06a3e 100644 --- a/examples/visual_chatgpt/visual_chatgpt.py +++ b/examples/visual_chatgpt/visual_chatgpt.py @@ -149,18 +149,17 @@ class ConversationBot: def __init__(self, load_dict): # load_dict = { # 'OCRTool':'cuda:0', - # 'ImageCaption':'cuda:1',...} + # 'ImageDescription':'cuda:1',...} print(f'Initializing VisualChatGPT, load_dict={load_dict}') - if 'ImageCaption' not in load_dict: - raise ValueError('You have to load ImageCaption as a ' + if 'ImageDescription' not in load_dict: + raise ValueError('You have to load ImageDescription as a ' 'basic function for VisualChatGPT') self.models = {} # Load tools for class_name, device in load_dict.items(): - self.models[class_name] = load_tool( - class_name, device=device).to_langchain() + self.models[class_name] = load_tool(class_name, device=device).to_langchain() print(f'All the Available Functions: {self.models}') @@ -225,9 +224,8 @@ def run_image(self, image, state, txt, lang): img = img.resize((width_new, height_new)) img = img.convert('RGB') img.save(image_filename, 'PNG') - print( - f'Resize image form {width}x{height} to {width_new}x{height_new}') - description = self.models['ImageCaption'](image_filename) + print(f'Resize image form {width}x{height} to {width_new}x{height_new}') + description = self.models['ImageDescription'](image_filename) if lang == 'Chinese': Human_prompt = (f'\nHuman: 提供一张名为 {image_filename}的图片。' f'它的描述是: {description}。 ' @@ -248,9 +246,7 @@ def run_image(self, image, state, txt, lang): AI_prompt = 'Received. ' self.agent.memory.buffer = self.agent.memory.buffer + \ Human_prompt + 'AI: ' + AI_prompt - state = state + [ - (f'![](file={image_filename})*{image_filename}*', AI_prompt) - ] + state = state + [(f'![](file={image_filename})*{image_filename}*', AI_prompt)] return state, state, f'{txt} {image_filename} ' @@ -258,7 +254,7 @@ def run_image(self, image, state, txt, lang): if not os.path.exists('checkpoints'): os.mkdir('checkpoints') parser = argparse.ArgumentParser() - parser.add_argument('--load', type=str, default='ImageCaption_cuda:0') + parser.add_argument('--load', type=str, default='ImageDescription_cuda:0') args = parser.parse_args() load_dict = { e.split('_')[0].strip(): e.split('_')[1].strip() @@ -266,8 +262,7 @@ def run_image(self, image, state, txt, lang): } bot = ConversationBot(load_dict=load_dict) with gr.Blocks(css='#chatbot .overflow-y-auto{height:500px}') as demo: - lang = gr.Radio( - choices=['Chinese', 'English'], value=None, label='Language') + lang = gr.Radio(choices=['Chinese', 'English'], value=None, label='Language') chatbot = gr.Chatbot(elem_id='chatbot', label='Visual ChatGPT') state = gr.State([]) with gr.Row(visible=False) as input_raws: @@ -284,8 +279,7 @@ def run_image(self, image, state, txt, lang): lang.change(bot.init_agent, [lang], [input_raws, lang, txt, clear]) txt.submit(bot.run_text, [txt, state], [chatbot, state]) txt.submit(lambda: '', None, txt) - btn.upload(bot.run_image, [btn, state, txt, lang], - [chatbot, state, txt]) + btn.upload(bot.run_image, [btn, state, txt, lang], [chatbot, state, txt]) clear.click(bot.memory.clear) clear.click(lambda: [], None, chatbot) clear.click(lambda: [], None, state) diff --git a/pyproject.toml b/pyproject.toml index c90543c1..a5ace3a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools >= 62.6"] +requires = ["setuptools >= 62.6, < 64.0"] build-backend = "setuptools.build_meta" [project] @@ -24,6 +24,9 @@ dynamic = ["version", "readme", "dependencies", "optional-dependencies"] Documentation = "https://agentlego.readthedocs.io" Repository = "https://github.com/InternLM/agentlego" +[project.scripts] +agentlego-server = "agentlego.server.server:cli" + [tool.setuptools.packages.find] where = ["."] include = ["agentlego*"] @@ -42,10 +45,11 @@ based_on_style = "pep8" blank_line_before_nested_class_or_def = true split_before_expression_after_opening_paren = true split_penalty_import_names = 0 -SPLIT_PENALTY_AFTER_OPENING_BRACKET = 800 +split_penalty_after_opening_bracket = 800 +column_limit = 89 [tool.isort] -line_length = 79 +line_length = 89 multi_line_output = 0 known_first_party = "agentlego" no_lines_before = ["STDLIB", "LOCALFOLDER"] diff --git a/requirements/runtime.txt b/requirements/runtime.txt index a1d5573f..cab0a984 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,10 +1,14 @@ +addict func_timeout -mmengine>=0.8 numpy -opencv-python +openapi-pydantic packaging Pillow prompt_toolkit +pydantic>=2.0 +pyyaml requests +rich thefuzz tqdm +typing-extensions diff --git a/requirements/server.txt b/requirements/server.txt index 8fede1a1..acd357ac 100644 --- a/requirements/server.txt +++ b/requirements/server.txt @@ -1,4 +1,5 @@ fastapi +makefun python-multipart -typing-extensions +typer>=0.9 uvicorn[standard] diff --git a/server.py b/server.py deleted file mode 100644 index 16586cca..00000000 --- a/server.py +++ /dev/null @@ -1,175 +0,0 @@ -import argparse -import base64 -import inspect -from io import BytesIO -from typing import Dict -from urllib.parse import quote_plus - -import uvicorn -from fastapi import APIRouter, FastAPI, File, Form, UploadFile -from typing_extensions import Annotated - -from agentlego.apis import load_tool -from agentlego.parsers import NaiveParser -from agentlego.tools.base import BaseTool -from agentlego.types import AudioIO, CatgoryToIO, ImageIO - -prog_description = """\ -Start a server for several tools. -""" - - -def parse_args(): - parser = argparse.ArgumentParser(description=prog_description) - parser.add_argument( - 'tools', - type=str, - nargs='+', - help='The tools to deploy', - ) - parser.add_argument( - '--port', - default=16180, - type=int, - help='The port number', - ) - parser.add_argument( - '--device', - default='cuda:0', - type=str, - help='The device to deploy the tools', - ) - parser.add_argument( - '--no-setup', - action='store_true', - help='Avoid setup tools during starting the server.', - ) - args = parser.parse_args() - return args - - -args = parse_args() -tools: Dict[str, BaseTool] = {} -for name in args.tools: - tool = load_tool(name, device=args.device, parser=NaiveParser) - if not args.no_setup: - tool.setup() - tool._is_setup = True - tools[quote_plus(tool.name.replace(' ', ''))] = tool - -app = FastAPI() -tool_router = APIRouter() - - -@app.get('/') -def index(): - response = [] - for tool_name, tool in tools.items(): - response.append( - dict( - domain=tool_name, - toolmeta=tool.toolmeta.__dict__, - parameters=[p.__dict__ for p in tool.parameters.values()], - )) - return response - - -def add_tool(tool_name: str): - tool: BaseTool = tools[tool_name] - - def _call(**kwargs): - args = {} - for p in tool.parameters.values(): - data = kwargs[p.name] - if p.category == 'image': - from PIL import Image - data = ImageIO(Image.open(data.file)) - elif p.category == 'audio': - import torchaudio - file_format = data.filename.rpartition('.')[-1] or None - raw, sr = torchaudio.load(data.file, format=file_format) - data = AudioIO(raw, sampling_rate=sr) - else: - data = CatgoryToIO[p.category](data) - args[p.name] = data - - outs = tool(**args) - if not isinstance(outs, tuple): - outs = [outs] - - res = [] - for out, out_category in zip(outs, tool.toolmeta.outputs): - if out_category == 'image': - file = BytesIO() - out.to_pil().save(file, format='png') - res.append( - dict( - type='image', - data=base64.encodebytes( - file.getvalue()).decode('ascii'), - )) - elif out_category == 'audio': - import torchaudio - file = BytesIO() - torchaudio.save( - file, out.to_tensor(), out.sampling_rate, format='wav') - res.append( - dict( - type='audio', - data=base64.encodebytes( - file.getvalue()).decode('ascii'), - )) - else: - res.append(out) - return res - - def call(**kwargs): - try: - return _call(**kwargs) - except Exception as e: - return dict(error=repr(e)) - - call_args = {} - call_params = [] - for p in tool.parameters.values(): - if p.category in ['image', 'audio']: - annotation = Annotated[UploadFile, File(media_type=p.category)] - else: - type_ = { - 'text': str, - 'int': int, - 'bool': bool, - 'float': float - }[p.category] - annotation = Annotated[type_, Form()] - - call_args[p.name] = annotation - call_params.append( - inspect.Parameter( - p.name, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=p.default if p.optional else inspect._empty, - annotation=annotation, - )) - call.__signature__ = inspect.Signature(call_params) - call.__annotations__ = call_args - tool_router.add_api_route( - f'/{tool_name}/call', - endpoint=call, - methods=['POST'], - ) - tool_router.add_api_route( - f'/{tool_name}/meta', - endpoint=lambda: dict( - toolmeta=tool.toolmeta.__dict__, - parameters=[p.__dict__ for p in tool.parameters.values()]), - methods=['GET'], - ) - - -for tool_name in tools: - add_tool(tool_name) -app.include_router(tool_router) - -if __name__ == '__main__': - uvicorn.run(app, host='0.0.0.0', port=args.port) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..01e211aa --- /dev/null +++ b/setup.cfg @@ -0,0 +1,3 @@ +[flake8] +exclude = webui +max-line-length = 89 diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..60684932 --- /dev/null +++ b/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup + +setup() diff --git a/tests/test_apis.py b/tests/test_apis.py index 52151fa4..bb325854 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -8,7 +8,7 @@ def test_load_tool(): assert isinstance(tool, Calculator) # description will be overwrite - tool = load_tool('ImageCaption', description='custom') + tool = load_tool('ImageDescription', description='custom') assert 'custom' in tool.toolmeta.description # cached tool diff --git a/tests/test_tools/test_basetool.py b/tests/test_tools/test_basetool.py index 0efc51fa..f4ce46e8 100644 --- a/tests/test_tools/test_basetool.py +++ b/tests/test_tools/test_basetool.py @@ -1,21 +1,16 @@ -from agentlego.parsers import DefaultParser -from agentlego.tools.base import BaseTool -from agentlego.types import ImageIO +from agentlego.tools import BaseTool +from agentlego.types import Annotated, ImageIO, Info class DummyTool(BaseTool): - DEFAULT_TOOLMETA = dict( - name='Dummy Tool', - description='This is a dummy tool. It takes an image and a text ' - 'as the inputs, and returns an image.', - inputs=('image', 'text'), - outputs=('image', ), - ) - - def __init__(self): - super().__init__(toolmeta=self.DEFAULT_TOOLMETA, parser=DefaultParser) - - def apply(self, image: ImageIO, query: str) -> ImageIO: + default_desc = 'This is a dummy tool.' + + def apply( + self, + image: ImageIO, + query: Annotated[str, Info('The query')], + option: bool = True, + ) -> Annotated[ImageIO, Info('The result image')]: return image @@ -24,14 +19,10 @@ def test_lagent(): tool = DummyTool().to_lagent() assert isinstance(tool, BaseAction) - expected_description = ( - 'This is a dummy tool. It takes an image and a text as the inputs, ' - 'and returns an image. Args: image (image path), query (text string) ' - 'Combine all args to one json string like {"image": xxx, "query": xxx}' - ) - assert tool.name == 'DummyTool' - assert tool.description == expected_description + assert tool.description['description'] == 'This is a dummy tool.' + assert tool.description['required'] == ['image', 'query'] + assert tool.description['parameters'][1]['description'] == 'The query' def test_hf_agent(): @@ -39,11 +30,12 @@ def test_hf_agent(): tool = DummyTool().to_transformers_agent() assert isinstance(tool, Tool) - expected_description = ( - 'This is a dummy tool. It takes an image and a text as ' - 'the inputs, and returns an image. Args: image (image), query (text)') + expected_description = '''\ +This is a dummy tool. +Args: image (image); query (str, The query); option (bool. Optional, Defaults to True) +Returns: image (The result image)''' - assert tool.name == 'agentlego_dummy_tool' + assert tool.name == 'agentlego_dummytool' assert tool.description == expected_description @@ -52,9 +44,10 @@ def test_langchain(): tool = DummyTool().to_langchain() assert isinstance(tool, StructuredTool) - expected_description = ( - 'Dummy Tool(image: str, query: str) - This is a dummy tool. ' - 'It takes an image and a text as the inputs, and returns an image.') + expected_description = '''\ +DummyTool(image: str, query: str, option: str = True) - This is a dummy tool. +Args: image (path); query (str, The query); option (bool. Optional, Defaults to True) +Returns: path (The result image)''' - assert tool.name == 'Dummy Tool' + assert tool.name == 'DummyTool' assert tool.description == expected_description diff --git a/tests/test_tools/test_image_canny/test_canny_to_image.py b/tests/test_tools/test_image_canny/test_canny_to_image.py deleted file mode 100644 index 1f646f0d..00000000 --- a/tests/test_tools/test_image_canny/test_canny_to_image.py +++ /dev/null @@ -1,30 +0,0 @@ -import os.path as osp -from unittest import skipIf - -import cv2 -import numpy as np -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf(not is_installed('mmagic'), reason='requires mmagic') -class TestCannyTextToImage(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'CannyTextToImage', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path, 'prompt') - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'CannyTextToImage', parser=HuggingFaceAgentParser(), device='cuda') - res = tool(img, 'prompt') - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_canny/test_image_to_canny.py b/tests/test_tools/test_image_canny/test_image_to_canny.py deleted file mode 100644 index 2361f6b3..00000000 --- a/tests/test_tools/test_image_canny/test_image_to_canny.py +++ /dev/null @@ -1,27 +0,0 @@ -import os.path as osp - -import cv2 -import numpy as np -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -class TestImageToCanny(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'ImageToCanny', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path) - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'ImageToCanny', parser=HuggingFaceAgentParser(), device='cuda') - res = tool(img) - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_depth/test_depth_to_image.py b/tests/test_tools/test_image_depth/test_depth_to_image.py deleted file mode 100644 index 5c6dea04..00000000 --- a/tests/test_tools/test_image_depth/test_depth_to_image.py +++ /dev/null @@ -1,27 +0,0 @@ -import os.path as osp - -import cv2 -import numpy as np -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -class TestDepthTextToImage(ToolTestCase): - - def test_all(self): - tool = load_tool( - 'DepthTextToImage', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path, 'prompt') - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'DepthTextToImage', parser=HuggingFaceAgentParser(), device='cuda') - res = tool(img, 'prompt') - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_depth/test_image_to_depth.py b/tests/test_tools/test_image_depth/test_image_to_depth.py deleted file mode 100644 index 59bf0d28..00000000 --- a/tests/test_tools/test_image_depth/test_image_to_depth.py +++ /dev/null @@ -1,27 +0,0 @@ -import os.path as osp - -import cv2 -import numpy as np -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -class TestImageToDepth(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'ImageToDepth', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path) - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'ImageToDepth', parser=HuggingFaceAgentParser(), device='cuda') - res = tool(img) - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_editing/test_image_extension.py b/tests/test_tools/test_image_editing/test_image_extension.py deleted file mode 100644 index fb69294e..00000000 --- a/tests/test_tools/test_image_editing/test_image_extension.py +++ /dev/null @@ -1,21 +0,0 @@ -from PIL import Image - -from agentlego import load_tool -from agentlego.parsers import HuggingFaceAgentParser, LangChainParser -from agentlego.testing import ToolTestCase - - -class TestImageExpansion(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'ImageExpansion', parser=LangChainParser(), device='cpu') - img_path = 'tests/data/images/dog2.jpg' - res = tool(img_path, '2000x1000') - assert isinstance(res, str) - - img = Image.open(img_path) - tool = load_tool( - 'ImageExpansion', parser=HuggingFaceAgentParser(), device='cpu') - res = tool(img, '2000x1000') - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_editing/test_image_stylization.py b/tests/test_tools/test_image_editing/test_image_stylization.py deleted file mode 100644 index dd4249fa..00000000 --- a/tests/test_tools/test_image_editing/test_image_stylization.py +++ /dev/null @@ -1,25 +0,0 @@ -from unittest import skipIf - -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf(not is_installed('diffusers'), reason='requires diffusers') -class TestInstructPix2Pix(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'ImageStylization', parser=LangChainParser(), device='cuda') - img_path = 'tests/data/images/dog.jpg' - res = tool(f'{img_path}, watercolor painting') - assert isinstance(res, str) - - img = Image.open(img_path) - tool = load_tool( - 'ImageStylization', parser=HuggingFaceAgentParser(), device='cpu') - res = tool(img, 'watercolor painting') - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_editing/test_object_remove.py b/tests/test_tools/test_image_editing/test_object_remove.py deleted file mode 100644 index a680e7c0..00000000 --- a/tests/test_tools/test_image_editing/test_object_remove.py +++ /dev/null @@ -1,26 +0,0 @@ -from unittest import skipIf - -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf( - not is_installed('segment_anything'), reason='requires segment_anything') -class TestObjectRemove(ToolTestCase): - - def test_call(self): - img_path = 'tests/data/images/dog2.jpg' - tool = load_tool( - 'ObjectRemove', parser=LangChainParser(), device='cpu') - res = tool(img_path, 'dog') - assert isinstance(res, str) - - img = Image.open(img_path) - tool = load_tool( - 'ObjectRemove', parser=HuggingFaceAgentParser(), device='cpu') - res = tool(img, 'dog') - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_editing/test_object_replace.py b/tests/test_tools/test_image_editing/test_object_replace.py deleted file mode 100644 index a88750d4..00000000 --- a/tests/test_tools/test_image_editing/test_object_replace.py +++ /dev/null @@ -1,26 +0,0 @@ -from unittest import skipIf - -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf( - not is_installed('segment_anything'), reason='requires segment_anything') -class TestObjectReplace(ToolTestCase): - - def test_call(self): - img_path = 'tests/data/images/dog2.jpg' - tool = load_tool( - 'ObjectReplace', parser=LangChainParser(), device='cpu') - res = tool(img_path, 'dog', 'a cartoon dog') - assert isinstance(res, str) - - img = Image.open(img_path) - tool = load_tool( - 'ObjectReplace', parser=HuggingFaceAgentParser(), device='cpu') - res = tool(img, 'dog', 'a cartoon dog') - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_pose/test_facelandmark.py b/tests/test_tools/test_image_pose/test_facelandmark.py deleted file mode 100644 index 9076742d..00000000 --- a/tests/test_tools/test_image_pose/test_facelandmark.py +++ /dev/null @@ -1,32 +0,0 @@ -import os.path as osp -from unittest import skipIf - -import cv2 -import numpy as np -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf(not is_installed('mmpose'), reason='requires mmpose') -class TestHumanFaceLandmark(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'HumanFaceLandmark', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path) - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'HumanFaceLandmark', - parser=HuggingFaceAgentParser(), - device='cuda') - res = tool(img) - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_pose/test_image_to_pose.py b/tests/test_tools/test_image_pose/test_image_to_pose.py deleted file mode 100644 index 8a4d7a45..00000000 --- a/tests/test_tools/test_image_pose/test_image_to_pose.py +++ /dev/null @@ -1,30 +0,0 @@ -import os.path as osp -from unittest import skipIf - -import cv2 -import numpy as np -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf(not is_installed('mmpose'), reason='requires mmpose') -class TestHumanBodyPose(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'HumanBodyPose', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path) - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'HumanBodyPose', parser=HuggingFaceAgentParser(), device='cuda') - res = tool(img) - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_pose/test_pose_to_image.py b/tests/test_tools/test_image_pose/test_pose_to_image.py deleted file mode 100644 index fb9232b6..00000000 --- a/tests/test_tools/test_image_pose/test_pose_to_image.py +++ /dev/null @@ -1,30 +0,0 @@ -import os.path as osp -from unittest import skipIf - -import cv2 -import numpy as np -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf(not is_installed('mmagic'), reason='requires mmagic') -class TestPoseToImage(ToolTestCase): - - def test_all(self): - tool = load_tool( - 'PoseToImage', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path, 'prompt') - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'PoseToImage', parser=HuggingFaceAgentParser(), device='cuda') - res = tool(img, 'prompt') - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_scribble/test_image_to_scribble.py b/tests/test_tools/test_image_scribble/test_image_to_scribble.py deleted file mode 100644 index b854cd33..00000000 --- a/tests/test_tools/test_image_scribble/test_image_to_scribble.py +++ /dev/null @@ -1,27 +0,0 @@ -import os.path as osp - -import cv2 -import numpy as np -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -class TestImageToScribble(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'ImageToScribble', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path) - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'ImageToScribble', parser=HuggingFaceAgentParser(), device='cuda') - res = tool(img) - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_scribble/test_scribble_to_image.py b/tests/test_tools/test_image_scribble/test_scribble_to_image.py deleted file mode 100644 index 83936f8f..00000000 --- a/tests/test_tools/test_image_scribble/test_scribble_to_image.py +++ /dev/null @@ -1,29 +0,0 @@ -import os.path as osp - -import cv2 -import numpy as np -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -class TestScribbleTextToImage(ToolTestCase): - - def test_all(self): - tool = load_tool( - 'ScribbleTextToImage', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path, 'prompt') - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'ScribbleTextToImage', - parser=HuggingFaceAgentParser(), - device='cuda') - res = tool(img, 'prompt') - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_image_text/test_image_to_text.py b/tests/test_tools/test_image_text/test_image_to_text.py deleted file mode 100644 index 3c596d32..00000000 --- a/tests/test_tools/test_image_text/test_image_to_text.py +++ /dev/null @@ -1,38 +0,0 @@ -from pathlib import Path - -import pytest - -from agentlego.parsers import NaiveParser -from agentlego.testing import setup_tool - -data_dir = Path(__file__).parents[2] / 'data' -test_img = (data_dir / 'images/dog.jpg').absolute() - - -@pytest.fixture() -def tool(): - from agentlego.tools import ImageCaption - return setup_tool(ImageCaption, device='cuda') - - -def test_call(tool): - tool.set_parser(NaiveParser) - res = tool(str(test_img)) - assert isinstance(res, str) - - -def test_hf_agent(tool, hf_agent): - tool = tool.to_transformers_agent() - hf_agent.prepare_for_new_chat() - hf_agent._toolbox = {tool.name: tool} - - out = hf_agent.chat(f'Please describe the image `{test_img}`.') - assert out is not None - - -def test_lagent(tool, lagent_agent): - lagent_agent.new_session([tool.to_lagent()]) - - out = lagent_agent.chat(f'Please describe the image `{test_img}`.') - assert out.actions[-1].valid == 1 - assert 'dog' in out.response diff --git a/tests/test_tools/test_image_text/test_text_to_image.py b/tests/test_tools/test_image_text/test_text_to_image.py deleted file mode 100644 index 91cfa02a..00000000 --- a/tests/test_tools/test_image_text/test_text_to_image.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest - -import agentlego.types as types -from agentlego.parsers import NaiveParser -from agentlego.testing import setup_tool - - -@pytest.fixture() -def tool(): - from agentlego.tools import TextToImage - return setup_tool(TextToImage, device='cuda') - - -def test_call(tool): - tool.set_parser(NaiveParser) - res = tool('generate an image of a cat') - assert isinstance(res, types.ImageIO) - - -def test_hf_agent(tool, hf_agent): - tool = tool.to_transformers_agent() - hf_agent.prepare_for_new_chat() - hf_agent._toolbox = {tool.name: tool} - - out = hf_agent.chat('generate an image of a cat') - assert out is not None - - -def test_lagent(tool, lagent_agent): - lagent_agent.new_session([tool.to_lagent()]) - - out = lagent_agent.chat('generate an image of a cat') - assert out.actions[-1].valid == 1 - assert '.png' in out.response diff --git a/tests/test_tools/test_imagebind/test_anything_to_image.py b/tests/test_tools/test_imagebind/test_anything_to_image.py deleted file mode 100644 index 1a44fcb4..00000000 --- a/tests/test_tools/test_imagebind/test_anything_to_image.py +++ /dev/null @@ -1,37 +0,0 @@ -from pathlib import Path - -from agentlego.parsers import NaiveParser -from agentlego.testing import setup_tool -from agentlego.tools import (AudioImageToImage, AudioTextToImage, AudioToImage, - ThermalToImage) -from agentlego.types import ImageIO - -data_dir = Path(__file__).parents[2] / 'data' - - -def test_audio_to_image(): - tool = setup_tool(AudioToImage, device='cuda') - tool.set_parser(NaiveParser) - res = tool(data_dir / 'audio/cat.wav') - assert isinstance(res, ImageIO) - - -def test_thermal_to_image(): - tool = setup_tool(ThermalToImage, device='cuda') - tool.set_parser(NaiveParser) - res = tool(data_dir / 'audio/030419.jpg') - assert isinstance(res, ImageIO) - - -def test_audio_image_to_image(): - tool = setup_tool(AudioImageToImage, device='cuda') - tool.set_parser(NaiveParser) - res = tool(data_dir / 'images/dog.jpg', data_dir / 'audio/cat.wav') - assert isinstance(res, ImageIO) - - -def test_audio_text_to_image(): - tool = setup_tool(AudioTextToImage, device='cuda') - tool.set_parser(NaiveParser) - res = tool(data_dir / 'audio/cat.wav', 'generate a cat flying in the sky') - assert isinstance(res, ImageIO) diff --git a/tests/test_tools/test_object_detection/test_object_detection.py b/tests/test_tools/test_object_detection/test_object_detection.py deleted file mode 100644 index 261af816..00000000 --- a/tests/test_tools/test_object_detection/test_object_detection.py +++ /dev/null @@ -1,32 +0,0 @@ -import os.path as osp -from unittest import skipIf - -import cv2 -import numpy as np -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf(not is_installed('mmdet'), reason='requires mmdet') -class TestObjectDetectionTool(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'ObjectDetectionTool', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path) - assert isinstance(res, str) - - tool = load_tool( - 'ObjectDetectionTool', - parser=HuggingFaceAgentParser(), - device='cuda') - img = Image.fromarray(img) - res = tool(img) - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_object_detection/test_text_to_bbox.py b/tests/test_tools/test_object_detection/test_text_to_bbox.py deleted file mode 100644 index c5ca6784..00000000 --- a/tests/test_tools/test_object_detection/test_text_to_bbox.py +++ /dev/null @@ -1,29 +0,0 @@ -import os.path as osp -from unittest import skipIf - -import cv2 -import numpy as np -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf(not is_installed('mmdet'), reason='requires mmdet') -class TestTextToBbox(ToolTestCase): - - def test_call(self): - tool = load_tool('TextToBbox', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(f'{img_path}, man') - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'TextToBbox', parser=HuggingFaceAgentParser(), device='cuda') - res = tool(img, 'man') - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_ocr/test_ocr.py b/tests/test_tools/test_ocr/test_ocr.py deleted file mode 100644 index a8f74b55..00000000 --- a/tests/test_tools/test_ocr/test_ocr.py +++ /dev/null @@ -1,39 +0,0 @@ -import os.path as osp -from unittest import skipIf - -import cv2 -import numpy as np -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import HuggingFaceAgentParser, LangChainParser - - -@skipIf(not is_installed('mmocr'), reason='requires mmocr') -class TestOCR(ToolTestCase): - - def test_call(self): - tool = load_tool('OCR', parser=LangChainParser(), device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path) - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool('OCR', parser=HuggingFaceAgentParser(), device='cuda') - res = tool(img) - assert isinstance(res, list) - - -@skipIf(not is_installed('mmocr'), reason='requires mmocr') -class TestImageMaskOCR(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'ImageMaskOCR', parser=LangChainParser(), device='cuda') - res = tool('tests/data/images/cups.png, ' - 'tests/data/images/cups_mask.png') - assert isinstance(res, str) diff --git a/tests/test_tools/test_segmentation/test_segment_anything.py b/tests/test_tools/test_segmentation/test_segment_anything.py deleted file mode 100644 index 2a2cccd1..00000000 --- a/tests/test_tools/test_segmentation/test_segment_anything.py +++ /dev/null @@ -1,41 +0,0 @@ -from unittest import skipIf - -from mmengine import is_installed - -from agentlego import load_tool -from agentlego.testing import ToolTestCase -from agentlego.tools.parsers import LangChainParser - - -@skipIf( - not is_installed('segment_anything'), reason='requires segment_anything') -class TestSegmentAnything(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'SegmentAnything', parser=LangChainParser(), device='cpu') - res = tool('tests/data/images/cups.png') - assert isinstance(res, str) - - -@skipIf( - not is_installed('segment_anything'), reason='requires segment_anything') -class TestSegmentClicked(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'SegmentClicked', parser=LangChainParser(), device='cpu') - res = tool('tests/data/images/cups.png, ' - 'tests/data/images/cups_mask.png') - assert isinstance(res, str) - - -@skipIf( - not is_installed('segment_anything'), reason='requires segment_anything') -class TestObjectSegmenting(ToolTestCase): - - def test_call(self): - tool = load_tool( - 'ObjectSegmenting', parser=LangChainParser(), device='cpu') - res = tool('tests/data/images/cups.png, water cup') - assert isinstance(res, str) diff --git a/tests/test_tools/test_segmentation/test_semantic_segmentation.py b/tests/test_tools/test_segmentation/test_semantic_segmentation.py deleted file mode 100644 index 8fbde08f..00000000 --- a/tests/test_tools/test_segmentation/test_semantic_segmentation.py +++ /dev/null @@ -1,31 +0,0 @@ -import os.path as osp -from unittest import skipIf - -import cv2 -import numpy as np -from mmengine import is_installed -from PIL import Image - -from agentlego import load_tool -from agentlego.testing import ToolTestCase - - -@skipIf(not is_installed('mmsegmentation'), reason='mmsegmentation') -class TestSemanticSegmentation(ToolTestCase): - - def test_call(self): - tool = load_tool('SemanticSegmentation', device='cuda') - img = np.ones([224, 224, 3]).astype(np.uint8) - img_path = osp.join(self.tempdir.name, 'temp.jpg') - cv2.imwrite(img_path, img) - res = tool(img_path) - assert isinstance(res, str) - - img = Image.fromarray(img) - tool = load_tool( - 'SemanticSegmentation', - output_style='pil image', - input_style='pil image', - device='cuda') - res = tool(img) - assert isinstance(res, Image.Image) diff --git a/tests/test_tools/test_speech_test/test_speech_to_text.py b/tests/test_tools/test_speech_test/test_speech_to_text.py deleted file mode 100644 index 370efed3..00000000 --- a/tests/test_tools/test_speech_test/test_speech_to_text.py +++ /dev/null @@ -1,38 +0,0 @@ -from pathlib import Path - -import pytest - -from agentlego.parsers import NaiveParser -from agentlego.testing import setup_tool - -data_dir = Path(__file__).parents[2] / 'data' -test_audio = (data_dir / 'audio/speech_to_text.flac').absolute() - - -@pytest.fixture() -def tool(): - from agentlego.tools import SpeechToText - return setup_tool(SpeechToText, device='cuda') - - -def test_call(tool): - tool.set_parser(NaiveParser) - res = tool(str(test_audio)) - assert isinstance(res, str) - - -def test_hf_agent(tool, hf_agent): - tool = tool.to_transformers_agent() - hf_agent.prepare_for_new_chat() - hf_agent._toolbox = {tool.name: tool} - - out = hf_agent.chat(f'Convert the audio `{test_audio}` to text.') - assert 'going along slushy country' in out - - -def test_lagent(tool, lagent_agent): - lagent_agent.new_session([tool.to_lagent()]) - - out = lagent_agent.chat(f'Convert the audio `{test_audio}` to text.') - assert out.actions[-1].valid == 1 - assert 'going along slushy country' in out.response diff --git a/tests/test_tools/test_speech_test/test_text_to_speech.py b/tests/test_tools/test_speech_test/test_text_to_speech.py deleted file mode 100644 index e3416778..00000000 --- a/tests/test_tools/test_speech_test/test_text_to_speech.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest - -from agentlego.parsers import NaiveParser -from agentlego.testing import setup_tool -from agentlego.types import AudioIO - - -@pytest.fixture() -def tool(): - from agentlego.tools import TextToSpeech - return setup_tool(TextToSpeech, device='cuda') - - -def test_call(tool): - tool.set_parser(NaiveParser) - res = tool('Hello world') - assert isinstance(res, AudioIO) - - -def test_hf_agent(tool, hf_agent): - tool = tool.to_transformers_agent() - hf_agent.prepare_for_new_chat() - hf_agent._toolbox = {tool.name: tool} - - out = hf_agent.chat('Please speak out the text `Hello world`.') - assert out is not None - - -def test_lagent(tool, lagent_agent): - lagent_agent.new_session([tool.to_lagent()]) - - out = lagent_agent.chat('Please speak out the text `Hello world`.') - assert out.actions[-1].valid == 1 - assert 'wav' in out.response diff --git a/tests/test_tools/test_translation/test_translation.py b/tests/test_tools/test_translation/test_translation.py deleted file mode 100644 index b50f42dc..00000000 --- a/tests/test_tools/test_translation/test_translation.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest - -from agentlego.parsers import NaiveParser -from agentlego.testing import setup_tool - -text = 'Legumes share resources with nitrogen-fixing bacteria' -source_lang = 'English' -target_lang = 'French' - - -@pytest.fixture() -def tool(): - from agentlego.tools import Translation - return setup_tool(Translation, device='cuda') - - -def test_call(tool): - tool.set_parser(NaiveParser) - res = tool(text, source_lang, target_lang) - assert isinstance(res, str) - - -def test_hf_agent(tool, hf_agent): - tool = tool.to_transformers_agent() - hf_agent.prepare_for_new_chat() - hf_agent._toolbox = {tool.name: tool} - - out = hf_agent.chat(f'Please translate the `{text}` from {source_lang} ' - f'to {target_lang}.`') - assert out.startswith('Les légumes') - - -def test_lagent(tool, lagent_agent): - lagent_agent.new_session([tool.to_lagent()]) - - out = lagent_agent.chat( - f'Translate the `{text}` from {source_lang} to {target_lang}') - assert out.actions[-1].valid == 1 - assert out.response.startswith('Les légumes') diff --git a/tests/test_tools/test_vqa/test_vqa.py b/tests/test_tools/test_vqa/test_vqa.py deleted file mode 100644 index fc8689b2..00000000 --- a/tests/test_tools/test_vqa/test_vqa.py +++ /dev/null @@ -1,38 +0,0 @@ -from pathlib import Path - -import pytest - -from agentlego.parsers import NaiveParser -from agentlego.testing import setup_tool -from agentlego.types import ImageIO - -data_dir = Path(__file__).parents[2] / 'data' -test_image = (data_dir / 'images/dog.jpg').absolute() - - -@pytest.fixture() -def tool(): - from agentlego.tools import VisualQuestionAnswering - return setup_tool(VisualQuestionAnswering, device='cuda') - - -def test_call(tool): - tool.set_parser(NaiveParser) - assert isinstance(tool(ImageIO(str(test_image)), 'prompt'), str) - - -def test_hf_agent(tool, hf_agent): - tool = tool.to_transformers_agent() - hf_agent.prepare_for_new_chat() - hf_agent._toolbox = {tool.name: tool} - - out = hf_agent.chat(f'Please describe the `{test_image}`') - assert isinstance(out, str) - - -def test_lagent(tool, lagent_agent): - lagent_agent.new_session([tool.to_lagent()]) - - out = lagent_agent.chat(f'Please describe the `{test_image}`') - assert out.actions[-1].valid == 1 - assert isinstance(out.response, str) diff --git a/webui/README.md b/webui/README.md new file mode 100644 index 00000000..a4f6d993 --- /dev/null +++ b/webui/README.md @@ -0,0 +1,116 @@ +# AgentLego WebUI + +An easy-to-use Gradio App to setup agent and tools. + +## Setup your LLM + +AgentLego and this web ui doesn't aim to host LLM, therefore, you need to use other framework to host a LLM as +the backend of agents. + +### OpenAI models + +A simple choice is to use OpenAI's models, and we have provided the agent configs to use OpenAI models in the +preset configs. You need to set the OpenAI API key in the environment variables. + +```bash +export OPENAI_API_KEY="your openai key" +``` + +### LMDeploy + +LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by the **InternLM** teams. + +To host a LLM with LMDeploy, use the below command to install LMDeploy (See the [official tutorial](https://lmdeploy.readthedocs.io/en/latest/get_started.html) for more details): + +```bash +pip install 'lmdeploy>=0.2.1' +``` + +And then, host an OpenAI-style API server by the below command, here we use InternLM2 as example. + +```bash +lmdeploy serve api_server internlm/internlm2-chat-20b +``` + +And after startup, you will get the below output: + +``` +INFO: Started server process [853738] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) +``` + +### vLLM + +vLLM is a fast and easy-to-use library for LLM inference and serving. + +To host a LLM with vLLM, use the below command to install vLLM (See the [official tutorial](https://docs.vllm.ai/en/latest/getting_started/installation.html) for more details): + +```bash +# (Optional) Create a new conda environment. +conda create -n myenv python=3.9 -y +conda activate myenv + +# Install vLLM with CUDA 12.1. +pip install vllm +``` + +And then, host an OpenAI-style API server by the below command, here we use QWen as example. + +```bash +# Get the ChatML style chat template +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/template_chatml.jinja + +# Start the vLLM server +python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen-14B-Chat --trust-remote-code --chat-template ./template_chatml.jinja +``` + +And after startup, you will get the below output: + +``` +INFO: Started server process [3837676] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` + +## Setup WebUI + +You can use the `start_linux.sh` to create a standalone environment from scratch, or use `one_click.py` to +setup the environment on your own environment. + +```bash +bash startup_linux.sh +# OR +python one_click.py +``` + +Then you will get the below output and then, open the URL in your browser. + +``` +Running on local URL: http://127.0.0.1:7860 +``` + +## Chat with Agent + +After open the web ui, you need to choose the agent in the `Agent` tab. If you are hosting your own LLM, you +can use `langchain.StructuredChat.ChatOpenAI` agent and pass the URL of your LLM server in the `API base url` +field. + +And to setup available tools, you need to add them in the `Tools` tab. + +If the response of agent raise an parse error (it's common for the LLM with low instruction-following ability) +or you want to re-roll the response, you can click `Regenerate` to regenerate the last response. + +During chat, you can select tools from the all available tools in the `Select tools` checkboxes. + +If you want to save the current chat, click the `Save` button. And you can also resume the past chats from the +`Past chats` dropdown. + +You can upload files (images or audios) during chat. If you have provided related tools, agent may use these +tools to handle your file. + +## Acknowledge + +The WebUI app is modified from [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui). Thanks for the great work. diff --git a/webui/README_zh-CN.md b/webui/README_zh-CN.md new file mode 100644 index 00000000..d82c7275 --- /dev/null +++ b/webui/README_zh-CN.md @@ -0,0 +1,110 @@ +# AgentLego WebUI + +一个可以方便和 Agent 系统对话的 Gradio WebUI + +## 设置您的语言模型 + +AgentLego 和这个 WebUI 并不负责部署大语言模型。因此,您需要使用其他框架来托管大语言模型作为 Agent 系统的后端。 + +### OpenAI模型 + +一个简单的选择是使用 OpenAI 的 API,我们已经在预设配置中提供了使用 OpenAI 模型的 Agent 配置。 + +您需要在环境变量中设置 OpenAI API key。 + +```bash +export OPENAI_API_KEY="你的 OpenAI API key" +``` + +### LMDeploy + +LMDeploy 是一个用于压缩、部署和提供大型语言模型(LLM)服务的工具包,由 **InternLM** 团队开发。 + +要使用 LMDeploy 托管一个 LLM,请使用以下命令安装 LMDeploy(更多详情请参见[官方教程](https://lmdeploy.readthedocs.io/en/latest/get_started.html)): + +```bash +pip install 'lmdeploy>=0.2.1' +``` + +然后,通过以下命令托管一个类 OpenAI 风格的 API 服务器,这里我们以 InternLM2 为例。 + +```bash +lmdeploy serve api_server internlm/internlm2-chat-20b +``` + +启动后,你将得到以下输出: + +``` +INFO: Started server process [853738] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) +``` + +### vLLM + +vLLM 是一个快速且易于使用的大型语言模型(LLM)推理和服务库。 + +要使用 vLLM 托管一个 LLM,请使用以下命令安装 vLLM(更多详情请参见[官方教程](https://docs.vllm.ai/en/latest/getting_started/installation.html)): + +```bash +# (可选) 创建一个新的 conda 环境。 +conda create -n myenv python=3.9 -y +conda activate myenv + +# 安装带 CUDA 12.1 的 vLLM。 +pip install vllm +``` + +然后,通过以下命令托管一个类 OpenAI 风格的 API 服务器,这里我们以 QWen 为例。 + +```bash +# 获取 ChatML 风格的聊天模板 +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/template_chatml.jinja + +# 启动 vLLM 服务器 +python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen-14B-Chat --trust-remote-code --chat-template ./template_chatml.jinja +``` + +启动后,你将得到以下输出: + +``` +INFO: Started server process [3837676] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` + +## 设置 WebUI + +你可以使用 `start_linux.sh` 从头开始创建一个独立的环境,或者使用 `one_click.py` 在你自己的环境上设置环境。 + +```bash +bash startup_linux.sh +# 或者 +python one_click.py +``` + +然后你将得到以下输出,之后,在浏览器中打开该 URL。 + +``` +Running on local URL: http://127.0.0.1:7860 +``` + +## 与 Agent 聊天 + +打开 WebUI 后,你需要在 `Agent` 标签中选择 Agent。如果你部署了自己的 LLM,可以使用 `langchain.StructuredChat.ChatOpenAI`,并在 `API base url` 字段中传递你的 LLM 服务器的 URL。 + +要设置可用的工具,你需要在 `Tools` 标签中添加它们并设置 `enable`。 + +如果 Agent 的响应引发了解析错误(对于执行指令能力较低的 LLM 来说很常见)或你想重新生成响应,你可以点击 `Regenerate` 以重新生成最后一个响应。 + +在聊天过程中,你可以从所有可用的工具中选择工具,勾选 `Select tools` 复选框。 + +如果你想保存当前聊天,点击 `Save` 按钮。你也可以从 `Past chats` 下拉菜单中恢复过去的聊天。 + +在聊天过程中,你可以上传文件(图片或音频)。如果你已提供相关工具,Agent 可能会使用这些工具来处理你的文件。 + +## 致谢 + +WebUI 应用是基于 [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) 修改的。感谢他们的出色工作。 diff --git a/webui/agent_config.yml.example b/webui/agent_config.yml.example new file mode 100644 index 00000000..6f8199c5 --- /dev/null +++ b/webui/agent_config.yml.example @@ -0,0 +1,14 @@ +gpt-3.5-turbo: + agent_class: langchain.StructuredChat.ChatOpenAI + model_name: gpt-3.5-turbo + openai_api_base: null + openai_api_key: null + max_tokens: 512 + extra_stop: null +gpt-4-turbo: + agent_class: langchain.StructuredChat.ChatOpenAI + model_name: gpt-4-1106-preview + openai_api_base: null + openai_api_key: null + max_tokens: 512 + extra_stop: null diff --git a/webui/app.py b/webui/app.py new file mode 100644 index 00000000..c290cfe6 --- /dev/null +++ b/webui/app.py @@ -0,0 +1,98 @@ +import time +from threading import Lock + +import gradio as gr +from modules import chat, shared, ui, ui_agent, ui_chat, ui_tools, utils +from modules.agents import load_llm +from modules.logging import logger +from modules.tools import load_tool +from modules.utils import gradio + + +def create_interface(): + + title = 'AgentLego Web UI' + + # Password authentication + auth = [] + if shared.args.gradio_auth: + auth.extend(x.strip() for x in shared.args.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()) + if shared.args.gradio_auth_path: + with open(shared.args.gradio_auth_path, 'r', encoding='utf8') as file: + auth.extend(x.strip() for line in file for x in line.split(',') if x.strip()) + auth = [tuple(cred.split(':')) for cred in auth] + + # Interface state elements + shared.input_elements = ui.list_interface_input_elements() + + with gr.Blocks( + css=ui.css, + analytics_enabled=False, + title=title, + theme=ui.theme, + ) as shared.gradio['interface']: + + # Interface state + shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) + + ui_chat.create_ui() # Chat Tab + ui_agent.create_ui() # Agent Tab + ui_tools.create_ui() # Tools Tab + + ui_chat.create_event_handlers() + ui_agent.create_event_handlers() + ui_tools.create_event_handlers() + + shared.gradio['interface'].load(lambda: None, None, None, js=f'() => {{{ui.js}}}') + shared.gradio['interface'].load(lambda: shared.agent_name, None, gradio('agent_menu'), show_progress=False) + shared.gradio['interface'].load(chat.redraw_html, gradio('history'), gradio('display')) + + # Launch the interface + shared.gradio['interface'].queue() + shared.gradio['interface'].launch( + prevent_thread_lock=True, + share=shared.args.share, + server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'), + server_port=shared.args.listen_port, + inbrowser=shared.args.auto_launch, + auth=auth or None, + ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True, + ssl_keyfile=shared.args.ssl_keyfile, + ssl_certfile=shared.args.ssl_certfile, + max_threads=64, + allowed_paths=['.'], + ) + + +if __name__ == '__main__': + + # Initialize agent + available_agents = utils.get_available_agents() + if shared.args.agent is not None: + assert shared.args.agent in available_agents + shared.agent_name = shared.args.agent + if shared.agent_name is not None: + # Load the agent + shared.llm = load_llm(shared.agent_name) + + # Initialize tools + for name in shared.tool_settings: + try: + logger.info(f'Loading tool `{name}`') + load_tool(name) + except Exception: + logger.exception('Traceback') + logger.warning(f'Failed to load tool `{name}`, auto disabled.') + + shared.generation_lock = Lock() + + # Launch the web UI + create_interface() + while True: + time.sleep(0.5) + if shared.need_restart: + shared.need_restart = False + time.sleep(0.5) + shared.gradio['interface'].close() + time.sleep(0.5) + create_interface() diff --git a/webui/css/chat.css b/webui/css/chat.css new file mode 100644 index 00000000..2cf20cde --- /dev/null +++ b/webui/css/chat.css @@ -0,0 +1,140 @@ +.message { + display: grid; + grid-template-columns: 60px minmax(0, 1fr); + padding-bottom: 25px; + font-size: 15px; + font-family: 'Noto Sans', Helvetica, Arial, sans-serif; + line-height: 22.5px !important; +} + +.message-body { + margin-top: 3px; +} + +.circle-you { + width: 40px; + height: 40px; + background-color: rgb(255, 108, 108); + border-radius: 20%; +} + +.circle-bot { + width: 40px; + height: 40px; + background-color: rgb(255, 189, 69); + border-radius: 20%; +} + +.circle-bot svg, +.circle-you svg { + border-radius: 50%; + width: 100%; + height: 100%; + object-fit: cover; +} + +.username { + font-weight: bold; +} + +.message-body p { + font-size: 15px !important; + line-height: 22.5px !important; +} + +.message-body p, .chat .message-body ul, .chat .message-body ol { + margin-bottom: 10px !important; +} + +.message-body p:last-child, .chat .message-body ul:last-child, .chat .message-body ol:last-child { + margin-bottom: 0 !important; +} + +.dark .message-body p em { + color: rgb(138 138 138) !important; +} + +.message-body p em { + color: rgb(110 110 110) !important; + font-weight: 500; +} + +/* -------------------- 工具展示相关样式 -------------------- */ +/*
的基础样式 */ +.message-body details.tool { + background-color: #182029; + margin: 15px 0; + padding: 10px; + border-radius: 8px; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.3); + transition: 0.3s; +} + +/* 鼠标悬停在
元素上时的样式 */ +.message-body details.tool:hover { + background-color: #1e2c3a; +} + +/* 元素的样式,这里使用了伪元素来自定义展开按钮的样式 */ +.message-body summary { + font-weight: bold; + cursor: pointer; + outline: none; +} + +/* 工具参数和回复的样式 */ +.message-body .tool-args, .tool-response, .tool-thought { + background-color: #111A24; + margin: 8px 0; + padding: 10px; + border-radius: 4px; + font-family: 'Courier New', monospace; +} + +/* 斜体强调的文本 */ +.message-body .tool em { + font-style: italic; +} + +/* 最终答案模块的样式 */ +.message-body .final-answer .thought { + margin: 20px 0; +} + +/* 图片的样式,确保它们在容器内完全显示 */ +.message-body img { + max-width: 500px; + max-height: 300px; + height: auto; + display: block; + margin: 10px 0; + border-radius: 4px; + border: 2px solid #162B3D; +} + +.message-body .tool img { + max-width: 300px; + max-height: 300px; +} + +.message-body .error-box { + background-color: #6a313194; + color: #fff; + border: 1px solid #e74c3c; + padding: 12px; + margin: 10px 0; + border-radius: 4px; + font-family: 'Courier New', Courier, monospace; +} + +.message-body .error-title { + font-size: 16px; + font-weight: bold; + margin-bottom: 8px; +} + +.message-body .error-reason { + font-size: 14px; + opacity: 0.7; + line-height: 1.6; +} diff --git a/webui/css/main.css b/webui/css/main.css new file mode 100644 index 00000000..a350ee71 --- /dev/null +++ b/webui/css/main.css @@ -0,0 +1,655 @@ +.tabs.svelte-710i53 { + margin-top: 0 +} + +.py-6 { + padding-top: 2.5rem +} + +.small-button { + min-width: 0 !important; + max-width: 171px; + height: 39.594px; + align-self: end; +} + +.refresh-button { + max-width: 4.4em; + min-width: 2.2em !important; + height: 39.594px; + align-self: end; + line-height: 1em; + border-radius: 0.5em; + flex: none; +} + +.refresh-button-small { + max-width: 2.2em; +} + +.button_nowrap { + white-space: nowrap; +} + +#slim-column { + flex: none !important; + min-width: 0 !important; +} + +.slim-dropdown { + background-color: transparent !important; + border: none !important; + padding: 0 !important; +} + +#download-label, #upload-label { + min-height: 0 +} + +.dark svg { + fill: white; +} + +.dark a { + color: white !important; +} + +ol li p, ul li p { + display: inline-block; +} + +#chat-tab, #default-tab, #notebook-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab { + border: 0; +} + +.gradio-container-3-18-0 .prose * h1, h2, h3, h4 { + color: white; +} + +.gradio-container { + max-width: 100% !important; + padding-top: 0 !important; +} + +#extensions { + margin-top: 5px; + margin-bottom: 35px; +} + +.extension-tab { + border: 0 !important; +} + +span.math.inline { + font-size: 27px; + vertical-align: baseline !important; +} + +div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * { + flex-wrap: nowrap; +} + +.header_bar { + background-color: #f7f7f7; + margin-bottom: 19px; + overflow-x: scroll; + margin-left: calc(-1 * var(--size-4)); + margin-right: calc(-1 * var(--size-4)); + display: block !important; + text-wrap: nowrap; +} + +.dark .header_bar { + border: none !important; + background-color: #8080802b !important; +} + +.header_bar button.selected { + border-radius: 0; +} + +.textbox_default textarea { + height: calc(100dvh - 271px); +} + +.textbox_default_output textarea { + height: calc(100dvh - 185px); +} + +.textbox textarea { + height: calc(100dvh - 241px); +} + +.textbox_logits textarea { + height: calc(100dvh - 236px); +} + +.textbox_logits_notebook textarea { + height: calc(100dvh - 292px); +} + +.monospace textarea { + font-family: monospace; +} + +.textbox_default textarea, +.textbox_default_output textarea, +.textbox_logits textarea, +.textbox_logits_notebook textarea, +.textbox textarea { + font-size: 16px !important; + color: #46464A !important; +} + +.dark textarea { + color: #efefef !important; +} + +@media screen and (width <= 711px) { + .textbox_default textarea { + height: calc(100dvh - 259px); + } + + div .default-token-counter { + top: calc( 0.5 * (100dvh - 236px) ) !important; + } + + .transparent-substring { + display: none; + } + + .hover-menu { + min-width: 250px !important; + } +} + +/* Hide the gradio footer */ +footer { + display: none !important; +} + +button { + font-size: 14px !important; +} + +.file-saver { + position: fixed !important; + height: 100%; + z-index: 1000; + background-color: rgb(0 0 0 / 50%) !important; + margin-left: -20px; + margin-right: -20px; +} + +.file-saver > :first-child { + position: fixed !important; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); /* center horizontally */ + width: 100%; + max-width: 500px; + background-color: var(--input-background-fill); + border: var(--input-border-width) solid var(--input-border-color) !important; +} + +.file-saver > :first-child > :nth-child(2) { + background: var(--block-background-fill); +} + +.checkboxgroup-table label { + background: none !important; + padding: 0 !important; + border: 0 !important; +} + +.checkboxgroup-table div { + display: grid !important; +} + +.markdown ul ol { + font-size: 100% !important; +} + +.pretty_scrollbar::-webkit-scrollbar { + width: 5px; +} + +.pretty_scrollbar::-webkit-scrollbar-track { + background: transparent; +} + +.pretty_scrollbar::-webkit-scrollbar-thumb, +.pretty_scrollbar::-webkit-scrollbar-thumb:hover { + background: #c5c5d2; +} + +.dark .pretty_scrollbar::-webkit-scrollbar-thumb, +.dark .pretty_scrollbar::-webkit-scrollbar-thumb:hover { + background: #374151; +} + +.pretty_scrollbar::-webkit-resizer { + background: #c5c5d2; +} + +.dark .pretty_scrollbar::-webkit-resizer { + background: #374151; +} + +audio { + max-width: 100%; +} + +/* Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui */ +.token-counter { + position: absolute !important; + top: calc( 0.5 * (100dvh - 218px) ) !important; + right: 2px; + z-index: 100; + background: var(--input-background-fill) !important; + min-height: 0 !important; +} + +.default-token-counter { + top: calc( 0.5 * (100dvh - 248px) ) !important; +} + +.token-counter span { + padding: 1px; + box-shadow: 0 0 0 0.3em rgb(192 192 192 / 15%), inset 0 0 0.6em rgb(192 192 192 / 7.5%); + border: 2px solid rgb(192 192 192 / 40%) !important; + border-radius: 0.4em; +} + +.no-background { + background: var(--background-fill-primary) !important; + padding: 0 !important; +} + +/* ---------------------------------------------- + Chat tab +---------------------------------------------- */ +.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx { + height: 66.67vh +} + +.gradio-container { + margin-left: auto !important; + margin-right: auto !important; +} + +.w-screen { + width: unset +} + +div.svelte-362y77>*, div.svelte-362y77>.form>* { + flex-wrap: nowrap +} + +.pending.svelte-1ed2p3z { + opacity: 1; +} + +.wrap.svelte-6roggh.svelte-6roggh { + max-height: 92.5%; +} + +/* This is for the microphone button in the whisper extension */ +.sm.svelte-1ipelgc { + width: 100%; +} + +#chat-tab button#Generate, #chat-tab button#stop { + width: 89.3438px !important; +} + +#chat-tab button, #notebook-tab button, #default-tab button { + min-width: 0 !important; +} + +#chat-tab > :first-child, #extensions { + max-width: 880px; + margin-left: auto; + margin-right: auto; +} + +@media screen and (width <= 688px) { + #chat-tab { + padding-left: 0; + padding-right: 0; + } + + .chat-parent { + height: calc(100dvh - 179px) !important; + } + + .old-ui .chat-parent { + height: calc(100dvh - 310px) !important; + } +} + +.chat { + margin-left: auto; + margin-right: auto; + max-width: 880px; + height: 100%; + overflow-y: auto; + padding-right: 15px; + display: flex; + flex-direction: column; + word-break: break-word; + overflow-wrap: anywhere; +} + +.chat-parent { + height: calc(100dvh - 181px); + overflow: auto !important; +} + +.old-ui .chat-parent { + height: calc(100dvh - 270px); +} + +.chat-parent.bigchat { + height: calc(100dvh - 181px) !important; +} + +.chat > .messages { + display: flex; + flex-direction: column; +} + +.chat .message:last-child { + margin-bottom: 0 !important; + padding-bottom: 0 !important; +} + +.message-body li:not(:last-child) { + margin-top: 0 !important; + margin-bottom: 2px !important; +} + +.message-body li:last-child { + margin-bottom: 0 !important; +} + +.message-body li > p { + display: inline !important; +} + +.message-body ul, .message-body ol { + font-size: 15px !important; +} + +.message-body ul { + list-style-type: disc !important; +} + +.message-body pre:not(:last-child) { + margin-bottom: 35.625px !important; +} + +.message-body pre:last-child { + margin-bottom: 0 !important; +} + +.message-body code { + white-space: pre-wrap !important; + word-wrap: break-word !important; + border: 1px solid var(--border-color-primary); + border-radius: var(--radius-sm); + background: var(--background-fill-secondary); + font-size: 90%; + padding: 1px 3px; +} + +.message-body pre > code { + display: block; + padding: 15px; +} + +.message-body :not(pre) > code { + white-space: normal !important; +} + +#chat-input { + padding: 0; + padding-top: 18px; + background: transparent; + border: none; +} + +#chat-input textarea:focus { + box-shadow: none !important; +} + +#chat-input > :first-child { + background-color: transparent; +} + +#chat-input .progress-text { + display: none; +} + +@media print { + body { + visibility: hidden; + } + + .chat { + visibility: visible; + position: absolute; + left: 0; + top: 0; + max-width: unset; + max-height: unset; + width: 100%; + overflow-y: visible; + } + + .message { + break-inside: avoid; + } + + .gradio-container { + overflow: visible; + } + + .tab-nav { + display: none !important; + } + + #chat-tab > :first-child { + max-width: unset; + } +} + +#show-controls { + position: absolute; + height: 100%; + background-color: var(--background-fill-primary); + border: 0 !important; + border-radius: 0; +} + +#show-controls label { + z-index: 1000; + position: absolute; + left: calc(100% - 168px); +} + +#typing-container { + display: none; + position: absolute; + background-color: transparent; + left: -2px; + padding: var(--block-padding); +} + +.typing { + position: relative; +} + +.visible-dots #typing-container { + display: block; +} + +.typing span { + content: ''; + animation: blink 1.5s infinite; + animation-fill-mode: both; + height: 10px; + width: 10px; + background: #3b5998;; + position: absolute; + left:0; + top:0; + border-radius: 50%; +} + +.typing .dot1 { + animation-delay: .2s; + margin-left: calc(10px * 1.5); +} + +.typing .dot2 { + animation-delay: .4s; + margin-left: calc(10px * 3); +} + +@keyframes blink { + 0% { + opacity: .1; + } + + 20% { + opacity: 1; + } + + 100% { + opacity: .1; + } +} + +#chat-tab .generating { + display: none !important; +} + +.hover-element { + position: relative; + font-size: 24px; +} + +.hover-menu { + display: none; + position: absolute; + bottom: 80%; + left: 0; + background-color: var(--background-fill-secondary); + box-shadow: 0 0 10px rgb(0 0 0 / 50%); + z-index: 10000; + min-width: 330px; + flex-direction: column; +} + +.hover-menu button { + width: 100%; + background: transparent !important; + border-radius: 0 !important; + justify-content: space-between; + margin: 0 !important; + height: 36px; +} + +.hover-menu button:not(#clear-history-confirm) { + border-bottom: 0 !important; +} + +.hover-menu button:not(#clear-history-confirm):last-child { + border-bottom: var(--button-border-width) solid var(--button-secondary-border-color) !important; +} + +.hover-menu button:hover { + background: var(--button-secondary-background-fill-hover) !important; +} + +.transparent-substring { + opacity: 0.333; +} + +#chat-tab:not(.old-ui) #chat-buttons { + display: none !important; +} + +#gr-hover-container { + min-width: 0 !important; + display: flex; + flex-direction: column-reverse; + padding-right: 20px; + padding-bottom: 3px; + flex-grow: 0 !important; +} + +#upload-container, +#generate-stop-container { + min-width: 0 !important; + display: flex; + flex-direction: column-reverse; + padding-bottom: 3px; + flex: 0 auto !important; +} + +#chat-input-container { + min-width: 0 !important; +} + +#chat-input-container > .form { + background: transparent; + border: none; +} + +#chat-input-row { + padding-bottom: 20px; +} + +.old-ui #chat-input-row, #chat-input-row.bigchat { + padding-bottom: 0 !important; +} + +#chat-col { + padding-bottom: 5px; +} + +.chat ol, .chat ul { + margin-top: 6px !important; +} + +/* ---------------------------------------------- + Past chats menus +---------------------------------------------- */ +#past-chats-row { + margin-bottom: calc( -1 * var(--layout-gap) ); +} + +#rename-row label { + margin-top: var(--layout-gap); +} + +/* ---------------------------------------------- + Keep dropdown menus above errored components +---------------------------------------------- */ +.options { + z-index: 100 !important; +} + +/* ---------------------------------------------- + Current agent info +---------------------------------------------- */ +.current-agent, +.current-agent-warn { + text-align: right; +} + +.current-agent-warn { + color: red; +} diff --git a/webui/js/main.js b/webui/js/main.js new file mode 100644 index 00000000..4bfec28c --- /dev/null +++ b/webui/js/main.js @@ -0,0 +1,180 @@ +let main_parent = document.getElementById("chat-tab").parentNode; + +main_parent.childNodes[0].classList.add("header_bar"); +main_parent.style = "padding: 0; margin: 0"; +main_parent.parentNode.style = "gap: 0"; +main_parent.parentNode.parentNode.style = "padding: 0"; + +//------------------------------------------------ +// Position the chat typing dots +//------------------------------------------------ +typing = document.getElementById("typing-container"); +typingParent = typing.parentNode; +typingSibling = typing.previousElementSibling; +typingSibling.insertBefore(typing, typingSibling.childNodes[2]); + +//------------------------------------------------ +// Chat scrolling +//------------------------------------------------ +const targetElement = document.getElementById("chat").parentNode.parentNode.parentNode; +targetElement.classList.add("pretty_scrollbar"); +targetElement.classList.add("chat-parent"); +let isScrolled = false; + +targetElement.addEventListener("scroll", function() { + let diff = targetElement.scrollHeight - targetElement.clientHeight; + if(Math.abs(targetElement.scrollTop - diff) <= 10 || diff == 0) { + isScrolled = false; + } else { + isScrolled = true; + } +}); + +// Create a MutationObserver instance +const observer = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if(!isScrolled) { + targetElement.scrollTop = targetElement.scrollHeight; + } + + const firstChild = targetElement.children[0]; + if (firstChild.classList.contains("generating")) { + typing.parentNode.classList.add("visible-dots"); + document.getElementById("stop").style.display = "flex"; + document.getElementById("Generate").style.display = "none"; + } else { + typing.parentNode.classList.remove("visible-dots"); + document.getElementById("stop").style.display = "none"; + document.getElementById("Generate").style.display = "flex"; + } + + }); +}); + +// Configure the observer to watch for changes in the subtree and attributes +const config = { + childList: true, + subtree: true, + characterData: true, + attributeOldValue: true, + characterDataOldValue: true +}; + +// Start observing the target element +observer.observe(targetElement, config); + +//------------------------------------------------ +// Add some scrollbars +//------------------------------------------------ +const textareaElements = document.querySelectorAll(".add_scrollbar textarea"); +for(i = 0; i < textareaElements.length; i++) { + textareaElements[i].classList.remove("scroll-hide"); + textareaElements[i].classList.add("pretty_scrollbar"); + textareaElements[i].style.resize = "none"; +} + +//------------------------------------------------ +// Remove some backgrounds +//------------------------------------------------ +const noBackgroundelements = document.querySelectorAll(".no-background"); +for(i = 0; i < noBackgroundelements.length; i++) { + noBackgroundelements[i].parentNode.style.border = "none"; + noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = "center"; +} + +const slimDropdownElements = document.querySelectorAll('.slim-dropdown'); +for (i = 0; i < slimDropdownElements.length; i++) { + const parentNode = slimDropdownElements[i].parentNode; + parentNode.style.background = 'transparent'; + parentNode.style.border = '0'; +} + +//------------------------------------------------ +// Create the hover menu in the chat tab +// The show/hide events were adapted from: +// https://github.com/SillyTavern/SillyTavern/blob/6c8bd06308c69d51e2eb174541792a870a83d2d6/public/script.js +//------------------------------------------------ +var buttonsInChat = document.querySelectorAll("#chat-tab:not(.old-ui) #chat-buttons button"); +var button = document.getElementById("hover-element-button"); +var menu = document.getElementById("hover-menu"); + +function showMenu() { + menu.style.display = "flex"; // Show the menu +} + +function hideMenu() { + menu.style.display = "none"; // Hide the menu + document.querySelector("#chat-input textarea").focus(); +} + +if (buttonsInChat.length > 0) { + for (let i = buttonsInChat.length - 1; i >= 0; i--) { + const thisButton = buttonsInChat[i]; + menu.appendChild(thisButton); + + thisButton.addEventListener("click", () => { + hideMenu(); + }); + + const buttonText = thisButton.textContent; + const matches = buttonText.match(/(\(.*?\))/); + + if (matches && matches.length > 1) { + // Apply the transparent-substring class to the matched substring + const substring = matches[1]; + const newText = buttonText.replace(substring, ` ${substring.slice(1, -1)}`); + thisButton.innerHTML = newText; + } + } +} else { + buttonsInChat = document.querySelectorAll("#chat-tab.old-ui #chat-buttons button"); + for (let i = 0; i < buttonsInChat.length; i++) { + buttonsInChat[i].textContent = buttonsInChat[i].textContent.replace(/ \(.*?\)/, ""); + } + document.getElementById("gr-hover-container").style.display = "none"; +} + +function isMouseOverButtonOrMenu() { + return menu.matches(":hover") || button.matches(":hover"); +} + +button.addEventListener("mouseenter", function () { + showMenu(); +}); + +button.addEventListener("click", function () { + showMenu(); +}); + +// Add event listener for mouseleave on the button +button.addEventListener("mouseleave", function () { + // Delay to prevent menu hiding when the mouse leaves the button into the menu + setTimeout(function () { + if (!isMouseOverButtonOrMenu()) { + hideMenu(); + } + }, 100); +}); + +// Add event listener for mouseleave on the menu +menu.addEventListener("mouseleave", function () { + // Delay to prevent menu hide when the mouse leaves the menu into the button + setTimeout(function () { + if (!isMouseOverButtonOrMenu()) { + hideMenu(); + } + }, 100); +}); + +// Add event listener for click anywhere in the document +document.addEventListener("click", function (event) { + // Check if the click is outside the button/menu and the menu is visible + if (!isMouseOverButtonOrMenu() && menu.style.display === "flex") { + hideMenu(); + } +}); + +//------------------------------------------------ +// Focus on the chat input +//------------------------------------------------ +document.querySelector("#chat-input textarea").focus(); diff --git a/webui/modules/__init__.py b/webui/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/webui/modules/agents/__init__.py b/webui/modules/agents/__init__.py new file mode 100644 index 00000000..5a4e3482 --- /dev/null +++ b/webui/modules/agents/__init__.py @@ -0,0 +1,83 @@ +import gc +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Mapping + +import yaml + +from agentlego.utils import is_package_available +from ..logging import logger +from . import lagent_agent as lagent +from . import langchain_agent as langchain + + +@dataclass +class AgentCallbacks: + load_llm: Callable + cfg_widget: Callable + generate: Callable + + +agent_func_map: Mapping[str, AgentCallbacks] = { + 'langchain.StructuredChat.ChatOpenAI': + AgentCallbacks( + load_llm=langchain.llm_chat_openai, + cfg_widget=langchain.cfg_chat_openai, + generate=langchain.generate_structured, + ), + 'lagent.InternLM2Agent': + AgentCallbacks( + load_llm=lagent.llm_internlm2_lmdeploy, + cfg_widget=lagent.cfg_internlm2, + generate=lagent.generate_internlm2, + ) +} + + +def load_llm(agent_name): + from ..settings import get_agent_settings + logger.info(f'Loading {agent_name}...') + agent_settings = get_agent_settings(agent_name) + loader = agent_settings.pop('agent_class') + output = agent_func_map[loader].load_llm(agent_settings) + return output + + +def clear_cache(): + gc.collect() + if is_package_available('torch'): + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def unload_agent(): + from .. import shared + shared.agent = None + clear_cache() + +def delete_agent(name): + from .. import shared + + name = name.strip() + + if name == '': + return + if name == shared.agent_name: + unload_agent() + shared.agent_name = None + + p = Path(shared.args.agent_config) + if p.exists(): + settings = yaml.safe_load(open(p, 'r').read()) + else: + settings = {} + + settings.pop(name, None) + shared.agent_settings = settings + + output = yaml.dump(settings, sort_keys=False, allow_unicode=True) + with open(p, 'w') as f: + f.write(output) + + return f'`{name}` is deleted from `{p}`.' diff --git a/webui/modules/agents/lagent_agent.py b/webui/modules/agents/lagent_agent.py new file mode 100644 index 00000000..fab11755 --- /dev/null +++ b/webui/modules/agents/lagent_agent.py @@ -0,0 +1,102 @@ +import copy +from typing import Iterator, List + +from lagent.actions import ActionExecutor +from lagent.agents import internlm2_agent +from lagent.llms.lmdepoly_wrapper import LMDeployClient +from lagent.llms.meta_template import INTERNLM2_META +from lagent.schema import AgentStatusCode + +from .. import message_schema as msg +from ..logging import logger +from ..utils import parse_inputs + + +def llm_internlm2_lmdeploy(cfg): + url = cfg['url'].strip() + llm = LMDeployClient( + path='internlm2-chat-20b', + url=url, + meta_template=INTERNLM2_META, + top_p=0.8, + top_k=100, + temperature=0, + repetition_penalty=1.0, + stop_words=['<|im_end|>']) + return llm + + +def cfg_internlm2(): + import gradio as gr + widgets = {} + widgets['url'] = gr.Textbox(label='URL', info='The internlm2 server url of LMDeploy, like `http://localhost:23333`') + widgets['meta_prompt'] = gr.Textbox(label='system prompt', value=internlm2_agent.META_CN) + widgets['plugin_prompt'] = gr.Textbox(label='plugin prompt', value=internlm2_agent.PLUGIN_CN) + return widgets + + +def lagent_style_history(history) -> List[dict]: + inner_steps = [] + for row in history['internal']: + inner_steps.append(dict(role='user', content=row[0])) + for step in row[1]: + if isinstance(step, msg.ToolInput): + args = {k: v['content'] for k, v in step.args.items()} + inner_steps.append(dict(role='tool', name='plugin', content=dict(name=step.name, parameters=args))) + elif isinstance(step, msg.ToolOutput) and step.outputs: + outputs = '\n'.join(item['content'] if item['type'] == + 'text' else f'[{item["type"]}]({item["content"]})' + for item in step.outputs) + inner_steps.append(dict(role='environment', content=outputs, name='plugin')) + elif isinstance(step, msg.Answer): + inner_steps.append(dict(role='language', content=step.text)) + return inner_steps + + +def create_internlm2_agent(llm, tools, cfg) -> internlm2_agent.Internlm2Agent: + + tools = [tool.to_lagent() for tool in tools] + + agent = internlm2_agent.Internlm2Agent( + llm=llm, + plugin_executor=ActionExecutor(actions=tools), + protocol=internlm2_agent.Interlm2Protocol( + plugin_prompt=cfg['plugin_prompt'].strip(), + tool=dict( + begin='{start_token}{name}\n', + start_token='<|action_start|>', + name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'), + belong='assistant', + end='<|action_end|>\n', + ), + ), + max_turn=6, + ) + return agent + + +def generate_internlm2(question, state, history) -> Iterator[List[msg.Message]]: + from .. import shared + + cfg = copy.deepcopy(shared.agents_settings[shared.agent_name]) + tools = [tool for k, tool in shared.toolkits.items() if k in state['selected_tools']] + agent = create_internlm2_agent(shared.llm, tools, cfg) + messages: List[msg.Message] = [] + history = lagent_style_history(history) + [dict(role='user', content=question)] + if shared.args.verbose: + for dialog in agent._protocol.format(inner_step=history, plugin_executor=agent._action_executor): + logger.info(f'[{dialog["role"].upper()}]: {dialog["content"]}') + for agent_return in agent.stream_chat(history): + if agent_return.state == AgentStatusCode.PLUGIN_RETURN: + action = agent_return.actions[-1] + tool = shared.toolkits[action.type] + args = parse_inputs(tool.toolmeta, action.args) + messages.append( + msg.ToolInput(name=action.type, args=args, thought=action.thought)) + messages.append(msg.ToolOutput(outputs=tuple(action.result))) + yield messages + elif agent_return.state == AgentStatusCode.END and isinstance(agent_return.response, str): + messages.append(msg.Answer(text=agent_return.response)) + yield messages + elif agent_return.state == AgentStatusCode.STREAM_ING: + yield messages + [msg.Answer(text=agent_return.response)] diff --git a/webui/modules/agents/langchain_agent.py b/webui/modules/agents/langchain_agent.py new file mode 100644 index 00000000..5d1c6331 --- /dev/null +++ b/webui/modules/agents/langchain_agent.py @@ -0,0 +1,207 @@ +import json +import time +from copy import deepcopy +from queue import Queue +from threading import Thread +from typing import Iterator, List, Optional, Union + +import langchain_core.messages as lc_msg +from langchain.agents import AgentExecutor, create_structured_chat_agent +from langchain.callbacks.base import BaseCallbackHandler +from langchain.memory import ChatMessageHistory +from langchain.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate, + MessagesPlaceholder, PromptTemplate, + SystemMessagePromptTemplate) +from langchain_core.agents import AgentAction, AgentFinish +from langchain_openai import ChatOpenAI as _ChatOpenAI +from pydantic import BaseModel + +from agentlego.tools import BaseTool +from .. import message_schema as msg +from ..logging import logger +from ..utils import parse_inputs, parse_outputs + +# modified form hub.pull("hwchase17/structured-chat-agent") +STRUCTURED_CHAT_PROMPT = ChatPromptTemplate( + input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools'], + input_types={ + 'chat_history': + List[Union[lc_msg.ai.AIMessage, lc_msg.human.HumanMessage, + lc_msg.chat.ChatMessage, lc_msg.system.SystemMessage, + lc_msg.function.FunctionMessage, lc_msg.tool.ToolMessage]] + }, + messages=[ + SystemMessagePromptTemplate( + prompt=PromptTemplate( + input_variables=['tool_names', 'tools'], + template= + 'Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n{tools}\n\nUse a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\nValid "action" values: "Final Answer" or {tool_names}\n\nProvide only ONE action per $JSON_BLOB, as shown:\n\n```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\nFollow this format:\n\nQuestion: input question to answer\nThought: consider previous and subsequent steps\nAction:\n```\n$JSON_BLOB\n```\nObservation: action result\n... (repeat Thought/Action/Observation N times)\nThought: I know what to respond\nAction:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}```\n\nBegin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Please use the markdown style file path link to display images and audios in the final answer. The thought and final answer should use the same language with the question. Format is Action:```$JSON_BLOB```then Observation' + )), + MessagesPlaceholder(variable_name='chat_history', optional=True), + HumanMessagePromptTemplate( + prompt=PromptTemplate( + input_variables=['agent_scratchpad', 'input'], + template= + '{input}\n\n{agent_scratchpad}\n (reminder to respond in a JSON blob no matter what)' + )) + ]) + +class StopChainException(Exception): + """Stop the chain by user.""" + + +class GenerationCallback(BaseCallbackHandler): + raise_error: bool = True + + def __init__(self, mq: Queue, tools: List[BaseTool]): + self.mq = mq + self.tools = {tool.name: tool for tool in tools} + + def on_agent_action(self, action: AgentAction, **kwargs): + if 'Thought:' in action.log: + thought = action.log.partition('Thought:')[-1].partition('\n')[0].strip() + else: + thought = None + tool = self.tools[action.tool] + args = parse_inputs(tool.toolmeta, action.tool_input) + self.mq.put(msg.ToolInput(name=action.tool, args=args, thought=thought)) + + def on_tool_end(self, output: str, name: str, **kwargs): + if name in self.tools: + tool = self.tools[name] + # Try to parse the outputs + outputs = parse_outputs(tool.toolmeta, output) + self.mq.put(msg.ToolOutput(outputs=outputs)) + else: + self.mq.put(msg.ToolOutput(error=output)) + + def on_agent_finish(self, finish: AgentFinish, **kwargs): + self.mq.put(msg.Answer(text=finish.return_values['output'])) + + def on_tool_start(self, *args, **kwargs): + from .. import shared + if shared.stop_everything: + raise StopChainException('The chain is stopped by user.') + + def on_llm_start(self, serialized, prompts, *args, **kwargs): + from .. import shared + if shared.stop_everything: + raise StopChainException('The chain is stopped by user.') + if shared.args.verbose: + logger.info('LangChain prompt: \n' + '\n'.join(prompts)) + + def on_chain_error(self, error: BaseException, **kwargs): + + self.mq.put(msg.Error(type=type(error).__name__, reason=str(error))) + + +class ChatOpenAI(_ChatOpenAI): + """Support Extra stop words.""" + extra_stop: Optional[List[str]] = None + + def _create_message_dicts(self, messages, stop): + if stop is not None and self.extra_stop is not None: + stop = stop + self.extra_stop + elif stop is None and self.extra_stop is not None: + stop = self.extra_stop + return super()._create_message_dicts(messages, stop=stop) + +def llm_chat_openai(model_kwargs): + + model_kwargs = deepcopy(model_kwargs) + extra_stop = model_kwargs.pop('extra_stop', None) + if isinstance(extra_stop, str) and len(extra_stop) > 0: + extra_stop = extra_stop.split(',') + else: + extra_stop = None + + if model_kwargs.get('openai_api_base') and not model_kwargs.get('openai_api_key'): + model_kwargs['openai_api_key'] = 'DUMMY_API_KEY' + + llm = ChatOpenAI(**model_kwargs, extra_stop=extra_stop) + + return llm + +def cfg_chat_openai(): + import gradio as gr + widgets = {} + widgets['model_name'] = gr.Textbox(label='Model name') + widgets['openai_api_base'] = gr.Textbox(label='API base url', info='If empty, use the default OpenAI api url, if you have self-hosted openai-style API server, please specify the host address here, like `http://localhost:8099/v1`') + widgets['openai_api_key'] = gr.Textbox(label='API key', info="If set `ENV`, will use the `OPENAI_API_KEY` environment variable. Leave empty if you don't need pass key.") + widgets['max_tokens'] = gr.Slider(label='Max number of tokens', minimum=0, maximum=8192, step=256, info='The maximum number of tokens to generate for one response.') + widgets['extra_stop'] = gr.Textbox(label='Extra stop words', info='Comma-separated list of stop words. Example: <|im_end|>,Response') + return widgets + + +def langchain_style_history(history) -> ChatMessageHistory: + memory = ChatMessageHistory() + for row in history['internal']: + response = '' + for step in row[1]: + if isinstance(step, msg.ToolInput): + response += f'Thought: {step.thought or ""}\n' + args = json.dumps({k: v['content'] for k, v in step.args.items()}) + tool_str = f'{{\n "action": "{step.name}",\n "action_input": "{args}"\n}}' + response += 'Action:\n```\n' + tool_str + '\n```\n' + elif isinstance(step, msg.ToolOutput): + if step.outputs: + outputs = ', '.join(out['content'] for out in step.outputs) + response += f'Observation: {outputs}\n' + elif step.error: + response += f'Observation: {step.error}\n' + elif isinstance(step, msg.Answer): + response += f'Thought: {step.thought or ""}\n' + tool_str = f'{{\n "action": "Final Answer",\n "action_input": "{step.text}"\n}}' + response += 'Action:\n```\n' + tool_str + '\n```\n' + memory.add_user_message(row[0]) + memory.add_ai_message(response) + return memory + + +def create_langchain_structure(llm, tools): + from .. import shared + + tools = [tool.to_langchain() for tool in tools] + agent = create_structured_chat_agent( + llm=llm, + tools=tools, + prompt=STRUCTURED_CHAT_PROMPT, + ) + return AgentExecutor( + agent=agent, + tools=tools, + verbose=shared.args.verbose, + handle_parsing_errors=False, + ) + + +def generate_structured(question, state, history) -> Iterator[List[BaseModel]]: + from .. import shared + messages = [] + + mq = Queue() + tools = [ + tool for k, tool in shared.toolkits.items() + if k in state['selected_tools'] + ] + callback = GenerationCallback(mq, tools) + agent = create_langchain_structure(shared.llm, tools) + + history = langchain_style_history(history) + thread = Thread( + target=agent.invoke, + args=(dict(input=question, chat_history=history.messages), + dict(callbacks=[callback], ))) + thread.start() + while thread.is_alive() or mq.qsize() > 0: + if mq.qsize() > 0: + item = mq.get() + messages.append(item) + yield messages + if isinstance(item, msg.Error): + return + elif shared.stop_everything: + yield messages + return + else: + time.sleep(0.5) diff --git a/webui/modules/chat.py b/webui/modules/chat.py new file mode 100644 index 00000000..2ffebe45 --- /dev/null +++ b/webui/modules/chat.py @@ -0,0 +1,220 @@ +import copy +import hashlib +import html +import json +from datetime import datetime +from pathlib import Path + +import gradio as gr +from pydantic_core import to_jsonable_python + +from . import shared +from .html_generator import chat_html_wrapper, reply_to_html +from .logging import logger +from .text_generation import generate_reply +from .utils import delete_file + +IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp') +AUDIO_EXTENSIONS = ('.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a') + + +def persist_file(path): + path = Path(path) + with open(path, 'rb') as file: + file_bytes = file.read() + file_hash = hashlib.sha256(file_bytes[:512] + file_bytes[-512:]) + date = datetime.now().strftime('%Y%m%d-') + filename = date + file_hash.hexdigest()[:8] + path.suffix + new_path = Path('generated/upload/') + new_path.mkdir(parents=True, exist_ok=True) + new_path = new_path / filename + if not new_path.exists(): + path.rename(new_path) + return new_path + + +def add_file(path): + path = Path(path) + if not path.exists(): + return '', '' + + if path.suffix in IMAGE_EXTENSIONS: + internal = f'An image at `{path}`.\n' + visible = f'
Image
' + elif path.suffix in AUDIO_EXTENSIONS: + internal = f'An audio at `{path}`.\n' + visible = f'
' + else: + internal = f'A file at `{path}\n' + visible = f'
{path}
' + + return internal, visible + + +def chatbot_wrapper(text, state, files=None, regenerate=False, loading_message=True): + history = state['history'] + if not text and not regenerate: + yield state['history'] + return + output = copy.deepcopy(history) + just_started = True + visible_text = None + + # Prepare the input + if not regenerate: + visible_text = html.escape(text) + if files: + for file in reversed(files): + internal, visible = add_file(file) + visible_text = visible + visible_text + text = internal + text + else: + text, visible_text = output['internal'][-1][0], output['visible'][-1][0] + output['visible'].pop() + output['internal'].pop() + + # *Is typing...* + if loading_message: + yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} + + # Generate + reply = None + for j, reply in enumerate(generate_reply(text, state, history=output)): + + visible_reply = reply_to_html(reply) + + if shared.stop_everything: + yield output + return + + if just_started: + just_started = False + output['internal'].append(['', []]) + output['visible'].append(['', '']) + + if not (j == 0 and visible_reply.strip() == ''): + output['internal'][-1] = [text, reply] + output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] + yield output + + yield output + + +def generate_chat_reply(text, state, file=None, regenerate=False, loading_message=True): + history = state['history'] + if regenerate: + text = '' + if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0: + yield history + return + + for history in chatbot_wrapper(text, state, file, regenerate=regenerate, loading_message=loading_message): + yield history + + +def generate_chat_reply_wrapper(text, state, file=None, regenerate=False): + """Same as above but returns HTML for the UI.""" + for i, history in enumerate(generate_chat_reply(text, state, file, regenerate, loading_message=True)): + yield chat_html_wrapper(history), history + +def upload_file(file, uploaded, history): + if not file: + return gr.update() + file = persist_file(file) + uploaded.append(file) + + output = copy.deepcopy(history) + visible = '' + for file in uploaded: + visible += add_file(file)[1] + output['visible'] = output['visible'] + [[visible, '']] + return uploaded, chat_html_wrapper(output) + +def redraw_html(history): + return chat_html_wrapper(history) + + +def start_new_chat(): + history = {'internal': [], 'visible': []} + + unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S') + save_history(history, unique_id) + + return history + + +def get_history_file_path(unique_id): + p = Path(f'logs/{unique_id}.json') + return p + + +def save_history(history, unique_id): + if not unique_id: + return + + p = get_history_file_path(unique_id) + if not p.parent.is_dir(): + p.parent.mkdir(parents=True) + + history = to_jsonable_python(history) + with open(p, 'w', encoding='utf-8') as f: + f.write(json.dumps(history, indent=4, ensure_ascii=False)) + + +def rename_history(old_id, new_id): + old_p = get_history_file_path(old_id) + new_p = get_history_file_path(new_id) + if new_p.parent != old_p.parent: + logger.error(f'The following path is not allowed: {new_p}.') + elif new_p == old_p: + logger.info('The provided path is identical to the old one.') + else: + logger.info(f'Renaming {old_p} to {new_p}') + old_p.rename(new_p) + + +def find_all_histories(): + paths = Path('logs/').glob('*.json') + histories = sorted(paths, key=lambda x: x.stat().st_mtime, reverse=True) + histories = [path.stem for path in histories] + + return histories + + +def load_latest_history(): + """Loads the latest history for the given character in chat or chat- + instruct mode, or the latest instruct history for instruct mode.""" + + histories = find_all_histories() + + if len(histories) > 0: + unique_id = Path(histories[0]).stem + history = load_history(unique_id) + else: + history = start_new_chat() + + return history + +def recover_message(data): + from . import message_schema as msg + + if isinstance(data, dict) and '_role' in data: + data_type = getattr(msg, data['_role']) + return data_type.model_validate(data) + elif isinstance(data, dict): + return {k: recover_message(v) for k, v in data.items()} + elif isinstance(data, (tuple, list)): + return type(data)([recover_message(item) for item in data]) + else: + return data + +def load_history(unique_id): + p = get_history_file_path(unique_id) + + f = json.loads(open(p, 'rb').read()) + return recover_message(f) + + +def delete_history(unique_id): + p = get_history_file_path(unique_id) + delete_file(p) diff --git a/webui/modules/html_generator.py b/webui/modules/html_generator.py new file mode 100644 index 00000000..4b50a9b6 --- /dev/null +++ b/webui/modules/html_generator.py @@ -0,0 +1,221 @@ +import html +import re +from copy import deepcopy +from pathlib import Path +from typing import List, Optional + +import markdown +from pydantic import BaseModel + +from . import message_schema as msg + +IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp') +AUDIO_EXTENSIONS = ('.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a') + +IMAGE_REGEX = r'(?P[/\.\w-]+\.(png|jpg))' +AUDIO_REGEX = r'(?P[/\.\w-]+\.(wav|mp3))' +# `sandbox:` and `file://` are common prefix of ChatGPT output. +LINK_REGEX = r'`*\!?\[(?P.+)\]\((sandbox:)?(file://)?{}\)`*' + +with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f: + chat_css = f.read() + +def fix_newlines(string): + string = string.replace('\n', '\n\n') + string = re.sub(r'\n{3,}', '\n\n', string) + string = string.strip() + return string + +# avatars come from streamlit chatbot. +USER_AVATAR = '''\ +''' +BOT_AVATAR = '''\ + +''' + +def replace_blockquote(m): + return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '') + + +def convert_to_markdown(string): + + # Blockquote + string = re.sub(r'(^|[\n])>', r'\1>', string) + pattern = re.compile(r'\\begin{blockquote}(.*?)\\end{blockquote}', re.DOTALL) + string = pattern.sub(replace_blockquote, string) + + # Code + string = string.replace('\\begin{code}', '```') + string = string.replace('\\end{code}', '```') + string = re.sub(r'(.)```', r'\1\n```', string) + + result = '' + is_code = False + for line in string.split('\n'): + if line.lstrip(' ').startswith('```'): + is_code = not is_code + + result += line + if is_code or line.startswith('|'): # Don't add an extra \n for tables or code + result += '\n' + else: + result += '\n\n' + + result = result.strip() + if is_code: + result += '\n```' # Unfinished code block + + # Unfinished list, like "\n1.". A |delete| string is added and then + # removed to force a
    or
      to be generated instead of a

      . + if re.search(r'(\n\d+\.?|\n\*\s*)$', result): + delete_str = '|delete|' + + if re.search(r'(\d+\.?)$', result) and not result.endswith('.'): + result += '.' + + result = re.sub(r'(\n\d+\.?|\n\*\s*)$', r'\g<1> ' + delete_str, result) + + html_output = markdown.markdown(result, extensions=['fenced_code', 'tables']) + pos = html_output.rfind(delete_str) + if pos > -1: + html_output = html_output[:pos] + html_output[pos + len(delete_str):] + else: + html_output = markdown.markdown(result, extensions=['fenced_code', 'tables']) + + # Unescape code blocks + pattern = re.compile(r']*>(.*?)', re.DOTALL) + html_output = pattern.sub(lambda x: html.unescape(x.group()), html_output) + + return html_output + + +def chat_html_wrapper(history): + history = history['visible'] + + output = f'

      ' + + for i, _row in enumerate(history): + row = [convert_to_markdown(entry) for entry in _row] + + if row[0]: # don't display empty user messages + output += f""" +
      +
      + {USER_AVATAR} +
      +
      +
      + {row[0]} +
      +
      +
      + """ + + if row[1]: + output += f""" +
      +
      + {BOT_AVATAR} +
      +
      +
      + {row[1]} +
      +
      +
      + """ + + output += '
      ' + return output + + +def sub_image_path(match_obj: re.Match): + path = Path(match_obj.groupdict()['path']) + path = path.absolute().relative_to(Path.cwd()) + return f'
      Image
      ' + +def sub_audio_path(match_obj: re.Match): + path = Path(match_obj.groupdict()['path']) + path = path.absolute().relative_to(Path.cwd()) + return f'
      ' + + + +def tool_to_html(input: msg.ToolInput, + output: Optional[msg.ToolOutput] = None): + tool = input.name + + if input.thought: + html = f'
      {convert_to_markdown(input.thought)}
      ' + else: + html = '' + + html += '
      ' + if output is None: + html += f'Executing {tool} ...' + elif output.outputs is not None: + html += f'{tool}' + else: + html += f'Failed to execute {tool}' + + if input.args: + args = deepcopy(input.args) + def replace_arg(data: dict): + if data['type'] == 'text': + return repr(data['content']) + else: + return 'path' + + args = ', '.join(f'{k}={replace_arg(v)}' for k, v in args.items()) + html += f'
      Args: {args}
      ' + + if output and output.outputs is not None: + text_output = ', '.join( + str(out['content']) for out in output.outputs if out['type'] == 'text') + html += f'
      Response: {text_output}' + for out in output.outputs: + if out['type'] == 'image': + html += re.sub(IMAGE_REGEX, sub_image_path, out['content']) + elif out['type'] == 'audio': + html += re.sub(AUDIO_REGEX, sub_audio_path, out['content']) + elif out['type'] == 'file': + html += f'
      {out["content"]}
      ' + html += '
      ' + elif output and output.error is not None: + html += f'
      Error: {output.error}
      ' + + html += '
      ' + return html + +def reply_to_html(steps: List[BaseModel]) -> str: + html = '' + loading = False + for i, step in enumerate(steps): + if isinstance(step, msg.ToolInput): + if len(steps) > i + 1 and isinstance(steps[i + 1], msg.ToolOutput): + loading = False + response = steps[i + 1] + else: + loading = True + response = None + html += tool_to_html(input=step, output=response) + if isinstance(step, msg.Answer): + loading = False + html += '
      ' + answer = re.sub(LINK_REGEX.format(IMAGE_REGEX), sub_image_path, step.text) + answer = re.sub(LINK_REGEX.format(AUDIO_REGEX), sub_audio_path, answer) + html += convert_to_markdown(answer) + html += '
      ' + if isinstance(step, msg.Error): + loading = False + html += '
      ' + html += f'
      {step.type}
      ' + if step.reason: + reason = step.reason + if len(reason) > 350: + reason = reason[:350] + '...' + html += f'
      {reason}
      ' + html += '
      ' + if loading: + html += '

      Is thinking...

      ' + return html diff --git a/webui/modules/logging.py b/webui/modules/logging.py new file mode 100644 index 00000000..8eaa00d2 --- /dev/null +++ b/webui/modules/logging.py @@ -0,0 +1,13 @@ +import logging + +from rich.console import Console +from rich.logging import RichHandler + +console = Console() + +logger = logging.getLogger('agentlego') + +handler = RichHandler(console=console, keywords=[], show_path=False) +handler.setFormatter(logging.Formatter('%(message)s', datefmt='[%X]')) +logger.addHandler(handler) +logger.setLevel(logging.INFO) diff --git a/webui/modules/message_schema.py b/webui/modules/message_schema.py new file mode 100644 index 00000000..e4819a74 --- /dev/null +++ b/webui/modules/message_schema.py @@ -0,0 +1,32 @@ +from typing import Literal, Mapping, Optional, Tuple + +from pydantic import BaseModel + + +class Message(BaseModel): + ... + + +class ToolInput(Message): + name: str + args: Mapping[str, dict] + thought: Optional[str] = None + _role: Literal['ToolInput'] = 'ToolInput' + + +class ToolOutput(Message): + outputs: Optional[Tuple[dict, ...]] = None + error: Optional[str] = None + _role: Literal['ToolOutput'] = 'ToolOutput' + + +class Answer(Message): + text: str + thought: Optional[str] = None + _role: Literal['Answer'] = 'Answer' + + +class Error(Message): + type: str + reason: Optional[str] = None + _role: Literal['Answer'] = 'Answer' diff --git a/webui/modules/settings.py b/webui/modules/settings.py new file mode 100644 index 00000000..90f7c0c3 --- /dev/null +++ b/webui/modules/settings.py @@ -0,0 +1,121 @@ +from copy import deepcopy +from pathlib import Path + +import gradio as gr +import yaml + +from . import shared, ui + + +def get_agent_settings(name): + return deepcopy(shared.agents_settings[name]) + + +def get_tool_settings(name): + return deepcopy(shared.tool_settings[name]) + + +def apply_agent_settings(agent): + ''' + UI: update the state variable with the agent settings + ''' + state = deepcopy(ui.agent_elements) + if agent is None or agent == 'New Agent': + return list(state.values()) + agent_settings = get_agent_settings(agent) + agent_class = agent_settings.pop('agent_class') + state['agent_class'] = agent_class + state.update({f'{agent_class}#{k}': v for k, v in agent_settings.items()}) + return [state[k] for k in ui.agent_elements] + + +def apply_tool_settings_to_state(tool, state): + ''' + UI: update the state variable with the tool settings + ''' + if tool is None or tool == 'New Tool': + state.update(dict(tool_class=None)) + return state + tool_settings = get_tool_settings(tool) + for k, v in tool_settings.items(): + state['tool_' + k] = v + return state + + +def save_agent_settings(name, *args): + """Save the settings for this agent to agent_config.yaml.""" + name = name.strip() + + if name == '': + raise gr.Error('Please specify the agent name to save.') + + p = Path(shared.args.agent_config) + if p.exists(): + settings = yaml.safe_load(open(p, 'r').read()) + else: + settings = {} + + settings.setdefault(name, {}) + + state = {k: v for k, v in zip(ui.agent_elements, args)} + for k in ui.agent_elements: + if k == 'agent_class' or k.startswith(state['agent_class'] + '#'): + save_k = k.rpartition('#')[-1] + settings[name][save_k] = state[k] or None + + shared.agents_settings = settings + + output = yaml.dump(settings, sort_keys=False, allow_unicode=True) + with open(p, 'w') as f: + f.write(output) + + return f'Settings for `{name}` saved to `{p}`.' + +def save_tool_settings(tool_class, name, desc, enable, device, args, old_name=None): + """Save the settings for this agent to agent_config.yaml.""" + name = name.strip() + + if name == '': + return 'Not saving the settings because no model is loaded.' + + p = Path(shared.args.tool_config) + if p.exists(): + settings = yaml.safe_load(open(p, 'r').read()) + else: + settings = {} + + if old_name is not None: + settings.pop(old_name, None) + elif name in settings: + return f'The name `{name}` is already used.' + + settings[name] = { + 'class': tool_class, + 'name': name, + 'description': desc, + 'enable': enable, + 'device': device, + 'args': args, + } + shared.tool_settings = settings + + output = yaml.dump(settings, sort_keys=False, allow_unicode=True) + with open(p, 'w') as f: + f.write(output) + + return f'Settings for `{name}` saved to `{p}`.' + + +def make_agent_params_visible(agent_class): + updates = [] + for name in ui.agent_elements: + if name == 'agent_class': + updates.append(gr.update()) + elif not agent_class: + updates.append(gr.update(visible=False, interactive=False)) + elif name.startswith(agent_class + '#'): + updates.append(gr.update(visible=True, interactive=True)) + else: + updates.append(gr.update(visible=False, interactive=False)) + + return updates diff --git a/webui/modules/shared.py b/webui/modules/shared.py new file mode 100644 index 00000000..c1379b8a --- /dev/null +++ b/webui/modules/shared.py @@ -0,0 +1,107 @@ +import argparse +from collections import OrderedDict +from pathlib import Path +from typing import Mapping + +import yaml +from modules.logging import logger + +from agentlego.apis.tool import NAMES2TOOLS, extract_all_tools +from agentlego.tools import BaseTool +from agentlego.tools.remote import RemoteTool +from agentlego.utils import resolve_module + +# Agent variables +agent_name = None +agent = None +llm = None +toolkits: Mapping[str, BaseTool] = {} + +# Generation variables +stop_everything = False +need_restart = False +generation_lock = None +processing_message = '*Is thinking...*' + +# UI variables +gradio = {} + +# UI defaults +settings = { + 'preset': 'simple-1', + 'max_new_tokens': 512, + 'max_new_tokens_min': 1, + 'max_new_tokens_max': 4096, + 'seed': -1, + 'truncation_length': 2048, + 'truncation_length_min': 0, + 'truncation_length_max': 200000, + 'max_tokens_second': 0, + 'custom_stopping_strings': '', + 'custom_token_bans': '', + 'add_bos_token': True, + 'skip_special_tokens': True, + 'stream': True, + 'autoload_model': False, +} + +parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) + +# Basic settings +parser.add_argument('--agent', type=str, help='Name of the agent to load by default.') +parser.add_argument('--agent-config', type=str, default='agent_config.yml', help='The agent config yaml file.') +parser.add_argument('--tool-config', type=str, default='tool_config.yml', help='The tools config yaml file.') +parser.add_argument('--lazy-setup', action='store_true', help='Avoid setup tools before the first run.') +parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') + +# Gradio +parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') +parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') +parser.add_argument('--listen-host', type=str, help='The hostname that the server will use.') +parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') +parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') +parser.add_argument('--gradio-auth', type=str, help='Set Gradio authentication password in the format "username:password". Multiple credentials can also be supplied with "u1:p1,u2:p2,u3:p3".', default=None) +parser.add_argument('--gradio-auth-path', type=str, help='Set the Gradio authentication file path. The file should contain one or more user:password pairs in the same format as above.', default=None) +parser.add_argument('--ssl-keyfile', type=str, help='The path to the SSL certificate key file.', default=None) +parser.add_argument('--ssl-certfile', type=str, help='The path to the SSL certificate cert file.', default=None) + +args = parser.parse_args() + +# Security warnings +if args.share: + logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") +if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): + logger.warning("You are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.") + + +# Load agent-specific settings +with Path(args.agent_config) as p: + if p.exists(): + agents_settings = yaml.safe_load(open(p, 'r').read()) + else: + agents_settings = {} + +agents_settings = OrderedDict(agents_settings) + +# Load tool-specific settings +with Path(args.tool_config) as p: + if p.exists(): + tool_settings = yaml.safe_load(open(p, 'r').read()) + else: + tool_settings = {} + +tool_classes: Mapping[str, type] = NAMES2TOOLS.copy() +tool_classes['RemoteTool'] = RemoteTool +custom_tools_dir = Path(__file__).absolute().parents[1] / 'custom_tools' +for source_file in custom_tools_dir.glob('*.py'): + try: + module = resolve_module(source_file) + toolkit = module.__name__ + if toolkit.startswith('_ext_'): + toolkit = toolkit[5:] + tool_classes.update({ + toolkit + '.' + k: v + for k, v in extract_all_tools(module).items() + }) + except Exception: + logger.exception('Traceback') diff --git a/webui/modules/text_generation.py b/webui/modules/text_generation.py new file mode 100644 index 00000000..84d5a1f7 --- /dev/null +++ b/webui/modules/text_generation.py @@ -0,0 +1,43 @@ +import time + +import gradio as gr +import modules.shared as shared + +from .agents import agent_func_map, clear_cache +from .logging import logger + + +def generate_reply(*args, **kwargs): + shared.generation_lock.acquire() + try: + yield from _generate_reply(*args, **kwargs) + finally: + shared.generation_lock.release() + + +def _generate_reply(question, state, history): + + # Find the appropriate generation function + if shared.agent_name is None or shared.llm is None: + logger.error('No agent is loaded! Select one in the Agent tab.') + raise gr.Error('No agent is loaded! Select one in the Agent tab.') + + agent_class = shared.agents_settings[shared.agent_name]['agent_class'] + generate_func = agent_func_map[agent_class].generate + + if shared.args.verbose: + logger.info(f'question:\n{question}\n--------------------\n') + + shared.stop_everything = False + clear_cache() + + try: + t0 = time.time() + yield from generate_func(question, state, history) + finally: + t1 = time.time() + logger.info(f'Output generated in {(t1-t0):.2f} seconds.') + + +def stop_everything_event(): + shared.stop_everything = True diff --git a/webui/modules/tools.py b/webui/modules/tools.py new file mode 100644 index 00000000..174e626e --- /dev/null +++ b/webui/modules/tools.py @@ -0,0 +1,87 @@ +import ast +import inspect +from copy import deepcopy +from pathlib import Path + +import gradio as gr +import yaml + +from . import shared +from .settings import get_tool_settings, save_tool_settings + + +def parse_args(args_str): + call = ast.parse(f'foo({args_str})').body[0].value + kwargs = {} + for keyword in call.keywords: + k = keyword.arg + v = ast.Expression(body=keyword.value) + ast.fix_missing_locations(v) + kwargs[k] = eval(compile(v, '', 'eval')) + return kwargs + + +def load_tool(name=None): + cfg = get_tool_settings(name) + if not cfg['enable']: + shared.toolkits.pop(name, None) + return None + + try: + tool = load_tool_from_cfg(cfg) + tool.setup() + tool._is_setup = True + shared.toolkits[name] = tool + return tool + except Exception as e: + save_tool_settings( + tool_class=cfg['class'], + name=cfg['name'], + desc=cfg['description'], + enable=False, + device=cfg.get('device', None), + args=cfg['args'], + old_name=cfg['name']) + raise gr.Error(f'Failed to load tool `{name}`, auto disabled.') from e + + +def load_tool_from_cfg(tool_cfg): + tool_cfg = deepcopy(tool_cfg) + tool_class = shared.tool_classes[tool_cfg.pop('class')] + device = tool_cfg.pop('device', 'cpu') + kwargs = parse_args(tool_cfg.pop('args')) + from agentlego.tools.remote import RemoteTool + + if 'device' in inspect.signature(tool_class).parameters: + tool = tool_class(device=device, **kwargs) + elif tool_class is RemoteTool: + tool = RemoteTool.from_url(**kwargs) + else: + tool = tool_class(**kwargs) + + tool.toolmeta.name = tool_cfg['name'] + tool.toolmeta.description = tool_cfg['description'] + + return tool + +def delete_tool(name): + name = name.strip() + + if name == '': + return + + p = Path(shared.args.tool_config) + if p.exists(): + settings = yaml.safe_load(open(p, 'r').read()) + else: + settings = {} + + settings.pop(name, None) + shared.tool_settings = settings + shared.toolkits.pop(name, None) + + output = yaml.dump(settings, sort_keys=False, allow_unicode=True) + with open(p, 'w') as f: + f.write(output) + + return f'`{name}` is deleted from `{p}`.' diff --git a/webui/modules/ui.py b/webui/modules/ui.py new file mode 100644 index 00000000..5198f653 --- /dev/null +++ b/webui/modules/ui.py @@ -0,0 +1,103 @@ +from pathlib import Path + +import gradio as gr + +with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: + css = f.read() +with open(Path(__file__).resolve().parent / '../js/main.js', 'r') as f: + js = f.read() + +refresh_symbol = '🔄' +delete_symbol = '🗑️' +save_symbol = '💾' + +theme = gr.themes.Default( + font=['Noto Sans', 'Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'], + font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'], +).set( + border_color_primary='#c5c5d2', + button_large_padding='6px 12px', + body_text_color_subdued='#484848', + background_fill_secondary='#eaeaea' +) + +agent_elements = {'agent_class': None} + + +def list_tool_elements(): + elements = [ + 'tool_class', + 'tool_name', + 'tool_description', + 'tool_enable', + 'tool_device', + 'tool_args', + ] + return elements + + +def list_interface_input_elements(): + elements = [] + + # Chat elements + elements += [ + 'textbox', + 'history', + 'selected_tools', + ] + + # Tool elements + elements += list_tool_elements() + + return elements + + +def gather_interface_values(*args): + output = {} + for i, element in enumerate(list_interface_input_elements()): + output[element] = args[i] + + return output + + +def apply_interface_values(state): + elements = list_interface_input_elements() + if len(state) == 0: + return [gr.update() for k in elements] # Dummy, do nothing + else: + return [state[k] if k in state else gr.update() for k in elements] + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class, interactive=True): + """ + Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui + """ + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = gr.Button(refresh_symbol, elem_classes=elem_class, interactive=interactive) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + + return refresh_button + +def create_confirm_cancel(value, **kwargs): + widget = gr.Button(value, **kwargs) + hidden = [] + confirm = gr.Button('Confirm', visible=False, **kwargs) + cancel = gr.Button('Cancel', visible=False, variant='stop', **kwargs) + hidden.extend([confirm, cancel]) + + widget.click(lambda: [gr.update(visible=False)] + [gr.update(visible=True)] * len(hidden), None, [widget, *hidden], show_progress=False) + cancel.click(lambda: [gr.update(visible=True)] + [gr.update(visible=False)] * len(hidden), None, [widget, *hidden], show_progress=False) + + return widget, confirm, cancel diff --git a/webui/modules/ui_agent.py b/webui/modules/ui_agent.py new file mode 100644 index 00000000..a5a76b56 --- /dev/null +++ b/webui/modules/ui_agent.py @@ -0,0 +1,113 @@ +from functools import partial + +import gradio as gr + +from . import shared, ui, utils +from .agents import agent_func_map, delete_agent, load_llm, unload_agent +from .logging import logger +from .settings import (apply_agent_settings, make_agent_params_visible, + save_agent_settings) +from .utils import gradio + + +def create_ui(): + + with gr.Tab('Agent', elem_id='agent-tab'): + with gr.Row(): + shared.gradio['agent_menu'] = gr.Dropdown(choices=utils.get_available_agents(), label='Agent', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['agent_menu'], lambda: None, lambda: {'choices': utils.get_available_agents()}, 'refresh-button') + delete, confirm, cancel = ui.create_confirm_cancel('🗑️', elem_classes='refresh-button') + shared.gradio['delete_agent'] = delete + shared.gradio['delete_agent-confirm'] = confirm + shared.gradio['delete_agent-cancel'] = cancel + shared.gradio['load_agent'] = gr.Button('Load', elem_classes='refresh-button') + + with gr.Row(): + with gr.Column(scale=2): + with gr.Row(): + shared.gradio['agent_class'] = gr.Dropdown(label='Agent class', choices=agent_func_map.keys(), value=None, elem_classes=['slim-dropdown'], scale=3, interactive=False) + shared.gradio['save_agent'] = gr.Button('Save', visible=False, elem_classes='refresh-button') + shared.gradio['save_agent_new'] = gr.Button('Save to', visible=False, elem_classes='refresh-button') + shared.gradio['new_agent_name'] = gr.Textbox(label='Agent name', visible=False, scale=1, elem_classes=['slim-dropdown']) + with gr.Group(): + # Agent initialization arguments + for agent_class, callbacks in agent_func_map.items(): + for name, widget in callbacks.cfg_widget().items(): + widget.visible = False + name = f'{agent_class}#{name}' + shared.gradio[name] = widget + ui.agent_elements[name] = widget.value + + with gr.Column(scale=1): + shared.gradio['agent_status'] = gr.Markdown('No agent is loaded' if shared.agent_name == 'None' else 'Ready') + + +def create_event_handlers(): + def update_widgets(agent_menu): + if agent_menu == 'New Agent': + return [gr.update(visible=True, interactive=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)] + elif agent_menu is None: + return [gr.update(visible=False, interactive=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, interactive=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] + + def update_current_agent(): + if shared.agent_name is None: + return '
      No agent is loaded, please select in the Agent tab.
      ' + else: + return f'
      Current agent: {shared.agent_name}
      ' + + shared.gradio['agent_menu']\ + .change(apply_agent_settings, gradio('agent_menu'), gradio(*ui.agent_elements))\ + .then(update_widgets, gradio('agent_menu'), gradio('agent_class', 'save_agent', 'save_agent_new', 'new_agent_name'), show_progress=False)\ + .then(load_agent, gradio('agent_menu'), gradio('agent_status'), show_progress=False)\ + .success(update_current_agent, None, gradio('current-agent'), show_progress=False)\ + + shared.gradio['agent_class'].change(make_agent_params_visible, gradio('agent_class'), gradio(*ui.agent_elements), show_progress=False) + + shared.gradio['load_agent']\ + .click(partial(load_agent, autoload=True), gradio('agent_menu'), gradio('agent_status'), show_progress=False)\ + .success(update_current_agent, None, gradio('current-agent'), show_progress=False)\ + + shared.gradio['save_agent']\ + .click(save_agent_settings, gradio('agent_menu', *ui.agent_elements), gradio('agent_status'), show_progress=False) + + shared.gradio['save_agent_new']\ + .click(save_agent_settings, gradio('new_agent_name', *ui.agent_elements), gradio('agent_status'), show_progress=False)\ + .success(lambda name: gr.update(value=name, choices=utils.get_available_agents()), gradio('new_agent_name'), gradio('agent_menu'))\ + .then(lambda: gr.update(value=''), None, gradio('new_agent_name'))\ + + delete_agent_widgets = ('delete_agent', 'delete_agent-confirm', 'delete_agent-cancel') + shared.gradio['delete_agent-confirm']\ + .click(delete_agent, gradio('agent_menu'), gradio('agent_status'))\ + .success(lambda: gr.update(value=None, choices=utils.get_available_agents()), None, gradio('agent_menu'))\ + .then(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)], None, gradio(delete_agent_widgets)) + + +def load_agent(selected_agent, autoload=False): + if selected_agent is None: + yield 'No agent selected' + elif selected_agent == 'New Agent': + yield 'Please save the new agent before load it.' + elif not autoload: + if shared.agent_name is None: + yield f"Click on \"Load\" to load `{selected_agent}`." + elif shared.agent_name != selected_agent: + yield f"The current agent is `{shared.agent_name}`.\n\nClick on \"Load\" to load `{selected_agent}`." + else: + yield f'The agent `{shared.agent_name}` is loaded.' + return + else: + try: + yield f'Loading `{selected_agent}`...' + unload_agent() + shared.agent_name = selected_agent + shared.llm = load_llm(selected_agent) + + if shared.llm is not None: + yield f'Successfully loaded `{selected_agent}`.' + else: + yield f'Failed to load `{selected_agent}`.' + except: + logger.exception('Traceback') + raise gr.Error(f'Failed to load the agent `{selected_agent}`.') diff --git a/webui/modules/ui_chat.py b/webui/modules/ui_chat.py new file mode 100644 index 00000000..7bdeaf35 --- /dev/null +++ b/webui/modules/ui_chat.py @@ -0,0 +1,118 @@ +from datetime import datetime +from functools import partial + +import gradio as gr +from modules import chat, shared, ui +from modules.html_generator import chat_html_wrapper +from modules.text_generation import stop_everything_event +from modules.utils import gradio + + +def create_ui(): + shared.gradio['Chat input'] = gr.State() + shared.gradio['uploaded-files'] = gr.State([]) + shared.gradio['history'] = gr.State({'internal': [], 'visible': []}) + + with gr.Tab('Chat', elem_id='chat-tab'): + with gr.Row(): + with gr.Column(elem_id='chat-col'): + # Display history + shared.gradio['current-agent'] = gr.HTML(value='
      No agent is loaded, please select in the Agent tab.
      ') + shared.gradio['display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': []})) + + # Chat input area + with gr.Row(elem_id='chat-input-row'): + with gr.Column(scale=1, elem_id='gr-hover-container'): + gr.HTML(value='
      ', elem_id='gr-hover') + + with gr.Column(scale=10, elem_id='chat-input-container'): + shared.gradio['textbox'] = gr.Textbox(label='', placeholder='Send a message', elem_id='chat-input', elem_classes=['add_scrollbar']) + shared.gradio['typing-dots'] = gr.HTML(value='
      ', label='typing', elem_id='typing-container') + + with gr.Column(scale=1, elem_id='upload-container'): + shared.gradio['upload-file'] = gr.UploadButton('📁', elem_id='upload', file_types=['image', 'audio']) + + with gr.Column(scale=1, elem_id='generate-stop-container'): + with gr.Row(): + shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop', visible=False) + shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate', variant='primary') + + # Hover menu buttons + with gr.Column(elem_id='chat-buttons'): + with gr.Row(): + shared.gradio['Regenerate'] = gr.Button('Regenerate', elem_id='Regenerate') + shared.gradio['Start new chat'] = gr.Button('Start new chat') + + with gr.Group(): + gr.HTML('
      Select tools
      ') + with gr.Row(): + shared.gradio['select_all_tools'] = gr.Button('All', size='sm') + shared.gradio['select_no_tools'] = gr.Button('None', size='sm') + shared.gradio['selected_tools'] = gr.CheckboxGroup(show_label=False, choices=shared.toolkits.keys(), value=list(shared.toolkits.keys())) + + with gr.Row(elem_id='past-chats-row'): + shared.gradio['unique_id'] = gr.Dropdown(label='Past chats', elem_classes=['slim-dropdown'], choices=chat.find_all_histories(), value=None, allow_custom_value=True) + shared.gradio['save_chat'] = gr.Button('Save', elem_classes='refresh-button') + delete, confirm, cancel = ui.create_confirm_cancel('🗑️', elem_classes='refresh-button') + shared.gradio['delete_chat'] = delete + shared.gradio['delete_chat-confirm'] = confirm + shared.gradio['delete_chat-cancel'] = cancel + +def create_event_handlers(): + + inputs = ('Chat input', 'interface_state', 'uploaded-files') + + shared.gradio['Generate']\ + .click(ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state'))\ + .then(lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False)\ + .then(chat.generate_chat_reply_wrapper, gradio(inputs), gradio('display', 'history'), show_progress=False)\ + .then(ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state'))\ + .then(lambda: [], None, gradio('uploaded-files'), show_progress=False)\ + .then(chat.save_history, gradio('history', 'unique_id'), None) + + shared.gradio['textbox']\ + .submit(ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state'))\ + .then(lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False)\ + .then(chat.generate_chat_reply_wrapper, gradio(inputs), gradio('display', 'history'), show_progress=False)\ + .then(ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state'))\ + .then(lambda: [], None, gradio('uploaded-files'), show_progress=False)\ + .then(chat.save_history, gradio('history', 'unique_id'), None) + + shared.gradio['upload-file']\ + .upload(chat.upload_file, gradio('upload-file', 'uploaded-files', 'history'), gradio('uploaded-files', 'display'), show_progress=False)\ + .then(lambda: gr.update(value=None), None, gradio('upload-file'), show_progress=False) + + shared.gradio['Stop']\ + .click(stop_everything_event, None, None, queue=False)\ + .then(chat.redraw_html, gradio('history'), gradio('display')) + + shared.gradio['select_all_tools'].click(lambda: list(shared.toolkits), None, gradio('selected_tools'), show_progress=False) + shared.gradio['select_no_tools'].click(lambda: [], None, gradio('selected_tools'), show_progress=False) + + shared.gradio['unique_id']\ + .select(chat.load_history, gradio('unique_id'), gradio('history'))\ + .then(chat.redraw_html, gradio('history'), gradio('display')) + + shared.gradio['Start new chat']\ + .click(lambda: {'internal': [], 'visible': []}, None, gradio('history'))\ + .then(chat.redraw_html, gradio('history'), gradio('display'))\ + .then(lambda: gr.update(value=None), None, gradio('unique_id')) + + shared.gradio['Regenerate']\ + .click(ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state'))\ + .then(partial(chat.generate_chat_reply_wrapper, regenerate=True), gradio(inputs), gradio('display', 'history'), show_progress=False)\ + .then(ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state'))\ + .then(chat.save_history, gradio('history', 'unique_id'), None) + + shared.gradio['save_chat']\ + .click(lambda unique_id: unique_id or datetime.now().strftime('%Y%m%d-%H-%M-%S'), gradio('unique_id'), gradio('unique_id'))\ + .then(chat.save_history, gradio('history', 'unique_id'), None)\ + .then(lambda: gr.update(choices=chat.find_all_histories()), None, gradio('unique_id'), show_progress=False) + + delete_history_widgets = ('delete_chat', 'delete_chat-confirm', 'delete_chat-cancel') + shared.gradio['delete_chat-confirm']\ + .click(lambda: {'internal': [], 'visible': []}, None, gradio('history'))\ + .then(chat.redraw_html, gradio('history'), gradio('display'))\ + .then(chat.delete_history, gradio('unique_id'), None)\ + .then(lambda: gr.update(value=None, choices=chat.find_all_histories()), None, gradio('unique_id'), show_progress=False)\ + .then(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)], None, gradio(delete_history_widgets)) diff --git a/webui/modules/ui_tools.py b/webui/modules/ui_tools.py new file mode 100644 index 00000000..cd4f4dbb --- /dev/null +++ b/webui/modules/ui_tools.py @@ -0,0 +1,172 @@ +import inspect +from functools import partial + +import gradio as gr + +from agentlego.tools.remote import RemoteTool +from . import shared, ui, utils +from .settings import apply_tool_settings_to_state, save_tool_settings +from .tools import delete_tool, load_tool +from .utils import gradio + + +def create_ui(): + + with gr.Tab('Tools', elem_id='tools-tab'): + with gr.Row(): + shared.gradio['tool_menu'] = gr.Dropdown(choices=utils.get_available_tools(), label='Tools', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['tool_menu'], lambda: None, lambda: {'choices': utils.get_available_tools()}, 'refresh-button') + delete, confirm, cancel = ui.create_confirm_cancel('🗑️', elem_classes='refresh-button') + shared.gradio['delete_tool'] = delete + shared.gradio['delete_tool-confirm'] = confirm + shared.gradio['delete_tool-cancel'] = cancel + shared.gradio['setup_all_tools'] = gr.Button('Setup All', elem_classes='refresh-button') + + with gr.Row(): + with gr.Column(scale=2): + with gr.Row(): + shared.gradio['tool_class'] = gr.Dropdown(label='Tool class', choices=list(shared.tool_classes), interactive=False, visible=False, elem_classes=['slim-dropdown']) + shared.gradio['save_tool'] = gr.Button('Save', visible=False, elem_classes='refresh-button') + with gr.Group(): + shared.gradio['tool_name'] = gr.Textbox(label='Name', visible=False) + shared.gradio['tool_description'] = gr.Textbox(label='Description', visible=False) + shared.gradio['tool_enable'] = gr.Checkbox(label='Enable', value=True, visible=False) + shared.gradio['tool_device'] = gr.Dropdown(label='Device', value='cpu', choices=utils.get_available_devices(), visible=False) + shared.gradio['tool_args'] = gr.Textbox(label='Initialize arguments', visible=False) + + with gr.Column(scale=1): + with gr.Row(): + shared.gradio['remote_server'] = gr.Textbox(label='Import from tool server', elem_classes=['slim-dropdown']) + shared.gradio['import_remote'] = gr.Button('Confirm', elem_classes='refresh-button') + + shared.gradio['tool_status'] = gr.Markdown('') + + +def make_tool_param_visiable(class_name, tool_name): + if not class_name: + return [gr.update(visible=False)] * 5 + + tool_class = shared.tool_classes[class_name] + params = inspect.signature(tool_class).parameters + if 'device' in params: + device = params['device'].default + if 'cuda' in str(device) and 'cuda:0' in utils.get_available_devices(): + device = 'cuda:0' + else: + device = 'cpu' + else: + device = None + + if tool_name == 'New Tool': + if tool_class is not RemoteTool: + toolmeta = tool_class.get_default_toolmeta() + module_name = tool_class.__module__ + if module_name.startswith('_ext_'): + name = module_name.removeprefix('_ext_') + '.' + toolmeta.name + else: + name = toolmeta.name + description = toolmeta.description + else: + name, description = '', '' + return ( + gr.update(value=name, visible=True), + gr.update(value=description, visible=True), + gr.update(value=True, visible=True), + gr.update(value=device, visible=(device is not None)), + gr.update(value='', visible=True), + ) + else: + return ( + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=(device is not None)), + gr.update(visible=True), + ) + + +def setup_tools(*args): + for name in args: + tool = shared.toolkits[name] + if not tool._is_setup: + yield f'Setup `{name}`...' + tool.setup() + tool._is_setup = True + yield 'Done' + + +def import_remote_tools(server: str): + if not server.startswith('http'): + server = 'http://' + server + tools = RemoteTool.from_server(server) + msg = '' + for tool in tools: + if tool.name in shared.tool_settings: + msg += f'- Skip `{tool.name}` since it is already in Tools.\n' + yield msg + continue + save_tool_settings( + tool_class='RemoteTool', + name=tool.name, + desc=tool.toolmeta.description, + enable=True, + device='cpu', + args=f'url="{tool.url}"', + ) + shared.toolkits[tool.name] = tool + msg += f'- Imported tool `{tool.name}`\n' + yield msg + + +def create_event_handlers(): + def update_widgets(tool_menu): + if tool_menu is None: + return [gr.update(visible=False, interactive=False), gr.update(visible=False)] + elif tool_menu == 'New Tool': + return [gr.update(visible=True, interactive=True), gr.update(visible=True)] + else: + return [gr.update(visible=True, interactive=False), gr.update(visible=True)] + + tool_cfg_widgets = gradio('tool_name', 'tool_description', 'tool_enable', 'tool_device', 'tool_args') + + shared.gradio['tool_menu']\ + .change(ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state'))\ + .then(apply_tool_settings_to_state, gradio('tool_menu', 'interface_state'), gradio('interface_state'))\ + .then(ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False)\ + .then(update_widgets, gradio('tool_menu'), gradio('tool_class', 'save_tool'), show_progress=False)\ + + shared.gradio['tool_class'].change(make_tool_param_visiable, gradio('tool_class', 'tool_menu'), tool_cfg_widgets, show_progress=False) + + shared.gradio['setup_all_tools'].click( + partial(setup_tools, *shared.toolkits), None, gradio('tool_status')) + + def load_tool_wrapper(name): + yield f'Loading tool `{name}`...' + try: + tool = load_tool(name) + except Exception as e: + yield f'Failed to load `{name}`' + raise e + else: + if tool is not None: + yield f'Loaded `{name}`.' + else: + yield f'Skipped disabled tool `{name}`.' + + shared.gradio['save_tool']\ + .click(save_tool_settings, gradio('tool_class') + tool_cfg_widgets + gradio('tool_menu'), gradio('tool_status'), show_progress=False)\ + .success(load_tool_wrapper, gradio('tool_name'), gradio('tool_status'))\ + .success(lambda name: gr.update(value=name, choices=utils.get_available_tools()), gradio('tool_name'), gradio('tool_menu'))\ + .then(lambda: gr.update(choices=list(shared.toolkits)), None, gradio('selected_tools')) + + shared.gradio['import_remote']\ + .click(import_remote_tools, gradio('remote_server'), gradio('tool_status'))\ + .then(lambda: gr.update(choices=utils.get_available_tools()), None, gradio('tool_menu'))\ + .then(lambda: gr.update(choices=list(shared.toolkits)), None, gradio('selected_tools')) + + delete_tool_widgets = ('delete_tool', 'delete_tool-confirm', 'delete_tool-cancel') + shared.gradio['delete_tool-confirm']\ + .click(delete_tool, gradio('tool_menu'), gradio('tool_status'))\ + .success(lambda: gr.update(value=None, choices=utils.get_available_tools()), None, gradio('tool_menu'))\ + .then(lambda: gr.update(choices=list(shared.toolkits)), None, gradio('selected_tools'))\ + .then(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)], None, gradio(delete_tool_widgets)) diff --git a/webui/modules/utils.py b/webui/modules/utils.py new file mode 100644 index 00000000..f9d08ec7 --- /dev/null +++ b/webui/modules/utils.py @@ -0,0 +1,154 @@ +import functools +import re +from ast import literal_eval +from datetime import datetime +from pathlib import Path +from typing import Mapping, Tuple, Union + +from agentlego.schema import ToolMeta +from agentlego.types import AudioIO, File, ImageIO +from . import shared +from .logging import logger + + +def parse_inputs(toolmeta: ToolMeta, args: Union[str, tuple, dict]) -> Mapping[str, dict]: + if not args: + return {} + params = {p.name: p for p in toolmeta.inputs} + + if len(params) > 1 and isinstance(args, str): + try: + args = literal_eval(args) + except Exception: + pass + + if isinstance(args, str): + args = {toolmeta.inputs[0].name: args} + elif isinstance(args, tuple): + args = {name: args for name in params} + + parsed_args = {} + for k, v in args.items(): + p = params[k] + if p.type is ImageIO: + parsed_args[k] = dict(type='image', content=v) + elif p.type is AudioIO: + parsed_args[k] = dict(type='audio', content=v) + elif p.type is File: + parsed_args[k] = dict(type='file', content=v) + else: + parsed_args[k] = dict(type='text', content=v) + + return parsed_args + + +def parse_outputs(toolmeta: ToolMeta, outs: Union[str, tuple, dict]) -> Tuple[dict, ...]: + if not outs: + return () + + if len(toolmeta.outputs) > 1 and isinstance(outs, str): + try: + outs = literal_eval(outs) + except Exception: + pass + + if isinstance(outs, str): + outs = (outs, ) + elif isinstance(outs, dict): + outs = tuple(outs.values()) + + parsed_outs = [] + for p, out in zip(toolmeta.outputs, outs): + if p.type is ImageIO: + parsed_outs.append(dict(type='image', content=out)) + elif p.type is AudioIO: + parsed_outs.append(dict(type='audio', content=out)) + elif p.type is File: + parsed_outs.append(dict(type='file', content=out)) + else: + parsed_outs.append(dict(type='text', content=out)) + + return tuple(parsed_outs) + + +# Helper function to get multiple values from shared.gradio +def gradio(*keys): + if len(keys) == 1 and isinstance(keys[0], (list, tuple, set)): + keys = keys[0] + + return [shared.gradio[k] for k in keys] + + +def save_file(fname, contents): + if fname == '': + logger.error('File name is empty!') + return + + root_folder = Path(__file__).resolve().parent.parent + abs_path = Path(fname).resolve() + rel_path = abs_path.relative_to(root_folder) + if rel_path.parts[0] == '..': + logger.error(f'Invalid file path: {fname}') + return + + with open(abs_path, 'w', encoding='utf-8') as f: + f.write(contents) + + logger.info(f'Saved {abs_path}.') + + +def delete_file(fname): + if fname == '': + logger.error('File name is empty!') + return + + root_folder = Path(__file__).resolve().parent.parent + abs_path = Path(fname).resolve() + rel_path = abs_path.relative_to(root_folder) + if rel_path.parts[0] == '..': + logger.error(f'Invalid file path: {fname}') + return + + if abs_path.exists(): + abs_path.unlink() + logger.info(f'Deleted {fname}.') + + +def current_time(): + return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}" + + +def atoi(text): + return int(text) if text.isdigit() else text.lower() + + +# Replace multiple string pairs in a string +def replace_all(text, dic): + for i, j in dic.items(): + text = text.replace(i, j) + + return text + + +def natural_keys(text): + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_available_agents(): + return ['New Agent'] + sorted(shared.agents_settings.keys(), key=natural_keys) + + +def get_available_tools(): + return ['New Tool'] + sorted(shared.tool_settings.keys(), key=natural_keys) + + +@functools.lru_cache() +def get_available_devices(): + devices = ['cpu'] + try: + import torch + if torch.cuda.is_available(): + devices += [f'cuda:{i}' for i in range(torch.cuda.device_count())] + except ImportError: + pass + return devices diff --git a/webui/one_click.py b/webui/one_click.py new file mode 100644 index 00000000..46646f8e --- /dev/null +++ b/webui/one_click.py @@ -0,0 +1,170 @@ +import os +import platform +import re +import shutil +import signal +import subprocess +import sys +from importlib.metadata import PackageNotFoundError, distribution +from pathlib import Path + +script_dir = Path(__file__).absolute().parent +conda_env_path = os.path.join(script_dir, 'installer_files', 'env') +agentlego_root = script_dir.parent + +# Remove the '# ' from the following lines as needed for your AMD GPU on Linux +# os.environ["ROCM_PATH"] = '/opt/rocm' +# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0' +# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030' + +signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0)) + + +def is_linux(): + return sys.platform.startswith('linux') + + +def is_windows(): + return sys.platform.startswith('win') + + +def is_macos(): + return sys.platform.startswith('darwin') + + +def is_x86_64(): + return platform.machine() == 'x86_64' + + +def cpu_has_avx2(): + try: + import cpuinfo + + info = cpuinfo.get_cpu_info() + if 'avx2' in info['flags']: + return True + else: + return False + except Exception: + return True + + +def cpu_has_amx(): + try: + import cpuinfo + + info = cpuinfo.get_cpu_info() + if 'amx' in info['flags']: + return True + else: + return False + except Exception: + return True + + +def digit_version(version_str: str): + pattern = r'(?P\d+)\.?(?P\d+)?\.?(?P\d+)?' + version = re.match(pattern, version_str) + assert version is not None, f'failed to parse version {version_str}' + return tuple(int(i) if i is not None else 0 for i in version.groups()) + + +def get_version(package: str): + try: + dist = distribution(package) + return digit_version(dist.version) + except PackageNotFoundError: + return None + + +def check_env(): + # If we have access to conda, we are probably in an environment + conda_exist = run_cmd('conda', capture_output=True).returncode == 0 + if not conda_exist: + print('Conda is not installed. Exiting...') + sys.exit(1) + + # Ensure this is a new environment and not the base environment + if os.environ['CONDA_DEFAULT_ENV'] == 'base': + print('Create an environment for this project and activate it. Exiting...') + sys.exit(1) + + +def clear_cache(): + run_cmd('conda clean -a -y') + run_cmd('python -m pip cache purge') + + +def print_big_message(message): + message = message.strip() + lines = message.split('\n') + print('\n\n*******************************************************************') + for line in lines: + if line.strip() != '': + print('*', line) + + print('*******************************************************************\n\n') + + +def run_cmd(cmd, assert_success=False, capture_output=False, env=None): + # Run shell commands + result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env) + + # Assert the command ran successfully + if assert_success and result.returncode != 0: + print("Command '" + cmd + "' failed with exit status code '" + + str(result.returncode) + + "'.\n\nExiting now.\nTry running the script again.") + sys.exit(1) + + return result + + +def install_agentlego(): + agentlego = get_version('agentlego') + if agentlego is None or not Path( + distribution('agentlego').locate_file('.')).samefile(agentlego_root): + print('Installing AgentLego') + run_cmd(f'python -m pip install -e {agentlego_root}', assert_success=True) + + +def install_demo_dependencies(): + gradio = get_version('gradio') + if gradio is None or gradio < digit_version('4.13.0'): + run_cmd(f'python -m pip install "gradio>=4.13.0"', assert_success=True) + + if get_version('langchain') is None: + run_cmd(f'python -m pip install langchain', assert_success=True) + if get_version('langchain-openai') is None: + run_cmd(f'python -m pip install langchain-openai', assert_success=True) + + if get_version('markdown') is None: + run_cmd(f'python -m pip install markdown', assert_success=True) + + lagent = get_version('lagent') + if lagent is None or lagent < digit_version('0.2.0'): + run_cmd(f'python -m pip install "lagent>=0.2.0"', assert_success=True) + + +if __name__ == '__main__': + # Verifies we are in a conda environment + check_env() + os.chdir(script_dir) + if not (script_dir / 'tool_config.yml').exists(): + shutil.copy(script_dir / 'tool_config.yml.example', + script_dir / 'tool_config.yml') + if not (script_dir / 'agent_config.yml').exists(): + shutil.copy(script_dir / 'agent_config.yml.example', + script_dir / 'agent_config.yml') + + # Install the current version agentlego + install_agentlego() + + # Install gradio + install_demo_dependencies() + + # Install main tools dependencies + # install_tool_dependencies() + + # Launch the gradio + run_cmd(f'python app.py ' + ' '.join(sys.argv[1:])) diff --git a/webui/start_linux.sh b/webui/start_linux.sh new file mode 100755 index 00000000..948e40b8 --- /dev/null +++ b/webui/start_linux.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +cd "$(dirname "${BASH_SOURCE[0]}")" + +if [[ "$(pwd)" =~ " " ]]; then echo This script relies on Miniconda which can not be silently installed under a path with spaces. && exit; fi + +# deactivate existing conda envs as needed to avoid conflicts +{ conda deactivate && conda deactivate && conda deactivate; } 2> /dev/null + +OS_ARCH=$(uname -m) +case "${OS_ARCH}" in + x86_64*) OS_ARCH="x86_64";; + arm64*) OS_ARCH="aarch64";; + aarch64*) OS_ARCH="aarch64";; + *) echo "Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64" && exit +esac + +# config +INSTALL_DIR="$(pwd)/installer_files" +CONDA_ROOT_PREFIX="$(pwd)/installer_files/conda" +INSTALL_ENV_DIR="$(pwd)/installer_files/env" +MINICONDA_DOWNLOAD_URL="https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-${OS_ARCH}.sh" +conda_exists="F" + +# figure out whether git and conda needs to be installed +if "$CONDA_ROOT_PREFIX/bin/conda" --version &>/dev/null; then conda_exists="T"; fi + +# (if necessary) install git and conda into a contained environment +# download miniconda +if [ "$conda_exists" == "F" ]; then + echo "Downloading Miniconda from $MINICONDA_DOWNLOAD_URL to $INSTALL_DIR/miniconda_installer.sh" + + mkdir -p "$INSTALL_DIR" + curl -Lk "$MINICONDA_DOWNLOAD_URL" > "$INSTALL_DIR/miniconda_installer.sh" + + chmod u+x "$INSTALL_DIR/miniconda_installer.sh" + bash "$INSTALL_DIR/miniconda_installer.sh" -b -p $CONDA_ROOT_PREFIX + + # test the conda binary + echo "Miniconda version:" + "$CONDA_ROOT_PREFIX/bin/conda" --version +fi + +# create the installer env +if [ ! -e "$INSTALL_ENV_DIR" ]; then + "$CONDA_ROOT_PREFIX/bin/conda" create -y -k --prefix "$INSTALL_ENV_DIR" python=3.11 +fi + +# check if conda environment was actually created +if [ ! -e "$INSTALL_ENV_DIR/bin/python" ]; then + echo "Conda environment is empty." + exit +fi + +# environment isolation +export PYTHONNOUSERSITE=1 +unset PYTHONPATH +unset PYTHONHOME +export CUDA_PATH="$INSTALL_ENV_DIR" +export CUDA_HOME="$CUDA_PATH" + +# activate installer env +source "$CONDA_ROOT_PREFIX/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script) +conda activate "$INSTALL_ENV_DIR" + +# setup installer env +python one_click.py $@ diff --git a/webui/tool_config.yml.example b/webui/tool_config.yml.example new file mode 100644 index 00000000..c512f5a0 --- /dev/null +++ b/webui/tool_config.yml.example @@ -0,0 +1,9 @@ +Calculator: + class: Calculator + name: Calculator + description: A calculator tool. The input must be a single Python expression and + you cannot import packages. You can use functions in the `math` package without + import. + enable: true + device: null + args: ''