Skip to content

Shopping (multi modal & vdb) #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions demohouse/shopping/backend/code/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the 【火山方舟】原型应用软件自用许可协议
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.volcengine.com/docs/82379/1433703
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import time
from typing import AsyncIterable
from arkitect.launcher.local.serve import launch_serve
from arkitect.telemetry.trace import task
from arkitect.types.llm.model import (
ArkChatRequest,
ArkChatParameters,
ArkChatCompletionChunk,
BotUsage,
ActionDetail,
ToolDetail,
)
from volcenginesdkarkruntime.types.chat import ChatCompletionChunk

from arkitect.core.component.context.context import Context

from arkitect.core.runtime import Response

from arkitect.core.component.context.model import State

from arkitect.types.llm.model import ChatCompletionMessageToolCallParam
from vdb import vector_search

logger = logging.getLogger(__name__)

DOUBAO_VLM_ENDPOINT = "doubao-1-5-pro-32k-250115"


@task()
def get_vector_search_result_chunk(result: str) -> ArkChatCompletionChunk:
return ArkChatCompletionChunk(
id="",
created=int(time.time()),
model="",
object="chat.completion.chunk",
choices=[],
bot_usage=BotUsage(
action_details=[
ActionDetail(
name="vector_search",
count=1,
tool_details=[
ToolDetail(
name="vector_search", input=None, output=json.loads(result)
)
],
)
]
),
)


@task()
async def default_model_calling(
request: ArkChatRequest,
) -> AsyncIterable[ArkChatCompletionChunk | ChatCompletionChunk]:
parameters = ArkChatParameters(**request.__dict__)
image_urls = [
content.get("image_url", {}).get("url", "")
for message in request.messages
if isinstance(message.content, list)
for content in message.content
if content.get("image_url", {}).get("url", "")
]
image_url = image_urls[-1] if len(image_urls) > 0 else ""

# only search for relevant product
if request.metadata and request.metadata.get("search"):
result = await vector_search("", image_url)
yield get_vector_search_result_chunk(result)
return

async def modify_url_hook(
state: State, param: ChatCompletionMessageToolCallParam
) -> ChatCompletionMessageToolCallParam:
arguments = json.loads(param["function"]["arguments"])
arguments["image_url"] = image_url
param["function"]["arguments"] = json.dumps(arguments)
return param

async with Context(
model=DOUBAO_VLM_ENDPOINT, tools=[vector_search], parameters=parameters
) as ctx:
ctx.tool_hooks.update(vector_search=[modify_url_hook])
stream = await ctx.completions.create(
messages=[m.model_dump() for m in request.messages], stream=True
)
tool_call = False
async for chunk in stream:
if tool_call and chunk.choices:
tool_result = ctx.get_latest_message()
yield get_vector_search_result_chunk(
str(tool_result.get("content", ""))
)
tool_call = False
yield chunk
if chunk.choices and chunk.choices[0].finish_reason == "tool_calls":
tool_call = True


@task()
async def main(request: ArkChatRequest) -> AsyncIterable[Response]:
async for resp in default_model_calling(request):
yield resp


if __name__ == "__main__":
port = os.getenv("_FAAS_RUNTIME_PORT")
launch_serve(
package_path="main",
port=int(port) if port else 8888,
health_check_path="/v1/ping",
endpoint_path="/api/v3/bots/chat/completions",
)
99 changes: 99 additions & 0 deletions demohouse/shopping/backend/code/vdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the 【火山方舟】原型应用软件自用许可协议
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.volcengine.com/docs/82379/1433703
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import tos

from httpx import Timeout
from tos import HttpMethodType
from volcenginesdkarkruntime import AsyncArk
from volcenginesdkarkruntime.types.multimodal_embedding import (
MultimodalEmbeddingContentPartTextParam,
MultimodalEmbeddingResponse,
MultimodalEmbeddingContentPartImageParam,
)
from volcengine.viking_db import *
from volcenginesdkarkruntime.types.multimodal_embedding.embedding_content_part_image_param import (
ImageURL,
)


COLLECTION_NAME = "shopping_demo"
INDEX_NAME = "shopping_demo"
MODEL_NAME = "doubao-embedding-vision-241215"
LIMIT = 6
SCORE_THRESHOLD = 300

vikingdb_service = VikingDBService(
host="api-vikingdb.volces.com",
region="cn-beijing",
scheme="https",
connection_timeout=30,
socket_timeout=30,
)
vikingdb_service.set_ak(os.environ.get("VOLC_ACCESSKEY"))
vikingdb_service.set_sk(os.environ.get("VOLC_SECRETKEY"))

tos_client = tos.TosClientV2(
os.getenv("VOLC_ACCESSKEY"),
os.getenv("VOLC_SECRETKEY"),
"tos-cn-beijing.volces.com",
"cn-beijing",
)


async def vector_search(text: str, image_url: str) -> str:
"""获取商品相关信息,当想要了解商品信息,比如价格,详细介绍,销量,评价时调用该工具

Args:
text: 商品的描述信息
image_url: 固定填写为<image_url>
"""
client = AsyncArk(timeout=Timeout(connect=1.0, timeout=60.0))
embedding_input = []
if text != "":
embedding_input = [MultimodalEmbeddingContentPartTextParam(type="text", text=text)]
if image_url != "":
embedding_input.append(
MultimodalEmbeddingContentPartImageParam(
type="image_url", image_url=ImageURL(url=image_url)
)
)
resp: MultimodalEmbeddingResponse = await client.multimodal_embeddings.create(
model=MODEL_NAME,
input=embedding_input,
)
embedding = resp.data.get("embedding", [])
index = await vikingdb_service.async_get_index(COLLECTION_NAME, INDEX_NAME)
retrieve = await index.async_search_by_vector(vector=embedding, limit=LIMIT)
retrieve_fields = [
json.loads(result.fields.get("data"))
for result in retrieve
if result.score > SCORE_THRESHOLD
]
mock_data = [
{
"名称": item.get("Name", ""),
"类别": item.get("category", ""),
"子类别": item.get("sub_category", ""),
"价格": item.get("price", "99"),
"销量": item.get("sales", "999"),
"图片链接": tos_client.pre_signed_url(
http_method=HttpMethodType.Http_Method_Get,
bucket="shopping",
key=item.get("key", ""),
expires=600,
).signed_url,
}
for item in retrieve_fields
]
return json.dumps(mock_data, ensure_ascii=False)
Loading