11"""SAP GenAI Hub model provider.
22
33- Docs: https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/consume-generative-ai-models-using-sap-ai-core#aws-bedrock
4- - SDK Reference: https://help.sap.com/doc/generative -ai-hub- sdk/CLOUD/en-US/_reference/gen_ai_hub.html
4+ - SDK Reference: https://help.sap.com/doc/sap -ai-sdk-gen /CLOUD/en-US/_reference/gen_ai_hub.html
55"""
66
77import asyncio
3737 from gen_ai_hub .proxy .native .amazon .clients import Session
3838except ImportError as e :
3939 raise ImportError (
40- "SAP GenAI Hub SDK is not installed. Please install it with: pip install 'generative -ai-hub- sdk[all]'"
40+ "SAP GenAI Hub SDK is not installed. Please install it with: pip install 'sap -ai-sdk-gen [all]'"
4141 ) from e
4242
4343logger = logging .getLogger (__name__ )
@@ -94,7 +94,9 @@ def __init__(
9494 Args:
9595 **model_config: Configuration options for the SAP GenAI Hub model.
9696 """
97- self .config = SAPGenAIHubModel .SAPGenAIHubConfig (model_id = DEFAULT_SAP_GENAI_HUB_MODEL_ID )
97+ self .config = SAPGenAIHubModel .SAPGenAIHubConfig (
98+ model_id = DEFAULT_SAP_GENAI_HUB_MODEL_ID
99+ )
98100 self .update_config (** model_config )
99101
100102 logger .debug ("config=<%s> | initializing" , self .config )
@@ -166,13 +168,19 @@ def _format_request(
166168 """
167169 # Format request based on model type
168170 if self ._is_nova_model ():
169- return self ._format_nova_request (messages , tool_specs , system_prompt_content , tool_choice )
171+ return self ._format_nova_request (
172+ messages , tool_specs , system_prompt_content , tool_choice
173+ )
170174 elif self ._is_claude_model ():
171- return self ._format_claude_request (messages , tool_specs , system_prompt_content , tool_choice )
175+ return self ._format_claude_request (
176+ messages , tool_specs , system_prompt_content , tool_choice
177+ )
172178 elif self ._is_titan_embed_model ():
173179 return self ._format_titan_embed_request (messages )
174180 else :
175- raise ValueError (f"model_id=<{ self .config ['model_id' ]} > | unsupported model" )
181+ raise ValueError (
182+ f"model_id=<{ self .config ['model_id' ]} > | unsupported model"
183+ )
176184
177185 def _format_nova_request (
178186 self ,
@@ -218,7 +226,10 @@ def _format_nova_request(
218226 }
219227
220228 # Add additional arguments if provided
221- if "additional_args" in self .config and self .config ["additional_args" ] is not None :
229+ if (
230+ "additional_args" in self .config
231+ and self .config ["additional_args" ] is not None
232+ ):
222233 request .update (self .config ["additional_args" ])
223234
224235 return request
@@ -243,7 +254,9 @@ def _format_claude_request(
243254 """
244255 # For Claude models, we use the same format as Nova models
245256 # since we're using the converse API for both
246- return self ._format_nova_request (messages , tool_specs , system_prompt_content , tool_choice )
257+ return self ._format_nova_request (
258+ messages , tool_specs , system_prompt_content , tool_choice
259+ )
247260
248261 def _format_titan_embed_request (self , messages : Messages ) -> dict [str , Any ]:
249262 """Format a request for Amazon Titan Embedding models.
@@ -272,7 +285,10 @@ def _format_titan_embed_request(self, messages: Messages) -> dict[str, Any]:
272285 }
273286
274287 # Add additional arguments if provided
275- if "additional_args" in self .config and self .config ["additional_args" ] is not None :
288+ if (
289+ "additional_args" in self .config
290+ and self .config ["additional_args" ] is not None
291+ ):
276292 request .update (self .config ["additional_args" ])
277293
278294 return request
@@ -297,7 +313,9 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent:
297313 # For any other type, convert to string and wrap
298314 return {"contentBlockDelta" : {"delta" : {"text" : str (event )}}}
299315
300- def _convert_streaming_response (self , stream_response : Any ) -> Iterable [StreamEvent ]:
316+ def _convert_streaming_response (
317+ self , stream_response : Any
318+ ) -> Iterable [StreamEvent ]:
301319 """Convert a streaming response to the standardized streaming format.
302320
303321 Args:
@@ -388,7 +406,9 @@ def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEv
388406
389407 # If this is a messageStop event, we're done
390408 if "messageStop" in event :
391- logger .debug ("received messageStop event from stream" )
409+ logger .debug (
410+ "received messageStop event from stream"
411+ )
392412 return
393413 else :
394414 # Format unknown events
@@ -402,7 +422,9 @@ def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEv
402422 event_count ,
403423 )
404424 else :
405- logger .debug ("stream response not iterable, treating as single response" )
425+ logger .debug (
426+ "stream response not iterable, treating as single response"
427+ )
406428 yield {"messageStart" : {"role" : "assistant" }}
407429 yield self ._format_chunk (stream_response )
408430
@@ -423,7 +445,9 @@ def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEv
423445 )
424446 raise e
425447
426- async def _convert_non_streaming_to_streaming (self , response : dict [str , Any ]) -> AsyncGenerator [StreamEvent , None ]:
448+ def _convert_non_streaming_to_streaming (
449+ self , response : dict [str , Any ]
450+ ) -> Iterable [StreamEvent ]:
427451 """Convert a non-streaming response to the streaming format.
428452
429453 Args:
@@ -455,7 +479,11 @@ async def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) ->
455479 # For tool use, we need to yield the input as a delta
456480 input_value = json .dumps (content ["toolUse" ]["input" ])
457481
458- yield {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : input_value }}}}
482+ yield {
483+ "contentBlockDelta" : {
484+ "delta" : {"toolUse" : {"input" : input_value }}
485+ }
486+ }
459487 elif "text" in content :
460488 # Then yield the text as a delta
461489 yield {
@@ -492,7 +520,9 @@ async def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) ->
492520 if "embedding" in response :
493521 yield {
494522 "contentBlockDelta" : {
495- "delta" : {"text" : f"Embedding generated with { len (response ['embedding' ])} dimensions" },
523+ "delta" : {
524+ "text" : f"Embedding generated with { len (response ['embedding' ])} dimensions"
525+ },
496526 }
497527 }
498528
@@ -589,56 +619,60 @@ def _stream(
589619 """
590620 try :
591621 logger .debug ("formatting request" )
592- request = self ._format_request (messages , tool_specs , system_prompt_content , tool_choice )
622+ request = self ._format_request (
623+ messages , tool_specs , system_prompt_content , tool_choice
624+ )
593625
594626 logger .debug ("invoking model" )
595627 streaming = self .config .get ("streaming" , True )
596628
597629 if self ._is_nova_model () or self ._is_claude_model ():
598- if streaming :
599- # Use converse_stream for streaming responses
600- try :
601- logger .debug ("using converse_stream api" )
602- stream_response = self .client .converse_stream (** request )
603-
604- # Process all streaming events
605- event_count = 0
606- has_content = False
630+ # Try converse_stream first, fall back to converse if not supported
631+ try :
632+ logger .debug ("attempting converse_stream api" )
633+ stream_response = self .client .converse_stream (** request )
607634
608- for event in self ._convert_streaming_response (stream_response ):
609- event_count += 1
635+ # Process all streaming events
636+ event_count = 0
637+ has_content = False
610638
611- # Check if we have actual content
612- if "contentBlockDelta" in event :
613- has_content = True
639+ for event in self ._convert_streaming_response (stream_response ):
640+ event_count += 1
614641
615- callback (event )
642+ # Check if we have actual content
643+ if "contentBlockDelta" in event :
644+ has_content = True
616645
617- logger .debug (
618- "event_count=<%d>, has_content=<%s> | processed streaming events" ,
619- event_count ,
620- has_content ,
621- )
646+ callback (event )
622647
623- # If we didn't get any content, fallback to non-streaming
624- if event_count == 0 or not has_content :
625- logger .warning ("no content received from streaming, falling back to non-streaming" )
626- response = self .client .converse (** request )
627- for event in self ._convert_non_streaming_to_streaming (response ):
628- callback (event )
648+ logger .debug (
649+ "event_count=<%d>, has_content=<%s> | processed streaming events" ,
650+ event_count ,
651+ has_content ,
652+ )
629653
630- except (AttributeError , Exception ) as e :
631- # Fallback to non-streaming if converse_stream fails
632- logger .warning (
633- "error=<%s> | converse_stream failed, falling back to non-streaming" ,
634- e ,
654+ # If we didn't get any content, fallback to non-streaming
655+ if event_count == 0 or not has_content :
656+ logger .debug (
657+ "no content received from streaming, falling back to converse"
635658 )
636659 response = self .client .converse (** request )
637660 for event in self ._convert_non_streaming_to_streaming (response ):
638661 callback (event )
639- else :
640- # Non-streaming path
641- logger .debug ("using non-streaming converse api" )
662+
663+ except NotImplementedError as nie :
664+ # converse_stream not supported by this model/deployment, use converse
665+ logger .debug ("converse_stream not supported, using converse api" )
666+ response = self .client .converse (** request )
667+ for event in self ._convert_non_streaming_to_streaming (response ):
668+ callback (event )
669+
670+ except Exception as e :
671+ # Other errors - log and fallback to converse
672+ logger .debug (
673+ "error=<%s> | converse_stream failed, falling back to converse" ,
674+ e ,
675+ )
642676 response = self .client .converse (** request )
643677 for event in self ._convert_non_streaming_to_streaming (response ):
644678 callback (event )
@@ -647,18 +681,26 @@ def _stream(
647681 if streaming :
648682 # Try streaming for Titan models
649683 try :
650- logger .debug ("using invoke_model_with_response_stream for titan" )
651- stream_response = self .client .invoke_model_with_response_stream (** request )
684+ logger .debug (
685+ "using invoke_model_with_response_stream for titan"
686+ )
687+ stream_response = self .client .invoke_model_with_response_stream (
688+ ** request
689+ )
652690
653691 event_count = 0
654692 for event in self ._convert_streaming_response (stream_response ):
655693 event_count += 1
656694 callback (event )
657695
658696 if event_count == 0 :
659- logger .warning ("no events from titan streaming, falling back to non-streaming" )
697+ logger .warning (
698+ "no events from titan streaming, falling back to non-streaming"
699+ )
660700 response = self .client .invoke_model (** request )
661- for event in self ._convert_non_streaming_to_streaming (response ):
701+ for event in self ._convert_non_streaming_to_streaming (
702+ response
703+ ):
662704 callback (event )
663705
664706 except (AttributeError , Exception ) as e :
@@ -684,7 +726,10 @@ def _stream(
684726 raise ModelThrottledException (error_message ) from e
685727
686728 # Handle context window overflow
687- if any (overflow_message in error_message for overflow_message in CONTEXT_WINDOW_OVERFLOW_MESSAGES ):
729+ if any (
730+ overflow_message in error_message
731+ for overflow_message in CONTEXT_WINDOW_OVERFLOW_MESSAGES
732+ ):
688733 logger .warning ("sap genai hub threw context window overflow error" )
689734 raise ContextWindowOverflowException (e ) from e
690735
@@ -733,7 +778,9 @@ async def structured_output(
733778 stop_reason , messages , _ , _ = event ["stop" ]
734779
735780 if stop_reason != "tool_use" :
736- raise ValueError (f'Model returned stop_reason: { stop_reason } instead of "tool_use".' )
781+ raise ValueError (
782+ f'Model returned stop_reason: { stop_reason } instead of "tool_use".'
783+ )
737784
738785 content = messages ["content" ]
739786 output_response : dict [str , Any ] | None = None
@@ -746,6 +793,8 @@ async def structured_output(
746793 continue
747794
748795 if output_response is None :
749- raise ValueError ("No valid tool use or tool use input was found in the response." )
796+ raise ValueError (
797+ "No valid tool use or tool use input was found in the response."
798+ )
750799
751800 yield {"output" : output_model (** output_response )}
0 commit comments