Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update AToMiC demo page to support dense search #1721

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 155 additions & 44 deletions pyserini/demo/atomic.py
Original file line number Diff line number Diff line change
@@ -29,41 +29,134 @@
from typing import Callable, Optional, Tuple, Union

from flask import Flask, render_template, request, flash, jsonify
from pyserini.search import LuceneSearcher, FaissSearcher
from pyserini.search import LuceneSearcher, FaissSearcher, QueryEncoder


RETRIEVER_TO_INDEXES = {
'BM25': [
'atomic_image_v0.2_small_validation',
'atomic_image_v0.2_base',
'atomic_image_v0.2_large',
'atomic_text_v0.2.1_small_validation',
'atomic_text_v0.2.1_base',
'atomic_text_v0.2.1_large',
],
'ViT-L-14.laion2b_s32b_b82k': [
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large',
],
'ViT-H-14.laion2b_s32b_b79k': [
'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large',
'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large',
],
'ViT-bigG-14.laion2b_s39b_b160k': [
'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large',
'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large',
],
'ViT-B-32.laion2b_e16': [
'atomic-v0.2.ViT-B-32.laion2b_e16.image.large',
'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large',
],
'ViT-B-32.laion400m_e32': [
'atomic-v0.2.ViT-B-32.laion400m_e32.image.large',
'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large',
],
'openai.clip-vit-base-patch32': [
'atomic-v0.2.openai.clip-vit-base-patch32.image.large',
'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large',
],
'openai.clip-vit-large-patch14': [
'atomic-v0.2.openai.clip-vit-large-patch14.image.large',
'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large',
],
'Salesforce.blip-itm-base-coco': [
'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large',
'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large',
],
'Salesforce.blip-itm-large-coco': [
'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large',
'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large',
],
'facebook.flava-full': [
'atomic-v0.2.facebook.flava-full.image.large',
'atomic-v0.2.1.facebook.flava-full.text.large',
],
}

INDEX_TO_ENCODED_QUERIES = {
# 'ViT-L-14.laion2b_s32b_b82k'
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation',
# ViT-H-14.laion2b_s32b_b79k
'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large': 'atomic-v0.2.1-text-ViT-H-14.laion2b_s32b_b79k-validation',
'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large': 'atomic-v0.2-image-ViT-H-14.laion2b_s32b_b79k-validation',
# ViT-bigG-14.laion2b_s39b_b160k
'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large': 'atomic-v0.2.1-text-ViT-bigG-14.laion2b_s39b_b160k-validation',
'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large': 'atomic-v0.2-image-ViT-bigG-14.laion2b_s39b_b160k-validation',
# ViT-B-32.laion2b_e16
'atomic-v0.2.ViT-B-32.laion2b_e16.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion2b_e16-validation',
'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large': 'atomic-v0.2-image-ViT-B-32.laion2b_e16-validation',
# ViT-B-32.laion400m_e32
'atomic-v0.2.ViT-B-32.laion400m_e32.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion400m_e32-validation',
'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large': 'atomic-v0.2-image-ViT-B-32.laion400m_e32-validation',
# openai.clip-vit-base-patch32
'atomic-v0.2.openai.clip-vit-base-patch32.image.large': 'atomic-v0.2.1-text-openai.clip-vit-base-patch32-validation',
'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large': 'atomic-v0.2-image-openai.clip-vit-base-patch32-validation',
# openai.clip-vit-large-patch14
'atomic-v0.2.openai.clip-vit-large-patch14.image.large': 'atomic-v0.2.1-text-openai.clip-vit-large-patch14-validation',
'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large': 'atomic-v0.2-image-openai.clip-vit-large-patch14-validation',
# Salesforce.blip-itm-base-coco
'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-base-coco-validation',
'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-base-coco-validation',
# Salesforce.blip-itm-large-coco
'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-large-coco-validation',
'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-large-coco-validation',
# facebook.flava-full
'atomic-v0.2.facebook.flava-full.image.large': 'atomic-v0.2.1-text-facebook.flava-full-validation',
'atomic-v0.2.1.facebook.flava-full.text.large': 'atomic-v0.2-image-facebook.flava-full-validation',
}


