diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 29f4bbc6..40e654aa 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -3,6 +3,7 @@ create_neptune_sparql_qa_chain, ) from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse +from langchain_aws.document_compressors.rerank import BedrockRerank from langchain_aws.embeddings import BedrockEmbeddings from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from langchain_aws.llms import BedrockLLM, SagemakerEndpoint @@ -15,6 +16,25 @@ InMemoryVectorStore, ) + +def setup_logging(): + import logging + import os + + if os.environ.get("LANGCHAIN_AWS_DEBUG", "FALSE").lower() in ["true", "1"]: + DEFAULT_LOG_FORMAT = ( + "%(asctime)s %(levelname)s | [%(filename)s:%(lineno)s]" + "| %(name)s - %(message)s" + ) + log_formatter = logging.Formatter(DEFAULT_LOG_FORMAT) + log_handler = logging.StreamHandler() + log_handler.setFormatter(log_formatter) + logging.getLogger("langchain_aws").addHandler(log_handler) + logging.getLogger("langchain_aws").setLevel(logging.DEBUG) + + +setup_logging() + __all__ = [ "BedrockEmbeddings", "BedrockLLM", @@ -29,4 +49,5 @@ "NeptuneGraph", "InMemoryVectorStore", "InMemorySemanticCache", + "BedrockRerank" ] diff --git a/libs/aws/langchain_aws/chains/__init__.py b/libs/aws/langchain_aws/chains/__init__.py index 55e7e845..1d05a7db 100644 --- a/libs/aws/langchain_aws/chains/__init__.py +++ b/libs/aws/langchain_aws/chains/__init__.py @@ -3,7 +3,4 @@ create_neptune_sparql_qa_chain, ) -__all__ = [ - "create_neptune_opencypher_qa_chain", - "create_neptune_sparql_qa_chain" -] +__all__ = ["create_neptune_opencypher_qa_chain", "create_neptune_sparql_qa_chain"] diff --git a/libs/aws/langchain_aws/chains/graph_qa/__init__.py b/libs/aws/langchain_aws/chains/graph_qa/__init__.py index 357d0360..572d74c1 100644 --- a/libs/aws/langchain_aws/chains/graph_qa/__init__.py +++ b/libs/aws/langchain_aws/chains/graph_qa/__init__.py @@ -1,7 +1,4 @@ from .neptune_cypher import create_neptune_opencypher_qa_chain from .neptune_sparql import create_neptune_sparql_qa_chain -__all__ = [ - "create_neptune_opencypher_qa_chain", - "create_neptune_sparql_qa_chain" -] +__all__ = ["create_neptune_opencypher_qa_chain", "create_neptune_sparql_qa_chain"] diff --git a/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py b/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py index 8c6a35a3..5af575ea 100644 --- a/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py +++ b/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Any, Optional, Union +from typing import Optional, Union from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts.base import BasePromptTemplate diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 73e53202..b971a36e 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -1,4 +1,6 @@ +import logging import re +import warnings from collections import defaultdict from operator import itemgetter from typing import ( @@ -50,10 +52,13 @@ _combine_generation_info_for_llm_result, ) from langchain_aws.utils import ( + anthropic_tokens_supported, get_num_tokens_anthropic, get_token_ids_anthropic, ) +logger = logging.getLogger(__name__) + def _convert_one_message_to_text_llama(message: BaseMessage) -> str: if isinstance(message, ChatMessage): @@ -524,6 +529,7 @@ def _generate( return self._as_converse._generate( messages, stop=stop, run_manager=run_manager, **kwargs ) + logger.info(f"The input message: {messages}") completion = "" llm_output: Dict[str, Any] = {} tool_calls: List[ToolCall] = [] @@ -586,16 +592,14 @@ def _generate( ) else: usage_metadata = None - + logger.info(f"The message received from Bedrock: {completion}") llm_output["model_id"] = self.model_id - msg = AIMessage( content=completion, additional_kwargs=llm_output, tool_calls=cast(List[ToolCall], tool_calls), usage_metadata=usage_metadata, ) - return ChatResult( generations=[ ChatGeneration( @@ -618,16 +622,27 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: return final_output def get_num_tokens(self, text: str) -> int: - if self._model_is_anthropic: + if ( + self._model_is_anthropic + and not self.custom_get_token_ids + and anthropic_tokens_supported() + ): return get_num_tokens_anthropic(text) - else: - return super().get_num_tokens(text) + return super().get_num_tokens(text) def get_token_ids(self, text: str) -> List[int]: - if self._model_is_anthropic: - return get_token_ids_anthropic(text) - else: - return super().get_token_ids(text) + if self._model_is_anthropic and not self.custom_get_token_ids: + if anthropic_tokens_supported(): + return get_token_ids_anthropic(text) + else: + warnings.warn( + f"Falling back to default token method due to missing or incompatible `anthropic` installation " + f"(needs <=0.38.0).\n\nIf using `anthropic>0.38.0`, it is recommended to provide the model " + f"class with a custom_get_token_ids method implementing a more accurate tokenizer for Anthropic. " + f"For get_num_tokens, as another alternative, you can implement your own token counter method " + f"using the ChatAnthropic or AnthropicLLM classes." + ) + return super().get_token_ids(text) def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None: """Workaround to bind. Sets the system prompt with tools""" @@ -844,6 +859,8 @@ def _as_converse(self) -> ChatBedrockConverse: "top_p", "additional_model_request_fields", "additional_model_response_field_paths", + "performance_config", + "request_metadata", ) } if self.max_tokens: diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 1d328656..5a27763b 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -1,5 +1,6 @@ import base64 import json +import logging import os import re from operator import itemgetter @@ -10,6 +11,7 @@ Iterator, List, Literal, + Mapping, Optional, Sequence, Tuple, @@ -53,7 +55,9 @@ from langchain_aws.function_calling import ToolsOutputParser +logger = logging.getLogger(__name__) _BM = TypeVar("_BM", bound=BaseModel) + _DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] @@ -393,6 +397,19 @@ class Joke(BaseModel): ('auto') if a 'nova' model is used, empty otherwise. """ + performance_config: Optional[Mapping[str, Any]] = Field( + default=None, + description="""Performance configuration settings for latency optimization. + + Example: + performance_config={'latency': 'optimized'} + If not provided, defaults to standard latency. + """, + ) + + request_metadata: Optional[Dict[str, str]] = None + """Key-Value pairs that you can use to filter invocation logs.""" + model_config = ConfigDict( extra="forbid", populate_by_name=True, @@ -495,13 +512,19 @@ def _generate( **kwargs: Any, ) -> ChatResult: """Top Level call""" + logger.info(f"The input message: {messages}") bedrock_messages, system = _messages_to_bedrock(messages) + logger.debug(f"input message to bedrock: {bedrock_messages}") + logger.debug(f"System message to bedrock: {system}") params = self._converse_params( stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"}) ) + logger.debug(f"Input params: {params}") + logger.info("Using Bedrock Converse API to generate response") response = self.client.converse( messages=bedrock_messages, system=system, **params ) + logger.debug(f"Response from Bedrock: {response}") response_message = _parse_response(response) return ChatResult(generations=[ChatGeneration(message=response_message)]) @@ -624,6 +647,8 @@ def _converse_params( additionalModelRequestFields: Optional[dict] = None, additionalModelResponseFieldPaths: Optional[List[str]] = None, guardrailConfig: Optional[dict] = None, + performanceConfig: Optional[Mapping[str, Any]] = None, + requestMetadata: Optional[dict] = None, ) -> Dict[str, Any]: if not inferenceConfig: inferenceConfig = { @@ -646,6 +671,8 @@ def _converse_params( "additionalModelResponseFieldPaths": additionalModelResponseFieldPaths or self.additional_model_response_field_paths, "guardrailConfig": guardrailConfig or self.guardrail_config, + "performanceConfig": performanceConfig or self.performance_config, + "requestMetadata": requestMetadata or self.request_metadata, } ) diff --git a/libs/aws/langchain_aws/document_compressors/__init__.py b/libs/aws/langchain_aws/document_compressors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/langchain_aws/document_compressors/rerank.py b/libs/aws/langchain_aws/document_compressors/rerank.py new file mode 100644 index 00000000..bd17ee32 --- /dev/null +++ b/libs/aws/langchain_aws/document_compressors/rerank.py @@ -0,0 +1,134 @@ +from copy import deepcopy +from typing import Any, Dict, List, Optional, Sequence, Union + +import boto3 +from langchain_core.callbacks.manager import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from langchain_core.utils import from_env +from pydantic import ConfigDict, Field, model_validator +from typing_extensions import Self + + +class BedrockRerank(BaseDocumentCompressor): + """Document compressor that uses AWS Bedrock Rerank API.""" + + model_arn: str + """The ARN of the reranker model.""" + client: Any = None + """Bedrock client to use for compressing documents.""" + top_n: Optional[int] = 3 + """Number of documents to return.""" + region_name: str = Field( + default_factory=from_env("AWS_DEFAULT_REGION", default=None) + ) + """AWS region to initialize the Bedrock client.""" + credentials_profile_name: Optional[str] = Field( + default_factory=from_env("AWS_PROFILE", default=None) + ) + """AWS profile for authentication, optional.""" + + model_config = ConfigDict( + extra="forbid", + arbitrary_types_allowed=True, + ) + + @model_validator(mode="before") + @classmethod + def initialize_client(cls, values: Dict[str, Any]) -> Any: + """Initialize the AWS Bedrock client.""" + if not values.get("client"): + session = ( + boto3.Session(profile_name=values.get("credentials_profile_name")) + if values.get("credentials_profile_name", None) + else boto3.Session() + ) + values["client"] = session.client( + "bedrock-agent-runtime", + region_name=values.get("region_name"), + ) + return values + + def rerank( + self, + documents: Sequence[Union[str, Document, dict]], + query: str, + top_n: Optional[int] = None, + additional_model_request_fields: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + """Returns an ordered list of documents based on their relevance to the query. + + Args: + query: The query to use for reranking. + documents: A sequence of documents to rerank. + top_n: The number of top-ranked results to return. Defaults to self.top_n. + additional_model_request_fields: A dictionary of additional fields to pass to the model. + + Returns: + List[Dict[str, Any]]: A list of ranked documents with relevance scores. + """ + if len(documents) == 0: + return [] + + # Serialize documents for the Bedrock API + serialized_documents = [ + {"textDocument": {"text": doc.page_content}, "type": "TEXT"} + if isinstance(doc, Document) + else {"textDocument": {"text": doc}, "type": "TEXT"} + if isinstance(doc, str) + else {"jsonDocument": doc, "type": "JSON"} + for doc in documents + ] + + request_body = { + "queries": [{"textQuery": {"text": query}, "type": "TEXT"}], + "rerankingConfiguration": { + "bedrockRerankingConfiguration": { + "modelConfiguration": { + "modelArn": self.model_arn, + "additionalModelRequestFields": additional_model_request_fields + or {}, + }, + "numberOfResults": top_n or self.top_n, + }, + "type": "BEDROCK_RERANKING_MODEL", + }, + "sources": [ + {"inlineDocumentSource": doc, "type": "INLINE"} + for doc in serialized_documents + ], + } + + response = self.client.rerank(**request_body) + response_body = response.get("results", []) + + results = [ + {"index": result["index"], "relevance_score": result["relevanceScore"]} + for result in response_body + ] + + return results + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """ + Compress documents using Bedrock's rerank API. + + Args: + documents: A sequence of documents to compress. + query: The query to use for compressing the documents. + callbacks: Callbacks to run during the compression process. + + Returns: + A sequence of compressed documents. + """ + compressed = [] + for res in self.rerank(documents, query): + doc = documents[res["index"]] + doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) + doc_copy.metadata["relevance_score"] = res["relevance_score"] + compressed.append(doc_copy) + return compressed diff --git a/libs/aws/langchain_aws/embeddings/bedrock.py b/libs/aws/langchain_aws/embeddings/bedrock.py index 34e57f5b..33bbd174 100644 --- a/libs/aws/langchain_aws/embeddings/bedrock.py +++ b/libs/aws/langchain_aws/embeddings/bedrock.py @@ -78,6 +78,11 @@ class BedrockEmbeddings(BaseModel, Embeddings): protected_namespaces=(), ) + @property + def provider(self) -> str: + """Provider of the model.""" + return self.model_id.split(".")[0] + @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that AWS credentials to and python package exists in environment.""" @@ -121,20 +126,38 @@ def validate_environment(self) -> Self: return self def _embedding_func(self, text: str) -> List[float]: - """Call out to Bedrock embedding endpoint.""" + """Call out to Bedrock embedding endpoint with a single text.""" # replace newlines, which can negatively affect performance. text = text.replace(os.linesep, " ") - # format input body for provider - provider = self.model_id.split(".")[0] - input_body: Dict[str, Any] = {} - if provider == "cohere": - input_body["input_type"] = "search_document" - input_body["texts"] = [text] + if self.provider == "cohere": + response_body = self._invoke_model( + input_body={ + "input_type": "search_document", + "texts": [text], + } + ) + return response_body.get("embeddings")[0] else: # includes common provider == "amazon" - input_body["inputText"] = text + response_body = self._invoke_model( + input_body={"inputText": text}, + ) + return response_body.get("embedding") + + def _cohere_multi_embedding(self, texts: List[str]) -> List[float]: + """Call out to Cohere Bedrock embedding endpoint with multiple inputs.""" + # replace newlines, which can negatively affect performance. + texts = [text.replace(os.linesep, " ") for text in texts] + + return self._invoke_model( + input_body={ + "input_type": "search_document", + "texts": texts, + } + ).get("embeddings") + def _invoke_model(self, input_body: Dict[str, Any] = {}) -> Dict[str, Any]: if self.model_kwargs: input_body = {**input_body, **self.model_kwargs} @@ -149,11 +172,7 @@ def _embedding_func(self, text: str) -> List[float]: ) response_body = json.loads(response.get("body").read()) - if provider == "cohere": - return response_body.get("embeddings")[0] - else: - return response_body.get("embedding") - + return response_body except Exception as e: logging.error(f"Error raised by inference endpoint: {e}") raise e @@ -173,6 +192,22 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: Returns: List of embeddings, one for each text. """ + + # If we are able to make use of Cohere's multiple embeddings, use that + if self.provider == "cohere": + return self._embed_cohere_documents(texts) + else: + return self._iteratively_embed_documents(texts) + + def _embed_cohere_documents(self, texts: List[str]) -> List[List[float]]: + response = self._cohere_multi_embedding(texts) + + if self.normalize: + response = [self._normalize_vector(embedding) for embedding in response] + + return response + + def _iteratively_embed_documents(self, texts: List[str]) -> List[List[float]]: results = [] for text in texts: response = self._embedding_func(text) diff --git a/libs/aws/langchain_aws/graphs/__init__.py b/libs/aws/langchain_aws/graphs/__init__.py index f8d36c9e..a4e83e20 100644 --- a/libs/aws/langchain_aws/graphs/__init__.py +++ b/libs/aws/langchain_aws/graphs/__init__.py @@ -7,7 +7,7 @@ __all__ = [ "BaseNeptuneGraph", - "NeptuneAnalyticsGraph", + "NeptuneAnalyticsGraph", "NeptuneGraph", - "NeptuneRdfGraph" + "NeptuneRdfGraph", ] diff --git a/libs/aws/langchain_aws/graphs/neptune_graph.py b/libs/aws/langchain_aws/graphs/neptune_graph.py index 16623879..df491cbd 100644 --- a/libs/aws/langchain_aws/graphs/neptune_graph.py +++ b/libs/aws/langchain_aws/graphs/neptune_graph.py @@ -7,9 +7,7 @@ def _format_triples(triples: List[dict]) -> List[str]: triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)" triple_schema = [] for t in triples: - triple = triple_template.format( - a=t["~from"], e=t["~type"], b=t["~to"] - ) + triple = triple_template.format(a=t["~from"], e=t["~type"], b=t["~to"]) triple_schema.append(triple) return triple_schema @@ -21,7 +19,9 @@ def _format_node_properties(n_labels: dict) -> List: for label, props_item in n_labels.items(): props = props_item["properties"] np = { - "properties": [{"property": k, "type": v["datatypes"][0]} for k, v in props.items()], + "properties": [ + {"property": k, "type": v["datatypes"][0]} for k, v in props.items() + ], "labels": label, } node_properties.append(np) @@ -36,7 +36,9 @@ def _format_edge_properties(e_labels: dict) -> List: props = props_item["properties"] np = { "type": label, - "properties": [{"property": k, "type": v["datatypes"][0]} for k, v in props.items()], + "properties": [ + {"property": k, "type": v["datatypes"][0]} for k, v in props.items() + ], } edge_properties.append(np) @@ -210,7 +212,7 @@ def __init__( graph_identifier: str, client: Any = None, credentials_profile_name: Optional[str] = None, - region_name: Optional[str] = None + region_name: Optional[str] = None, ) -> None: """Create a new Neptune Analytics graph wrapper instance.""" @@ -331,6 +333,7 @@ def _refresh_schema(self) -> None: {triple_schema} """ + class NeptuneGraph(BaseNeptuneGraph): """Neptune wrapper for graph operations. diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index d336c67a..112a5c67 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -33,11 +33,14 @@ from langchain_aws.function_calling import _tools_in_params from langchain_aws.utils import ( + anthropic_tokens_supported, enforce_stop_tokens, get_num_tokens_anthropic, get_token_ids_anthropic, ) +logger = logging.getLogger(__name__) + AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace" GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAction" HUMAN_PROMPT = "\n\nHuman:" @@ -825,6 +828,8 @@ def _prepare_input_and_invoke( request_options["trace"] = "ENABLED" try: + logger.debug(f"Request body sent to bedrock: {request_options}") + logger.info("Using Bedrock Invoke API to generate response") response = self.client.invoke_model(**request_options) ( @@ -834,7 +839,7 @@ def _prepare_input_and_invoke( usage_info, stop_reason, ) = LLMInputOutputAdapter.prepare_output(provider, response).values() - + logger.debug(f"Response received from Bedrock: {response}") except Exception as e: logging.error(f"Error raised by bedrock service: {e}") if run_manager is not None: @@ -1298,13 +1303,21 @@ async def _acall( return "".join([chunk.text for chunk in chunks]) def get_num_tokens(self, text: str) -> int: - if self._model_is_anthropic: - return get_num_tokens_anthropic(text) - else: - return super().get_num_tokens(text) + if self._model_is_anthropic and not self.custom_get_token_ids: + if anthropic_tokens_supported(): + return get_num_tokens_anthropic(text) + return super().get_num_tokens(text) def get_token_ids(self, text: str) -> List[int]: - if self._model_is_anthropic: - return get_token_ids_anthropic(text) - else: - return super().get_token_ids(text) + if self._model_is_anthropic and not self.custom_get_token_ids: + if anthropic_tokens_supported(): + return get_token_ids_anthropic(text) + else: + warnings.warn( + f"Falling back to default token method due to missing or incompatible `anthropic` installation " + f"(needs <=0.38.0).\n\nFor `anthropic>0.38.0`, it is recommended to provide the model " + f"class with a custom_get_token_ids method implementing a more accurate tokenizer for Anthropic. " + f"For get_num_tokens, as another alternative, you can implement your own token counter method " + f"using the ChatAnthropic or AnthropicLLM classes." + ) + return super().get_token_ids(text) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 55d32837..d0399d33 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, List, Literal, Optional, Union import boto3 @@ -65,11 +66,10 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever): specified. If not specified, the default credential profile or, if on an EC2 instance, credentials from IMDS will be used. client: boto3 client for bedrock agent runtime. - retrieval_config: Configuration for retrieval. - + retrieval_config: Optional configuration for retrieval specified as a + Python object (RetrievalConfig) or as a dictionary Example: .. code-block:: python - from langchain_community.retrievers import AmazonKnowledgeBasesRetriever retriever = AmazonKnowledgeBasesRetriever( @@ -87,7 +87,7 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever): credentials_profile_name: Optional[str] = None endpoint_url: Optional[str] = None client: Any - retrieval_config: RetrievalConfig + retrieval_config: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None min_score_confidence: Annotated[ Optional[float], Field(ge=0.0, le=1.0, default=None) ] @@ -159,17 +159,53 @@ def _get_relevant_documents( *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: - response = self.client.retrieve( - retrievalQuery={"text": query.strip()}, - knowledgeBaseId=self.knowledge_base_id, - retrievalConfiguration=self.retrieval_config.model_dump( - exclude_none=True, by_alias=True - ), - ) + """ + Get relevant document from a KnowledgeBase + + :param query: the user's query + :param run_manager: The callback handler to use + :return: List of relevant documents + """ + retrieve_request: Dict[str, Any] = self._get_retrieve_request(query) + response = self.client.retrieve(**retrieve_request) results = response["retrievalResults"] + documents: List[ + Document + ] = AmazonKnowledgeBasesRetriever._retrieval_results_to_documents(results) + + return self._filter_by_score_confidence(docs=documents) + + def _get_retrieve_request(self, query: str) -> Dict[str, Any]: + """ + Build a Retrieve request + + :param query: + :return: + """ + request: Dict[str, Any] = { + "retrievalQuery": {"text": query.strip()}, + "knowledgeBaseId": self.knowledge_base_id, + } + if self.retrieval_config: + request["retrievalConfiguration"] = self.retrieval_config.model_dump( + exclude_none=True, by_alias=True + ) + return request + + @staticmethod + def _retrieval_results_to_documents( + results: List[Dict[str, Any]], + ) -> List[Document]: + """ + Convert the Retrieve API results to LangChain Documents + + :param results: Retrieve API results list + :return: List of LangChain Documents + """ documents = [] for result in results: - content = result["content"]["text"] + content = AmazonKnowledgeBasesRetriever._get_content_from_result(result) + result["type"] = result.get("content", {}).get("type", "TEXT") result.pop("content") if "score" not in result: result["score"] = 0 @@ -181,5 +217,33 @@ def _get_relevant_documents( metadata=result, ) ) + return documents - return self._filter_by_score_confidence(docs=documents) + @staticmethod + def _get_content_from_result(result: Dict[str, Any]) -> Optional[str]: + """ + Convert the content from one Retrieve API result to string + + :param result: Retrieve API search result + :return: string representation of the content attribute + """ + if not result: + raise ValueError("Invalid search result") + content: dict = result.get("content") + if not content: + raise ValueError( + "Invalid search result, content is missing from the result" + ) + if not content.get("type"): + return content.get("text") + if content["type"] == "TEXT": + return content.get("text") + elif content["type"] == "IMAGE": + return content.get("byteContent") + elif content["type"] == "ROW": + row: Optional[List[dict]] = content.get("row", []) + return json.dumps(row if row else []) + else: + # future proofing this class to prevent code breaks if new types + # are introduced + return None diff --git a/libs/aws/langchain_aws/utils.py b/libs/aws/langchain_aws/utils.py index ff9188a2..9426bb98 100644 --- a/libs/aws/langchain_aws/utils.py +++ b/libs/aws/langchain_aws/utils.py @@ -1,21 +1,38 @@ import re from typing import Any, List +from packaging import version + def enforce_stop_tokens(text: str, stop: List[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text, maxsplit=1)[0] -def _get_anthropic_client() -> Any: +def anthropic_tokens_supported() -> bool: + """Check if we have all requirements for Anthropic count_tokens() and get_tokenizer().""" try: import anthropic except ImportError: - raise ImportError( - "Could not import anthropic python package. " - "This is needed in order to accurately tokenize the text " - "for anthropic models. Please install it with `pip install anthropic`." - ) + return False + + if version.parse(anthropic.__version__) > version.parse("0.38.0"): + return False + + try: + import httpx + + if version.parse(httpx.__version__) > version.parse("0.27.2"): + raise ImportError() + except ImportError: + raise ImportError("httpx<=0.27.2 is required.") + + return True + + +def _get_anthropic_client() -> Any: + import anthropic + return anthropic.Anthropic() diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index 4d065592..40304fd4 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-aws" -version = "0.2.11" +version = "0.2.12" description = "An integration package connecting AWS and LangChain" authors = [] readme = "README.md" diff --git a/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py b/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py index 624dc4d3..b95002b5 100644 --- a/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py +++ b/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py @@ -18,6 +18,13 @@ def bedrock_embeddings_v2() -> BedrockEmbeddings: ) +@pytest.fixture +def cohere_embeddings_v3() -> BedrockEmbeddings: + return BedrockEmbeddings( + model_id="cohere.embed-english-v3", + ) + + @pytest.mark.scheduled def test_bedrock_embedding_documents(bedrock_embeddings) -> None: documents = ["foo bar"] @@ -101,3 +108,21 @@ def test_embed_query_with_size(bedrock_embeddings_v2) -> None: output = bedrock_embeddings_v2.embed_query(prompt_data) assert len(response[0]) == 256 assert len(output) == 256 + + +@pytest.mark.scheduled +def test_bedrock_cohere_embedding_documents(cohere_embeddings_v3) -> None: + documents = ["foo bar"] + output = cohere_embeddings_v3.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 1024 + + +@pytest.mark.scheduled +def test_bedrock_cohere_embedding_documents_multiple(cohere_embeddings_v3) -> None: + documents = ["foo bar", "bar foo", "foo"] + output = cohere_embeddings_v3.embed_documents(documents) + assert len(output) == 3 + assert len(output[0]) == 1024 + assert len(output[1]) == 1024 + assert len(output[2]) == 1024 diff --git a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py index 54eb8bf3..37d6040c 100644 --- a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py +++ b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py @@ -48,19 +48,20 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore expected_documents = [ Document( page_content="This is the first result.", - metadata={"location": "location1", "score": 0.9}, + metadata={"location": "location1", "score": 0.9, "type": "TEXT"}, ), Document( page_content="This is the second result.", - metadata={"location": "location2", "score": 0.8}, + metadata={"location": "location2", "score": 0.8, "type": "TEXT"}, ), Document( page_content="This is the third result.", - metadata={"location": "location3", "score": 0.0}, + metadata={"location": "location3", "score": 0.0, "type": "TEXT"}, ), Document( page_content="This is the fourth result.", metadata={ + "type": "TEXT", "score": 0.0, "source_metadata": { "key1": "value1", @@ -108,11 +109,11 @@ def test_get_relevant_documents_with_score(retriever, mock_client) -> None: # t expected_documents = [ Document( page_content="This is the first result.", - metadata={"location": "location1", "score": 0.9}, + metadata={"location": "location1", "score": 0.9, "type": "TEXT"}, ), Document( page_content="This is the second result.", - metadata={"location": "location2", "score": 0.8}, + metadata={"location": "location2", "score": 0.8, "type": "TEXT"}, ), ] diff --git a/libs/aws/tests/unit_tests/document_compressors/__init__.py b/libs/aws/tests/unit_tests/document_compressors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/tests/unit_tests/document_compressors/test_rerank.py b/libs/aws/tests/unit_tests/document_compressors/test_rerank.py new file mode 100644 index 00000000..8acdda12 --- /dev/null +++ b/libs/aws/tests/unit_tests/document_compressors/test_rerank.py @@ -0,0 +1,55 @@ +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.documents import Document + +from langchain_aws.document_compressors.rerank import BedrockRerank + + +@pytest.fixture +def reranker() -> BedrockRerank: + reranker = BedrockRerank( + model_arn="arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0", + region_name="us-east-1", + ) + reranker.client = MagicMock() + return reranker + +@patch("boto3.Session") +def test_initialize_client(mock_boto_session: MagicMock, reranker: BedrockRerank) -> None: + session_instance = MagicMock() + mock_boto_session.return_value = session_instance + session_instance.client.return_value = MagicMock() + assert reranker.client is not None + +@patch("langchain_aws.document_compressors.rerank.BedrockRerank.rerank") +def test_rerank(mock_rerank: MagicMock, reranker: BedrockRerank) -> None: + mock_rerank.return_value = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.8}, + ] + + documents = [Document(page_content="Doc 1"), Document(page_content="Doc 2")] + query = "Example Query" + results = reranker.rerank(documents, query) + + assert len(results) == 2 + assert results[0]["index"] == 0 + assert results[0]["relevance_score"] == 0.9 + assert results[1]["index"] == 1 + assert results[1]["relevance_score"] == 0.8 + +@patch("langchain_aws.document_compressors.rerank.BedrockRerank.rerank") +def test_compress_documents(mock_rerank: MagicMock, reranker: BedrockRerank) -> None: + mock_rerank.return_value = [ + {"index": 0, "relevance_score": 0.95}, + {"index": 1, "relevance_score": 0.85}, + ] + + documents = [Document(page_content="Content 1"), Document(page_content="Content 2")] + query = "Relevant query" + compressed_docs = reranker.compress_documents(documents, query) + + assert len(compressed_docs) == 2 + assert compressed_docs[0].metadata["relevance_score"] == 0.95 + assert compressed_docs[1].metadata["relevance_score"] == 0.85 diff --git a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py index 98357428..bc37c3e8 100644 --- a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py +++ b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py @@ -1,5 +1,5 @@ # type: ignore - +from typing import Any, List from unittest.mock import MagicMock import pytest @@ -28,6 +28,16 @@ def mock_retriever_config(): ) +@pytest.fixture +def mock_retriever_config_dict(): + return { + "vectorSearchConfiguration": { + "numberOfResults": 5, + "filter": {"in": {"key": "key", "value": ["value1", "value2"]}}, + } + } + + @pytest.fixture def amazon_retriever(mock_client, mock_retriever_config): return AmazonKnowledgeBasesRetriever( @@ -37,6 +47,23 @@ def amazon_retriever(mock_client, mock_retriever_config): ) +@pytest.fixture +def amazon_retriever_no_retrieval_config(mock_client, mock_retriever_config): + return AmazonKnowledgeBasesRetriever( + knowledge_base_id="test_kb_id", + client=mock_client, + ) + + +@pytest.fixture +def amazon_retriever_retrieval_config_dict(mock_client, mock_retriever_config_dict): + return AmazonKnowledgeBasesRetriever( + knowledge_base_id="test_kb_id", + retrieval_config=mock_retriever_config_dict, + client=mock_client, + ) + + def test_retriever_invoke(amazon_retriever, mock_client): query = "test query" mock_client.retrieve.return_value = { @@ -67,15 +94,20 @@ def test_retriever_invoke(amazon_retriever, mock_client): assert len(documents) == 3 assert isinstance(documents[0], Document) assert documents[0].page_content == "result1" - assert documents[0].metadata == {"score": 0, "source_metadata": {"key": "value1"}} + assert documents[0].metadata == { + "score": 0, + "source_metadata": {"key": "value1"}, + "type": "TEXT", + } assert documents[1].page_content == "result2" assert documents[1].metadata == { "score": 1, "source_metadata": {"key": "value2"}, "location": "testLocation", + "type": "TEXT", } assert documents[2].page_content == "result3" - assert documents[2].metadata == {"score": 0} + assert documents[2].metadata == {"score": 0, "type": "TEXT"} def test_retriever_invoke_with_score(amazon_retriever, mock_client): @@ -88,6 +120,7 @@ def test_retriever_invoke_with_score(amazon_retriever, mock_client): "metadata": {"key": "value2"}, "score": 1, "location": "testLocation", + "type": "TEXT", }, {"content": {"text": "result3"}}, ] @@ -103,4 +136,435 @@ def test_retriever_invoke_with_score(amazon_retriever, mock_client): "score": 1, "source_metadata": {"key": "value2"}, "location": "testLocation", + "type": "TEXT", + } + + +def test_retriever_retrieval_config_dict_invoke( + amazon_retriever_retrieval_config_dict, mock_client +): + documents = set_return_value_and_query( + mock_client, amazon_retriever_retrieval_config_dict + ) + validate_query_response_no_cutoff(documents) + mock_client.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, + knowledgeBaseId="test_kb_id", + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 5, + # Expecting to be called with correct "in" operator instead of "in_" + "filter": {"in": {"key": "key", "value": ["value1", "value2"]}}, + } + }, + ) + + +def test_retriever_retrieval_config_dict_invoke_with_score( + amazon_retriever_retrieval_config_dict, mock_client +): + amazon_retriever_retrieval_config_dict.min_score_confidence = 0.6 + documents = set_return_value_and_query( + mock_client, amazon_retriever_retrieval_config_dict + ) + validate_query_response_with_cutoff(documents) + + +def test_retriever_no_retrieval_config_invoke( + amazon_retriever_no_retrieval_config, mock_client +): + documents = set_return_value_and_query( + mock_client, amazon_retriever_no_retrieval_config + ) + validate_query_response_no_cutoff(documents) + mock_client.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, knowledgeBaseId="test_kb_id" + ) + + +def test_retriever_no_retrieval_config_invoke_with_score( + amazon_retriever_no_retrieval_config, mock_client +): + amazon_retriever_no_retrieval_config.min_score_confidence = 0.6 + documents = set_return_value_and_query( + mock_client, amazon_retriever_no_retrieval_config + ) + validate_query_response_with_cutoff(documents) + + +@pytest.mark.parametrize( + "search_results,expected_documents", + [ + ( + [ + { + "content": {"text": "result"}, + "metadata": {"key": "value1"}, + "score": 1, + "location": "testLocation", + }, + { + "content": {"text": "result"}, + "metadata": {"key": "value1"}, + "score": 0.5, + "location": "testLocation", + }, + ], + [ + Document( + page_content="result", + metadata={ + "score": 1, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "TEXT", + }, + ), + Document( + page_content="result", + metadata={ + "score": 0.5, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "TEXT", + }, + ), + ], + ), + # text type + ( + [ + { + "content": {"text": "result", "type": "TEXT"}, + "metadata": {"key": "value1"}, + "score": 1, + "location": "testLocation", + }, + { + "content": {"text": "result", "type": "TEXT"}, + "metadata": {"key": "value1"}, + "score": 0.5, + "location": "testLocation", + }, + ], + [ + Document( + page_content="result", + metadata={ + "score": 1, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "TEXT", + }, + ), + Document( + page_content="result", + metadata={ + "score": 0.5, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "TEXT", + }, + ), + ], + ), + # image type + ( + [ + { + "content": {"byteContent": "bytecontent", "type": "IMAGE"}, + "metadata": {"key": "value1"}, + "score": 1, + "location": "testLocation", + }, + { + "content": {"byteContent": "bytecontent", "type": "IMAGE"}, + "metadata": {"key": "value1"}, + "score": 0.5, + "location": "testLocation", + }, + ], + [ + Document( + page_content="bytecontent", + metadata={ + "score": 1, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "IMAGE", + }, + ), + Document( + page_content="bytecontent", + metadata={ + "score": 0.5, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "IMAGE", + }, + ), + ], + ), + # row type + ( + [ + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "ROW", + }, + "score": 1, + "metadata": {"key": "value1"}, + "location": "testLocation", + }, + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "ROW", + }, + "score": 0.5, + "metadata": {"key": "value1"}, + "location": "testLocation", + }, + ], + [ + Document( + page_content='[{"columnName": "someName1", "columnValue": "someValue1"}, ' + '{"columnName": "someName2", "columnValue": "someValue2"}]', + metadata={ + "score": 1, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "ROW", + }, + ), + Document( + page_content='[{"columnName": "someName1", "columnValue": "someValue1"}, ' + '{"columnName": "someName2", "columnValue": "someValue2"}]', + metadata={ + "score": 0.5, + "location": "testLocation", + "source_metadata": {"key": "value1"}, + "type": "ROW", + }, + ), + ], + ), + ], +) +def test_retriever_with_multi_modal_types_then_get_valid_documents( + mock_client, amazon_retriever, search_results, expected_documents +): + query = "test query" + mock_client.retrieve.return_value = {"retrievalResults": search_results} + documents = amazon_retriever.invoke(query, run_manager=None) + assert documents == expected_documents + + +@pytest.mark.parametrize( + "search_result_input,expected_output", + [ + # VALID INPUTS + # no type + ({"content": {"text": "result"}}, "result"), + # text type + ({"content": {"text": "result", "type": "TEXT"}}, "result"), + # image type + ({"content": {"byteContent": "bytecontent", "type": "IMAGE"}}, "bytecontent"), + # row type + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "ROW", + } + }, + '[{"columnName": "someName1", "columnValue": "someValue1"}, ' + '{"columnName": "someName2", "columnValue": "someValue2"}]', + ), + # VALID INPUTS w/ metadata + # no type + ({"content": {"text": "result"}, "metadata": {"key": "value1"}}, "result"), + # text type + ( + { + "content": {"text": "result", "type": "TEXT"}, + "metadata": {"key": "value1"}, + }, + "result", + ), + # image type + ( + { + "content": {"byteContent": "bytecontent", "type": "IMAGE"}, + "metadata": {"key": "value1"}, + }, + "bytecontent", + ), + # row type + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "metadata": {"key": "value1"}, + "type": "ROW", + } + }, + '[{"columnName": "someName1", "columnValue": "someValue1"}, ' + '{"columnName": "someName2", "columnValue": "someValue2"}]', + ), + # invalid type + ({"content": {"invalid": "invalid", "type": "INVALID"}}, None), + # EMPTY VALUES + # no type + ({"content": {"text": ""}}, ""), + # text type + ({"content": {"text": "", "type": "TEXT"}}, ""), + # image type + ({"content": {"byteContent": "", "type": "IMAGE"}}, ""), + # row type + ({"content": {"row": [], "type": "ROW"}}, "[]"), + # NONE VALUES + ({"content": {"text": None}}, None), + # text type + ({"content": {"text": None, "type": "TEXT"}}, None), + # image type + ({"content": {"byteContent": None, "type": "IMAGE"}}, None), + # row type + ({"content": {"row": None, "type": "ROW"}}, "[]"), + # WRONG CONTENT + # text + ({"content": {"text": "result", "type": "IMAGE"}}, None), + ({"content": {"text": "result", "type": "ROW"}}, "[]"), + # byteContent + ({"content": {"byteContent": "result"}}, None), + ({"content": {"byteContent": "result", "type": "TEXT"}}, None), + ({"content": {"byteContent": "result", "type": "ROW"}}, "[]"), + # row + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ] + } + }, + None, + ), + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "TEXT", + } + }, + None, + ), + ( + { + "content": { + "row": [ + {"columnName": "someName1", "columnValue": "someValue1"}, + {"columnName": "someName2", "columnValue": "someValue2"}, + ], + "type": "IMAGE", + } + }, + None, + ), + ], +) +def test_when_get_content_from_result_then_get_expected_content( + search_result_input, expected_output +): + assert ( + AmazonKnowledgeBasesRetriever._get_content_from_result(search_result_input) + == expected_output + ) + + +@pytest.mark.parametrize( + "search_result_input", + [ + # empty content + ({"content": {}}), + # None content + ({"content": None}), + # empty dict + ({}), + # None search result + None, + ], +) +def test_when_get_content_from_result_with_invalid_content_then_raise_error( + search_result_input, +): + with pytest.raises(ValueError): + AmazonKnowledgeBasesRetriever._get_content_from_result(search_result_input) + + +def set_return_value_and_query( + client: Any, retriever: AmazonKnowledgeBasesRetriever +) -> List[Document]: + query = "test query" + client.retrieve.return_value = { + "retrievalResults": [ + {"content": {"text": "result1"}, "metadata": {"key": "value1"}}, + { + "content": {"text": "result2"}, + "metadata": {"key": "value2"}, + "score": 1, + "location": "testLocation", + }, + {"content": {"text": "result3"}}, + ] + } + return retriever.invoke(query, run_manager=None) + + +def validate_query_response_no_cutoff(documents: List[Document]): + assert len(documents) == 3 + assert isinstance(documents[0], Document) + assert documents[0].page_content == "result1" + assert documents[0].metadata == { + "score": 0, + "source_metadata": {"key": "value1"}, + "type": "TEXT", + } + assert documents[1].page_content == "result2" + assert documents[1].metadata == { + "score": 1, + "source_metadata": {"key": "value2"}, + "location": "testLocation", + "type": "TEXT", + } + assert documents[2].page_content == "result3" + assert documents[2].metadata == {"score": 0, "type": "TEXT"} + + +def validate_query_response_with_cutoff(documents: List[Document]): + assert len(documents) == 1 + assert isinstance(documents[0], Document) + assert documents[0].page_content == "result2" + assert documents[0].metadata == { + "score": 1, + "source_metadata": {"key": "value2"}, + "location": "testLocation", + "type": "TEXT", } diff --git a/samples/document_compressors/rerank.ipynb b/samples/document_compressors/rerank.ipynb new file mode 100644 index 00000000..911ae12a --- /dev/null +++ b/samples/document_compressors/rerank.ipynb @@ -0,0 +1,220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Rerank Document Compressor\n", + "\n", + "In this notebook we will go through how you can use a rerank document compressor with Bedrock.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "\n", + "session = boto3.Session()\n", + "client = session.client('bedrock')\n", + "foundation_model = client.get_foundation_model(modelIdentifier=\"amazon.rerank-v1:0\")\n", + "\n", + "model_arn = foundation_model[\"modelDetails\"][\"modelArn\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The code below processes a list of documents to determine their relevance to a given query using AWS Bedrock's reranking capabilities. It initializes a BedrockRerank instance, providing a list of documents and a query. The `compress_documents` method then evaluates and ranks the documents based on relevance, ensuring that the most relevant ones are prioritized." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Content: AWS Bedrock enables access to AI models.\n", + "Score: 0.07081620395183563\n", + "Content: Artificial intelligence is transforming the world.\n", + "Score: 2.8350802949717036e-06\n", + "Content: LangChain is a powerful library for LLMs.\n", + "Score: 1.5903378880466335e-06\n" + ] + } + ], + "source": [ + "from langchain_core.documents import Document\n", + "from langchain_aws import BedrockRerank\n", + "\n", + "# Initialize the class\n", + "reranker = BedrockRerank(model_arn=model_arn)\n", + "\n", + "# List of documents to rerank\n", + "documents = [\n", + " Document(page_content=\"LangChain is a powerful library for LLMs.\"),\n", + " Document(page_content=\"AWS Bedrock enables access to AI models.\"),\n", + " Document(page_content=\"Artificial intelligence is transforming the world.\"),\n", + "]\n", + "\n", + "# Query for reranking\n", + "query = \"What is AWS Bedrock?\"\n", + "\n", + "# Call the rerank method\n", + "results = reranker.compress_documents(documents, query)\n", + "\n", + "# Display the most relevant documents\n", + "for doc in results:\n", + " print(f\"Content: {doc.page_content}\")\n", + " print(f\"Score: {doc.metadata['relevance_score']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's enhance our base retriever by wrapping it with a `ContextualCompressionRetriever`. Here, we integrate `BedrockRerank`, which leverages AWS Bedrock's reranking capabilities to refine the retrieved results.\n", + "\n", + "When a query is executed, the retriever first retrieves relevant documents using FAISS and then reranks them based on relevance, providing more accurate and meaningful responses." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Content: AWS Bedrock provides cloud-based AI models.\n", + "Score: 0.07585818320512772\n", + "Content: Machine learning can be used for predictions.\n", + "Score: 2.8573158488143235e-06\n", + "Content: LangChain integrates LLM models.\n", + "Score: 1.640820528336917e-06\n" + ] + } + ], + "source": [ + "from langchain_aws import BedrockEmbeddings\n", + "from langchain.retrievers.contextual_compression import ContextualCompressionRetriever\n", + "from langchain.vectorstores import FAISS\n", + "from langchain_core.documents import Document\n", + "from langchain_aws import BedrockRerank\n", + "\n", + "# Create a vector store using FAISS with Bedrock embeddings\n", + "documents = [\n", + " Document(page_content=\"LangChain integrates LLM models.\"),\n", + " Document(page_content=\"AWS Bedrock provides cloud-based AI models.\"),\n", + " Document(page_content=\"Machine learning can be used for predictions.\"),\n", + "]\n", + "embeddings = BedrockEmbeddings()\n", + "vectorstore = FAISS.from_documents(documents, embeddings)\n", + "\n", + "# Create the document compressor using BedrockRerank\n", + "reranker = BedrockRerank(model_arn=model_arn)\n", + "\n", + "# Create the retriever with contextual compression\n", + "retriever = ContextualCompressionRetriever(\n", + " base_compressor=reranker,\n", + " base_retriever=vectorstore.as_retriever(),\n", + ")\n", + "\n", + "# Execute a query\n", + "query = \"How does AWS Bedrock work?\"\n", + "retrieved_docs = retriever.invoke(query)\n", + "\n", + "# Display the most relevant documents\n", + "for doc in retrieved_docs:\n", + " print(f\"Content: {doc.page_content}\")\n", + " print(f\"Score: {doc.metadata.get('relevance_score', 'N/A')}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Unlike `compress_documents`, which works with structured Document objects, the rerank method allows passing plain text strings. This simplifies the process of evaluating and ranking text data." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Index: 1, Score: 0.07159119844436646\n", + "Document: AWS Bedrock provides access to cloud-based models.\n", + "Index: 2, Score: 9.666109690442681e-06\n", + "Document: Machine learning is revolutionizing the world.\n", + "Index: 0, Score: 8.25057043130073e-07\n", + "Document: LangChain is used to integrate LLM models.\n" + ] + } + ], + "source": [ + "from langchain_aws import BedrockRerank\n", + "\n", + "# Initialize BedrockRerank\n", + "reranker = BedrockRerank(model_arn=model_arn)\n", + "\n", + "# Unstructured documents\n", + "documents = [\n", + " \"LangChain is used to integrate LLM models.\",\n", + " \"AWS Bedrock provides access to cloud-based models.\",\n", + " \"Machine learning is revolutionizing the world.\",\n", + "]\n", + "\n", + "# Query\n", + "query = \"What is the role of AWS Bedrock?\"\n", + "\n", + "# Rerank the documents\n", + "results = reranker.rerank(query=query, documents=documents)\n", + "\n", + "# Display the results\n", + "for res in results:\n", + " print(f\"Index: {res['index']}, Score: {res['relevance_score']}\")\n", + " print(f\"Document: {documents[res['index']]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}