From 961318c3a51b1f1714b5272107c4c5d4d1684126 Mon Sep 17 00:00:00 2001
From: Kyung Jae Lee <kj25lee@uwaterloo.ca>
Date: Sun, 19 Nov 2023 18:24:12 -0500
Subject: [PATCH] Update AToMiC demo page to support dense search

---
 pyserini/demo/atomic.py             | 199 ++++++++++++++++++++++------
 pyserini/demo/templates/atomic.html | 147 ++++++++++++++++----
 2 files changed, 278 insertions(+), 68 deletions(-)

diff --git a/pyserini/demo/atomic.py b/pyserini/demo/atomic.py
index b9cb53488..fafce8b3c 100644
--- a/pyserini/demo/atomic.py
+++ b/pyserini/demo/atomic.py
@@ -29,41 +29,134 @@
 from typing import Callable, Optional, Tuple, Union
 
 from flask import Flask, render_template, request, flash, jsonify
-from pyserini.search import LuceneSearcher, FaissSearcher
+from pyserini.search import LuceneSearcher, FaissSearcher, QueryEncoder
+
+
+RETRIEVER_TO_INDEXES = {
+    'BM25': [
+        'atomic_image_v0.2_small_validation',
+        'atomic_image_v0.2_base',
+        'atomic_image_v0.2_large',
+        'atomic_text_v0.2.1_small_validation',
+        'atomic_text_v0.2.1_base',
+        'atomic_text_v0.2.1_large',
+    ],
+    'ViT-L-14.laion2b_s32b_b82k': [
+        'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation',
+        'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base',
+        'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large',
+        'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation',
+        'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base',
+        'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large',
+    ],
+    'ViT-H-14.laion2b_s32b_b79k': [
+        'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large',
+        'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large',
+    ],
+    'ViT-bigG-14.laion2b_s39b_b160k': [
+        'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large',
+        'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large',
+    ],
+    'ViT-B-32.laion2b_e16': [
+        'atomic-v0.2.ViT-B-32.laion2b_e16.image.large',
+        'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large',
+    ],
+    'ViT-B-32.laion400m_e32': [
+        'atomic-v0.2.ViT-B-32.laion400m_e32.image.large',
+        'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large',
+    ],
+    'openai.clip-vit-base-patch32': [
+        'atomic-v0.2.openai.clip-vit-base-patch32.image.large',
+        'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large',
+    ],
+    'openai.clip-vit-large-patch14': [
+        'atomic-v0.2.openai.clip-vit-large-patch14.image.large',
+        'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large',
+    ],
+    'Salesforce.blip-itm-base-coco': [
+        'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large',
+        'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large',
+    ],
+    'Salesforce.blip-itm-large-coco': [
+        'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large',
+        'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large',
+    ],
+    'facebook.flava-full': [
+        'atomic-v0.2.facebook.flava-full.image.large',
+        'atomic-v0.2.1.facebook.flava-full.text.large',
+    ],
+}
+
+INDEX_TO_ENCODED_QUERIES = {
+    # 'ViT-L-14.laion2b_s32b_b82k'
+    'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation',
+    'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation',
+    'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation',
+    'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation',
+    'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation',
+    'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation',
+    # ViT-H-14.laion2b_s32b_b79k
+    'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large': 'atomic-v0.2.1-text-ViT-H-14.laion2b_s32b_b79k-validation',
+    'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large': 'atomic-v0.2-image-ViT-H-14.laion2b_s32b_b79k-validation',
+    # ViT-bigG-14.laion2b_s39b_b160k
+    'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large': 'atomic-v0.2.1-text-ViT-bigG-14.laion2b_s39b_b160k-validation',
+    'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large': 'atomic-v0.2-image-ViT-bigG-14.laion2b_s39b_b160k-validation',
+    # ViT-B-32.laion2b_e16
+    'atomic-v0.2.ViT-B-32.laion2b_e16.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion2b_e16-validation',
+    'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large': 'atomic-v0.2-image-ViT-B-32.laion2b_e16-validation',
+    # ViT-B-32.laion400m_e32
+    'atomic-v0.2.ViT-B-32.laion400m_e32.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion400m_e32-validation',
+    'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large': 'atomic-v0.2-image-ViT-B-32.laion400m_e32-validation',
+    # openai.clip-vit-base-patch32
+    'atomic-v0.2.openai.clip-vit-base-patch32.image.large': 'atomic-v0.2.1-text-openai.clip-vit-base-patch32-validation',
+    'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large': 'atomic-v0.2-image-openai.clip-vit-base-patch32-validation',
+    # openai.clip-vit-large-patch14
+    'atomic-v0.2.openai.clip-vit-large-patch14.image.large': 'atomic-v0.2.1-text-openai.clip-vit-large-patch14-validation',
+    'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large': 'atomic-v0.2-image-openai.clip-vit-large-patch14-validation',
+    # Salesforce.blip-itm-base-coco
+    'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-base-coco-validation',
+    'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-base-coco-validation',
+    # Salesforce.blip-itm-large-coco
+    'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-large-coco-validation',
+    'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-large-coco-validation',
+    # facebook.flava-full
+    'atomic-v0.2.facebook.flava-full.image.large': 'atomic-v0.2.1-text-facebook.flava-full-validation',
+    'atomic-v0.2.1.facebook.flava-full.text.large': 'atomic-v0.2-image-facebook.flava-full-validation',
+}
 
