Skip to content

Commit

Permalink
test: Format + add test
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-marion committed Sep 4, 2024
1 parent 1ff4251 commit c7c605e
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 68 deletions.
3 changes: 1 addition & 2 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,7 @@ async function processCreateOptions(options: any): Promise<void> {
{
type: "confirm",
name: "advancedMonitoring",
message:
"Do you want to enable custom metrics and advanced monitoring?",
message: "Do you want to enable custom metrics and advanced monitoring?",
initial: options.advancedMonitoring || false,
},
{
Expand Down
6 changes: 6 additions & 0 deletions integtests/chatbot-api/session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,22 @@ def test_create_session(client, default_model, default_provider, session_id):
break

assert found == True

assert sessionFound.get("title") == request.get("data").get("text")



def test_get_session(client, session_id, default_model):
session = client.get_session(session_id)
assert session.get("id") == session_id
assert session.get("title") == "test"
assert len(session.get("history")) == 2
assert session.get("history")[0].get("type") == "human"
assert session.get("history")[1].get("type") == "ai"
assert session.get("history")[1].get("metadata") is not None
metadata = json.loads(session.get("history")[1].get("metadata"))
assert metadata.get("usage") is not None
assert metadata.get("usage").get("total_tokens") > 0


def test_delete_session(client, session_id):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.human import HumanMessage
from langchain_core.language_models.chat_models import BaseChatModel
from langchain import hub

logger = Logger()

Expand Down Expand Up @@ -53,7 +52,7 @@ def on_llm_end(
and isinstance(generation, ChatGeneration)
and isinstance(generation.message, AIMessage)
):
## In case of rag there could be 2 llm calls.
# In case of rag there could be 2 llm calls.
if self.usage is None:
self.usage = {
"input_tokens": 0,
Expand Down Expand Up @@ -149,29 +148,30 @@ def run_with_chain_v2(self, user_prompt, workspace_id=None):

if workspace_id:
retriever = WorkspaceRetriever(workspace_id=workspace_id)
## Only stream the last llm call (otherwise the internal llm response will be visible)
# Only stream the last llm call (otherwise the internal
# llm response will be visible)
llm_without_streaming = self.get_llm({"streaming": False})
history_aware_retriever = create_history_aware_retriever(
llm_without_streaming,
retriever,
self.get_condense_question_prompt(),
)
question_answer_chain = create_stuff_documents_chain(
self.llm, self.get_qa_prompt(),
self.llm,
self.get_qa_prompt(),
)
chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain
)
else:
chain = self.get_prompt() | self.llm


conversation = RunnableWithMessageHistory(
chain,
lambda session_id: self.chat_history,
history_messages_key="chat_history",
input_messages_key="input",
output_messages_key="output"
output_messages_key="output",
)

config = {"configurable": {"session_id": self.session_id}}
Expand Down Expand Up @@ -212,7 +212,7 @@ def run_with_chain_v2(self, user_prompt, workspace_id=None):
}
for doc in retriever.get_last_search_documents()
]

