Skip to content

Commit 0b00c68

Browse files
authored
fix: use lambda pattern for bedrock config env vars (#3307)
# What does this PR do? Improved bedrock provider config to read from environment variables like AWS_ACCESS_KEY_ID. Updated all fields to use default_factory with lambda patterns like the nvidia provider does. Now the environment variables work as documented. Closes #3305 ## Test Plan Ran the new bedrock config tests: ```bash python -m pytest tests/unit/providers/inference/bedrock/test_config.py -v Verified existing provider tests still work: python -m pytest tests/unit/providers/test_configs.py -v
1 parent 3a7ac42 commit 0b00c68

File tree

4 files changed

+79
-14
lines changed

4 files changed

+79
-14
lines changed

docs/source/providers/inference/remote_bedrock.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
1515
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
1616
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
1717
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
18-
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
19-
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
18+
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
19+
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
2020
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
2121

2222
## Sample Configuration

docs/source/providers/safety/remote_bedrock.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
1515
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
1616
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
1717
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
18-
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
19-
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
18+
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
19+
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
2020
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
2121

2222
## Sample Configuration

llama_stack/providers/utils/bedrock/config.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,53 +4,55 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
import os
8+
79
from pydantic import BaseModel, Field
810

911

1012
class BedrockBaseConfig(BaseModel):
1113
aws_access_key_id: str | None = Field(
12-
default=None,
14+
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
1315
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
1416
)
1517
aws_secret_access_key: str | None = Field(
16-
default=None,
18+
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
1719
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
1820
)
1921
aws_session_token: str | None = Field(
20-
default=None,
22+
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
2123
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
2224
)
2325
region_name: str | None = Field(
24-
default=None,
26+
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"),
2527
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
2628
"Default use environment variable: AWS_DEFAULT_REGION",
2729
)
2830
profile_name: str | None = Field(
29-
default=None,
31+
default_factory=lambda: os.getenv("AWS_PROFILE"),
3032
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
3133
)
3234
total_max_attempts: int | None = Field(
33-
default=None,
35+
default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None,
3436
description="An integer representing the maximum number of attempts that will be made for a single request, "
3537
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
3638
)
3739
retry_mode: str | None = Field(
38-
default=None,
40+
default_factory=lambda: os.getenv("AWS_RETRY_MODE"),
3941
description="A string representing the type of retries Boto3 will perform."
4042
"Default use environment variable: AWS_RETRY_MODE",
4143
)
4244
connect_timeout: float | None = Field(
43-
default=60,
45+
default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")),
4446
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
4547
"The default is 60 seconds.",
4648
)
4749
read_timeout: float | None = Field(
48-
default=60,
50+
default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")),
4951
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
5052
"The default is 60 seconds.",
5153
)
5254
session_ttl: int | None = Field(
53-
default=3600,
55+
default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")),
5456
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
5557
)
5658

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import os
8+
from unittest.mock import patch
9+
10+
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
11+
12+
13+
class TestBedrockBaseConfig:
14+
def test_defaults_work_without_env_vars(self):
15+
with patch.dict(os.environ, {}, clear=True):
16+
config = BedrockBaseConfig()
17+
18+
# Basic creds should be None
19+
assert config.aws_access_key_id is None
20+
assert config.aws_secret_access_key is None
21+
assert config.region_name is None
22+
23+
# Timeouts get defaults
24+
assert config.connect_timeout == 60.0
25+
assert config.read_timeout == 60.0
26+
assert config.session_ttl == 3600
27+
28+
def test_env_vars_get_picked_up(self):
29+
env_vars = {
30+
"AWS_ACCESS_KEY_ID": "AKIATEST123",
31+
"AWS_SECRET_ACCESS_KEY": "secret123",
32+
"AWS_DEFAULT_REGION": "us-west-2",
33+
"AWS_MAX_ATTEMPTS": "5",
34+
"AWS_RETRY_MODE": "adaptive",
35+
"AWS_CONNECT_TIMEOUT": "30",
36+
}
37+
38+
with patch.dict(os.environ, env_vars, clear=True):
39+
config = BedrockBaseConfig()
40+
41+
assert config.aws_access_key_id == "AKIATEST123"
42+
assert config.aws_secret_access_key == "secret123"
43+
assert config.region_name == "us-west-2"
44+
assert config.total_max_attempts == 5
45+
assert config.retry_mode == "adaptive"
46+
assert config.connect_timeout == 30.0
47+
48+
def test_partial_env_setup(self):
49+
# Just setting one timeout var
50+
with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True):
51+
config = BedrockBaseConfig()
52+
53+
assert config.connect_timeout == 120.0
54+
assert config.read_timeout == 60.0 # still default
55+
assert config.aws_access_key_id is None
56+
57+
def test_bad_max_attempts_breaks(self):
58+
with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True):
59+
try:
60+
BedrockBaseConfig()
61+
raise AssertionError("Should have failed on bad int conversion")
62+
except ValueError:
63+
pass # expected

0 commit comments

Comments
 (0)