Skip to content

Commit 01cd58a

Browse files
authored
fix(reranker): support omitting top_n (#7199)
* fix(reranker): support omitting top_n Signed-off-by: Mikhail Khludnev <[email protected]> * fix(reranker): support omitting top_n Signed-off-by: Mikhail Khludnev <[email protected]> * pass 0 explicitly Signed-off-by: Mikhail Khludnev <[email protected]> --------- Signed-off-by: Mikhail Khludnev <[email protected]> Signed-off-by: Mikhail Khludnev <[email protected]>
1 parent 679d43c commit 01cd58a

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

backend/python/rerankers/backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,13 @@ def Rerank(self, request, context):
7575
documents.append(doc)
7676
ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents))))
7777
# Prepare results to return
78+
cropped_results = ranked_results.top_k(request.top_n) if request.top_n > 0 else ranked_results
7879
results = [
7980
backend_pb2.DocumentResult(
8081
index=res.doc_id,
8182
text=res.text,
8283
relevance_score=res.score
83-
) for res in ranked_results.top_k(request.top_n)
84+
) for res in (cropped_results)
8485
]
8586

8687
# Calculate the usage and total tokens

backend/python/rerankers/test.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,35 @@ def test_rerank(self):
7676
)
7777
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
7878
self.assertTrue(response.success)
79-
79+
80+
rerank_response = stub.Rerank(request)
81+
print(rerank_response.results[0])
82+
self.assertIsNotNone(rerank_response.results)
83+
self.assertEqual(len(rerank_response.results), 2)
84+
self.assertEqual(rerank_response.results[0].text, "I really like you")
85+
self.assertEqual(rerank_response.results[1].text, "I hate you")
86+
except Exception as err:
87+
print(err)
88+
self.fail("Reranker service failed")
89+
finally:
90+
self.tearDown()
91+
92+
def test_rerank_omit_top_n(self):
93+
"""
94+
This method tests if the embeddings are generated successfully even top_n is omitted
95+
"""
96+
try:
97+
self.setUp()
98+
with grpc.insecure_channel("localhost:50051") as channel:
99+
stub = backend_pb2_grpc.BackendStub(channel)
100+
request = backend_pb2.RerankRequest(
101+
query="I love you",
102+
documents=["I hate you", "I really like you"],
103+
top_n=0 #
104+
)
105+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
106+
self.assertTrue(response.success)
107+
80108
rerank_response = stub.Rerank(request)
81109
print(rerank_response.results[0])
82110
self.assertIsNotNone(rerank_response.results)
@@ -91,7 +119,7 @@ def test_rerank(self):
91119

92120
def test_rerank_crop(self):
93121
"""
94-
This method tests if the embeddings are generated successfully
122+
This method tests top_n cropping
95123
"""
96124
try:
97125
self.setUp()
@@ -104,7 +132,7 @@ def test_rerank_crop(self):
104132
)
105133
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
106134
self.assertTrue(response.success)
107-
135+
108136
rerank_response = stub.Rerank(request)
109137
print(rerank_response.results[0])
110138
self.assertIsNotNone(rerank_response.results)
@@ -115,4 +143,4 @@ def test_rerank_crop(self):
115143
print(err)
116144
self.fail("Reranker service failed")
117145
finally:
118-
self.tearDown()
146+
self.tearDown()

0 commit comments

Comments
 (0)