Skip to content

Commit

Permalink
Feature: Use a Bedrock Knowledge Base as RAG source (Backend) (#428)
Browse files Browse the repository at this point in the history
* wip

* refactor: apply sfn to invoke ecs task

* wip

* wip

* fix

* fix fe

* fix poetry timoeut

* fix

* s3 validation

* lint

* omit instruction from kb

* lint
  • Loading branch information
statefb authored Jul 12, 2024
1 parent 0f03317 commit 788a9b0
Show file tree
Hide file tree
Showing 46 changed files with 3,348 additions and 1,351 deletions.
8 changes: 6 additions & 2 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@ COPY --from=public.ecr.aws/awsguru/aws-lambda-adapter:0.7.0 /lambda-adapter /opt
WORKDIR /backend

COPY ./pyproject.toml ./poetry.lock ./
RUN pip install poetry --no-cache-dir && \

ENV POETRY_REQUESTS_TIMEOUT=10800
RUN python -m pip install --upgrade pip && \
pip install poetry --no-cache-dir && \
poetry config virtualenvs.create false && \
poetry install --no-interaction --no-ansi
poetry install --no-interaction --no-ansi --only main && \
poetry cache clear --all pypi

COPY ./app ./app

Expand Down
3 changes: 1 addition & 2 deletions backend/app/agents/tools/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def _run(
search_results = dummy_search_results
else:
search_results = search_related_docs(
self.bot.id,
limit=self.bot.search_params.max_results,
self.bot,
query=query,
)

Expand Down
29 changes: 25 additions & 4 deletions backend/app/bot_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

import boto3
import pg8000
from app.repositories.apigateway import delete_api_key, find_usage_plan_by_id
from app.repositories.cloudformation import delete_stack_by_bot_id, find_stack_by_bot_id
from app.repositories.api_publication import delete_api_key, find_usage_plan_by_id
from app.repositories.api_publication import (
delete_stack_by_bot_id,
find_stack_by_bot_id,
)
from app.repositories.common import RecordNotFoundError, decompose_bot_id
from aws_lambda_powertools.utilities import parameters
from app.repositories.custom_bot import find_public_bot_by_id

DB_SECRETS_ARN = os.environ.get("DB_SECRETS_ARN", "")
DOCUMENT_BUCKET = os.environ.get("DOCUMENT_BUCKET", "documents")
Expand Down Expand Up @@ -43,6 +47,16 @@ def delete_from_postgres(bot_id: str):
conn.close()


def delete_kb_stack_by_bot_id(bot_id: str):
client = boto3.client("cloudformation")
stack_name = f"BrChatKbStack{bot_id}"
try:
response = client.delete_stack(StackName=stack_name)
except client.exceptions.ClientError as e:
raise RecordNotFoundError()
return response


def delete_from_s3(user_id: str, bot_id: str):
"""Delete all files in S3 bucket for the specified `user_id` and `bot_id`."""
prefix = f"{user_id}/{bot_id}/"
Expand Down Expand Up @@ -90,14 +104,21 @@ def handler(event, context):
user_id = pk
bot_id = decompose_bot_id(sk)

delete_from_postgres(bot_id)
delete_from_s3(user_id, bot_id)

try:
print(f"Remove Bedrock Knowledge Base Stack.")
# Remove Knowledge Base Stack
delete_kb_stack_by_bot_id(bot_id)
except RecordNotFoundError:
print(f"Remove records from PostgreSQL.")
delete_from_postgres(bot_id)

# Check if cloudformation stack exists
try:
stack = find_stack_by_bot_id(bot_id)
except RecordNotFoundError:
print(f"Bot {bot_id} cloudformation stack not found. Skipping deletion.")
print(f"Bot {bot_id} api published stack not found. Skipping deletion.")
return

# Before delete cfn stack, delete all api keys
Expand Down
130 changes: 99 additions & 31 deletions backend/app/repositories/custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
BotMeta,
BotMetaWithStackInfo,
BotModel,
ConversationQuickStarterModel,
EmbeddingParamsModel,
GenerationParamsModel,
KnowledgeModel,
SearchParamsModel,
ConversationQuickStarterModel,
)
from app.repositories.models.custom_bot_kb import BedrockKnowledgeBaseModel
from app.routes.schemas.bot import type_sync_status
from app.utils import get_current_time
from boto3.dynamodb.conditions import Attr, Key
Expand Down Expand Up @@ -78,6 +79,8 @@ def store_bot(user_id: str, custom_bot: BotModel):
starter.model_dump() for starter in custom_bot.conversation_quick_starters
],
}
if custom_bot.bedrock_knowledge_base:
item["BedrockKnowledgeBase"] = custom_bot.bedrock_knowledge_base.model_dump()