-
-INDEX_NAMES = (
-    'atomic_image_v0.2_small_validation',
-    'atomic_image_v0.2_base',
-    'atomic_image_v0.2_large',
-    'atomic_text_v0.2.1_small_validation',
-    'atomic_text_v0.2.1_base',
-    'atomic_text_v0.2.1_large',
-)
 Searcher = Union[FaissSearcher, LuceneSearcher]
 
 
-def create_app(k: int, load_searcher_fn: Callable[[str], Tuple[Searcher, str]]):
+def create_app(k: int, load_searcher_fn: Callable[[str], Searcher]):
     app = Flask(__name__)
 
-    index_name = INDEX_NAMES[0]
-    searcher, retriever = load_searcher_fn(index_name=index_name)
+    # Use BM25 as default retriever upon page load
+    retriever = "BM25"
+    index_name = RETRIEVER_TO_INDEXES[retriever][0]
+    searcher = load_searcher_fn(index_name=index_name)
+    query_options = [] # for dense search only
 
     @app.route('/')
     def index():
-        nonlocal searcher, retriever
-        return render_template('atomic.html', index_name=index_name, retriever=retriever)
+        return render_template(
+            'atomic.html', index_name=index_name, retriever=retriever, retriever_to_indexes=RETRIEVER_TO_INDEXES
+        )
 
     @app.route('/search', methods=['GET', 'POST'])
     def search():
-        nonlocal searcher, retriever
         query = request.form['q']
+        if retriever != "BM25":
+            query = query_options[int(query)]
         if not query:
             search_results = []
             flash('Question is required')
             # NOTE: this throws an exception unless we set a secret session key
         else:
