-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathapi_utils.py
More file actions
70 lines (60 loc) · 2.67 KB
/
api_utils.py
File metadata and controls
70 lines (60 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# api_utils.py
from typing import List, Tuple
import requests
from openai import OpenAI
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.embeddings.base import Embeddings
from config import API_CONFIG, EMBEDDING_CONFIG
class LocalEmbeddings(Embeddings):
def __init__(self, base_url: str, model: str):
self.base_url = base_url
self.model = model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
url = f"{self.base_url}/embeddings"
embeddings = []
for text in texts:
response = requests.post(url, json={"input": text, "model": self.model})
embeddings.append(response.json()["data"][0]["embedding"])
return embeddings
def embed_query(self, text: str) -> List[float]:
url = f"{self.base_url}/embeddings"
response = requests.post(url, json={"input": text, "model": self.model})
return response.json()["data"][0]["embedding"]
def clean_api_response(response: str, api_type: str) -> str:
"""清理API响应"""
if api_type == "DeepSeek":
return response.replace("<|end▁of▁sentence|>", "").strip()
return response.strip()
def test_api_connection(api_type: str, api_key: str, model_name: str) -> Tuple[bool, str]:
"""测试API连接"""
try:
base_url = API_CONFIG[api_type.lower()]['base_url']
client = OpenAI(api_key=api_key, base_url=base_url)
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": "Hi!"}],
max_tokens=10
)
raw_response = response.choices[0].message.content
cleaned_response = clean_api_response(raw_response, api_type)
return True, cleaned_response
except Exception as e:
return False, str(e)
def test_embeddings(embed_type: str, api_key: str = None, base_url: str = None, model: str = None) -> Tuple[bool, str]:
"""测试嵌入模型"""
try:
if embed_type == "本地":
embeddings = LocalEmbeddings(
base_url=base_url or EMBEDDING_CONFIG['local']['base_url'],
model=model or EMBEDDING_CONFIG['local']['model']
)
else:
embeddings = OpenAIEmbeddings(api_key=api_key)
test_embedding = embeddings.embed_query("test")
return True, f"成功生成嵌入向量,维度: {len(test_embedding)}"
except Exception as e:
return False, str(e)
def get_api_client(api_type: str, api_key: str, model_name: str) -> OpenAI:
"""获取API客户端"""
base_url = API_CONFIG[api_type.lower()]['base_url']
return OpenAI(api_key=api_key, base_url=base_url)