-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2个新特性 #653
base: main
Are you sure you want to change the base?
2个新特性 #653
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -377,23 +377,94 @@ def __init__( | |
model = self.envs["OPENAI_MODEL"] | ||
super().__init__(lang_in, lang_out, model) | ||
self.options = {"temperature": 0} # 随机采样可能会打断公式标记 | ||
self.client = openai.OpenAI( | ||
base_url=base_url or self.envs["OPENAI_BASE_URL"], | ||
api_key=api_key or self.envs["OPENAI_API_KEY"], | ||
) | ||
|
||
# 处理多个API key | ||
base_url = base_url or self.envs["OPENAI_BASE_URL"] | ||
api_keys = [] | ||
|
||
# 优先使用传入的api_key | ||
if api_key: | ||
api_keys = [k.strip() for k in api_key.split(",") if k.strip()] | ||
# 如果没有传入api_key,则使用环境变量中的api_key | ||
elif self.envs["OPENAI_API_KEY"]: | ||
api_keys = [ | ||
k.strip() for k in self.envs["OPENAI_API_KEY"].split(",") if k.strip() | ||
] | ||
|
||
if not api_keys: | ||
raise ValueError("No API key provided") | ||
|
||
# 为每个API key创建一个client | ||
self.clients = [] | ||
for key in api_keys: | ||
self.clients.append( | ||
openai.OpenAI( | ||
base_url=base_url, | ||
api_key=key, | ||
) | ||
) | ||
self.current_client_index = 0 | ||
|
||
self.prompttext = prompt | ||
self.add_cache_impact_parameters("temperature", self.options["temperature"]) | ||
|
||
def do_translate(self, text) -> str: | ||
response = self.client.chat.completions.create( | ||
model=self.model, | ||
**self.options, | ||
messages=self.prompt(text, self.prompttext), | ||
) | ||
if not response.choices: | ||
if hasattr(response, "error"): | ||
raise ValueError("Error response from Service", response.error) | ||
return response.choices[0].message.content.strip() | ||
# 获取当前client | ||
client = self.clients[self.current_client_index] | ||
|
||
# 更新index为下一个client | ||
self.current_client_index = (self.current_client_index + 1) % len(self.clients) | ||
|
||
try: | ||
response = client.chat.completions.create( | ||
model=self.model, | ||
**self.options, | ||
messages=self.prompt(text, self.prompttext), | ||
) | ||
|
||
content = response.choices[0].message.content.strip() | ||
|
||
# 过滤掉<think>标签内的内容 | ||
if "<think>" in content and "</think>" in content: | ||
content = re.sub(r"^<think>.+?</think>", "", content, flags=re.DOTALL) | ||
|
||
return content.strip() | ||
|
||
except Exception as e: | ||
# 如果当前client失败,尝试使用其他client | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是不是把调用某个client发请求并获取响应的代码抽成函数,然后再递归调用自己来实现这个自动重试的逻辑比较好? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个其实我也没想好。我觉得比较好的设计是提供一种default的重试策略,然后不同的translator去覆盖,比如有的是429,那么重试应该降低请求频率,有的是安全围栏,那么应该换translator(比如google),而不是一味的连续重试,毕竟网络质量问题不是我遇到的最多的问题。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 先去掉吧,再开个issue追踪。整一套translator重试策略,并且允许配置多个translator感觉比较好。可以放到pdf2zh 2.0里做 |
||
original_index = self.current_client_index | ||
for _ in range(len(self.clients) - 1): | ||
try: | ||
client = self.clients[self.current_client_index] | ||
response = client.chat.completions.create( | ||
model=self.model, | ||
**self.options, | ||
messages=self.prompt(text, self.prompttext), | ||
) | ||
if not response.choices: | ||
if hasattr(response, "error"): | ||
raise ValueError( | ||
"Error response from Service", response.error | ||
) | ||
|
||
content = response.choices[0].message.content.strip() | ||
|
||
# 过滤掉<think>标签内的内容 | ||
if "<think>" in content and "</think>" in content: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里没必要再多一个if。如果原文不包含这两个tag,re.sub应该是啥都不会干吧。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我把这个if去掉,然后改成预编译的吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. re内部有缓存。另外你可以在初始化translator的时候compile一下re。 另外,判断字符串是否为子串开销也不小 |
||
content = re.sub( | ||
r"^<think>.+?</think>", "", content, flags=re.DOTALL | ||
) | ||
|
||
return content.strip() | ||
except: | ||
self.current_client_index = (self.current_client_index + 1) % len( | ||
self.clients | ||
) | ||
continue | ||
|
||
# 如果所有client都失败,恢复原始index并抛出异常 | ||
self.current_client_index = original_index | ||
raise e | ||
|
||
def get_formular_placeholder(self, id: int): | ||
return "{{v" + str(id) + "}}" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.current_client_index = (self.current_client_index + 1) % len(self.clients)
不是线程安全的。这句话会被多个线程并行调用。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
了解