Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model): Support siliconflow models #2157

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ def __init__(self) -> None:
os.environ["claude_proxyllm_api_base"] = os.getenv(
"ANTHROPIC_BASE_URL", "https://api.anthropic.com"
)
self.silicon_flow_proxy_api_key = os.getenv("SILICON_FLOW_API_KEY")
if self.silicon_flow_proxy_api_key:
os.environ[
"silicon_flow_proxyllm_proxy_api_key"
] = self.silicon_flow_proxy_api_key
os.environ["silicon_flow_proxyllm_proxyllm_backend"] = os.getenv(
"SILICON_FLOW_MODEL_VERSION", "Qwen/Qwen2.5-Coder-32B-Instruct"
)
os.environ["silicon_flow_proxyllm_api_base"] = os.getenv(
"SILICON_FLOW_API_BASE", "https://api.siliconflow.cn/v1"
)

self.proxy_server_url = os.getenv("PROXY_SERVER_URL")

Expand Down
2 changes: 2 additions & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def get_device() -> str:
"ollama_proxyllm": "ollama_proxyllm",
# https://platform.deepseek.com/api-docs/
"deepseek_proxyllm": "deepseek_proxyllm",
# https://docs.siliconflow.cn/quickstart
"silicon_flow_proxyllm": "silicon_flow_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
Expand Down
26 changes: 26 additions & 0 deletions dbgpt/model/adapter/proxy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,31 @@ def get_async_generate_stream_function(self, model, model_path: str):
return deepseek_generate_stream


class SiliconFlowProxyLLMModelAdapter(ProxyLLMModelAdapter):
"""SiliconFlow proxy LLM model adapter.

See Also: `SiliconFlow Documentation <https://docs.siliconflow.cn/quickstart>`_
"""

def support_async(self) -> bool:
return True

def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "silicon_flow_proxyllm"

def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient

return SiliconFlowLLMClient

def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.siliconflow import silicon_flow_generate_stream

return silicon_flow_generate_stream


register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(ClaudeProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter)
Expand All @@ -352,3 +377,4 @@ def get_async_generate_stream_function(self, model, model_path: str):
register_model_adapter(YiProxyLLMModelAdapter)
register_model_adapter(MoonshotProxyLLMModelAdapter)
register_model_adapter(DeepseekProxyLLMModelAdapter)
register_model_adapter(SiliconFlowProxyLLMModelAdapter)
3 changes: 3 additions & 0 deletions dbgpt/model/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dbgpt.model.proxy.llms.gemini import GeminiLLMClient
from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient
from dbgpt.model.proxy.llms.spark import SparkLLMClient
from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient
from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient
Expand All @@ -21,6 +22,7 @@ def __lazy_import(name):
"OpenAILLMClient": "dbgpt.model.proxy.llms.chatgpt",
"ClaudeLLMClient": "dbgpt.model.proxy.llms.claude",
"GeminiLLMClient": "dbgpt.model.proxy.llms.gemini",
"SiliconFlowLLMClient": "dbgpt.model.proxy.llms.siliconflow",
"SparkLLMClient": "dbgpt.model.proxy.llms.spark",
"TongyiLLMClient": "dbgpt.model.proxy.llms.tongyi",
"WenxinLLMClient": "dbgpt.model.proxy.llms.wenxin",
Expand Down Expand Up @@ -49,6 +51,7 @@ def __getattr__(name):
"TongyiLLMClient",
"ZhipuLLMClient",
"WenxinLLMClient",
"SiliconFlowLLMClient",
"SparkLLMClient",
"YiLLMClient",
"MoonshotLLMClient",
Expand Down
87 changes: 87 additions & 0 deletions dbgpt/model/proxy/llms/siliconflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request

from .chatgpt import OpenAILLMClient

if TYPE_CHECKING:
from httpx._types import ProxiesTypes
from openai import AsyncAzureOpenAI, AsyncOpenAI

ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]


_SILICON_FLOW_DEFAULT_MODEL = "Qwen/Qwen2.5-Coder-32B-Instruct"


async def silicon_flow_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
client: SiliconFlowLLMClient = model.proxy_llm_client
request = parse_model_request(params, client.default_model, stream=True)
async for r in client.generate_stream(request):
yield r


class SiliconFlowLLMClient(OpenAILLMClient):
"""SiliconFlow LLM Client.

SiliconFlow's API is compatible with OpenAI's API, so we inherit from OpenAILLMClient.
"""

def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
model: Optional[str] = None,
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "silicon_flow_proxyllm",
context_length: Optional[int] = None,
openai_client: Optional["ClientType"] = None,
openai_kwargs: Optional[Dict[str, Any]] = None,
**kwargs
):
api_base = (
api_base
or os.getenv("SILICON_FLOW_API_BASE")
or "https://api.siliconflow.cn/v1"
)
api_key = api_key or os.getenv("SILICON_FLOW_API_KEY")
model = model or _SILICON_FLOW_DEFAULT_MODEL
if not context_length:
if "200k" in model:
context_length = 200 * 1024
else:
context_length = 4096

if not api_key:
raise ValueError(
"SiliconFlow API key is required, please set 'SILICON_FLOW_API_KEY' in environment "
"or pass it as an argument."
)

super().__init__(
api_key=api_key,
api_base=api_base,
api_type=api_type,
api_version=api_version,
model=model,
proxies=proxies,
timeout=timeout,
model_alias=model_alias,
context_length=context_length,
openai_client=openai_client,
openai_kwargs=openai_kwargs,
**kwargs
)

@property
def default_model(self) -> str:
model = self._model
if not model:
model = _SILICON_FLOW_DEFAULT_MODEL
return model
Loading