Skip to content

Commit

Permalink
Resolve merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
VThejas committed May 30, 2024
2 parents b97e216 + 862edcf commit 40e6367
Show file tree
Hide file tree
Showing 30 changed files with 22,982 additions and 187 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ __pycache__/

# Other
.vscode
*.tsv
*.tsv
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include colbert/indexing/codecs/*.cpp
include colbert/indexing/codecs/*.cu
41 changes: 36 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ These rich interactions allow ColBERT to surpass the quality of _single-vector_
* [**Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval**](https://arxiv.org/abs/2101.00436) (NeurIPS'21).
* [**ColBERTv2: Effective and Efficient Retrieval via Lightweight Late Interaction**](https://arxiv.org/abs/2112.01488) (NAACL'22).
* [**PLAID: An Efficient Engine for Late Interaction Retrieval**](https://arxiv.org/abs/2205.09707) (CIKM'22).
* [**Moving Beyond Downstream Task Accuracy for Information Retrieval Benchmarking**](https://arxiv.org/abs/2212.01340) (ACL'23 Findings).
* [**UDAPDR: Unsupervised Domain Adaptation via LLM Prompting and Distillation of Rerankers**](https://arxiv.org/abs/2303.00807) (EMNLP'23).

----

## 🚨 **Announcements**

* (1/28/24) One of the easiest ways to use ColBERT in applications nowadays is the semi-official, fast-growing [RAGatouille](https://github.com/bclavie/ragatouille) library.
* (1/29/23) We have merged a new index updater feature and support for additional Hugging Face models! These are in beta so please give us feedback as you try them out.
* (1/24/23) If you're looking for the **DSP** framework for composing ColBERTv2 and LLMs, it's at: https://github.com/stanfordnlp/dsp
* (1/24/23) If you're looking for the **DSPy** framework for composing retrievers like ColBERTv2 and LLMs, it's at: https://github.com/stanfordnlp/dspy

----

Expand All @@ -43,6 +46,8 @@ The ColBERTv1 code from the SIGIR'20 paper is in the [`colbertv1` branch](https:

## Installation

(Update: nowadays you can typically do `pip install colbert-ai[torch,faiss-gpu]` to get things up and running, but if you face issues conda is always more reliable for `faiss` and `torch`.)

ColBERT requires Python 3.7+ and Pytorch 1.9+ and uses the [Hugging Face Transformers](https://github.com/huggingface/transformers) library.

We strongly recommend creating a conda environment using the commands below. (If you don't have conda, follow the official [conda installation guide](https://docs.anaconda.com/anaconda/install/linux/#installation).)
Expand Down Expand Up @@ -100,7 +105,7 @@ For fast retrieval, indexing precomputes the ColBERT representations of passages

Example usage:

```
```python
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Indexer

Expand All @@ -120,7 +125,7 @@ if __name__=='__main__':

We typically recommend that you use ColBERT for **end-to-end** retrieval, where it directly finds its top-k passages from the full collection:

```
```python
from colbert.data import Queries
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Searcher
Expand All @@ -145,7 +150,7 @@ We can evaluate the MSMARCO rankings using the following command:
python -m utility.evaluate.msmarco_passages --ranking "/path/to/msmarco.nbits=2.ranking.tsv" --qrels "/path/to/MSMARCO/qrels.dev.small.tsv"
```

## Training
## Basic Training (ColBERTv1-style)

We provide a [pre-trained model checkpoint](https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz), but we also detail how to train from scratch here.
Note that this example demonstrates the ColBERTv1 style of training, but the provided checkpoint was trained with ColBERTv2.
Expand All @@ -154,7 +159,7 @@ Training requires a JSONL triples file with a `[qid, pid+, pid-]` list per line.

Example usage (training on 4 GPUs):

```
```python
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Trainer

Expand All @@ -177,6 +182,32 @@ if __name__=='__main__':
print(f"Saved checkpoint to {checkpoint_path}...")
```

## Advanced Training (ColBERTv2-style)

```python
from colbert.infra.run import Run
from colbert.infra.config import ColBERTConfig, RunConfig
from colbert import Trainer


def train():
# use 4 gpus (e.g. four A100s, but you can use fewer by changing nway,accumsteps,bsize).
with Run().context(RunConfig(nranks=4)):
triples = '/path/to/examples.64.json' # `wget https://huggingface.co/colbert-ir/colbertv2.0_msmarco_64way/resolve/main/examples.json?download=true` (26GB)
queries = '/path/to/MSMARCO/queries.train.tsv'
collection = '/path/to/MSMARCO/collection.tsv'

config = ColBERTConfig(bsize=32, lr=1e-05, warmup=20_000, doc_maxlen=180, dim=128, attend_to_mask_tokens=False, nway=64, accumsteps=1, similarity='cosine', use_ib_negatives=True)
trainer = Trainer(triples=triples, queries=queries, collection=collection, config=config)

trainer.train(checkpoint='colbert-ir/colbertv1.9') # or start from scratch, like `bert-base-uncased`


if __name__ == '__main__':
train()
```


## Running a lightweight ColBERTv2 server
We provide a script to run a lightweight server which serves k (upto 100) results in ranked order for a given search query, in JSON format. This script can be used to power DSP programs.

Expand Down
31 changes: 22 additions & 9 deletions colbert/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@


class Indexer:
def __init__(self, checkpoint, config=None):
def __init__(self, checkpoint, config=None, verbose: int = 3):
"""
Use Run().context() to choose the run's configuration. They are NOT extracted from `config`.
"""

self.index_path = None
self.verbose = verbose
self.checkpoint = checkpoint
self.checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint)

Expand All @@ -31,7 +32,7 @@ def configure(self, **kw_args):
def get_index(self):
return self.index_path

def erase(self):
def erase(self, force_silent: bool = False):
assert self.index_path is not None
directory = self.index_path
deleted = []
Expand All @@ -47,27 +48,32 @@ def erase(self):
deleted.append(filename)

if len(deleted):
print_message(f"#> Will delete {len(deleted)} files already at {directory} in 20 seconds...")
time.sleep(20)
if not force_silent:
print_message(f"#> Will delete {len(deleted)} files already at {directory} in 20 seconds...")
time.sleep(20)

for filename in deleted:
os.remove(filename)

return deleted

def index(self, name, collection, overwrite=False):
assert overwrite in [True, False, 'reuse', 'resume']
assert overwrite in [True, False, 'reuse', 'resume', "force_silent_overwrite"]

self.configure(collection=collection, index_name=name, resume=overwrite=='resume')
# Note: The bsize value set here is ignored internally. Users are encouraged
# to supply their own batch size for indexing by using the index_bsize parameter in the ColBERTConfig.
self.configure(bsize=64, partitions=None)

self.index_path = self.config.index_path_
index_does_not_exist = (not os.path.exists(self.config.index_path_))

assert (overwrite in [True, 'reuse', 'resume']) or index_does_not_exist, self.config.index_path_
assert (overwrite in [True, 'reuse', 'resume', "force_silent_overwrite"]) or index_does_not_exist, self.config.index_path_
create_directory(self.config.index_path_)

if overwrite is True:
if overwrite == 'force_silent_overwrite':
self.erase(force_silent=True)
elif overwrite is True:
self.erase()

if index_does_not_exist or overwrite != 'reuse':
Expand All @@ -76,10 +82,17 @@ def index(self, name, collection, overwrite=False):
return self.index_path

def __launch(self, collection):
launcher = Launcher(encode)
if self.config.nranks == 1 and self.config.avoid_fork_if_possible:
shared_queues = []
shared_lists = []
launcher.launch_without_fork(self.config, collection, shared_lists, shared_queues, self.verbose)

return

manager = mp.Manager()
shared_lists = [manager.list() for _ in range(self.config.nranks)]
shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]

# Encodes collection into index using the CollectionIndexer class
launcher = Launcher(encode)
launcher.launch(self.config, collection, shared_lists, shared_queues)
launcher.launch(self.config, collection, shared_lists, shared_queues, self.verbose)
8 changes: 4 additions & 4 deletions colbert/indexing/collection_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ def encode_passages(self, passages):
# Batch here to avoid OOM from storing intermediate embeddings on GPU.
# Storing on the GPU helps with speed of masking, etc.
# But ideally this batching happens internally inside docFromText.
for passages_batch in batch(passages, self.config.bsize * 50):
embs_, doclens_ = self.checkpoint.docFromText(passages_batch, bsize=self.config.bsize,
for passages_batch in batch(passages, self.config.index_bsize * 50):
embs_, doclens_ = self.checkpoint.docFromText(passages_batch, bsize=self.config.index_bsize,
keep_dims='flatten', showprogress=(not self.use_gpu))
embs.append(embs_)
doclens.extend(doclens_)

embs = torch.cat(embs)

# embs, doclens = self.checkpoint.docFromText(passages, bsize=self.config.bsize,
# embs, doclens = self.checkpoint.docFromText(passages, bsize=self.config.index_bsize,
# keep_dims='flatten', showprogress=(self.config.rank < 1))

# with torch.inference_mode():
# embs = self.checkpoint.docFromText(passages, bsize=self.config.bsize,
# embs = self.checkpoint.docFromText(passages, bsize=self.config.index_bsize,
# keep_dims=False, showprogress=(self.config.rank < 1))
# assert type(embs) is list
# assert len(embs) == len(passages)
Expand Down
Loading

0 comments on commit 40e6367

Please sign in to comment.