From f72db23baba5973c681832a016a732774bc35a5b Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 29 Aug 2024 16:37:31 +0800 Subject: [PATCH] chore: Fix pylint error (#1915) --- .mypy.ini | 3 +++ dbgpt/model/adapter/embeddings_loader.py | 1 + dbgpt/model/parameter.py | 2 +- dbgpt/rag/embedding/__init__.py | 4 ++-- dbgpt/rag/embedding/embeddings.py | 28 ++++++++++++++---------- setup.py | 2 +- 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index d9e3a7589..335f3fc7f 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -122,3 +122,6 @@ ignore_missing_imports = True [mypy-pypdf.*] ignore_missing_imports = True + +[mypy-qianfan.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/dbgpt/model/adapter/embeddings_loader.py b/dbgpt/model/adapter/embeddings_loader.py index cbc504fdf..892469962 100644 --- a/dbgpt/model/adapter/embeddings_loader.py +++ b/dbgpt/model/adapter/embeddings_loader.py @@ -60,6 +60,7 @@ def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddin return TongYiEmbeddings(**tongyi_param) elif model_name in ["proxy_qianfan"]: from dbgpt.rag.embedding import QianFanEmbeddings + proxy_param = cast(ProxyEmbeddingParameters, param) qianfan_param = {"api_key": proxy_param.proxy_api_key} if proxy_param.proxy_backend: diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index 8debe18fa..ae8168dac 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -563,7 +563,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters): metadata={ "tags": "privacy", "help": "The api secret of the current embedding model(OPENAI_API_SECRET)", - } + }, ) proxy_api_version: Optional[str] = field( default=None, diff --git a/dbgpt/rag/embedding/__init__.py b/dbgpt/rag/embedding/__init__.py index c14987de8..184da3daf 100644 --- a/dbgpt/rag/embedding/__init__.py +++ b/dbgpt/rag/embedding/__init__.py @@ -14,8 +14,8 @@ JinaEmbeddings, OllamaEmbeddings, OpenAPIEmbeddings, - TongYiEmbeddings, QianFanEmbeddings, + TongYiEmbeddings, ) from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401 @@ -34,5 +34,5 @@ "TongYiEmbeddings", "CrossEncoderRerankEmbeddings", "OpenAPIRerankEmbeddings", - "QianFanEmbeddings" + "QianFanEmbeddings", ] diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index 773fb9aa0..f81fd3c0e 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -926,16 +926,18 @@ def embed_query(self, text: str) -> List[float]: class QianFanEmbeddings(BaseModel, Embeddings): """Baidu Qianfan Embeddings embedding models. - Embed: - .. code-block:: python - # embed the documents - vectors = embeddings.embed_documents([text1, text2, ...]) + Embed: + .. code-block:: python - # embed the query - vectors = embeddings.embed_query(text) + # embed the documents + vectors = embeddings.embed_documents([text1, text2, ...]) + + # embed the query + vectors = embeddings.embed_query(text) + + """ # noqa: E501 - """ # noqa: E501 client: Any chunk_size: int = 16 endpoint: str = "" @@ -950,7 +952,7 @@ class QianFanEmbeddings(BaseModel, Embeddings): """Model name you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu - for now, we support Embedding-V1 and + for now, we support Embedding-V1 and - Embedding-V1 (默认模型) - bge-large-en - bge-large-zh @@ -962,7 +964,7 @@ class QianFanEmbeddings(BaseModel, Embeddings): default="text-embedding-v1", description="The name of the model to use." ) init_kwargs: Dict[str, Any] = Field(default_factory=dict) - """init kwargs for qianfan client init, such as `query_per_second` which is + """init kwargs for qianfan client init, such as `query_per_second` which is associated with qianfan resource object to limit QPS""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) @@ -983,7 +985,10 @@ def __init__(self, **kwargs): model_name = kwargs.get("model_name") if not qianfan_ak or not qianfan_sk or not model_name: - raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.") + raise ValueError( + "API key, API secret, and model name are required to initialize " + "QianFanEmbeddings." + ) params = { "model": model_name, @@ -996,6 +1001,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a QianFan embedding model.""" resp = self.embed_documents([text]) return resp[0] @@ -1011,7 +1017,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: Each embedding is represented as a list of float values. """ text_in_chunks = [ - texts[i: i + self.chunk_size] + texts[i : i + self.chunk_size] for i in range(0, len(texts), self.chunk_size) ] lst = [] diff --git a/setup.py b/setup.py index 97b043120..214f82f70 100644 --- a/setup.py +++ b/setup.py @@ -687,7 +687,7 @@ def default_requires(): "chardet", "sentencepiece", "ollama", - "qianfan" + "qianfan", ] setup_spec.extras["default"] += setup_spec.extras["framework"] setup_spec.extras["default"] += setup_spec.extras["rag"]