Skip to content

Added AWS Bedrock support as a provider #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ ALIBABA_API_KEY=
MOONSHOT_ENDPOINT=https://api.moonshot.cn/v1
MOONSHOT_API_KEY=

# AWS only needs the region and the default aws credentials. TODO: Add support for AWS profiles
AWS_BEDROCK_REGION=us-west-2
BEDROCK_API_KEY=

# Set to false to disable anonymized telemetry
ANONYMIZED_TELEMETRY=true

Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ services:
- CHROME_DEBUGGING_HOST=localhost
volumes:
- /tmp/.X11-unix:/tmp/.X11-unix
- ~/.aws/credentials:/root/.aws/credentials
restart: unless-stopped
shm_size: '2gb'
cap_add:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pyperclip==1.9.0
gradio==5.10.0
json-repair
langchain-mistralai==0.2.4
langchain-aws
14 changes: 13 additions & 1 deletion src/agent/custom_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,13 @@ def __init__(
tool_calling_method: Optional[str] = 'auto',
page_extraction_llm: Optional[BaseChatModel] = None,
planner_llm: Optional[BaseChatModel] = None,
planner_interval: int = 1, # Run planner every N steps
planner_interval: int = 1, # Run planner every N steps,
placeholders: Optional[str] = None,
):

# make placeholders available to the class
self.placeholders = placeholders

# Load sensitive data from environment variables
env_sensitive_data = {}
for key, value in os.environ.items():
Expand Down Expand Up @@ -239,6 +243,14 @@ async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutpu

ai_content = ai_content.replace("```json", "").replace("```", "")
ai_content = repair_json(ai_content)


# Replace placeholders in ai_content with values from self.placeholders
for key, value in self.placeholders.items():
print(key, value)
ai_content = ai_content.replace(key, value)


parsed_json = json.loads(ai_content)
parsed: AgentOutput = self.AgentOutput(**parsed_json)

Expand Down
377 changes: 0 additions & 377 deletions src/utils/deep_research.py

This file was deleted.

24 changes: 21 additions & 3 deletions src/utils/default_config_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,26 @@ def default_config():
"max_actions_per_step": 10,
"use_vision": True,
"tool_calling_method": "auto",
"llm_provider": "openai",
"llm_model_name": "gpt-4o",
"llm_provider": "bedrock",
"llm_model_name": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"prerequisite": """

import boto3

session = boto3.Session(region_name="us-west-2")
sagemaker_client = session.client("sagemaker")

response = sagemaker_client.create_presigned_domain_url(
DomainId="d-8aldpksok8tq",
UserProfileName="arkaprav-ssh-test"
)

PLACEHOLDERS={}
PLACEHOLDERS["PLACEHOLDER_URL"] = response["AuthorizedUrl"]



""",
"llm_num_ctx": 32000,
"llm_temperature": 1.0,
"llm_base_url": "",
Expand All @@ -28,7 +46,7 @@ def default_config():
"save_recording_path": "./tmp/record_videos",
"save_trace_path": "./tmp/traces",
"save_agent_history_path": "./tmp/agent_history",
"task": "go to google.com and type 'OpenAI' click search and give me the first url",
"task": "open PLACEHOLDER_URL and open space names test1234 else create it",
}


Expand Down
21 changes: 20 additions & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from pathlib import Path
from typing import Dict, Optional
import requests
import boto3

from langchain_anthropic import ChatAnthropic
from langchain_mistralai import ChatMistralAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_aws import ChatBedrock
import gradio as gr

from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
Expand All @@ -21,7 +23,8 @@
"deepseek": "DeepSeek",
"google": "Google",
"alibaba": "Alibaba",
"moonshot": "MoonShot"
"moonshot": "MoonShot",
"bedrock": "AWS Bedrock"
}

def get_llm_model(provider: str, **kwargs):
Expand Down Expand Up @@ -158,6 +161,21 @@ def get_llm_model(provider: str, **kwargs):
base_url=os.getenv("MOONSHOT_ENDPOINT"),
api_key=os.getenv("MOONSHOT_API_KEY"),
)
elif provider == "bedrock":
region = kwargs.get("region", "") or os.getenv("AWS_BEDROCK_REGION", "us-west-2")

session = boto3.Session(region_name=region)
bedrock_runtime = session.client(
service_name="bedrock-runtime",
region_name=region,
)

model_id = kwargs.get("model_name", "anthropic.claude-3-5-sonnet-20241022-v2:0")

return ChatBedrock(
client=bedrock_runtime,
model_id=model_id,
)
else:
raise ValueError(f"Unsupported provider: {provider}")

Expand All @@ -172,6 +190,7 @@ def get_llm_model(provider: str, **kwargs):
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
"moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"],
"bedrock": ["anthropic.claude-3-5-sonnet-20241022-v2:0"]
}

# Callback to update the model name dropdown based on the selected provider
Expand Down
30 changes: 0 additions & 30 deletions tests/test_deep_research.py

This file was deleted.

Loading