diff --git a/driver.py b/driver.py index 7f1aa18..d604abd 100644 --- a/driver.py +++ b/driver.py @@ -67,7 +67,7 @@ async def run(args): tasks = [] t = time.time() - for i in range(min(len(qvals), 1000)): + for i in range(len(qvals)): request = server_pb2.Query(query=qvals[i][1], qid=qvals[i][0], k=100) tasks.append(asyncio.ensure_future(run_request(stub, request, args.experiment))) await asyncio.sleep(inter_request_time[i % length]) @@ -80,7 +80,7 @@ async def run(args): total_time = str(time.time()-t) open(args.output, "w").write("\n".join([str(x) for x in ret[1]]) + f"\nTotal time: {total_time}") - print(f"Total time for {len(qvals)-100} requests:", total_time) + print(f"Total time for {len(qvals)} requests:", total_time) await stub.DumpScores(server_pb2.Empty()) diff --git a/server.py b/server.py index 4e14021..402b7bf 100644 --- a/server.py +++ b/server.py @@ -75,7 +75,7 @@ def __init__(self, num_workers, index, mmap): self.colbert_results = [] self.pisa_results = [] - checkpoint_path = self.prefix + "/colbertv2.0/" + checkpoint_path = "colbert-ir/colbertv2.0" self.colbert_search_config = ColBERTConfig( index_root=os.path.join(os.environ["DATA_PATH"], "indexes"), @@ -158,10 +158,10 @@ def api_serve_query(self, query, qid, k=100): combined_scores = {} for d, v in zip(pids_, scores_): - combined_scores[d] = 0.5 * v + combined_scores[d] = 0.7 * v for d, v in zip(docs, pisa_score): - combined_scores[d] += 0.5 * v + combined_scores[d] += 0.3 * v sorted_pids = sorted(combined_scores.items(), key=lambda x: -x[1]) @@ -241,7 +241,7 @@ def DumpScores(self, request, context): def serve_ColBERT_server(args): connection = None if args.run_mode == "driver": - connection = Listener(('localhost', 50040), authkey=b'password') + connection = Listener(('localhost', 50040), authkey=b'password').accept() server = grpc.server(futures.ThreadPoolExecutor()) server_pb2_grpc.add_ServerServicer_to_server(ColBERTServer(args.num_workers, args.index, args.mmap), server) @@ -262,7 +262,7 @@ def serve_ColBERT_server(args): parser = argparse.ArgumentParser(description='Server for ColBERT') parser.add_argument('-w', '--num_workers', type=int, required=True, help='Number of worker threads (torch.num_threads)') - parser.add_argument('-i', '--index', type=str, equired=True, help='Index to run (use "wiki", "msmarco", "lifestyle" to repro the paper, or specify your own index name)') + parser.add_argument('-i', '--index', type=str, required=True, help='Index to run (use "wiki", "msmarco", "lifestyle" to repro the paper, or specify your own index name)') parser.add_argument('-m', '--mmap', action="store_true", help='If the index is memory mapped') parser.add_argument("-r", "--run_mode", default="server", choices=["server", "driver"], help="Use -r driver while invoking from driver.py")