Skip to content

Commit 921e3d0

Browse files
committed
fix: Load region from session when possible
Only use a default region if the session doesn't provide one or if the AWS_REGION environment variable is not set. Fixes #238
1 parent 6a1ccea commit 921e3d0

File tree

2 files changed

+67
-45
lines changed

2 files changed

+67
-45
lines changed

src/strands/models/bedrock.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
logger = logging.getLogger(__name__)
2626

2727
DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
28+
DEFAULT_BEDROCK_REGION = "us-west-2"
2829

2930
BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
3031
"Input is too long for requested model",
@@ -117,18 +118,7 @@ def __init__(
117118

118119
logger.debug("config=<%s> | initializing", self.config)
119120

120-
region_for_boto = region_name or os.getenv("AWS_REGION")
121-
if region_for_boto is None:
122-
region_for_boto = "us-west-2"
123-
logger.warning("defaulted to us-west-2 because no region was specified")
124-
logger.warning(
125-
"issue=<%s> | this behavior will change in an upcoming release",
126-
"https://github.com/strands-agents/sdk-python/issues/238",
127-
)
128-
129-
session = boto_session or boto3.Session(
130-
region_name=region_for_boto,
131-
)
121+
session = boto_session or boto3.Session()
132122

133123
# Add strands-agents to the request user agent
134124
if boto_client_config:
@@ -147,6 +137,7 @@ def __init__(
147137
self.client = session.client(
148138
service_name="bedrock-runtime",
149139
config=client_config,
140+
region_name=region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION,
150141
)
151142

152143
@override

tests/strands/models/test_bedrock.py

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
import unittest.mock
4+
from unittest.mock import ANY
45

56
import boto3
67
import pydantic
@@ -10,17 +11,30 @@
1011

1112
import strands
1213
from strands.models import BedrockModel
13-
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID
14+
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION
1415
from strands.types.exceptions import ModelThrottledException
1516

1617

1718
@pytest.fixture
18-
def bedrock_client():
19+
def session_cls():
20+
# Mock the creation of a Session so that we don't depend on environment variables or profiles
1921
with unittest.mock.patch.object(strands.models.bedrock.boto3, "Session") as mock_session_cls:
20-
mock_client = mock_session_cls.return_value.client.return_value
21-
mock_client.meta = unittest.mock.MagicMock()
22-
mock_client.meta.region_name = "us-west-2"
23-
yield mock_client
22+
mock_session_cls.return_value.region_name = None
23+
yield mock_session_cls
24+
25+
26+
@pytest.fixture
27+
def mock_client_method(session_cls):
28+
# the boto3.Session().client(...) method
29+
return session_cls.return_value.client
30+
31+
32+
@pytest.fixture
33+
def bedrock_client(session_cls):
34+
mock_client = session_cls.return_value.client.return_value
35+
mock_client.meta = unittest.mock.MagicMock()
36+
mock_client.meta.region_name = "us-west-2"
37+
yield mock_client
2438

2539

2640
@pytest.fixture
@@ -105,41 +119,58 @@ def test__init__default_model_id(bedrock_client):
105119
assert tru_model_id == exp_model_id
106120

107121

108-
def test__init__with_default_region(bedrock_client):
122+
def test__init__with_default_region(session_cls, mock_client_method):
109123
"""Test that BedrockModel uses the provided region."""
110-
_ = bedrock_client
111-
default_region = "us-west-2"
124+
BedrockModel()
112125

113-
with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls:
114-
with unittest.mock.patch("strands.models.bedrock.logger.warning") as mock_warning:
115-
_ = BedrockModel()
116-
mock_session_cls.assert_called_once_with(region_name=default_region)
117-
# Assert that warning logs are emitted
118-
mock_warning.assert_any_call("defaulted to us-west-2 because no region was specified")
119-
mock_warning.assert_any_call(
120-
"issue=<%s> | this behavior will change in an upcoming release",
121-
"https://github.com/strands-agents/sdk-python/issues/238",
122-
)
123-
124-
125-
def test__init__with_custom_region(bedrock_client):
126+
session_cls.return_value.client.assert_called_with(region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY)
127+
128+
129+
def test__init__with_session_region(session_cls, mock_client_method):
126130
"""Test that BedrockModel uses the provided region."""
127-
_ = bedrock_client
128-
custom_region = "us-east-1"
131+
session_cls.return_value.region_name = "eu-blah-1"
129132

130-
with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls:
131-
_ = BedrockModel(region_name=custom_region)
132-
mock_session_cls.assert_called_once_with(region_name=custom_region)
133+
BedrockModel()
133134

135+
mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY)
134136

135-
def test__init__with_environment_variable_region(bedrock_client):
137+
138+
def test__init__with_custom_region(mock_client_method):
136139
"""Test that BedrockModel uses the provided region."""
137-
_ = bedrock_client
138-
os.environ["AWS_REGION"] = "eu-west-1"
140+
custom_region = "us-east-1"
141+
BedrockModel(region_name=custom_region)
142+
mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY)
139143

140-
with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls:
141-
_ = BedrockModel()
142-
mock_session_cls.assert_called_once_with(region_name="eu-west-1")
144+
145+
def test__init__with_default_environment_variable_region(mock_client_method):
146+
"""Test that BedrockModel uses the AWS_REGION since we code that in."""
147+
with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "eu-west-2"}):
148+
BedrockModel()
149+
150+
mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY)
151+
152+
153+
def test__init__region_precedence(mock_client_method, session_cls):
154+
"""Test that BedrockModel uses the correct ordering of precedence when determining region."""
155+
with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "us-environment-1"}):
156+
session_cls.return_value.region_name = "us-session-1"
157+
158+
# specifying a region always wins out
159+
BedrockModel(region_name="us-specified-1")
160+
mock_client_method.assert_called_with(region_name="us-specified-1", config=ANY, service_name=ANY)
161+
162+
# other-wise uses the session's
163+
BedrockModel()
164+
mock_client_method.assert_called_with(region_name="us-session-1", config=ANY, service_name=ANY)
165+
166+
# environment variable next
167+
session_cls.return_value.region_name = None
168+
BedrockModel()
169+
mock_client_method.assert_called_with(region_name="us-environment-1", config=ANY, service_name=ANY)
170+
171+
# Finally default
172+
BedrockModel()
173+
mock_client_method.assert_called_with(region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY)
143174

144175

145176
def test__init__with_region_and_session_raises_value_error():

0 commit comments

Comments
 (0)