|
| 1 | +import os |
| 2 | + |
| 3 | +import tos |
| 4 | + |
| 5 | +from httpx import Timeout |
| 6 | +from tos import HttpMethodType |
| 7 | +from volcenginesdkarkruntime import AsyncArk |
| 8 | +from volcenginesdkarkruntime.types.multimodal_embedding import ( |
| 9 | + MultimodalEmbeddingContentPartTextParam, |
| 10 | + MultimodalEmbeddingResponse, |
| 11 | + MultimodalEmbeddingContentPartImageParam, |
| 12 | +) |
| 13 | +from volcengine.viking_db import * |
| 14 | +from volcenginesdkarkruntime.types.multimodal_embedding.embedding_content_part_image_param import ( |
| 15 | + ImageURL, |
| 16 | +) |
| 17 | + |
| 18 | + |
| 19 | +COLLECTION_NAME = "shopping_demo" |
| 20 | +INDEX_NAME = "shopping_demo" |
| 21 | +MODEL_NAME = "doubao-embedding-vision-241215" |
| 22 | +LIMIT = 6 |
| 23 | +SCORE_THRESHOLD = 300 |
| 24 | + |
| 25 | +vikingdb_service = VikingDBService( |
| 26 | + host="api-vikingdb.volces.com", |
| 27 | + region="cn-beijing", |
| 28 | + scheme="https", |
| 29 | + connection_timeout=30, |
| 30 | + socket_timeout=30, |
| 31 | +) |
| 32 | +vikingdb_service.set_ak(os.environ.get("VOLC_ACCESSKEY")) |
| 33 | +vikingdb_service.set_sk(os.environ.get("VOLC_SECRETKEY")) |
| 34 | + |
| 35 | +tos_client = tos.TosClientV2( |
| 36 | + os.getenv("VOLC_ACCESSKEY"), |
| 37 | + os.getenv("VOLC_SECRETKEY"), |
| 38 | + "tos-cn-beijing.volces.com", |
| 39 | + "cn-beijing", |
| 40 | +) |
| 41 | + |
| 42 | + |
| 43 | +async def vector_search(text: str, image_url: str) -> str: |
| 44 | + """获取商品相关信息,当想要了解商品信息,比如价格,详细介绍,销量,评价时调用该工具 |
| 45 | +
|
| 46 | + Args: |
| 47 | + text: 商品的描述信息 |
| 48 | + image_url: 固定填写为<image_url> |
| 49 | + """ |
| 50 | + client = AsyncArk(timeout=Timeout(connect=1.0, timeout=60.0)) |
| 51 | + embedding_input = [MultimodalEmbeddingContentPartTextParam(type="text", text=text)] |
| 52 | + if image_url != "": |
| 53 | + embedding_input.append( |
| 54 | + MultimodalEmbeddingContentPartImageParam( |
| 55 | + type="image_url", image_url=ImageURL(url=image_url) |
| 56 | + ) |
| 57 | + ) |
| 58 | + resp: MultimodalEmbeddingResponse = await client.multimodal_embeddings.create( |
| 59 | + model=MODEL_NAME, |
| 60 | + input=embedding_input, |
| 61 | + ) |
| 62 | + embedding = resp.data.get("embedding", []) |
| 63 | + index = await vikingdb_service.async_get_index(COLLECTION_NAME, INDEX_NAME) |
| 64 | + retrieve = await index.async_search_by_vector(vector=embedding, limit=LIMIT) |
| 65 | + retrieve_fields = [ |
| 66 | + json.loads(result.fields.get("data")) |
| 67 | + for result in retrieve |
| 68 | + if result.score > SCORE_THRESHOLD |
| 69 | + ] |
| 70 | + mock_data = [ |
| 71 | + { |
| 72 | + "名称": item.get("Name", ""), |
| 73 | + "类别": item.get("category", ""), |
| 74 | + "子类别": item.get("sub_category", ""), |
| 75 | + "价格": item.get("price", "99"), |
| 76 | + "销量": item.get("sales", "999"), |
| 77 | + "商品链接": tos_client.pre_signed_url( |
| 78 | + http_method=HttpMethodType.Http_Method_Get, |
| 79 | + bucket="shopping", |
| 80 | + key=item.get("key", ""), |
| 81 | + expires=600, |
| 82 | + ).signed_url, |
| 83 | + } |
| 84 | + for item in retrieve_fields |
| 85 | + ] |
| 86 | + return json.dumps(mock_data, ensure_ascii=False) |
0 commit comments