metadata = {
"modelId": self.model_id,
"modelKwargs": self.model_kwargs,
Expand All @@ -233,7 +233,8 @@ def run_with_chain_v2(self, user_prompt, workspace_id=None):
# Used by Cloudwatch filters to generate a metric of token usage.
logger.info(
"Usage Metric",
# Each unique value of model id will create a new cloudwatch metric (each one has a cost)
# Each unique value of model id will create a
# new cloudwatch metric (each one has a cost)
model=self.model_id,
metric_type="token_usage",
value=self.callback_handler.usage.get("total_tokens"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from ..base import ModelAdapter
import genai_core.clients
from langchain_aws import ChatBedrockConverse
from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder


def get_guardrails() -> dict:
if "BEDROCK_GUARDRAILS_ID" in os.environ:
Expand All @@ -33,9 +34,13 @@ def __init__(self, model_id, *args, **kwargs):
self.model_id = model_id

super().__init__(*args, **kwargs)

def get_qa_prompt(self):
system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \n\n{context}"
system_prompt = (
"Use the following pieces of context to answer the question at the end."
" If you don't know the answer, just say that you don't know, "
"don't try to make up an answer. \n\n{context}"
)
return ChatPromptTemplate.from_messages(
[
("system", system_prompt),
Expand All @@ -49,7 +54,12 @@ def get_prompt(self):
[
(
"system",
"The following is a friendly conversation between a human and an AI. If the AI does not know the answer to a question, it truthfully says it does not know.",
(
"The following is a friendly conversation between "
"a human and an AI."
"If the AI does not know the answer to a question, it "
"truthfully says it does not know."
),
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
Expand All @@ -60,7 +70,8 @@ def get_prompt(self):

def get_condense_question_prompt(self):
contextualize_q_system_prompt = (
"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question."
"Given the following conversation and a follow up"
" question, rephrase the follow up question to be a standalone question."
)
return ChatPromptTemplate.from_messages(
[
Expand Down Expand Up @@ -90,9 +101,10 @@ def get_llm(self, model_kwargs={}, extra={}):
disable_streaming=model_kwargs.get("streaming", False) == False,
callbacks=[self.callback_handler],
**params,
**extra
**extra,
)


class LLMInputOutputAdapter:
"""Adapter class to prepare the inputs from Langchain to a format
that LLM model expects.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@

sequence_number = 0

def on_llm_new_token(user_id, session_id, self, token, run_id, chunk, parent_run_id, *args, **kwargs):

def on_llm_new_token(
user_id, session_id, self, token, run_id, chunk, parent_run_id, *args, **kwargs
):
if isinstance(token, list):
# When using the newer Chat objects from Langchain.
# When using the newer Chat objects from Langchain.
# Token is not a string
text = ""
for t in token:
Expand Down
37 changes: 26 additions & 11 deletions lib/monitoring/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import { Stack } from "aws-cdk-lib";
import { IGraphqlApi } from "aws-cdk-lib/aws-appsync";
import { LogQueryWidget, MathExpression, Metric } from "aws-cdk-lib/aws-cloudwatch";
import {
LogQueryWidget,
MathExpression,
Metric,
} from "aws-cdk-lib/aws-cloudwatch";
import { ITable } from "aws-cdk-lib/aws-dynamodb";
import { IFunction as ILambdaFunction } from "aws-cdk-lib/aws-lambda";
import { CfnCollection } from "aws-cdk-lib/aws-opensearchserverless";
Expand Down Expand Up @@ -74,7 +78,11 @@ export class Monitoring extends Construct {
);

if (props.advancedMonitoring) {
this.addMetricFilter(props.prefix + "GenAI", monitoring, props.llmRequestHandlersLogGroups);
this.addMetricFilter(
props.prefix + "GenAI",
monitoring,
props.llmRequestHandlersLogGroups
);
}

const link = `https://${region}.console.aws.amazon.com/cognito/v2/idp/user-pools/${props.cognito.userPoolId}/users?region=${region}`;
Expand Down Expand Up @@ -152,17 +160,25 @@ export class Monitoring extends Construct {
}
}

private addMetricFilter(namespace: string, monitoring: MonitoringFacade, logGroups: ILogGroup[]) {
private addMetricFilter(
namespace: string,
monitoring: MonitoringFacade,
logGroups: ILogGroup[]
) {
for (const logGroupKey in logGroups) {
new MetricFilter(this, 'UsageFilter' + logGroupKey, {
new MetricFilter(this, "UsageFilter" + logGroupKey, {
logGroup: logGroups[logGroupKey],
metricNamespace: namespace,
metricName: 'TokenUsage',
filterPattern: FilterPattern.stringValue('$.metric_type', "=", "token_usage"),
metricValue: '$.value',
metricName: "TokenUsage",
filterPattern: FilterPattern.stringValue(
"$.metric_type",
"=",
"token_usage"
),
metricValue: "$.value",
dimensions: {
"model": "$.model"
}
model: "$.model",
},
});
}

Expand Down Expand Up @@ -194,7 +210,6 @@ export class Monitoring extends Construct {
},
],
});

}

private addCognitoMetrics(
Expand Down Expand Up @@ -369,7 +384,7 @@ export class Monitoring extends Construct {
*/
queryLines: [
"fields @timestamp, message, level, location" +
(extraFields.length > 0 ? "," + extraFields.join(",") : ""),
(extraFields.length > 0 ? "," + extraFields.join(",") : ""),
`filter ispresent(level)`, // only includes messages using the logger
"sort @timestamp desc",
`limit 200`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in DynamoDB"""
messages = messages_to_dict(self.messages)
if isinstance(message, AIMessageChunk):
# When streaming with RunnableWithMessageHistory,
# When streaming with RunnableWithMessageHistory,
# it would add a chunk to the history but it expects a text as content.
ai_message = ""
for c in message.content:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

logger = Logger()


class WorkspaceRetriever(BaseRetriever):
workspace_id: str
documents_found: List[Document] = []

def get_last_search_documents(self) -> List[Document]:
def get_last_search_documents(self) -> List[Document]:
return self.documents_found

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
Expand All @@ -21,9 +22,11 @@ def _get_relevant_documents(
self.workspace_id, query, limit=3, full_response=False
)

self.documents_found = [self._get_document(item) for item in result.get("items", [])]
self.documents_found = [
self._get_document(item) for item in result.get("items", [])
]
return self.documents_found

def _get_document(self, item):
content = item["content"]
content_complement = item.get("content_complement")
Expand Down
5 changes: 4 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
"build": "npx @aws-amplify/cli codegen && npx tsc",
"watch": "npx tsc -w",
"cdk": "cdk",
"deploy": "npx cdk deploy",
"hotswap": "cdk deploy --hotswap",
"test": "jest",
"pytest": "pytest tests/",
"test-all": "npm run test && npm run pytest",
"integtest": "pytest integtests/",
"gen": "npx @aws-amplify/cli codegen",
"create": "node ./dist/cli/magic.js config",
Expand All @@ -19,7 +21,8 @@
"pylint": "flake8 .",
"format": "npx prettier --ignore-path .gitignore --write \"**/*.+(js|ts|jsx|tsx|json|css)\"",
"pyformat": "black .",
"deploy": "npm run format && npx cdk deploy",
"format-lint-all": "npm run format && npm run pyformat && npm run lint && npm run pylint",
"vet-all": "npm run format-lint-all && npm run test-all",
"docs:dev": "vitepress dev docs",
"docs:build": "vitepress build docs",
"docs:preview": "vitepress preview docs"
Expand Down
Loading

0 comments on commit c7c605e

Please sign in to comment.