@@ -327,7 +327,7 @@ def delta(txt):
327327 self .set_output ("content" , answer )
328328
329329 @timeout (int (os .environ .get ("COMPONENT_EXEC_TIMEOUT" , 10 * 60 )))
330- def _invoke (self , ** kwargs ):
330+ async def _invoke_async (self , ** kwargs ):
331331 if self .check_if_canceled ("LLM processing" ):
332332 return
333333
@@ -338,22 +338,25 @@ def clean_formated_answer(ans: str) -> str:
338338
339339 prompt , msg , _ = self ._prepare_prompt_variables ()
340340 error : str = ""
341- output_structure = None
341+ output_structure = None
342342 try :
343- output_structure = self ._param .outputs [' structured' ]
343+ output_structure = self ._param .outputs [" structured" ]
344344 except Exception :
345345 pass
346346 if output_structure and isinstance (output_structure , dict ) and output_structure .get ("properties" ) and len (output_structure ["properties" ]) > 0 :
347- schema = json .dumps (output_structure , ensure_ascii = False , indent = 2 )
348- prompt += structured_output_prompt (schema )
349- for _ in range (self ._param .max_retries + 1 ):
347+ schema = json .dumps (output_structure , ensure_ascii = False , indent = 2 )
348+ prompt_with_schema = prompt + structured_output_prompt (schema )
349+ for _ in range (self ._param .max_retries + 1 ):
350350 if self .check_if_canceled ("LLM processing" ):
351351 return
352352
353- _ , msg = message_fit_in ([{"role" : "system" , "content" : prompt }, * msg ], int (self .chat_mdl .max_length * 0.97 ))
353+ _ , msg_fit = message_fit_in (
354+ [{"role" : "system" , "content" : prompt_with_schema }, * deepcopy (msg )],
355+ int (self .chat_mdl .max_length * 0.97 ),
356+ )
354357 error = ""
355- ans = self ._generate ( msg )
356- msg .pop (0 )
358+ ans = await self ._generate_async ( msg_fit )
359+ msg_fit .pop (0 )
357360 if ans .find ("**ERROR**" ) >= 0 :
358361 logging .error (f"LLM response error: { ans } " )
359362 error = ans
@@ -362,26 +365,31 @@ def clean_formated_answer(ans: str) -> str:
362365 self .set_output ("structured" , json_repair .loads (clean_formated_answer (ans )))
363366 return
364367 except Exception :
365- msg .append ({"role" : "user" , "content" : "The answer can't not be parsed as JSON" })
368+ msg_fit .append ({"role" : "user" , "content" : "The answer can't not be parsed as JSON" })
366369 error = "The answer can't not be parsed as JSON"
367370 if error :
368371 self .set_output ("_ERROR" , error )
369372 return
370373
371374 downstreams = self ._canvas .get_component (self ._id )["downstream" ] if self ._canvas .get_component (self ._id ) else []
372375 ex = self .exception_handler ()
373- if any ([self ._canvas .get_component_obj (cid ).component_name .lower ()== "message" for cid in downstreams ]) and not (ex and ex ["goto" ]):
374- self .set_output ("content" , partial (self ._stream_output_async , prompt , msg ))
376+ if any ([self ._canvas .get_component_obj (cid ).component_name .lower () == "message" for cid in downstreams ]) and not (
377+ ex and ex ["goto" ]
378+ ):
379+ self .set_output ("content" , partial (self ._stream_output_async , prompt , deepcopy (msg )))
375380 return
376381
377- for _ in range (self ._param .max_retries + 1 ):
382+ error = ""
383+ for _ in range (self ._param .max_retries + 1 ):
378384 if self .check_if_canceled ("LLM processing" ):
379385 return
380386
381- _ , msg = message_fit_in ([{"role" : "system" , "content" : prompt }, * msg ], int (self .chat_mdl .max_length * 0.97 ))
387+ _ , msg_fit = message_fit_in (
388+ [{"role" : "system" , "content" : prompt }, * deepcopy (msg )], int (self .chat_mdl .max_length * 0.97 )
389+ )
382390 error = ""
383- ans = self ._generate ( msg )
384- msg .pop (0 )
391+ ans = await self ._generate_async ( msg_fit )
392+ msg_fit .pop (0 )
385393 if ans .find ("**ERROR**" ) >= 0 :
386394 logging .error (f"LLM response error: { ans } " )
387395 error = ans
@@ -395,23 +403,9 @@ def clean_formated_answer(ans: str) -> str:
395403 else :
396404 self .set_output ("_ERROR" , error )
397405
398- def _stream_output (self , prompt , msg ):
399- _ , msg = message_fit_in ([{"role" : "system" , "content" : prompt }, * msg ], int (self .chat_mdl .max_length * 0.97 ))
400- answer = ""
401- for ans in self ._generate_streamly (msg ):
402- if self .check_if_canceled ("LLM streaming" ):
403- return
404-
405- if ans .find ("**ERROR**" ) >= 0 :
406- if self .get_exception_default_value ():
407- self .set_output ("content" , self .get_exception_default_value ())
408- yield self .get_exception_default_value ()
409- else :
410- self .set_output ("_ERROR" , ans )
411- return
412- yield ans
413- answer += ans
414- self .set_output ("content" , answer )
406+ @timeout (int (os .environ .get ("COMPONENT_EXEC_TIMEOUT" , 10 * 60 )))
407+ def _invoke (self , ** kwargs ):
408+ return asyncio .run (self ._invoke_async (** kwargs ))
415409
416410 def add_memory (self , user :str , assist :str , func_name : str , params : dict , results : str , user_defined_prompt :dict = {}):
417411 summ = tool_call_summary (self .chat_mdl , func_name , params , results , user_defined_prompt )
0 commit comments