Skip to content

Commit a16f2cc

Browse files
committed
tested and working
1 parent 7984e5a commit a16f2cc

File tree

5 files changed

+470
-64
lines changed

5 files changed

+470
-64
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ sagemaker = [
5454
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0",
5555
"openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface
5656
]
57+
sap_genai_hub = [
58+
"sap-ai-sdk-gen[all]>=5.0.0,<6.0.0",
59+
]
5760
otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"]
5861
docs = [
5962
"sphinx>=5.0.0,<9.0.0",
@@ -69,7 +72,7 @@ a2a = [
6972
"fastapi>=0.115.12,<1.0.0",
7073
"starlette>=0.46.2,<1.0.0",
7174
]
72-
all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"]
75+
all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,sap_genai_hub,otel]"]
7376

7477
dev = [
7578
"commitizen>=4.4.0,<5.0.0",

src/strands/models/sap_genai_hub.py

Lines changed: 106 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""SAP GenAI Hub model provider.
22
33
- Docs: https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/consume-generative-ai-models-using-sap-ai-core#aws-bedrock
4-
- SDK Reference: https://help.sap.com/doc/generative-ai-hub-sdk/CLOUD/en-US/_reference/gen_ai_hub.html
4+
- SDK Reference: https://help.sap.com/doc/sap-ai-sdk-gen/CLOUD/en-US/_reference/gen_ai_hub.html
55
"""
66

77
import asyncio
@@ -37,7 +37,7 @@
3737
from gen_ai_hub.proxy.native.amazon.clients import Session
3838
except ImportError as e:
3939
raise ImportError(
40-
"SAP GenAI Hub SDK is not installed. Please install it with: pip install 'generative-ai-hub-sdk[all]'"
40+
"SAP GenAI Hub SDK is not installed. Please install it with: pip install 'sap-ai-sdk-gen[all]'"
4141
) from e
4242

4343
logger = logging.getLogger(__name__)
@@ -94,7 +94,9 @@ def __init__(
9494
Args:
9595
**model_config: Configuration options for the SAP GenAI Hub model.
9696
"""
97-
self.config = SAPGenAIHubModel.SAPGenAIHubConfig(model_id=DEFAULT_SAP_GENAI_HUB_MODEL_ID)
97+
self.config = SAPGenAIHubModel.SAPGenAIHubConfig(
98+
model_id=DEFAULT_SAP_GENAI_HUB_MODEL_ID
99+
)
98100
self.update_config(**model_config)
99101

100102
logger.debug("config=<%s> | initializing", self.config)
@@ -166,13 +168,19 @@ def _format_request(
166168
"""
167169
# Format request based on model type
168170
if self._is_nova_model():
169-
return self._format_nova_request(messages, tool_specs, system_prompt_content, tool_choice)
171+
return self._format_nova_request(
172+
messages, tool_specs, system_prompt_content, tool_choice
173+
)
170174
elif self._is_claude_model():
171-
return self._format_claude_request(messages, tool_specs, system_prompt_content, tool_choice)
175+
return self._format_claude_request(
176+
messages, tool_specs, system_prompt_content, tool_choice
177+
)
172178
elif self._is_titan_embed_model():
173179
return self._format_titan_embed_request(messages)
174180
else:
175-
raise ValueError(f"model_id=<{self.config['model_id']}> | unsupported model")
181+
raise ValueError(
182+
f"model_id=<{self.config['model_id']}> | unsupported model"
183+
)
176184

177185
def _format_nova_request(
178186
self,
@@ -218,7 +226,10 @@ def _format_nova_request(
218226
}
219227

220228
# Add additional arguments if provided
221-
if "additional_args" in self.config and self.config["additional_args"] is not None:
229+
if (
230+
"additional_args" in self.config
231+
and self.config["additional_args"] is not None
232+
):
222233
request.update(self.config["additional_args"])
223234

224235
return request
@@ -243,7 +254,9 @@ def _format_claude_request(
243254
"""
244255
# For Claude models, we use the same format as Nova models
245256
# since we're using the converse API for both
246-
return self._format_nova_request(messages, tool_specs, system_prompt_content, tool_choice)
257+
return self._format_nova_request(
258+
messages, tool_specs, system_prompt_content, tool_choice
259+
)
247260

248261
def _format_titan_embed_request(self, messages: Messages) -> dict[str, Any]:
249262
"""Format a request for Amazon Titan Embedding models.
@@ -272,7 +285,10 @@ def _format_titan_embed_request(self, messages: Messages) -> dict[str, Any]:
272285
}
273286

274287
# Add additional arguments if provided
275-
if "additional_args" in self.config and self.config["additional_args"] is not None:
288+
if (
289+
"additional_args" in self.config
290+
and self.config["additional_args"] is not None
291+
):
276292
request.update(self.config["additional_args"])
277293

278294
return request
@@ -297,7 +313,9 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent:
297313
# For any other type, convert to string and wrap
298314
return {"contentBlockDelta": {"delta": {"text": str(event)}}}
299315

300-
def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEvent]:
316+
def _convert_streaming_response(
317+
self, stream_response: Any
318+
) -> Iterable[StreamEvent]:
301319
"""Convert a streaming response to the standardized streaming format.
302320
303321
Args:
@@ -388,7 +406,9 @@ def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEv
388406

389407
# If this is a messageStop event, we're done
390408
if "messageStop" in event:
391-
logger.debug("received messageStop event from stream")
409+
logger.debug(
410+
"received messageStop event from stream"
411+
)
392412
return
393413
else:
394414
# Format unknown events
@@ -402,7 +422,9 @@ def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEv
402422
event_count,
403423
)
404424
else:
405-
logger.debug("stream response not iterable, treating as single response")
425+
logger.debug(
426+
"stream response not iterable, treating as single response"
427+
)
406428
yield {"messageStart": {"role": "assistant"}}
407429
yield self._format_chunk(stream_response)
408430

@@ -423,7 +445,9 @@ def _convert_streaming_response(self, stream_response: Any) -> Iterable[StreamEv
423445
)
424446
raise e
425447

426-
async def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]:
448+
def _convert_non_streaming_to_streaming(
449+
self, response: dict[str, Any]
450+
) -> Iterable[StreamEvent]:
427451
"""Convert a non-streaming response to the streaming format.
428452
429453
Args:
@@ -455,7 +479,11 @@ async def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) ->
455479
# For tool use, we need to yield the input as a delta
456480
input_value = json.dumps(content["toolUse"]["input"])
457481

458-
yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}}
482+
yield {
483+
"contentBlockDelta": {
484+
"delta": {"toolUse": {"input": input_value}}
485+
}
486+
}
459487
elif "text" in content:
460488
# Then yield the text as a delta
461489
yield {
@@ -492,7 +520,9 @@ async def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) ->
492520
if "embedding" in response:
493521
yield {
494522
"contentBlockDelta": {
495-
"delta": {"text": f"Embedding generated with {len(response['embedding'])} dimensions"},
523+
"delta": {
524+
"text": f"Embedding generated with {len(response['embedding'])} dimensions"
525+
},
496526
}
497527
}
498528

@@ -589,56 +619,60 @@ def _stream(
589619
"""
590620
try:
591621
logger.debug("formatting request")
592-
request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice)
622+
request = self._format_request(
623+
messages, tool_specs, system_prompt_content, tool_choice
624+
)
593625

