Skip to content

Commit

Permalink
Update to version 0.5.0, with Baleen and FLIPR
Browse files Browse the repository at this point in the history
  • Loading branch information
Omar Khattab committed Dec 7, 2021
1 parent c1905c6 commit c0b180a
Show file tree
Hide file tree
Showing 23 changed files with 664 additions and 190 deletions.
135 changes: 135 additions & 0 deletions baleen/condenser/condense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import torch

from colbert.utils.utils import load_checkpoint
from colbert.utils.amp import MixedPrecisionManager
from colbert.utils.utils import flatten

from baleen.utils.loaders import *
from baleen.condenser.model import ElectraReader
from baleen.condenser.tokenization import AnswerAwareTokenizer



class Condenser:
def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda', deviceL2='cuda'):
self.modelL1, self.maxlenL1 = self._load_model(checkpointL1, deviceL1)
self.modelL2, self.maxlenL2 = self._load_model(checkpointL2, deviceL2)

assert self.maxlenL1 == self.maxlenL2, "Add support for different maxlens: use two tokenizers."

self.amp, self.tokenizer = self._setup_inference(self.maxlenL2)
self.CollectionX, self.CollectionY = self._load_collection(collectionX_path)

def condense(self, query, backs, ranking):
stage1_preds = self._stage1(query, backs, ranking)
stage2_preds = self._stage2(query, stage1_preds)

return stage1_preds, stage2_preds

def _load_model(self, path, device):
model = torch.load(path, map_location='cpu')
ElectraModels = ['google/electra-base-discriminator', 'google/electra-large-discriminator']
assert model['arguments']['model'] in ElectraModels, model['arguments']

model = ElectraReader.from_pretrained(model['arguments']['model'])
checkpoint = load_checkpoint(path, model)

model = model.to(device)
model.eval()

maxlen = checkpoint['arguments']['maxlen']

return model, maxlen

def _setup_inference(self, maxlen):
amp = MixedPrecisionManager(activated=True)
tokenizer = AnswerAwareTokenizer(total_maxlen=maxlen)

return amp, tokenizer

def _load_collection(self, collectionX_path):
CollectionX = {}
CollectionY = {}

with open(collectionX_path) as f:
for line_idx, line in enumerate(f):
line = ujson.loads(line)

assert type(line['text']) is list
assert line['pid'] == line_idx, (line_idx, line)

passage = [line['title']] + line['text']
CollectionX[line_idx] = passage

passage = [line['title'] + ' | ' + sentence for sentence in line['text']]

for idx, sentence in enumerate(passage):
CollectionY[(line_idx, idx)] = sentence

return CollectionX, CollectionY

def _stage1(self, query, BACKS, ranking, TOPK=9):
model = self.modelL1

with torch.inference_mode():
backs = [self.CollectionY[(pid, sid)] for pid, sid in BACKS if (pid, sid) in self.CollectionY]
backs = [query] + backs
query = ' # '.join(backs)

# print(query)
# print(backs)
passages = []
actual_ranking = []

for pid in ranking:
actual_ranking.append(pid)
psg = self.CollectionX[pid]
psg = ' [MASK] '.join(psg)

passages.append(psg)

obj = self.tokenizer.process([query], passages, None)

with self.amp.context():
scores = model(obj.encoding.to(model.device)).float()

pids = [[pid] * scores.size(1) for pid in actual_ranking]
pids = flatten(pids)

sids = [list(range(scores.size(1))) for pid in actual_ranking]
sids = flatten(sids)

scores = scores.view(-1)

topk = scores.topk(min(TOPK, len(scores))).indices.tolist()
topk_pids = [pids[idx] for idx in topk]
topk_sids = [sids[idx] for idx in topk]

preds = [(pid, sid) for pid, sid in zip(topk_pids, topk_sids)]

pred_plus = BACKS + preds
pred_plus = f7(list(map(tuple, pred_plus)))[:TOPK]

return pred_plus

def _stage2(self, query, preds):
model = self.modelL2

psgX = [self.CollectionY[(pid, sid)] for pid, sid in preds if (pid, sid) in self.CollectionY]
psg = ' [MASK] '.join([''] + psgX)
passages = [psg]
# print(passages)

