Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_bedrock_chat_completion_service_and_request_settings() -> tuple[
"""
from semantic_kernel.connectors.ai.bedrock import BedrockChatCompletion, BedrockChatPromptExecutionSettings

chat_service = BedrockChatCompletion(service_id=service_id, model_id="anthropic.claude-3-sonnet-20240229-v1:0")
chat_service = BedrockChatCompletion(service_id=service_id)
request_settings = BedrockChatPromptExecutionSettings(
# For model specific settings, specify them in the extension_data dictionary.
# For example, for Cohere Command specific settings, refer to:
Expand Down
8 changes: 8 additions & 0 deletions python/semantic_kernel/connectors/ai/bedrock/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ bedrock_chat_completion_service = BedrockChatCompletion(runtime_client=runtime_c

To find model supports by AWS regions, refer to this [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/models-regions.html).

### Inference profiles

You can create inference profiles in AWS Bedrock to monitor and optimize the performance of your foundation models. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles.html) for more information.

When you are using an Application Inference Profile, you must specify the `BEDROCK_MODEL_PROVIDER` environment variable to the model provider you are using. For example, if you are using Amazon Titan, you must set `BEDROCK_MODEL_PROVIDER=amazon`. This is because an Application Inference Profile doesn't contain the model provider information, and the Bedrock connector needs to know which model provider to use so that it can create the correct request body to the Bedrock API.

> An Application Inference Profile ARN is usually formatted as followed: `arn:aws:bedrock:<region>:<account-id>:application-inference-profile/<profile-id>`.

### Input & Output Modalities

Foundational models in Bedrock support the multiple modalities, including text, image, and embedding. However, not all models support the same modalities. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for more information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import ClassVar

from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import BedrockModelProvider
from semantic_kernel.kernel_pydantic import KernelBaseSettings
from semantic_kernel.utils.feature_stage_decorator import experimental

Expand All @@ -25,10 +26,16 @@ class BedrockSettings(KernelBaseSettings):
(Env var BEDROCK_TEXT_MODEL_ID)
- embedding_model_id: str | None - The Amazon Bedrock embedding model ID to use.
(Env var BEDROCK_EMBEDDING_MODEL_ID)
- model_provider: BedrockModelProvider | None - The Bedrock model provider to use.
If not provided, the model provider will be extracted from the model ID.
When using an Application Inference Profile where the model provider is not part
of the model ID, this setting must be provided.
(Env var BEDROCK_MODEL_PROVIDER)
"""

env_prefix: ClassVar[str] = "BEDROCK_"

chat_model_id: str | None = None
text_model_id: str | None = None
embedding_model_id: str | None = None
model_provider: BedrockModelProvider | None = None
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import boto3

from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import BedrockModelProvider
from semantic_kernel.kernel_pydantic import KernelBaseModel


Expand All @@ -19,22 +20,30 @@ class BedrockBase(KernelBaseModel, ABC):
# Client: Use for model management
bedrock_client: Any

bedrock_model_provider: BedrockModelProvider | None = None

def __init__(
self,
*,
runtime_client: Any | None = None,
client: Any | None = None,
bedrock_model_provider: BedrockModelProvider | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Amazon Bedrock Base Class.

Args:
runtime_client: The Amazon Bedrock runtime client to use.
client: The Amazon Bedrock client to use.
bedrock_model_provider: The Bedrock model provider to use.
If not provided, the model provider will be extracted from the model ID.
When using an Application Inference Profile where the model provider is not part
of the model ID, this setting must be provided.
**kwargs: Additional keyword arguments.
"""
super().__init__(
bedrock_runtime_client=runtime_client or boto3.client("bedrock-runtime"),
bedrock_client=client or boto3.client("bedrock"),
bedrock_model_provider=bedrock_model_provider,
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from semantic_kernel.connectors.ai.bedrock.bedrock_settings import BedrockSettings
from semantic_kernel.connectors.ai.bedrock.services.bedrock_base import BedrockBase
from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import (
BedrockModelProvider,
get_chat_completion_additional_model_request_fields,
)
from semantic_kernel.connectors.ai.bedrock.services.model_provider.utils import (
Expand All @@ -36,10 +37,7 @@
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.exceptions.service_exceptions import (
ServiceInitializationError,
ServiceInvalidResponseError,
)
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidResponseError
from semantic_kernel.utils.async_utils import run_in_executor
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import (
trace_chat_completion,
Expand All @@ -60,6 +58,7 @@ class BedrockChatCompletion(BedrockBase, ChatCompletionClientBase):
def __init__(
self,
model_id: str | None = None,
model_provider: BedrockModelProvider | None = None,
service_id: str | None = None,
runtime_client: Any | None = None,
client: Any | None = None,
Expand All @@ -70,6 +69,7 @@ def __init__(

Args:
model_id: The Amazon Bedrock chat model ID to use.
model_provider: The Bedrock model provider to use.
service_id: The Service ID for the completion service.
runtime_client: The Amazon Bedrock runtime client to use.
client: The Amazon Bedrock client to use.
Expand All @@ -79,6 +79,7 @@ def __init__(
try:
bedrock_settings = BedrockSettings(
chat_model_id=model_id,
model_provider=model_provider,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
Expand All @@ -93,6 +94,7 @@ def __init__(
service_id=service_id or bedrock_settings.chat_model_id,
runtime_client=runtime_client,
client=client,
bedrock_model_provider=bedrock_settings.model_provider,
)

# region Overriding base class methods
Expand Down Expand Up @@ -212,7 +214,7 @@ def _prepare_settings_for_request(
"stopSequences": settings.stop,
}),
"additionalModelRequestFields": get_chat_completion_additional_model_request_fields(
self.ai_model_id, settings
self.ai_model_id, settings, model_provider=self.bedrock_model_provider
),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from semantic_kernel.connectors.ai.bedrock.bedrock_settings import BedrockSettings
from semantic_kernel.connectors.ai.bedrock.services.bedrock_base import BedrockBase
from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import (
BedrockModelProvider,
get_text_completion_request_body,
parse_streaming_text_completion_response,
parse_text_completion_response,
Expand All @@ -41,6 +42,7 @@ class BedrockTextCompletion(BedrockBase, TextCompletionClientBase):
def __init__(
self,
model_id: str | None = None,
model_provider: BedrockModelProvider | None = None,
service_id: str | None = None,
runtime_client: Any | None = None,
client: Any | None = None,
Expand All @@ -51,6 +53,7 @@ def __init__(

Args:
model_id: The Amazon Bedrock text model ID to use.
model_provider: The Bedrock model provider to use.
service_id: The Service ID for the text completion service.
runtime_client: The Amazon Bedrock runtime client to use.
client: The Amazon Bedrock client to use.
Expand All @@ -60,6 +63,7 @@ def __init__(
try:
bedrock_settings = BedrockSettings(
text_model_id=model_id,
model_provider=model_provider,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
Expand All @@ -74,6 +78,7 @@ def __init__(
service_id=service_id or bedrock_settings.text_model_id,
runtime_client=runtime_client,
client=client,
bedrock_model_provider=bedrock_settings.model_provider,
)

# region Overriding base class methods
Expand All @@ -94,11 +99,17 @@ async def _inner_get_text_contents(
settings = self.get_prompt_execution_settings_from_settings(settings)
assert isinstance(settings, BedrockTextPromptExecutionSettings) # nosec

request_body = get_text_completion_request_body(self.ai_model_id, prompt, settings)
request_body = get_text_completion_request_body(
self.ai_model_id,
prompt,
settings,
model_provider=self.bedrock_model_provider,
)
response_body = await self._async_invoke_model(request_body)
return parse_text_completion_response(
self.ai_model_id,
json.loads(response_body.get("body").read()),
model_provider=self.bedrock_model_provider,
)

@override
Expand All @@ -112,14 +123,20 @@ async def _inner_get_streaming_text_contents(
settings = self.get_prompt_execution_settings_from_settings(settings)
assert isinstance(settings, BedrockTextPromptExecutionSettings) # nosec

request_body = get_text_completion_request_body(self.ai_model_id, prompt, settings)
request_body = get_text_completion_request_body(
self.ai_model_id,
prompt,
settings,
model_provider=self.bedrock_model_provider,
)
response_stream = await self._async_invoke_model_stream(request_body)
for event in response_stream.get("body"):
chunk = event.get("chunk")
yield [
parse_streaming_text_completion_response(
self.ai_model_id,
json.loads(chunk.get("bytes").decode()),
model_provider=self.bedrock_model_provider,
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from semantic_kernel.connectors.ai.bedrock.bedrock_settings import BedrockSettings
from semantic_kernel.connectors.ai.bedrock.services.bedrock_base import BedrockBase
from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import (
BedrockModelProvider,
get_text_embedding_request_body,
parse_text_embedding_response,
)
Expand All @@ -38,6 +39,7 @@ class BedrockTextEmbedding(BedrockBase, EmbeddingGeneratorBase):
def __init__(
self,
model_id: str | None = None,
model_provider: BedrockModelProvider | None = None,
service_id: str | None = None,
runtime_client: Any | None = None,
client: Any | None = None,
Expand All @@ -48,6 +50,7 @@ def __init__(

Args:
model_id: The Amazon Bedrock text embedding model ID to use.
model_provider: The Bedrock model provider to use.
service_id: The Service ID for the text embedding service.
runtime_client: The Amazon Bedrock runtime client to use.
client: The Amazon Bedrock client to use.
Expand All @@ -57,6 +60,7 @@ def __init__(
try:
bedrock_settings = BedrockSettings(
embedding_model_id=model_id,
model_provider=model_provider,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
Expand All @@ -71,6 +75,7 @@ def __init__(
service_id=service_id or bedrock_settings.embedding_model_id,
runtime_client=runtime_client,
client=client,
bedrock_model_provider=bedrock_settings.model_provider,
)

@override
Expand All @@ -87,12 +92,25 @@ async def generate_embeddings(
assert isinstance(settings, BedrockEmbeddingPromptExecutionSettings) # nosec

results = await asyncio.gather(*[
self._async_invoke_model(get_text_embedding_request_body(self.ai_model_id, text, settings))
self._async_invoke_model(
get_text_embedding_request_body(
self.ai_model_id,
text,
settings,
model_provider=self.bedrock_model_provider,
)
)
for text in texts
])

return array([
array(parse_text_embedding_response(self.ai_model_id, json.loads(result.get("body").read())))
array(
parse_text_embedding_response(
self.ai_model_id,
json.loads(result.get("body").read()),
model_provider=self.bedrock_model_provider,
)
)
for result in results
])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,34 @@ def to_model_provider(cls, model_id: str) -> "BedrockModelProvider":
}


def get_text_completion_request_body(model_id: str, prompt: str, settings: BedrockTextPromptExecutionSettings) -> dict:
def get_text_completion_request_body(
model_id: str,
prompt: str,
settings: BedrockTextPromptExecutionSettings,
model_provider: BedrockModelProvider | None = None,
) -> dict:
"""Get the request body for text completion for Amazon Bedrock models."""
model_provider = BedrockModelProvider.to_model_provider(model_id)
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
return TEXT_COMPLETION_REQUEST_BODY_MAPPING[model_provider](prompt, settings)


def parse_text_completion_response(model_id: str, response: dict) -> list[TextContent]:
def parse_text_completion_response(
model_id: str,
response: dict,
model_provider: BedrockModelProvider | None = None,
) -> list[TextContent]:
"""Parse the response from text completion for Amazon Bedrock models."""
model_provider = BedrockModelProvider.to_model_provider(model_id)
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
return TEXT_COMPLETION_RESPONSE_MAPPING[model_provider](response, model_id)


def parse_streaming_text_completion_response(model_id: str, chunk: dict) -> StreamingTextContent:
def parse_streaming_text_completion_response(
model_id: str,
chunk: dict,
model_provider: BedrockModelProvider | None = None,
) -> StreamingTextContent:
"""Parse the response from streaming text completion for Amazon Bedrock models."""
model_provider = BedrockModelProvider.to_model_provider(model_id)
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
return STREAMING_TEXT_COMPLETION_RESPONSE_MAPPING[model_provider](chunk, model_id)


Expand All @@ -109,10 +122,12 @@ def parse_streaming_text_completion_response(model_id: str, chunk: dict) -> Stre


def get_chat_completion_additional_model_request_fields(
model_id: str, settings: BedrockChatPromptExecutionSettings
model_id: str,
settings: BedrockChatPromptExecutionSettings,
model_provider: BedrockModelProvider | None = None,
) -> dict[str, Any] | None:
"""Get the additional model request fields for chat completion for Amazon Bedrock models."""
model_provider = BedrockModelProvider.to_model_provider(model_id)
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
return CHAT_COMPLETION_ADDITIONAL_MODEL_REQUEST_FIELDS_MAPPING[model_provider](settings)


Expand All @@ -134,16 +149,23 @@ def get_chat_completion_additional_model_request_fields(


def get_text_embedding_request_body(
model_id: str, text: str, settings: BedrockEmbeddingPromptExecutionSettings
model_id: str,
text: str,
settings: BedrockEmbeddingPromptExecutionSettings,
model_provider: BedrockModelProvider | None = None,
) -> dict:
"""Get the request body for text embedding for Amazon Bedrock models."""
model_provider = BedrockModelProvider.to_model_provider(model_id)
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
return TEXT_EMBEDDING_REQUEST_BODY_MAPPING[model_provider](text, settings)


def parse_text_embedding_response(model_id: str, response: dict) -> list[float]:
def parse_text_embedding_response(
model_id: str,
response: dict,
model_provider: BedrockModelProvider | None = None,
) -> list[float]:
"""Parse the response from text embedding for Amazon Bedrock models."""
model_provider = BedrockModelProvider.to_model_provider(model_id)
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
return TEXT_EMBEDDING_RESPONSE_MAPPING[model_provider](response)


Expand Down
Loading
Loading