Skip to content

Commit

Permalink
chore: Fix pylint error (eosphoros-ai#1915)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Aug 29, 2024
1 parent 51b4327 commit f72db23
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 15 deletions.
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,6 @@ ignore_missing_imports = True

[mypy-pypdf.*]
ignore_missing_imports = True

[mypy-qianfan.*]
ignore_missing_imports = True
1 change: 1 addition & 0 deletions dbgpt/model/adapter/embeddings_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddin
return TongYiEmbeddings(**tongyi_param)
elif model_name in ["proxy_qianfan"]:
from dbgpt.rag.embedding import QianFanEmbeddings

proxy_param = cast(ProxyEmbeddingParameters, param)
qianfan_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend:
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
metadata={
"tags": "privacy",
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
}
},
)
proxy_api_version: Optional[str] = field(
default=None,
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/rag/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
JinaEmbeddings,
OllamaEmbeddings,
OpenAPIEmbeddings,
TongYiEmbeddings,
QianFanEmbeddings,
TongYiEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401

Expand All @@ -34,5 +34,5 @@
"TongYiEmbeddings",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
"QianFanEmbeddings"
"QianFanEmbeddings",
]
28 changes: 17 additions & 11 deletions dbgpt/rag/embedding/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,16 +926,18 @@ def embed_query(self, text: str) -> List[float]:

class QianFanEmbeddings(BaseModel, Embeddings):
"""Baidu Qianfan Embeddings embedding models.
Embed:
.. code-block:: python
# embed the documents
vectors = embeddings.embed_documents([text1, text2, ...])
Embed:
.. code-block:: python
# embed the query
vectors = embeddings.embed_query(text)
# embed the documents
vectors = embeddings.embed_documents([text1, text2, ...])
# embed the query
vectors = embeddings.embed_query(text)
""" # noqa: E501

""" # noqa: E501
client: Any
chunk_size: int = 16
endpoint: str = ""
Expand All @@ -950,7 +952,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
"""Model name
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
for now, we support Embedding-V1 and
for now, we support Embedding-V1 and
- Embedding-V1 (默认模型)
- bge-large-en
- bge-large-zh
Expand All @@ -962,7 +964,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
default="text-embedding-v1", description="The name of the model to use."
)
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""

model_kwargs: Dict[str, Any] = Field(default_factory=dict)
Expand All @@ -983,7 +985,10 @@ def __init__(self, **kwargs):
model_name = kwargs.get("model_name")

if not qianfan_ak or not qianfan_sk or not model_name:
raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.")
raise ValueError(
"API key, API secret, and model name are required to initialize "
"QianFanEmbeddings."
)

params = {
"model": model_name,
Expand All @@ -996,6 +1001,7 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)

def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a QianFan embedding model."""
resp = self.embed_documents([text])
return resp[0]

Expand All @@ -1011,7 +1017,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i: i + self.chunk_size]
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def default_requires():
"chardet",
"sentencepiece",
"ollama",
"qianfan"
"qianfan",
]
setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["rag"]
Expand Down

0 comments on commit f72db23

Please sign in to comment.