From 6eb7ff53982de71c18ea8e4762faa9d646f6d727 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Fri, 24 Dec 2021 12:17:41 -0800 Subject: [PATCH] Minor updates to Baleen --- baleen/condenser/condense.py | 16 +++++++++++----- baleen/engine.py | 4 ++-- colbert/indexing/collection_indexer.py | 15 --------------- colbert/utilities/annotate_em.py | 8 +++----- colbert/utils/amp.py | 5 ----- 5 files changed, 16 insertions(+), 32 deletions(-) diff --git a/baleen/condenser/condense.py b/baleen/condenser/condense.py index 61115ce0..27ac302f 100644 --- a/baleen/condenser/condense.py +++ b/baleen/condenser/condense.py @@ -22,9 +22,9 @@ def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda' def condense(self, query, backs, ranking): stage1_preds = self._stage1(query, backs, ranking) - stage2_preds = self._stage2(query, stage1_preds) + stage2_preds, stage2_preds_L3x = self._stage2(query, stage1_preds) - return stage1_preds, stage2_preds + return stage1_preds, stage2_preds, stage2_preds_L3x def _load_model(self, path, device): model = torch.load(path, map_location='cpu') @@ -128,8 +128,14 @@ def _stage2(self, query, preds): preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)] preds = sorted(preds, reverse=True)[:5] + + preds_L3x = [x for score, x in preds if score > min(0, preds[1][0] - 1e-10)] # Take at least 2! preds = [x for score, x in preds if score > 0] - # TODO: Apply L3x for final stage. - - return preds + earliest_pids = f7([pid for pid, _ in preds_L3x])[:4] # Take at most 4 docs. + preds_L3x = [(pid, sid) for pid, sid in preds_L3x if pid in earliest_pids] + + assert len(preds_L3x) >= 2 + assert len(f7([pid for pid, _ in preds_L3x])) <= 4 + + return preds, preds_L3x diff --git a/baleen/engine.py b/baleen/engine.py index d8031254..f215258c 100644 --- a/baleen/engine.py +++ b/baleen/engine.py @@ -36,12 +36,12 @@ def search(self, query, num_hops, depth=100, verbose=False): if len(pids_bag) < k * (hop_idx+1): pids_bag.add(pid) - stage1_preds, facts = condenser.condense(query, backs=facts, ranking=ranking_) + stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_) context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts]) assert len(pids_bag) == depth - return facts, pids_bag, stage1_preds + return stage2_L3x, pids_bag, stage1_preds diff --git a/colbert/indexing/collection_indexer.py b/colbert/indexing/collection_indexer.py index ee3d4e76..e1a3111c 100644 --- a/colbert/indexing/collection_indexer.py +++ b/colbert/indexing/collection_indexer.py @@ -299,26 +299,11 @@ def _collect_embedding_id_offset(self): assert len(self.embedding_offsets) == self.num_chunks def _build_ivf(self): - # TODO: If this is slow or memory intensive, it can be done as a torch.sort, torch.add offset, torch.unique - # operations over the concatenated codes. - # On MS MARCO, this seems to take 10 minutes! I can imagine that's 40 minutes on Wikipedia. - - - """ - codes = ResidualCodec.Embeddings.load_all_codes(index_path) - - codes.sort().{values, indices} - - values.unique_consecutive -> counts -> cumsum -> offsets - (indices.int32(), offsets) - - # Maybe we should several small IVFs? Every 250M embeddings, so that's every 1 GB. # It would save *memory* here and *disk space* regarding the int64. # But we'd have to decide how many IVFs to use during retrieval: many (loop) or one? # A loop seems nice if we can find a size that's large enough for speed yet small enough to fit on GPU! # Then it would help nicely for batching later: 1GB. - """ codes = torch.empty(self.num_embeddings,) print_memory_stats(f'RANK:{self.rank}') diff --git a/colbert/utilities/annotate_em.py b/colbert/utilities/annotate_em.py index d9e783e5..e2c4cea6 100644 --- a/colbert/utilities/annotate_em.py +++ b/colbert/utilities/annotate_em.py @@ -106,10 +106,8 @@ def save(self, new_path): if __name__ == '__main__': - r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.50.02/ranking.tsv' - r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.59.37/ranking.tsv' - r = sys.argv[1] + r = sys.argv[2] - a = AnnotateEM(collection='/future/u/okhattab/root/unit/data/NQ-mini/collection.tsv', - qas='/future/u/okhattab/root/unit/data/NQ-mini/dev/qas.json') + a = AnnotateEM(collection='/dfs/scratch0/okhattab/OpenQA/collection.tsv', + qas=sys.argv[1]) a.annotate(ranking=r) diff --git a/colbert/utils/amp.py b/colbert/utils/amp.py index dc492f36..2a0dadc1 100644 --- a/colbert/utils/amp.py +++ b/colbert/utils/amp.py @@ -3,13 +3,9 @@ from contextlib import contextmanager from colbert.utils.utils import NullContextManager -PyTorch_over_1_6 = True # float('.'.join(torch.__version__.split('.')[0:2])) >= 1.6 - class MixedPrecisionManager(): def __init__(self, activated): - assert (not activated) or PyTorch_over_1_6, "Cannot use AMP for PyTorch version < 1.6" - self.activated = activated if self.activated: @@ -22,7 +18,6 @@ def backward(self, loss): if self.activated: self.scaler.scale(loss).backward() else: - assert False, "for now" loss.backward() def step(self, colbert, optimizer, scheduler=None):