Skip to content

Commit

Permalink
feat: Use existing Knowledge Base (#638)
Browse files Browse the repository at this point in the history
* wip

* wip

* wip

* disable

* fix ci and test code

* modify cdk test code

* has_exist_knowlednge_base_id -> has_exist_knowledge_base_id.

* knowledge_base_idとexist_knowledge_base_idが排他であることを強制する

* review comment対応

* コメント対応

* 不要な修正を削除, CIエラー解消

* Sfnの処理で使うデータ構造を修正

* UI変更

* UI変更

* content_model_from_contentの修正

* StackOutputの型エラー修正

* validate_knowledge_base_idsのpydanticのvalidator修正

* 誤った修正を戻す

* has_knowledgeに関する修正

* fix mypy --check-untyped-defs

* validate_knowledge_base_idsの削除

* インデント調整

* 修正誤り訂正

* fix CI

---------

Co-authored-by: statefb <[email protected]>
  • Loading branch information
fsatsuki and statefb authored Dec 18, 2024
1 parent 46fd095 commit 3b48782
Show file tree
Hide file tree
Showing 28 changed files with 768 additions and 426 deletions.
11 changes: 6 additions & 5 deletions backend/app/bot_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def delete_custom_bot_stack_by_bot_id(bot_id: str):
stack_name = f"BrChatKbStack{bot_id}"
try:
response = client.delete_stack(StackName=stack_name)
except client.exceptions.ClientError as e:
except client.exceptions.ClientError:
raise RecordNotFoundError()
return response

Expand Down Expand Up @@ -49,7 +49,7 @@ def delete_from_s3(user_id: str, bot_id: str):
print(e)


def handler(event, context):
def handler(event: dict, context: Any) -> None:
"""Bot removal handler.
This function is triggered by dynamodb stream when item is deleted.
Following resources are deleted asynchronously when bot is deleted:
Expand Down Expand Up @@ -84,9 +84,10 @@ def handler(event, context):
return

# Before delete cfn stack, delete all api keys
usage_plan = find_usage_plan_by_id(stack.api_usage_plan_id)
for key_id in usage_plan.key_ids:
delete_api_key(key_id)
if stack.api_usage_plan_id: # Add type check
usage_plan = find_usage_plan_by_id(stack.api_usage_plan_id)
for key_id in usage_plan.key_ids:
delete_api_key(key_id)

# Delete `ApiPublishmentStack` by CloudFormation
delete_stack_by_bot_id(bot_id)
16 changes: 9 additions & 7 deletions backend/app/repositories/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from typing import Dict, List, Optional, Sequence

import boto3

Expand Down Expand Up @@ -63,7 +64,7 @@ def decompose_related_document_source_id(composed_id: str):
return composed_id.split("#")[-1]


def _get_aws_resource(service_name, user_id=None):
def _get_aws_resource(service_name: str, user_id: Optional[str] = None):
"""Get AWS resource with optional row-level access control for DynamoDB.
Ref: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_examples_dynamodb_items.html
"""
Expand All @@ -75,11 +76,11 @@ def _get_aws_resource(service_name, user_id=None):
aws_access_key_id="key",
aws_secret_access_key="key",
region_name=REGION,
)
) # type: ignore[call-overload]
else:
return boto3.resource(service_name, region_name=REGION)
return boto3.resource(service_name, region_name=REGION) # type: ignore[call-overload]

policy_document = {
policy_document: Dict[str, List[Dict]] = {
"Statement": [
{
"Effect": "Allow",
Expand All @@ -103,6 +104,7 @@ def _get_aws_resource(service_name, user_id=None):
}
]
}

if user_id:
policy_document["Statement"][0]["Condition"] = {
# Allow access to items with the same partition key as the user id
Expand All @@ -121,15 +123,15 @@ def _get_aws_resource(service_name, user_id=None):
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
return session.resource(service_name, region_name=REGION)
return session.resource(service_name, region_name=REGION) # type: ignore[call-overload]


def _get_dynamodb_client(user_id=None):
def _get_dynamodb_client(user_id: Optional[str] = None):
"""Get a DynamoDB client, optionally with row-level access control."""
return _get_aws_resource("dynamodb", user_id=user_id).meta.client


def _get_table_client(user_id):
def _get_table_client(user_id: str):
"""Get a DynamoDB table client with row-level access."""
return _get_aws_resource("dynamodb", user_id=user_id).Table(TABLE_NAME)

Expand Down
2 changes: 1 addition & 1 deletion backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def to_contents_for_converse(self) -> list[ContentBlockTypeDef]:


def content_model_from_content(content: Content) -> ContentModel:

if isinstance(content, TextContent):
return TextContentModel.from_text_content(content=content)

Expand All @@ -513,7 +514,6 @@ def content_model_from_content(content: Content) -> ContentModel:

elif isinstance(content, ToolResultContent):
return ToolResultContentModel.from_tool_result_content(content=content)

else:
raise ValueError(f"Unknown content type")

Expand Down
1 change: 1 addition & 0 deletions backend/app/repositories/models/custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def has_knowledge(self) -> bool:
or len(self.knowledge.sitemap_urls) > 0
or len(self.knowledge.filenames) > 0
or len(self.knowledge.s3_urls) > 0
or self.bedrock_knowledge_base is not None
)

def is_agent_enabled(self) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion backend/app/repositories/models/custom_bot_kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
type_os_token_filter,
type_os_tokenizer,
)
from pydantic import BaseModel
from typing import Self
from pydantic import BaseModel, validator, model_validator


class SearchParamsModel(BaseModel):
Expand Down Expand Up @@ -72,6 +73,7 @@ class BedrockKnowledgeBaseModel(BaseModel):
)
search_params: SearchParamsModel
knowledge_base_id: str | None = None
exist_knowledge_base_id: str | None = None
data_source_ids: list[str] | None = None
parsing_model: type_kb_parsing_model = "disabled"
web_crawling_scope: type_kb_web_crawling_scope = "DEFAULT"
Expand Down
2 changes: 2 additions & 0 deletions backend/app/routes/schemas/bot_kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class BedrockKnowledgeBaseInput(BaseSchema):
)
search_params: SearchParams
knowledge_base_id: str | None = None
exist_knowledge_base_id: str | None = None
parsing_model: type_kb_parsing_model = "disabled"
web_crawling_scope: type_kb_web_crawling_scope = "DEFAULT"
web_crawling_filters: WebCrawlingFilters = WebCrawlingFilters(
Expand All @@ -113,6 +114,7 @@ class BedrockKnowledgeBaseOutput(BaseSchema):
)
search_params: SearchParams
knowledge_base_id: str | None = None
exist_knowledge_base_id: str | None = None
data_source_ids: list[str] | None = None
parsing_model: type_kb_parsing_model = "disabled"
web_crawling_scope: type_kb_web_crawling_scope = "DEFAULT"
Expand Down
6 changes: 6 additions & 0 deletions backend/app/usecases/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,18 @@ def create_new_bot(user_id: str, bot_input: BotInput) -> BotOutput:
or len(bot_input.knowledge.sitemap_urls) > 0
or len(bot_input.knowledge.filenames) > 0
or len(bot_input.knowledge.s3_urls) > 0
# This is a condition for running Sfn to register existing KB information in DynamoDB when an existing KB is specified.
or (
bot_input.bedrock_knowledge_base is not None
and bot_input.bedrock_knowledge_base.exist_knowledge_base_id is not None
)
)

has_guardrails = (
bot_input.bedrock_guardrails
and bot_input.bedrock_guardrails.is_guardrail_enabled == True
)

sync_status: type_sync_status = (
"QUEUED" if has_knowledge or has_guardrails else "SUCCEEDED"
)
Expand Down
8 changes: 7 additions & 1 deletion backend/app/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _bedrock_knowledge_base_search(bot: BotModel, query: str) -> list[SearchResu
bot.bedrock_knowledge_base is not None
and bot.bedrock_knowledge_base.knowledge_base_id is not None
)

if bot.bedrock_knowledge_base.search_params.search_type == "semantic":
search_type = "SEMANTIC"
elif bot.bedrock_knowledge_base.search_params.search_type == "hybrid":
Expand All @@ -72,7 +73,12 @@ def _bedrock_knowledge_base_search(bot: BotModel, query: str) -> list[SearchResu
raise ValueError("Invalid search type")

limit = bot.bedrock_knowledge_base.search_params.max_results
knowledge_base_id = bot.bedrock_knowledge_base.knowledge_base_id
# Use exist_knowledge_base_id if available, otherwise use knowledge_base_id
knowledge_base_id = (
bot.bedrock_knowledge_base.exist_knowledge_base_id
if bot.bedrock_knowledge_base.exist_knowledge_base_id is not None
else bot.bedrock_knowledge_base.knowledge_base_id
)

try:
response = agent_client.retrieve(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from typing import TypedDict, List, Dict, Any

import boto3
from app.repositories.custom_bot import (
decompose_bot_id,
Expand All @@ -9,7 +11,21 @@
cf_client = boto3.client("cloudformation", BEDROCK_REGION)


def handler(event, context):
class StackItem(TypedDict):
KnowledgeBaseId: str
DataSourceId: str
GuardrailArn: str
GuardrailVersion: str
PK: str
SK: str


class StackResult(TypedDict):
KnowledgeBaseId: str
items: List[StackItem]


def handler(event: Dict[str, str], context: Any) -> StackResult:
print(event)
pk = event["pk"]
sk = event["sk"]
Expand All @@ -24,29 +40,30 @@ def handler(event, context):
outputs = response["Stacks"][0]["Outputs"]

knowledge_base_id = None
data_source_ids = []
data_source_ids: List[str] = []
guardrail_arn = None
guardrail_version = None
result: StackResult = {"KnowledgeBaseId": "", "items": []}

for output in outputs:
if output["OutputKey"] == "KnowledgeBaseId":
knowledge_base_id = output["OutputValue"]
result["KnowledgeBaseId"] = knowledge_base_id
elif output["OutputKey"].startswith("DataSource"):
data_source_ids.append(output["OutputValue"])
elif output["OutputKey"] == "GuardrailArn":
guardrail_arn = output["OutputValue"]
elif output["OutputKey"] == "GuardrailVersion":
guardrail_version = output["OutputValue"]

result = []
for data_source_id in data_source_ids:
result.append(
result["items"].append(
{
"KnowledgeBaseId": knowledge_base_id,
"KnowledgeBaseId": knowledge_base_id or "",
"DataSourceId": data_source_id,
"GuardrailArn": guardrail_arn if guardrail_arn != None else "",
"GuardrailArn": guardrail_arn if guardrail_arn is not None else "",
"GuardrailVersion": (
guardrail_version if guardrail_version != None else ""
guardrail_version if guardrail_version is not None else ""
),
"PK": pk,
"SK": sk,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,45 @@
from app.repositories.custom_bot import decompose_bot_id, update_knowledge_base_id
from app.routes.schemas.bot import type_sync_status
from retry import retry
from typing import List
from typing_extensions import TypedDict

logger = logging.getLogger()
logger.setLevel(logging.INFO)


class StackOutput(TypedDict):
class Items(TypedDict):
KnowledgeBaseId: str
DataSourceId: str
GuardrailArn: str
GuardrailVersion: str
PK: str
SK: str


class StackOutput(TypedDict):
KnowledgeBaseId: str
items: List[Items]


def handler(event, context):
logger.info(f"Event: {event}")
pk = event["pk"]
sk = event["sk"]
stack_output: list[StackOutput] = event["stack_output"]

kb_id = stack_output[0]["KnowledgeBaseId"]
data_source_ids = [x["DataSourceId"] for x in stack_output]
stack_output: StackOutput = event["stack_output"]

kb_id = (
stack_output["KnowledgeBaseId"] if "KnowledgeBaseId" in stack_output else None
)
if not kb_id:
raise ValueError("KnowledgeBaseId not found in stack outputs")

# Filter out None values and ensure all elements are strings
data_source_ids: List[str] = [
item["DataSourceId"]
for item in stack_output.get("items", [])
if item.get("DataSourceId")
]

user_id = pk
bot_id = decompose_bot_id(sk)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
from typing import Literal

import boto3
from app.repositories.common import _get_table_client
Expand Down Expand Up @@ -70,6 +71,14 @@ def handler(event, context):
try:
cause = event.get("cause", None)
ingestion_job = event.get("ingestion_job", None)

# Initialize variables
pk: str
sk: str
sync_status: type_sync_status
sync_status_reason: str
last_exec_id: str

if cause:
# UpdateSymcStatusFailed
pk, sk, build_arn = extract_from_cause(cause)
Expand Down
20 changes: 13 additions & 7 deletions backend/embedding_statemachine/guardrails/store_guardrail_arn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
from app.routes.schemas.bot import type_sync_status
from app.repositories.common import _get_table_client
from app.repositories.custom_bot import decompose_bot_id, update_guardrails_params
from typing import TypedDict
from typing import TypedDict, List

logger = logging.getLogger()
logger.setLevel(logging.INFO)


class StackOutput(TypedDict):
KnowledgeBaseId: str
DataSourceId: str
GuardrailArn: str
GuardrailVersion: str

Expand All @@ -24,10 +22,18 @@ def handler(event, context):
logger.info(f"Event: {event}")
pk = event["pk"]
sk = event["sk"]
stack_output: list[StackOutput] = event["stack_output"]

guardrail_arn = stack_output[0]["GuardrailArn"]
guardrail_version = stack_output[0]["GuardrailVersion"]
stack_output: List[StackOutput] = event["stack_output"]

# Check if stack_output is valid and has at least one item
if not stack_output or not isinstance(stack_output, list) or len(stack_output) == 0:
logger.warning("Empty or invalid stack_output received")
guardrail_arn = ""
guardrail_version = ""
else:
# Access the first item directly since we know it exists
first_output = stack_output[0]
guardrail_arn = first_output.get("GuardrailArn", "")
guardrail_version = first_output.get("GuardrailVersion", "")

user_id = pk
bot_id = decompose_bot_id(sk)
Expand Down
3 changes: 2 additions & 1 deletion backend/tests/test_repositories/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,9 @@ def mock_query_side_effect(**kwargs):
self.assertEqual(content[0].content_type, "text")
self.assertEqual(content[0].body, "Hello")
self.assertEqual(content[1].content_type, "image")
# Convert the raw bytes to base64 for comparison
self.assertEqual(
content[1].body,
base64.b64encode(content[1].body).decode(),
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=",
)
self.assertEqual(message_map["a"].model, "claude-instant-v1")
Expand Down
Loading

1 comment on commit 3b48782

@axelpina
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Thank you very much for your work. Just something minor but there's a typo in advancedConfigration -> advancedConfiguration

Please sign in to comment.