response = table.put_item(Item=item)
return response
Expand All @@ -98,44 +101,56 @@ def update_bot(
sync_status_reason: str,
display_retrieved_chunks: bool,
conversation_quick_starters: list[ConversationQuickStarterModel],
bedrock_knowledge_base: BedrockKnowledgeBaseModel | None = None,
):
"""Update bot title, description, and instruction.
NOTE: Use `update_bot_visibility` to update visibility.
"""
table = _get_table_client(user_id)
logger.info(f"Updating bot: {bot_id}")

update_expression = (
"SET Title = :title, "
"Description = :description, "
"Instruction = :instruction, "
"EmbeddingParams = :embedding_params, "
"AgentData = :agent_data, "
"Knowledge = :knowledge, "
"SyncStatus = :sync_status, "
"SyncStatusReason = :sync_status_reason, "
"GenerationParams = :generation_params, "
"SearchParams = :search_params, "
"DisplayRetrievedChunks = :display_retrieved_chunks, "
"ConversationQuickStarters = :conversation_quick_starters"
)

expression_attribute_values = {
":title": title,
":description": description,
":instruction": instruction,
":knowledge": knowledge.model_dump(),
":agent_data": agent.model_dump(),
":embedding_params": embedding_params.model_dump(),
":sync_status": sync_status,
":sync_status_reason": sync_status_reason,
":display_retrieved_chunks": display_retrieved_chunks,
":generation_params": generation_params.model_dump(),
":search_params": search_params.model_dump(),
":conversation_quick_starters": [
starter.model_dump() for starter in conversation_quick_starters
],
}
if bedrock_knowledge_base:
update_expression += ", BedrockKnowledgeBase = :bedrock_knowledge_base"
expression_attribute_values[":bedrock_knowledge_base"] = (
bedrock_knowledge_base.model_dump()
)

try:
response = table.update_item(
Key={"PK": user_id, "SK": compose_bot_id(user_id, bot_id)},
UpdateExpression="SET Title = :title, "
"Description = :description, "
"Instruction = :instruction, "
"EmbeddingParams = :embedding_params, "
"AgentData = :agent_data, " # Note: `Agent` is reserved keyword
"Knowledge = :knowledge, "
"SyncStatus = :sync_status, "
"SyncStatusReason = :sync_status_reason, "
"GenerationParams = :generation_params, "
"SearchParams = :search_params, "
"DisplayRetrievedChunks = :display_retrieved_chunks, "
"ConversationQuickStarters = :conversation_quick_starters",
ExpressionAttributeValues={
":title": title,
":description": description,
":instruction": instruction,
":knowledge": knowledge.model_dump(),
":agent_data": agent.model_dump(),
":embedding_params": embedding_params.model_dump(),
":sync_status": sync_status,
":sync_status_reason": sync_status_reason,
":display_retrieved_chunks": display_retrieved_chunks,
":generation_params": generation_params.model_dump(),
":search_params": search_params.model_dump(),
":conversation_quick_starters": [
starter.model_dump() for starter in conversation_quick_starters
],
},
UpdateExpression=update_expression,
ExpressionAttributeValues=expression_attribute_values,
ReturnValues="ALL_NEW",
ConditionExpression="attribute_exists(PK) AND attribute_exists(SK)",
)
Expand Down Expand Up @@ -249,6 +264,33 @@ def update_alias_pin_status(user_id: str, alias_id: str, pinned: bool):
return response


