Skip to content

Commit

Permalink
Non-text document upload directly on chat window (#467)
Browse files Browse the repository at this point in the history
* add

* black format

* chore: lint

* chore: variable name consistency (AttachmentType)

* add checking not to excess 10MB which is api gateway response size limit

* merge v1

* chore: eslint deps ignore

* fix: size limit to 6MB

* change: image size

* chore: remove eslint-disabled-line

* fix: size check bug
  • Loading branch information
statefb authored Jul 26, 2024
1 parent 6d879b8 commit cc804d9
Show file tree
Hide file tree
Showing 18 changed files with 365 additions and 140 deletions.
18 changes: 15 additions & 3 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import re
from pathlib import Path
from typing import TypedDict, no_type_check

Expand Down Expand Up @@ -94,6 +95,17 @@ def _get_converse_supported_format(ext: str) -> str:
return supported_formats.get(ext, "txt")


def _convert_to_valid_file_name(file_name: str) -> str:
# Note: The document file name can only contain alphanumeric characters,
# whitespace characters, hyphens, parentheses, and square brackets.
# The name can't contain more than one consecutive whitespace character.
file_name = re.sub(r"[^a-zA-Z0-9\s\-\(\)\[\]]", "", file_name)
file_name = re.sub(r"\s+", " ", file_name)
file_name = file_name.strip()

return file_name


@no_type_check
def compose_args_for_converse_api(
messages: list[MessageModel],
Expand Down Expand Up @@ -124,7 +136,7 @@ def compose_args_for_converse_api(
}
}
)
elif c.content_type == "textAttachment":
elif c.content_type == "attachment":
content_blocks.append(
{
"document": {
Expand All @@ -134,10 +146,10 @@ def compose_args_for_converse_api(
], # e.g. "document.txt" -> "txt"
),
"name": Path(
c.file_name
_convert_to_valid_file_name(c.file_name)
).stem, # e.g. "document.txt" -> "document"
# encode text attachment body
"source": {"bytes": c.body.encode("utf-8")},
"source": {"bytes": base64.b64decode(c.body)},
}
}
)
Expand Down
2 changes: 1 addition & 1 deletion backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class ContentModel(BaseModel):
content_type: Literal["text", "image", "textAttachment"]
content_type: Literal["text", "image", "attachment"]
media_type: str | None
body: str = Field(
...,
Expand Down
22 changes: 11 additions & 11 deletions backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class Content(BaseSchema):
content_type: Literal["text", "image", "textAttachment"] = Field(
content_type: Literal["text", "image", "attachment"] = Field(
..., description="Content type. Note that image is only available for claude 3."
)
media_type: str | None = Field(
Expand All @@ -27,7 +27,7 @@ class Content(BaseSchema):
)
file_name: str | None = Field(
None,
description="File name of the attachment. Must be specified if `content_type` is `textAttachment`.",
description="File name of the attachment. Must be specified if `content_type` is `attachment`.",
)
body: str = Field(..., description="Content body.")

Expand All @@ -42,18 +42,18 @@ def check_media_type(cls, v, values):
def check_body(cls, v, values):
content_type = values.get("content_type")

# if content_type in ["image", "textAttachment"]:
# try:
# # Check if the body is a valid base64 string
# base64.b64decode(v, validate=True)
# except Exception:
# raise ValueError(
# "body must be a valid base64 string if `content_type` is `image` or `textAttachment`"
# )

if content_type == "text" and not isinstance(v, str):
raise ValueError("body must be str if `content_type` is `text`")

if content_type in ["image", "attachment"]:
try:
# Check if the body is a valid base64 string
base64.b64decode(v, validate=True)
except Exception:
raise ValueError(
"body must be a valid base64 string if `content_type` is `image` or `attachment`"
)

return v


Expand Down
9 changes: 5 additions & 4 deletions backend/tests/test_stream/test_stream.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import sys

sys.path.append(".")
Expand Down Expand Up @@ -151,16 +152,16 @@ def test_run_with_image(self):
self._run(message)

def test_run_with_attachment(self):
# _, aws_pdf_body = get_aws_overview()
# aws_pdf_filename = "aws_arch_overview.pdf"
body = get_test_markdown()
file_name, body = get_aws_overview()
body = base64.b64encode(body).decode("utf-8")
# body = get_test_markdown()
file_name = "test.md"

message = MessageModel(
role="user",
content=[
ContentModel(
content_type="textAttachment",
content_type="attachment",
media_type=None,
body=body,
file_name=file_name,
Expand Down
38 changes: 38 additions & 0 deletions backend/tests/test_usecases/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import sys

sys.path.insert(0, ".")
Expand Down Expand Up @@ -39,6 +40,7 @@
)
from app.vector_search import SearchResult
from tests.test_stream.get_aws_logo import get_aws_logo
from tests.test_stream.get_pdf import get_aws_overview
from tests.test_usecases.utils.bot_factory import (
create_test_instruction_template,
create_test_private_bot,
Expand Down Expand Up @@ -271,6 +273,42 @@ def tearDown(self) -> None:
delete_conversation_by_id("user1", self.output.conversation_id)


class TestAttachmentChat(unittest.TestCase):
def tearDown(self) -> None:
delete_conversation_by_id("user1", self.output.conversation_id)

def test_chat(self):
file_name, body = get_aws_overview()
chat_input = ChatInput(
conversation_id="test_conversation_id",
message=MessageInput(
role="user",
content=[
Content(
content_type="attachment",
body=base64.b64encode(body).decode("utf-8"),
media_type=None,
file_name=file_name,
),
Content(
content_type="text",
body="Summarize the document.",
media_type=None,
file_name=None,
),
],
model=MODEL,
parent_message_id=None,
message_id=None,
),
bot_id=None,
continue_generate=False,
)
output: ChatOutput = chat(user_id="user1", chat_input=chat_input)
pprint(output.model_dump())
self.output = output


class TestMultimodalChat(unittest.TestCase):
def tearDown(self) -> None:
delete_conversation_by_id("user1", self.output.conversation_id)
Expand Down
1 change: 1 addition & 0 deletions cdk/lib/bedrock-chat-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ export class BedrockChatStack extends cdk.Stack {
bedrockRegion: props.bedrockRegion,
largeMessageBucket,
documentBucket,
enableMistral: props.enableMistral,
});
frontend.buildViteApp({
backendApiEndpoint: backendApi.api.apiEndpoint,
Expand Down
2 changes: 2 additions & 0 deletions cdk/lib/constructs/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export interface WebSocketProps {
readonly websocketSessionTable: ITable;
readonly largeMessageBucket: s3.IBucket;
readonly accessLogBucket?: s3.Bucket;
readonly enableMistral: boolean;
}

export class WebSocket extends Construct {
Expand Down Expand Up @@ -110,6 +111,7 @@ export class WebSocket extends Construct {
DB_SECRETS_ARN: props.dbSecrets.secretArn,
LARGE_PAYLOAD_SUPPORT_BUCKET: largePayloadSupportBucket.bucketName,
WEBSOCKET_SESSION_TABLE_NAME: props.websocketSessionTable.tableName,
ENABLE_MISTRAL: props.enableMistral.toString(),
},
role: handlerRole,
});
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/@types/conversation.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export type Model =
| 'mixtral-8x7b-instruct'
| 'mistral-large';
export type Content = {
contentType: 'text' | 'image' | 'textAttachment';
contentType: 'text' | 'image' | 'attachment';
mediaType?: string;
fileName?: string;
body: string;
Expand Down
34 changes: 25 additions & 9 deletions frontend/src/components/ChatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import ModalDialog from './ModalDialog';
import { useTranslation } from 'react-i18next';
import useChat from '../hooks/useChat';
import DialogFeedback from './DialogFeedback';
import UploadedFileText from './UploadedFileText';
import UploadedAttachedFile from './UploadedAttachedFile';
import { TEXT_FILE_EXTENSIONS } from '../constants/supportedAttachedFiles';

type Props = BaseProps & {
chatContent?: DisplayMessageContent;
Expand Down Expand Up @@ -177,20 +178,35 @@ const ChatMessage: React.FC<Props> = (props) => {
</div>
)}
{chatContent.content.some(
(content) => content.contentType === 'textAttachment'
(content) => content.contentType === 'attachment'
) && (
<div key="files" className="my-2 flex">
{chatContent.content.map((content, idx) => {
if (content.contentType === 'textAttachment') {
if (content.contentType === 'attachment') {
const isTextFile = TEXT_FILE_EXTENSIONS.some(
(ext) => content.fileName?.toLowerCase().endsWith(ext)
);
return (
<UploadedFileText
<UploadedAttachedFile
key={idx}
fileName={content.fileName ?? ''}
onClick={() => {
setDialogFileName(content.fileName ?? '');
setDialogFileContent(content.body);
setIsFileModalOpen(true);
}}
onClick={
// Only text file can be previewed
isTextFile
? () => {
const textContent = new TextDecoder(
'utf-8'
).decode(
Uint8Array.from(atob(content.body), (c) =>
c.charCodeAt(0)
)
); // base64 encoded text to be decoded string
setDialogFileName(content.fileName ?? '');
setDialogFileContent(textContent);
setIsFileModalOpen(true);
}
: undefined
}
/>
);
}
Expand Down
Loading

0 comments on commit cc804d9

Please sign in to comment.