1
1
import json
2
2
import logging
3
3
import os
4
+ import time
4
5
from typing import AsyncIterable
5
6
from arkitect .launcher .local .serve import launch_serve
6
7
from arkitect .telemetry .trace import task
7
- from arkitect .types .llm .model import ArkChatRequest , ArkChatParameters
8
+ from arkitect .types .llm .model import (
9
+ ArkChatRequest ,
10
+ ArkChatParameters ,
11
+ ArkChatCompletionChunk ,
12
+ BotUsage ,
13
+ ActionDetail ,
14
+ ToolDetail ,
15
+ )
8
16
from volcenginesdkarkruntime .types .chat import ChatCompletionChunk
9
17
10
18
from arkitect .core .component .context .context import Context
11
19
12
20
from arkitect .core .runtime import Response
13
- from volcenginesdkarkruntime .types .chat .chat_completion_chunk import Choice , ChoiceDelta
14
21
15
22
from arkitect .core .component .context .model import State
16
23
19
26
20
27
logger = logging .getLogger (__name__ )
21
28
22
- DOUBAO_VLM_ENDPOINT = "doubao-1-5-vision-pro-32k-250115"
29
+ DOUBAO_VLM_ENDPOINT = "doubao-1-5-pro-32k-250115"
30
+
31
+
32
+ @task ()
33
+ def get_vector_search_result_chunk (result : str ) -> ArkChatCompletionChunk :
34
+ return ArkChatCompletionChunk (
35
+ id = "" ,
36
+ created = int (time .time ()),
37
+ model = "" ,
38
+ object = "chat.completion.chunk" ,
39
+ choices = [],
40
+ bot_usage = BotUsage (
41
+ action_details = [
42
+ ActionDetail (
43
+ name = "vector_search" ,
44
+ count = 1 ,
45
+ tool_details = [
46
+ ToolDetail (
47
+ name = "vector_search" , input = None , output = json .loads (result )
48
+ )
49
+ ],
50
+ )
51
+ ]
52
+ ),
53
+ )
23
54
24
55
25
56
@task ()
26
57
async def default_model_calling (
27
58
request : ArkChatRequest ,
28
- ) -> AsyncIterable [ChatCompletionChunk ]:
59
+ ) -> AsyncIterable [ArkChatCompletionChunk | ChatCompletionChunk ]:
29
60
parameters = ArkChatParameters (** request .__dict__ )
30
61
image_urls = [
31
62
content .get ("image_url" , {}).get ("url" , "" )
@@ -36,6 +67,12 @@ async def default_model_calling(
36
67
]
37
68
image_url = image_urls [- 1 ] if len (image_urls ) > 0 else ""
38
69
70
+ # only search for relevant product
71
+ if request .metadata and request .metadata .get ("search" ):
72
+ result = await vector_search ("" , image_url )
73
+ yield get_vector_search_result_chunk (result )
74
+ return
75
+
39
76
async def modify_url_hook (
40
77
state : State , param : ChatCompletionMessageToolCallParam
41
78
) -> ChatCompletionMessageToolCallParam :
@@ -55,12 +92,8 @@ async def modify_url_hook(
55
92
async for chunk in stream :
56
93
if tool_call and chunk .choices :
57
94
tool_result = ctx .get_latest_message ()
58
- chunk .choices .append (
59
- Choice (
60
- role = "tool" ,
61
- delta = ChoiceDelta (content = tool_result .get ("content" )),
62
- index = len (chunk .choices ),
63
- )
95
+ yield get_vector_search_result_chunk (
96
+ str (tool_result .get ("content" , "" ))
64
97
)
65
98
tool_call = False
66
99
yield chunk
0 commit comments