def update_knowledge_base_id(
user_id: str, bot_id: str, knowledge_base_id: str, data_source_ids: list[str]
):
table = _get_table_client(user_id)
logger.info(f"Updating knowledge base id for bot: {bot_id}")

try:
response = table.update_item(
Key={"PK": user_id, "SK": compose_bot_id(user_id, bot_id)},
UpdateExpression="SET BedrockKnowledgeBase.knowledge_base_id = :kb_id, BedrockKnowledgeBase.data_source_ids = :ds_ids",
ExpressionAttributeValues={
":kb_id": knowledge_base_id,
":ds_ids": data_source_ids,
},
ConditionExpression="attribute_exists(PK) AND attribute_exists(SK)",
ReturnValues="ALL_NEW",
)
logger.info(f"Updated knowledge base id for bot: {bot_id} successfully")
except ClientError as e:
if e.response["Error"]["Code"] == "ConditionalCheckFailedException":
raise RecordNotFoundError(f"Bot with id {bot_id} not found")
else:
raise e

return response


def find_private_bots_by_user_id(
user_id: str, limit: int | None = None
) -> list[BotMeta]:
Expand Down Expand Up @@ -281,6 +323,9 @@ def find_private_bots_by_user_id(
description=item["Description"],
is_public="PublicBotId" in item,
sync_status=item["SyncStatus"],
has_bedrock_knowledge_base=(
True if item.get("BedrockKnowledgeBase", None) else False
),
)
for item in response["Items"]
]
Expand All @@ -303,6 +348,9 @@ def find_private_bots_by_user_id(
description=item["Description"],
is_public="PublicBotId" in item,
sync_status=item["SyncStatus"],
has_bedrock_knowledge_base=(
True if item.get("BedrockKnowledgeBase", None) else False
),
)
for item in response["Items"]
]
Expand Down Expand Up @@ -387,7 +435,9 @@ def find_private_bot_by_id(user_id: str, bot_id: str) -> BotModel:
if "AgentData" in item
else AgentModel(tools=[])
),
knowledge=KnowledgeModel(**item["Knowledge"]),
knowledge=KnowledgeModel(
**{**item["Knowledge"], "s3_urls": item["Knowledge"].get("s3_urls", [])}
),
sync_status=item["SyncStatus"],
sync_status_reason=item["SyncStatusReason"],
sync_last_exec_id=item["LastExecId"],
Expand All @@ -406,6 +456,11 @@ def find_private_bot_by_id(user_id: str, bot_id: str) -> BotModel:
),
display_retrieved_chunks=item.get("DisplayRetrievedChunks", False),
conversation_quick_starters=item.get("ConversationQuickStarters", []),
bedrock_knowledge_base=(
BedrockKnowledgeBaseModel(**item["BedrockKnowledgeBase"])
if "BedrockKnowledgeBase" in item
else None
),
)

logger.info(f"Found bot: {bot}")
Expand Down Expand Up @@ -473,7 +528,9 @@ def find_public_bot_by_id(bot_id: str) -> BotModel:
if "AgentData" in item
else AgentModel(tools=[])
),
knowledge=KnowledgeModel(**item["Knowledge"]),
knowledge=KnowledgeModel(
**{**item["Knowledge"], "s3_urls": item["Knowledge"].get("s3_urls", [])}
),
sync_status=item["SyncStatus"],
sync_status_reason=item["SyncStatusReason"],
sync_last_exec_id=item["LastExecId"],
Expand All @@ -492,6 +549,11 @@ def find_public_bot_by_id(bot_id: str) -> BotModel:
),
display_retrieved_chunks=item.get("DisplayRetrievedChunks", False),
conversation_quick_starters=item.get("ConversationQuickStarters", []),
bedrock_knowledge_base=(
BedrockKnowledgeBaseModel(**item["BedrockKnowledgeBase"])
if "BedrockKnowledgeBase" in item
else None
),
)
logger.info(f"Found public bot: {bot}")
return bot
Expand Down Expand Up @@ -683,6 +745,9 @@ def query_dynamodb(table, bot_id):
sync_status=item["SyncStatus"],
published_api_stack_name=item.get("ApiPublishmentStackName", None),
published_api_datetime=item.get("ApiPublishedDatetime", None),
has_bedrock_knowledge_base=(
True if item.get("BedrockKnowledgeBase", None) else False
),
)
)