594626
logger.debug("invoking model")
595627
streaming = self.config.get("streaming", True)
596628

597629
if self._is_nova_model() or self._is_claude_model():
598-
if streaming:
599-
# Use converse_stream for streaming responses
600-
try:
601-
logger.debug("using converse_stream api")
602-
stream_response = self.client.converse_stream(**request)
603-
604-
# Process all streaming events
605-
event_count = 0
606-
has_content = False
630+
# Try converse_stream first, fall back to converse if not supported
631+
try:
632+
logger.debug("attempting converse_stream api")
633+
stream_response = self.client.converse_stream(**request)
607634

608-
for event in self._convert_streaming_response(stream_response):
609-
event_count += 1
635+
# Process all streaming events
636+
event_count = 0
637+
has_content = False
610638

611-
# Check if we have actual content
612-
if "contentBlockDelta" in event:
613-
has_content = True
639+
for event in self._convert_streaming_response(stream_response):
640+
event_count += 1
614641

615-
callback(event)
642+
# Check if we have actual content
643+
if "contentBlockDelta" in event:
644+
has_content = True
616645

617-
logger.debug(
618-
"event_count=<%d>, has_content=<%s> | processed streaming events",
619-
event_count,
620-
has_content,
621-
)
646+
callback(event)
622647

623-
# If we didn't get any content, fallback to non-streaming
624-
if event_count == 0 or not has_content:
625-
logger.warning("no content received from streaming, falling back to non-streaming")
626-
response = self.client.converse(**request)
627-
for event in self._convert_non_streaming_to_streaming(response):
628-
callback(event)
648+
logger.debug(
649+
"event_count=<%d>, has_content=<%s> | processed streaming events",
650+
event_count,
651+
has_content,
652+
)
629653

