Skip to content

Commit

Permalink
Add server.py, .env and dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
Thejas authored and Thejas committed Feb 19, 2023
1 parent be4271f commit 3fd36e3
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
INDEX_ROOT=""
INDEX_NAME=""
PORT="8893"
2 changes: 2 additions & 0 deletions conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ dependencies:
- tqdm
- transformers
- ujson
- flask
- python-dotenv
2 changes: 2 additions & 0 deletions conda_env_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ dependencies:
- tqdm
- transformers
- ujson
- flask
- python-dotenv
48 changes: 48 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from flask import Flask, render_template, request
from functools import lru_cache
import math
import os
from dotenv import load_dotenv

from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Searcher

load_dotenv()

INDEX_NAME = os.getenv("INDEX_NAME")
INDEX_ROOT = os.getenv("INDEX_ROOT")
app = Flask(__name__)

searcher = Searcher(index=f"{INDEX_ROOT}/{INDEX_NAME}")
counter = {"api" : 0}

@lru_cache(maxsize=1000000)
def api_search_query(query, k):
print(f"Query={query}")
if k == None: k = 10
k = min(int(k), 100)
pids, ranks, scores = searcher.search(query, k=100)
pids, ranks, scores = pids[:k], ranks[:k], scores[:k]
passages = [searcher.collection[pid] for pid in pids]
probs = [math.exp(score) for score in scores]
probs = [prob / sum(probs) for prob in probs]
topk = []
for pid, rank, score, prob in zip(pids, ranks, scores, probs):
text = searcher.collection[pid]
d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob}
topk.append(d)
topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
return {"query" : query, "topk": topk}

@app.route("/api/search", methods=["GET"])
def api_search():
if request.method == "GET":
counter["api"] += 1
print("API request count:", counter["api"])
return api_search_query(request.args.get("query"), request.args.get("k"))
else:
return ('', 405)

if __name__ == "__main__":
app.run("0.0.0.0", int(os.getenv("PORT")))

0 comments on commit 3fd36e3

Please sign in to comment.