obj = self.tokenizer.process([query], passages, None)

with self.amp.context():
scores = model(obj.encoding.to(model.device)).float()
scores = scores.view(-1).tolist()

preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)]
preds = sorted(preds, reverse=True)[:5]
preds = [x for score, x in preds if score > 0]

# TODO: Apply L3x for final stage.

return preds
79 changes: 79 additions & 0 deletions baleen/condenser/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import torch.nn as nn

from transformers import ElectraPreTrainedModel, ElectraModel


class ElectraReader(ElectraPreTrainedModel):
def __init__(self, config, learn_labels=False):
super(ElectraReader, self).__init__(config)

self.electra = ElectraModel(config)

self.relevance = nn.Linear(config.hidden_size, 1)

if learn_labels:
self.linear = nn.Linear(config.hidden_size, 2)
else:
self.linear = nn.Linear(config.hidden_size, 1)

self.init_weights()

self.learn_labels = learn_labels

def forward(self, encoding):
outputs = self.electra(encoding.input_ids,
attention_mask=encoding.attention_mask,
token_type_ids=encoding.token_type_ids)[0]

scores = self.linear(outputs)

if self.learn_labels:
scores = scores[:, 0].squeeze(1)
else:
scores = scores.squeeze(-1)
candidates = (encoding.input_ids == 103)
scores = self._mask_2d_index(scores, candidates)

return scores

def _mask_2d_index(self, scores, mask):
bsize, maxlen = scores.size()
bsize_, maxlen_ = mask.size()

assert bsize == bsize_, (scores.size(), mask.size())
assert maxlen == maxlen_, (scores.size(), mask.size())

# Get flat scores corresponding to the True mask positions, with -inf at the end
flat_scores = scores[mask]
flat_scores = torch.cat((flat_scores, torch.ones(1, device=self.device) * float('-inf')))

# Get 2D indexes
rowidxs, nnzs = torch.unique(torch.nonzero(mask, as_tuple=False)[:, 0], return_counts=True)
max_nnzs = nnzs.max().item()

rows = [[-1] * max_nnzs for _ in range(bsize)]
offset = 0
for rowidx, nnz in zip(rowidxs.tolist(), nnzs.tolist()):
rows[rowidx] = [offset + i for i in range(nnz)]
rows[rowidx] += [-1] * (max_nnzs - len(rows[rowidx]))
offset += nnz

indexes = torch.tensor(rows).to(self.device)

# Index with the 2D indexes
scores_2d = flat_scores[indexes]

return scores_2d

def _2d_index(self, embeddings, positions):
bsize, maxlen, hdim = embeddings.size()
bsize_, max_out = positions.size()

assert bsize == bsize_
assert positions.max() < maxlen

embeddings = embeddings.view(bsize * maxlen, hdim)
positions = positions + torch.arange(bsize, device=positions.device).unsqueeze(-1) * maxlen

return embeddings[positions]
118 changes: 118 additions & 0 deletions baleen/condenser/tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch

from transformers import ElectraTokenizerFast

class AnswerAwareTokenizer():
def __init__(self, total_maxlen, bert_model='google/electra-base-discriminator'):
self.total_maxlen = total_maxlen

self.tok = ElectraTokenizerFast.from_pretrained(bert_model)

def process(self, questions, passages, all_answers=None, mask=None):
return TokenizationObject(self, questions, passages, all_answers, mask)

def tensorize(self, questions, passages):
query_lengths = self.tok(questions, padding='longest', return_tensors='pt').attention_mask.sum(-1)

encoding = self.tok(questions, passages, padding='longest', truncation='longest_first',
return_tensors='pt', max_length=self.total_maxlen, add_special_tokens=True)

return encoding, query_lengths

def get_all_candidates(self, encoding, index):
offsets, endpositions = self.all_word_positions(encoding, index)

candidates = [(offset, endpos)
for idx, offset in enumerate(offsets)
for endpos in endpositions[idx:idx+10]]

return candidates

def all_word_positions(self, encoding, index):
words = encoding.word_ids(index)
offsets = [position
for position, (last_word_number, current_word_number) in enumerate(zip([-1] + words, words))
if last_word_number != current_word_number]

