Skip to content

Commit 27c78c4

Browse files
authored
Bi, Poly and Cross encoder not dependent from Hugging Face code (facebookresearch#1791)
* first draft for cross encoder * first actual draft of crossencoder * biencoder 0.86 on convai2 * same performance for biencoder as in paper * removing logs * adding a mask to the basic attention * stable version of crossencoder * first draft of polyencoder * polyencoder gives the exact same results as the HF version * (merge from facebookresearch#1790) possibility to not share embeddings' * add or remove the residual in the basic attention * a few corrections of the polyencoder * operation-type was not working properly with data-paralle=true * some warinings + deleting this modification of the dictionary * flake8 * no mutable structure in argument * remove this 'surround' function, generalizes _add_start_end_tokens * interactive with polyencoder seems to work. It's not a great conversationalist though... * remove pdb * zoo files * flake8 * remove debug line. Adding polyencoder in model_list * syntax error * unifying APIs of BasicAttention and MultiheadAttention * Some corrections from Emily * some other nits * factorizing some ifs * highlighted the --init-model param a bit more * removing useless arguments * change basic_sqrt to sqrt * ahem * useless line * blacked! * black again * ahem... adding the biencoder * no need for that line * right, I was a bit rude there * black is beautiful
1 parent 9489d4b commit 27c78c4

12 files changed

+805
-23
lines changed
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# This source code is licensed under the MIT license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
from .transformer import TransformerRankerAgent
6+
from parlai.core.torch_ranker_agent import TorchRankerAgent
7+
import torch
8+
9+
10+
class BiencoderAgent(TransformerRankerAgent):
11+
""" Equivalent of bert_ranker/biencoder but does not rely on an external
12+
library (hugging face).
13+
"""
14+
15+
def __init__(self, opt, shared=None):
16+
super().__init__(opt, shared)
17+
# favor average instead of sum for the loss.
18+
self.rank_loss = torch.nn.CrossEntropyLoss(reduce=True, size_average=True)
19+
if self.use_cuda:
20+
self.rank_loss.cuda()
21+
22+
def vectorize(self, *args, **kwargs):
23+
""" Add the start and end token to the text.
24+
"""
25+
kwargs['add_start'] = True
26+
kwargs['add_end'] = True
27+
obs = TorchRankerAgent.vectorize(self, *args, **kwargs)
28+
return obs
29+
30+
def _set_text_vec(self, *args, **kwargs):
31+
""" Add the start and end token to the text.
32+
"""
33+
obs = super()._set_text_vec(*args, **kwargs)
34+
if 'text_vec' in obs:
35+
obs['text_vec'] = self._add_start_end_tokens(obs['text_vec'], True, True)
36+
return obs
+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# This source code is licensed under the MIT license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
# hack to make sure -m transformer/generator works as expected
6+
from .modules import TransformerEncoder
7+
from .modules import get_n_positions_from_options
8+
from parlai.core.torch_ranker_agent import TorchRankerAgent
9+
from .transformer import TransformerRankerAgent
10+
import torch
11+
12+
13+
class CrossencoderAgent(TorchRankerAgent):
14+
""" Equivalent of bert_ranker/crossencoder but does not rely on an external
15+
library (hugging face).
16+
"""
17+
18+
def __init__(self, opt, shared=None):
19+
super().__init__(opt, shared)
20+
self.rank_loss = torch.nn.CrossEntropyLoss(reduce=True, size_average=True)
21+
if self.use_cuda:
22+
self.rank_loss.cuda()
23+
self.data_parallel = opt.get('data_parallel') and self.use_cuda
24+
if self.data_parallel:
25+
from parlai.core.distributed_utils import is_distributed
26+
27+
if is_distributed():
28+
raise ValueError('Cannot combine --data-parallel and distributed mode')
29+
self.model = torch.nn.DataParallel(self.model)
30+
31+
@classmethod
32+
def add_cmdline_args(cls, argparser):
33+
"""Add command-line arguments specifically for this agent."""
34+
TransformerRankerAgent.add_cmdline_args(argparser)
35+
return argparser
36+
37+
def build_model(self, states=None):
38+
self.model = CrossEncoderModule(self.opt, self.dict, self.NULL_IDX)
39+
return self.model
40+
41+
def vectorize(self, *args, **kwargs):
42+
""" Add the start and end token to the text.
43+
"""
44+
kwargs['add_start'] = True
45+
kwargs['add_end'] = True
46+
obs = super().vectorize(*args, **kwargs)
47+
return obs
48+
49+
def _set_text_vec(self, *args, **kwargs):
50+
""" Add the start and end token to the text.
51+
"""
52+
obs = super()._set_text_vec(*args, **kwargs)
53+
if 'text_vec' in obs:
54+
obs['text_vec'] = self._add_start_end_tokens(obs['text_vec'], True, True)
55+
return obs
56+
57+
def concat_without_padding(self, text_idx, cand_idx, null_idx=0):
58+
""" if text_idx = [[1, 2, 3, 4, 0, 0 ]]
59+
and cand_idx = [[5, 6, 7, 8, 0, 0 ]]
60+
then result = (tokens, segments) where
61+
tokens = [[1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0]]
62+
segments = [[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0]]
63+
"""
64+
assert text_idx.size(0) == cand_idx.size(0)
65+
assert len(text_idx.size()) == 2
66+
assert len(cand_idx.size()) == 2
67+
segments_idx = [0, 1]
68+
text_idx = text_idx.cpu()
69+
cand_idx = cand_idx.cpu()
70+
cand_len = cand_idx.size(1)
71+
concat_len = text_idx.size(1) + cand_idx.size(1)
72+
tokens = text_idx.new_zeros(text_idx.size(0), concat_len) + null_idx
73+
segments = text_idx.new_zeros(text_idx.size(0), concat_len) + null_idx
74+
for i in range(len(tokens)):
75+
non_nuls = torch.sum(text_idx[i, :] != null_idx)
76+
tokens[i, 0:non_nuls] = text_idx[i, 0:non_nuls]
77+
segments[i, 0:non_nuls] = segments_idx[0]
78+
tokens[i, non_nuls : non_nuls + cand_len] = cand_idx[i, :]
79+
segments[i, non_nuls : non_nuls + cand_len] = segments_idx[1]
80+
if self.use_cuda:
81+
tokens = tokens.cuda()
82+
segments = segments.cuda()
83+
return tokens, segments
84+
85+
def score_candidates(self, batch, cand_vecs, cand_encs=None):
86+
if cand_encs is not None:
87+
raise Exception(
88+
'Candidate pre-computation is impossible on the ' 'crossencoder'
89+
)
90+
num_cands_per_sample = cand_vecs.size(1)
91+
bsz = cand_vecs.size(0)
92+
text_idx = (
93+
batch.text_vec.unsqueeze(1)
94+
.expand(-1, num_cands_per_sample, -1)
95+
.contiguous()
96+
.view(num_cands_per_sample * bsz, -1)
97+
)
98+
cand_idx = cand_vecs.view(num_cands_per_sample * bsz, -1)
99+
tokens, segments = self.concat_without_padding(
100+
text_idx, cand_idx, self.NULL_IDX
101+
)
102+
scores = self.model(tokens, segments)
103+
scores = scores.view(bsz, num_cands_per_sample)
104+
return scores
105+
106+
107+
class CrossEncoderModule(torch.nn.Module):
108+
""" A simple wrapper around the transformer encoder which adds a linear
109+
layer.
110+
"""
111+
112+
def __init__(self, opt, dict, null_idx):
113+
super(CrossEncoderModule, self).__init__()
114+
n_positions = get_n_positions_from_options(opt)
115+
embeddings = torch.nn.Embedding(
116+
len(dict), opt['embedding_size'], padding_idx=null_idx
117+
)
118+
torch.nn.init.normal_(embeddings.weight, 0, opt['embedding_size'] ** -0.5)
119+
self.encoder = TransformerEncoder(
120+
n_heads=opt['n_heads'],
121+
n_layers=opt['n_layers'],
122+
embedding_size=opt['embedding_size'],
123+
ffn_size=opt['ffn_size'],
124+
vocabulary_size=len(dict),
125+
embedding=embeddings,
126+
dropout=opt['dropout'],
127+
attention_dropout=opt['attention_dropout'],
128+
relu_dropout=opt['relu_dropout'],
129+
padding_idx=null_idx,
130+
learn_positional_embeddings=opt['learn_positional_embeddings'],
131+
embeddings_scale=opt['embeddings_scale'],
132+
reduction_type=opt.get('reduction_type', 'first'),
133+
n_positions=n_positions,
134+
n_segments=2,
135+
activation=opt['activation'],
136+
variant=opt['variant'],
137+
output_scaling=opt['output_scaling'],
138+
)
139+
self.linear_layer = torch.nn.Linear(opt['embedding_size'], 1)
140+
141+
def forward(self, tokens, segments):
142+
""" Scores each concatenation text + candidate.
143+
"""
144+
encoded = self.encoder(tokens, None, segments)
145+
res = self.linear_layer(encoded)
146+
return res

parlai/agents/transformer/modules.py

+49-20
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def _build_encoder(
7676
n_segments=n_segments,
7777
activation=opt['activation'],
7878
variant=opt['variant'],
79+
output_scaling=opt['output_scaling'],
7980
)
8081

8182

@@ -111,6 +112,22 @@ def gelu(tensor):
111112
return 0.5 * tensor * (1.0 + torch.erf(tensor / math.sqrt(2.0)))
112113

113114

115+
def get_n_positions_from_options(opt):
116+
if opt.get('n_positions'):
117+
# if the number of positions is explicitly provided, use that
118+
n_positions = opt['n_positions']
119+
else:
120+
# else, use the worst case from truncate
121+
n_positions = max(
122+
opt.get('truncate') or 0,
123+
opt.get('text_truncate') or 0,
124+
opt.get('label_truncate') or 0,
125+
)
126+
if n_positions == 0:
127+
n_positions = 1024
128+
return n_positions
129+
130+
114131
class TransformerMemNetModel(nn.Module):
115132
"""Model which takes context, memories, candidates and encodes them."""
116133

@@ -135,19 +152,7 @@ def __init__(self, opt, dictionary):
135152
if not self.share_word_embedding:
136153
self.cand_embeddings.weight.requires_grad = False
137154

138-
if opt.get('n_positions'):
139-
# if the number of positions is explicitly provided, use that
140-
n_positions = opt['n_positions']
141-
else:
142-
# else, use the worst case from truncate
143-
n_positions = max(
144-
opt.get('truncate') or 0,
145-
opt.get('text_truncate') or 0,
146-
opt.get('label_truncate') or 0,
147-
)
148-
if n_positions == 0:
149-
# default to 1024
150-
n_positions = 1024
155+
n_positions = get_n_positions_from_options(opt)
151156

152157
if n_positions < 0:
153158
raise ValueError('n_positions must be positive')
@@ -192,7 +197,9 @@ def __init__(self, opt, dictionary):
192197
else:
193198
self.memory_transformer = self.context_encoder
194199

195-
self.attender = BasicAttention(dim=2, attn=opt['memory_attention'])
200+
self.attender = BasicAttention(
201+
dim=2, attn=opt['memory_attention'], residual=True
202+
)
196203

197204
def encode_cand(self, words):
198205
"""Encode the candidates."""
@@ -318,6 +325,8 @@ class TransformerEncoder(nn.Module):
318325
:param variant:
319326
Which transformer architecture to use. Could be AIAYN or XLM.
320327
Future versions may support things like GPT-2, ...
328+
:param output_scaling:
329+
Scale the outputs by a given scalar
321330
"""
322331

323332
def __init__(
@@ -339,6 +348,7 @@ def __init__(
339348
activation='relu',
340349
variant='aiayn',
341350
n_segments=0,
351+
output_scaling=1.0,
342352
):
343353
super(TransformerEncoder, self).__init__()
344354

@@ -414,6 +424,7 @@ def __init__(
414424
activation=activation,
415425
)
416426
)
427+
self.output_scaling = output_scaling
417428

418429
def forward(self, input, positions=None, segments=None):
419430
"""
@@ -457,6 +468,7 @@ def forward(self, input, positions=None, segments=None):
457468
for i in range(self.n_layers):
458469
tensor = self.layers[i](tensor, mask)
459470

471+
tensor *= self.output_scaling
460472
if self.reduction_type == 'first':
461473
return tensor[:, 0, :]
462474
elif self.reduction_type == 'max':
@@ -805,29 +817,46 @@ def output(self, tensor):
805817
class BasicAttention(nn.Module):
806818
"""Implements simple/classical attention."""
807819

808-
def __init__(self, dim=1, attn='cosine'):
820+
def __init__(self, dim=1, attn='cosine', residual=False, get_weights=True):
809821
super().__init__()
810822
self.softmax = nn.Softmax(dim=dim)
811823
if attn == 'cosine':
812824
self.cosine = nn.CosineSimilarity(dim=dim)
813825
self.attn = attn
814826
self.dim = dim
827+
self.get_weights = get_weights
828+
self.residual = residual
815829

816-
def forward(self, xs, ys):
817-
"""Forward pass."""
830+
def forward(self, xs, ys, mask_ys=None):
831+
""" xs: B x query_len x dim
832+
ys: B x key_len x dim
833+
TODO: Document this
834+
"""
835+
bsz = xs.size(0)
836+
y_len = ys.size(1)
837+
x_len = xs.size(1)
818838
if self.attn == 'cosine':
819839
l1 = self.cosine(xs, ys).unsqueeze(self.dim - 1)
820840
else:
821841
l1 = torch.bmm(xs, ys.transpose(1, 2))
822842
if self.attn == 'sqrt':
823843
d_k = ys.size(-1)
824844
l1 = l1 / math.sqrt(d_k)
845+
if mask_ys is not None:
846+
attn_mask = (mask_ys == 0).view(bsz, 1, y_len)
847+
attn_mask = attn_mask.repeat(1, x_len, 1)
848+
l1.masked_fill_(attn_mask, -float('inf'))
825849
l2 = self.softmax(l1)
826850
lhs_emb = torch.bmm(l2, ys)
827-
# add back the query
828-
lhs_emb = lhs_emb.add(xs)
829851

830-
return lhs_emb.squeeze(self.dim - 1), l2
852+
# # add back the query
853+
if self.residual:
854+
lhs_emb = lhs_emb.add(xs)
855+
856+
if self.get_weights:
857+
return lhs_emb.squeeze(self.dim - 1), l2
858+
else:
859+
return lhs_emb.squeeze(self.dim - 1)
831860

832861

833862
class MultiHeadAttention(nn.Module):

0 commit comments

Comments
 (0)