From 597ec09f4980ae5b9986e764bf8483e3598b9475 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=AE=AE=E0=AE=A9=E0=AF=8B=E0=AE=9C=E0=AF=8D=E0=AE=95?= =?UTF-8?q?=E0=AF=81=E0=AE=AE=E0=AE=BE=E0=AE=B0=E0=AF=8D=20=E0=AE=AA?= =?UTF-8?q?=E0=AE=B4=E0=AE=A9=E0=AE=BF=E0=AE=9A=E0=AF=8D=E0=AE=9A=E0=AE=BE?= =?UTF-8?q?=E0=AE=AE=E0=AE=BF?= Date: Fri, 14 Feb 2025 02:36:54 +0000 Subject: [PATCH] add unit tests --- .../unit_tests/chat_models/test_bedrock.py | 35 ++++++++++++++++++ .../chat_models/test_bedrock_converse.py | 37 +++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py index 4e43ff04..a7bf60e2 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py @@ -2,6 +2,7 @@ """Test chat model integration.""" +import os from contextlib import nullcontext from typing import Any, Callable, Dict, Literal, Type, cast @@ -472,3 +473,37 @@ def test__get_provider(model_id, provider, expected_provider, expectation) -> No llm = ChatBedrock(model_id=model_id, provider=provider, region_name="us-west-2") with expectation: assert llm._get_provider() == expected_provider + + +def test_chat_bedrock_different_regions() -> None: + regions = ["us-east-1", "us-west-2", "ap-south-2"] + for region in regions: + llm = ChatBedrock(model_id="anthropic.claude-3-sonnet-20240229-v1:0", region_name=region) + assert llm.region_name == region + + +def test_chat_bedrock_environment_variable() -> None: + regions = ["us-east-1", "us-west-2", "ap-south-2"] + for region in regions: + os.environ["AWS_REGION"] = region + llm = ChatBedrock(model_id="anthropic.claude-3-sonnet-20240229-v1:0") + assert llm.region_name == region + + +def test_chat_bedrock_scenarios() -> None: + scenarios = [ + {"model_id": "anthropic.claude-3-sonnet-20240229-v1:0", "temperature": 0.5}, + {"model_id": "anthropic.claude-3-sonnet-20240229-v1:0", "max_tokens": 50}, + { + "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", + "temperature": 0.5, + "max_tokens": 50, + }, + ] + for scenario in scenarios: + llm = ChatBedrock(region_name="us-west-2", **scenario) + assert llm.model_id == scenario["model_id"] + if "temperature" in scenario: + assert llm.temperature == scenario["temperature"] + if "max_tokens" in scenario: + assert llm.max_tokens == scenario["max_tokens"] diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index d61830dc..70fa51f8 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -1,6 +1,7 @@ """Test chat model integration.""" import base64 +import os from typing import Dict, List, Tuple, Type, Union, cast import pytest @@ -502,3 +503,39 @@ def test__extract_response_metadata() -> None: } response_metadata = _extract_response_metadata(response) assert response_metadata["metrics"]["latencyMs"] == [191] + + +def test_chat_bedrock_converse_different_regions() -> None: + regions = ["us-east-1", "us-west-2", "ap-south-2"] + for region in regions: + llm = ChatBedrockConverse( + model="anthropic.claude-3-sonnet-20240229-v1:0", region_name=region + ) + assert llm.region_name == region + + +def test_chat_bedrock_converse_environment_variable() -> None: + regions = ["us-east-1", "us-west-2", "ap-south-2"] + for region in regions: + os.environ["AWS_REGION"] = region + llm = ChatBedrockConverse(model="anthropic.claude-3-sonnet-20240229-v1:0") + assert llm.region_name == region + + +def test_chat_bedrock_converse_scenarios() -> None: + scenarios = [ + {"model": "anthropic.claude-3-sonnet-20240229-v1:0", "temperature": 0.5}, + {"model": "anthropic.claude-3-sonnet-20240229-v1:0", "max_tokens": 50}, + { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "temperature": 0.5, + "max_tokens": 50, + }, + ] + for scenario in scenarios: + llm = ChatBedrockConverse(region_name="us-west-2", **scenario) + assert llm.model_id == scenario["model"] + if "temperature" in scenario: + assert llm.temperature == scenario["temperature"] + if "max_tokens" in scenario: + assert llm.max_tokens == scenario["max_tokens"]