Skip to content

Commit

Permalink
Add reasoning (extended thinking) for claude 3.7 (#750)
Browse files Browse the repository at this point in the history
* WIP

* CI対応

* CI対応

* discriminatedの対応

* model_validate

* CIのビルドエラー修正

* レビューコメント反映

* CIエラー修正

* UIのインデント調整

* reformat

* feat: add Pyright configuration to exclude specific directories

* chore: update boto3 to support reasoning

* feat: add core functionality for reasoning

* feat: bot feature

* feat: frontend

* add test suit

* chore: fix lint err

* delete_secret_managerの実施場所変更.

* search_engineがfirecrawlなのに、firecrawl_configが未入力の場合はエラーにする

* apikeyのvalidation

* Refactor API key handling and update tool models in the bot framework

* Rename delete_secret_manager to delete_api_key_from_secret_manager

* fix: raise error when failed for internet search tool

* feat: enhance Firecrawl integration with improved validation and legacy tool handling

* chore: lint

* change: update logging level to INFO from DEBUG

* chore: lint err

---------

Co-authored-by: fsatsuki <[email protected]>
  • Loading branch information
statefb and fsatsuki authored Mar 3, 2025
1 parent acfc697 commit 8d2e925
Show file tree
Hide file tree
Showing 47 changed files with 2,029 additions and 1,470 deletions.
101 changes: 72 additions & 29 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mypy_boto3_bedrock_runtime.type_defs import (
ContentBlockTypeDef,
ConverseResponseTypeDef,
ConverseStreamRequestRequestTypeDef,
ConverseStreamRequestTypeDef,
GuardrailConverseContentBlockTypeDef,
InferenceConfigurationTypeDef,
MessageTypeDef,
Expand Down Expand Up @@ -105,7 +105,8 @@ def compose_args_for_converse_api(
grounding_source: GuardrailConverseContentBlockTypeDef | None = None,
tools: dict[str, AgentTool] | None = None,
stream: bool = True,
) -> ConverseStreamRequestRequestTypeDef:
enable_reasoning: bool = False,
) -> ConverseStreamRequestTypeDef:
def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
if c.content_type == "text":
if (
Expand Down Expand Up @@ -142,6 +143,7 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
inference_config: InferenceConfigurationTypeDef
additional_model_request_fields: dict[str, Any]
system_prompts: list[SystemContentBlockTypeDef]

if is_nova_model(model):
# Special handling for Nova models
inference_config, additional_model_request_fields = _prepare_nova_model_params(
Expand All @@ -159,35 +161,76 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:

else:
# Standard handling for non-Nova models
inference_config = {
"maxTokens": (
if enable_reasoning:
budget_tokens = (
generation_params.reasoning_params.budget_tokens
if generation_params and generation_params.reasoning_params
else DEFAULT_GENERATION_CONFIG["reasoning_params"]["budget_tokens"] # type: ignore
)
max_tokens = (
generation_params.max_tokens
if generation_params
else DEFAULT_GENERATION_CONFIG["max_tokens"]
),
"temperature": (
generation_params.temperature
if generation_params
else DEFAULT_GENERATION_CONFIG["temperature"]
),
"topP": (
generation_params.top_p
if generation_params
else DEFAULT_GENERATION_CONFIG["top_p"]
),
"stopSequences": (
generation_params.stop_sequences
if generation_params
else DEFAULT_GENERATION_CONFIG.get("stop_sequences", [])
),
}
additional_model_request_fields = {
"top_k": (
generation_params.top_k
if generation_params
else DEFAULT_GENERATION_CONFIG["top_k"]
)
}

if max_tokens <= budget_tokens:
logger.warning(
f"max_tokens ({max_tokens}) must be greater than budget_tokens ({budget_tokens}). "
f"Setting max_tokens to {budget_tokens + 1024}"
)
max_tokens = budget_tokens + 1024

inference_config = {
"maxTokens": max_tokens,
"temperature": 1.0, # Force temperature to 1.0 when reasoning is enabled
"topP": (
generation_params.top_p
if generation_params
else DEFAULT_GENERATION_CONFIG["top_p"]
),
"stopSequences": (
generation_params.stop_sequences
if generation_params
else DEFAULT_GENERATION_CONFIG.get("stop_sequences", [])
),
}
additional_model_request_fields = {
# top_k cannot be used with reasoning
"thinking": {
"type": "enabled",
"budget_tokens": budget_tokens,
},
}
else:
inference_config = {
"maxTokens": (
generation_params.max_tokens
if generation_params
else DEFAULT_GENERATION_CONFIG["max_tokens"]
),
"temperature": (
generation_params.temperature
if generation_params
else DEFAULT_GENERATION_CONFIG["temperature"]
),
"topP": (
generation_params.top_p
if generation_params
else DEFAULT_GENERATION_CONFIG["top_p"]
),
"stopSequences": (
generation_params.stop_sequences
if generation_params
else DEFAULT_GENERATION_CONFIG.get("stop_sequences", [])
),
}
additional_model_request_fields = {
"top_k": (
generation_params.top_k
if generation_params
else DEFAULT_GENERATION_CONFIG["top_k"]
),
}
system_prompts = [
{
"text": instruction,
Expand All @@ -197,7 +240,7 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
]

# Construct the base arguments
args: ConverseStreamRequestRequestTypeDef = {
args: ConverseStreamRequestTypeDef = {
"inferenceConfig": inference_config,
"modelId": get_model_id(model),
"messages": arg_messages,
Expand Down Expand Up @@ -230,7 +273,7 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:


def call_converse_api(
args: ConverseStreamRequestRequestTypeDef,
args: ConverseStreamRequestTypeDef,
) -> ConverseResponseTypeDef:
client = get_bedrock_runtime_client()

Expand Down
11 changes: 8 additions & 3 deletions backend/app/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict


class GenerationParams(TypedDict):
Expand All @@ -7,6 +7,7 @@ class GenerationParams(TypedDict):
top_p: float
temperature: float
stop_sequences: list[str]
reasoning_params: NotRequired[dict[str, int]]


class EmbeddingConfig(TypedDict):
Expand All @@ -20,11 +21,15 @@ class EmbeddingConfig(TypedDict):
# Adjust the values according to your application.
# See: https://docs.anthropic.com/claude/reference/complete_post
DEFAULT_GENERATION_CONFIG: GenerationParams = {
"max_tokens": 2000,
# Minimum (Haiku) is 4096
# Ref: https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-comparison
"max_tokens": 4096,
"top_k": 250,
"top_p": 0.999,
"temperature": 0.6,
"temperature": 1.0,
"stop_sequences": ["Human: ", "Assistant: "],
# Budget tokens must NOT exceeds max_tokens
"reasoning_params": {"budget_tokens": 1024},
}

# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html#model-parameters-mistral-request-response
Expand Down
11 changes: 5 additions & 6 deletions backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@
from decimal import Decimal as decimal

import boto3
from boto3.dynamodb.conditions import Key
from botocore.exceptions import ClientError
from pydantic import TypeAdapter

from app.repositories.common import (
TRANSACTION_BATCH_SIZE,
RecordNotFoundError,
_get_table_client,
compose_conv_id,
decompose_conv_id,
compose_related_document_source_id,
decompose_conv_id,
decompose_related_document_source_id,
)
from app.repositories.models.conversation import (
Expand All @@ -25,9 +21,12 @@
RelatedDocumentModel,
ToolResultModel,
)
from boto3.dynamodb.conditions import Key
from botocore.exceptions import ClientError
from pydantic import TypeAdapter

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.INFO)

THRESHOLD_LARGE_MESSAGE = 300 * 1024 # 300KB
LARGE_MESSAGE_BUCKET = os.environ.get("LARGE_MESSAGE_BUCKET")
Expand Down
26 changes: 20 additions & 6 deletions backend/app/repositories/custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,16 @@ def find_private_bot_by_id(user_id: str, bot_id: str) -> BotModel:
public_bot_id=None if "PublicBotId" not in item else item["PublicBotId"],
owner_user_id=user_id,
generation_params=GenerationParamsModel.model_validate(
item["GenerationParams"]
if "GenerationParams" in item
else DEFAULT_GENERATION_CONFIG
{
**item.get("GenerationParams", DEFAULT_GENERATION_CONFIG),
# For backward compatibility
"reasoning_params": item.get("GenerationParams", {}).get(
"reasoning_params",
{
"budget_tokens": DEFAULT_GENERATION_CONFIG["reasoning_params"]["budget_tokens"], # type: ignore
},
),
}
),
agent=(
AgentModel.model_validate(item["AgentData"])
Expand Down Expand Up @@ -527,9 +534,16 @@ def find_public_bot_by_id(bot_id: str) -> BotModel:
public_bot_id=item["PublicBotId"],
owner_user_id=item["PK"],
generation_params=GenerationParamsModel.model_validate(
item["GenerationParams"]
if "GenerationParams" in item
else DEFAULT_GENERATION_CONFIG
{
**item.get("GenerationParams", DEFAULT_GENERATION_CONFIG),
# For backward compatibility
"reasoning_params": item.get("GenerationParams", {}).get(
"reasoning_params",
{
"budget_tokens": DEFAULT_GENERATION_CONFIG["reasoning_params"]["budget_tokens"], # type: ignore
},
),
}
),
agent=(
AgentModel.model_validate(item["AgentData"])
Expand Down
45 changes: 43 additions & 2 deletions backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import re
from pathlib import Path
from typing import Annotated, Any, Literal, Self, TypeGuard, TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated, Any, Literal, Self, TypeGuard
from urllib.parse import urlparse

from app.repositories.models.common import Base64EncodedBytes
Expand All @@ -15,6 +15,7 @@
ImageToolResult,
JsonToolResult,
MessageInput,
ReasoningContent,
RelatedDocument,
SimpleMessage,
TextContent,
Expand Down Expand Up @@ -547,12 +548,52 @@ def to_contents_for_converse(self) -> list[ContentBlockTypeDef]:
]


class ReasoningContentModel(BaseModel):
content_type: Literal["reasoning"]
text: str
signature: str
redacted_content: Base64EncodedBytes

def to_content(self) -> Content:
return ReasoningContent(
content_type="reasoning",
text=self.text,
signature=self.signature,
redacted_content=self.redacted_content,
)

def to_contents_for_converse(self) -> list[ContentBlockTypeDef]:
# Ref: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
# Ref: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking

if self.text:
return [
{
"reasoningContent": { # type: ignore
"reasoningText": {
"text": self.text,
"signature": self.signature,
},
}
}
]
else:
return [
{
"reasoningContent": { # type: ignore
"redactedContent": {"data": self.redacted_content},
}
}
]


ContentModel = Annotated[
TextContentModel
| ImageContentModel
| AttachmentContentModel
| ToolUseContentModel
| ToolResultContentModel,
| ToolResultContentModel
| ReasoningContentModel,
Discriminator("content_type"),
]

Expand Down
11 changes: 11 additions & 0 deletions backend/app/repositories/models/custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,23 @@ def __str_in_claude_format__(self) -> str:
return f"{_source_urls}{_sitemap_urls}{_filenames}{_s3_urls}"


class ReasoningParamsModel(BaseModel):
budget_tokens: int

@field_validator("budget_tokens")
def validate_budget_tokens(cls, v: int) -> int:
if v < 1024:
raise ValueError("budget_tokens must be greater than or equal to 1024")
return v


class GenerationParamsModel(BaseModel):
max_tokens: int
top_k: int
top_p: Float
temperature: Float
stop_sequences: list[str]
reasoning_params: ReasoningParamsModel


class FirecrawlConfigModel(BaseModel):
Expand Down
8 changes: 1 addition & 7 deletions backend/app/routes/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,7 @@ def get_private_bot(request: Request, bot_id: str):
filenames=bot.knowledge.filenames,
s3_urls=bot.knowledge.s3_urls,
),
generation_params=GenerationParams(
max_tokens=bot.generation_params.max_tokens,
top_k=bot.generation_params.top_k,
top_p=bot.generation_params.top_p,
temperature=bot.generation_params.temperature,
stop_sequences=bot.generation_params.stop_sequences,
),
generation_params=GenerationParams(**bot.generation_params.model_dump()),
sync_status=bot.sync_status,
sync_status_reason=bot.sync_status_reason,
sync_last_exec_id=bot.sync_last_exec_id,
Expand Down
1 change: 1 addition & 0 deletions backend/app/routes/published_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def post_message(request: Request, message_input: ChatInputWithoutBotId):
),
bot_id=bot_id,
continue_generate=message_input.continue_generate,
enable_reasoning=message_input.enable_reasoning,
)

try:
Expand Down
5 changes: 5 additions & 0 deletions backend/app/routes/schemas/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,17 @@ def create_model_activate_output(model_names: List[str]) -> Type[BaseSchema]:
ActiveModelsOutput = create_model_activate_output(list(get_args(type_model_name)))


class ReasoningParams(BaseSchema):
budget_tokens: int


class GenerationParams(BaseSchema):
max_tokens: int
top_k: int
top_p: float
temperature: float
stop_sequences: list[str]
reasoning_params: ReasoningParams


class FirecrawlConfig(BaseSchema):
Expand Down
Loading

0 comments on commit 8d2e925

Please sign in to comment.