Skip to content

Commit

Permalink
Minor updates to Baleen
Browse files Browse the repository at this point in the history
  • Loading branch information
okhat committed Dec 24, 2021
1 parent c0b180a commit 6eb7ff5
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 32 deletions.
16 changes: 11 additions & 5 deletions baleen/condenser/condense.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda'

def condense(self, query, backs, ranking):
stage1_preds = self._stage1(query, backs, ranking)
stage2_preds = self._stage2(query, stage1_preds)
stage2_preds, stage2_preds_L3x = self._stage2(query, stage1_preds)

return stage1_preds, stage2_preds
return stage1_preds, stage2_preds, stage2_preds_L3x

def _load_model(self, path, device):
model = torch.load(path, map_location='cpu')
Expand Down Expand Up @@ -128,8 +128,14 @@ def _stage2(self, query, preds):

preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)]
preds = sorted(preds, reverse=True)[:5]

preds_L3x = [x for score, x in preds if score > min(0, preds[1][0] - 1e-10)] # Take at least 2!
preds = [x for score, x in preds if score > 0]

# TODO: Apply L3x for final stage.

return preds
earliest_pids = f7([pid for pid, _ in preds_L3x])[:4] # Take at most 4 docs.
preds_L3x = [(pid, sid) for pid, sid in preds_L3x if pid in earliest_pids]

assert len(preds_L3x) >= 2
assert len(f7([pid for pid, _ in preds_L3x])) <= 4

return preds, preds_L3x
4 changes: 2 additions & 2 deletions baleen/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def search(self, query, num_hops, depth=100, verbose=False):
if len(pids_bag) < k * (hop_idx+1):
pids_bag.add(pid)

stage1_preds, facts = condenser.condense(query, backs=facts, ranking=ranking_)
stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_)
context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])

assert len(pids_bag) == depth

return facts, pids_bag, stage1_preds
return stage2_L3x, pids_bag, stage1_preds



Expand Down
15 changes: 0 additions & 15 deletions colbert/indexing/collection_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,26 +299,11 @@ def _collect_embedding_id_offset(self):
assert len(self.embedding_offsets) == self.num_chunks

def _build_ivf(self):
# TODO: If this is slow or memory intensive, it can be done as a torch.sort, torch.add offset, torch.unique
# operations over the concatenated codes.
# On MS MARCO, this seems to take 10 minutes! I can imagine that's 40 minutes on Wikipedia.


"""
codes = ResidualCodec.Embeddings.load_all_codes(index_path)
codes.sort().{values, indices}
values.unique_consecutive -> counts -> cumsum -> offsets
(indices.int32(), offsets)
# Maybe we should several small IVFs? Every 250M embeddings, so that's every 1 GB.
# It would save *memory* here and *disk space* regarding the int64.
# But we'd have to decide how many IVFs to use during retrieval: many (loop) or one?
# A loop seems nice if we can find a size that's large enough for speed yet small enough to fit on GPU!
# Then it would help nicely for batching later: 1GB.
"""

codes = torch.empty(self.num_embeddings,)
print_memory_stats(f'RANK:{self.rank}')
Expand Down
8 changes: 3 additions & 5 deletions colbert/utilities/annotate_em.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ def save(self, new_path):


if __name__ == '__main__':
r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.50.02/ranking.tsv'
r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.59.37/ranking.tsv'
r = sys.argv[1]
r = sys.argv[2]

a = AnnotateEM(collection='/future/u/okhattab/root/unit/data/NQ-mini/collection.tsv',
qas='/future/u/okhattab/root/unit/data/NQ-mini/dev/qas.json')
a = AnnotateEM(collection='/dfs/scratch0/okhattab/OpenQA/collection.tsv',
qas=sys.argv[1])
a.annotate(ranking=r)
5 changes: 0 additions & 5 deletions colbert/utils/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
from contextlib import contextmanager
from colbert.utils.utils import NullContextManager

PyTorch_over_1_6 = True # float('.'.join(torch.__version__.split('.')[0:2])) >= 1.6


class MixedPrecisionManager():
def __init__(self, activated):
assert (not activated) or PyTorch_over_1_6, "Cannot use AMP for PyTorch version < 1.6"

self.activated = activated

if self.activated:
Expand All @@ -22,7 +18,6 @@ def backward(self, loss):
if self.activated:
self.scaler.scale(loss).backward()
else:
assert False, "for now"
loss.backward()

def step(self, colbert, optimizer, scheduler=None):
Expand Down

0 comments on commit 6eb7ff5

Please sign in to comment.