-            hits = searcher.search(query, k=k)
+            try:
+                hits = searcher.search(query, k=k)
+            except KeyError:
+                hits = []
+                flash('Invalid query given')
             docs = [json.loads(searcher.doc(hit.docid).raw()) for hit in hits]
             search_results = [
                 {
@@ -76,56 +169,74 @@ def search():
                 for r, hit in enumerate(hits)
             ]
         return render_template(
-            'atomic.html', index_name=index_name, search_results=search_results, query=query, retriever=retriever
+            'atomic.html', index_name=index_name, retriever=retriever,
+            retriever_to_indexes=RETRIEVER_TO_INDEXES, search_results=search_results, query=query,
         )
 
+    def _change_index(new_index_name):
+        nonlocal index_name, searcher, query_options
+        index_name = new_index_name
+        searcher = load_searcher_fn(index_name=index_name)
+        if retriever != "BM25":
+            query_options = {i: option for i, option in enumerate(searcher.query_encoder.embedding.keys())}
+
+    @app.route('/retriever', methods=['GET'])
+    def change_retriever():
+        nonlocal retriever
+        new_retriever = request.args.get('new_retriever_name', '', type=str)
+        if not new_retriever or new_retriever not in list(RETRIEVER_TO_INDEXES.keys()):
+            return
+
+        retriever = new_retriever
+        _change_index(new_index_name=RETRIEVER_TO_INDEXES[retriever][0])
+        return jsonify(index_list=RETRIEVER_TO_INDEXES[retriever])
+
     @app.route('/index', methods=['GET'])
     def change_index_name():
-        nonlocal index_name, searcher, retriever
         new_index_name = request.args.get('new_index_name', '', type=str)
-        if not new_index_name or new_index_name not in INDEX_NAMES:
+        if not new_index_name or new_index_name not in RETRIEVER_TO_INDEXES[retriever]:
             return
-
-        index_name = new_index_name
-        searcher, retriever = load_searcher_fn(index_name=index_name)
+        _change_index(new_index_name)
         return jsonify(index_name=index_name)
 
+    @app.route('/search_options', methods=['GET'])
+    def search_options():
+        query = request.args.get('query', '')
+
+        matching_options = {
+            i: option
+            for i, option in query_options.items()
+            if option.lower().startswith(query.lower())
+        }
+        return jsonify(matching_options)
+
     return app
 
 
-def _load_sparse_searcher(index_name, language: str, k1: Optional[float]=None, b: Optional[float]=None) -> (Searcher, str):
-    searcher = LuceneSearcher.from_prebuilt_index(index_name)
-    if k1 is not None and b is not None:
-        searcher.set_bm25(k1, b)
-        retriever_name = f'BM25 (k1={k1}, b={b})'
+def _load_searcher(index_name: str, language: str, k1: Optional[float]=None, b: Optional[float]=None):
+    if index_name in RETRIEVER_TO_INDEXES['BM25']:
+        searcher = LuceneSearcher.from_prebuilt_index(index_name)
+        if k1 is not None and b is not None:
+            searcher.set_bm25(k1, b)
     else:
-        retriever_name = 'BM25'
-
-    return searcher, retriever_name
+        query_encoder = QueryEncoder.load_encoded_queries(INDEX_TO_ENCODED_QUERIES[index_name])
+        searcher = FaissSearcher.from_prebuilt_index(
+            index_name, query_encoder
+        )
+    return searcher
 
 
 def main():
     parser = ArgumentParser()
-
     parser.add_argument('--k1', type=float, help='BM25 k1 parameter.')
     parser.add_argument('--b', type=float, help='BM25 b parameter.')
     parser.add_argument('--hits', type=int, default=10, help='Number of hits returned by the retriever')
     parser.add_argument(
-        '--device',
-        type=str,
-        default='cpu',
-        help='Device to run query encoder, cpu or [cuda:0, cuda:1, ...] (used only when index is based on FAISS)',
-    )
-    parser.add_argument(
-        '--port',
-        default=8080,
-        type=int,
-        help='Web server port',
+        '--port', default=8080, type=int, help='Web server port',
     )
-
     args = parser.parse_args()
-    load_fn = partial(_load_sparse_searcher, language='en', k1=args.k1, b=args.b)
 
+    load_fn = partial(_load_searcher, language='en', k1=args.k1, b=args.b)
     app = create_app(args.hits, load_fn)
     app.run(host='0.0.0.0', port=args.port)
 
diff --git a/pyserini/demo/templates/atomic.html b/pyserini/demo/templates/atomic.html
index ff916be02..f278fdf60 100644
--- a/pyserini/demo/templates/atomic.html
+++ b/pyserini/demo/templates/atomic.html
@@ -15,12 +15,91 @@
   <script>
     $SCRIPT_ROOT = {{ request.script_root | tojson }};
 
-    $(document).ready(function () {
+    $(function () {
+      $("#retriever_name").val("{{retriever}}");
+      $("#index_name").val("{{index_name}}");
       $("#loading").hide();
-      $('#index_name').val("{{index_name}}");
-    });
+      $("#pre_encoded_query_table").hide();
+
+      $("#search_button").on("click", function () {
+        // for sparse search, we can submit the input directly as a query
+        // for dense search, we only support searching from pre-encoded queries,
+        // so the search button will retrieve valid pre-encoded queries
+        if ($("#retriever_name").val() === "BM25") {
+          $("#submit_query_form").submit();
+        } else {
+          const query = $("#query_input").val();
+          $.getJSON(
+            $SCRIPT_ROOT + "/search_options", {
+            query: query
+          }, function (data) {
+            $("#loading").hide();
+            $("#retriever_name").removeAttr('disabled');
+            $("#index_name").removeAttr('disabled');
+            $("#results_table").hide();
+            $("#pre_encoded_query_table").show();
+
+            const validQueries = $("#pre_encoded_queries");
+            validQueries.empty();
+            // construct row for each valid pre-encoded query
+            var rowNum = 0;
+            $.each(data, function (queryIndex, result) {
+              const tr = $("<tr>")
+                .attr({ "class": rowNum % 2 ? "table-secondary" : "table-light" });
+              tr.append($("<td>").text(queryIndex));
+              tr.append(
+                $("<td>")
+                  .attr({ "style": "word-wrap: break-word;min-width: 600px;max-width: 600px;", "class": "text-start" })
+                  .append($("<small></small>").attr({ "data-query-idx": queryIndex }).text(result))
+              );
+              tr.append(
+                $("<td>")
+                  .append($("<button>")
+                    .attr({ "type": "button" })
+                    .click(function () {
+                      const queryIndex = $(this).closest("tr").find("small").attr("data-query-idx");
+                      $("#query_input").val(queryIndex);
+                      $("#submit_query_form").submit();
+                    })
+                    .text("Use this query"))
+              );
+              validQueries.append(tr);
+              rowIndex += 1;
+            });
+          });
+
+          $("#retriever_name").attr('disabled', 'disabled');
+          $("#index_name").attr('disabled', 'disabled');
+          $("#loading").show();
+          return false;
+        }
+      });
+
+      $('#retriever_name').on('change', function () {
+        $.getJSON($SCRIPT_ROOT + '/retriever', {
+          new_retriever_name: this.value,
+        }, function (data) {
+          $("#retriever_name").removeAttr('disabled');
+          $("#loading").hide();
+
+          var dropdownElem = $('#index_name');
+          dropdownElem.empty();
+          $.each(data.index_list, function (index, val) {
+            var option = $('<option></option>');
+            option.text(val);
+            option.val(val);
+            if (index === 0) option.selected = true;
+            dropdownElem.append(option);
+          });
+        });
+
+        $(this).attr('disabled', 'disabled');
+        $("#pre_encoded_query_table").hide();
+        $("#loading").show();
+
+        return false;
+      });
 
-    $(function () {
       $('#index_name').on('change', function () {
         $.getJSON($SCRIPT_ROOT + '/index', {
           new_index_name: this.value,
@@ -30,6 +109,7 @@
         });
 
         $(this).attr('disabled', 'disabled');
+        $("#pre_encoded_query_table").hide();
         $("#loading").show();
 
         return false;
@@ -50,16 +130,30 @@ <h4>Large-scale image/text retrieval test collection</h4>
     collection designed to aid in multimedia content creation.
   </p>
 
+  <p>
+    For sparse search (BM25), simply type in your query and submit.<br>
+    For dense search, we support searching from a set of pre-encoded queries. Submitting your query will
+    retrieve the pre-encoded queries that begins with your query (case-insensitive).
+  </p>
+
+  <div class="row g-3 align-items-center">
+    <label class="col-auto" for="retriever_name">You are using the following retriever</label>
+    <div class="col-auto">
+      <select class="form-select form-select-sm" aria-label=".form-select-sm" id="retriever_name">
+        {% for retriever in retriever_to_indexes.keys() %}
+        <option value="{{ retriever }}">{{ retriever }}</option>
+        {% endfor %}
+      </select>
+    </div>
+  </div>
+
   <div class="row g-3 align-items-center">
-    <label class="col-auto" for="index">You are perfoming search on the following dataset</label>
+    <label class="col-auto" for="index_name">You are perfoming search on the following dataset</label>
     <div class="col-auto">
       <select class="form-select form-select-sm" aria-label=".form-select-sm" id="index_name">
-        <option value="atomic_image_v0.2_small_validation">Text to Image (Small)</option>
-        <option value="atomic_image_v0.2_base">Text to Image (Base)</option>
-        <option value="atomic_image_v0.2_large">Text to Image (Large)</option>
-        <option value="atomic_text_v0.2.1_small_validation">Image to Text (Small)</option>
-        <option value="atomic_text_v0.2.1_base">Image to Text (Base)</option>
-        <option value="atomic_text_v0.2.1_large">Image to Text (Large)</option>
+        {% for index in retriever_to_indexes[retriever] %}
+        <option value="{{ index }}">{{ index }}</option>
+        {% endfor %}
       </select>
     </div>
     <div class="col-auto">
@@ -67,11 +161,6 @@ <h4>Large-scale image/text retrieval test collection</h4>
         <span class="visually-hidden">Loading...</span>
       </div>
     </div>
-    <div class="col-auto">
-      <span>
-        retrieves passages using <em>{{retriever}}</em>.
-      </span>
-    </div>
   </div>
 
   <br />
@@ -81,27 +170,37 @@ <h4>Large-scale image/text retrieval test collection</h4>
     <div class="alert">{{ message }}</div>
     {% endfor %}
 
-    <form action="/search" method="post">
+    <form action="/search" method="post" id="submit_query_form">
       <div class="row-cols-3">
         <div class="input-group mb-3">
-          <input type="text" class="form-control" placeholder="Enter a Question" aria-label="Question" name="q"
-            aria-describedby="button-addon2" value="{{query if query else ''}}">
-          <button class="btn btn-outline-secondary" type="submit" id="button-addon2"><i
-              class="bi bi-search"></i></button>
+          <input type="text" id="query_input" class="form-control" placeholder="Enter a Question" aria-label="Question"
+            name="q" aria-describedby="search_button" value="{{query if query else ''}}">
+          <button class="btn btn-outline-secondary" id="search_button"><i class="bi bi-search"></i></button>
         </div>
+        <table class="table" id="pre_encoded_query_table">
+          <thead>
+            <tr>
+              <th scope="col">Query Number</th>
+              <th scope="col">Pre-encoded Query</th>
+              <th scope="col"></th>
+            </tr>
+          </thead>
+          <tbody class="table-group-divider" id="pre_encoded_queries">
+          </tbody>
+        </table>
       </div>
     </form>
 
     {% if search_results %}
     <div class="row">
-      <table class="table">
+      <table class="table" id="results_table">
         <thead>
           <tr>
             <th scope="col">#</th>
             <th scope="col">Score</th>
             <th scope="col">Passage ID</th>
             <th scope="col">Content</th>
-            {% if index_name.startswith("atomic_image") %}
+            {% if "image" in index_name %}
             <th scope="col">Image</th>
             {% endif %}
           </tr>
@@ -115,7 +214,7 @@ <h4>Large-scale image/text retrieval test collection</h4>
             <td style="word-wrap: break-word;min-width: 600px;max-width: 600px;" class="text-start">
               <small>{{res["content"]}}</small>
             </td>
-            {% if index_name.startswith("atomic_image") %}
+            {% if "image" in index_name %}
             <td>
               <img src="{{ res['image_url'] }}" width="500" height="auto">
             </td>