|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +# This source code is licensed under the MIT license found in the |
| 3 | +# LICENSE file in the root directory of this source tree. |
| 4 | + |
| 5 | +# hack to make sure -m transformer/generator works as expected |
| 6 | +from .modules import TransformerEncoder |
| 7 | +from .modules import get_n_positions_from_options |
| 8 | +from parlai.core.torch_ranker_agent import TorchRankerAgent |
| 9 | +from .transformer import TransformerRankerAgent |
| 10 | +import torch |
| 11 | + |
| 12 | + |
| 13 | +class CrossencoderAgent(TorchRankerAgent): |
| 14 | + """ Equivalent of bert_ranker/crossencoder but does not rely on an external |
| 15 | + library (hugging face). |
| 16 | + """ |
| 17 | + |
| 18 | + def __init__(self, opt, shared=None): |
| 19 | + super().__init__(opt, shared) |
| 20 | + self.rank_loss = torch.nn.CrossEntropyLoss(reduce=True, size_average=True) |
| 21 | + if self.use_cuda: |
| 22 | + self.rank_loss.cuda() |
| 23 | + self.data_parallel = opt.get('data_parallel') and self.use_cuda |
| 24 | + if self.data_parallel: |
| 25 | + from parlai.core.distributed_utils import is_distributed |
| 26 | + |
| 27 | + if is_distributed(): |
| 28 | + raise ValueError('Cannot combine --data-parallel and distributed mode') |
| 29 | + self.model = torch.nn.DataParallel(self.model) |
| 30 | + |
| 31 | + @classmethod |
| 32 | + def add_cmdline_args(cls, argparser): |
| 33 | + """Add command-line arguments specifically for this agent.""" |
| 34 | + TransformerRankerAgent.add_cmdline_args(argparser) |
| 35 | + return argparser |
| 36 | + |
| 37 | + def build_model(self, states=None): |
| 38 | + self.model = CrossEncoderModule(self.opt, self.dict, self.NULL_IDX) |
| 39 | + return self.model |
| 40 | + |
| 41 | + def vectorize(self, *args, **kwargs): |
| 42 | + """ Add the start and end token to the text. |
| 43 | + """ |
| 44 | + kwargs['add_start'] = True |
| 45 | + kwargs['add_end'] = True |
| 46 | + obs = super().vectorize(*args, **kwargs) |
| 47 | + return obs |
| 48 | + |
| 49 | + def _set_text_vec(self, *args, **kwargs): |
| 50 | + """ Add the start and end token to the text. |
| 51 | + """ |
| 52 | + obs = super()._set_text_vec(*args, **kwargs) |
| 53 | + if 'text_vec' in obs: |
| 54 | + obs['text_vec'] = self._add_start_end_tokens(obs['text_vec'], True, True) |
| 55 | + return obs |
| 56 | + |
| 57 | + def concat_without_padding(self, text_idx, cand_idx, null_idx=0): |
| 58 | + """ if text_idx = [[1, 2, 3, 4, 0, 0 ]] |
| 59 | + and cand_idx = [[5, 6, 7, 8, 0, 0 ]] |
| 60 | + then result = (tokens, segments) where |
| 61 | + tokens = [[1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0]] |
| 62 | + segments = [[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0]] |
| 63 | + """ |
| 64 | + assert text_idx.size(0) == cand_idx.size(0) |
| 65 | + assert len(text_idx.size()) == 2 |
| 66 | + assert len(cand_idx.size()) == 2 |
| 67 | + segments_idx = [0, 1] |
| 68 | + text_idx = text_idx.cpu() |
| 69 | + cand_idx = cand_idx.cpu() |
| 70 | + cand_len = cand_idx.size(1) |
| 71 | + concat_len = text_idx.size(1) + cand_idx.size(1) |
| 72 | + tokens = text_idx.new_zeros(text_idx.size(0), concat_len) + null_idx |
| 73 | + segments = text_idx.new_zeros(text_idx.size(0), concat_len) + null_idx |
| 74 | + for i in range(len(tokens)): |
| 75 | + non_nuls = torch.sum(text_idx[i, :] != null_idx) |
| 76 | + tokens[i, 0:non_nuls] = text_idx[i, 0:non_nuls] |
| 77 | + segments[i, 0:non_nuls] = segments_idx[0] |
| 78 | + tokens[i, non_nuls : non_nuls + cand_len] = cand_idx[i, :] |
| 79 | + segments[i, non_nuls : non_nuls + cand_len] = segments_idx[1] |
| 80 | + if self.use_cuda: |
| 81 | + tokens = tokens.cuda() |
| 82 | + segments = segments.cuda() |
| 83 | + return tokens, segments |
| 84 | + |
| 85 | + def score_candidates(self, batch, cand_vecs, cand_encs=None): |
| 86 | + if cand_encs is not None: |
| 87 | + raise Exception( |
| 88 | + 'Candidate pre-computation is impossible on the ' 'crossencoder' |
| 89 | + ) |
| 90 | + num_cands_per_sample = cand_vecs.size(1) |
| 91 | + bsz = cand_vecs.size(0) |
| 92 | + text_idx = ( |
| 93 | + batch.text_vec.unsqueeze(1) |
| 94 | + .expand(-1, num_cands_per_sample, -1) |
| 95 | + .contiguous() |
| 96 | + .view(num_cands_per_sample * bsz, -1) |
| 97 | + ) |
| 98 | + cand_idx = cand_vecs.view(num_cands_per_sample * bsz, -1) |
| 99 | + tokens, segments = self.concat_without_padding( |
| 100 | + text_idx, cand_idx, self.NULL_IDX |
| 101 | + ) |
| 102 | + scores = self.model(tokens, segments) |
| 103 | + scores = scores.view(bsz, num_cands_per_sample) |
| 104 | + return scores |
| 105 | + |
| 106 | + |
| 107 | +class CrossEncoderModule(torch.nn.Module): |
| 108 | + """ A simple wrapper around the transformer encoder which adds a linear |
| 109 | + layer. |
| 110 | + """ |
| 111 | + |
| 112 | + def __init__(self, opt, dict, null_idx): |
| 113 | + super(CrossEncoderModule, self).__init__() |
| 114 | + n_positions = get_n_positions_from_options(opt) |
| 115 | + embeddings = torch.nn.Embedding( |
| 116 | + len(dict), opt['embedding_size'], padding_idx=null_idx |
| 117 | + ) |
| 118 | + torch.nn.init.normal_(embeddings.weight, 0, opt['embedding_size'] ** -0.5) |
| 119 | + self.encoder = TransformerEncoder( |
| 120 | + n_heads=opt['n_heads'], |
| 121 | + n_layers=opt['n_layers'], |
| 122 | + embedding_size=opt['embedding_size'], |
| 123 | + ffn_size=opt['ffn_size'], |
| 124 | + vocabulary_size=len(dict), |
| 125 | + embedding=embeddings, |
| 126 | + dropout=opt['dropout'], |
| 127 | + attention_dropout=opt['attention_dropout'], |
| 128 | + relu_dropout=opt['relu_dropout'], |
| 129 | + padding_idx=null_idx, |
| 130 | + learn_positional_embeddings=opt['learn_positional_embeddings'], |
| 131 | + embeddings_scale=opt['embeddings_scale'], |
| 132 | + reduction_type=opt.get('reduction_type', 'first'), |
| 133 | + n_positions=n_positions, |
| 134 | + n_segments=2, |
| 135 | + activation=opt['activation'], |
| 136 | + variant=opt['variant'], |
| 137 | + output_scaling=opt['output_scaling'], |
| 138 | + ) |
| 139 | + self.linear_layer = torch.nn.Linear(opt['embedding_size'], 1) |
| 140 | + |
| 141 | + def forward(self, tokens, segments): |
| 142 | + """ Scores each concatenation text + candidate. |
| 143 | + """ |
| 144 | + encoded = self.encoder(tokens, None, segments) |
| 145 | + res = self.linear_layer(encoded) |
| 146 | + return res |
0 commit comments