From 3fd36e3a263bc1890ec94679a8adf5afbe614ddf Mon Sep 17 00:00:00 2001 From: Thejas Date: Sat, 18 Feb 2023 23:57:44 -0800 Subject: [PATCH] Add server.py, .env and dependencies --- .env | 3 +++ conda_env.yml | 2 ++ conda_env_cpu.yml | 2 ++ server.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+) create mode 100644 .env create mode 100644 server.py diff --git a/.env b/.env new file mode 100644 index 00000000..125818ee --- /dev/null +++ b/.env @@ -0,0 +1,3 @@ +INDEX_ROOT="" +INDEX_NAME="" +PORT="8893" diff --git a/conda_env.yml b/conda_env.yml index d897db62..1b96a89f 100644 --- a/conda_env.yml +++ b/conda_env.yml @@ -26,3 +26,5 @@ dependencies: - tqdm - transformers - ujson + - flask + - python-dotenv diff --git a/conda_env_cpu.yml b/conda_env_cpu.yml index 1e4702f1..28002edf 100644 --- a/conda_env_cpu.yml +++ b/conda_env_cpu.yml @@ -18,3 +18,5 @@ dependencies: - tqdm - transformers - ujson + - flask + - python-dotenv diff --git a/server.py b/server.py new file mode 100644 index 00000000..56222145 --- /dev/null +++ b/server.py @@ -0,0 +1,48 @@ +from flask import Flask, render_template, request +from functools import lru_cache +import math +import os +from dotenv import load_dotenv + +from colbert.infra import Run, RunConfig, ColBERTConfig +from colbert import Searcher + +load_dotenv() + +INDEX_NAME = os.getenv("INDEX_NAME") +INDEX_ROOT = os.getenv("INDEX_ROOT") +app = Flask(__name__) + +searcher = Searcher(index=f"{INDEX_ROOT}/{INDEX_NAME}") +counter = {"api" : 0} + +@lru_cache(maxsize=1000000) +def api_search_query(query, k): + print(f"Query={query}") + if k == None: k = 10 + k = min(int(k), 100) + pids, ranks, scores = searcher.search(query, k=100) + pids, ranks, scores = pids[:k], ranks[:k], scores[:k] + passages = [searcher.collection[pid] for pid in pids] + probs = [math.exp(score) for score in scores] + probs = [prob / sum(probs) for prob in probs] + topk = [] + for pid, rank, score, prob in zip(pids, ranks, scores, probs): + text = searcher.collection[pid] + d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob} + topk.append(d) + topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) + return {"query" : query, "topk": topk} + +@app.route("/api/search", methods=["GET"]) +def api_search(): + if request.method == "GET": + counter["api"] += 1 + print("API request count:", counter["api"]) + return api_search_query(request.args.get("query"), request.args.get("k")) + else: + return ('', 405) + +if __name__ == "__main__": + app.run("0.0.0.0", int(os.getenv("PORT"))) +