INDEX_NAMES = (
'atomic_image_v0.2_small_validation',
'atomic_image_v0.2_base',
'atomic_image_v0.2_large',
'atomic_text_v0.2.1_small_validation',
'atomic_text_v0.2.1_base',
'atomic_text_v0.2.1_large',
)
Searcher = Union[FaissSearcher, LuceneSearcher]


def create_app(k: int, load_searcher_fn: Callable[[str], Tuple[Searcher, str]]):
def create_app(k: int, load_searcher_fn: Callable[[str], Searcher]):
app = Flask(__name__)

index_name = INDEX_NAMES[0]
searcher, retriever = load_searcher_fn(index_name=index_name)
# Use BM25 as default retriever upon page load
retriever = "BM25"
index_name = RETRIEVER_TO_INDEXES[retriever][0]
searcher = load_searcher_fn(index_name=index_name)
query_options = [] # for dense search only

@app.route('/')
def index():
nonlocal searcher, retriever
return render_template('atomic.html', index_name=index_name, retriever=retriever)
return render_template(
'atomic.html', index_name=index_name, retriever=retriever, retriever_to_indexes=RETRIEVER_TO_INDEXES
)

@app.route('/search', methods=['GET', 'POST'])
def search():
nonlocal searcher, retriever
query = request.form['q']
if retriever != "BM25":
query = query_options[int(query)]
if not query:
search_results = []
flash('Question is required')
# NOTE: this throws an exception unless we set a secret session key
else:
hits = searcher.search(query, k=k)
try:
hits = searcher.search(query, k=k)
except KeyError:
hits = []
flash('Invalid query given')
docs = [json.loads(searcher.doc(hit.docid).raw()) for hit in hits]
search_results = [
{
@@ -76,56 +169,74 @@ def search():
for r, hit in enumerate(hits)
]
return render_template(
'atomic.html', index_name=index_name, search_results=search_results, query=query, retriever=retriever
'atomic.html', index_name=index_name, retriever=retriever,
retriever_to_indexes=RETRIEVER_TO_INDEXES, search_results=search_results, query=query,
)

def _change_index(new_index_name):
nonlocal index_name, searcher, query_options
index_name = new_index_name
searcher = load_searcher_fn(index_name=index_name)
if retriever != "BM25":
query_options = {i: option for i, option in enumerate(searcher.query_encoder.embedding.keys())}

@app.route('/retriever', methods=['GET'])
def change_retriever():
nonlocal retriever
new_retriever = request.args.get('new_retriever_name', '', type=str)
if not new_retriever or new_retriever not in list(RETRIEVER_TO_INDEXES.keys()):
return

retriever = new_retriever
_change_index(new_index_name=RETRIEVER_TO_INDEXES[retriever][0])
return jsonify(index_list=RETRIEVER_TO_INDEXES[retriever])

@app.route('/index', methods=['GET'])
def change_index_name():
nonlocal index_name, searcher, retriever
new_index_name = request.args.get('new_index_name', '', type=str)
if not new_index_name or new_index_name not in INDEX_NAMES:
if not new_index_name or new_index_name not in RETRIEVER_TO_INDEXES[retriever]:
return

index_name = new_index_name
searcher, retriever = load_searcher_fn(index_name=index_name)
_change_index(new_index_name)
return jsonify(index_name=index_name)

@app.route('/search_options', methods=['GET'])
def search_options():
query = request.args.get('query', '')

matching_options = {
i: option
for i, option in query_options.items()
if option.lower().startswith(query.lower())
}
return jsonify(matching_options)

return app


def _load_sparse_searcher(index_name, language: str, k1: Optional[float]=None, b: Optional[float]=None) -> (Searcher, str):
searcher = LuceneSearcher.from_prebuilt_index(index_name)
if k1 is not None and b is not None:
searcher.set_bm25(k1, b)
retriever_name = f'BM25 (k1={k1}, b={b})'
def _load_searcher(index_name: str, language: str, k1: Optional[float]=None, b: Optional[float]=None):
if index_name in RETRIEVER_TO_INDEXES['BM25']:
searcher = LuceneSearcher.from_prebuilt_index(index_name)
if k1 is not None and b is not None:
searcher.set_bm25(k1, b)
else:
retriever_name = 'BM25'

return searcher, retriever_name
query_encoder = QueryEncoder.load_encoded_queries(INDEX_TO_ENCODED_QUERIES[index_name])
searcher = FaissSearcher.from_prebuilt_index(
index_name, query_encoder
)
return searcher


def main():
parser = ArgumentParser()

parser.add_argument('--k1', type=float, help='BM25 k1 parameter.')
parser.add_argument('--b', type=float, help='BM25 b parameter.')
parser.add_argument('--hits', type=int, default=10, help='Number of hits returned by the retriever')
parser.add_argument(
'--device',
type=str,
default='cpu',
help='Device to run query encoder, cpu or [cuda:0, cuda:1, ...] (used only when index is based on FAISS)',
)
parser.add_argument(
'--port',
default=8080,
type=int,
help='Web server port',
'--port', default=8080, type=int, help='Web server port',
)

args = parser.parse_args()
load_fn = partial(_load_sparse_searcher, language='en', k1=args.k1, b=args.b)

load_fn = partial(_load_searcher, language='en', k1=args.k1, b=args.b)
app = create_app(args.hits, load_fn)
app.run(host='0.0.0.0', port=args.port)

147 changes: 123 additions & 24 deletions pyserini/demo/templates/atomic.html
Original file line number Diff line number Diff line change
@@ -15,12 +15,91 @@
<script>
$SCRIPT_ROOT = {{ request.script_root | tojson }};

$(document).ready(function () {
$(function () {
$("#retriever_name").val("{{retriever}}");
$("#index_name").val("{{index_name}}");
$("#loading").hide();
$('#index_name').val("{{index_name}}");
});
$("#pre_encoded_query_table").hide();

$("#search_button").on("click", function () {
// for sparse search, we can submit the input directly as a query
// for dense search, we only support searching from pre-encoded queries,
// so the search button will retrieve valid pre-encoded queries
if ($("#retriever_name").val() === "BM25") {
$("#submit_query_form").submit();
} else {
const query = $("#query_input").val();
$.getJSON(
$SCRIPT_ROOT + "/search_options", {
query: query
}, function (data) {
$("#loading").hide();
$("#retriever_name").removeAttr('disabled');
$("#index_name").removeAttr('disabled');
$("#results_table").hide();
$("#pre_encoded_query_table").show();

const validQueries = $("#pre_encoded_queries");
validQueries.empty();
// construct row for each valid pre-encoded query
var rowNum = 0;
$.each(data, function (queryIndex, result) {
const tr = $("<tr>")
.attr({ "class": rowNum % 2 ? "table-secondary" : "table-light" });
tr.append($("<td>").text(queryIndex));
tr.append(
$("<td>")
.attr({ "style": "word-wrap: break-word;min-width: 600px;max-width: 600px;", "class": "text-start" })
.append($("<small></small>").attr({ "data-query-idx": queryIndex }).text(result))
);
tr.append(
$("<td>")
.append($("<button>")
.attr({ "type": "button" })
.click(function () {
const queryIndex = $(this).closest("tr").find("small").attr("data-query-idx");
$("#query_input").val(queryIndex);
$("#submit_query_form").submit();
})
.text("Use this query"))
);
validQueries.append(tr);
rowIndex += 1;
});
});

$("#retriever_name").attr('disabled', 'disabled');
$("#index_name").attr('disabled', 'disabled');
$("#loading").show();
return false;
}
});

$('#retriever_name').on('change', function () {
$.getJSON($SCRIPT_ROOT + '/retriever', {
new_retriever_name: this.value,
}, function (data) {
$("#retriever_name").removeAttr('disabled');
$("#loading").hide();

var dropdownElem = $('#index_name');
dropdownElem.empty();
$.each(data.index_list, function (index, val) {
var option = $('<option></option>');
option.text(val);
option.val(val);
if (index === 0) option.selected = true;
dropdownElem.append(option);
});
});

$(this).attr('disabled', 'disabled');
$("#pre_encoded_query_table").hide();
$("#loading").show();

return false;
});

$(function () {
$('#index_name').on('change', function () {
$.getJSON($SCRIPT_ROOT + '/index', {
new_index_name: this.value,
@@ -30,6 +109,7 @@
});

$(this).attr('disabled', 'disabled');
$("#pre_encoded_query_table").hide();
$("#loading").show();

return false;
@@ -50,28 +130,37 @@ <h4>Large-scale image/text retrieval test collection</h4>
collection designed to aid in multimedia content creation.
</p>

<p>
For sparse search (BM25), simply type in your query and submit.<br>
For dense search, we support searching from a set of pre-encoded queries. Submitting your query will
retrieve the pre-encoded queries that begins with your query (case-insensitive).
</p>

<div class="row g-3 align-items-center">
<label class="col-auto" for="retriever_name">You are using the following retriever</label>
<div class="col-auto">
<select class="form-select form-select-sm" aria-label=".form-select-sm" id="retriever_name">
{% for retriever in retriever_to_indexes.keys() %}
<option value="{{ retriever }}">{{ retriever }}</option>
{% endfor %}
</select>
</div>
</div>

<div class="row g-3 align-items-center">
<label class="col-auto" for="index">You are perfoming search on the following dataset</label>
<label class="col-auto" for="index_name">You are perfoming search on the following dataset</label>
<div class="col-auto">
<select class="form-select form-select-sm" aria-label=".form-select-sm" id="index_name">
<option value="atomic_image_v0.2_small_validation">Text to Image (Small)</option>
<option value="atomic_image_v0.2_base">Text to Image (Base)</option>
<option value="atomic_image_v0.2_large">Text to Image (Large)</option>
<option value="atomic_text_v0.2.1_small_validation">Image to Text (Small)</option>
<option value="atomic_text_v0.2.1_base">Image to Text (Base)</option>
<option value="atomic_text_v0.2.1_large">Image to Text (Large)</option>
{% for index in retriever_to_indexes[retriever] %}
<option value="{{ index }}">{{ index }}</option>
{% endfor %}
</select>
</div>
<div class="col-auto">
<div class="spinner-border text-secondary" role="status" id="loading">
<span class="visually-hidden">Loading...</span>
</div>
</div>
<div class="col-auto">
<span>
retrieves passages using <em>{{retriever}}</em>.
</span>
</div>
</div>

<br />
@@ -81,27 +170,37 @@ <h4>Large-scale image/text retrieval test collection</h4>
<div class="alert">{{ message }}</div>
{% endfor %}

<form action="/search" method="post">
<form action="/search" method="post" id="submit_query_form">
<div class="row-cols-3">
<div class="input-group mb-3">
<input type="text" class="form-control" placeholder="Enter a Question" aria-label="Question" name="q"
aria-describedby="button-addon2" value="{{query if query else ''}}">
<button class="btn btn-outline-secondary" type="submit" id="button-addon2"><i
class="bi bi-search"></i></button>
<input type="text" id="query_input" class="form-control" placeholder="Enter a Question" aria-label="Question"
name="q" aria-describedby="search_button" value="{{query if query else ''}}">
<button class="btn btn-outline-secondary" id="search_button"><i class="bi bi-search"></i></button>
</div>
<table class="table" id="pre_encoded_query_table">
<thead>
<tr>
<th scope="col">Query Number</th>
<th scope="col">Pre-encoded Query</th>
<th scope="col"></th>
</tr>
</thead>
<tbody class="table-group-divider" id="pre_encoded_queries">
</tbody>
</table>
</div>
</form>

{% if search_results %}
<div class="row">
<table class="table">
<table class="table" id="results_table">
<thead>
<tr>
<th scope="col">#</th>
<th scope="col">Score</th>
<th scope="col">Passage ID</th>
<th scope="col">Content</th>
{% if index_name.startswith("atomic_image") %}
{% if "image" in index_name %}
<th scope="col">Image</th>
{% endif %}
</tr>
@@ -115,7 +214,7 @@ <h4>Large-scale image/text retrieval test collection</h4>
<td style="word-wrap: break-word;min-width: 600px;max-width: 600px;" class="text-start">
<small>{{res["content"]}}</small>
</td>
{% if index_name.startswith("atomic_image") %}
{% if "image" in index_name %}
<td>
<img src="{{ res['image_url'] }}" width="500" height="auto">
</td>