From 961318c3a51b1f1714b5272107c4c5d4d1684126 Mon Sep 17 00:00:00 2001 From: Kyung Jae Lee <kj25lee@uwaterloo.ca> Date: Sun, 19 Nov 2023 18:24:12 -0500 Subject: [PATCH] Update AToMiC demo page to support dense search --- pyserini/demo/atomic.py | 199 ++++++++++++++++++++++------ pyserini/demo/templates/atomic.html | 147 ++++++++++++++++---- 2 files changed, 278 insertions(+), 68 deletions(-) diff --git a/pyserini/demo/atomic.py b/pyserini/demo/atomic.py index b9cb53488..fafce8b3c 100644 --- a/pyserini/demo/atomic.py +++ b/pyserini/demo/atomic.py @@ -29,41 +29,134 @@ from typing import Callable, Optional, Tuple, Union from flask import Flask, render_template, request, flash, jsonify -from pyserini.search import LuceneSearcher, FaissSearcher +from pyserini.search import LuceneSearcher, FaissSearcher, QueryEncoder + + +RETRIEVER_TO_INDEXES = { + 'BM25': [ + 'atomic_image_v0.2_small_validation', + 'atomic_image_v0.2_base', + 'atomic_image_v0.2_large', + 'atomic_text_v0.2.1_small_validation', + 'atomic_text_v0.2.1_base', + 'atomic_text_v0.2.1_large', + ], + 'ViT-L-14.laion2b_s32b_b82k': [ + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large', + ], + 'ViT-H-14.laion2b_s32b_b79k': [ + 'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large', + 'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large', + ], + 'ViT-bigG-14.laion2b_s39b_b160k': [ + 'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large', + 'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large', + ], + 'ViT-B-32.laion2b_e16': [ + 'atomic-v0.2.ViT-B-32.laion2b_e16.image.large', + 'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large', + ], + 'ViT-B-32.laion400m_e32': [ + 'atomic-v0.2.ViT-B-32.laion400m_e32.image.large', + 'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large', + ], + 'openai.clip-vit-base-patch32': [ + 'atomic-v0.2.openai.clip-vit-base-patch32.image.large', + 'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large', + ], + 'openai.clip-vit-large-patch14': [ + 'atomic-v0.2.openai.clip-vit-large-patch14.image.large', + 'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large', + ], + 'Salesforce.blip-itm-base-coco': [ + 'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large', + 'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large', + ], + 'Salesforce.blip-itm-large-coco': [ + 'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large', + 'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large', + ], + 'facebook.flava-full': [ + 'atomic-v0.2.facebook.flava-full.image.large', + 'atomic-v0.2.1.facebook.flava-full.text.large', + ], +} + +INDEX_TO_ENCODED_QUERIES = { + # 'ViT-L-14.laion2b_s32b_b82k' + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation', + 'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation', + # ViT-H-14.laion2b_s32b_b79k + 'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large': 'atomic-v0.2.1-text-ViT-H-14.laion2b_s32b_b79k-validation', + 'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large': 'atomic-v0.2-image-ViT-H-14.laion2b_s32b_b79k-validation', + # ViT-bigG-14.laion2b_s39b_b160k + 'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large': 'atomic-v0.2.1-text-ViT-bigG-14.laion2b_s39b_b160k-validation', + 'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large': 'atomic-v0.2-image-ViT-bigG-14.laion2b_s39b_b160k-validation', + # ViT-B-32.laion2b_e16 + 'atomic-v0.2.ViT-B-32.laion2b_e16.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion2b_e16-validation', + 'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large': 'atomic-v0.2-image-ViT-B-32.laion2b_e16-validation', + # ViT-B-32.laion400m_e32 + 'atomic-v0.2.ViT-B-32.laion400m_e32.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion400m_e32-validation', + 'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large': 'atomic-v0.2-image-ViT-B-32.laion400m_e32-validation', + # openai.clip-vit-base-patch32 + 'atomic-v0.2.openai.clip-vit-base-patch32.image.large': 'atomic-v0.2.1-text-openai.clip-vit-base-patch32-validation', + 'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large': 'atomic-v0.2-image-openai.clip-vit-base-patch32-validation', + # openai.clip-vit-large-patch14 + 'atomic-v0.2.openai.clip-vit-large-patch14.image.large': 'atomic-v0.2.1-text-openai.clip-vit-large-patch14-validation', + 'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large': 'atomic-v0.2-image-openai.clip-vit-large-patch14-validation', + # Salesforce.blip-itm-base-coco + 'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-base-coco-validation', + 'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-base-coco-validation', + # Salesforce.blip-itm-large-coco + 'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-large-coco-validation', + 'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-large-coco-validation', + # facebook.flava-full + 'atomic-v0.2.facebook.flava-full.image.large': 'atomic-v0.2.1-text-facebook.flava-full-validation', + 'atomic-v0.2.1.facebook.flava-full.text.large': 'atomic-v0.2-image-facebook.flava-full-validation', +} - -INDEX_NAMES = ( - 'atomic_image_v0.2_small_validation', - 'atomic_image_v0.2_base', - 'atomic_image_v0.2_large', - 'atomic_text_v0.2.1_small_validation', - 'atomic_text_v0.2.1_base', - 'atomic_text_v0.2.1_large', -) Searcher = Union[FaissSearcher, LuceneSearcher] -def create_app(k: int, load_searcher_fn: Callable[[str], Tuple[Searcher, str]]): +def create_app(k: int, load_searcher_fn: Callable[[str], Searcher]): app = Flask(__name__) - index_name = INDEX_NAMES[0] - searcher, retriever = load_searcher_fn(index_name=index_name) + # Use BM25 as default retriever upon page load + retriever = "BM25" + index_name = RETRIEVER_TO_INDEXES[retriever][0] + searcher = load_searcher_fn(index_name=index_name) + query_options = [] # for dense search only @app.route('/') def index(): - nonlocal searcher, retriever - return render_template('atomic.html', index_name=index_name, retriever=retriever) + return render_template( + 'atomic.html', index_name=index_name, retriever=retriever, retriever_to_indexes=RETRIEVER_TO_INDEXES + ) @app.route('/search', methods=['GET', 'POST']) def search(): - nonlocal searcher, retriever query = request.form['q'] + if retriever != "BM25": + query = query_options[int(query)] if not query: search_results = [] flash('Question is required') # NOTE: this throws an exception unless we set a secret session key else: - hits = searcher.search(query, k=k) + try: + hits = searcher.search(query, k=k) + except KeyError: + hits = [] + flash('Invalid query given') docs = [json.loads(searcher.doc(hit.docid).raw()) for hit in hits] search_results = [ { @@ -76,56 +169,74 @@ def search(): for r, hit in enumerate(hits) ] return render_template( - 'atomic.html', index_name=index_name, search_results=search_results, query=query, retriever=retriever + 'atomic.html', index_name=index_name, retriever=retriever, + retriever_to_indexes=RETRIEVER_TO_INDEXES, search_results=search_results, query=query, ) + def _change_index(new_index_name): + nonlocal index_name, searcher, query_options + index_name = new_index_name + searcher = load_searcher_fn(index_name=index_name) + if retriever != "BM25": + query_options = {i: option for i, option in enumerate(searcher.query_encoder.embedding.keys())} + + @app.route('/retriever', methods=['GET']) + def change_retriever(): + nonlocal retriever + new_retriever = request.args.get('new_retriever_name', '', type=str) + if not new_retriever or new_retriever not in list(RETRIEVER_TO_INDEXES.keys()): + return + + retriever = new_retriever + _change_index(new_index_name=RETRIEVER_TO_INDEXES[retriever][0]) + return jsonify(index_list=RETRIEVER_TO_INDEXES[retriever]) + @app.route('/index', methods=['GET']) def change_index_name(): - nonlocal index_name, searcher, retriever new_index_name = request.args.get('new_index_name', '', type=str) - if not new_index_name or new_index_name not in INDEX_NAMES: + if not new_index_name or new_index_name not in RETRIEVER_TO_INDEXES[retriever]: return - - index_name = new_index_name - searcher, retriever = load_searcher_fn(index_name=index_name) + _change_index(new_index_name) return jsonify(index_name=index_name) + @app.route('/search_options', methods=['GET']) + def search_options(): + query = request.args.get('query', '') + + matching_options = { + i: option + for i, option in query_options.items() + if option.lower().startswith(query.lower()) + } + return jsonify(matching_options) + return app -def _load_sparse_searcher(index_name, language: str, k1: Optional[float]=None, b: Optional[float]=None) -> (Searcher, str): - searcher = LuceneSearcher.from_prebuilt_index(index_name) - if k1 is not None and b is not None: - searcher.set_bm25(k1, b) - retriever_name = f'BM25 (k1={k1}, b={b})' +def _load_searcher(index_name: str, language: str, k1: Optional[float]=None, b: Optional[float]=None): + if index_name in RETRIEVER_TO_INDEXES['BM25']: + searcher = LuceneSearcher.from_prebuilt_index(index_name) + if k1 is not None and b is not None: + searcher.set_bm25(k1, b) else: - retriever_name = 'BM25' - - return searcher, retriever_name + query_encoder = QueryEncoder.load_encoded_queries(INDEX_TO_ENCODED_QUERIES[index_name]) + searcher = FaissSearcher.from_prebuilt_index( + index_name, query_encoder + ) + return searcher def main(): parser = ArgumentParser() - parser.add_argument('--k1', type=float, help='BM25 k1 parameter.') parser.add_argument('--b', type=float, help='BM25 b parameter.') parser.add_argument('--hits', type=int, default=10, help='Number of hits returned by the retriever') parser.add_argument( - '--device', - type=str, - default='cpu', - help='Device to run query encoder, cpu or [cuda:0, cuda:1, ...] (used only when index is based on FAISS)', - ) - parser.add_argument( - '--port', - default=8080, - type=int, - help='Web server port', + '--port', default=8080, type=int, help='Web server port', ) - args = parser.parse_args() - load_fn = partial(_load_sparse_searcher, language='en', k1=args.k1, b=args.b) + load_fn = partial(_load_searcher, language='en', k1=args.k1, b=args.b) app = create_app(args.hits, load_fn) app.run(host='0.0.0.0', port=args.port) diff --git a/pyserini/demo/templates/atomic.html b/pyserini/demo/templates/atomic.html index ff916be02..f278fdf60 100644 --- a/pyserini/demo/templates/atomic.html +++ b/pyserini/demo/templates/atomic.html @@ -15,12 +15,91 @@ <script> $SCRIPT_ROOT = {{ request.script_root | tojson }}; - $(document).ready(function () { + $(function () { + $("#retriever_name").val("{{retriever}}"); + $("#index_name").val("{{index_name}}"); $("#loading").hide(); - $('#index_name').val("{{index_name}}"); - }); + $("#pre_encoded_query_table").hide(); + + $("#search_button").on("click", function () { + // for sparse search, we can submit the input directly as a query + // for dense search, we only support searching from pre-encoded queries, + // so the search button will retrieve valid pre-encoded queries + if ($("#retriever_name").val() === "BM25") { + $("#submit_query_form").submit(); + } else { + const query = $("#query_input").val(); + $.getJSON( + $SCRIPT_ROOT + "/search_options", { + query: query + }, function (data) { + $("#loading").hide(); + $("#retriever_name").removeAttr('disabled'); + $("#index_name").removeAttr('disabled'); + $("#results_table").hide(); + $("#pre_encoded_query_table").show(); + + const validQueries = $("#pre_encoded_queries"); + validQueries.empty(); + // construct row for each valid pre-encoded query + var rowNum = 0; + $.each(data, function (queryIndex, result) { + const tr = $("<tr>") + .attr({ "class": rowNum % 2 ? "table-secondary" : "table-light" }); + tr.append($("<td>").text(queryIndex)); + tr.append( + $("<td>") + .attr({ "style": "word-wrap: break-word;min-width: 600px;max-width: 600px;", "class": "text-start" }) + .append($("<small></small>").attr({ "data-query-idx": queryIndex }).text(result)) + ); + tr.append( + $("<td>") + .append($("<button>") + .attr({ "type": "button" }) + .click(function () { + const queryIndex = $(this).closest("tr").find("small").attr("data-query-idx"); + $("#query_input").val(queryIndex); + $("#submit_query_form").submit(); + }) + .text("Use this query")) + ); + validQueries.append(tr); + rowIndex += 1; + }); + }); + + $("#retriever_name").attr('disabled', 'disabled'); + $("#index_name").attr('disabled', 'disabled'); + $("#loading").show(); + return false; + } + }); + + $('#retriever_name').on('change', function () { + $.getJSON($SCRIPT_ROOT + '/retriever', { + new_retriever_name: this.value, + }, function (data) { + $("#retriever_name").removeAttr('disabled'); + $("#loading").hide(); + + var dropdownElem = $('#index_name'); + dropdownElem.empty(); + $.each(data.index_list, function (index, val) { + var option = $('<option></option>'); + option.text(val); + option.val(val); + if (index === 0) option.selected = true; + dropdownElem.append(option); + }); + }); + + $(this).attr('disabled', 'disabled'); + $("#pre_encoded_query_table").hide(); + $("#loading").show(); + + return false; + }); - $(function () { $('#index_name').on('change', function () { $.getJSON($SCRIPT_ROOT + '/index', { new_index_name: this.value, @@ -30,6 +109,7 @@ }); $(this).attr('disabled', 'disabled'); + $("#pre_encoded_query_table").hide(); $("#loading").show(); return false; @@ -50,16 +130,30 @@ <h4>Large-scale image/text retrieval test collection</h4> collection designed to aid in multimedia content creation. </p> + <p> + For sparse search (BM25), simply type in your query and submit.<br> + For dense search, we support searching from a set of pre-encoded queries. Submitting your query will + retrieve the pre-encoded queries that begins with your query (case-insensitive). + </p> + + <div class="row g-3 align-items-center"> + <label class="col-auto" for="retriever_name">You are using the following retriever</label> + <div class="col-auto"> + <select class="form-select form-select-sm" aria-label=".form-select-sm" id="retriever_name"> + {% for retriever in retriever_to_indexes.keys() %} + <option value="{{ retriever }}">{{ retriever }}</option> + {% endfor %} + </select> + </div> + </div> + <div class="row g-3 align-items-center"> - <label class="col-auto" for="index">You are perfoming search on the following dataset</label> + <label class="col-auto" for="index_name">You are perfoming search on the following dataset</label> <div class="col-auto"> <select class="form-select form-select-sm" aria-label=".form-select-sm" id="index_name"> - <option value="atomic_image_v0.2_small_validation">Text to Image (Small)</option> - <option value="atomic_image_v0.2_base">Text to Image (Base)</option> - <option value="atomic_image_v0.2_large">Text to Image (Large)</option> - <option value="atomic_text_v0.2.1_small_validation">Image to Text (Small)</option> - <option value="atomic_text_v0.2.1_base">Image to Text (Base)</option> - <option value="atomic_text_v0.2.1_large">Image to Text (Large)</option> + {% for index in retriever_to_indexes[retriever] %} + <option value="{{ index }}">{{ index }}</option> + {% endfor %} </select> </div> <div class="col-auto"> @@ -67,11 +161,6 @@ <h4>Large-scale image/text retrieval test collection</h4> <span class="visually-hidden">Loading...</span> </div> </div> - <div class="col-auto"> - <span> - retrieves passages using <em>{{retriever}}</em>. - </span> - </div> </div> <br /> @@ -81,27 +170,37 @@ <h4>Large-scale image/text retrieval test collection</h4> <div class="alert">{{ message }}</div> {% endfor %} - <form action="/search" method="post"> + <form action="/search" method="post" id="submit_query_form"> <div class="row-cols-3"> <div class="input-group mb-3"> - <input type="text" class="form-control" placeholder="Enter a Question" aria-label="Question" name="q" - aria-describedby="button-addon2" value="{{query if query else ''}}"> - <button class="btn btn-outline-secondary" type="submit" id="button-addon2"><i - class="bi bi-search"></i></button> + <input type="text" id="query_input" class="form-control" placeholder="Enter a Question" aria-label="Question" + name="q" aria-describedby="search_button" value="{{query if query else ''}}"> + <button class="btn btn-outline-secondary" id="search_button"><i class="bi bi-search"></i></button> </div> + <table class="table" id="pre_encoded_query_table"> + <thead> + <tr> + <th scope="col">Query Number</th> + <th scope="col">Pre-encoded Query</th> + <th scope="col"></th> + </tr> + </thead> + <tbody class="table-group-divider" id="pre_encoded_queries"> + </tbody> + </table> </div> </form> {% if search_results %} <div class="row"> - <table class="table"> + <table class="table" id="results_table"> <thead> <tr> <th scope="col">#</th> <th scope="col">Score</th> <th scope="col">Passage ID</th> <th scope="col">Content</th> - {% if index_name.startswith("atomic_image") %} + {% if "image" in index_name %} <th scope="col">Image</th> {% endif %} </tr> @@ -115,7 +214,7 @@ <h4>Large-scale image/text retrieval test collection</h4> <td style="word-wrap: break-word;min-width: 600px;max-width: 600px;" class="text-start"> <small>{{res["content"]}}</small> </td> - {% if index_name.startswith("atomic_image") %} + {% if "image" in index_name %} <td> <img src="{{ res['image_url'] }}" width="500" height="auto"> </td>