-
Notifications
You must be signed in to change notification settings - Fork 405
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update to version 0.5.0, with Baleen and FLIPR
- Loading branch information
Omar Khattab
committed
Dec 7, 2021
1 parent
c1905c6
commit c0b180a
Showing
23 changed files
with
664 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import torch | ||
|
||
from colbert.utils.utils import load_checkpoint | ||
from colbert.utils.amp import MixedPrecisionManager | ||
from colbert.utils.utils import flatten | ||
|
||
from baleen.utils.loaders import * | ||
from baleen.condenser.model import ElectraReader | ||
from baleen.condenser.tokenization import AnswerAwareTokenizer | ||
|
||
|
||
|
||
class Condenser: | ||
def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda', deviceL2='cuda'): | ||
self.modelL1, self.maxlenL1 = self._load_model(checkpointL1, deviceL1) | ||
self.modelL2, self.maxlenL2 = self._load_model(checkpointL2, deviceL2) | ||
|
||
assert self.maxlenL1 == self.maxlenL2, "Add support for different maxlens: use two tokenizers." | ||
|
||
self.amp, self.tokenizer = self._setup_inference(self.maxlenL2) | ||
self.CollectionX, self.CollectionY = self._load_collection(collectionX_path) | ||
|
||
def condense(self, query, backs, ranking): | ||
stage1_preds = self._stage1(query, backs, ranking) | ||
stage2_preds = self._stage2(query, stage1_preds) | ||
|
||
return stage1_preds, stage2_preds | ||
|
||
def _load_model(self, path, device): | ||
model = torch.load(path, map_location='cpu') | ||
ElectraModels = ['google/electra-base-discriminator', 'google/electra-large-discriminator'] | ||
assert model['arguments']['model'] in ElectraModels, model['arguments'] | ||
|
||
model = ElectraReader.from_pretrained(model['arguments']['model']) | ||
checkpoint = load_checkpoint(path, model) | ||
|
||
model = model.to(device) | ||
model.eval() | ||
|
||
maxlen = checkpoint['arguments']['maxlen'] | ||
|
||
return model, maxlen | ||
|
||
def _setup_inference(self, maxlen): | ||
amp = MixedPrecisionManager(activated=True) | ||
tokenizer = AnswerAwareTokenizer(total_maxlen=maxlen) | ||
|
||
return amp, tokenizer | ||
|
||
def _load_collection(self, collectionX_path): | ||
CollectionX = {} | ||
CollectionY = {} | ||
|
||
with open(collectionX_path) as f: | ||
for line_idx, line in enumerate(f): | ||
line = ujson.loads(line) | ||
|
||
assert type(line['text']) is list | ||
assert line['pid'] == line_idx, (line_idx, line) | ||
|
||
passage = [line['title']] + line['text'] | ||
CollectionX[line_idx] = passage | ||
|
||
passage = [line['title'] + ' | ' + sentence for sentence in line['text']] | ||
|
||
for idx, sentence in enumerate(passage): | ||
CollectionY[(line_idx, idx)] = sentence | ||
|
||
return CollectionX, CollectionY | ||
|
||
def _stage1(self, query, BACKS, ranking, TOPK=9): | ||
model = self.modelL1 | ||
|
||
with torch.inference_mode(): | ||
backs = [self.CollectionY[(pid, sid)] for pid, sid in BACKS if (pid, sid) in self.CollectionY] | ||
backs = [query] + backs | ||
query = ' # '.join(backs) | ||
|
||
# print(query) | ||
# print(backs) | ||
passages = [] | ||
actual_ranking = [] | ||
|
||
for pid in ranking: | ||
actual_ranking.append(pid) | ||
psg = self.CollectionX[pid] | ||
psg = ' [MASK] '.join(psg) | ||
|
||
passages.append(psg) | ||
|
||
obj = self.tokenizer.process([query], passages, None) | ||
|
||
with self.amp.context(): | ||
scores = model(obj.encoding.to(model.device)).float() | ||
|
||
pids = [[pid] * scores.size(1) for pid in actual_ranking] | ||
pids = flatten(pids) | ||
|
||
sids = [list(range(scores.size(1))) for pid in actual_ranking] | ||
sids = flatten(sids) | ||
|
||
scores = scores.view(-1) | ||
|
||
topk = scores.topk(min(TOPK, len(scores))).indices.tolist() | ||
topk_pids = [pids[idx] for idx in topk] | ||
topk_sids = [sids[idx] for idx in topk] | ||
|
||
preds = [(pid, sid) for pid, sid in zip(topk_pids, topk_sids)] | ||
|
||
pred_plus = BACKS + preds | ||
pred_plus = f7(list(map(tuple, pred_plus)))[:TOPK] | ||
|
||
return pred_plus | ||
|
||
def _stage2(self, query, preds): | ||
model = self.modelL2 | ||
|
||
psgX = [self.CollectionY[(pid, sid)] for pid, sid in preds if (pid, sid) in self.CollectionY] | ||
psg = ' [MASK] '.join([''] + psgX) | ||
passages = [psg] | ||
# print(passages) | ||
|
||
obj = self.tokenizer.process([query], passages, None) | ||
|
||
with self.amp.context(): | ||
scores = model(obj.encoding.to(model.device)).float() | ||
scores = scores.view(-1).tolist() | ||
|
||
preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)] | ||
preds = sorted(preds, reverse=True)[:5] | ||
preds = [x for score, x in preds if score > 0] | ||
|
||
# TODO: Apply L3x for final stage. | ||
|
||
return preds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from transformers import ElectraPreTrainedModel, ElectraModel | ||
|
||
|
||
class ElectraReader(ElectraPreTrainedModel): | ||
def __init__(self, config, learn_labels=False): | ||
super(ElectraReader, self).__init__(config) | ||
|
||
self.electra = ElectraModel(config) | ||
|
||
self.relevance = nn.Linear(config.hidden_size, 1) | ||
|
||
if learn_labels: | ||
self.linear = nn.Linear(config.hidden_size, 2) | ||
else: | ||
self.linear = nn.Linear(config.hidden_size, 1) | ||
|
||
self.init_weights() | ||
|
||
self.learn_labels = learn_labels | ||
|
||
def forward(self, encoding): | ||
outputs = self.electra(encoding.input_ids, | ||
attention_mask=encoding.attention_mask, | ||
token_type_ids=encoding.token_type_ids)[0] | ||
|
||
scores = self.linear(outputs) | ||
|
||
if self.learn_labels: | ||
scores = scores[:, 0].squeeze(1) | ||
else: | ||
scores = scores.squeeze(-1) | ||
candidates = (encoding.input_ids == 103) | ||
scores = self._mask_2d_index(scores, candidates) | ||
|
||
return scores | ||
|
||
def _mask_2d_index(self, scores, mask): | ||
bsize, maxlen = scores.size() | ||
bsize_, maxlen_ = mask.size() | ||
|
||
assert bsize == bsize_, (scores.size(), mask.size()) | ||
assert maxlen == maxlen_, (scores.size(), mask.size()) | ||
|
||
# Get flat scores corresponding to the True mask positions, with -inf at the end | ||
flat_scores = scores[mask] | ||
flat_scores = torch.cat((flat_scores, torch.ones(1, device=self.device) * float('-inf'))) | ||
|
||
# Get 2D indexes | ||
rowidxs, nnzs = torch.unique(torch.nonzero(mask, as_tuple=False)[:, 0], return_counts=True) | ||
max_nnzs = nnzs.max().item() | ||
|
||
rows = [[-1] * max_nnzs for _ in range(bsize)] | ||
offset = 0 | ||
for rowidx, nnz in zip(rowidxs.tolist(), nnzs.tolist()): | ||
rows[rowidx] = [offset + i for i in range(nnz)] | ||
rows[rowidx] += [-1] * (max_nnzs - len(rows[rowidx])) | ||
offset += nnz | ||
|
||
indexes = torch.tensor(rows).to(self.device) | ||
|
||
# Index with the 2D indexes | ||
scores_2d = flat_scores[indexes] | ||
|
||
return scores_2d | ||
|
||
def _2d_index(self, embeddings, positions): | ||
bsize, maxlen, hdim = embeddings.size() | ||
bsize_, max_out = positions.size() | ||
|
||
assert bsize == bsize_ | ||
assert positions.max() < maxlen | ||
|
||
embeddings = embeddings.view(bsize * maxlen, hdim) | ||
positions = positions + torch.arange(bsize, device=positions.device).unsqueeze(-1) * maxlen | ||
|
||
return embeddings[positions] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import torch | ||
|
||
from transformers import ElectraTokenizerFast | ||
|
||
class AnswerAwareTokenizer(): | ||
def __init__(self, total_maxlen, bert_model='google/electra-base-discriminator'): | ||
self.total_maxlen = total_maxlen | ||
|
||
self.tok = ElectraTokenizerFast.from_pretrained(bert_model) | ||
|
||
def process(self, questions, passages, all_answers=None, mask=None): | ||
return TokenizationObject(self, questions, passages, all_answers, mask) | ||
|
||
def tensorize(self, questions, passages): | ||
query_lengths = self.tok(questions, padding='longest', return_tensors='pt').attention_mask.sum(-1) | ||
|
||
encoding = self.tok(questions, passages, padding='longest', truncation='longest_first', | ||
return_tensors='pt', max_length=self.total_maxlen, add_special_tokens=True) | ||
|
||
return encoding, query_lengths | ||
|
||
def get_all_candidates(self, encoding, index): | ||
offsets, endpositions = self.all_word_positions(encoding, index) | ||
|
||
candidates = [(offset, endpos) | ||
for idx, offset in enumerate(offsets) | ||
for endpos in endpositions[idx:idx+10]] | ||
|
||
return candidates | ||
|
||
def all_word_positions(self, encoding, index): | ||
words = encoding.word_ids(index) | ||
offsets = [position | ||
for position, (last_word_number, current_word_number) in enumerate(zip([-1] + words, words)) | ||
if last_word_number != current_word_number] | ||
|
||
endpositions = offsets[1:] + [len(words)] | ||
|
||
return offsets, endpositions | ||
|
||
def characters_to_tokens(self, text, answers, encoding, index, offset, endpos): | ||
# print(text, answers, encoding, index, offset, endpos) | ||
# endpos = endpos - 1 | ||
|
||
for offset_ in range(offset, len(text)+1): | ||
tokens_offset = encoding.char_to_token(index, offset_) | ||
# print(f'tokens_offset = {tokens_offset}') | ||
if tokens_offset is not None: | ||
break | ||
|
||
for endpos_ in range(endpos, len(text)+1): | ||
tokens_endpos = encoding.char_to_token(index, endpos_) | ||
# print(f'tokens_endpos = {tokens_endpos}') | ||
if tokens_endpos is not None: | ||
break | ||
|
||
# None on whitespace! | ||
assert tokens_offset is not None, (text, answers, offset) | ||
# assert tokens_endpos is not None, (text, answers, endpos) | ||
tokens_endpos = tokens_endpos if tokens_endpos is not None else len(encoding.tokens(index)) | ||
|
||
return tokens_offset, tokens_endpos | ||
|
||
def tokens_to_answer(self, encoding, index, text, tokens_offset, tokens_endpos): | ||
# print(encoding, index, text, tokens_offset, tokens_endpos, len(encoding.tokens(index))) | ||
|
||
char_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_offset)).start | ||
|
||
try: | ||
char_next_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos)).start | ||
char_endpos = char_next_offset | ||
except: | ||
char_endpos = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos-1)).end | ||
|
||
assert char_offset is not None | ||
assert char_endpos is not None | ||
|
||
return text[char_offset:char_endpos].strip() | ||
|
||
|
||
class TokenizationObject(): | ||
def __init__(self, tokenizer: AnswerAwareTokenizer, questions, passages, answers=None, mask=None): | ||
assert type(questions) is list and type(passages) is list | ||
assert len(questions) in [1, len(passages)] | ||
|
||
if mask is None: | ||
mask = [True for _ in passages] | ||
|
||
self.mask = mask | ||
|
||
self.tok = tokenizer | ||
self.questions = questions if len(questions) == len(passages) else questions * len(passages) | ||
self.passages = passages | ||
self.answers = answers | ||
|
||
self.encoding, self.query_lengths = self._encode() | ||
self.passages_only_encoding, self.candidates, self.candidates_list = self._candidize() | ||
|
||
if answers is not None: | ||
self.gold_candidates = self.answers # self._answerize() | ||
|
||
def _encode(self): | ||
return self.tok.tensorize(self.questions, self.passages) | ||
|
||
def _candidize(self): | ||
encoding = self.tok.tok(self.passages, add_special_tokens=False) | ||
|
||
all_candidates = [self.tok.get_all_candidates(encoding, index) for index in range(len(self.passages))] | ||
|
||
bsize, maxcands = len(self.passages), max(map(len, all_candidates)) | ||
all_candidates = [cands + [(-1, -1)] * (maxcands - len(cands)) for cands in all_candidates] | ||
|
||
candidates = torch.tensor(all_candidates) | ||
assert candidates.size() == (bsize, maxcands, 2), (candidates.size(), (bsize, maxcands, 2), (self.questions, self.passages)) | ||
|
||
candidates = candidates + self.query_lengths.unsqueeze(-1).unsqueeze(-1) | ||
|
||
return encoding, candidates, all_candidates |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from baleen.utils.loaders import * | ||
from baleen.condenser.condense import Condenser | ||
|
||
|
||
class Baleen: | ||
def __init__(self, collectionX_path: str, searcher, condenser: Condenser): | ||
self.collectionX = load_collectionX(collectionX_path) | ||
self.searcher = searcher | ||
self.condenser = condenser | ||
|
||
def search(self, query, num_hops, depth=100, verbose=False): | ||
assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}." | ||
k = depth // num_hops | ||
|
||
searcher = self.searcher | ||
condenser = self.condenser | ||
collectionX = self.collectionX | ||
|
||
facts = [] | ||
stage1_preds = None | ||
context = None | ||
|
||
pids_bag = set() | ||
|
||
for hop_idx in range(0, num_hops): | ||
ranking = list(zip(*searcher.search(query, context=context, k=depth))) | ||
ranking_ = [] | ||
|
||
facts_pids = set([pid for pid, _ in facts]) | ||
|
||
for pid, rank, score in ranking: | ||
# print(f'[{score}] \t\t {searcher.collection[pid]}') | ||
if len(ranking_) < k and pid not in facts_pids: | ||
ranking_.append(pid) | ||
|
||
if len(pids_bag) < k * (hop_idx+1): | ||
pids_bag.add(pid) | ||
|
||
stage1_preds, facts = condenser.condense(query, backs=facts, ranking=ranking_) | ||
context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts]) | ||
|
||
assert len(pids_bag) == depth | ||
|
||
return facts, pids_bag, stage1_preds | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.