Skip to content

Commit

Permalink
Merge branch 'main' into fix-chatbedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
SmartManoj committed Feb 12, 2025
2 parents 50d0f7f + 925cd56 commit 59b946a
Show file tree
Hide file tree
Showing 21 changed files with 1,167 additions and 77 deletions.
21 changes: 21 additions & 0 deletions libs/aws/langchain_aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
create_neptune_sparql_qa_chain,
)
from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse
from langchain_aws.document_compressors.rerank import BedrockRerank
from langchain_aws.embeddings import BedrockEmbeddings
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from langchain_aws.llms import BedrockLLM, SagemakerEndpoint
Expand All @@ -15,6 +16,25 @@
InMemoryVectorStore,
)


def setup_logging():
import logging
import os

if os.environ.get("LANGCHAIN_AWS_DEBUG", "FALSE").lower() in ["true", "1"]:
DEFAULT_LOG_FORMAT = (
"%(asctime)s %(levelname)s | [%(filename)s:%(lineno)s]"
"| %(name)s - %(message)s"
)
log_formatter = logging.Formatter(DEFAULT_LOG_FORMAT)
log_handler = logging.StreamHandler()
log_handler.setFormatter(log_formatter)
logging.getLogger("langchain_aws").addHandler(log_handler)
logging.getLogger("langchain_aws").setLevel(logging.DEBUG)


setup_logging()

__all__ = [
"BedrockEmbeddings",
"BedrockLLM",
Expand All @@ -29,4 +49,5 @@
"NeptuneGraph",
"InMemoryVectorStore",
"InMemorySemanticCache",
"BedrockRerank"
]
5 changes: 1 addition & 4 deletions libs/aws/langchain_aws/chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,4 @@
create_neptune_sparql_qa_chain,
)

__all__ = [
"create_neptune_opencypher_qa_chain",
"create_neptune_sparql_qa_chain"
]
__all__ = ["create_neptune_opencypher_qa_chain", "create_neptune_sparql_qa_chain"]
5 changes: 1 addition & 4 deletions libs/aws/langchain_aws/chains/graph_qa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from .neptune_cypher import create_neptune_opencypher_qa_chain
from .neptune_sparql import create_neptune_sparql_qa_chain

__all__ = [
"create_neptune_opencypher_qa_chain",
"create_neptune_sparql_qa_chain"
]
__all__ = ["create_neptune_opencypher_qa_chain", "create_neptune_sparql_qa_chain"]
2 changes: 1 addition & 1 deletion libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import Any, Optional, Union
from typing import Optional, Union

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
Expand Down
37 changes: 27 additions & 10 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import re
import warnings
from collections import defaultdict
from operator import itemgetter
from typing import (
Expand Down Expand Up @@ -50,10 +52,13 @@
_combine_generation_info_for_llm_result,
)
from langchain_aws.utils import (
anthropic_tokens_supported,
get_num_tokens_anthropic,
get_token_ids_anthropic,
)

logger = logging.getLogger(__name__)


def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
Expand Down Expand Up @@ -524,6 +529,7 @@ def _generate(
return self._as_converse._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
logger.info(f"The input message: {messages}")
completion = ""
llm_output: Dict[str, Any] = {}
tool_calls: List[ToolCall] = []
Expand Down Expand Up @@ -586,16 +592,14 @@ def _generate(
)
else:
usage_metadata = None

logger.info(f"The message received from Bedrock: {completion}")
llm_output["model_id"] = self.model_id

msg = AIMessage(
content=completion,
additional_kwargs=llm_output,
tool_calls=cast(List[ToolCall], tool_calls),
usage_metadata=usage_metadata,
)

return ChatResult(
generations=[
ChatGeneration(
Expand All @@ -618,16 +622,27 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return final_output

def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
if (
self._model_is_anthropic
and not self.custom_get_token_ids
and anthropic_tokens_supported()
):
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
return super().get_num_tokens(text)

def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)
if self._model_is_anthropic and not self.custom_get_token_ids:
if anthropic_tokens_supported():
return get_token_ids_anthropic(text)
else:
warnings.warn(
f"Falling back to default token method due to missing or incompatible `anthropic` installation "
f"(needs <=0.38.0).\n\nIf using `anthropic>0.38.0`, it is recommended to provide the model "
f"class with a custom_get_token_ids method implementing a more accurate tokenizer for Anthropic. "
f"For get_num_tokens, as another alternative, you can implement your own token counter method "
f"using the ChatAnthropic or AnthropicLLM classes."
)
return super().get_token_ids(text)

def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None:
"""Workaround to bind. Sets the system prompt with tools"""
Expand Down Expand Up @@ -844,6 +859,8 @@ def _as_converse(self) -> ChatBedrockConverse:
"top_p",
"additional_model_request_fields",
"additional_model_response_field_paths",
"performance_config",
"request_metadata",
)
}
if self.max_tokens:
Expand Down
27 changes: 27 additions & 0 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import json
import logging
import os
import re
from operator import itemgetter
Expand All @@ -10,6 +11,7 @@
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -53,7 +55,9 @@

from langchain_aws.function_calling import ToolsOutputParser

logger = logging.getLogger(__name__)
_BM = TypeVar("_BM", bound=BaseModel)

_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]


