-
Notifications
You must be signed in to change notification settings - Fork 405
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit with the new API and residual compression
- Loading branch information
Showing
78 changed files
with
6,583 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"jupyter.jupyterServerType": "local", | ||
"python.formatting.autopep8Args": ["--max-line-length", "120"], | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2019, 2020 Stanford Future Data Systems | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# ColBERT | ||
|
||
### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds. | ||
|
||
|
||
<p align="center"> | ||
<img align="center" src="docs/images/ColBERT-Framework-MaxSim-W370px.png" /> | ||
</p> | ||
<p align="center"> | ||
<b>Figure 1:</b> ColBERT's late interaction, efficiently scoring the fine-grained similarity between a queries and a passage. | ||
</p> | ||
|
||
As Figure 1 illustrates, ColBERT relies on fine-grained **contextual late interaction**: it encodes each passage into a **matrix** of token-level embeddings (shown above in blue). Then at search time, it embeds every query into another matrix (shown in green) and efficiently finds passages that contextually match the query using scalable vector-similarity (`MaxSim`) operators. | ||
|
||
These rich interactions allow ColBERT to surpass the quality of _single-vector_ representation models, while scaling efficiently to large corpora. You can read more in our papers: | ||
|
||
* [**ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT**](https://arxiv.org/abs/2004.12832) (SIGIR'20). | ||
* [**Relevance-guided Supervision for OpenQA with ColBERT**](https://arxiv.org/abs/2007.00814) (TACL'21). | ||
|
||
|
||
---- | ||
|
||
## Installation | ||
|
||
ColBERT (currently: [v0.4.6](#releases)) requires Python 3.7+ and Pytorch 1.9+ and uses the [HuggingFace Transformers](https://github.com/huggingface/transformers) library. | ||
|
||
We strongly recommend creating a conda environment using the commands below. (If you don't have conda, follow the official [conda installation guide](https://docs.anaconda.com/anaconda/install/linux/#installation).) | ||
|
||
``` | ||
conda env create -f conda_env.yml | ||
conda activate colbert-v0.4.2 | ||
``` | ||
|
||
If you face any problems, please [open a new issue](https://github.com/stanford-futuredata/ColBERT/issues) and we'll help you promptly! | ||
|
||
|
||
## NEW: API Usage Notebook | ||
|
||
This Jupyter **[docs/intro.ipynb notebook](docs/intro.ipynb)** illustrates using the key features of ColBERT with the new Python API. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .trainer import Trainer | ||
from .indexer import Indexer | ||
from .searcher import Searcher | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
import ujson | ||
import torch | ||
import random | ||
|
||
from collections import defaultdict, OrderedDict | ||
|
||
from colbert.parameters import DEVICE | ||
from colbert.modeling.colbert import ColBERT | ||
from colbert.utils.utils import print_message, load_checkpoint | ||
|
||
|
||
def load_model(args, do_print=True): | ||
colbert = ColBERT.from_pretrained('bert-base-uncased', | ||
query_maxlen=args.query_maxlen, | ||
doc_maxlen=args.doc_maxlen, | ||
dim=args.dim, | ||
similarity_metric=args.similarity, | ||
mask_punctuation=args.mask_punctuation) | ||
colbert = colbert.to(DEVICE) | ||
|
||
print_message("#> Loading model checkpoint.", condition=do_print) | ||
|
||
checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print) | ||
|
||
colbert.eval() | ||
|
||
return colbert, checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import os | ||
import ujson | ||
import torch | ||
import random | ||
|
||
from collections import defaultdict, OrderedDict | ||
|
||
from colbert.parameters import DEVICE | ||
from colbert.modeling.colbert import ColBERT | ||
from colbert.utils.utils import print_message, load_checkpoint | ||
from colbert.evaluation.load_model import load_model | ||
from colbert.utils.runs import Run | ||
|
||
|
||
def load_queries(queries_path): | ||
queries = OrderedDict() | ||
|
||
print_message("#> Loading the queries from", queries_path, "...") | ||
|
||
with open(queries_path) as f: | ||
for line in f: | ||
qid, query, *_ = line.strip().split('\t') | ||
qid = int(qid) | ||
|
||
assert (qid not in queries), ("Query QID", qid, "is repeated!") | ||
queries[qid] = query | ||
|
||
print_message("#> Got", len(queries), "queries. All QIDs are unique.\n") | ||
|
||
return queries | ||
|
||
|
||
def load_qrels(qrels_path): | ||
if qrels_path is None: | ||
return None | ||
|
||
print_message("#> Loading qrels from", qrels_path, "...") | ||
|
||
qrels = OrderedDict() | ||
with open(qrels_path, mode='r', encoding="utf-8") as f: | ||
for line in f: | ||
qid, x, pid, y = map(int, line.strip().split('\t')) | ||
assert x == 0 and y == 1 | ||
qrels[qid] = qrels.get(qid, []) | ||
qrels[qid].append(pid) | ||
|
||
assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels) | ||
|
||
avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2) | ||
|
||
print_message("#> Loaded qrels for", len(qrels), "unique queries with", | ||
avg_positive, "positives per query on average.\n") | ||
|
||
return qrels | ||
|
||
|
||
def load_topK(topK_path): | ||
queries = OrderedDict() | ||
topK_docs = OrderedDict() | ||
topK_pids = OrderedDict() | ||
|
||
print_message("#> Loading the top-k per query from", topK_path, "...") | ||
|
||
with open(topK_path) as f: | ||
for line_idx, line in enumerate(f): | ||
if line_idx and line_idx % (10*1000*1000) == 0: | ||
print(line_idx, end=' ', flush=True) | ||
|
||
qid, pid, query, passage = line.split('\t') | ||
qid, pid = int(qid), int(pid) | ||
|
||
assert (qid not in queries) or (queries[qid] == query) | ||
queries[qid] = query | ||
topK_docs[qid] = topK_docs.get(qid, []) | ||
topK_docs[qid].append(passage) | ||
topK_pids[qid] = topK_pids.get(qid, []) | ||
topK_pids[qid].append(pid) | ||
|
||
print() | ||
|
||
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids) | ||
|
||
Ks = [len(topK_pids[qid]) for qid in topK_pids] | ||
|
||
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2)) | ||
print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n") | ||
|
||
return queries, topK_docs, topK_pids | ||
|
||
|
||
def load_topK_pids(topK_path, qrels): | ||
topK_pids = defaultdict(list) | ||
topK_positives = defaultdict(list) | ||
|
||
print_message("#> Loading the top-k PIDs per query from", topK_path, "...") | ||
|
||
with open(topK_path) as f: | ||
for line_idx, line in enumerate(f): | ||
if line_idx and line_idx % (10*1000*1000) == 0: | ||
print(line_idx, end=' ', flush=True) | ||
|
||
qid, pid, *rest = line.strip().split('\t') | ||
qid, pid = int(qid), int(pid) | ||
|
||
topK_pids[qid].append(pid) | ||
|
||
assert len(rest) in [1, 2, 3] | ||
|
||
if len(rest) > 1: | ||
*_, label = rest | ||
label = int(label) | ||
assert label in [0, 1] | ||
|
||
if label >= 1: | ||
topK_positives[qid].append(pid) | ||
|
||
print() | ||
|
||
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids) | ||
assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives) | ||
|
||
# Make them sets for fast lookups later | ||
topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives} | ||
|
||
Ks = [len(topK_pids[qid]) for qid in topK_pids] | ||
|
||
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2)) | ||
print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n") | ||
|
||
if len(topK_positives) == 0: | ||
topK_positives = None | ||
else: | ||
assert len(topK_pids) >= len(topK_positives) | ||
|
||
for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())): | ||
topK_positives[qid] = [] | ||
|
||
assert len(topK_pids) == len(topK_positives) | ||
|
||
avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2) | ||
|
||
print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with", | ||
avg_positive, "positives per query on average.\n") | ||
|
||
assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!" | ||
|
||
if topK_positives is None: | ||
topK_positives = qrels | ||
|
||
return topK_pids, topK_positives | ||
|
||
|
||
def load_collection(collection_path): | ||
print_message("#> Loading collection...") | ||
|
||
collection = [] | ||
|
||
with open(collection_path) as f: | ||
for line_idx, line in enumerate(f): | ||
if line_idx % (1000*1000) == 0: | ||
print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True) | ||
|
||
pid, passage, *rest = line.strip().split('\t') | ||
assert pid == 'id' or int(pid) == line_idx | ||
|
||
if len(rest) >= 1: | ||
title = rest[0] | ||
passage = title + ' | ' + passage | ||
|
||
collection.append(passage) | ||
|
||
print() | ||
|
||
return collection | ||
|
||
|
||
def load_colbert(args, do_print=True): | ||
colbert, checkpoint = load_model(args, do_print) | ||
|
||
# TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used. | ||
# I.e., not their purely (i.e., training) default values. | ||
|
||
for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']: | ||
if 'arguments' in checkpoint and hasattr(args, k): | ||
if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k): | ||
a, b = checkpoint['arguments'][k], getattr(args, k) | ||
Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})") | ||
|
||
if 'arguments' in checkpoint: | ||
if args.rank < 1: | ||
print(ujson.dumps(checkpoint['arguments'], indent=4)) | ||
|
||
if do_print: | ||
print('\n') | ||
|
||
return colbert, checkpoint |
Oops, something went wrong.