630-
except (AttributeError, Exception) as e:
631-
# Fallback to non-streaming if converse_stream fails
632-
logger.warning(
633-
"error=<%s> | converse_stream failed, falling back to non-streaming",
634-
e,
654+
# If we didn't get any content, fallback to non-streaming
655+
if event_count == 0 or not has_content:
656+
logger.debug(
657+
"no content received from streaming, falling back to converse"
635658
)
636659
response = self.client.converse(**request)
637660
for event in self._convert_non_streaming_to_streaming(response):
638661
callback(event)
639-
else:
640-
# Non-streaming path
641-
logger.debug("using non-streaming converse api")
662+
663+
except NotImplementedError as nie:
664+
# converse_stream not supported by this model/deployment, use converse
665+
logger.debug("converse_stream not supported, using converse api")
666+
response = self.client.converse(**request)
667+
for event in self._convert_non_streaming_to_streaming(response):
668+
callback(event)
669+
670+
except Exception as e:
671+
# Other errors - log and fallback to converse
672+
logger.debug(
673+
"error=<%s> | converse_stream failed, falling back to converse",
674+
e,
675+
)
642676
response = self.client.converse(**request)
643677
for event in self._convert_non_streaming_to_streaming(response):
644678
callback(event)
@@ -647,18 +681,26 @@ def _stream(
647681
if streaming:
648682
# Try streaming for Titan models
649683
try:
650-
logger.debug("using invoke_model_with_response_stream for titan")
651-
stream_response = self.client.invoke_model_with_response_stream(**request)
684+
logger.debug(
685+
"using invoke_model_with_response_stream for titan"
686+
)
687+
stream_response = self.client.invoke_model_with_response_stream(
688+
**request
689+
)
652690

653691
event_count = 0
654692
for event in self._convert_streaming_response(stream_response):
655693
event_count += 1
656694
callback(event)
657695

658696
if event_count == 0:
659-
logger.warning("no events from titan streaming, falling back to non-streaming")
697+
logger.warning(
698+
"no events from titan streaming, falling back to non-streaming"
699+
)
660700
response = self.client.invoke_model(**request)
661-
for event in self._convert_non_streaming_to_streaming(response):
701+
for event in self._convert_non_streaming_to_streaming(
702+
response
703+
):
662704
callback(event)
663705

664706
except (AttributeError, Exception) as e:
@@ -684,7 +726,10 @@ def _stream(
684726
raise ModelThrottledException(error_message) from e
685727

686728
# Handle context window overflow
687-
if any(overflow_message in error_message for overflow_message in CONTEXT_WINDOW_OVERFLOW_MESSAGES):
729+
if any(
730+
overflow_message in error_message
731+
for overflow_message in CONTEXT_WINDOW_OVERFLOW_MESSAGES
732+
):
688733
logger.warning("sap genai hub threw context window overflow error")
689734
raise ContextWindowOverflowException(e) from e
690735

@@ -733,7 +778,9 @@ async def structured_output(
733778
stop_reason, messages, _, _ = event["stop"]
734779

735780
if stop_reason != "tool_use":
736-
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
781+
raise ValueError(
782+
f'Model returned stop_reason: {stop_reason} instead of "tool_use".'
783+
)
737784

738785
content = messages["content"]
739786
output_response: dict[str, Any] | None = None
@@ -746,6 +793,8 @@ async def structured_output(
746793
continue
747794

748795
if output_response is None:
749-
raise ValueError("No valid tool use or tool use input was found in the response.")
796+
raise ValueError(
797+
"No valid tool use or tool use input was found in the response."
798+
)
750799

751800
yield {"output": output_model(**output_response)}

tests/strands/models/test_sap_genai_hub.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def test_initialization_with_default_config(self):
1919
model = SAPGenAIHubModel()
2020

2121
assert model.config["model_id"] == "amazon--nova-lite"
22-
mock_session.return_value.client.assert_called_once_with(model_name="amazon--nova-lite")
22+
mock_session.return_value.client.assert_called_once_with(
23+
model_name="amazon--nova-lite"
24+
)
2325

2426
def test_initialization_with_custom_config(self):
2527
"""Test model initialization with custom configuration."""
@@ -106,12 +108,16 @@ def test_format_nova_request(self):
106108
mock_client = MagicMock()
107109
mock_session.return_value.client.return_value = mock_client
108110

109-
model = SAPGenAIHubModel(model_id="amazon--nova-lite", temperature=0.7, max_tokens=1000)
111+
model = SAPGenAIHubModel(
112+
model_id="amazon--nova-lite", temperature=0.7, max_tokens=1000
113+
)
110114

111115
messages = [{"role": "user", "content": [{"text": "Hello"}]}]
112116
system_prompt_content = [{"text": "You are a helpful assistant"}]
113117

114-
request = model._format_nova_request(messages=messages, system_prompt_content=system_prompt_content)
118+
request = model._format_nova_request(
119+
messages=messages, system_prompt_content=system_prompt_content
120+
)
115121

116122
assert request["messages"] == messages
117123
assert request["system"] == system_prompt_content
@@ -135,7 +141,9 @@ def test_format_nova_request_with_tools(self):
135141
}
136142
]
137143

138-
request = model._format_nova_request(messages=messages, tool_specs=tool_specs)
144+
request = model._format_nova_request(
145+
messages=messages, tool_specs=tool_specs
146+
)
139147

140148
assert "toolConfig" in request
141149
assert len(request["toolConfig"]["tools"]) == 1

0 commit comments

Comments
 (0)