Expand Down Expand Up @@ -393,6 +397,19 @@ class Joke(BaseModel):
('auto') if a 'nova' model is used, empty otherwise.
"""

performance_config: Optional[Mapping[str, Any]] = Field(
default=None,
description="""Performance configuration settings for latency optimization.
Example:
performance_config={'latency': 'optimized'}
If not provided, defaults to standard latency.
""",
)

request_metadata: Optional[Dict[str, str]] = None
"""Key-Value pairs that you can use to filter invocation logs."""

model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
Expand Down Expand Up @@ -495,13 +512,19 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
logger.info(f"The input message: {messages}")
bedrock_messages, system = _messages_to_bedrock(messages)
logger.debug(f"input message to bedrock: {bedrock_messages}")
logger.debug(f"System message to bedrock: {system}")
params = self._converse_params(
stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"})
)
logger.debug(f"Input params: {params}")
logger.info("Using Bedrock Converse API to generate response")
response = self.client.converse(
messages=bedrock_messages, system=system, **params
)
logger.debug(f"Response from Bedrock: {response}")
response_message = _parse_response(response)
return ChatResult(generations=[ChatGeneration(message=response_message)])

Expand Down Expand Up @@ -624,6 +647,8 @@ def _converse_params(
additionalModelRequestFields: Optional[dict] = None,
additionalModelResponseFieldPaths: Optional[List[str]] = None,
guardrailConfig: Optional[dict] = None,
performanceConfig: Optional[Mapping[str, Any]] = None,
requestMetadata: Optional[dict] = None,
) -> Dict[str, Any]:
if not inferenceConfig:
inferenceConfig = {
Expand All @@ -646,6 +671,8 @@ def _converse_params(
"additionalModelResponseFieldPaths": additionalModelResponseFieldPaths
or self.additional_model_response_field_paths,
"guardrailConfig": guardrailConfig or self.guardrail_config,
"performanceConfig": performanceConfig or self.performance_config,
"requestMetadata": requestMetadata or self.request_metadata,
}
)

Expand Down
Empty file.
134 changes: 134 additions & 0 deletions libs/aws/langchain_aws/document_compressors/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union

import boto3
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils import from_env
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self


class BedrockRerank(BaseDocumentCompressor):
"""Document compressor that uses AWS Bedrock Rerank API."""

model_arn: str
"""The ARN of the reranker model."""
client: Any = None
"""Bedrock client to use for compressing documents."""
top_n: Optional[int] = 3
"""Number of documents to return."""
region_name: str = Field(
default_factory=from_env("AWS_DEFAULT_REGION", default=None)
)
"""AWS region to initialize the Bedrock client."""
credentials_profile_name: Optional[str] = Field(
default_factory=from_env("AWS_PROFILE", default=None)
)
"""AWS profile for authentication, optional."""

model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True,
)

@model_validator(mode="before")
@classmethod
def initialize_client(cls, values: Dict[str, Any]) -> Any:
"""Initialize the AWS Bedrock client."""
if not values.get("client"):
session = (
boto3.Session(profile_name=values.get("credentials_profile_name"))
if values.get("credentials_profile_name", None)
else boto3.Session()
)
values["client"] = session.client(
"bedrock-agent-runtime",
region_name=values.get("region_name"),
)
return values

def rerank(
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
top_n: Optional[int] = None,
additional_model_request_fields: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Returns an ordered list of documents based on their relevance to the query.
Args:
query: The query to use for reranking.
documents: A sequence of documents to rerank.
top_n: The number of top-ranked results to return. Defaults to self.top_n.
additional_model_request_fields: A dictionary of additional fields to pass to the model.
Returns:
List[Dict[str, Any]]: A list of ranked documents with relevance scores.
"""
if len(documents) == 0:
return []

# Serialize documents for the Bedrock API
serialized_documents = [
{"textDocument": {"text": doc.page_content}, "type": "TEXT"}
if isinstance(doc, Document)
else {"textDocument": {"text": doc}, "type": "TEXT"}
if isinstance(doc, str)
else {"jsonDocument": doc, "type": "JSON"}
for doc in documents
]

request_body = {
"queries": [{"textQuery": {"text": query}, "type": "TEXT"}],
"rerankingConfiguration": {
"bedrockRerankingConfiguration": {
"modelConfiguration": {
"modelArn": self.model_arn,
"additionalModelRequestFields": additional_model_request_fields
or {},
},
"numberOfResults": top_n or self.top_n,
},
"type": "BEDROCK_RERANKING_MODEL",
},
"sources": [
{"inlineDocumentSource": doc, "type": "INLINE"}
for doc in serialized_documents
],
}

response = self.client.rerank(**request_body)
response_body = response.get("results", [])

results = [
{"index": result["index"], "relevance_score": result["relevanceScore"]}
for result in response_body
]

return results

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using Bedrock's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
compressed = []
for res in self.rerank(documents, query):
doc = documents[res["index"]]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res["relevance_score"]
compressed.append(doc_copy)
return compressed
Loading

0 comments on commit 59b946a

Please sign in to comment.