Skip to content

Commit

Permalink
Minor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
VThejas committed Nov 2, 2024
1 parent 4b23820 commit 09b9653
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def run(args):
tasks = []
t = time.time()

for i in range(min(len(qvals), 1000)):
for i in range(len(qvals)):
request = server_pb2.Query(query=qvals[i][1], qid=qvals[i][0], k=100)
tasks.append(asyncio.ensure_future(run_request(stub, request, args.experiment)))
await asyncio.sleep(inter_request_time[i % length])
Expand All @@ -80,7 +80,7 @@ async def run(args):
total_time = str(time.time()-t)

open(args.output, "w").write("\n".join([str(x) for x in ret[1]]) + f"\nTotal time: {total_time}")
print(f"Total time for {len(qvals)-100} requests:", total_time)
print(f"Total time for {len(qvals)} requests:", total_time)

await stub.DumpScores(server_pb2.Empty())

Expand Down
10 changes: 5 additions & 5 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, num_workers, index, mmap):
self.colbert_results = []
self.pisa_results = []

checkpoint_path = self.prefix + "/colbertv2.0/"
checkpoint_path = "colbert-ir/colbertv2.0"

self.colbert_search_config = ColBERTConfig(
index_root=os.path.join(os.environ["DATA_PATH"], "indexes"),
Expand Down Expand Up @@ -158,10 +158,10 @@ def api_serve_query(self, query, qid, k=100):
combined_scores = {}

for d, v in zip(pids_, scores_):
combined_scores[d] = 0.5 * v
combined_scores[d] = 0.7 * v

for d, v in zip(docs, pisa_score):
combined_scores[d] += 0.5 * v
combined_scores[d] += 0.3 * v

sorted_pids = sorted(combined_scores.items(), key=lambda x: -x[1])

Expand Down Expand Up @@ -241,7 +241,7 @@ def DumpScores(self, request, context):
def serve_ColBERT_server(args):
connection = None
if args.run_mode == "driver":
connection = Listener(('localhost', 50040), authkey=b'password')
connection = Listener(('localhost', 50040), authkey=b'password').accept()

server = grpc.server(futures.ThreadPoolExecutor())
server_pb2_grpc.add_ServerServicer_to_server(ColBERTServer(args.num_workers, args.index, args.mmap), server)
Expand All @@ -262,7 +262,7 @@ def serve_ColBERT_server(args):
parser = argparse.ArgumentParser(description='Server for ColBERT')
parser.add_argument('-w', '--num_workers', type=int, required=True,
help='Number of worker threads (torch.num_threads)')
parser.add_argument('-i', '--index', type=str, equired=True, help='Index to run (use "wiki", "msmarco", "lifestyle" to repro the paper, or specify your own index name)')
parser.add_argument('-i', '--index', type=str, required=True, help='Index to run (use "wiki", "msmarco", "lifestyle" to repro the paper, or specify your own index name)')
parser.add_argument('-m', '--mmap', action="store_true", help='If the index is memory mapped')
parser.add_argument("-r", "--run_mode", default="server", choices=["server", "driver"], help="Use -r driver while invoking from driver.py")

Expand Down

0 comments on commit 09b9653

Please sign in to comment.