Skip to content

Commit e493786

Browse files
committed
feat(demohouse/shopping): mock vdb
1 parent bf2932c commit e493786

File tree

4 files changed

+3041
-0
lines changed

4 files changed

+3041
-0
lines changed
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import json
2+
import logging
3+
import os
4+
from typing import AsyncIterable
5+
from arkitect.launcher.local.serve import launch_serve
6+
from arkitect.telemetry.trace import task
7+
from arkitect.types.llm.model import ArkChatRequest, ArkChatParameters
8+
from volcenginesdkarkruntime.types.chat import ChatCompletionChunk
9+
10+
from arkitect.core.component.context.context import Context
11+
12+
from arkitect.core.runtime import Response
13+
from volcenginesdkarkruntime.types.chat.chat_completion_chunk import Choice, ChoiceDelta
14+
15+
from arkitect.core.component.context.model import State
16+
17+
from arkitect.types.llm.model import ChatCompletionMessageToolCallParam
18+
from vdb import vector_search
19+
20+
logger = logging.getLogger(__name__)
21+
22+
DOUBAO_VLM_ENDPOINT = "doubao-1-5-vision-pro-32k-250115"
23+
24+
25+
@task()
26+
async def default_model_calling(
27+
request: ArkChatRequest,
28+
) -> AsyncIterable[ChatCompletionChunk]:
29+
parameters = ArkChatParameters(**request.__dict__)
30+
image_urls = [
31+
content.get("image_url", {}).get("url", "")
32+
for message in request.messages
33+
if isinstance(message.content, list)
34+
for content in message.content
35+
if content.get("image_url", {}).get("url", "")
36+
]
37+
image_url = image_urls[-1] if len(image_urls) > 0 else ""
38+
39+
async def modify_url_hook(
40+
state: State, param: ChatCompletionMessageToolCallParam
41+
) -> ChatCompletionMessageToolCallParam:
42+
arguments = json.loads(param["function"]["arguments"])
43+
arguments["image_url"] = image_url
44+
param["function"]["arguments"] = json.dumps(arguments)
45+
return param
46+
47+
async with Context(
48+
model=DOUBAO_VLM_ENDPOINT, tools=[vector_search], parameters=parameters
49+
) as ctx:
50+
ctx.tool_hooks.update(vector_search=[modify_url_hook])
51+
stream = await ctx.completions.create(
52+
messages=[m.model_dump() for m in request.messages], stream=True
53+
)
54+
tool_call = False
55+
async for chunk in stream:
56+
if tool_call and chunk.choices:
57+
tool_result = ctx.get_latest_message()
58+
chunk.choices.append(
59+
Choice(
60+
role="tool",
61+
delta=ChoiceDelta(content=tool_result.get("content")),
62+
index=len(chunk.choices),
63+
)
64+
)
65+
tool_call = False
66+
yield chunk
67+
if chunk.choices and chunk.choices[0].finish_reason == "tool_calls":
68+
tool_call = True
69+
70+
71+
@task()
72+
async def main(request: ArkChatRequest) -> AsyncIterable[Response]:
73+
async for resp in default_model_calling(request):
74+
yield resp
75+
76+
77+
if __name__ == "__main__":
78+
port = os.getenv("_FAAS_RUNTIME_PORT")
79+
launch_serve(
80+
package_path="main",
81+
port=int(port) if port else 8888,
82+
health_check_path="/v1/ping",
83+
endpoint_path="/api/v3/bots/chat/completions",
84+
)
+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
3+
import tos
4+
5+
from httpx import Timeout
6+
from tos import HttpMethodType
7+
from volcenginesdkarkruntime import AsyncArk
8+
from volcenginesdkarkruntime.types.multimodal_embedding import (
9+
MultimodalEmbeddingContentPartTextParam,
10+
MultimodalEmbeddingResponse,
11+
MultimodalEmbeddingContentPartImageParam,
12+
)
13+
from volcengine.viking_db import *
14+
from volcenginesdkarkruntime.types.multimodal_embedding.embedding_content_part_image_param import (
15+
ImageURL,
16+
)
17+
18+
19+
COLLECTION_NAME = "shopping_demo"
20+
INDEX_NAME = "shopping_demo"
21+
MODEL_NAME = "doubao-embedding-vision-241215"
22+
LIMIT = 6
23+
SCORE_THRESHOLD = 300
24+
25+
vikingdb_service = VikingDBService(
26+
host="api-vikingdb.volces.com",
27+
region="cn-beijing",
28+
scheme="https",
29+
connection_timeout=30,
30+
socket_timeout=30,
31+
)
32+
vikingdb_service.set_ak(os.environ.get("VOLC_ACCESSKEY"))
33+
vikingdb_service.set_sk(os.environ.get("VOLC_SECRETKEY"))
34+
35+
tos_client = tos.TosClientV2(
36+
os.getenv("VOLC_ACCESSKEY"),
37+
os.getenv("VOLC_SECRETKEY"),
38+
"tos-cn-beijing.volces.com",
39+
"cn-beijing",
40+
)
41+
42+
43+
async def vector_search(text: str, image_url: str) -> str:
44+
"""获取商品相关信息,当想要了解商品信息,比如价格,详细介绍,销量,评价时调用该工具
45+
46+
Args:
47+
text: 商品的描述信息
48+
image_url: 固定填写为<image_url>
49+
"""
50+
client = AsyncArk(timeout=Timeout(connect=1.0, timeout=60.0))
51+
embedding_input = [MultimodalEmbeddingContentPartTextParam(type="text", text=text)]
52+
if image_url != "":
53+
embedding_input.append(
54+
MultimodalEmbeddingContentPartImageParam(
55+
type="image_url", image_url=ImageURL(url=image_url)
56+
)
57+
)
58+
resp: MultimodalEmbeddingResponse = await client.multimodal_embeddings.create(
59+
model=MODEL_NAME,
60+
input=embedding_input,
61+
)
62+
embedding = resp.data.get("embedding", [])
63+
index = await vikingdb_service.async_get_index(COLLECTION_NAME, INDEX_NAME)
64+
retrieve = await index.async_search_by_vector(vector=embedding, limit=LIMIT)
65+
retrieve_fields = [
66+
json.loads(result.fields.get("data"))
67+
for result in retrieve
68+
if result.score > SCORE_THRESHOLD
69+
]
70+
mock_data = [
71+
{
72+
"名称": item.get("Name", ""),
73+
"类别": item.get("category", ""),
74+
"子类别": item.get("sub_category", ""),
75+
"价格": item.get("price", "99"),
76+
"销量": item.get("sales", "999"),
77+
"商品链接": tos_client.pre_signed_url(
78+
http_method=HttpMethodType.Http_Method_Get,
79+
bucket="shopping",
80+
key=item.get("key", ""),
81+
expires=600,
82+
).signed_url,
83+
}
84+
for item in retrieve_fields
85+
]
86+
return json.dumps(mock_data, ensure_ascii=False)

0 commit comments

Comments
 (0)