Skip to content

Commit

Permalink
fix: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
csunny committed May 25, 2023
1 parent 310e1c4 commit b00a47a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pilot/server/vectordb_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def build_knowledge_prompt(query, docs, state):
prompt = state.get_prompt()
print("new prompt length:" + str(len(prompt)))

return prompt
return prompt
29 changes: 23 additions & 6 deletions pilot/server/webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,16 @@ def http_bot(
cfg.set_last_plugin_return(plugin_resp)
print(plugin_resp)
state.messages[-1][-1] = (
"Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
"Model推理信息:\n"
+ ai_response
+ "\n\nDB-GPT执行结果:\n"
+ plugin_resp
)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
except NotCommands as e:
print("命令执行:" + e.message)
state.messages[-1][-1] = (
"命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
"命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
Expand Down Expand Up @@ -422,18 +425,32 @@ def http_bot(

output = post_process_code(output)
state.messages[-1][-1] = output + "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return

except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return

state.messages[-1][-1] = state.messages[-1][-1][:-1]
Expand Down
20 changes: 10 additions & 10 deletions pilot/vector_store/milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(self, ctx: {}) -> None:
self.fields = []
self.alias = "default"


# use HNSW by default.
self.index_params = {
"metric_type": "L2",
Expand Down Expand Up @@ -105,17 +104,18 @@ def init_schema_and_load(self, vector_name, documents):
embeddings = self.embedding.embed_query(texts[0])

if utility.has_collection(self.collection_name):
self.col = Collection(
self.collection_name, using=self.alias
)
self.col = Collection(self.collection_name, using=self.alias)
self.fields = []
for x in self.col.schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
if (
x.dtype == DataType.FLOAT_VECTOR
or x.dtype == DataType.BINARY_VECTOR
):
self.vector_field = x.name
self._add_documents(texts, metadatas)
return self.collection_name
Expand All @@ -132,9 +132,7 @@ def init_schema_and_load(self, vector_name, documents):
for y in texts:
max_length = max(max_length, len(y))
# Create the text field
fields.append(
FieldSchema(text_field, DataType.VARCHAR, max_length= 65535)
)
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535))
# primary key field
fields.append(
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
Expand Down Expand Up @@ -252,7 +250,9 @@ def load_document(self, documents) -> None:
"""load document in vector database."""
# self.init_schema_and_load(self.collection_name, documents)
batch_size = 500
batched_list = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
batched_list = [
documents[i : i + batch_size] for i in range(0, len(documents), batch_size)
]
# docs = []
for doc_batch in batched_list:
self.init_schema_and_load(self.collection_name, doc_batch)
Expand Down Expand Up @@ -320,4 +320,4 @@ def _search(
return data[0], ret

def close(self):
connections.disconnect()
connections.disconnect()

0 comments on commit b00a47a

Please sign in to comment.