Skip to content

Commit

Permalink
Initial commit with the new API and residual compression
Browse files Browse the repository at this point in the history
  • Loading branch information
okhat committed Oct 13, 2021
1 parent c4e79e8 commit 4120feb
Show file tree
Hide file tree
Showing 78 changed files with 6,583 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"jupyter.jupyterServerType": "local",
"python.formatting.autopep8Args": ["--max-line-length", "120"],
}
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2019, 2020 Stanford Future Data Systems

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ColBERT

### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds.


<p align="center">
<img align="center" src="docs/images/ColBERT-Framework-MaxSim-W370px.png" />
</p>
<p align="center">
<b>Figure 1:</b> ColBERT's late interaction, efficiently scoring the fine-grained similarity between a queries and a passage.
</p>

As Figure 1 illustrates, ColBERT relies on fine-grained **contextual late interaction**: it encodes each passage into a **matrix** of token-level embeddings (shown above in blue). Then at search time, it embeds every query into another matrix (shown in green) and efficiently finds passages that contextually match the query using scalable vector-similarity (`MaxSim`) operators.

These rich interactions allow ColBERT to surpass the quality of _single-vector_ representation models, while scaling efficiently to large corpora. You can read more in our papers:

* [**ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT**](https://arxiv.org/abs/2004.12832) (SIGIR'20).
* [**Relevance-guided Supervision for OpenQA with ColBERT**](https://arxiv.org/abs/2007.00814) (TACL'21).


----

## Installation

ColBERT (currently: [v0.4.6](#releases)) requires Python 3.7+ and Pytorch 1.9+ and uses the [HuggingFace Transformers](https://github.com/huggingface/transformers) library.

We strongly recommend creating a conda environment using the commands below. (If you don't have conda, follow the official [conda installation guide](https://docs.anaconda.com/anaconda/install/linux/#installation).)

```
conda env create -f conda_env.yml
conda activate colbert-v0.4.2
```

If you face any problems, please [open a new issue](https://github.com/stanford-futuredata/ColBERT/issues) and we'll help you promptly!


## NEW: API Usage Notebook

This Jupyter **[docs/intro.ipynb notebook](docs/intro.ipynb)** illustrates using the key features of ColBERT with the new Python API.

4 changes: 4 additions & 0 deletions colbert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .trainer import Trainer
from .indexer import Indexer
from .searcher import Searcher

Empty file added colbert/evaluation/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions colbert/evaluation/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import ujson
import torch
import random

from collections import defaultdict, OrderedDict

from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.utils.utils import print_message, load_checkpoint


def load_model(args, do_print=True):
colbert = ColBERT.from_pretrained('bert-base-uncased',
query_maxlen=args.query_maxlen,
doc_maxlen=args.doc_maxlen,
dim=args.dim,
similarity_metric=args.similarity,
mask_punctuation=args.mask_punctuation)
colbert = colbert.to(DEVICE)

print_message("#> Loading model checkpoint.", condition=do_print)

checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)

colbert.eval()

return colbert, checkpoint
196 changes: 196 additions & 0 deletions colbert/evaluation/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import os
import ujson
import torch
import random

from collections import defaultdict, OrderedDict

from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.utils.utils import print_message, load_checkpoint
from colbert.evaluation.load_model import load_model
from colbert.utils.runs import Run


def load_queries(queries_path):
queries = OrderedDict()

print_message("#> Loading the queries from", queries_path, "...")

with open(queries_path) as f:
for line in f:
qid, query, *_ = line.strip().split('\t')
qid = int(qid)

assert (qid not in queries), ("Query QID", qid, "is repeated!")
queries[qid] = query

print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")

return queries


def load_qrels(qrels_path):
if qrels_path is None:
return None

print_message("#> Loading qrels from", qrels_path, "...")

qrels = OrderedDict()
with open(qrels_path, mode='r', encoding="utf-8") as f:
for line in f:
qid, x, pid, y = map(int, line.strip().split('\t'))
assert x == 0 and y == 1
qrels[qid] = qrels.get(qid, [])
qrels[qid].append(pid)

assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)

avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)

print_message("#> Loaded qrels for", len(qrels), "unique queries with",
avg_positive, "positives per query on average.\n")

return qrels


def load_topK(topK_path):
queries = OrderedDict()
topK_docs = OrderedDict()
topK_pids = OrderedDict()

print_message("#> Loading the top-k per query from", topK_path, "...")

with open(topK_path) as f:
for line_idx, line in enumerate(f):
if line_idx and line_idx % (10*1000*1000) == 0:
print(line_idx, end=' ', flush=True)

qid, pid, query, passage = line.split('\t')
qid, pid = int(qid), int(pid)

assert (qid not in queries) or (queries[qid] == query)
queries[qid] = query
topK_docs[qid] = topK_docs.get(qid, [])
topK_docs[qid].append(passage)
topK_pids[qid] = topK_pids.get(qid, [])
topK_pids[qid].append(pid)

print()

assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)

Ks = [len(topK_pids[qid]) for qid in topK_pids]

print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")

return queries, topK_docs, topK_pids


def load_topK_pids(topK_path, qrels):
topK_pids = defaultdict(list)
topK_positives = defaultdict(list)

print_message("#> Loading the top-k PIDs per query from", topK_path, "...")

with open(topK_path) as f:
for line_idx, line in enumerate(f):
if line_idx and line_idx % (10*1000*1000) == 0:
print(line_idx, end=' ', flush=True)

qid, pid, *rest = line.strip().split('\t')
qid, pid = int(qid), int(pid)

topK_pids[qid].append(pid)

assert len(rest) in [1, 2, 3]

if len(rest) > 1:
*_, label = rest
label = int(label)
assert label in [0, 1]

if label >= 1:
topK_positives[qid].append(pid)

print()

assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)

# Make them sets for fast lookups later
topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}

Ks = [len(topK_pids[qid]) for qid in topK_pids]

print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")

if len(topK_positives) == 0:
topK_positives = None
else:
assert len(topK_pids) >= len(topK_positives)

for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
topK_positives[qid] = []

assert len(topK_pids) == len(topK_positives)

avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)

print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
avg_positive, "positives per query on average.\n")

assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"

if topK_positives is None:
topK_positives = qrels

return topK_pids, topK_positives


def load_collection(collection_path):
print_message("#> Loading collection...")

collection = []

with open(collection_path) as f:
for line_idx, line in enumerate(f):
if line_idx % (1000*1000) == 0:
print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)

pid, passage, *rest = line.strip().split('\t')
assert pid == 'id' or int(pid) == line_idx

if len(rest) >= 1:
title = rest[0]
passage = title + ' | ' + passage

collection.append(passage)

print()

return collection


def load_colbert(args, do_print=True):
colbert, checkpoint = load_model(args, do_print)

# TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
# I.e., not their purely (i.e., training) default values.

for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
if 'arguments' in checkpoint and hasattr(args, k):
if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
a, b = checkpoint['arguments'][k], getattr(args, k)
Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")

if 'arguments' in checkpoint:
if args.rank < 1:
print(ujson.dumps(checkpoint['arguments'], indent=4))

if do_print:
print('\n')

return colbert, checkpoint
Loading

0 comments on commit 4120feb

Please sign in to comment.