From 2409e51ba71ac57a196beba131371bd467465410 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 4 Aug 2025 14:07:25 -0700 Subject: [PATCH 1/3] Add optional bedrock model provider param --- .../setup/chat_completion_services.py | 2 +- .../connectors/ai/bedrock/bedrock_settings.py | 7 ++ .../ai/bedrock/services/bedrock_base.py | 9 ++ .../services/bedrock_chat_completion.py | 12 +- .../services/bedrock_text_completion.py | 21 +++- .../services/bedrock_text_embedding.py | 22 +++- .../model_provider/bedrock_model_provider.py | 46 +++++-- .../unit/connectors/ai/bedrock/conftest.py | 44 ++++++- .../services/test_bedrock_chat_completion.py | 57 ++++++++- .../services/test_bedrock_text_completion.py | 112 +++++++++++++++++- .../test_bedrock_text_embedding_generation.py | 68 ++++++++++- 11 files changed, 362 insertions(+), 38 deletions(-) diff --git a/python/samples/concepts/setup/chat_completion_services.py b/python/samples/concepts/setup/chat_completion_services.py index b3b01bbe0802..a1479230d2ca 100644 --- a/python/samples/concepts/setup/chat_completion_services.py +++ b/python/samples/concepts/setup/chat_completion_services.py @@ -213,7 +213,7 @@ def get_bedrock_chat_completion_service_and_request_settings() -> tuple[ """ from semantic_kernel.connectors.ai.bedrock import BedrockChatCompletion, BedrockChatPromptExecutionSettings - chat_service = BedrockChatCompletion(service_id=service_id, model_id="anthropic.claude-3-sonnet-20240229-v1:0") + chat_service = BedrockChatCompletion(service_id=service_id) request_settings = BedrockChatPromptExecutionSettings( # For model specific settings, specify them in the extension_data dictionary. # For example, for Cohere Command specific settings, refer to: diff --git a/python/semantic_kernel/connectors/ai/bedrock/bedrock_settings.py b/python/semantic_kernel/connectors/ai/bedrock/bedrock_settings.py index 2c5348a14676..2db708d19f9c 100644 --- a/python/semantic_kernel/connectors/ai/bedrock/bedrock_settings.py +++ b/python/semantic_kernel/connectors/ai/bedrock/bedrock_settings.py @@ -2,6 +2,7 @@ from typing import ClassVar +from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import BedrockModelProvider from semantic_kernel.kernel_pydantic import KernelBaseSettings from semantic_kernel.utils.feature_stage_decorator import experimental @@ -25,6 +26,11 @@ class BedrockSettings(KernelBaseSettings): (Env var BEDROCK_TEXT_MODEL_ID) - embedding_model_id: str | None - The Amazon Bedrock embedding model ID to use. (Env var BEDROCK_EMBEDDING_MODEL_ID) + - model_provider: BedrockModelProvider | None - The Bedrock model provider to use. + If not provided, the model provider will be extracted from the model ID. + When using an Application Inference Profile where the model provider is not part + of the model ID, this setting must be provided. + (Env var BEDROCK_MODEL_PROVIDER) """ env_prefix: ClassVar[str] = "BEDROCK_" @@ -32,3 +38,4 @@ class BedrockSettings(KernelBaseSettings): chat_model_id: str | None = None text_model_id: str | None = None embedding_model_id: str | None = None + model_provider: BedrockModelProvider | None = None diff --git a/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_base.py b/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_base.py index 39c208656a3b..afb032e402f4 100644 --- a/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_base.py +++ b/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_base.py @@ -5,6 +5,7 @@ import boto3 +from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import BedrockModelProvider from semantic_kernel.kernel_pydantic import KernelBaseModel @@ -19,11 +20,14 @@ class BedrockBase(KernelBaseModel, ABC): # Client: Use for model management bedrock_client: Any + bedrock_model_provider: BedrockModelProvider | None = None + def __init__( self, *, runtime_client: Any | None = None, client: Any | None = None, + bedrock_model_provider: BedrockModelProvider | None = None, **kwargs: Any, ) -> None: """Initialize the Amazon Bedrock Base Class. @@ -31,10 +35,15 @@ def __init__( Args: runtime_client: The Amazon Bedrock runtime client to use. client: The Amazon Bedrock client to use. + bedrock_model_provider: The Bedrock model provider to use. + If not provided, the model provider will be extracted from the model ID. + When using an Application Inference Profile where the model provider is not part + of the model ID, this setting must be provided. **kwargs: Additional keyword arguments. """ super().__init__( bedrock_runtime_client=runtime_client or boto3.client("bedrock-runtime"), bedrock_client=client or boto3.client("bedrock"), + bedrock_model_provider=bedrock_model_provider, **kwargs, ) diff --git a/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_chat_completion.py b/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_chat_completion.py index 46b660a3cc2f..77a73fcc800a 100644 --- a/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_chat_completion.py @@ -16,6 +16,7 @@ from semantic_kernel.connectors.ai.bedrock.bedrock_settings import BedrockSettings from semantic_kernel.connectors.ai.bedrock.services.bedrock_base import BedrockBase from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import ( + BedrockModelProvider, get_chat_completion_additional_model_request_fields, ) from semantic_kernel.connectors.ai.bedrock.services.model_provider.utils import ( @@ -36,10 +37,7 @@ from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole from semantic_kernel.contents.utils.finish_reason import FinishReason -from semantic_kernel.exceptions.service_exceptions import ( - ServiceInitializationError, - ServiceInvalidResponseError, -) +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidResponseError from semantic_kernel.utils.async_utils import run_in_executor from semantic_kernel.utils.telemetry.model_diagnostics.decorators import ( trace_chat_completion, @@ -60,6 +58,7 @@ class BedrockChatCompletion(BedrockBase, ChatCompletionClientBase): def __init__( self, model_id: str | None = None, + model_provider: BedrockModelProvider | None = None, service_id: str | None = None, runtime_client: Any | None = None, client: Any | None = None, @@ -70,6 +69,7 @@ def __init__( Args: model_id: The Amazon Bedrock chat model ID to use. + model_provider: The Bedrock model provider to use. service_id: The Service ID for the completion service. runtime_client: The Amazon Bedrock runtime client to use. client: The Amazon Bedrock client to use. @@ -79,6 +79,7 @@ def __init__( try: bedrock_settings = BedrockSettings( chat_model_id=model_id, + model_provider=model_provider, env_file_path=env_file_path, env_file_encoding=env_file_encoding, ) @@ -93,6 +94,7 @@ def __init__( service_id=service_id or bedrock_settings.chat_model_id, runtime_client=runtime_client, client=client, + bedrock_model_provider=bedrock_settings.model_provider, ) # region Overriding base class methods @@ -212,7 +214,7 @@ def _prepare_settings_for_request( "stopSequences": settings.stop, }), "additionalModelRequestFields": get_chat_completion_additional_model_request_fields( - self.ai_model_id, settings + self.ai_model_id, settings, model_provider=self.bedrock_model_provider ), } diff --git a/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_completion.py b/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_completion.py index d32982ffc2ff..27971361cb6e 100644 --- a/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_completion.py +++ b/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_completion.py @@ -17,6 +17,7 @@ from semantic_kernel.connectors.ai.bedrock.bedrock_settings import BedrockSettings from semantic_kernel.connectors.ai.bedrock.services.bedrock_base import BedrockBase from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import ( + BedrockModelProvider, get_text_completion_request_body, parse_streaming_text_completion_response, parse_text_completion_response, @@ -41,6 +42,7 @@ class BedrockTextCompletion(BedrockBase, TextCompletionClientBase): def __init__( self, model_id: str | None = None, + model_provider: BedrockModelProvider | None = None, service_id: str | None = None, runtime_client: Any | None = None, client: Any | None = None, @@ -51,6 +53,7 @@ def __init__( Args: model_id: The Amazon Bedrock text model ID to use. + model_provider: The Bedrock model provider to use. service_id: The Service ID for the text completion service. runtime_client: The Amazon Bedrock runtime client to use. client: The Amazon Bedrock client to use. @@ -60,6 +63,7 @@ def __init__( try: bedrock_settings = BedrockSettings( text_model_id=model_id, + model_provider=model_provider, env_file_path=env_file_path, env_file_encoding=env_file_encoding, ) @@ -74,6 +78,7 @@ def __init__( service_id=service_id or bedrock_settings.text_model_id, runtime_client=runtime_client, client=client, + bedrock_model_provider=bedrock_settings.model_provider, ) # region Overriding base class methods @@ -94,11 +99,17 @@ async def _inner_get_text_contents( settings = self.get_prompt_execution_settings_from_settings(settings) assert isinstance(settings, BedrockTextPromptExecutionSettings) # nosec - request_body = get_text_completion_request_body(self.ai_model_id, prompt, settings) + request_body = get_text_completion_request_body( + self.ai_model_id, + prompt, + settings, + model_provider=self.bedrock_model_provider, + ) response_body = await self._async_invoke_model(request_body) return parse_text_completion_response( self.ai_model_id, json.loads(response_body.get("body").read()), + model_provider=self.bedrock_model_provider, ) @override @@ -112,7 +123,12 @@ async def _inner_get_streaming_text_contents( settings = self.get_prompt_execution_settings_from_settings(settings) assert isinstance(settings, BedrockTextPromptExecutionSettings) # nosec - request_body = get_text_completion_request_body(self.ai_model_id, prompt, settings) + request_body = get_text_completion_request_body( + self.ai_model_id, + prompt, + settings, + model_provider=self.bedrock_model_provider, + ) response_stream = await self._async_invoke_model_stream(request_body) for event in response_stream.get("body"): chunk = event.get("chunk") @@ -120,6 +136,7 @@ async def _inner_get_streaming_text_contents( parse_streaming_text_completion_response( self.ai_model_id, json.loads(chunk.get("bytes").decode()), + model_provider=self.bedrock_model_provider, ) ] diff --git a/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_embedding.py b/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_embedding.py index 672b697d2ca5..71bffca7e10e 100644 --- a/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_embedding.py +++ b/python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_embedding.py @@ -20,6 +20,7 @@ from semantic_kernel.connectors.ai.bedrock.bedrock_settings import BedrockSettings from semantic_kernel.connectors.ai.bedrock.services.bedrock_base import BedrockBase from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import ( + BedrockModelProvider, get_text_embedding_request_body, parse_text_embedding_response, ) @@ -38,6 +39,7 @@ class BedrockTextEmbedding(BedrockBase, EmbeddingGeneratorBase): def __init__( self, model_id: str | None = None, + model_provider: BedrockModelProvider | None = None, service_id: str | None = None, runtime_client: Any | None = None, client: Any | None = None, @@ -48,6 +50,7 @@ def __init__( Args: model_id: The Amazon Bedrock text embedding model ID to use. + model_provider: The Bedrock model provider to use. service_id: The Service ID for the text embedding service. runtime_client: The Amazon Bedrock runtime client to use. client: The Amazon Bedrock client to use. @@ -57,6 +60,7 @@ def __init__( try: bedrock_settings = BedrockSettings( embedding_model_id=model_id, + model_provider=model_provider, env_file_path=env_file_path, env_file_encoding=env_file_encoding, ) @@ -71,6 +75,7 @@ def __init__( service_id=service_id or bedrock_settings.embedding_model_id, runtime_client=runtime_client, client=client, + bedrock_model_provider=bedrock_settings.model_provider, ) @override @@ -87,12 +92,25 @@ async def generate_embeddings( assert isinstance(settings, BedrockEmbeddingPromptExecutionSettings) # nosec results = await asyncio.gather(*[ - self._async_invoke_model(get_text_embedding_request_body(self.ai_model_id, text, settings)) + self._async_invoke_model( + get_text_embedding_request_body( + self.ai_model_id, + text, + settings, + model_provider=self.bedrock_model_provider, + ) + ) for text in texts ]) return array([ - array(parse_text_embedding_response(self.ai_model_id, json.loads(result.get("body").read()))) + array( + parse_text_embedding_response( + self.ai_model_id, + json.loads(result.get("body").read()), + model_provider=self.bedrock_model_provider, + ) + ) for result in results ]) diff --git a/python/semantic_kernel/connectors/ai/bedrock/services/model_provider/bedrock_model_provider.py b/python/semantic_kernel/connectors/ai/bedrock/services/model_provider/bedrock_model_provider.py index 8655361331e5..00a10febcf86 100644 --- a/python/semantic_kernel/connectors/ai/bedrock/services/model_provider/bedrock_model_provider.py +++ b/python/semantic_kernel/connectors/ai/bedrock/services/model_provider/bedrock_model_provider.py @@ -73,21 +73,34 @@ def to_model_provider(cls, model_id: str) -> "BedrockModelProvider": } -def get_text_completion_request_body(model_id: str, prompt: str, settings: BedrockTextPromptExecutionSettings) -> dict: +def get_text_completion_request_body( + model_id: str, + prompt: str, + settings: BedrockTextPromptExecutionSettings, + model_provider: BedrockModelProvider | None = None, +) -> dict: """Get the request body for text completion for Amazon Bedrock models.""" - model_provider = BedrockModelProvider.to_model_provider(model_id) + model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id) return TEXT_COMPLETION_REQUEST_BODY_MAPPING[model_provider](prompt, settings) -def parse_text_completion_response(model_id: str, response: dict) -> list[TextContent]: +def parse_text_completion_response( + model_id: str, + response: dict, + model_provider: BedrockModelProvider | None = None, +) -> list[TextContent]: """Parse the response from text completion for Amazon Bedrock models.""" - model_provider = BedrockModelProvider.to_model_provider(model_id) + model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id) return TEXT_COMPLETION_RESPONSE_MAPPING[model_provider](response, model_id) -def parse_streaming_text_completion_response(model_id: str, chunk: dict) -> StreamingTextContent: +def parse_streaming_text_completion_response( + model_id: str, + chunk: dict, + model_provider: BedrockModelProvider | None = None, +) -> StreamingTextContent: """Parse the response from streaming text completion for Amazon Bedrock models.""" - model_provider = BedrockModelProvider.to_model_provider(model_id) + model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id) return STREAMING_TEXT_COMPLETION_RESPONSE_MAPPING[model_provider](chunk, model_id) @@ -109,10 +122,12 @@ def parse_streaming_text_completion_response(model_id: str, chunk: dict) -> Stre def get_chat_completion_additional_model_request_fields( - model_id: str, settings: BedrockChatPromptExecutionSettings + model_id: str, + settings: BedrockChatPromptExecutionSettings, + model_provider: BedrockModelProvider | None = None, ) -> dict[str, Any] | None: """Get the additional model request fields for chat completion for Amazon Bedrock models.""" - model_provider = BedrockModelProvider.to_model_provider(model_id) + model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id) return CHAT_COMPLETION_ADDITIONAL_MODEL_REQUEST_FIELDS_MAPPING[model_provider](settings) @@ -134,16 +149,23 @@ def get_chat_completion_additional_model_request_fields( def get_text_embedding_request_body( - model_id: str, text: str, settings: BedrockEmbeddingPromptExecutionSettings + model_id: str, + text: str, + settings: BedrockEmbeddingPromptExecutionSettings, + model_provider: BedrockModelProvider | None = None, ) -> dict: """Get the request body for text embedding for Amazon Bedrock models.""" - model_provider = BedrockModelProvider.to_model_provider(model_id) + model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id) return TEXT_EMBEDDING_REQUEST_BODY_MAPPING[model_provider](text, settings) -def parse_text_embedding_response(model_id: str, response: dict) -> list[float]: +def parse_text_embedding_response( + model_id: str, + response: dict, + model_provider: BedrockModelProvider | None = None, +) -> list[float]: """Parse the response from text embedding for Amazon Bedrock models.""" - model_provider = BedrockModelProvider.to_model_provider(model_id) + model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id) return TEXT_EMBEDDING_RESPONSE_MAPPING[model_provider](response) diff --git a/python/tests/unit/connectors/ai/bedrock/conftest.py b/python/tests/unit/connectors/ai/bedrock/conftest.py index d5f8a44f4344..d37450b36946 100644 --- a/python/tests/unit/connectors/ai/bedrock/conftest.py +++ b/python/tests/unit/connectors/ai/bedrock/conftest.py @@ -46,6 +46,7 @@ def bedrock_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): "BEDROCK_TEXT_MODEL_ID": "env_test_text_model_id", "BEDROCK_CHAT_MODEL_ID": "env_test_chat_model_id", "BEDROCK_EMBEDDING_MODEL_ID": "env_test_embedding_model_id", + "BEDROCK_MODEL_PROVIDER": "amazon", } env_vars.update(override_env_param_dict) @@ -164,8 +165,22 @@ def output_text(): @pytest.fixture() -def mock_bedrock_text_completion_response(model_id: str, output_text: str): - model_provider = BedrockModelProvider.to_model_provider(model_id) +def model_provider(): + return BedrockModelProvider.AMAZON + + +@pytest.fixture() +def mock_bedrock_text_completion_response( + model_id: str, + output_text: str, + request, +): + # Check if model_provider fixture is requested by the test + model_provider = None + if "model_provider" in request.fixturenames: + model_provider = request.getfixturevalue("model_provider") + else: + model_provider = BedrockModelProvider.to_model_provider(model_id) match model_provider: case BedrockModelProvider.AMAZON: @@ -219,8 +234,17 @@ def mock_bedrock_text_completion_response(model_id: str, output_text: str): @pytest.fixture() -def mock_bedrock_streaming_text_completion_response(model_id: str, output_text: str): - model_provider = BedrockModelProvider.to_model_provider(model_id) +def mock_bedrock_streaming_text_completion_response( + model_id: str, + output_text: str, + request, +): + # Check if model_provider fixture is requested by the test + model_provider = None + if "model_provider" in request.fixturenames: + model_provider = request.getfixturevalue("model_provider") + else: + model_provider = BedrockModelProvider.to_model_provider(model_id) match model_provider: case BedrockModelProvider.AMAZON: @@ -250,8 +274,16 @@ def event_stream(events): @pytest.fixture() -def mock_bedrock_text_embedding_response(model_id: str): - model_provider = BedrockModelProvider.to_model_provider(model_id) +def mock_bedrock_text_embedding_response( + model_id: str, + request, +): + # Check if model_provider fixture is requested by the test + model_provider = None + if "model_provider" in request.fixturenames: + model_provider = request.getfixturevalue("model_provider") + else: + model_provider = BedrockModelProvider.to_model_provider(model_id) match model_provider: case BedrockModelProvider.AMAZON: diff --git a/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_chat_completion.py b/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_chat_completion.py index bd6182af7b92..efa702a43813 100644 --- a/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_chat_completion.py +++ b/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_chat_completion.py @@ -9,6 +9,7 @@ from semantic_kernel.connectors.ai.bedrock.bedrock_prompt_execution_settings import BedrockChatPromptExecutionSettings from semantic_kernel.connectors.ai.bedrock.services.bedrock_chat_completion import BedrockChatCompletion +from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import BedrockModelProvider from semantic_kernel.connectors.ai.completion_usage import CompletionUsage from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent @@ -16,10 +17,7 @@ from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole from semantic_kernel.contents.utils.finish_reason import FinishReason -from semantic_kernel.exceptions.service_exceptions import ( - ServiceInitializationError, - ServiceInvalidResponseError, -) +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidResponseError from tests.unit.connectors.ai.bedrock.conftest import MockBedrockClient, MockBedrockRuntimeClient # region init @@ -33,6 +31,9 @@ def test_bedrock_chat_completion_init(mock_client, bedrock_unit_test_env) -> Non assert bedrock_chat_completion.ai_model_id == bedrock_unit_test_env["BEDROCK_CHAT_MODEL_ID"] assert bedrock_chat_completion.service_id == bedrock_unit_test_env["BEDROCK_CHAT_MODEL_ID"] + assert bedrock_chat_completion.bedrock_model_provider == BedrockModelProvider( + bedrock_unit_test_env["BEDROCK_MODEL_PROVIDER"] + ) assert bedrock_chat_completion.bedrock_client is not None assert bedrock_chat_completion.bedrock_runtime_client is not None @@ -93,6 +94,16 @@ def test_bedrock_chat_completion_init_custom_runtime_client(mock_client, bedrock assert isinstance(bedrock_chat_completion.bedrock_runtime_client, MockBedrockRuntimeClient) +@patch.object(boto3, "client", return_value=Mock()) +def test_bedrock_chat_completion_init_custom_bedrock_model_provider(mock_client, bedrock_unit_test_env) -> None: + """Test initialization of Amazon Bedrock Chat Completion service""" + bedrock_chat_completion = BedrockChatCompletion( + model_provider=BedrockModelProvider.AMAZON, + ) + + assert bedrock_chat_completion.bedrock_model_provider == BedrockModelProvider.AMAZON + + @pytest.mark.parametrize("exclude_list", [["BEDROCK_CHAT_MODEL_ID"]], indirect=True) def test_bedrock_chat_completion_client_init_with_empty_model_id(bedrock_unit_test_env) -> None: """Test initialization of Amazon Bedrock Chat Completion service with empty model id""" @@ -108,6 +119,14 @@ def test_bedrock_chat_completion_client_init_invalid_settings(bedrock_unit_test_ BedrockChatCompletion(model_id=123) # Model ID must be a string +def test_bedrock_chat_completion_client_init_invalid_model_provider(bedrock_unit_test_env) -> None: + """Test initialization of Amazon Bedrock Chat Completion service with invalid settings""" + with pytest.raises( + ServiceInitializationError, match="Failed to initialize the Amazon Bedrock Chat Completion Service." + ): + BedrockChatCompletion(model_provider="invalid_provider") + + @patch.object(boto3, "client", return_value=Mock()) def test_prompt_execution_settings_class(mock_client, bedrock_unit_test_env) -> None: """Test getting prompt execution settings class""" @@ -167,6 +186,36 @@ def test_prepare_settings_for_request(mock_client, model_id, chat_history) -> No assert all([parsed_settings["inferenceConfig"].values()]) +@pytest.mark.parametrize( + "model_id", + [ + "arn:aws:bedrock:us-east-1:972143716085:application-inference-profile/123456", + ], +) +@patch.object(boto3, "client", return_value=Mock()) +def test_prepare_settings_for_request_with_application_inference_profile(mock_client, model_id, chat_history) -> None: + """Test preparing settings for request""" + # Without a valid model provider, it should raise an error + bedrock_chat_completion = BedrockChatCompletion(model_id=model_id) + settings = BedrockChatPromptExecutionSettings() + with pytest.raises( + ValueError, + match=f"Model ID {model_id} does not contain a valid model provider name.", + ): + bedrock_chat_completion._prepare_settings_for_request(chat_history, settings) + + # With a valid model provider, it should not raise an error + bedrock_chat_completion = BedrockChatCompletion(model_id=model_id, model_provider=BedrockModelProvider.AMAZON) + parsed_settings = bedrock_chat_completion._prepare_settings_for_request(chat_history, settings) + + assert isinstance(parsed_settings, dict) + assert parsed_settings["modelId"] == bedrock_chat_completion.ai_model_id + assert parsed_settings["messages"] == bedrock_chat_completion._prepare_chat_history_for_request(chat_history) + assert parsed_settings["system"] == bedrock_chat_completion._prepare_system_messages_for_request(chat_history) + assert isinstance(parsed_settings["inferenceConfig"], dict) + assert all([parsed_settings["inferenceConfig"].values()]) + + # endregion diff --git a/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_completion.py b/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_completion.py index b5827803811d..3179ea5b2a95 100644 --- a/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_completion.py +++ b/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_completion.py @@ -10,6 +10,7 @@ from semantic_kernel.connectors.ai.bedrock.bedrock_prompt_execution_settings import BedrockTextPromptExecutionSettings from semantic_kernel.connectors.ai.bedrock.services.bedrock_text_completion import BedrockTextCompletion from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import ( + BedrockModelProvider, get_text_completion_request_body, ) from semantic_kernel.contents.streaming_text_content import StreamingTextContent @@ -28,6 +29,9 @@ def test_bedrock_text_completion_init(mock_client, bedrock_unit_test_env) -> Non assert bedrock_text_completion.ai_model_id == bedrock_unit_test_env["BEDROCK_TEXT_MODEL_ID"] assert bedrock_text_completion.service_id == bedrock_unit_test_env["BEDROCK_TEXT_MODEL_ID"] + assert bedrock_text_completion.bedrock_model_provider == BedrockModelProvider( + bedrock_unit_test_env["BEDROCK_MODEL_PROVIDER"] + ) assert bedrock_text_completion.bedrock_client is not None assert bedrock_text_completion.bedrock_runtime_client is not None @@ -88,6 +92,16 @@ def test_bedrock_text_completion_init_custom_runtime_client(mock_client, bedrock assert isinstance(bedrock_text_completion.bedrock_runtime_client, MockBedrockRuntimeClient) +@patch.object(boto3, "client", return_value=Mock()) +def test_bedrock_text_completion_init_custom_bedrock_model_provider(mock_client, bedrock_unit_test_env) -> None: + """Test initialization of Amazon Bedrock Text Completion service""" + bedrock_text_completion = BedrockTextCompletion( + model_provider=BedrockModelProvider.AMAZON, + ) + + assert bedrock_text_completion.bedrock_model_provider == BedrockModelProvider.AMAZON + + @pytest.mark.parametrize("exclude_list", [["BEDROCK_TEXT_MODEL_ID"]], indirect=True) def test_bedrock_text_completion_client_init_with_empty_model_id(bedrock_unit_test_env) -> None: """Test initialization of Amazon Bedrock Text Completion service with empty model id""" @@ -103,6 +117,14 @@ def test_bedrock_text_completion_client_init_invalid_settings(bedrock_unit_test_ BedrockTextCompletion(model_id=123) # Model ID must be a string +def test_bedrock_text_completion_client_init_invalid_model_provider(bedrock_unit_test_env) -> None: + """Test initialization of Amazon Bedrock Text Completion service with invalid settings""" + with pytest.raises( + ServiceInitializationError, match="Failed to initialize the Amazon Bedrock Text Completion Service." + ): + BedrockTextCompletion(model_provider="invalid_provider") + + @patch.object(boto3, "client", return_value=Mock()) def test_prompt_execution_settings_class(mock_client, bedrock_unit_test_env) -> None: """Test getting prompt execution settings class""" @@ -134,7 +156,7 @@ async def test_bedrock_text_completion( mock_bedrock_text_completion_response, output_text, ) -> None: - """Test Amazon Bedrock Chat Completion complete method""" + """Test Amazon Bedrock Text Completion complete method""" with patch.object( MockBedrockRuntimeClient, "invoke_model", return_value=mock_bedrock_text_completion_response ) as mock_model_invoke: @@ -164,6 +186,47 @@ async def test_bedrock_text_completion( assert response[0].text == output_text +@pytest.mark.parametrize( + # These are fake model ids with the supported prefixes + "model_id", + [ + "arn:aws:bedrock:us-east-1:972143716085:application-inference-profile/123456", + ], + indirect=True, +) +async def test_bedrock_text_completion_with_application_inference_profile( + model_id, + mock_bedrock_text_completion_response, + output_text, + model_provider, +) -> None: + """Test Amazon Bedrock Text Completion complete method""" + with patch.object( + MockBedrockRuntimeClient, + "invoke_model", + return_value=mock_bedrock_text_completion_response, + ) as mock_model_invoke: + # Setup + bedrock_text_completion = BedrockTextCompletion( + model_id=model_id, + runtime_client=MockBedrockRuntimeClient(), + client=MockBedrockClient(), + model_provider=model_provider, + ) + + # Act + settings = BedrockTextPromptExecutionSettings() + await bedrock_text_completion.get_text_contents("Hello!", settings=settings) + + # Assert + mock_model_invoke.assert_called_once_with( + body=json.dumps(get_text_completion_request_body(model_id, "Hello!", settings, model_provider)), + modelId=model_id, + accept="application/json", + contentType="application/json", + ) + + @pytest.mark.parametrize( # These are fake model ids with the supported prefixes "model_id", @@ -177,7 +240,7 @@ async def test_bedrock_streaming_text_completion( mock_bedrock_streaming_text_completion_response, output_text, ) -> None: - """Test Amazon Bedrock Chat Completion complete method""" + """Test Amazon Bedrock Text Completion complete method""" with patch.object( MockBedrockRuntimeClient, "invoke_model_with_response_stream", @@ -213,4 +276,49 @@ async def test_bedrock_streaming_text_completion( assert isinstance(response.inner_content, list) +@pytest.mark.parametrize( + # These are fake model ids with the supported prefixes + "model_id", + [ + "arn:aws:bedrock:us-east-1:972143716085:application-inference-profile/123456", + ], + indirect=True, +) +async def test_bedrock_streaming_text_completion_with_application_inference_profile( + model_id, + mock_bedrock_streaming_text_completion_response, + output_text, + model_provider, +) -> None: + """Test Amazon Bedrock Chat Completion complete method""" + with patch.object( + MockBedrockRuntimeClient, + "invoke_model_with_response_stream", + return_value=mock_bedrock_streaming_text_completion_response, + ) as mock_invoke_model_with_response_stream: + # Setup + bedrock_text_completion = BedrockTextCompletion( + model_id=model_id, + runtime_client=MockBedrockRuntimeClient(), + client=MockBedrockClient(), + model_provider=model_provider, + ) + + # Act + settings = BedrockTextPromptExecutionSettings() + chunks: list[StreamingTextContent] = [] + async for streaming_responses in bedrock_text_completion.get_streaming_text_contents( + "Hello!", settings=settings + ): + chunks.extend(streaming_responses) + + # Assert + mock_invoke_model_with_response_stream.assert_called_once_with( + body=json.dumps(get_text_completion_request_body(model_id, "Hello!", settings, model_provider)), + modelId=model_id, + accept="application/json", + contentType="application/json", + ) + + # endregion diff --git a/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_embedding_generation.py b/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_embedding_generation.py index 4180f2a02896..a591f458c204 100644 --- a/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_embedding_generation.py +++ b/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_embedding_generation.py @@ -9,10 +9,8 @@ BedrockEmbeddingPromptExecutionSettings, ) from semantic_kernel.connectors.ai.bedrock.services.bedrock_text_embedding import BedrockTextEmbedding -from semantic_kernel.exceptions.service_exceptions import ( - ServiceInitializationError, - ServiceInvalidResponseError, -) +from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import BedrockModelProvider +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidResponseError from tests.unit.connectors.ai.bedrock.conftest import MockBedrockClient, MockBedrockRuntimeClient # region init @@ -26,6 +24,9 @@ def test_bedrock_text_embedding_init(mock_client, bedrock_unit_test_env) -> None assert bedrock_text_embedding.ai_model_id == bedrock_unit_test_env["BEDROCK_EMBEDDING_MODEL_ID"] assert bedrock_text_embedding.service_id == bedrock_unit_test_env["BEDROCK_EMBEDDING_MODEL_ID"] + assert bedrock_text_embedding.bedrock_model_provider == BedrockModelProvider( + bedrock_unit_test_env["BEDROCK_MODEL_PROVIDER"] + ) assert bedrock_text_embedding.bedrock_client is not None assert bedrock_text_embedding.bedrock_runtime_client is not None @@ -86,6 +87,16 @@ def test_bedrock_text_embedding_init_custom_runtime_client(mock_client, bedrock_ assert isinstance(bedrock_text_embedding.bedrock_runtime_client, MockBedrockRuntimeClient) +@patch.object(boto3, "client", return_value=Mock()) +def test_bedrock_text_embedding_init_custom_bedrock_model_provider(mock_client, bedrock_unit_test_env) -> None: + """Test initialization of Amazon Bedrock Text Embedding service""" + bedrock_text_embedding = BedrockTextEmbedding( + model_provider=BedrockModelProvider.AMAZON, + ) + + assert bedrock_text_embedding.bedrock_model_provider == BedrockModelProvider.AMAZON + + @pytest.mark.parametrize("exclude_list", [["BEDROCK_EMBEDDING_MODEL_ID"]], indirect=True) def test_bedrock_text_embedding_client_init_with_empty_model_id(bedrock_unit_test_env) -> None: """Test initialization of Amazon Bedrock Text Embedding service with empty model id""" @@ -101,6 +112,14 @@ def test_bedrock_text_embedding_client_init_invalid_settings(bedrock_unit_test_e BedrockTextEmbedding(model_id=123) # Model ID must be a string +def test_bedrock_text_embedding_client_init_invalid_model_provider(bedrock_unit_test_env) -> None: + """Test initialization of Amazon Bedrock Text Embedding service with invalid settings""" + with pytest.raises( + ServiceInitializationError, match="Failed to initialize the Amazon Bedrock Text Embedding Service." + ): + BedrockTextEmbedding(model_provider="invalid_provider") + + @patch.object(boto3, "client", return_value=Mock()) def test_prompt_execution_settings_class(mock_client, bedrock_unit_test_env) -> None: """Test getting prompt execution settings class""" @@ -148,6 +167,47 @@ async def test_bedrock_text_embedding(model_id, mock_bedrock_text_embedding_resp assert len(response) == 2 +@pytest.mark.parametrize( + # These are fake model ids with the supported prefixes + "model_id", + [ + "arn:aws:bedrock:us-east-1:972143716085:application-inference-profile/123456", + ], + indirect=True, +) +async def test_bedrock_text_embedding_with_application_inference_profile( + model_id, + mock_bedrock_text_embedding_response, + model_provider, +) -> None: + """Test Bedrock text embedding generation""" + with patch.object( + MockBedrockRuntimeClient, "invoke_model", return_value=mock_bedrock_text_embedding_response + ) as mock_model_invoke: + # Setup + bedrock_text_embedding = BedrockTextEmbedding( + model_id=model_id, + runtime_client=MockBedrockRuntimeClient(), + client=MockBedrockClient(), + model_provider=BedrockModelProvider.AMAZON, + ) + + # Act + settings = BedrockEmbeddingPromptExecutionSettings() + response = await bedrock_text_embedding.generate_embeddings(["hello", "world"], settings) + + # Assert + mock_model_invoke.assert_called_with( + body=ANY, + modelId=model_id, + accept="application/json", + contentType="application/json", + ) + assert mock_model_invoke.call_count == 2 + + assert len(response) == 2 + + @pytest.mark.parametrize( # These are fake model ids with the supported prefixes "model_id", From ed76005a2d8ce302662c621d39367ae36a27a1b1 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 4 Aug 2025 14:17:03 -0700 Subject: [PATCH 2/3] Add readme --- python/semantic_kernel/connectors/ai/bedrock/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/semantic_kernel/connectors/ai/bedrock/README.md b/python/semantic_kernel/connectors/ai/bedrock/README.md index e2678be9028f..d850c6ddfe9f 100644 --- a/python/semantic_kernel/connectors/ai/bedrock/README.md +++ b/python/semantic_kernel/connectors/ai/bedrock/README.md @@ -38,6 +38,14 @@ bedrock_chat_completion_service = BedrockChatCompletion(runtime_client=runtime_c To find model supports by AWS regions, refer to this [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/models-regions.html). +### Inference profiles + +you can create inference profiles in AWS Bedrock to monitor and optimize the performance of your foundation models. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles.html) for more information. + +when you are using an Application Inference Profile, you must specify the `BEDROCK_MODEL_PROVIDER` environment variable to the model provider you are using. For example, if you are using Amazon Titan, you must set `BEDROCK_MODEL_PROVIDER=amazon`. This is because an Application Inference Profile doesn't contain the model provider information, and the Bedrock connector needs to know which model provider to use so that it can create the correct request body to the Bedrock API. + +> An Application Inference Profile ARN is usually formatted as followed: `arn:aws:bedrock:::application-inference-profile/`. + ### Input & Output Modalities Foundational models in Bedrock support the multiple modalities, including text, image, and embedding. However, not all models support the same modalities. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for more information. From b03c55b1f0ba3dd566a305181649c057ae5ab467 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 4 Aug 2025 17:37:19 -0700 Subject: [PATCH 3/3] Update python/semantic_kernel/connectors/ai/bedrock/README.md Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- python/semantic_kernel/connectors/ai/bedrock/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/bedrock/README.md b/python/semantic_kernel/connectors/ai/bedrock/README.md index d850c6ddfe9f..d5109d294306 100644 --- a/python/semantic_kernel/connectors/ai/bedrock/README.md +++ b/python/semantic_kernel/connectors/ai/bedrock/README.md @@ -40,9 +40,9 @@ To find model supports by AWS regions, refer to this [AWS documentation](https:/ ### Inference profiles -you can create inference profiles in AWS Bedrock to monitor and optimize the performance of your foundation models. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles.html) for more information. +You can create inference profiles in AWS Bedrock to monitor and optimize the performance of your foundation models. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles.html) for more information. -when you are using an Application Inference Profile, you must specify the `BEDROCK_MODEL_PROVIDER` environment variable to the model provider you are using. For example, if you are using Amazon Titan, you must set `BEDROCK_MODEL_PROVIDER=amazon`. This is because an Application Inference Profile doesn't contain the model provider information, and the Bedrock connector needs to know which model provider to use so that it can create the correct request body to the Bedrock API. +When you are using an Application Inference Profile, you must specify the `BEDROCK_MODEL_PROVIDER` environment variable to the model provider you are using. For example, if you are using Amazon Titan, you must set `BEDROCK_MODEL_PROVIDER=amazon`. This is because an Application Inference Profile doesn't contain the model provider information, and the Bedrock connector needs to know which model provider to use so that it can create the correct request body to the Bedrock API. > An Application Inference Profile ARN is usually formatted as followed: `arn:aws:bedrock:::application-inference-profile/`.