endpositions = offsets[1:] + [len(words)]

return offsets, endpositions

def characters_to_tokens(self, text, answers, encoding, index, offset, endpos):
# print(text, answers, encoding, index, offset, endpos)
# endpos = endpos - 1

for offset_ in range(offset, len(text)+1):
tokens_offset = encoding.char_to_token(index, offset_)
# print(f'tokens_offset = {tokens_offset}')
if tokens_offset is not None:
break

for endpos_ in range(endpos, len(text)+1):
tokens_endpos = encoding.char_to_token(index, endpos_)
# print(f'tokens_endpos = {tokens_endpos}')
if tokens_endpos is not None:
break

# None on whitespace!
assert tokens_offset is not None, (text, answers, offset)
# assert tokens_endpos is not None, (text, answers, endpos)
tokens_endpos = tokens_endpos if tokens_endpos is not None else len(encoding.tokens(index))

return tokens_offset, tokens_endpos

def tokens_to_answer(self, encoding, index, text, tokens_offset, tokens_endpos):
# print(encoding, index, text, tokens_offset, tokens_endpos, len(encoding.tokens(index)))

char_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_offset)).start

try:
char_next_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos)).start
char_endpos = char_next_offset
except:
char_endpos = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos-1)).end

assert char_offset is not None
assert char_endpos is not None

return text[char_offset:char_endpos].strip()


class TokenizationObject():
def __init__(self, tokenizer: AnswerAwareTokenizer, questions, passages, answers=None, mask=None):
assert type(questions) is list and type(passages) is list
assert len(questions) in [1, len(passages)]

if mask is None:
mask = [True for _ in passages]

self.mask = mask

self.tok = tokenizer
self.questions = questions if len(questions) == len(passages) else questions * len(passages)
self.passages = passages
self.answers = answers

self.encoding, self.query_lengths = self._encode()
self.passages_only_encoding, self.candidates, self.candidates_list = self._candidize()

if answers is not None:
self.gold_candidates = self.answers # self._answerize()

def _encode(self):
return self.tok.tensorize(self.questions, self.passages)

def _candidize(self):
encoding = self.tok.tok(self.passages, add_special_tokens=False)

all_candidates = [self.tok.get_all_candidates(encoding, index) for index in range(len(self.passages))]

bsize, maxcands = len(self.passages), max(map(len, all_candidates))
all_candidates = [cands + [(-1, -1)] * (maxcands - len(cands)) for cands in all_candidates]

candidates = torch.tensor(all_candidates)
assert candidates.size() == (bsize, maxcands, 2), (candidates.size(), (bsize, maxcands, 2), (self.questions, self.passages))

candidates = candidates + self.query_lengths.unsqueeze(-1).unsqueeze(-1)

return encoding, candidates, all_candidates
58 changes: 58 additions & 0 deletions baleen/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from baleen.utils.loaders import *
from baleen.condenser.condense import Condenser


class Baleen:
def __init__(self, collectionX_path: str, searcher, condenser: Condenser):
self.collectionX = load_collectionX(collectionX_path)
self.searcher = searcher
self.condenser = condenser

def search(self, query, num_hops, depth=100, verbose=False):
assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}."
k = depth // num_hops

searcher = self.searcher
condenser = self.condenser
collectionX = self.collectionX

facts = []
stage1_preds = None
context = None

pids_bag = set()

for hop_idx in range(0, num_hops):
ranking = list(zip(*searcher.search(query, context=context, k=depth)))
ranking_ = []

facts_pids = set([pid for pid, _ in facts])

for pid, rank, score in ranking:
# print(f'[{score}] \t\t {searcher.collection[pid]}')
if len(ranking_) < k and pid not in facts_pids:
ranking_.append(pid)

if len(pids_bag) < k * (hop_idx+1):
pids_bag.add(pid)

stage1_preds, facts = condenser.condense(query, backs=facts, ranking=ranking_)
context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])

assert len(pids_bag) == depth

return facts, pids_bag, stage1_preds














Loading

0 comments on commit c0b180a

Please sign in to comment.