@@ -5,7 +5,12 @@ import { generateUniqueId } from "../lib/generateUniqueId.js";
5
5
import { InferenceClient } from "@huggingface/inference" ;
6
6
import type { ChatCompletionInputMessage , ChatCompletionInputMessageChunkType } from "@huggingface/tasks" ;
7
7
8
- import { type Response as OpenAIResponse } from "openai/resources/responses/responses" ;
8
+ import type {
9
+ Response ,
10
+ ResponseStreamEvent ,
11
+ ResponseOutputItem ,
12
+ ResponseContentPartAddedEvent ,
13
+ } from "openai/resources/responses/responses" ;
9
14
10
15
export const postCreateResponse = async (
11
16
req : ValidatedRequest < CreateResponseParams > ,
@@ -33,27 +38,189 @@ export const postCreateResponse = async (
33
38
content :
34
39
typeof item . content === "string"
35
40
? item . content
36
- : item . content . map ( ( content ) => {
37
- if ( content . type === "input_image" ) {
38
- return {
39
- type : "image_url" as ChatCompletionInputMessageChunkType ,
40
- image_url : {
41
- url : content . image_url ,
42
- } ,
43
- } ;
44
- }
45
- // content.type must be "input_text" at this point
46
- return {
47
- type : "text" as ChatCompletionInputMessageChunkType ,
48
- text : content . text ,
49
- } ;
50
- } ) ,
41
+ : item . content
42
+ . map ( ( content ) => {
43
+ switch ( content . type ) {
44
+ case "input_image" :
45
+ return {
46
+ type : "image_url" as ChatCompletionInputMessageChunkType ,
47
+ image_url : {
48
+ url : content . image_url ,
49
+ } ,
50
+ } ;
51
+ case "output_text" :
52
+ return {
53
+ type : "text" as ChatCompletionInputMessageChunkType ,
54
+ text : content . text ,
55
+ } ;
56
+ case "refusal" :
57
+ return undefined ;
58
+ case "input_text" :
59
+ return {
60
+ type : "text" as ChatCompletionInputMessageChunkType ,
61
+ text : content . text ,
62
+ } ;
63
+ }
64
+ } )
65
+ . filter ( ( item ) => item !== undefined ) ,
51
66
} ) )
52
67
) ;
53
68
} else {
54
69
messages . push ( { role : "user" , content : req . body . input } ) ;
55
70
}
56
71
72
+ const payload = {
73
+ model : req . body . model ,
74
+ messages : messages ,
75
+ temperature : req . body . temperature ,
76
+ top_p : req . body . top_p ,
77
+ stream : req . body . stream ,
78
+ } ;
79
+
80
+ const responseObject : Omit <
81
+ Response ,
82
+ "incomplete_details" | "metadata" | "output_text" | "parallel_tool_calls" | "tool_choice" | "tools"
83
+ > = {
84
+ object : "response" ,
85
+ id : generateUniqueId ( "resp" ) ,
86
+ status : "in_progress" ,
87
+ error : null ,
88
+ instructions : req . body . instructions ,
89
+ model : req . body . model ,
90
+ temperature : req . body . temperature ,
91
+ top_p : req . body . top_p ,
92
+ created_at : new Date ( ) . getTime ( ) ,
93
+ output : [ ] ,
94
+ } ;
95
+
96
+ if ( req . body . stream ) {
97
+ res . setHeader ( "Content-Type" , "text/event-stream" ) ;
98
+ res . setHeader ( "Connection" , "keep-alive" ) ;
99
+ let sequenceNumber = 0 ;
100
+
101
+ // Emit events in sequence
102
+ const emitEvent = ( event : ResponseStreamEvent ) => {
103
+ res . write ( `data: ${ JSON . stringify ( event ) } \n\n` ) ;
104
+ } ;
105
+
106
+ try {
107
+ // Response created event
108
+ emitEvent ( {
109
+ type : "response.created" ,
110
+ response : responseObject as Response ,
111
+ sequence_number : sequenceNumber ++ ,
112
+ } ) ;
113
+
114
+ // Response in progress event
115
+ emitEvent ( {
116
+ type : "response.in_progress" ,
117
+ response : responseObject as Response ,
118
+ sequence_number : sequenceNumber ++ ,
119
+ } ) ;
120
+
121
+ const stream = client . chatCompletionStream ( payload ) ;
122
+
123
+ const outputObject : ResponseOutputItem = {
124
+ id : generateUniqueId ( "msg" ) ,
125
+ type : "message" ,
126
+ role : "assistant" ,
127
+ status : "in_progress" ,
128
+ content : [ ] ,
129
+ } ;
130
+ responseObject . output = [ outputObject ] ;
131
+
132
+ // Response output item added event
133
+ emitEvent ( {
134
+ type : "response.output_item.added" ,
135
+ output_index : 0 ,
136
+ item : outputObject ,
137
+ sequence_number : sequenceNumber ++ ,
138
+ } ) ;
139
+
140
+ // Response content part added event
141
+ const contentPart : ResponseContentPartAddedEvent [ "part" ] = {
142
+ type : "output_text" ,
143
+ text : "" ,
144
+ annotations : [ ] ,
145
+ } ;
146
+ outputObject . content . push ( contentPart ) ;
147
+
148
+ emitEvent ( {
149
+ type : "response.content_part.added" ,
150
+ item_id : outputObject . id ,
151
+ output_index : 0 ,
152
+ content_index : 0 ,
153
+ part : contentPart ,
154
+ sequence_number : sequenceNumber ++ ,
155
+ } ) ;
156
+
157
+ for await ( const chunk of stream ) {
158
+ if ( chunk . choices [ 0 ] . delta . content ) {
159
+ contentPart . text += chunk . choices [ 0 ] . delta . content ;
160
+
161
+ // Response output text delta event
162
+ emitEvent ( {
163
+ type : "response.output_text.delta" ,
164
+ item_id : outputObject . id ,
165
+ output_index : 0 ,
166
+ content_index : 0 ,
167
+ delta : chunk . choices [ 0 ] . delta . content ,
168
+ sequence_number : sequenceNumber ++ ,
169
+ } ) ;
170
+ }
171
+ }
172
+
173
+ // Response output text done event
174
+ emitEvent ( {
175
+ type : "response.output_text.done" ,
176
+ item_id : outputObject . id ,
177
+ output_index : 0 ,
178
+ content_index : 0 ,
179
+ text : contentPart . text ,
180
+ sequence_number : sequenceNumber ++ ,
181
+ } ) ;
182
+
183
+ // Response content part done event
184
+ emitEvent ( {
185
+ type : "response.content_part.done" ,
186
+ item_id : outputObject . id ,
187
+ output_index : 0 ,
188
+ content_index : 0 ,
189
+ part : contentPart ,
190
+ sequence_number : sequenceNumber ++ ,
191
+ } ) ;
192
+
193
+ // Response output item done event
194
+ outputObject . status = "completed" ;
195
+ emitEvent ( {
196
+ type : "response.output_item.done" ,
197
+ output_index : 0 ,
198
+ item : outputObject ,
199
+ sequence_number : sequenceNumber ++ ,
200
+ } ) ;
201
+
202
+ // Response completed event
203
+ responseObject . status = "completed" ;
204
+ emitEvent ( {
205
+ type : "response.completed" ,
206
+ response : responseObject as Response ,
207
+ sequence_number : sequenceNumber ++ ,
208
+ } ) ;
209
+ } catch ( streamError : any ) {
210
+ console . error ( "Error in streaming chat completion:" , streamError ) ;
211
+
212
+ emitEvent ( {
213
+ type : "error" ,
214
+ code : null ,
215
+ message : streamError . message || "An error occurred while streaming from inference server." ,
216
+ param : null ,
217
+ sequence_number : sequenceNumber ++ ,
218
+ } ) ;
219
+ }
220
+ res . end ( ) ;
221
+ return ;
222
+ }
223
+
57
224
try {
58
225
const chatCompletionResponse = await client . chatCompletion ( {
59
226
model : req . body . model ,
@@ -62,37 +229,24 @@ export const postCreateResponse = async (
62
229
top_p : req . body . top_p ,
63
230
} ) ;
64
231
65
- const responseObject : Omit <
66
- OpenAIResponse ,
67
- "incomplete_details" | "metadata" | "output_text" | "parallel_tool_calls" | "tool_choice" | "tools"
68
- > = {
69
- object : "response" ,
70
- id : generateUniqueId ( "resp" ) ,
71
- status : "completed" ,
72
- error : null ,
73
- instructions : req . body . instructions ,
74
- model : req . body . model ,
75
- temperature : req . body . temperature ,
76
- top_p : req . body . top_p ,
77
- created_at : chatCompletionResponse . created ,
78
- output : chatCompletionResponse . choices [ 0 ] . message . content
79
- ? [
80
- {
81
- id : generateUniqueId ( "msg" ) ,
82
- type : "message" ,
83
- role : "assistant" ,
84
- status : "completed" ,
85
- content : [
86
- {
87
- type : "output_text" ,
88
- text : chatCompletionResponse . choices [ 0 ] . message . content ,
89
- annotations : [ ] ,
90
- } ,
91
- ] ,
92
- } ,
93
- ]
94
- : [ ] ,
95
- } ;
232
+ responseObject . status = "completed" ;
233
+ responseObject . output = chatCompletionResponse . choices [ 0 ] . message . content
234
+ ? [
235
+ {
236
+ id : generateUniqueId ( "msg" ) ,
237
+ type : "message" ,
238
+ role : "assistant" ,
239
+ status : "completed" ,
240
+ content : [
241
+ {
242
+ type : "output_text" ,
243
+ text : chatCompletionResponse . choices [ 0 ] . message . content ,
244
+ annotations : [ ] ,
245
+ } ,
246
+ ] ,
247
+ } ,
248
+ ]
249
+ : [ ] ;
96
250
97
251
res . json ( responseObject ) ;
98
252
} catch ( error ) {
0 commit comments