Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/providers/inference/remote_bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |

## Sample Configuration
Expand Down
4 changes: 2 additions & 2 deletions docs/source/providers/safety/remote_bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |

## Sample Configuration
Expand Down
22 changes: 12 additions & 10 deletions llama_stack/providers/utils/bedrock/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,55 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import os

from pydantic import BaseModel, Field


class BedrockBaseConfig(BaseModel):
aws_access_key_id: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
)
aws_secret_access_key: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
)
aws_session_token: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
)
region_name: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"),
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
"Default use environment variable: AWS_DEFAULT_REGION",
)
profile_name: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_PROFILE"),
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
)
total_max_attempts: int | None = Field(
default=None,
default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None,
description="An integer representing the maximum number of attempts that will be made for a single request, "
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
)
retry_mode: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_RETRY_MODE"),
description="A string representing the type of retries Boto3 will perform."
"Default use environment variable: AWS_RETRY_MODE",
)
connect_timeout: float | None = Field(
default=60,
default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")),
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
"The default is 60 seconds.",
)
read_timeout: float | None = Field(
default=60,
default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")),
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
"The default is 60 seconds.",
)
session_ttl: int | None = Field(
default=3600,
default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")),
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
)

Expand Down
63 changes: 63 additions & 0 deletions tests/unit/providers/inference/bedrock/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import os
from unittest.mock import patch

from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig


class TestBedrockBaseConfig:
def test_defaults_work_without_env_vars(self):
with patch.dict(os.environ, {}, clear=True):
config = BedrockBaseConfig()

# Basic creds should be None
assert config.aws_access_key_id is None
assert config.aws_secret_access_key is None
assert config.region_name is None

# Timeouts get defaults
assert config.connect_timeout == 60.0
assert config.read_timeout == 60.0
assert config.session_ttl == 3600

def test_env_vars_get_picked_up(self):
env_vars = {
"AWS_ACCESS_KEY_ID": "AKIATEST123",
"AWS_SECRET_ACCESS_KEY": "secret123",
"AWS_DEFAULT_REGION": "us-west-2",
"AWS_MAX_ATTEMPTS": "5",
"AWS_RETRY_MODE": "adaptive",
"AWS_CONNECT_TIMEOUT": "30",
}

with patch.dict(os.environ, env_vars, clear=True):
config = BedrockBaseConfig()

assert config.aws_access_key_id == "AKIATEST123"
assert config.aws_secret_access_key == "secret123"
assert config.region_name == "us-west-2"
assert config.total_max_attempts == 5
assert config.retry_mode == "adaptive"
assert config.connect_timeout == 30.0

def test_partial_env_setup(self):
# Just setting one timeout var
with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True):
config = BedrockBaseConfig()

assert config.connect_timeout == 120.0
assert config.read_timeout == 60.0 # still default
assert config.aws_access_key_id is None

def test_bad_max_attempts_breaks(self):
with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True):
try:
BedrockBaseConfig()
raise AssertionError("Should have failed on bad int conversion")
except ValueError:
pass # expected
Loading