Skip to content

Commit

Permalink
Mistral models support (#253)
Browse files Browse the repository at this point in the history
* feat: backend support mistral model

* feat: added test cases for mistral model support

* chore: removed the print for debugging

* chore: house keeping debugging records

* chore: removed the unecessary langchain

* chore: fallback the trace_to_root retrieve messages

* feat: support Mistral models

* fix: update CORS in case with custom domain associated

* fix: fixed the MISTRAL_GENERATION_CONFIG type and updated the mistral model config

* feat: ignore .pyc files under __pycache__/ for cdk

* Mistral models support debug (#1)

* chore: see if format affects the type checking for parent_id

* chore: ignore type checking

* chore: using black formatting

* chore: fix one more testing script formatting with black format

* fix: updated the mistral models stop_sequence

* refactor: moved the function get_bedrock_response to bedrock.py and updated input/output type

* refactor: removed is_mistral_model from utils, keep is_anthropic_model only.

* refactor: added invocation metrics model

* fix: cdk ignore changes from tests files change from backend/tests

* fix: fixed the typing

* refactor: integrate the mistral models testing script to test_chat

* feat: added option enableMistral to cdk.json to support toggling Misgtral models

* chore: fixed the cdk test due to introduced stack props - , and added the verification in default stack testing case.

* fix: fixed the frontend lint for availableModels

* feat: added support for Mistral large models when enable Mistral

* chore: removed some debuging logs entry

* refactor: keep the enableMistral option to boolean until it passed to the frontend building stage

* refactor: moved the InvocationMetrics model from conversation.py to bedrock.py

* refactor: inegrated mistral models test to existing ones

* chore: updated the pricing mapping variable to BEDROCK_PRICING.
  • Loading branch information
hustshawn authored Apr 26, 2024
1 parent 1635328 commit 588c099
Show file tree
Hide file tree
Showing 17 changed files with 496 additions and 123 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,15 @@ BedrockChatStack.FrontendURL = https://xxxxx.cloudfront.net
```

## Others
### Configure Mistral models support
Update `enableMistral` to `true` in [cdk.json](./cdk/cdk.json), and run `cdk deploy`.
```json
...
"enableMistral": true,
```
[!NOTE]
- This project focus on Anthropic Claude models, the Mistral models are limited supported. For example, prompt examples are based on Claude models.
- This is a Mistral-only option, once you toggled to enable Mistral models, you can only use Mistral models for all the chat features, NOT both Claude and Mistral models.

### Configure text generation

Expand Down
135 changes: 129 additions & 6 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import os

from anthropic import AnthropicBedrock
from app.config import ANTHROPIC_PRICING, DEFAULT_EMBEDDING_CONFIG, GENERATION_CONFIG
from app.config import (
BEDROCK_PRICING,
DEFAULT_EMBEDDING_CONFIG,
GENERATION_CONFIG,
MISTRAL_GENERATION_CONFIG,
)
from app.repositories.models.conversation import MessageModel
from app.utils import get_bedrock_client
from app.utils import get_bedrock_client, is_anthropic_model
from pydantic import BaseModel

logger = logging.getLogger(__name__)

Expand All @@ -16,6 +22,58 @@
anthropic_client = AnthropicBedrock()


class InvocationMetrics(BaseModel):
input_tokens: int
output_tokens: int


def compose_args(
messages: list[MessageModel],
model: str,
instruction: str | None = None,
stream: bool = False,
) -> dict:
# if model is from Anthropic, use AnthropicBedrock
# otherwise, use bedrock client
model_id = get_model_id(model)
if is_anthropic_model(model_id):
return compose_args_for_anthropic_client(messages, model, instruction, stream)
else:
return compose_args_for_other_client(messages, model, instruction, stream)


def compose_args_for_other_client(
messages: list[MessageModel],
model: str,
instruction: str | None = None,
stream: bool = False,
) -> dict:
arg_messages = []
for message in messages:
if message.role not in ["system", "instruction"]:
content: list[dict] = []
for c in message.content:
if c.content_type == "text":
content.append(
{
"type": "text",
"text": c.body,
}
)
m = {"role": message.role, "content": content}
arg_messages.append(m)

args = {
**MISTRAL_GENERATION_CONFIG,
"model": get_model_id(model),
"messages": arg_messages,
"stream": stream,
}
if instruction:
args["system"] = instruction
return args


def compose_args_for_anthropic_client(
messages: list[MessageModel],
model: str,
Expand Down Expand Up @@ -66,14 +124,14 @@ def calculate_price(
model: str, input_tokens: int, output_tokens: int, region: str = BEDROCK_REGION
) -> float:
input_price = (
ANTHROPIC_PRICING.get(region, {})
BEDROCK_PRICING.get(region, {})
.get(model, {})
.get("input", ANTHROPIC_PRICING["default"][model]["input"])
.get("input", BEDROCK_PRICING["default"][model]["input"])
)
output_price = (
ANTHROPIC_PRICING.get(region, {})
BEDROCK_PRICING.get(region, {})
.get(model, {})
.get("output", ANTHROPIC_PRICING["default"][model]["output"])
.get("output", BEDROCK_PRICING["default"][model]["output"])
)

return input_price * input_tokens / 1000.0 + output_price * output_tokens / 1000.0
Expand All @@ -91,6 +149,12 @@ def get_model_id(model: str) -> str:
return "anthropic.claude-3-haiku-20240307-v1:0"
elif model == "claude-v3-opus":
return "anthropic.claude-3-opus-20240229-v1:0"
elif model == "mistral-7b-instruct":
return "mistral.mistral-7b-instruct-v0:2"
elif model == "mixtral-8x7b-instruct":
return "mistral.mixtral-8x7b-instruct-v0:1"
elif model == "mistral-large":
return "mistral.mistral-large-2402-v1:0"
else:
raise NotImplementedError()

Expand Down Expand Up @@ -141,3 +205,62 @@ def _calculate_document_embeddings(documents: list[str]) -> list[list[float]]:
embeddings += _calculate_document_embeddings(batch)

return embeddings


def get_bedrock_response(args: dict) -> dict:

client = get_bedrock_client()
messages = args["messages"]

prompt = "\n".join(
[
message["content"][0]["text"]
for message in messages
if message["content"][0]["type"] == "text"
]
)

model_id = args["model"]
is_mistral_model = model_id.startswith("mistral")
if is_mistral_model:
prompt = f"<s>[INST] {prompt} [/INST]"

logger.info(f"Final Prompt: {prompt}")
body = json.dumps(
{
"prompt": prompt,
"max_tokens": args["max_tokens"],
"temperature": args["temperature"],
"top_p": args["top_p"],
"top_k": args["top_k"],
}
)

logger.info(f"The args before invoke bedrock: {args}")
if args["stream"]:
try:
response = client.invoke_model_with_response_stream(
modelId=model_id,
body=body,
)
# Ref: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/invoke_model_with_response_stream.html
response_body = response
except Exception as e:
logger.error(e)
else:
response = client.invoke_model(
modelId=model_id,
body=body,
)
# Ref: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/invoke_model.html
response_body = json.loads(response.get("body").read())
invocation_metrics = InvocationMetrics(
input_tokens=response["ResponseMetadata"]["HTTPHeaders"][
"x-amzn-bedrock-input-token-count"
],
output_tokens=response["ResponseMetadata"]["HTTPHeaders"][
"x-amzn-bedrock-output-token-count"
],
)
response_body["amazon-bedrock-invocationMetrics"] = invocation_metrics
return response_body
20 changes: 19 additions & 1 deletion backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ class EmbeddingConfig(TypedDict):
"stop_sequences": ["Human: ", "Assistant: "],
}

# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html#model-parameters-mistral-request-response
MISTRAL_GENERATION_CONFIG: GenerationConfig = {
"max_tokens": 4096,
"top_k": 50,
"top_p": 0.9,
"temperature": 0.5,
"stop_sequences": ["[INST]", "[/INST]"],
}

# Configure embedding parameter.
DEFAULT_EMBEDDING_CONFIG: EmbeddingConfig = {
# DO NOT change `model_id` (currently other models are not supported)
Expand All @@ -43,7 +52,7 @@ class EmbeddingConfig(TypedDict):
# Used for price estimation.
# NOTE: The following is based on 2024-03-07
# See: https://aws.amazon.com/bedrock/pricing/
ANTHROPIC_PRICING = {
BEDROCK_PRICING = {
"us-east-1": {
"claude-instant-v1": {
"input": 0.00080,
Expand All @@ -55,6 +64,9 @@ class EmbeddingConfig(TypedDict):
},
"claude-v3-haiku": {"input": 0.00025, "output": 0.00125},
"claude-v3-sonnet": {"input": 0.00300, "output": 0.01500},
"mistral-7b-instruct": {"input": 0.00015, "output": 0.0002},
"mixtral-8x7b-instruct": {"input": 0.00045, "output": 0.0007},
"mistral-large": {"input": 0.008, "output": 0.024},
},
"us-west-2": {
"claude-instant-v1": {
Expand All @@ -67,6 +79,9 @@ class EmbeddingConfig(TypedDict):
},
"claude-v3-sonnet": {"input": 0.00300, "output": 0.01500},
"claude-v3-opus": {"input": 0.01500, "output": 0.07500},
"mistral-7b-instruct": {"input": 0.00015, "output": 0.0002},
"mixtral-8x7b-instruct": {"input": 0.00045, "output": 0.0007},
"mistral-large": {"input": 0.008, "output": 0.024},
},
"ap-northeast-1": {
"claude-instant-v1": {
Expand All @@ -90,5 +105,8 @@ class EmbeddingConfig(TypedDict):
"claude-v3-haiku": {"input": 0.00025, "output": 0.00125},
"claude-v3-sonnet": {"input": 0.00300, "output": 0.01500},
"claude-v3-opus": {"input": 0.01500, "output": 0.07500},
"mistral-7b-instruct": {"input": 0.00015, "output": 0.0002},
"mixtral-8x7b-instruct": {"input": 0.00045, "output": 0.0007},
"mistral-large": {"input": 0.008, "output": 0.024},
},
}
3 changes: 3 additions & 0 deletions backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
"claude-v3-sonnet",
"claude-v3-haiku",
"claude-v3-opus",
"mistral-7b-instruct",
"mixtral-8x7b-instruct",
"mistral-large",
]


Expand Down
54 changes: 42 additions & 12 deletions backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from ulid import ULID
from app.vector_search import SearchResult, get_source_link, search_related_docs
from app.utils import (
get_bedrock_client,
get_anthropic_client,
get_current_time,
is_running_on_lambda,
is_anthropic_model,
)
import json
import logging
from copy import deepcopy
from datetime import datetime
from typing import Literal

from anthropic.types import Message as AnthropicMessage
from app.bedrock import calculate_price, compose_args_for_anthropic_client, get_model_id
from app.bedrock import (
calculate_price,
compose_args,
get_bedrock_response,
InvocationMetrics,
)
from app.config import GENERATION_CONFIG, SEARCH_CONFIG
from app.repositories.conversation import (
RecordNotFoundError,
Expand Down Expand Up @@ -294,7 +308,7 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
messages.append(chat_input.message) # type: ignore

# Create payload to invoke Bedrock
args = compose_args_for_anthropic_client(
args = compose_args(
messages=messages,
model=chat_input.message.model,
instruction=(
Expand All @@ -303,8 +317,14 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
else None
),
)
response: AnthropicMessage = client.messages.create(**args)
reply_txt = response.content[0].text

if is_anthropic_model(args["model"]):
client = get_anthropic_client()
response: AnthropicMessage = client.messages.create(**args)
reply_txt = response.content[0].text
else:
response = get_bedrock_response(args) # type: ignore
reply_txt = response["outputs"][0]["text"] # type: ignore

# Used chunks for RAG generation
used_chunks = None
Expand Down Expand Up @@ -333,11 +353,14 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id

# Update total pricing
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens

logger.debug(f"Input tokens: {input_tokens}, Output tokens: {output_tokens}")
if is_anthropic_model(args["model"]):
# Update total pricing
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
else:
metrics: InvocationMetrics = response["amazon-bedrock-invocationMetrics"] # type: ignore
input_tokens = metrics.input_tokens
output_tokens = metrics.output_tokens

price = calculate_price(chat_input.message.model, input_tokens, output_tokens)
conversation.total_price += price
Expand Down Expand Up @@ -395,6 +418,9 @@ def propose_conversation_title(
"claude-v3-opus",
"claude-v3-sonnet",
"claude-v3-haiku",
"mistral-7b-instruct",
"mixtral-8x7b-instruct",
"mistral-large",
] = "claude-v3-haiku",
) -> str:
PROMPT = """Reading the conversation above, what is the appropriate title for the conversation? When answering the title, please follow the rules below:
Expand Down Expand Up @@ -433,12 +459,16 @@ def propose_conversation_title(
messages.append(new_message)

# Invoke Bedrock
args = compose_args_for_anthropic_client(
args = compose_args(
messages=messages,
model=model,
)
response = client.messages.create(**args)
reply_txt = response.content[0].text
if is_anthropic_model(args["model"]):
response = client.messages.create(**args)
reply_txt = response.content[0].text
else:
response: AnthropicMessage = get_bedrock_response(args)["outputs"][0] # type: ignore[no-redef]
reply_txt = response["text"]
return reply_txt


Expand Down
4 changes: 4 additions & 0 deletions backend/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def is_running_on_lambda():
return "AWS_EXECUTION_ENV" in os.environ


def is_anthropic_model(model_id: str) -> bool:
return model_id.startswith("anthropic") or False


def get_bedrock_client(region=BEDROCK_REGION):
client = boto3.client("bedrock-runtime", region)
return client
Expand Down
Loading

0 comments on commit 588c099

Please sign in to comment.