|
1 | 1 | import os
|
2 | 2 | import sys
|
3 | 3 | import unittest.mock
|
| 4 | +from unittest.mock import ANY |
4 | 5 |
|
5 | 6 | import boto3
|
6 | 7 | import pydantic
|
|
10 | 11 |
|
11 | 12 | import strands
|
12 | 13 | from strands.models import BedrockModel
|
13 |
| -from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID |
| 14 | +from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION |
14 | 15 | from strands.types.exceptions import ModelThrottledException
|
15 | 16 |
|
16 | 17 |
|
17 | 18 | @pytest.fixture
|
18 |
| -def bedrock_client(): |
| 19 | +def session_cls(): |
| 20 | + # Mock the creation of a Session so that we don't depend on environment variables or profiles |
19 | 21 | with unittest.mock.patch.object(strands.models.bedrock.boto3, "Session") as mock_session_cls:
|
20 |
| - mock_client = mock_session_cls.return_value.client.return_value |
21 |
| - mock_client.meta = unittest.mock.MagicMock() |
22 |
| - mock_client.meta.region_name = "us-west-2" |
23 |
| - yield mock_client |
| 22 | + mock_session_cls.return_value.region_name = None |
| 23 | + yield mock_session_cls |
| 24 | + |
| 25 | + |
| 26 | +@pytest.fixture |
| 27 | +def mock_client_method(session_cls): |
| 28 | + # the boto3.Session().client(...) method |
| 29 | + return session_cls.return_value.client |
| 30 | + |
| 31 | + |
| 32 | +@pytest.fixture |
| 33 | +def bedrock_client(session_cls): |
| 34 | + mock_client = session_cls.return_value.client.return_value |
| 35 | + mock_client.meta = unittest.mock.MagicMock() |
| 36 | + mock_client.meta.region_name = "us-west-2" |
| 37 | + yield mock_client |
24 | 38 |
|
25 | 39 |
|
26 | 40 | @pytest.fixture
|
@@ -105,41 +119,58 @@ def test__init__default_model_id(bedrock_client):
|
105 | 119 | assert tru_model_id == exp_model_id
|
106 | 120 |
|
107 | 121 |
|
108 |
| -def test__init__with_default_region(bedrock_client): |
| 122 | +def test__init__with_default_region(session_cls, mock_client_method): |
109 | 123 | """Test that BedrockModel uses the provided region."""
|
110 |
| - _ = bedrock_client |
111 |
| - default_region = "us-west-2" |
| 124 | + BedrockModel() |
112 | 125 |
|
113 |
| - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: |
114 |
| - with unittest.mock.patch("strands.models.bedrock.logger.warning") as mock_warning: |
115 |
| - _ = BedrockModel() |
116 |
| - mock_session_cls.assert_called_once_with(region_name=default_region) |
117 |
| - # Assert that warning logs are emitted |
118 |
| - mock_warning.assert_any_call("defaulted to us-west-2 because no region was specified") |
119 |
| - mock_warning.assert_any_call( |
120 |
| - "issue=<%s> | this behavior will change in an upcoming release", |
121 |
| - "https://github.com/strands-agents/sdk-python/issues/238", |
122 |
| - ) |
123 |
| - |
124 |
| - |
125 |
| -def test__init__with_custom_region(bedrock_client): |
| 126 | + session_cls.return_value.client.assert_called_with(region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY) |
| 127 | + |
| 128 | + |
| 129 | +def test__init__with_session_region(session_cls, mock_client_method): |
126 | 130 | """Test that BedrockModel uses the provided region."""
|
127 |
| - _ = bedrock_client |
128 |
| - custom_region = "us-east-1" |
| 131 | + session_cls.return_value.region_name = "eu-blah-1" |
129 | 132 |
|
130 |
| - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: |
131 |
| - _ = BedrockModel(region_name=custom_region) |
132 |
| - mock_session_cls.assert_called_once_with(region_name=custom_region) |
| 133 | + BedrockModel() |
133 | 134 |
|
| 135 | + mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY) |
134 | 136 |
|
135 |
| -def test__init__with_environment_variable_region(bedrock_client): |
| 137 | + |
| 138 | +def test__init__with_custom_region(mock_client_method): |
136 | 139 | """Test that BedrockModel uses the provided region."""
|
137 |
| - _ = bedrock_client |
138 |
| - os.environ["AWS_REGION"] = "eu-west-1" |
| 140 | + custom_region = "us-east-1" |
| 141 | + BedrockModel(region_name=custom_region) |
| 142 | + mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY) |
139 | 143 |
|
140 |
| - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: |
141 |
| - _ = BedrockModel() |
142 |
| - mock_session_cls.assert_called_once_with(region_name="eu-west-1") |
| 144 | + |
| 145 | +def test__init__with_default_environment_variable_region(mock_client_method): |
| 146 | + """Test that BedrockModel uses the AWS_REGION since we code that in.""" |
| 147 | + with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "eu-west-2"}): |
| 148 | + BedrockModel() |
| 149 | + |
| 150 | + mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY) |
| 151 | + |
| 152 | + |
| 153 | +def test__init__region_precedence(mock_client_method, session_cls): |
| 154 | + """Test that BedrockModel uses the correct ordering of precedence when determining region.""" |
| 155 | + with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "us-environment-1"}): |
| 156 | + session_cls.return_value.region_name = "us-session-1" |
| 157 | + |
| 158 | + # specifying a region always wins out |
| 159 | + BedrockModel(region_name="us-specified-1") |
| 160 | + mock_client_method.assert_called_with(region_name="us-specified-1", config=ANY, service_name=ANY) |
| 161 | + |
| 162 | + # other-wise uses the session's |
| 163 | + BedrockModel() |
| 164 | + mock_client_method.assert_called_with(region_name="us-session-1", config=ANY, service_name=ANY) |
| 165 | + |
| 166 | + # environment variable next |
| 167 | + session_cls.return_value.region_name = None |
| 168 | + BedrockModel() |
| 169 | + mock_client_method.assert_called_with(region_name="us-environment-1", config=ANY, service_name=ANY) |
| 170 | + |
| 171 | + # Finally default |
| 172 | + BedrockModel() |
| 173 | + mock_client_method.assert_called_with(region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY) |
143 | 174 |
|
144 | 175 |
|
145 | 176 | def test__init__with_region_and_session_raises_value_error():
|
|
0 commit comments