Expand Down Expand Up @@ -723,6 +788,9 @@ def find_all_published_bots(
sync_status=item["SyncStatus"],
published_api_stack_name=item["ApiPublishmentStackName"],
published_api_datetime=item.get("ApiPublishedDatetime", None),
has_bedrock_knowledge_base=(
True if item.get("BedrockKnowledgeBase", None) else False
),
)
for item in response["Items"]
]
Expand Down
8 changes: 8 additions & 0 deletions backend/app/repositories/models/custom_bot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from app.repositories.models.common import Float
from app.routes.schemas.bot import type_sync_status
from pydantic import BaseModel
from app.repositories.models.custom_bot_kb import BedrockKnowledgeBaseModel


class EmbeddingParamsModel(BaseModel):
Expand All @@ -13,6 +14,7 @@ class KnowledgeModel(BaseModel):
source_urls: list[str]
sitemap_urls: list[str]
filenames: list[str]
s3_urls: list[str]

def __str_in_claude_format__(self) -> str:
"""Description of the knowledge in Claude format."""
Expand Down Expand Up @@ -81,17 +83,22 @@ class BotModel(BaseModel):
published_api_codebuild_id: str | None
display_retrieved_chunks: bool
conversation_quick_starters: list[ConversationQuickStarterModel]
bedrock_knowledge_base: BedrockKnowledgeBaseModel | None

def has_knowledge(self) -> bool:
return (
len(self.knowledge.source_urls) > 0
or len(self.knowledge.sitemap_urls) > 0
or len(self.knowledge.filenames) > 0
or len(self.knowledge.s3_urls) > 0
)

def is_agent_enabled(self) -> bool:
return len(self.agent.tools) > 0

def has_bedrock_knowledge_base(self) -> bool:
return self.bedrock_knowledge_base is not None


class BotAliasModel(BaseModel):
id: str
Expand Down Expand Up @@ -121,6 +128,7 @@ class BotMeta(BaseModel):
# This can be `False` if the bot is not owned by the user and original bot is removed.
available: bool
sync_status: type_sync_status
has_bedrock_knowledge_base: bool


class BotMetaWithStackInfo(BotMeta):
Expand Down
35 changes: 35 additions & 0 deletions backend/app/repositories/models/custom_bot_kb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from app.routes.schemas.bot_kb import (
type_kb_chunking_strategy,
type_kb_embeddings_model,
type_kb_search_type,
type_os_character_filter,
type_os_token_filter,
type_os_tokenizer,
)
from pydantic import BaseModel


class SearchParamsModel(BaseModel):
max_results: int
search_type: type_kb_search_type


class AnalyzerParamsModel(BaseModel):
character_filters: list[type_os_character_filter]
tokenizer: type_os_tokenizer
token_filters: list[type_os_token_filter]


class OpenSearchParamsModel(BaseModel):
analyzer: AnalyzerParamsModel | None


class BedrockKnowledgeBaseModel(BaseModel):
embeddings_model: type_kb_embeddings_model
open_search: OpenSearchParamsModel
chunking_strategy: type_kb_chunking_strategy
search_params: SearchParamsModel
max_tokens: int | None = None
overlap_percentage: int | None = None
knowledge_base_id: str | None = None
data_source_ids: list[str] | None = None
Loading

0 comments on commit 788a9b0

Please sign in to comment.