diff --git a/api/core/tools/provider/builtin/xinference/_assets/icon.png b/api/core/tools/provider/builtin/xinference/_assets/icon.png new file mode 100644 index 00000000000000..e58cacbd123b58 Binary files /dev/null and b/api/core/tools/provider/builtin/xinference/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py new file mode 100644 index 00000000000000..847f2730f240c4 --- /dev/null +++ b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py @@ -0,0 +1,412 @@ +import io +import json +from base64 import b64decode, b64encode +from copy import deepcopy +from typing import Any, Union + +from httpx import get, post +from PIL import Image +from yarl import URL + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolInvokeMessage, + ToolParameter, + ToolParameterOption, +) +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + +# All commented out parameters default to null +DRAW_TEXT_OPTIONS = { + # Prompts + "prompt": "", + "negative_prompt": "", + # "styles": [], + # Seeds + "seed": -1, + "subseed": -1, + "subseed_strength": 0, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + # Samplers + "sampler_name": "DPM++ 2M", + # "scheduler": "", + # "sampler_index": "Automatic", + # Latent Space Options + "batch_size": 1, + "n_iter": 1, + "steps": 10, + "cfg_scale": 7, + "width": 512, + "height": 512, + # "restore_faces": True, + # "tiling": True, + "do_not_save_samples": False, + "do_not_save_grid": False, + # "eta": 0, + # "denoising_strength": 0.75, + # "s_min_uncond": 0, + # "s_churn": 0, + # "s_tmax": 0, + # "s_tmin": 0, + # "s_noise": 0, + "override_settings": {}, + "override_settings_restore_afterwards": True, + # Refinement Options + "refiner_checkpoint": "", + "refiner_switch_at": 0, + "disable_extra_networks": False, + # "firstpass_image": "", + # "comments": "", + # High-Resolution Options + "enable_hr": False, + "firstphase_width": 0, + "firstphase_height": 0, + "hr_scale": 2, + # "hr_upscaler": "", + "hr_second_pass_steps": 0, + "hr_resize_x": 0, + "hr_resize_y": 0, + # "hr_checkpoint_name": "", + # "hr_sampler_name": "", + # "hr_scheduler": "", + "hr_prompt": "", + "hr_negative_prompt": "", + # Task Options + # "force_task_id": "", + # Script Options + # "script_name": "", + "script_args": [], + # Output Options + "send_images": True, + "save_images": False, + "alwayson_scripts": {}, + # "infotext": "", +} + + +class StableDiffusionTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + # base url + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return self.create_text_message("Please input base_url") + + if tool_parameters.get("model"): + self.runtime.credentials["model"] = tool_parameters["model"] + + model = self.runtime.credentials.get("model", None) + if not model: + return self.create_text_message("Please input model") + + # set model + try: + url = str(URL(base_url) / "sdapi" / "v1" / "options") + response = post( + url, + json={"sd_model_checkpoint": model}, + headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"}, + ) + if response.status_code != 200: + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") + except Exception as e: + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") + + # get image id and image variable + image_id = tool_parameters.get("image_id", "") + image_variable = self.get_default_image_variable() + # Return text2img if there's no image ID or no image variable + if not image_id or not image_variable: + return self.text2img(base_url=base_url, tool_parameters=tool_parameters) + + # Proceed with image-to-image generation + return self.img2img(base_url=base_url, tool_parameters=tool_parameters) + + def validate_models(self): + """ + validate models + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + raise ToolProviderCredentialValidationError("Please input base_url") + model = self.runtime.credentials.get("model", None) + if not model: + raise ToolProviderCredentialValidationError("Please input model") + + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") + response = get(url=api_url, timeout=10) + if response.status_code == 404: + # try draw a picture + self._invoke( + user_id="test", + tool_parameters={ + "prompt": "a cat", + "width": 1024, + "height": 1024, + "steps": 1, + "lora": "", + }, + ) + elif response.status_code != 200: + raise ToolProviderCredentialValidationError("Failed to get models") + else: + models = [d["model_name"] for d in response.json()] + if len([d for d in models if d == model]) > 0: + return self.create_text_message(json.dumps(models)) + else: + raise ToolProviderCredentialValidationError(f"model {model} does not exist") + except Exception as e: + raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") + + def get_sd_models(self) -> list[str]: + """ + get sd models + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") + response = get(url=api_url, timeout=120) + if response.status_code != 200: + return [] + else: + return [d["model_name"] for d in response.json()] + except Exception as e: + return [] + + def get_sample_methods(self) -> list[str]: + """ + get sample method + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "sdapi" / "v1" / "samplers") + response = get(url=api_url, timeout=120) + if response.status_code != 200: + return [] + else: + return [d["name"] for d in response.json()] + except Exception as e: + return [] + + def img2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image + """ + + # Fetch the binary data of the image + image_variable = self.get_default_image_variable() + image_binary = self.get_variable_file(image_variable.name) + if not image_binary: + return self.create_text_message("Image not found, please request user to generate image firstly.") + + # Convert image to RGB and save as PNG + try: + with Image.open(io.BytesIO(image_binary)) as image, io.BytesIO() as buffer: + image.convert("RGB").save(buffer, format="PNG") + image_binary = buffer.getvalue() + except Exception as e: + return self.create_text_message(f"Failed to process the image: {str(e)}") + + # copy draw options + draw_options = deepcopy(DRAW_TEXT_OPTIONS) + # set image options + model = tool_parameters.get("model", "") + draw_options_image = { + "init_images": [b64encode(image_binary).decode("utf-8")], + "denoising_strength": 0.9, + "restore_faces": False, + "script_args": [], + "override_settings": {"sd_model_checkpoint": model}, + "resize_mode": 0, + "image_cfg_scale": 0, + # "mask": None, + "mask_blur_x": 4, + "mask_blur_y": 4, + "mask_blur": 0, + "mask_round": True, + "inpainting_fill": 0, + "inpaint_full_res": True, + "inpaint_full_res_padding": 0, + "inpainting_mask_invert": 0, + "initial_noise_multiplier": 0, + # "latent_mask": None, + "include_init_images": True, + } + # update key and values + draw_options.update(draw_options_image) + draw_options.update(tool_parameters) + + # get prompt lora model + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") + if lora: + draw_options["prompt"] = f"{lora},{prompt}" + else: + draw_options["prompt"] = prompt + + try: + url = str(URL(base_url) / "sdapi" / "v1" / "img2img") + response = post( + url, + json=draw_options, + timeout=120, + headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"}, + ) + if response.status_code != 200: + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + + except Exception as e: + return self.create_text_message("Failed to generate image") + + def text2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image + """ + # copy draw options + draw_options = deepcopy(DRAW_TEXT_OPTIONS) + draw_options.update(tool_parameters) + # get prompt lora model + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") + if lora: + draw_options["prompt"] = f"{lora},{prompt}" + else: + draw_options["prompt"] = prompt + draw_options["override_settings"]["sd_model_checkpoint"] = model + + try: + url = str(URL(base_url) / "sdapi" / "v1" / "txt2img") + response = post( + url, + json=draw_options, + timeout=120, + headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"}, + ) + if response.status_code != 200: + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + + except Exception as e: + return self.create_text_message("Failed to generate image") + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), + human_description=I18nObject( + en_US="Image prompt, you can check the official documentation of Stable Diffusion", + zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image prompt of Stable Diffusion, you should describe the image you want to generate" + " as a list of words as possible as detailed, the prompt must be written in English.", + required=True, + ), + ] + if len(self.list_default_image_variables()) != 0: + parameters.append( + ToolParameter( + name="image_id", + label=I18nObject(en_US="image_id", zh_Hans="image_id"), + human_description=I18nObject( + en_US="Image id of the image you want to generate based on, if you want to generate image based" + " on the default image, you can leave this field empty.", + zh_Hans="您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image id of the original image, you can leave this field empty if you want to" + " generate a new image.", + required=True, + options=[ + ToolParameterOption(value=i.name, label=I18nObject(en_US=i.name, zh_Hans=i.name)) + for i in self.list_default_image_variables() + ], + ) + ) + + if self.runtime.credentials: + try: + models = self.get_sd_models() + if len(models) != 0: + parameters.append( + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="Model of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + zh_Hans="Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + required=True, + default=models[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models + ], + ) + ) + + except: + pass + + sample_methods = self.get_sample_methods() + if len(sample_methods) != 0: + parameters.append( + ToolParameter( + name="sampler_name", + label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), + human_description=I18nObject( + en_US="Sampling method of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Sampling method of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + required=True, + default=sample_methods[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in sample_methods + ], + ) + ) + return parameters diff --git a/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.yaml b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.yaml new file mode 100644 index 00000000000000..4f1d17f175c567 --- /dev/null +++ b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.yaml @@ -0,0 +1,87 @@ +identity: + name: stable_diffusion + author: xinference + label: + en_US: Stable Diffusion + zh_Hans: Stable Diffusion +description: + human: + en_US: Generate images using Stable Diffusion models. + zh_Hans: 使用 Stable Diffusion 模型生成图片。 + llm: draw the image you want based on your prompt. +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Image prompt + zh_Hans: 图像提示词 + llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: model + type: string + required: false + label: + en_US: Model Name + zh_Hans: 模型名称 + human_description: + en_US: Model Name + zh_Hans: 模型名称 + form: form + - name: lora + type: string + required: false + label: + en_US: Lora + zh_Hans: Lora + human_description: + en_US: Lora + zh_Hans: Lora + form: form + - name: steps + type: number + required: false + label: + en_US: Steps + zh_Hans: Steps + human_description: + en_US: Steps + zh_Hans: Steps + form: form + default: 10 + - name: width + type: number + required: false + label: + en_US: Width + zh_Hans: Width + human_description: + en_US: Width + zh_Hans: Width + form: form + default: 1024 + - name: height + type: number + required: false + label: + en_US: Height + zh_Hans: Height + human_description: + en_US: Height + zh_Hans: Height + form: form + default: 1024 + - name: negative_prompt + type: string + required: false + label: + en_US: Negative prompt + zh_Hans: Negative prompt + human_description: + en_US: Negative prompt + zh_Hans: Negative prompt + form: form + default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines diff --git a/api/core/tools/provider/builtin/xinference/xinference.py b/api/core/tools/provider/builtin/xinference/xinference.py new file mode 100644 index 00000000000000..7c2428cc00534e --- /dev/null +++ b/api/core/tools/provider/builtin/xinference/xinference.py @@ -0,0 +1,18 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class XinferenceProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + base_url = credentials.get("base_url") + api_key = credentials.get("api_key") + model = credentials.get("model") + res = requests.post( + f"{base_url}/sdapi/v1/options", + headers={"Authorization": f"Bearer {api_key}"}, + json={"sd_model_checkpoint": model}, + ) + if res.status_code != 200: + raise ToolProviderCredentialValidationError("Xinference API key is invalid") diff --git a/api/core/tools/provider/builtin/xinference/xinference.yaml b/api/core/tools/provider/builtin/xinference/xinference.yaml new file mode 100644 index 00000000000000..19aaf5cbd1398d --- /dev/null +++ b/api/core/tools/provider/builtin/xinference/xinference.yaml @@ -0,0 +1,40 @@ +identity: + author: xinference + name: xinference + label: + en_US: Xinference + zh_Hans: Xinference + description: + zh_Hans: Xinference 提供的兼容 Stable Diffusion web ui 的图片生成 API。 + en_US: Stable Diffusion web ui compatible API provided by Xinference. + icon: icon.png + tags: + - image +credentials_for_provider: + base_url: + type: secret-input + required: true + label: + en_US: Base URL + zh_Hans: Xinference 服务器的 Base URL + placeholder: + en_US: Please input Xinference server's Base URL + zh_Hans: 请输入 Xinference 服务器的 Base URL + model: + type: text-input + required: true + label: + en_US: Model + zh_Hans: 模型 + placeholder: + en_US: Please input your model name + zh_Hans: 请输入你的模型名称 + api_key: + type: secret-input + required: true + label: + en_US: API Key + zh_Hans: Xinference 服务器的 API Key + placeholder: + en_US: Please input Xinference server's API Key + zh_Hans: 请输入 Xinference 服务器的 API Key