Skip to content

Commit 51baa1b

Browse files
langchain[patch]: fix-cohere-reranker-rerank-method with cohere v5 (langchain-ai#19486)
#### Description Fixed the following error with `rerank` method from `CohereRerank`: ``` ---> [79](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:79) results = self.client.rerank( [80](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:80) query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc [81](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:81) ) [82](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:82) result_dicts = [] [83](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:83) for res in results.results: TypeError: BaseCohere.rerank() takes 1 positional argument but 4 positional arguments (and 2 keyword-only arguments) were given ``` This was easily fixed going from this: ``` def rerank( self, documents: Sequence[Union[str, Document, dict]], query: str, *, model: Optional[str] = None, top_n: Optional[int] = -1, max_chunks_per_doc: Optional[int] = None, ) -> List[Dict[str, Any]]: ... if len(documents) == 0: # to avoid empty api call return [] docs = [ doc.page_content if isinstance(doc, Document) else doc for doc in documents ] model = model or self.model top_n = top_n if (top_n is None or top_n > 0) else self.top_n results = self.client.rerank( query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc ) result_dicts = [] for res in results: result_dicts.append( {"index": res.index, "relevance_score": res.relevance_score} ) return result_dicts ``` to this: ``` def rerank( self, documents: Sequence[Union[str, Document, dict]], query: str, *, model: Optional[str] = None, top_n: Optional[int] = -1, max_chunks_per_doc: Optional[int] = None, ) -> List[Dict[str, Any]]: ... if len(documents) == 0: # to avoid empty api call return [] docs = [ doc.page_content if isinstance(doc, Document) else doc for doc in documents ] model = model or self.model top_n = top_n if (top_n is None or top_n > 0) else self.top_n results = self.client.rerank( query=query, documents=docs, model=model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc <------------- ) result_dicts = [] for res in results.results: <------------- result_dicts.append( {"index": res.index, "relevance_score": res.relevance_score} ) return result_dicts ``` #### Unit & Integration tests I added a unit test to check the behaviour of `rerank`. Also fixed the original integration test which was failing. #### Format & Linting Everything worked properly with `make lint_diff`, `make format_diff` and `make format`. However I noticed an error coming from other part of the library when doing `make lint`: ``` (langchain-py3.9) ➜ langchain git:(master) make format [ "." = "" ] || poetry run ruff format . 1636 files left unchanged [ "." = "" ] || poetry run ruff --select I --fix . (langchain-py3.9) ➜ langchain git:(master) make lint ./scripts/check_pydantic.sh . ./scripts/lint_imports.sh poetry run ruff . [ "." = "" ] || poetry run ruff format . --diff 1636 files already formatted [ "." = "" ] || poetry run ruff --select I . [ "." = "" ] || mkdir -p .mypy_cache && poetry run mypy . --cache-dir .mypy_cache langchain/agents/openai_assistant/base.py:252: error: Argument "file_ids" to "create" of "Assistants" has incompatible type "Optional[Any]"; expected "Union[list[str], NotGiven]" [arg-type] langchain/agents/openai_assistant/base.py:374: error: Argument "file_ids" to "create" of "AsyncAssistants" has incompatible type "Optional[Any]"; expected "Union[list[str], NotGiven]" [arg-type] Found 2 errors in 1 file (checked 1634 source files) make: *** [Makefile:65: lint] Error 1 ``` --------- Co-authored-by: Bagatur <[email protected]> Co-authored-by: Bagatur <[email protected]>
1 parent 332996b commit 51baa1b

File tree

4 files changed

+47
-5
lines changed

4 files changed

+47
-5
lines changed

libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,14 @@ def rerank(
8181
model = model or self.model
8282
top_n = top_n if (top_n is None or top_n > 0) else self.top_n
8383
results = self.client.rerank(
84-
query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc
84+
query=query,
85+
documents=docs,
86+
model=model,
87+
top_n=top_n,
88+
max_chunks_per_doc=max_chunks_per_doc,
8589
)
90+
if hasattr(results, "results"):
91+
results = getattr(results, "results")
8692
result_dicts = []
8793
for res in results:
8894
result_dicts.append(

libs/langchain/poetry.lock

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/langchain/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jinja2 = {version = "^3", optional = true}
3636
tiktoken = {version = ">=0.3.2,<0.6.0", optional = true, python=">=3.9"}
3737
qdrant-client = {version = "^1.3.1", optional = true, python = ">=3.8.1,<3.12"}
3838
dataclasses-json = ">= 0.5.7, < 0.7"
39-
cohere = {version = "^4", optional = true}
39+
cohere = {version = ">=4,<6", optional = true}
4040
openai = {version = "<2", optional = true}
4141
nlpcloud = {version = "^1", optional = true}
4242
huggingface_hub = {version = "^0", optional = true}

libs/langchain/tests/unit_tests/retrievers/document_compressors/test_cohere_rerank.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22

33
import pytest
4+
from pytest_mock import MockerFixture
45

56
from langchain.retrievers.document_compressors import CohereRerank
7+
from langchain.schema import Document
68

79
os.environ["COHERE_API_KEY"] = "foo"
810

@@ -14,3 +16,37 @@ def test_init() -> None:
1416
CohereRerank(
1517
top_n=5, model="rerank-english_v2.0", cohere_api_key="foo", user_agent="bar"
1618
)
19+
20+
21+
@pytest.mark.requires("cohere")
22+
def test_rerank(mocker: MockerFixture) -> None:
23+
mock_client = mocker.MagicMock()
24+
mock_result = mocker.MagicMock()
25+
mock_result.results = [
26+
mocker.MagicMock(index=0, relevance_score=0.8),
27+
mocker.MagicMock(index=1, relevance_score=0.6),
28+
]
29+
mock_client.rerank.return_value = mock_result
30+
31+
test_documents = [
32+
Document(page_content="This is a test document."),
33+
Document(page_content="Another test document."),
34+
]
35+
test_query = "Test query"
36+
37+
mocker.patch("cohere.Client", return_value=mock_client)
38+
39+
reranker = CohereRerank(cohere_api_key="foo")
40+
results = reranker.rerank(test_documents, test_query)
41+
42+
mock_client.rerank.assert_called_once_with(
43+
query=test_query,
44+
documents=[doc.page_content for doc in test_documents],
45+
model="rerank-english-v2.0",
46+
top_n=3,
47+
max_chunks_per_doc=None,
48+
)
49+
assert results == [
50+
{"index": 0, "relevance_score": 0.8},
51+
{"index": 1, "relevance_score": 0.6},
52+
]

0 commit comments

Comments
 (0)