@@ -76,7 +76,35 @@ def test_rerank(self):
7676 )
7777 response = stub .LoadModel (backend_pb2 .ModelOptions (Model = "cross-encoder" ))
7878 self .assertTrue (response .success )
79-
79+
80+ rerank_response = stub .Rerank (request )
81+ print (rerank_response .results [0 ])
82+ self .assertIsNotNone (rerank_response .results )
83+ self .assertEqual (len (rerank_response .results ), 2 )
84+ self .assertEqual (rerank_response .results [0 ].text , "I really like you" )
85+ self .assertEqual (rerank_response .results [1 ].text , "I hate you" )
86+ except Exception as err :
87+ print (err )
88+ self .fail ("Reranker service failed" )
89+ finally :
90+ self .tearDown ()
91+
92+ def test_rerank_omit_top_n (self ):
93+ """
94+ This method tests if the embeddings are generated successfully even top_n is omitted
95+ """
96+ try :
97+ self .setUp ()
98+ with grpc .insecure_channel ("localhost:50051" ) as channel :
99+ stub = backend_pb2_grpc .BackendStub (channel )
100+ request = backend_pb2 .RerankRequest (
101+ query = "I love you" ,
102+ documents = ["I hate you" , "I really like you" ],
103+ top_n = 0 #
104+ )
105+ response = stub .LoadModel (backend_pb2 .ModelOptions (Model = "cross-encoder" ))
106+ self .assertTrue (response .success )
107+
80108 rerank_response = stub .Rerank (request )
81109 print (rerank_response .results [0 ])
82110 self .assertIsNotNone (rerank_response .results )
@@ -91,7 +119,7 @@ def test_rerank(self):
91119
92120 def test_rerank_crop (self ):
93121 """
94- This method tests if the embeddings are generated successfully
122+ This method tests top_n cropping
95123 """
96124 try :
97125 self .setUp ()
@@ -104,7 +132,7 @@ def test_rerank_crop(self):
104132 )
105133 response = stub .LoadModel (backend_pb2 .ModelOptions (Model = "cross-encoder" ))
106134 self .assertTrue (response .success )
107-
135+
108136 rerank_response = stub .Rerank (request )
109137 print (rerank_response .results [0 ])
110138 self .assertIsNotNone (rerank_response .results )
@@ -115,4 +143,4 @@ def test_rerank_crop(self):
115143 print (err )
116144 self .fail ("Reranker service failed" )
117145 finally :
118- self .tearDown ()
146+ self .tearDown ()
0 commit comments