From fd8ac2e676f8e0b99dbbb70fa65d6af2ae5beecd Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Thu, 18 Jul 2019 11:54:11 +0200 Subject: [PATCH 01/28] Bert init commit --- onmt/__init__.py | 2 +- onmt/encoders/transformer.py | 14 ++- onmt/inputters/__init__.py | 12 +- onmt/inputters/inputter.py | 89 ++++++++++++-- onmt/model_builder.py | 80 +++++++++++++ onmt/models/__init__.py | 6 +- onmt/models/bert.py | 148 +++++++++++++++++++++++ onmt/models/language_model.py | 151 ++++++++++++++++++++++++ onmt/models/model_saver.py | 75 +++++++++++- onmt/modules/__init__.py | 4 +- onmt/modules/bert_embed.py | 79 +++++++++++++ onmt/modules/position_ffn.py | 19 ++- onmt/train_single.py | 65 ++++++---- onmt/trainer.py | 216 ++++++++++++++++++++++++++++------ onmt/utils/__init__.py | 12 +- onmt/utils/loss.py | 105 +++++++++++++++++ onmt/utils/optimizers.py | 216 +++++++++++++++++++++++++++++++++- onmt/utils/report_manager.py | 2 + onmt/utils/statistics.py | 81 +++++++++++++ train.py | 52 +++++--- 20 files changed, 1307 insertions(+), 121 deletions(-) create mode 100644 onmt/models/bert.py create mode 100644 onmt/models/language_model.py create mode 100644 onmt/modules/bert_embed.py diff --git a/onmt/__init__.py b/onmt/__init__.py index 840e0aefd5..597e35a17f 100644 --- a/onmt/__init__.py +++ b/onmt/__init__.py @@ -2,9 +2,9 @@ from __future__ import division, print_function import onmt.inputters +import onmt.models import onmt.encoders import onmt.decoders -import onmt.models import onmt.utils import onmt.modules from onmt.trainer import Trainer diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 4ebc9eae98..26f3dd7e89 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -4,6 +4,7 @@ import torch.nn as nn +import onmt from onmt.encoders.encoder import EncoderBase from onmt.modules import MultiHeadedAttention from onmt.modules.position_ffn import PositionwiseFeedForward @@ -23,15 +24,20 @@ class TransformerEncoderLayer(nn.Module): """ def __init__(self, d_model, heads, d_ff, dropout, - max_relative_positions=0): + max_relative_positions=0, activation='ReLU', is_bert=False): super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiHeadedAttention( heads, d_model, dropout=dropout, max_relative_positions=max_relative_positions) - self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, activation, is_bert) + self.layer_norm = onmt.models.BertLayerNorm(d_model, eps=1e-12) if is_bert else nn.LayerNorm(d_model, eps=1e-6) self.dropout = nn.Dropout(dropout) + self.is_bert = is_bert + + def residual(self, output, x): + maybe_norm = self.layer_norm(x) if self.is_bert else x + return output + maybe_norm def forward(self, inputs, mask): """ @@ -47,7 +53,7 @@ def forward(self, inputs, mask): input_norm = self.layer_norm(inputs) context, _ = self.self_attn(input_norm, input_norm, input_norm, mask=mask, type="self") - out = self.dropout(context) + inputs + out = self.residual(self.dropout(context), inputs) return self.feed_forward(out) def update_dropout(self, dropout): diff --git a/onmt/inputters/__init__.py b/onmt/inputters/__init__.py index b10fb85058..10990102df 100644 --- a/onmt/inputters/__init__.py +++ b/onmt/inputters/__init__.py @@ -4,14 +4,14 @@ e.g., from a line of text to a sequence of embeddings. """ from onmt.inputters.inputter import \ - load_old_vocab, get_fields, OrderedIterator, \ + load_old_vocab, get_fields, get_bert_fields, OrderedIterator, \ build_vocab, old_style_vocab, filter_example from onmt.inputters.dataset_base import Dataset from onmt.inputters.text_dataset import text_sort_key, TextDataReader from onmt.inputters.image_dataset import img_sort_key, ImageDataReader from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader from onmt.inputters.datareader_base import DataReaderBase - +from onmt.inputters.dataset_bert import BertDataset str2reader = { "text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader} @@ -19,8 +19,8 @@ 'text': text_sort_key, 'img': img_sort_key, 'audio': audio_sort_key} -__all__ = ['Dataset', 'load_old_vocab', 'get_fields', 'DataReaderBase', - 'filter_example', 'old_style_vocab', - 'build_vocab', 'OrderedIterator', - 'text_sort_key', 'img_sort_key', 'audio_sort_key', +__all__ = ['Dataset', 'load_old_vocab', 'get_fields', 'get_bert_fields', + 'DataReaderBase', 'filter_example', 'old_style_vocab', + 'build_vocab', 'OrderedIterator', 'text_sort_key', + 'img_sort_key', 'audio_sort_key', 'BertDataset', 'TextDataReader', 'ImageDataReader', 'AudioDataReader'] diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index f7437810dc..ee7ec86e55 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -137,6 +137,39 @@ def get_fields( return fields +def get_bert_fields(pad='[PAD]', bos='[CLS]', eos='[SEP]', unk='[UNK]'): + fields = {} + # tokens_kwargs = {"n_feats": 0, + # "include_lengths": True, + # "pad": "[PAD]", "bos": "[CLS]", "eos": "[SEP]", + # "truncate": src_truncate, + # "base_name": "tokens"} + # fields["tokens"] = text_fields(**tokens_kwargs) + tokens = Field(sequential=True, use_vocab=True, pad_token=pad, + unk_token=unk, include_lengths=True, batch_first=True) + fields["tokens"] = tokens + + segment_ids = Field(use_vocab=False, dtype=torch.long, + sequential=True, pad_token=0, batch_first=True) + fields["segment_ids"] = segment_ids + + is_next = Field(use_vocab=False, dtype=torch.long, + sequential=False, batch_first=True) # 0/1 + fields["is_next"] = is_next + + # masked_lm_positions = Field(use_vocab=False, dtype=torch.int, + # sequential=False) # indices that masked: [int] + # fields["masked_lm_positions"] = masked_lm_positions + + # masked_lm_labels = Field(use_vocab=True, sequential=False)# tokens masked + # fields["masked_lm_labels"] = masked_lm_labels + + lm_labels_ids = Field(sequential=True, use_vocab=False, + pad_token=-1, batch_first=True) + fields["lm_labels_ids"] = lm_labels_ids + return fields + + def load_old_vocab(vocab, data_type="text", dynamic_dict=False): """Update a legacy vocab/field format. @@ -348,6 +381,23 @@ def _build_fields_vocab(fields, counters, data_type, share_vocab, return fields +def _build_bert_fields_vocab(fields, counters, vocab_size, + tokens_min_frequency=0, vocab_size_multiple=1): + tokens_field = fields["tokens"] + tokens_counter = counters["tokens"] + # NOTE: Do not use _build_field_vocab + # as the special tokens is fixed in origin bert vocab file + # _build_field_vocab(tokens_field, tokens_counter, + # size_multiple=vocab_size_multiple, + # max_size=vocab_size, min_freq=tokens_min_frequency) + tokens_field.vocab = tokens_field.vocab_cls(tokens_counter, specials=[], + max_size=vocab_size, min_freq=tokens_min_frequency) + if vocab_size_multiple > 1: + _pad_vocab_to_multiple(tokens_field.vocab, vocab_size_multiple) + + return fields + + def build_vocab(train_dataset_files, fields, data_type, share_vocab, src_vocab_path, src_vocab_size, src_words_min_frequency, tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency, @@ -748,19 +798,32 @@ def max_tok_len(new, count, sofar): such that the total number of src/tgt tokens (including padding) in a batch <= batch_size """ - # Maintains the longest src and tgt length in the current batch - global max_src_in_batch, max_tgt_in_batch # this is a hack - # Reset current longest length at a new batch (count=1) - if count == 1: - max_src_in_batch = 0 - max_tgt_in_batch = 0 - # Src: [ w1 ... wN ] - max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2) - # Tgt: [w1 ... wM ] - max_tgt_in_batch = max(max_tgt_in_batch, len(new.tgt[0]) + 1) - src_elements = count * max_src_in_batch - tgt_elements = count * max_tgt_in_batch - return max(src_elements, tgt_elements) + if hasattr(new, 'is_next'): + # when a example has the attr 'is_next', + # this means we are loading Bert Data + # Maintains the longest token length in the current batch + global max_tokens_in_batch + # Reset current longest length at a new batch (count=1) + if count == 1: + max_tokens_in_batch = 0 + # tokens: ['[CLS]', '[MASK]', ..., '[SEP]','This',...,'B','[SEP]'] + max_tokens_in_batch = max(max_tokens_in_batch, len(new.tokens)) + tokens_nelem = count * max_tokens_in_batch + return tokens_nelem + else: + # Maintains the longest src and tgt length in the current batch + global max_src_in_batch, max_tgt_in_batch # this is a hack + # Reset current longest length at a new batch (count=1) + if count == 1: + max_src_in_batch = 0 + max_tgt_in_batch = 0 + # Src: [ w1 ... wN ] + max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2) + # Tgt: [w1 ... wM ] + max_tgt_in_batch = max(max_tgt_in_batch, len(new.tgt[0]) + 1) + src_elements = count * max_src_in_batch + tgt_elements = count * max_tgt_in_batch + return max(src_elements, tgt_elements) def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False): diff --git a/onmt/model_builder.py b/onmt/model_builder.py index a4cfbe641f..0aa063fa50 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -19,6 +19,9 @@ from onmt.utils.logging import logger from onmt.utils.parse import ArgumentParser +from onmt.models import BertLM, BERT +# from onmt.modules.bert_embed import BertEmbeddings + def build_embeddings(opt, text_field, for_encoder=True): """ @@ -223,3 +226,80 @@ def build_model(model_opt, opt, fields, checkpoint): model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint) logger.info(model) return model + + +def build_bert(model_opt, opt, fields, checkpoint): + logger.info('Building BERT model...') + model = build_bertLM(model_opt, fields, use_gpu(opt), checkpoint) + logger.info(model) + return model + + +def build_bertLM(model_opt, fields, gpu, checkpoint=None, gpu_id=None): + """Build a model from opts. + + Args: + model_opt: the option loaded from checkpoint. It's important that + the opts have been updated and validated. See + :class:`onmt.utils.parse.ArgumentParser`. + fields (dict[str, torchtext.data.Field]): + `Field` objects for the model. + gpu (bool): whether to use gpu. + checkpoint: the model generated by train phase, or a resumed snapshot + model from a stopped training. + gpu_id (int or NoneType): Which GPU to use. + + Returns: + the BertLM, composed of Bert with 2 generator heads for 2 task. + """ + # TODO: compability of opt.vocab_size + # Build BertEmbeddings + # tokens_fields = fields['tokens'] + # vocab_size = len(tokens_fields.vocab) + + # Build BertModel(= encoder), BertEmbeddings also built inside Bert. + if gpu and gpu_id is not None: + device = torch.device("cuda", gpu_id) + elif gpu and not gpu_id: + device = torch.device("cuda") + elif not gpu: + device = torch.device("cpu") + bert = build_bert_encoder(model_opt, fields, gpu, checkpoint) + # BertEmbeddings is built inside Bert + # tokens_emb = bert.embeddings + model = BertLM(bert) + + # load states from checkpoints + if checkpoint is not None: + logger.info("load states from checkpoints...") + # TODO: check model.load_state_dict(...) + model.load_state_dict(checkpoint['model'], strict=False) + else: + logger.info("No checkpoint, Initialize Parameters...") + if model_opt.param_init_normal != 0.0: + normal_std = model_opt.param_init_normal + for p in model.parameters(): + p.data.normal_(mean=0, std=normal_std) + elif model_opt.param_init != 0.0: + for p in model.parameters(): + p.data.uniform_(-model_opt.param_init, model_opt.param_init) + elif model_opt.param_init_glorot: + for p in model.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + else: + raise AttributeError("Initialization method haven't be used!") + + model.to(device) + return model + + +def build_bert_encoder(model_opt, fields, gpu, checkpoint=None, gpu_id=None): + # TODO: need to be more elegent + token_fields_vocab = fields['tokens'].vocab + vocab_size = len(token_fields_vocab) + bert = BERT(vocab_size, num_layers=model_opt.layers, + d_model=model_opt.word_vec_size, heads=model_opt.heads, + dropout=model_opt.dropout[0], + max_relative_positions=model_opt.max_relative_positions) + return bert diff --git a/onmt/models/__init__.py b/onmt/models/__init__.py index 20fe4c83b6..8cd5aa41d7 100644 --- a/onmt/models/__init__.py +++ b/onmt/models/__init__.py @@ -1,6 +1,8 @@ """Module defining models.""" from onmt.models.model_saver import build_model_saver, ModelSaver from onmt.models.model import NMTModel +from onmt.models.language_model import BertLM +from onmt.models.bert import BERT, BertLayerNorm -__all__ = ["build_model_saver", "ModelSaver", - "NMTModel", "check_sru_requirement"] +__all__ = ["build_model_saver", "ModelSaver", "NMTModel", "BERT", + "BertLM", "BertLayerNorm", "check_sru_requirement"] diff --git a/onmt/models/bert.py b/onmt/models/bert.py new file mode 100644 index 0000000000..fa56397789 --- /dev/null +++ b/onmt/models/bert.py @@ -0,0 +1,148 @@ +import torch +import torch.nn as nn +from onmt.modules.bert_embed import BertEmbeddings +from onmt.encoders.transformer import TransformerEncoderLayer + + +class BERT(nn.Module): + """ + BERT Implementation: https://arxiv.org/abs/1810.04805 + Use a Transformer Encoder as Language modeling. + """ + def __init__(self, vocab_size, num_layers=12, d_model=768, heads=12, + dropout=0.1, max_relative_positions=0): + super(BERT, self).__init__() + self.vocab_size = vocab_size + self.num_layers = num_layers + self.d_model = d_model # = hidden_size = embed_size + self.heads = heads + self.dropout = dropout + # Feed-Forward size is set to be 4H as in paper + self.d_ff = 4 * d_model + + # Build Embeddings according to vocab_size and d_model + # --DONE--: BERTEmbeddings() + # ref. build_embeddings in onmt.model_builder.py + # BERT input embeddings is sum of: + # 1. Token embeddings + # 2. Segmentation embeddings + # 3. Position embeddings + self.embeddings = BertEmbeddings(vocab_size=vocab_size, + embed_size=d_model, dropout=dropout) + + # Transformer Encoder Block + self.transformer_encoder = nn.ModuleList( + [TransformerEncoderLayer(d_model, heads, self.d_ff, dropout, + max_relative_positions=max_relative_positions, + activation='GeLU', is_bert=True) for _ in range(num_layers)]) + + self.layer_norm = BertLayerNorm(d_model, eps=1e-12) + self.pooler = BertPooler(d_model) + # TODO: self.apply(self.init_bert_weight) + + def forward(self, input_ids, token_type_ids=None, input_mask=None, + output_all_encoded_layers=False): + """ + Args: + input_ids: shape [batch, seq] padding ids=0 + token_type_ids: shape [batch, seq], A(0), B(1), pad(0) + input_mask: shape [batch, seq], 1 for masked position(that padding) + output_all_encoded_layers: if out contain all hidden layer + Returns: + all_encoder_layers: list of out in shape (batch, src, d_model) + """ + # # version 1: coder timo waiting for mask of size [B,1,T,T] + # [batch, seq] -> [batch, 1, seq] + # -> [batch, seq, seq] -> [batch, 1, seq, seq] + # attention masking for padded token + # mask: torch.ByteTensor([batch, 1, seq, seq]) + # mask = (input_ids > 0).unsqueeze(1) + # .repeat(1, input_ids.size(1), 1).unsqueeze(1) + # # This version mask 0, different masked_fill in Attention + + # # version 2: hugging face waiting for mask of size [B,1,1,T] + # if attention_mask is None: + # attention_mask = torch.ones_like(input_ids) + # if token_type_ids is None: + # token_type_ids = torch.zeros_like(input_ids) + # # extended_attention_mask.shape = [batch_size, 1, 1, seq_length] + # -> broadcast to [batch, num_heads, from_seq_length, to_seq_length] + # extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + # # for fp16 compatibility + # extended_attention_mask = extended_attention_mask + # .to(dtype=next(self.parameters()).dtype) + # -10000.0 for mask, 0 otherwise + # extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # # version 3: OpenNMT waiting for mask of size [B, 1, T], + # while in MultiHeadAttention part2 -> [B, 1, 1, T] + # TODO: create_attention_mask_from_input_mask + # padding_idx = self.embeddings.word_padding_idx + # mask = input_ids.data.eq(padding_idx).unsqueeze(1) + if input_mask is None: + # input_mask = torch.ones_like(input_ids) + # shape: 2D tensor [batch, seq] + padding_idx = self.embeddings.word_padding_idx + # input_mask = input_ids.data.ne(padding_idx) + # shape: 2D tensor [batch, seq]: 1 for tokens, 0 for paddings + input_mask = input_ids.data.eq(padding_idx) + # if token_type_ids is None: + # NOTE: not needed! already done in bert_embed.py + # token_type_ids = torch.zeros_like(input_ids) + # [batch, seq] -> [batch, 1, seq] + attention_mask = input_mask.unsqueeze(1) + + # embedding vectors: [batch, seq, hidden_size] + out = self.embeddings(input_ids, token_type_ids) + + all_encoder_layers = [] + for layer in self.transformer_encoder: + out = layer(out, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(self.layer_norm(out)) + out = self.layer_norm(out) + if not output_all_encoded_layers: + all_encoder_layers.append(out) + pooled_out = self.pooler(out) + return all_encoder_layers, pooled_out + + def update_dropout(self, dropout): + self.dropout = dropout + self.embeddings.update_dropout(dropout) + for layer in self.transformer_encoder: + layer.update_dropout(dropout) + + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation_fn = nn.Tanh() + + def forward(self, hidden_states): + """ + Args: + hidden_states: last layer's hidden_states,(batch, src, d_model) + Returns: + pooled_output: transformed output of last layer's hidden_states + """ + first_token_tensor = hidden_states[:, 0, :] # [batch, d_model] + pooled_output = self.activation_fn(self.dense(first_token_tensor)) + return pooled_output + + +class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style + (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias diff --git a/onmt/models/language_model.py b/onmt/models/language_model.py new file mode 100644 index 0000000000..51575da00d --- /dev/null +++ b/onmt/models/language_model.py @@ -0,0 +1,151 @@ +import torch + +import torch.nn as nn +import onmt +from onmt.utils.fn_activation import GELU + + +class BertLM(nn.Module): + """ + BERT Language Model for pretraining, trained with 2 task : + Next Sentence Prediction Model + Masked Language Model + """ + def __init__(self, bert: onmt.models.BERT): + """ + Args: + bert: BERT model which should be trained + """ + super(BertLM, self).__init__() + self.bert = bert + self.vocab_size = bert.vocab_size + self.cls = BertPreTrainingHeads(self.bert.d_model, self.vocab_size, + self.bert.embeddings.word_embeddings.weight) + + def forward(self, input_ids, token_type_ids, input_mask=None, + output_all_encoded_layers=False): + """ + Args: + input_ids: shape [batch, seq] padding ids=0 + token_type_ids: shape [batch, seq], A(0), B(1), pad(0) + input_mask: shape [batch, seq], 1 for masked position(that padding) + Returns: + seq_class_log_prob: next sentence predi, (batch, 2) + prediction_log_prob: masked lm predi, (batch, seq, vocab) + """ + x, pooled_out = self.bert(input_ids, token_type_ids, input_mask, + output_all_encoded_layers) + seq_class_log_prob, prediction_log_prob = self.cls(x, pooled_out) + return seq_class_log_prob, prediction_log_prob + + +class BertPreTrainingHeads(nn.Module): + """ + Bert Pretraining Heads: Masked Language Models, Next Sentence Prediction + """ + def __init__(self, hidden_size, vocab_size, embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.next_sentence = NextSentencePrediction(hidden_size) + self.mask_lm = MaskedLanguageModel(hidden_size, vocab_size, + embedding_weights) + + def forward(self, x, pooled_out): + """ + Args: + x: list of out of all_encoder_layers, shape (batch, seq, d_model) + pooled_output: transformed output of last layer's hidden_states + Returns: + seq_class_log_prob: next sentence prediction, (batch, 2) + prediction_log_prob: masked lm prediction, (batch, seq, vocab) + """ + seq_class_log_prob = self.next_sentence(pooled_out) + prediction_log_prob = self.mask_lm(x[-1]) + return seq_class_log_prob, prediction_log_prob + + +class MaskedLanguageModel(nn.Module): + """ + predicting origin token from masked input sequence + n-class classification problem, n-class = vocab_size + """ + + def __init__(self, hidden_size, vocab_size, + bert_word_embedding_weights=None): + """ + Args: + hidden_size: output size of BERT model + vocab_size: total vocab size + bert_word_embedding_weights: reuse embedding weights if set + """ + super(MaskedLanguageModel, self).__init__() + self.transform = BertPredictionTransform(hidden_size) + self.reuse_emb = (True + if bert_word_embedding_weights is not None + else False) + if self.reuse_emb: # NOTE: reinit ? + assert hidden_size == bert_word_embedding_weights.size(1) + assert vocab_size == bert_word_embedding_weights.size(0) + self.decode = nn.Linear(bert_word_embedding_weights.size(1), + bert_word_embedding_weights.size(0), + bias=False) + self.decode.weight = bert_word_embedding_weights + self.bias = nn.Parameter(torch.zeros(vocab_size)) + else: + self.decode = nn.Linear(hidden_size, vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(vocab_size)) + + self.softmax = nn.LogSoftmax(dim=-1) + + def forward(self, x): + """ + Args: + x: last layer output of bert, shape (batch, seq, d_model) + Returns: + prediction_log_prob: shape (batch, seq, vocab) + """ + x = self.transform(x) # (batch, seq, d_model) + prediction_scores = self.decode(x) + self.bias # (batch, seq, vocab) + prediction_log_prob = self.softmax(prediction_scores) + return prediction_log_prob + + +class NextSentencePrediction(nn.Module): + """ + 2-class classification model : is_next, is_random_next + """ + + def __init__(self, hidden_size): + """ + Args: + hidden_size: BERT model output size + """ + super(NextSentencePrediction, self).__init__() + self.linear = nn.Linear(hidden_size, 2) + self.softmax = nn.LogSoftmax(dim=-1) + + def forward(self, x): + """ + Args: + x: last layer's output of bert encoder, shape (batch, src, d_model) + Returns: + seq_class_prob: shape (batch_size, 2) + """ + seq_relationship_score = self.linear(x) # (batch, 2) + seq_class_log_prob = self.softmax(seq_relationship_score) # (batch, 2) + return seq_class_log_prob + + +class BertPredictionTransform(nn.Module): + def __init__(self, hidden_size): + super(BertPredictionTransform, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = GELU() # get_activation fn + self.layer_norm = onmt.models.BertLayerNorm(hidden_size, eps=1e-12) + + def forward(self, hidden_states): + """ + Args: + hidden_states: BERT model output size (batch, seq, d_model) + """ + hidden_states = self.layer_norm(self.activation( + self.dense(hidden_states))) + return hidden_states diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index c67b839384..e4a8d10768 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -9,12 +9,20 @@ def build_model_saver(model_opt, opt, model, fields, optim): - model_saver = ModelSaver(opt.save_model, - model, - model_opt, - fields, - optim, - opt.keep_checkpoint) + if opt.is_bert: + model_saver = BertModelSaver(opt.save_model, + model, + model_opt, + fields, + optim, + opt.keep_checkpoint) + else: + model_saver = ModelSaver(opt.save_model, + model, + model_opt, + fields, + optim, + opt.keep_checkpoint) return model_saver @@ -138,3 +146,58 @@ def _save(self, step, model): def _rm_checkpoint(self, name): os.remove(name) + + +class BertModelSaver(ModelSaverBase): + """Simple model saver to filesystem""" + + def _save(self, step, model): + real_model = (model.module + if isinstance(model, nn.DataParallel) + else model) + # real_generator = (real_model.generator.module + # if isinstance(real_model.generator, nn.DataParallel) + # else real_model.generator) + + model_state_dict = real_model.state_dict() + model_state_dict = {k: v for k, v in model_state_dict.items() + if 'generator' not in k} + # generator_state_dict = real_generator.state_dict() + + # NOTE: We need to trim the vocab to remove any unk tokens that + # were not originally here. + + vocab = deepcopy(self.fields) + for side in ["tokens"]: + keys_to_pop = [] + # if hasattr(vocab[side], "fields"): + # unk_token = vocab[side].fields[0][1].vocab.itos[0] + # for key, value in vocab[side].fields[0][1] + # .vocab.stoi.items(): + # if value == 0 and key != unk_token: + # keys_to_pop.append(key) + # for key in keys_to_pop: + # vocab[side].fields[0][1].vocab.stoi.pop(key, None) + unk_token = vocab[side].unk_token + unk_id = vocab[side].vocab.stoi[unk_token] + for key, value in vocab[side].vocab.stoi.items(): + if value == unk_id and key != unk_token: + keys_to_pop.append(key) + for key in keys_to_pop: + vocab[side].vocab.stoi.pop(key, None) + + checkpoint = { + 'model': model_state_dict, + # 'generator': generator_state_dict, + 'vocab': vocab, + 'opt': self.model_opt, + 'optim': self.optim.state_dict(), + } + + logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) + checkpoint_path = '%s_step_%d.pt' % (self.base_path, step) + torch.save(checkpoint, checkpoint_path) + return checkpoint, checkpoint_path + + def _rm_checkpoint(self, name): + os.remove(name) diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 38ff142b47..0bed435ff7 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -3,12 +3,12 @@ from onmt.modules.gate import context_gate_factory, ContextGate from onmt.modules.global_attention import GlobalAttention from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention -from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ - CopyGeneratorLossCompute from onmt.modules.multi_headed_attn import MultiHeadedAttention from onmt.modules.embeddings import Embeddings, PositionalEncoding from onmt.modules.weight_norm import WeightNormConv2d from onmt.modules.average_attn import AverageAttention +from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ + CopyGeneratorLossCompute __all__ = ["Elementwise", "context_gate_factory", "ContextGate", "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", diff --git a/onmt/modules/bert_embed.py b/onmt/modules/bert_embed.py new file mode 100644 index 0000000000..2f3ebfa9d4 --- /dev/null +++ b/onmt/modules/bert_embed.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn + + +class TokenEmb(nn.Embedding): + """ Embeddings for tokens. + """ + def __init__(self, vocab_size, hidden_size=768, padding_idx=0): + super(TokenEmb, self).__init__(vocab_size, hidden_size, + padding_idx=padding_idx) + + +class SegmentEmb(nn.Embedding): + """ Embeddings for token's type: sentence A(0), sentence B(1). Padding with 0. + """ + def __init__(self, type_vocab_size=2, hidden_size=768, padding_idx=0): + super(SegmentEmb, self).__init__(type_vocab_size, hidden_size, + padding_idx=padding_idx) + + +class PositionEmb(nn.Embedding): + """ Embeddings for token's position. + """ + def __init__(self, max_position=512, hidden_size=768): + super(PositionEmb, self).__init__(max_position, hidden_size) + + +class BertEmbeddings(nn.Module): + """ BERT input embeddings is sum of: + 1. Token embeddings: called word_embeddings + 2. Segmentation embeddings: called token_type_embeddings + 3. Position embeddings: called position_embeddings + """ + def __init__(self, vocab_size, embed_size, pad_idx=0, dropout=0.1): + """ + Args: + vocab_size: int. Size of the embedding vocabulary. + embed_size: int. Width of the word embeddings. + dropout: dropout rate + """ + super(BertEmbeddings, self).__init__() + self.word_padding_idx = pad_idx + self.word_embeddings = TokenEmb(vocab_size, hidden_size=embed_size, + padding_idx=pad_idx) + self.position_embeddings = PositionEmb(512, hidden_size=embed_size) + self.token_type_embeddings = SegmentEmb(2, hidden_size=embed_size, + padding_idx=pad_idx) + + self.dropout = nn.Dropout(dropout) + + def forward(self, input_ids, token_type_ids=None): + """ + Args: + input_ids: word ids in shape [batch, seq, hidden_size]. + token_type_ids: token type ids in shape [batch, seq]. + Output: + embeddings: word embeds in shape [batch, seq, hidden_size]. + """ + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, + device=input_ids.device) # [0, 1,..., seq_length-1] + # [[0,1,...,seq_length-1]] -> [[0,1,...,seq_length-1] *batch_size] + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + word_embeds = self.word_embeddings(input_ids) + position_embeds = self.position_embeddings(position_ids) + token_type_embeds = self.token_type_embeddings(token_type_ids) + + embeddings = word_embeds + position_embeds + token_type_embeds + # in our version, LN is done in EncoderLayer before fed into Attention + # embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def update_dropout(self, dropout): + self.dropout.p = dropout diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index fb8df80aa7..3a095d31ba 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -2,6 +2,9 @@ import torch.nn as nn +import onmt +from onmt.utils.fn_activation import GELU + class PositionwiseFeedForward(nn.Module): """ A two-layer Feed-Forward-Network with residual layer norm. @@ -13,14 +16,22 @@ class PositionwiseFeedForward(nn.Module): dropout (float): dropout probability in :math:`[0, 1)`. """ - def __init__(self, d_model, d_ff, dropout=0.1): + def __init__(self, d_model, d_ff, dropout=0.1, + activation='ReLU', is_bert=False): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.layer_norm = (onmt.models.BertLayerNorm(d_model, eps=1e-12) + if is_bert + else nn.LayerNorm(d_model, eps=1e-6)) self.dropout_1 = nn.Dropout(dropout) - self.relu = nn.ReLU() + self.relu = GELU() if activation == 'GeLU' else nn.ReLU() self.dropout_2 = nn.Dropout(dropout) + self.is_bert = is_bert + + def residual(self, output, x): + maybe_norm = self.layer_norm(x) if self.is_bert else x + return output + maybe_norm def forward(self, x): """Layer definition. @@ -34,7 +45,7 @@ def forward(self, x): inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) output = self.dropout_2(self.w_2(inter)) - return output + x + return self.residual(output, x) def update_dropout(self, dropout): self.dropout_1.p = dropout diff --git a/onmt/train_single.py b/onmt/train_single.py index e65002b9e9..26bb8193d9 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -6,7 +6,7 @@ from onmt.inputters.inputter import build_dataset_iter, \ load_old_vocab, old_style_vocab, build_dataset_iter_multiple -from onmt.model_builder import build_model +from onmt.model_builder import build_model, build_bert from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed from onmt.trainer import build_trainer @@ -51,11 +51,13 @@ def main(opt, device_id, batch_queue=None, semaphore=None): logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) - model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) - ArgumentParser.update_model_opts(model_opt) - ArgumentParser.validate_model_opts(model_opt) + # model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) + model_opt = opt # TODO: test + # ArgumentParser.update_model_opts(model_opt) # TODO + # ArgumentParser.validate_model_opts(model_opt) # TODO logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) - vocab = checkpoint['vocab'] + # vocab = checkpoint['vocab'] + vocab = torch.load(opt.data + '.vocab.pt') # TODO else: checkpoint = None model_opt = opt @@ -63,39 +65,54 @@ def main(opt, device_id, batch_queue=None, semaphore=None): # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) - if old_style_vocab(vocab): + if opt.is_bert: # TODO: test amelioration + fields = vocab + elif old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab - # Report src and tgt vocab sizes, including for features - for side in ['src', 'tgt']: - f = fields[side] - try: - f_iter = iter(f) - except TypeError: - f_iter = [(side, f)] - for sn, sf in f_iter: - if sf.use_vocab: - logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) + if opt.is_bert: + # Report bert tokens vocab sizes, including for features + f = fields['tokens'] + if f.use_vocab: # NOTE: useless! + logger.info(' * %s vocab size = %d' % ("BERT", len(f.vocab))) + else: + # Report src and tgt vocab sizes, including for features + for side in ['src', 'tgt']: + f = fields[side] + try: + f_iter = iter(f) + except TypeError: + f_iter = [(side, f)] + for sn, sf in f_iter: + if sf.use_vocab: + logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. - model = build_model(model_opt, opt, fields, checkpoint) - n_params, enc, dec = _tally_parameters(model) - logger.info('encoder: %d' % enc) - logger.info('decoder: %d' % dec) - logger.info('* number of parameters: %d' % n_params) + if opt.is_bert: + model = build_bert(model_opt, opt, fields, checkpoint) + n_params = 0 + for param in model.parameters(): + n_params += param.nelement() + logger.info('* number of parameters: %d' % n_params) + else: + model = build_model(model_opt, opt, fields, checkpoint) + n_params, enc, dec = _tally_parameters(model) + logger.info('encoder: %d' % enc) + logger.info('decoder: %d' % dec) + logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) - # Build optimizer. - optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) + # Build optimizer. # TODO: checkpoint=checkpoint # DEBUG + optim = Optimizer.from_opt(model, opt, checkpoint=None) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( - opt, device_id, model, fields, optim, model_saver=model_saver) + opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: diff --git a/onmt/trainer.py b/onmt/trainer.py index 7200adf12e..caf1886ef1 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -31,11 +31,14 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object used to save the model """ - - tgt_field = dict(fields)["tgt"].base_field - train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) - valid_loss = onmt.utils.loss.build_loss_compute( - model, tgt_field, opt, train=False) + if opt.is_bert: + train_loss = onmt.utils.loss.build_bert_loss_compute(opt) + valid_loss = onmt.utils.loss.build_bert_loss_compute(opt, train=False) + else: + tgt_field = dict(fields)["tgt"].base_field + train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) + valid_loss = onmt.utils.loss.build_loss_compute( + model, tgt_field, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0 @@ -131,13 +134,19 @@ def __init__(self, model, train_loss, valid_loss, optim, self.earlystopper = earlystopper self.dropout = dropout self.dropout_steps = dropout_steps - + self.is_bert = True if isinstance(self.model, + onmt.models.language_model.BertLM) else False # NOTE: NEW parameter for bert training + for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 if self.accum_count_l[i] > 1: assert self.trunc_size == 0, \ """To enable accumulated gradients, you must disable target sequence truncating.""" + + if self.is_bert: + assert self.trunc_size == 0 + """ Bert currently not support target sequence truncating""" # TODO # Set model in training mode. self.model.train() @@ -148,7 +157,7 @@ def _accum_count(self, step): _accum = self.accum_count_l[i] return _accum - def _maybe_update_dropout(self, step): + def _maybe_update_dropout(self, step): # TODO: to be test with Bert for i in range(len(self.dropout_steps)): if step > 1 and step == self.dropout_steps[i] + 1: self.model.update_dropout(self.dropout[i]) @@ -161,12 +170,15 @@ def _accum_batches(self, iterator): self.accum_count = self._accum_count(self.optim.training_step) for batch in iterator: batches.append(batch) - if self.norm_method == "tokens": - num_tokens = batch.tgt[1:, :, 0].ne( - self.train_loss.padding_idx).sum() - normalization += num_tokens.item() + if self.is_bert: + normalization += 1 else: - normalization += batch.batch_size + if self.norm_method == "tokens": + num_tokens = batch.tgt[1:, :, 0].ne( + self.train_loss.padding_idx).sum() + normalization += num_tokens.item() + else: + normalization += batch.batch_size if len(batches) == self.accum_count: yield batches, normalization self.accum_count = self._accum_count(self.optim.training_step) @@ -215,9 +227,12 @@ def train(self, else: logger.info('Start training loop and validate every %d steps...', valid_steps) - - total_stats = onmt.utils.Statistics() - report_stats = onmt.utils.Statistics() + if self.is_bert: + total_stats = onmt.utils.BertStatistics() + report_stats = onmt.utils.BertStatistics() + else: + total_stats = onmt.utils.Statistics() + report_stats = onmt.utils.Statistics() self._start_report_manager(start_time=total_stats.start_time) for i, (batches, normalization) in enumerate( @@ -233,15 +248,22 @@ def train(self, n_minibatch %d" % (self.gpu_rank, i + 1, len(batches))) - if self.n_gpu > 1: - normalization = sum(onmt.utils.distributed - .all_gather_list - (normalization)) + if self.n_gpu > 1: # NOTE: DEBUG + list_norm = onmt.utils.distributed.all_gather_list(normalization) + # current_rank = torch.distributed.get_rank() + # print("-> RANK: {}".format(current_rank)) + # print(list_norm) + normalization = sum(list_norm) - self._gradient_accumulation( + # Training Step: Forward -> compute Loss -> optimize + if self.is_bert: + self._bert_gradient_accumulation(batches, normalization, total_stats, report_stats) + else: + self._gradient_accumulation( batches, normalization, total_stats, report_stats) - + + # Moving average if self.average_decay > 0 and i % self.average_every == 0: self._update_average(step) @@ -249,7 +271,10 @@ def train(self, step, train_steps, self.optim.learning_rate(), report_stats) - + # NOTE: DEBUG + # exit() + + # Part: validation if valid_iter is not None and step % valid_steps == 0: if self.gpu_verbose_level > 0: logger.info('GpuRank %d: validate step %d' @@ -302,22 +327,41 @@ def validate(self, valid_iter, moving_average=None): valid_model.eval() with torch.no_grad(): - stats = onmt.utils.Statistics() - - for batch in valid_iter: - src, src_lengths = batch.src if isinstance(batch.src, tuple) \ - else (batch.src, None) - tgt = batch.tgt - - # F-prop through the model. - outputs, attns = valid_model(src, tgt, src_lengths) + # TODO:if not Bert + if self.is_bert: + stats = onmt.utils.BertStatistics() + for batch in valid_iter: + # input_ids: Size([batch_size, max_seq_length_in_batch]), seq_lengths: Size([batch_size]) + input_ids, seq_lengths = batch.tokens if isinstance(batch.tokens, tuple) \ + else (batch.tokens, None) + # segment_ids, lm_labels_ids: Size([batch_size, max_seq_length_in_batch]), is_next: Size([batch_size]) + token_type_ids = batch.segment_ids # 0 for sens A, 1 for sens B. 0 padding + is_next = batch.is_next + lm_labels_ids = batch.lm_labels_ids # -1 padding, others for predict in lm task + # F-prop through the model. # NOTE: keyword args: input_mask, output_all_encoded_layers + seq_class_log_prob, prediction_log_prob = valid_model(input_ids, token_type_ids) + outputs = (seq_class_log_prob, prediction_log_prob) + # Compute loss. + _, batch_stats = self.valid_loss(batch, outputs) + + # Update statistics. + stats.update(batch_stats) + else: + stats = onmt.utils.Statistics() + for batch in valid_iter: + src, src_lengths = batch.src if isinstance(batch.src, tuple) \ + else (batch.src, None) + tgt = batch.tgt - # Compute loss. - _, batch_stats = self.valid_loss(batch, outputs, attns) + # F-prop through the model. + outputs, attns = valid_model(src, tgt, src_lengths) - # Update statistics. - stats.update(batch_stats) + # Compute loss. + _, batch_stats = self.valid_loss(batch, outputs, attns) + # Update statistics. + stats.update(batch_stats) + if moving_average: del valid_model else: @@ -454,3 +498,105 @@ def _report_step(self, learning_rate, step, train_stats=None, return self.report_manager.report_step( learning_rate, step, train_stats=train_stats, valid_stats=valid_stats) + + + def _bert_gradient_accumulation(self, true_batches, normalization, total_stats, + report_stats): + if self.accum_count > 1: + self.optim.zero_grad() + + for k, batch in enumerate(true_batches): + # target_size = batch.tgt.size(0) + # NOTE: for batch in BERT : batch_first is True -> [batch, seq, vocab] + # # Truncated BPTT: reminder not compatible with accum > 1 + # if self.trunc_size: # TODO + # trunc_size = self.trunc_size + # else: + # trunc_size = target_size + + input_ids, seq_lengths = batch.tokens if isinstance(batch.tokens, tuple) \ + else (batch.tokens, None) + if seq_lengths is not None: + report_stats.n_src_words += seq_lengths.sum().item() + + # tgt_outer = batch.tgt + token_type_ids = batch.segment_ids + is_next = batch.is_next + lm_labels_ids = batch.lm_labels_ids + # bptt = False + # TODO: to be removed, as not support bptt yet! + # for j in range(0, target_size-1, trunc_size): + # 1. Create truncated target. + # tgt = tgt_outer[j: j + trunc_size] + + # 2. F-prop all to get log likelihood of two task. + if self.accum_count == 1: + self.optim.zero_grad() + seq_class_log_prob, prediction_log_prob = self.model(input_ids, token_type_ids) + # NOTE: (batch_size, 2), (batch_size, seq_size, vocab_size) + outputs = (seq_class_log_prob, prediction_log_prob) + # outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt) + # bptt = True + + # 3. Compute loss. + try: # NOTE: unuse normalisation + loss, batch_stats = self.train_loss(batch, outputs) + # NOTE: DEBUG + # loss_list = onmt.utils.distributed.all_gather_list(loss) + # current_rank = torch.distributed.get_rank() + # print("{}-> RANK: {}, loss:{} in {}".format( + # k, current_rank, loss, loss_list)) + # print("{}-> RANK: {}, stat:{}".format( + # k, current_rank, batch_stats.loss)) + # print(str(loss) + " ~ " +str(loss_list)) + + if loss is not None: + self.optim.backward(loss) + + total_stats.update(batch_stats) + report_stats.update(batch_stats) + # print(str(loss.item())+ " - " + str(report_stats.loss)) + # exit() + except Exception: + traceback.print_exc() + logger.info("At step %d, we removed a batch - accum %d", + self.optim.training_step, k) + + # 4. Update the parameters and statistics. + if self.accum_count == 1: + # Multi GPU gradient gather + if self.n_gpu > 1: + grads = [p.grad.data for p in self.model.parameters() + if p.requires_grad + and p.grad is not None] + # current_rank = torch.distributed.get_rank() + # print("{}-> RANK: {}, grads BEFORE:{}".format( + # k, current_rank, grads[0])) + # NOTE: average the gradient across the GPU + onmt.utils.distributed.all_reduce_and_rescale_tensors( + grads, float(self.n_gpu)) + # reduced_grads = [p.grad.data for p in self.model.parameters() + # if p.requires_grad + # and p.grad is not None] + # print("{}-> RANK: {}, grads AFTER:{}".format( + # k, current_rank, reduced_grads[0])) + self.optim.step() + + # If truncated, don't backprop fully. + # TO CHECK + # if dec_state is not None: + # dec_state.detach() + # if self.model.decoder.state is not None: # TODO: ?? + # self.model.decoder.detach_state() + + # in case of multi step gradient accumulation, + # update only after accum batches + if self.accum_count > 1: + if self.n_gpu > 1: + grads = [p.grad.data for p in self.model.parameters() + if p.requires_grad + and p.grad is not None] + # NOTE: average the gradient across the GPU + onmt.utils.distributed.all_reduce_and_rescale_tensors( + grads, float(self.n_gpu)) + self.optim.step() diff --git a/onmt/utils/__init__.py b/onmt/utils/__init__.py index 55dae40872..a1f422333a 100644 --- a/onmt/utils/__init__.py +++ b/onmt/utils/__init__.py @@ -1,12 +1,14 @@ """Module defining various utilities.""" + from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed from onmt.utils.report_manager import ReportMgr, build_report_manager -from onmt.utils.statistics import Statistics +from onmt.utils.statistics import Statistics, BertStatistics from onmt.utils.optimizers import MultipleOptimizer, \ - Optimizer, AdaFactor + Optimizer, AdaFactor, BertAdam from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts + __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", - "build_report_manager", "Statistics", - "MultipleOptimizer", "Optimizer", "AdaFactor", "EarlyStopping", - "scorers_from_opts"] + "build_report_manager", "Statistics", "BertStatistics", + "MultipleOptimizer", "Optimizer", "AdaFactor", "BertAdam", + "EarlyStopping", "scorers_from_opts"] diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index f7f8cf6586..f5abd9034d 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -56,6 +56,111 @@ def build_loss_compute(model, tgt_field, opt, train=True): return compute +def build_bert_loss_compute(opt, train=True): + """FOR BERT PRETRAINING. + Returns a LossCompute subclass which wraps around an nn.Module subclass + (such as nn.NLLLoss) which defines the loss criterion. + """ + device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") + # BERT use -1 for unmasked token in lm_label_ids + criterion = nn.NLLLoss(ignore_index=-1, reduction='mean') + compute = BertLoss(criterion).to(device) + return compute + + +class BertLoss(nn.Module): + def __init__(self, criterion): + super(BertLoss, self).__init__() + self.criterion = criterion + + @property + def padding_idx(self): + return self.criterion.ignore_index + + def _bottle(self, _v): + return _v.view(-1, _v.size(2)) + + def _stats(self, loss, mlm_scores, mlm_target, + nx_sent_scores, nx_sent_target): + """ + Args: + loss (:obj:`FloatTensor`): the loss computed by the loss criterion. + scores (:obj:`FloatTensor`): a score for each possible output + target (:obj:`FloatTensor`): true targets + + Returns: + :obj:`onmt.utils.Statistics` : statistics for this batch. + """ + # masked lm task + pred_mlm = mlm_scores.argmax(1) # (batch*seq, vocab) -> (batch*seq) + non_padding = mlm_target.ne(self.padding_idx) # mask: (batch*seq) + mlm_match = pred_mlm.eq(mlm_target).masked_select(non_padding) + num_correct = mlm_match.sum().item() + num_non_padding = non_padding.sum().item() + + # next sentence prediction task + pred_nx_sent = nx_sent_scores.argmax(-1) # (batch_size, 2) -> (2) + num_correct_nx_sent = nx_sent_target.eq(pred_nx_sent).sum().item() + num_sentence = len(nx_sent_target) + # print("lm: {}/{}".format(num_correct, num_non_padding)) + # print("nx: {}/{}".format(num_correct_nx_sent, num_sentence)) + return onmt.utils.BertStatistics(loss.item(), num_non_padding, + num_correct, num_sentence, + num_correct_nx_sent) + + # TODO: currently not support trunc_size & shard_size + # def _make_shard_state(self, batch, output): + # return { + # "output": output, + # "target": batch.tgt[range_[0] + 1: range_[1], :, 0], + # } + + # def _compute_loss(self, batch, output, target): + # bottled_output = self._bottle(output) + + # scores = self.generator(bottled_output) + # gtruth = target.view(-1) + + # loss = self.criterion(scores, gtruth) + # stats = self._stats(loss.clone(), scores, gtruth) + + # return loss, stats + + def forward(self, batch, outputs, normalization=1.0): # TODO: shard=0 + """ + Args: + batch: batch of examples + outputs: tuple of log proba for next sentense & lm + (seq_class_log_prob:(batch, 2), + prediction_log_prob:(batch, seq, vocab)) + """ + assert isinstance(outputs, tuple) + seq_class_log_prob, prediction_log_prob = outputs + assert list(seq_class_log_prob.size()) == [len(batch), 2] + + gtruth_next_sentence = batch.is_next # (batch,) + gtruth_masked_lm = batch.lm_labels_ids # (batch, seq) + # (batch, seq, vocab) -> (batch * seq, vocab) + bottled_prediction_log_prob = self._bottle(prediction_log_prob) + bottled_gtruth_masked_lm = gtruth_masked_lm.view(-1) # (batch * seq) + # loss mean by number of sentence + next_loss = self.criterion(seq_class_log_prob, gtruth_next_sentence) + # loss mean by number of masked token + mask_loss = self.criterion(bottled_prediction_log_prob, + bottled_gtruth_masked_lm) + total_loss = next_loss + mask_loss # total_loss reduced by mean + # loss_accum_normalized = total_loss #/ float(normalization) + # print("loss: ({} + {})/{} = {}".format(next_loss, mask_loss, + # float(normalization), loss_accum_normalized)) + # print("nx: {}/{}".format(num_correct_nx_sent, num_sentence)) + stats = self._stats(total_loss.clone(), + bottled_prediction_log_prob, + bottled_gtruth_masked_lm, + seq_class_log_prob, + gtruth_next_sentence) + return total_loss, stats + + class LossComputeBase(nn.Module): """ Class for managing efficient loss computation. Handles diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index 59d9ff4c0c..3d2ed3bd7b 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -5,7 +5,7 @@ import operator import functools from copy import copy -from math import sqrt +from math import sqrt, cos, pi def build_torch_optimizer(model, opt): @@ -52,6 +52,12 @@ def build_torch_optimizer(model, opt): lr=opt.learning_rate, betas=betas, eps=1e-9) + elif opt.optim == 'bertadam': # TODO:to be verified + optimizer = BertAdam( + params, + lr=opt.learning_rate, + betas=betas, + eps=1e-9) elif opt.optim == 'sparseadam': dense = [] sparse = [] @@ -111,6 +117,33 @@ def make_learning_rate_decay_fn(opt): rate=opt.learning_rate_decay, decay_steps=opt.decay_steps, start_step=opt.start_decay_steps) + elif opt.decay_method == 'linear': + return functools.partial( + linear_decay, + warmup_steps=opt.warmup_steps, + total_steps=opt.train_steps) + elif opt.decay_method == 'linearconst': + return functools.partial( + linear_decay, + warmup_steps=opt.warmup_steps) + elif opt.decay_method == 'cosine': + return functools.partial( + cosine_decay, + warmup_steps=opt.warmup_steps, + total_steps=opt.train_steps, + cycles=opt.cycles if opt.cycles is not None else 0.5) + elif opt.decay_method == 'cosine_hard_restart': + return functools.partial( + cosine_hard_restart_decay, + warmup_steps=opt.warmup_steps, + total_steps=opt.train_steps, + cycles=opt.cycles if opt.cycles is not None else 1.0) + elif opt.decay_method == 'cosine_warmup_restart': + return functools.partial( + cosine_warmup_restart_decay, + warmup_steps=opt.warmup_steps, + total_steps=opt.train_steps, + cycles=opt.cycles if opt.cycles is not None else 1.0) elif opt.decay_method == 'rsqrt': return functools.partial( rsqrt_decay, warmup_steps=opt.warmup_steps) @@ -153,6 +186,85 @@ def rsqrt_decay(step, warmup_steps): return 1.0 / sqrt(max(step, warmup_steps)) +def linear_decay(step, warmup_steps, total_steps): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, linearly decrease the lr from 1 to 0 over (warmup_steps, train_step) + """ + if not 0 <= warmup_steps < total_steps: + raise ValueError("Invalid decay: check warmup_step & train_steps") + if step > total_steps: + raise ValueError("Invalid step: step surpass train_steps!") + if step < warmup_steps: + return step / warmup_steps * 1.0 + else: + return max((total_steps - step) / (total_steps - warmup_steps), 0) + + +def linear_constant_decay(step, warmup_steps): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, keep constant. + """ + if step < warmup_steps: + return step / warmup_steps * 1.0 + return 1.0 + + +def cosine_decay(step, warmup_steps, total_steps, cycles=0.5): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, decrease lr from 1 to 0 over (warmup_steps, train_step) + following cosine curve. + """ + + if not 0 <= warmup_steps < total_steps: + raise ValueError("Invalid decay: check warmup_step & train_steps") + if step > total_steps: + raise ValueError("Invalid step: step surpass train_steps!") + if step < warmup_steps: + return step / warmup_steps * 1.0 + else: + progress = (step - warmup_steps) / (total_steps - warmup_steps) + return 0.5 * (1 + cos(pi * cycles * 2 * progress)) + + +def cosine_hard_restart_decay(step, warmup_steps, total_steps, cycles=1.0): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, decrease the lr from 1 over (warmup_steps, train_step) + following cosine curve. + If `cycles` is different from default(1.0), learning rate follows + `cycles` times a cosine decaying learning rate (with hard restarts). + """ + assert(cycles >= 1.0) + if not 0 <= warmup_steps < total_steps: + raise ValueError("Invalid decay: check warmup_step & train_steps") + if step > total_steps: + raise ValueError("Invalid step: step surpass train_steps!") + if step < warmup_steps: + return step / warmup_steps * 1.0 + else: + progress = (step - warmup_steps) / (total_steps - warmup_steps) + return 0.5 * (1 + cos(pi * ((cycles * progress) % 1))) + + +def cosine_warmup_restart_decay(step, warmup_steps, total_steps, cycles=1.0): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, decrease the lr from 1 to 0 over (warmup_steps, train_step) + following cosine curve. + """ + if not 0 <= warmup_steps < total_steps: + raise ValueError("Invalid decay: check warmup_step & train_steps") + if step > total_steps: + raise ValueError("Invalid step: step surpass train_steps!") + if not cycles * warmup_steps / total_steps < 1.0: + raise ValueError("Invalid decay: Error for decay! Check cycles!") + warmup_ratio = warmup_steps * cycles / total_steps + progress = (step * cycles / total_steps) % 1 + if progress < warmup_ratio: + return progress / warmup_ratio + else: + progress = (progress - warmup_ratio) / (1 - warmup_ratio) + return 0.5 * (1 + cos(pi * progress)) + + class MultipleOptimizer(object): """ Implement multiple optimizers needed for sparse adam """ @@ -513,3 +625,105 @@ def step(self, closure=None): p.data.add_(-group['weight_decay'] * lr_t, p.data) return loss + + +class BertAdam(torch.optim.Optimizer): + """Implements BERT version of Adam algorithm with weight decay fix + (while doesn't compensate for bias). + Ref: https://arxiv.org/abs/1711.05101 + Params: + lr: learning rate + betas: Adam betas(beta1, beta2). Default: (0.9, 0.999) + eps: Adams epsilon. Default: 1e-6 + weight_decay: Weight decay. Default: 0.01 + # TODO: exclude LayerNorm from weight decay? + max_grad_norm: Maximum norm for the gradients (-1 means no clipping). + """ + def __init__(self, params, lr=None, betas=(0.9, 0.999), + eps=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr) + + " - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid betas[0] parameter: {}".format( + betas[0]) + " - should be in [0.0, 1.0)") + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid betas[1] parameter: {}".format( + betas[1]) + " - should be in [0.0, 1.0)") + if not eps >= 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps) + + " - should be >= 0.0") + defaults = dict(lr=lr, betas=betas, + eps=eps, weight_decay=weight_decay, + max_grad_norm=max_grad_norm) + super(BertAdam, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam : not support sparse gradients,' + + 'please consider SparseAdam instead') + + state = self.state[p] + + # State initialization + if len(state) == 0: + # state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + # NOTE: Add grad clipping, DONE before step function + # if group['max_grad_norm'] > 0: + # clip_grad_norm_(p, group['max_grad_norm']) + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + # exp_avg = exp_avg * beta1 + (1-beta1)*grad + exp_avg.mul_(beta1).add_(1 - beta1, grad) + # exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2)*grad**2 + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + update = exp_avg / (exp_avg_sq.sqrt() + group['eps']) + + # ref: https://arxiv.org/abs/1711.05101 + # Just adding the square of the weights to the loss function + # is *not* the correct way of using L2/weight decay with Adam, + # since it will interact with m/v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't + # interact with the m/v. This is equivalent to add the square + # of the weights to the loss with plain (non-momentum) SGD. + if group['weight_decay'] > 0.0: + update += group['weight_decay'] * p.data + + lr_scheduled = group['lr'] + + update_with_lr = lr_scheduled * update + p.data.add_(-update_with_lr) + + # state['step'] += 1 + + # NOTE: BertAdam "No bias correction" comparing to standard + # bias_correction1 = 1 - betas[0] ** state['step'] + # bias_correction2 = 1 - betas[1] ** state['step'] + # step_size = lr_scheduled * math.sqrt(bias_correction2) + # / bias_correction1 + + return loss diff --git a/onmt/utils/report_manager.py b/onmt/utils/report_manager.py index 8fea78c444..06a46e1e9b 100644 --- a/onmt/utils/report_manager.py +++ b/onmt/utils/report_manager.py @@ -76,6 +76,8 @@ def report_training(self, step, num_steps, learning_rate, self._report_training( step, num_steps, learning_rate, report_stats) self.progress_step += 1 + if isinstance(report_stats, onmt.utils.BertStatistics): + return onmt.utils.BertStatistics() return onmt.utils.Statistics() else: return report_stats diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index 896d98c74d..a79bb56ad7 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -134,3 +134,84 @@ def log_tensorboard(self, prefix, writer, learning_rate, step): writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) writer.add_scalar(prefix + "/lr", learning_rate, step) + + +class BertStatistics(Statistics): + """ Bert Statistics as the loss is reduced by mean """ + def __init__(self, loss=0, n_words=0, n_correct=0, + n_sentence=0, n_correct_nx_sentence=0): + super(BertStatistics, self).__init__(loss, n_words, n_correct) + self.n_update = 0 if n_words == 0 else 1 + self.n_sentence = n_sentence + self.n_correct_nx_sentence = n_correct_nx_sentence + + def next_sentence_accuracy(self): + """ compute accuracy """ + return 100 * (self.n_correct_nx_sentence / self.n_sentence) + + def xent(self): + """ compute cross entropy """ + return self.loss + + def ppl(self): + """ compute perplexity """ + return math.exp(min(self.loss, 100)) + + def update(self, stat, update_n_src_words=False): + """ + Update statistics by suming values with another `Statistics` object + + Args: + stat: another statistic object + update_n_src_words(bool): whether to update (sum) `n_src_words` + or not + + """ + assert isinstance(stat, BertStatistics) + self.loss = (self.loss * self.n_update + stat.loss * + stat.n_update) / (self.n_update + stat.n_update) + self.n_update += 1 + self.n_words += stat.n_words + self.n_correct += stat.n_correct + self.n_sentence += stat.n_sentence + self.n_correct_nx_sentence += stat.n_correct_nx_sentence + + if update_n_src_words: + self.n_src_words += stat.n_src_words + + def output(self, step, num_steps, learning_rate, start): + """Write out statistics to stdout. + + Args: + step (int): current step + n_batch (int): total batches + start (int): start time of step. + """ + t = self.elapsed_time() + step_fmt = "%2d" % step + if num_steps > 0: + step_fmt = "%s/%5d" % (step_fmt, num_steps) + logger.info( + ("Step %s; acc(mlm/nx):%6.2f/%6.2f; total ppl: %5.2f; " + + "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + % (step_fmt, + self.accuracy(), + self.next_sentence_accuracy(), + self.ppl(), + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_words / (t + 1e-5), + time.time() - start)) + sys.stdout.flush() + + def log_tensorboard(self, prefix, writer, learning_rate, step): + """ display statistics to tensorboard """ + t = self.elapsed_time() + writer.add_scalar(prefix + "/xent", self.xent(), step) + writer.add_scalar(prefix + "/ppl", self.ppl(), step) + writer.add_scalar(prefix + "/accuracy(mlm)", self.accuracy(), step) + writer.add_scalar(prefix + "/accuracy(nx)", + self.next_sentence_accuracy(), step) + writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) + writer.add_scalar(prefix + "/lr", learning_rate, step) diff --git a/train.py b/train.py index d00f161a91..adfb1e8dcc 100755 --- a/train.py +++ b/train.py @@ -4,7 +4,7 @@ import signal import torch -import onmt.opts as opts +import onmt.opts_bert as opts import onmt.utils.distributed from onmt.utils.misc import set_random_seed @@ -18,9 +18,10 @@ def main(opt): - ArgumentParser.validate_train_opts(opt) - ArgumentParser.update_model_opts(opt) - ArgumentParser.validate_model_opts(opt) + # JUST FOR verify the options + # ArgumentParser.validate_train_opts(opt) + # ArgumentParser.update_model_opts(opt) + # ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: @@ -28,13 +29,16 @@ def main(opt): checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) - vocab = checkpoint['vocab'] + # vocab = checkpoint['vocab'] TODO:test + vocab = torch.load(opt.data + '.vocab.pt') else: vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) - if old_style_vocab(vocab): + if opt.is_bert: + fields = vocab + elif old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: @@ -114,17 +118,27 @@ def next_batch(device_id): for device_id, q in cycle(enumerate(queues)): b.dataset = None - if isinstance(b.src, tuple): - b.src = tuple([_.to(torch.device(device_id)) - for _ in b.src]) + if opt.is_bert: + if isinstance(b.tokens, tuple): + b.tokens = tuple([_.to(torch.device(device_id)) + for _ in b.tokens]) + else: + b.tokens = b.tokens.to(torch.device(device_id)) + b.segment_ids = b.segment_ids.to(torch.device(device_id)) + b.is_next = b.is_next.to(torch.device(device_id)) + b.lm_labels_ids = b.lm_labels_ids.to(torch.device(device_id)) else: - b.src = b.src.to(torch.device(device_id)) - b.tgt = b.tgt.to(torch.device(device_id)) - b.indices = b.indices.to(torch.device(device_id)) - b.alignment = b.alignment.to(torch.device(device_id)) \ - if hasattr(b, 'alignment') else None - b.src_map = b.src_map.to(torch.device(device_id)) \ - if hasattr(b, 'src_map') else None + if isinstance(b.src, tuple): + b.src = tuple([_.to(torch.device(device_id)) + for _ in b.src]) + else: + b.src = b.src.to(torch.device(device_id)) + b.tgt = b.tgt.to(torch.device(device_id)) + b.indices = b.indices.to(torch.device(device_id)) + b.alignment = b.alignment.to(torch.device(device_id)) \ + if hasattr(b, 'alignment') else None + b.src_map = b.src_map.to(torch.device(device_id)) \ + if hasattr(b, 'src_map') else None # hack to dodge unpicklable `dict_keys` b.fields = list(b.fields) @@ -188,8 +202,10 @@ def _get_parser(): parser = ArgumentParser(description='train.py') opts.config_opts(parser) - opts.model_opts(parser) - opts.train_opts(parser) + # opts.model_opts(parser) + # opts.train_opts(parser) + opts.bert_model_opts(parser) + opts.bert_pretrainning(parser) return parser From 7601a887eacd3c4f9ea9a6a2d44bccbe7e5e2cb2 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Thu, 18 Jul 2019 11:55:56 +0200 Subject: [PATCH 02/28] support file --- bert_ckp_convert.py | 120 +++++++++ onmt/utils/bert_tokenization.py | 415 ++++++++++++++++++++++++++++++ onmt/utils/file_utils.py | 270 +++++++++++++++++++ pregenerate_bert_training_data.py | 352 +++++++++++++++++++++++++ 4 files changed, 1157 insertions(+) create mode 100644 bert_ckp_convert.py create mode 100644 onmt/utils/bert_tokenization.py create mode 100644 onmt/utils/file_utils.py create mode 100755 pregenerate_bert_training_data.py diff --git a/bert_ckp_convert.py b/bert_ckp_convert.py new file mode 100644 index 0000000000..73d7286f82 --- /dev/null +++ b/bert_ckp_convert.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +""" Convert ckp of huggingface to onmt version""" +from argparse import ArgumentParser +from pathlib import Path +# import pytorch_pretrained_bert +# from pytorch_pretrained_bert.modeling import BertForPreTraining +import torch +import onmt +from collections import OrderedDict +import re + +# -1 +def decrement(matched): + value = int(matched.group(1)) + if value < 1: + raise ValueError('Value Error when converting string') + string = "bert.encoder.layer.{}.output.LayerNorm".format(value-1) + return string + +def convert_key(key, max_layers): + if 'bert.embeddings' in key: + key = key + + elif 'bert.transformer_encoder' in key: + # convert layer_norm weights + key = re.sub(r'bert.transformer_encoder.0.layer_norm\.(.*)', + r'bert.embeddings.LayerNorm.\1', key) + key = re.sub(r'bert.transformer_encoder\.(\d+)\.layer_norm', + decrement, key) # TODO + # convert attention weights + key = re.sub(r'bert.transformer_encoder\.(\d+)\.self_attn.linear_keys\.(.*)', + r'bert.encoder.layer.\1.attention.self.key.\2', key) + key = re.sub(r'bert.transformer_encoder\.(\d+)\.self_attn.linear_values\.(.*)', + r'bert.encoder.layer.\1.attention.self.value.\2', key) + key = re.sub(r'bert.transformer_encoder\.(\d+)\.self_attn.linear_query\.(.*)', + r'bert.encoder.layer.\1.attention.self.query.\2', key) + key = re.sub(r'bert.transformer_encoder\.(\d+)\.self_attn.final_linear\.(.*)', + r'bert.encoder.layer.\1.attention.output.dense.\2', key) + # convert feed forward weights + key = re.sub(r'bert.transformer_encoder\.(\d+)\.feed_forward.layer_norm\.(.*)', + r'bert.encoder.layer.\1.attention.output.LayerNorm.\2', key) + key = re.sub(r'bert.transformer_encoder\.(\d+)\.feed_forward.w_1\.(.*)', + r'bert.encoder.layer.\1.intermediate.dense.\2', key) + key = re.sub(r'bert.transformer_encoder\.(\d+)\.feed_forward.w_2\.(.*)', + r'bert.encoder.layer.\1.output.dense.\2', key) + + elif 'bert.layer_norm' in key: + key = re.sub(r'bert.layer_norm', + r'bert.encoder.layer.'+str(max_layers-1)+'.output.LayerNorm', key) + elif 'bert.pooler' in key: + key = key + elif 'cls.next_sentence' in key: + key = re.sub(r'cls.next_sentence.linear\.(.*)', + r'cls.seq_relationship.\1', key) + elif 'cls.mask_lm' in key: + key = re.sub(r'cls.mask_lm.bias', + r'cls.predictions.bias', key) + key = re.sub(r'cls.mask_lm.decode.weight', + r'cls.predictions.decoder.weight', key) + key = re.sub(r'cls.mask_lm.transform.dense\.(.*)', + r'cls.predictions.transform.dense.\1', key) + key = re.sub(r'cls.mask_lm.transform.layer_norm\.(.*)', + r'cls.predictions.transform.LayerNorm.\1', key) + else: + raise ValueError("Unexpected keys!") + return key + + +def load_bert_weights(bert_model, weights_dict, n_layers=12): + bert_model_keys = bert_model.state_dict().keys() + weights_keys = weights_dict.keys() + model_weights = OrderedDict() + + try: + for key in bert_model_keys: + key_huggingface = convert_key(key, n_layers) + # model_weights[key] = converted_key + model_weights[key] = weights_dict[key_huggingface] + except ValueError: + print("Unsuccessful convert!") + exit() + return model_weights + + +def main(): + parser = ArgumentParser() + parser.add_argument("--layers", type=int, default=None) + parser.add_argument("--bert_model", type=str, default="bert-base-multilingual-uncased")#, # required=True, + # choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", + # "bert-base-multilingual-uncased", "bert-base-chinese"]) + parser.add_argument("--bert_model_weights_path", type=str, default="PreTrainedBertckp/") + parser.add_argument("--output_dir", type=Path, default="PreTrainedBertckp/") + parser.add_argument("--output_name", type=str, default="onmt-bert-base-multilingual-uncased.pt") + args = parser.parse_args() + bert_model_weights = args.bert_model_weights_path + args.bert_model +".pt" + print(bert_model_weights) + args.output_dir.mkdir(exist_ok=True) + outfile = args.output_dir.joinpath(args.output_name) + + # pretrained_model_name_or_path = args.bert_model + # bert_pretrained = BertForPreTraining.from_pretrained(pretrained_model_name_or_path, cache=args.output_dir) + + if args.layers is None: + if 'large' in args.bert_model: + n_layers = 24 + else: + n_layers = 12 + else: + n_layers = args.layers + + bert_weights = torch.load(bert_model_weights) + bert = onmt.models.BERT(105879) + bertlm = onmt.models.BertLM(bert) + model_weights = load_bert_weights(bertlm, bert_weights, n_layers) + ckp={'model': model_weights} + torch.save(ckp, outfile) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/onmt/utils/bert_tokenization.py b/onmt/utils/bert_tokenization.py new file mode 100644 index 0000000000..3937d6e011 --- /dev/null +++ b/onmt/utils/bert_tokenization.py @@ -0,0 +1,415 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import collections +import logging +import os +import unicodedata +from io import open + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'bert-base-uncased': 512, + 'bert-large-uncased': 512, + 'bert-base-cased': 512, + 'bert-large-cased': 512, + 'bert-base-multilingual-uncased': 512, + 'bert-base-multilingual-cased': 512, + 'bert-base-chinese': 512, +} +VOCAB_NAME = 'vocab.txt' + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + return vocab_file + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): + logger.warning("The pre-trained model you are loading is a cased model but you have not set " + "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " + "you may want to check this behavior.") + kwargs['do_lower_case'] = False + elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): + logger.warning("The pre-trained model you are loading is an uncased model but you have set " + "`do_lower_case` to False. We are setting `do_lower_case=True` for you " + "but you may want to check this behavior.") + kwargs['do_lower_case'] = True + else: + vocab_file = pretrained_model_name_or_path + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + vocab_file)) + return None + if resolved_vocab_file == vocab_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) + return tokenizer + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/onmt/utils/file_utils.py b/onmt/utils/file_utils.py new file mode 100644 index 0000000000..17bdd258ea --- /dev/null +++ b/onmt/utils/file_utils.py @@ -0,0 +1,270 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import sys +import json +import logging +import os +import shutil +import tempfile +import fnmatch +from functools import wraps +from hashlib import sha256 +import sys +from io import open + +import boto3 +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + Path.home() / '.pytorch_pretrained_bert')) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "pytorch_model.bin" + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + try: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + etag = None + else: + etag = response.headers.get("ETag") + except EnvironmentError: + etag = None + + if sys.version_info[0] == 2 and etag is not None: + etag = etag.decode('utf-8') + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # If we don't have a connection (etag is None) and can't identify the file + # try to get the last downloaded one + if not os.path.exists(cache_path) and etag is None: + matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') + matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) + if matching_files: + cache_path = os.path.join(cache_dir, matching_files[-1]) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w') as meta_file: + output_string = json.dumps(meta) + if sys.version_info[0] == 2 and isinstance(output_string, str): + output_string = unicode(output_string, 'utf-8') # The beauty of python 2 + meta_file.write(output_string) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + ''' + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + ''' + collection = set() + with open(filename, 'r', encoding='utf-8') as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/pregenerate_bert_training_data.py b/pregenerate_bert_training_data.py new file mode 100755 index 0000000000..fe1ac44221 --- /dev/null +++ b/pregenerate_bert_training_data.py @@ -0,0 +1,352 @@ +from argparse import ArgumentParser +from pathlib import Path +from tqdm import tqdm, trange +from tempfile import TemporaryDirectory +import shelve + +from random import random, randrange, randint, shuffle, choice, sample +from onmt.utils.bert_tokenization import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP +from onmt.utils.file_utils import cached_path +import numpy as np +import json +from onmt.inputters.inputter import get_bert_fields, _build_bert_fields_vocab +from onmt.inputters.dataset_bert import BertDataset +import os +from collections import Counter, defaultdict +import torch + +class DocumentDatabase: + def __init__(self, reduce_memory=False): + if reduce_memory: + self.temp_dir = TemporaryDirectory() + self.working_dir = Path(self.temp_dir.name) + self.document_shelf_filepath = self.working_dir / 'shelf.db' + self.document_shelf = shelve.open(str(self.document_shelf_filepath), + flag='n', protocol=-1) + self.documents = None + else: + self.documents = [] + self.document_shelf = None + self.document_shelf_filepath = None + self.temp_dir = None + self.doc_lengths = [] + self.doc_cumsum = None + self.cumsum_max = None + self.reduce_memory = reduce_memory + + def add_document(self, document): + if not document: + return + if self.reduce_memory: + current_idx = len(self.doc_lengths) + self.document_shelf[str(current_idx)] = document + else: + self.documents.append(document) + self.doc_lengths.append(len(document)) + + def _precalculate_doc_weights(self): + self.doc_cumsum = np.cumsum(self.doc_lengths) + self.cumsum_max = self.doc_cumsum[-1] + + def sample_doc(self, current_idx, sentence_weighted=True): + # Uses the current iteration counter to ensure we don't sample the same doc twice + if sentence_weighted: + # With sentence weighting, we sample docs proportionally to their sentence length + if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths): + self._precalculate_doc_weights() + rand_start = self.doc_cumsum[current_idx] + rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] + sentence_index = randrange(rand_start, rand_end) % self.cumsum_max + sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') + else: + # If we don't use sentence weighting, then every doc has an equal chance to be chosen + sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths) + assert sampled_doc_index != current_idx + if self.reduce_memory: + return self.document_shelf[str(sampled_doc_index)] + else: + return self.documents[sampled_doc_index] + + def __len__(self): + return len(self.doc_lengths) + + def __getitem__(self, item): + if self.reduce_memory: + return self.document_shelf[str(item)] + else: + return self.documents[item] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + if self.document_shelf is not None: + self.document_shelf.close() + if self.temp_dir is not None: + self.temp_dir.cleanup() + + +def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): + """Truncates a pair of sequences to a maximum sequence length. Lifted from Google's BERT repo.""" + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_num_tokens: + break + + trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b + assert len(trunc_tokens) >= 1 + + # We want to sometimes truncate from the front and sometimes from the + # back to add more randomness and avoid biases. + if random() < 0.5: + del trunc_tokens[0] + else: + trunc_tokens.pop() + + +def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, tokenizer): + """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but + with several refactors to clean it up and remove a lot of unnecessary variables.""" + vocab_dict = tokenizer.vocab + vocab_list = list(vocab_dict.keys()) + cand_indices = [] + for (i, token) in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + cand_indices.append(i) + + num_to_mask = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + shuffle(cand_indices) + mask_indices = sorted(sample(cand_indices, num_to_mask)) + masked_token_labels = [] + for index in mask_indices: + # 80% of the time, replace with [MASK] + if random() < 0.8: + masked_token = "[MASK]" + else: + # 10% of the time, keep original + if random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = choice(vocab_list) + masked_token_labels.append(tokens[index]) + # Once we've saved the true label for that token, we can overwrite it with the masked version + tokens[index] = masked_token + lm_labels_ids = [-1 for _ in tokens] + for (i, token) in zip(mask_indices, masked_token_labels): + lm_labels_ids[i] = vocab_dict[token] + assert len(lm_labels_ids) == len(tokens) + return tokens, mask_indices, masked_token_labels, lm_labels_ids + + +def create_instances_from_document( + doc_database, doc_idx, max_seq_length, short_seq_prob, + masked_lm_prob, max_predictions_per_seq, tokenizer): + """This code is mostly a duplicate of the equivalent function from Google BERT's repo. + However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function. + Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence + (rather than each document) has an equal chance of being sampled as a false example for the NextSentence task.""" + document = doc_database[doc_idx] + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 3 + + # We *usually* want to fill up the entire sequence since we are padding + # to `max_seq_length` anyways, so short sequences are generally wasted + # computation. However, we *sometimes* + # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter + # sequences to minimize the mismatch between pre-training and fine-tuning. + # The `target_seq_length` is just a rough target however, whereas + # `max_seq_length` is a hard limit. + target_seq_length = max_num_tokens + if random() < short_seq_prob: + target_seq_length = randint(2, max_num_tokens) + + # We DON'T just concatenate all of the tokens from a document into a long + # sequence and choose an arbitrary split point because this would make the + # next sentence prediction task too easy. Instead, we split the input into + # segments "A" and "B" based on the actual "sentences" provided by the user + # input. + instances = [] + current_chunk = [] + current_length = 0 + i = 0 + while i < len(document): + segment = document[i] + current_chunk.append(segment) + current_length += len(segment) + if i == len(document) - 1 or current_length >= target_seq_length: + if current_chunk: + # `a_end` is how many segments from `current_chunk` go into the `A` + # (first) sentence. + a_end = 1 + if len(current_chunk) >= 2: + a_end = randrange(1, len(current_chunk)) + + tokens_a = [] + for j in range(a_end): + tokens_a.extend(current_chunk[j]) + + tokens_b = [] + + # Random next + if len(current_chunk) == 1 or random() < 0.5: + is_next = False + target_b_length = target_seq_length - len(tokens_a) + + # Sample a random document, with longer docs being sampled more frequently + random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True) + + random_start = randrange(0, len(random_document)) + for j in range(random_start, len(random_document)): + tokens_b.extend(random_document[j]) + if len(tokens_b) >= target_b_length: + break + # We didn't actually use these segments so we "put them back" so + # they don't go to waste. + num_unused_segments = len(current_chunk) - a_end + i -= num_unused_segments + # Actual next + else: + is_next = True + for j in range(a_end, len(current_chunk)): + tokens_b.extend(current_chunk[j]) + truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) + + assert len(tokens_a) >= 1 + assert len(tokens_b) >= 1 + + tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] + # The segment IDs are 0 for the [CLS] token, the A tokens and the first [SEP] + # They are 1 for the B tokens and the final [SEP] + segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] + + tokens, masked_lm_positions, masked_lm_labels, lm_labels_ids = create_masked_lm_predictions( + tokens, masked_lm_prob, max_predictions_per_seq, tokenizer) + + instance = { + "tokens": tokens, + "segment_ids": segment_ids, + "is_next": is_next, + # "masked_lm_positions": masked_lm_positions, + # "masked_lm_labels": masked_lm_labels, + "lm_labels_ids": lm_labels_ids} + instances.append(instance) + current_chunk = [] + current_length = 0 + i += 1 + + return instances + + +def _build_bert_vocab(vocab, name, counters, min_freq=0): + """ similar to _load_vocab in inputter.py, but build from a vocab list. + in place change counters + """ + vocab_size = len(vocab) + for i, token in enumerate(vocab): + counters[name][token] = vocab_size - i + min_freq + return vocab, vocab_size + + +def main(): + parser = ArgumentParser() + parser.add_argument('--train_corpus', type=Path, default="/home/lzeng/Documents/OpenNMT-py/onmt/inputters/small_wiki_sentence_corpus.txt") # required=True) + parser.add_argument("--corpus_type", type=str, default="train") # required=True) + parser.add_argument("--output_dir", type=Path, default="/home/lzeng/Documents/OpenNMT-py/onmt/inputters/test_opennmt/") # required=True) + parser.add_argument("--output_name", type=str, default="dataset") + parser.add_argument("--bert_model", type=str, default="bert-base-uncased")#, # required=True, + # choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", + # "bert-base-multilingual", "bert-base-chinese"]) + # parser.add_argument("--vocab_pathname", type=Path, required=True) # vocab file correspand to bert_model + + parser.add_argument("--do_lower_case", default=True) # action="store_true") + + parser.add_argument("--reduce_memory", action="store_true", + help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") + + parser.add_argument("--epochs_to_generate", type=int, default=20, + help="Number of epochs of data to pregenerate") + parser.add_argument("--max_seq_len", type=int, default=256) # 128 + parser.add_argument("--short_seq_prob", type=float, default=0.1, + help="Probability of making a short sentence as a training example") + parser.add_argument("--masked_lm_prob", type=float, default=0.15, + help="Probability of masking each token for the LM task") + parser.add_argument("--max_predictions_per_seq", type=int, default=20, + help="Maximum number of tokens to mask in each sequence") + parser.add_argument("--tokens_min_frequency", type=int, default=0) # not tested + parser.add_argument("--vocab_size_multiple", type=int, default=1) # not tested + + args = parser.parse_args() + fields = get_bert_fields() # + tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) + # save Vocab file + vocab_file_url = PRETRAINED_VOCAB_ARCHIVE_MAP[args.bert_model] + vocab_dir = Path.joinpath(args.output_dir, f"{args.bert_model}-vocab.txt") + vocab_file = cached_path(vocab_file_url, cache_dir=vocab_dir) + print("Donwload ") + with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: + with args.train_corpus.open() as f: + doc = [] + for line in tqdm(f, desc="Loading Dataset", unit=" lines"): + line = line.strip() + if line == "": + docs.add_document(doc) + doc = [] + else: + tokens = tokenizer.tokenize(line) + doc.append(tokens) + if doc: + docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added + if len(docs) <= 1: + exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to " + "ensure that random NextSentences are not sampled from the same document. Please add blank lines to " + "indicate breaks between documents in your input file. If your dataset does not contain multiple " + "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, " + "sections or paragraphs.") + + args.output_dir.mkdir(exist_ok=True) + for epoch in trange(args.epochs_to_generate, desc="Epoch"): + # epoch_filename = args.output_dir / f"epoch_{epoch}.json" + epoch_filename = args.output_dir / f"{args.output_name}.{args.corpus_type}.{epoch}.pt" + json_name = args.output_dir / f"{args.output_name}.{args.corpus_type}.{epoch}.json" + num_instances = 0 + with json_name.open('w') as epoch_file: + docs_instances = [] + for doc_idx in trange(len(docs), desc="Document"): + doc_instances = create_instances_from_document( + docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, + masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, + tokenizer=tokenizer) # return a list of dict [{}] + docs_instances.extend(doc_instances) + doc_instances_json = [json.dumps(instance) for instance in doc_instances] + for instance in doc_instances_json: + epoch_file.write(instance + '\n') + num_instances += 1 + # build BertDataset from instances collected from different document + dataset = BertDataset(fields, docs_instances) + dataset.save(epoch_filename) + num_doc_instances = len(docs_instances) + print("output file {}, num_example {}, max_seq_len {}".format(epoch_filename,num_doc_instances,args.max_seq_len)) + + metrics_file = args.output_dir / f"{args.output_name}.metrics.{args.corpus_type}.{epoch}.json" + with metrics_file.open('w') as metrics_file: + metrics = { + "num_training_examples": num_instances, + "max_seq_len": args.max_seq_len + } + metrics_file.write(json.dumps(metrics)) + # Build file Vocab.pt + if args.corpus_type == "train": + print("Building vocab from text file...") + vocab_list = list(tokenizer.vocab.keys()) + counters = defaultdict(Counter) + _, vocab_size = _build_bert_vocab(vocab_list, "tokens", counters) + fields = _build_bert_fields_vocab(fields, counters, vocab_size, args.tokens_min_frequency, args.vocab_size_multiple) # + bert_vocab_file = args.output_dir / f"{args.output_name}.vocab.pt" + torch.save(fields, bert_vocab_file) + + +if __name__ == '__main__': + main() From 1c0498e7f5e4542d27f132cf806ce7a46f6c2671 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Thu, 18 Jul 2019 11:56:21 +0200 Subject: [PATCH 03/28] activation function --- onmt/utils/fn_activation.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 onmt/utils/fn_activation.py diff --git a/onmt/utils/fn_activation.py b/onmt/utils/fn_activation.py new file mode 100644 index 0000000000..8fd0e0c027 --- /dev/null +++ b/onmt/utils/fn_activation.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn +import math + + +class GELU(nn.Module): + """ Implementation of the gelu activation function + + For information: OpenAI GPT's gelu is slightly different + (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) + * (x + 0.044715 * torch.pow(x, 3)))) + see https://arxiv.org/abs/1606.08415 + + Examples:: + >>> m = GELU() + >>> inputs = torch.randn(2) + >>> outputs = m(inputs) + """ + def forward(self, x): + gelu = x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + return gelu From de3ca851641ae2eddca6ff7add957921a070c5f4 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Thu, 18 Jul 2019 11:58:01 +0200 Subject: [PATCH 04/28] bert dataset --- onmt/inputters/dataset_bert.py | 50 ++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 onmt/inputters/dataset_bert.py diff --git a/onmt/inputters/dataset_bert.py b/onmt/inputters/dataset_bert.py new file mode 100644 index 0000000000..6deebe09cf --- /dev/null +++ b/onmt/inputters/dataset_bert.py @@ -0,0 +1,50 @@ +import torch +from torchtext.data import Dataset as TorchtextDataset +from torchtext.data import Example + + +def bert_sort_key(ex): + """Sort using the number of tokens in the sequence.""" + return len(ex.tokens) + + +class BertDataset(TorchtextDataset): + """Defines a BERT dataset composed of Examples along with its Fields. + Args: + fields_dict (dict[str, Field]): a dict with the structure + returned by :func:`onmt.inputters.get_bert_fields()`. + instances (Iterable[dict[]]): a list of document instance that + are going to be transfored into Examples + """ + + def __init__(self, fields_dict, instances, sort_key=bert_sort_key, filter_pred=None): + self.sort_key = sort_key + examples = [] + # NOTE: need to adapt ? + ex_fields = {k: [(k, v)] for k, v in fields_dict.items()} + # print(ex_fields) + for instance in instances: + # print("###################") + # print(instance) + # print("###################") + ex = Example.fromdict(instance, ex_fields) + # print(ex) + examples.append(ex) + # exit(1) + fields_list = list(fields_dict.items()) + + super(BertDataset, self).__init__(examples, fields_list, filter_pred) + + def __getattr__(self, attr): + # avoid infinite recursion when fields isn't defined + if 'fields' not in vars(self): + raise AttributeError + if attr in self.fields: + return (getattr(x, attr) for x in self.examples) + else: + raise AttributeError + + def save(self, path, remove_fields=True): + if remove_fields: + self.fields = [] + torch.save(self, path) From ede0250ee37e89243c9a9e3a2773829ac8abebe1 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Fri, 19 Jul 2019 12:29:12 +0200 Subject: [PATCH 05/28] add a new way of using bert --- onmt/inputters/dataset_bert.py | 4 +- onmt/model_builder.py | 107 ++++++++++++++++++++++++++++++++- onmt/models/__init__.py | 5 +- onmt/models/bert.py | 9 ++- onmt/models/language_model.py | 2 +- onmt/modules/bert_embed.py | 2 + onmt/train_single.py | 5 +- onmt/trainer.py | 11 +++- 8 files changed, 130 insertions(+), 15 deletions(-) diff --git a/onmt/inputters/dataset_bert.py b/onmt/inputters/dataset_bert.py index 6deebe09cf..92f7ea22b7 100644 --- a/onmt/inputters/dataset_bert.py +++ b/onmt/inputters/dataset_bert.py @@ -3,7 +3,7 @@ from torchtext.data import Example -def bert_sort_key(ex): +def bert_text_sort_key(ex): """Sort using the number of tokens in the sequence.""" return len(ex.tokens) @@ -17,7 +17,7 @@ class BertDataset(TorchtextDataset): are going to be transfored into Examples """ - def __init__(self, fields_dict, instances, sort_key=bert_sort_key, filter_pred=None): + def __init__(self, fields_dict, instances, sort_key=bert_text_sort_key, filter_pred=None): self.sort_key = sort_key examples = [] # NOTE: need to adapt ? diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 0aa063fa50..956a0b5b88 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -19,8 +19,9 @@ from onmt.utils.logging import logger from onmt.utils.parse import ArgumentParser -from onmt.models import BertLM, BERT -# from onmt.modules.bert_embed import BertEmbeddings +from onmt.models import BertLM, BERT, BertPreTrainingHeads +from onmt.modules.bert_embed import BertEmbeddings +from collections import OrderedDict def build_embeddings(opt, text_field, for_encoder=True): @@ -269,6 +270,13 @@ def build_bertLM(model_opt, fields, gpu, checkpoint=None, gpu_id=None): # tokens_emb = bert.embeddings model = BertLM(bert) + # # if model_opt.task == 'classification': + # generator = nn.Sequential(nn.Linear(bert_encoder.d_model, opt.num_labels), + # nn.LogSoftmax(dim=-1)) + # # if model_opt.task == 'generation': + # generator = MaskedLanguageModel(bert_encoder.d_model, bert_encoder.vocab_size, + # bert_encoder.embeddings.word_embeddings.weight) + # model.cls = generator # load states from checkpoints if checkpoint is not None: logger.info("load states from checkpoints...") @@ -303,3 +311,98 @@ def build_bert_encoder(model_opt, fields, gpu, checkpoint=None, gpu_id=None): dropout=model_opt.dropout[0], max_relative_positions=model_opt.max_relative_positions) return bert + + +def build_bert_embeddings(opt, fields): + token_fields_vocab = fields['tokens'].vocab + vocab_size = len(token_fields_vocab) + emb_size = opt.word_vec_size + bert_emb = BertEmbeddings(vocab_size, emb_size, + dropout=opt.dropout[0]) + return bert_emb + + +def build_bert_encoder_v2(model_opt, fields, embs): + # TODO: need to be more elegent + vocab_size = embs.vocab_size + bert = BERT(vocab_size, num_layers=model_opt.layers, + d_model=model_opt.word_vec_size, heads=model_opt.heads, + dropout=model_opt.dropout[0], embeds=embs, + max_relative_positions=model_opt.max_relative_positions) + return bert + + +def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): + """Build a model from opts. + + Args: + model_opt: the option loaded from checkpoint. It's important that + the opts have been updated and validated. See + :class:`onmt.utils.parse.ArgumentParser`. + fields (dict[str, torchtext.data.Field]): + `Field` objects for the model. + gpu (bool): whether to use gpu. + checkpoint: the model gnerated by train phase, or a resumed snapshot + model from a stopped training. + gpu_id (int or NoneType): Which GPU to use. + + Returns: + the NMTModel. + """ + logger.info('Building BERT model...') + # Build embeddings. + bert_emb = build_bert_embeddings(model_opt, fields) + + # Build encoder. + bert_encoder = build_bert_encoder_v2(model_opt, fields, bert_emb) + + + # Build NMTModel(= encoder + decoder). + gpu = use_gpu(opt) + if gpu and gpu_id is not None: + device = torch.device("cuda", gpu_id) + elif gpu and not gpu_id: + device = torch.device("cuda") + elif not gpu: + device = torch.device("cpu") + + + # Build Generator. + # if model_opt.task == 'pretraining': + generator = BertPreTrainingHeads(bert_encoder.d_model, bert_encoder.vocab_size, + bert_encoder.embeddings.word_embeddings.weight) + # # if model_opt.task == 'classification': + # generator = nn.Sequential(nn.Linear(bert_encoder.d_model, opt.num_labels), + # nn.LogSoftmax(dim=-1)) + # # if model_opt.task == 'generation': + # generator = MaskedLanguageModel(bert_encoder.d_model, bert_encoder.vocab_size, + # bert_encoder.embeddings.word_embeddings.weight) + # if model_opt.share_embeddings: + # generator.mask_lm.decode.weight = bert_emb.word_embeddings.weight + + model = nn.Sequential(OrderedDict([ + ('bert', bert_encoder), + ('cls', generator)])) + # Load the model states from checkpoint or initialize them. + if checkpoint is not None: + model.load_state_dict(checkpoint['model'], strict=False) + # generator.load_state_dict(checkpoint['generator'], strict=False) + else: + logger.info("No checkpoint, Initialize Parameters...") + if model_opt.param_init_normal != 0.0: + normal_std = model_opt.param_init_normal + for p in model.parameters(): + p.data.normal_(mean=0, std=normal_std) + elif model_opt.param_init != 0.0: + for p in model.parameters(): + p.data.uniform_(-model_opt.param_init, model_opt.param_init) + elif model_opt.param_init_glorot: + for p in model.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + else: + raise AttributeError("Initialization method haven't be used!") + + model.to(device) + logger.info(model) + return model diff --git a/onmt/models/__init__.py b/onmt/models/__init__.py index 8cd5aa41d7..09a54a1e33 100644 --- a/onmt/models/__init__.py +++ b/onmt/models/__init__.py @@ -1,8 +1,9 @@ """Module defining models.""" from onmt.models.model_saver import build_model_saver, ModelSaver from onmt.models.model import NMTModel -from onmt.models.language_model import BertLM from onmt.models.bert import BERT, BertLayerNorm +from onmt.models.language_model import BertLM, BertPreTrainingHeads __all__ = ["build_model_saver", "ModelSaver", "NMTModel", "BERT", - "BertLM", "BertLayerNorm", "check_sru_requirement"] + "BertLM", "BertLayerNorm", "BertPreTrainingHeads", + "check_sru_requirement"] diff --git a/onmt/models/bert.py b/onmt/models/bert.py index fa56397789..744ffc94d4 100644 --- a/onmt/models/bert.py +++ b/onmt/models/bert.py @@ -10,7 +10,7 @@ class BERT(nn.Module): Use a Transformer Encoder as Language modeling. """ def __init__(self, vocab_size, num_layers=12, d_model=768, heads=12, - dropout=0.1, max_relative_positions=0): + dropout=0.1, max_relative_positions=0, embeds=None): super(BERT, self).__init__() self.vocab_size = vocab_size self.num_layers = num_layers @@ -27,8 +27,11 @@ def __init__(self, vocab_size, num_layers=12, d_model=768, heads=12, # 1. Token embeddings # 2. Segmentation embeddings # 3. Position embeddings - self.embeddings = BertEmbeddings(vocab_size=vocab_size, - embed_size=d_model, dropout=dropout) + if embeds is not None: + self.embeddings = embeds + else: + self.embeddings = BertEmbeddings(vocab_size=vocab_size, + embed_size=d_model, dropout=dropout) # Transformer Encoder Block self.transformer_encoder = nn.ModuleList( diff --git a/onmt/models/language_model.py b/onmt/models/language_model.py index 51575da00d..20b1175791 100644 --- a/onmt/models/language_model.py +++ b/onmt/models/language_model.py @@ -10,7 +10,7 @@ class BertLM(nn.Module): BERT Language Model for pretraining, trained with 2 task : Next Sentence Prediction Model + Masked Language Model """ - def __init__(self, bert: onmt.models.BERT): + def __init__(self, bert): """ Args: bert: BERT model which should be trained diff --git a/onmt/modules/bert_embed.py b/onmt/modules/bert_embed.py index 2f3ebfa9d4..92b9961cac 100644 --- a/onmt/modules/bert_embed.py +++ b/onmt/modules/bert_embed.py @@ -39,6 +39,8 @@ def __init__(self, vocab_size, embed_size, pad_idx=0, dropout=0.1): dropout: dropout rate """ super(BertEmbeddings, self).__init__() + self.vocab_size = vocab_size + self.embed_size = embed_size self.word_padding_idx = pad_idx self.word_embeddings = TokenEmb(vocab_size, hidden_size=embed_size, padding_idx=pad_idx) diff --git a/onmt/train_single.py b/onmt/train_single.py index 26bb8193d9..a9413a963a 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -6,7 +6,7 @@ from onmt.inputters.inputter import build_dataset_iter, \ load_old_vocab, old_style_vocab, build_dataset_iter_multiple -from onmt.model_builder import build_model, build_bert +from onmt.model_builder import build_model, build_bert, build_bert_model from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed from onmt.trainer import build_trainer @@ -92,7 +92,8 @@ def main(opt, device_id, batch_queue=None, semaphore=None): # Build model. if opt.is_bert: - model = build_bert(model_opt, opt, fields, checkpoint) + # model = build_bert(model_opt, opt, fields, checkpoint) # V1 + model = build_bert_model(model_opt, opt, fields, checkpoint) # V2 n_params = 0 for param in model.parameters(): n_params += param.nelement() diff --git a/onmt/trainer.py b/onmt/trainer.py index caf1886ef1..d3021090e6 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -134,8 +134,9 @@ def __init__(self, model, train_loss, valid_loss, optim, self.earlystopper = earlystopper self.dropout = dropout self.dropout_steps = dropout_steps - self.is_bert = True if isinstance(self.model, - onmt.models.language_model.BertLM) else False # NOTE: NEW parameter for bert training + self.is_bert = True if hasattr(self.model, 'bert') else False + # self.is_bert = True if isinstance(self.model, + # onmt.models.language_model.BertLM) else False # NOTE: NEW parameter for bert training for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 @@ -532,7 +533,11 @@ def _bert_gradient_accumulation(self, true_batches, normalization, total_stats, # 2. F-prop all to get log likelihood of two task. if self.accum_count == 1: self.optim.zero_grad() - seq_class_log_prob, prediction_log_prob = self.model(input_ids, token_type_ids) + # Version 2: + all_encoder_layers, pooled_out = self.model.bert(input_ids, token_type_ids) + seq_class_log_prob, prediction_log_prob = self.model.cls(all_encoder_layers, pooled_out) + # Version 1: + # seq_class_log_prob, prediction_log_prob = self.model(input_ids, token_type_ids) # NOTE: (batch_size, 2), (batch_size, seq_size, vocab_size) outputs = (seq_class_log_prob, prediction_log_prob) # outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt) From 1dfa50a316a4b084c253819cf9c325de61621bf2 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Fri, 19 Jul 2019 18:41:49 +0200 Subject: [PATCH 06/28] merge some function --- onmt/encoders/transformer.py | 5 ++ onmt/inputters/__init__.py | 5 +- onmt/model_builder.py | 28 ++++--- onmt/models/bert.py | 9 +- onmt/models/model_saver.py | 84 +++++++++++++------ .../{bert_embed.py => bert_embeddings.py} | 41 +++------ onmt/modules/position_ffn.py | 4 + onmt/trainer.py | 32 +++---- onmt/utils/fn_activation.py | 4 + onmt/utils/loss.py | 84 ++++++++++--------- onmt/utils/statistics.py | 2 + pregenerate_bert_training_data.py | 4 + train.py | 2 +- 13 files changed, 178 insertions(+), 126 deletions(-) rename onmt/modules/{bert_embed.py => bert_embeddings.py} (63%) diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 26f3dd7e89..936d6edc57 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -21,6 +21,11 @@ class TransformerEncoderLayer(nn.Module): heads (int): the number of head for MultiHeadedAttention. d_ff (int): the second-layer of the PositionwiseFeedForward. dropout (float): dropout probability(0-1.0). + activation (str): activation function to chose from + ['ReLU', 'GeLU'] + is_bert (bool): default False. When set True, + layer_norm will be performed on the + direct connection of residual block. """ def __init__(self, d_model, heads, d_ff, dropout, diff --git a/onmt/inputters/__init__.py b/onmt/inputters/__init__.py index 10990102df..a3af38e144 100644 --- a/onmt/inputters/__init__.py +++ b/onmt/inputters/__init__.py @@ -11,7 +11,7 @@ from onmt.inputters.image_dataset import img_sort_key, ImageDataReader from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader from onmt.inputters.datareader_base import DataReaderBase -from onmt.inputters.dataset_bert import BertDataset +from onmt.inputters.dataset_bert import BertDataset, bert_text_sort_key str2reader = { "text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader} @@ -22,5 +22,6 @@ __all__ = ['Dataset', 'load_old_vocab', 'get_fields', 'get_bert_fields', 'DataReaderBase', 'filter_example', 'old_style_vocab', 'build_vocab', 'OrderedIterator', 'text_sort_key', - 'img_sort_key', 'audio_sort_key', 'BertDataset', + 'img_sort_key', 'audio_sort_key', + 'BertDataset', 'bert_text_sort_key', 'TextDataReader', 'ImageDataReader', 'AudioDataReader'] diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 956a0b5b88..e5b2fa733b 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -20,7 +20,7 @@ from onmt.utils.parse import ArgumentParser from onmt.models import BertLM, BERT, BertPreTrainingHeads -from onmt.modules.bert_embed import BertEmbeddings +from onmt.modules.bert_embeddings import BertEmbeddings from collections import OrderedDict @@ -347,7 +347,7 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): gpu_id (int or NoneType): Which GPU to use. Returns: - the NMTModel. + the BERT model. """ logger.info('Building BERT model...') # Build embeddings. @@ -366,19 +366,27 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): elif not gpu: device = torch.device("cpu") - + """Main part for transfer learning: + set opt.task to `pretraining` if want finetuning; + set opt.task to `classification` if want use Bert to classification task; + set opt.task to `generation` if want use Bert to generate a sequence. + The pooled_output from bert encoder will be feed to classification generator; + The all_encoder_layers from bert encoder will be feed to generation generator; + """ # Build Generator. # if model_opt.task == 'pretraining': generator = BertPreTrainingHeads(bert_encoder.d_model, bert_encoder.vocab_size, bert_encoder.embeddings.word_embeddings.weight) - # # if model_opt.task == 'classification': - # generator = nn.Sequential(nn.Linear(bert_encoder.d_model, opt.num_labels), - # nn.LogSoftmax(dim=-1)) - # # if model_opt.task == 'generation': - # generator = MaskedLanguageModel(bert_encoder.d_model, bert_encoder.vocab_size, + # if model_opt.share_embeddings: + # generator.mask_lm.decode.weight = bert_emb.word_embeddings.weight + # if model_opt.task == 'classification': + # generator = nn.Sequential(nn.Linear(bert_encoder.d_model, opt.num_labels), + # nn.LogSoftmax(dim=-1)) + # if model_opt.task == 'generation': + # generator = MaskedLanguageModel(bert_encoder.d_model, bert_encoder.vocab_size, # bert_encoder.embeddings.word_embeddings.weight) - # if model_opt.share_embeddings: - # generator.mask_lm.decode.weight = bert_emb.word_embeddings.weight + # if model_opt.share_embeddings: + # generator.mask_lm.decode.weight = bert_emb.word_embeddings.weight model = nn.Sequential(OrderedDict([ ('bert', bert_encoder), diff --git a/onmt/models/bert.py b/onmt/models/bert.py index 744ffc94d4..1d7b481241 100644 --- a/onmt/models/bert.py +++ b/onmt/models/bert.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from onmt.modules.bert_embed import BertEmbeddings +from onmt.modules.bert_embeddings import BertEmbeddings from onmt.encoders.transformer import TransformerEncoderLayer @@ -52,7 +52,10 @@ def forward(self, input_ids, token_type_ids=None, input_mask=None, input_mask: shape [batch, seq], 1 for masked position(that padding) output_all_encoded_layers: if out contain all hidden layer Returns: - all_encoder_layers: list of out in shape (batch, src, d_model) + all_encoder_layers: list of out in shape (batch, src, d_model), + to be used for generation task + pooled_output: shape (batch, d_model), + to be used for classification task """ # # version 1: coder timo waiting for mask of size [B,1,T,T] # [batch, seq] -> [batch, 1, seq] @@ -90,7 +93,7 @@ def forward(self, input_ids, token_type_ids=None, input_mask=None, # shape: 2D tensor [batch, seq]: 1 for tokens, 0 for paddings input_mask = input_ids.data.eq(padding_idx) # if token_type_ids is None: - # NOTE: not needed! already done in bert_embed.py + # NOTE: not needed! already done in bert_embeddings.py # token_type_ids = torch.zeros_like(input_ids) # [batch, seq] -> [batch, 1, seq] attention_mask = input_mask.unsqueeze(1) diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index e4a8d10768..b76e0a010e 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -1,7 +1,7 @@ import os import torch import torch.nn as nn - +from torchtext.data import Field from collections import deque from onmt.utils.logging import logger @@ -9,20 +9,20 @@ def build_model_saver(model_opt, opt, model, fields, optim): - if opt.is_bert: - model_saver = BertModelSaver(opt.save_model, - model, - model_opt, - fields, - optim, - opt.keep_checkpoint) - else: - model_saver = ModelSaver(opt.save_model, - model, - model_opt, - fields, - optim, - opt.keep_checkpoint) + # if opt.is_bert: + # model_saver = BertModelSaver(opt.save_model, + # model, + # model_opt, + # fields, + # optim, + # opt.keep_checkpoint) + # else: + model_saver = ModelSaver(opt.save_model, + model, + model_opt, + fields, + optim, + opt.keep_checkpoint) return model_saver @@ -108,9 +108,16 @@ def _save(self, step, model): real_model = (model.module if isinstance(model, nn.DataParallel) else model) - real_generator = (real_model.generator.module - if isinstance(real_model.generator, nn.DataParallel) - else real_model.generator) + if hasattr(real_model, "generator"): + print('NMT generator saving') + real_generator = (real_model.generator.module + if isinstance(real_model.generator, nn.DataParallel) + else real_model.generator) + if hasattr(real_model, "cls"): + print('BERT generator saving') + real_generator = (real_model.cls.module + if isinstance(real_model.cls, nn.DataParallel) + else real_model.cls) model_state_dict = real_model.state_dict() model_state_dict = {k: v for k, v in model_state_dict.items() @@ -121,15 +128,38 @@ def _save(self, step, model): # were not originally here. vocab = deepcopy(self.fields) - for side in ["src", "tgt"]: - keys_to_pop = [] - if hasattr(vocab[side], "fields"): - unk_token = vocab[side].fields[0][1].vocab.itos[0] - for key, value in vocab[side].fields[0][1].vocab.stoi.items(): - if value == 0 and key != unk_token: - keys_to_pop.append(key) - for key in keys_to_pop: - vocab[side].fields[0][1].vocab.stoi.pop(key, None) + for name, field in vocab.items(): + if isinstance(field, Field): + if hasattr(field, "vocab"): + assert name == 'tokens' + keys_to_pop = [] + unk_token = field.unk_token + unk_id = field.vocab.stoi[unk_token] + for key, value in field.vocab.stoi.items(): + if value == unk_id and key != unk_token: + keys_to_pop.append(key) + for key in keys_to_pop: + field.vocab.stoi.pop(key, None) + else: + if hasattr(field, "fields"): + assert name in ["src", "tgt"] + keys_to_pop = [] + unk_token = field.fields[0][1].vocab.itos[0] + for key, value in field.fields[0][1].vocab.stoi.items(): + if value == 0 and key != unk_token: + keys_to_pop.append(key) + for key in keys_to_pop: + field.fields[0][1].vocab.stoi.pop(key, None) + + # for side in ["src", "tgt"]: + # keys_to_pop = [] + # if hasattr(vocab[side], "fields"): + # unk_token = vocab[side].fields[0][1].vocab.itos[0] + # for key, value in vocab[side].fields[0][1].vocab.stoi.items(): + # if value == 0 and key != unk_token: + # keys_to_pop.append(key) + # for key in keys_to_pop: + # vocab[side].fields[0][1].vocab.stoi.pop(key, None) checkpoint = { 'model': model_state_dict, diff --git a/onmt/modules/bert_embed.py b/onmt/modules/bert_embeddings.py similarity index 63% rename from onmt/modules/bert_embed.py rename to onmt/modules/bert_embeddings.py index 92b9961cac..cc3bda783e 100644 --- a/onmt/modules/bert_embed.py +++ b/onmt/modules/bert_embeddings.py @@ -2,50 +2,34 @@ import torch.nn as nn -class TokenEmb(nn.Embedding): - """ Embeddings for tokens. - """ - def __init__(self, vocab_size, hidden_size=768, padding_idx=0): - super(TokenEmb, self).__init__(vocab_size, hidden_size, - padding_idx=padding_idx) - - -class SegmentEmb(nn.Embedding): - """ Embeddings for token's type: sentence A(0), sentence B(1). Padding with 0. - """ - def __init__(self, type_vocab_size=2, hidden_size=768, padding_idx=0): - super(SegmentEmb, self).__init__(type_vocab_size, hidden_size, - padding_idx=padding_idx) - - -class PositionEmb(nn.Embedding): - """ Embeddings for token's position. - """ - def __init__(self, max_position=512, hidden_size=768): - super(PositionEmb, self).__init__(max_position, hidden_size) - - class BertEmbeddings(nn.Module): """ BERT input embeddings is sum of: 1. Token embeddings: called word_embeddings 2. Segmentation embeddings: called token_type_embeddings 3. Position embeddings: called position_embeddings + Ref: https://arxiv.org/abs/1810.04805 section 3.2 """ - def __init__(self, vocab_size, embed_size, pad_idx=0, dropout=0.1): + def __init__(self, vocab_size, embed_size=768, pad_idx=0, + dropout=0.1, max_position=512, num_sentence=2): """ Args: vocab_size: int. Size of the embedding vocabulary. embed_size: int. Width of the word embeddings. dropout: dropout rate + pad_idx: padding index + max_position: max sentence length in input """ super(BertEmbeddings, self).__init__() self.vocab_size = vocab_size self.embed_size = embed_size self.word_padding_idx = pad_idx - self.word_embeddings = TokenEmb(vocab_size, hidden_size=embed_size, + # Token embeddings: for input tokens + self.word_embeddings = nn.Embedding(vocab_size, embed_size, padding_idx=pad_idx) - self.position_embeddings = PositionEmb(512, hidden_size=embed_size) - self.token_type_embeddings = SegmentEmb(2, hidden_size=embed_size, + # Position embeddings: for Position Encoding + self.position_embeddings = nn.Embedding(max_position, embed_size) + # Segmentation embeddings: for distinguish sentences A/B + self.token_type_embeddings = nn.Embedding(num_sentence, embed_size, padding_idx=pad_idx) self.dropout = nn.Dropout(dropout) @@ -72,7 +56,8 @@ def forward(self, input_ids, token_type_ids=None): token_type_embeds = self.token_type_embeddings(token_type_ids) embeddings = word_embeds + position_embeds + token_type_embeds - # in our version, LN is done in EncoderLayer before fed into Attention + # NOTE: in our version, LayerNorm is done in EncoderLayer + # before fed into Attention comparing to original implementation # embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index 3a095d31ba..36f5a9f7d6 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -14,6 +14,10 @@ class PositionwiseFeedForward(nn.Module): d_ff (int): the hidden layer size of the second-layer of the FNN. dropout (float): dropout probability in :math:`[0, 1)`. + activation (str): activation function to use. ['ReLU', 'GeLU'] + is_bert (bool): default False. When set True, + layer_norm will be performed on the + direct connection of residual block. """ def __init__(self, d_model, d_ff, dropout=0.1, diff --git a/onmt/trainer.py b/onmt/trainer.py index d3021090e6..2b109b49e4 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -31,14 +31,10 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object used to save the model """ - if opt.is_bert: - train_loss = onmt.utils.loss.build_bert_loss_compute(opt) - valid_loss = onmt.utils.loss.build_bert_loss_compute(opt, train=False) - else: - tgt_field = dict(fields)["tgt"].base_field - train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) - valid_loss = onmt.utils.loss.build_loss_compute( - model, tgt_field, opt, train=False) + tgt_field = dict(fields)["tgt"].base_field if not opt.is_bert else None + train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) + valid_loss = onmt.utils.loss.build_loss_compute( + model, tgt_field, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0 @@ -171,9 +167,7 @@ def _accum_batches(self, iterator): self.accum_count = self._accum_count(self.optim.training_step) for batch in iterator: batches.append(batch) - if self.is_bert: - normalization += 1 - else: + if self.is_bert is False: # Bert don't need normalization if self.norm_method == "tokens": num_tokens = batch.tgt[1:, :, 0].ne( self.train_loss.padding_idx).sum() @@ -258,7 +252,7 @@ def train(self, # Training Step: Forward -> compute Loss -> optimize if self.is_bert: - self._bert_gradient_accumulation(batches, normalization, total_stats, report_stats) + self._bert_gradient_accumulation(batches, total_stats, report_stats) else: self._gradient_accumulation( batches, normalization, total_stats, @@ -328,7 +322,6 @@ def validate(self, valid_iter, moving_average=None): valid_model.eval() with torch.no_grad(): - # TODO:if not Bert if self.is_bert: stats = onmt.utils.BertStatistics() for batch in valid_iter: @@ -340,7 +333,12 @@ def validate(self, valid_iter, moving_average=None): is_next = batch.is_next lm_labels_ids = batch.lm_labels_ids # -1 padding, others for predict in lm task # F-prop through the model. # NOTE: keyword args: input_mask, output_all_encoded_layers - seq_class_log_prob, prediction_log_prob = valid_model(input_ids, token_type_ids) + # Version 2: + all_encoder_layers, pooled_out = valid_model.bert(input_ids, token_type_ids) + seq_class_log_prob, prediction_log_prob = valid_model.cls(all_encoder_layers, pooled_out) + # Version 1: + # seq_class_log_prob, prediction_log_prob = valid_model(input_ids, token_type_ids) + # TODO: Heads outputs = (seq_class_log_prob, prediction_log_prob) # Compute loss. _, batch_stats = self.valid_loss(batch, outputs) @@ -501,8 +499,10 @@ def _report_step(self, learning_rate, step, train_stats=None, valid_stats=valid_stats) - def _bert_gradient_accumulation(self, true_batches, normalization, total_stats, - report_stats): + def _bert_gradient_accumulation(self, true_batches, total_stats, report_stats): + """As the loss will be reduced by mean, normalization is not needed anymore. + But we still need to average between GPUs. + """ if self.accum_count > 1: self.optim.zero_grad() diff --git a/onmt/utils/fn_activation.py b/onmt/utils/fn_activation.py index 8fd0e0c027..ff9f60b1f7 100644 --- a/onmt/utils/fn_activation.py +++ b/onmt/utils/fn_activation.py @@ -3,6 +3,10 @@ import math +""" +Adapted from huggingface implementation to reproduce the result +https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py +""" class GELU(nn.Module): """ Implementation of the gelu activation function diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index f5abd9034d..1e2e6d48e9 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -22,50 +22,56 @@ def build_loss_compute(model, tgt_field, opt, train=True): for when using a copy mechanism. """ device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") - - padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] - unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token] - if opt.copy_attn: - criterion = onmt.modules.CopyGeneratorLoss( - len(tgt_field.vocab), opt.copy_attn_force, - unk_index=unk_idx, ignore_index=padding_idx - ) - elif opt.label_smoothing > 0 and train: - criterion = LabelSmoothingLoss( - opt.label_smoothing, len(tgt_field.vocab), ignore_index=padding_idx - ) - elif isinstance(model.generator[-1], LogSparsemax): - criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') - else: - criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') - - # if the loss function operates on vectors of raw logits instead of - # probabilities, only the first part of the generator needs to be - # passed to the NMTLossCompute. At the moment, the only supported - # loss function of this kind is the sparsemax loss. - use_raw_logits = isinstance(criterion, SparsemaxLoss) - loss_gen = model.generator[0] if use_raw_logits else model.generator - if opt.copy_attn: - compute = onmt.modules.CopyGeneratorLossCompute( - criterion, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength - ) + if opt.is_bert is True: + assert hasattr(model, 'bert') + assert tgt_field is None + # BERT use -1 for unmasked token in lm_label_ids + criterion = nn.NLLLoss(ignore_index=-1, reduction='mean') + compute = BertLoss(criterion) else: - compute = NMTLossCompute(criterion, loss_gen) + assert isinstance(model, onmt.models.NMTModel) + padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] + unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token] + if opt.copy_attn: + criterion = onmt.modules.CopyGeneratorLoss( + len(tgt_field.vocab), opt.copy_attn_force, + unk_index=unk_idx, ignore_index=padding_idx + ) + elif opt.label_smoothing > 0 and train: + criterion = LabelSmoothingLoss( + opt.label_smoothing, len(tgt_field.vocab), ignore_index=padding_idx + ) + elif isinstance(model.generator[-1], LogSparsemax): + criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') + else: + criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') + + # if the loss function operates on vectors of raw logits instead of + # probabilities, only the first part of the generator needs to be + # passed to the NMTLossCompute. At the moment, the only supported + # loss function of this kind is the sparsemax loss. + use_raw_logits = isinstance(criterion, SparsemaxLoss) + loss_gen = model.generator[0] if use_raw_logits else model.generator + if opt.copy_attn: + compute = onmt.modules.CopyGeneratorLossCompute( + criterion, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength + ) + else: + compute = NMTLossCompute(criterion, loss_gen) compute.to(device) - return compute -def build_bert_loss_compute(opt, train=True): - """FOR BERT PRETRAINING. - Returns a LossCompute subclass which wraps around an nn.Module subclass - (such as nn.NLLLoss) which defines the loss criterion. - """ - device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") - # BERT use -1 for unmasked token in lm_label_ids - criterion = nn.NLLLoss(ignore_index=-1, reduction='mean') - compute = BertLoss(criterion).to(device) - return compute +# def build_bert_loss_compute(opt, train=True): +# """FOR BERT PRETRAINING. +# Returns a LossCompute subclass which wraps around an nn.Module subclass +# (such as nn.NLLLoss) which defines the loss criterion. +# """ +# device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") +# # BERT use -1 for unmasked token in lm_label_ids +# criterion = nn.NLLLoss(ignore_index=-1, reduction='mean') +# compute = BertLoss(criterion).to(device) +# return compute class BertLoss(nn.Module): diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index a79bb56ad7..8d60c2fc3d 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -168,6 +168,8 @@ def update(self, stat, update_n_src_words=False): """ assert isinstance(stat, BertStatistics) + # Loss for BERT is computed and reduced by average. + # Which is different from the NMTModel reduced by sum. self.loss = (self.loss * self.n_update + stat.loss * stat.n_update) / (self.n_update + stat.n_update) self.n_update += 1 diff --git a/pregenerate_bert_training_data.py b/pregenerate_bert_training_data.py index fe1ac44221..9539323d57 100755 --- a/pregenerate_bert_training_data.py +++ b/pregenerate_bert_training_data.py @@ -1,3 +1,7 @@ +""" +This file is massively inspired from huggingface and adapted into onmt custom. +Ref: https://github.com/huggingface/pytorch-transformers/blob/master/examples/lm_finetuning/pregenerate_training_data.py +""" from argparse import ArgumentParser from pathlib import Path from tqdm import tqdm, trange diff --git a/train.py b/train.py index adfb1e8dcc..2b0cd86466 100755 --- a/train.py +++ b/train.py @@ -205,7 +205,7 @@ def _get_parser(): # opts.model_opts(parser) # opts.train_opts(parser) opts.bert_model_opts(parser) - opts.bert_pretrainning(parser) + opts.bert_pretraining(parser) return parser From c1dd1f9773375648ca86acd1971d427272bcb764 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Tue, 23 Jul 2019 14:46:57 +0200 Subject: [PATCH 07/28] adapt BERT related module to ONMT habit --- bert_ckp_convert.py | 74 ++++---- onmt/encoders/__init__.py | 5 +- onmt/{models => encoders}/bert.py | 80 +++----- onmt/encoders/transformer.py | 8 +- onmt/inputters/dataset_bert.py | 18 +- onmt/model_builder.py | 174 ++++++------------ onmt/models/__init__.py | 10 +- .../{language_model.py => bert_generators.py} | 127 +++++++------ onmt/models/model_saver.py | 86 +-------- onmt/modules/position_ffn.py | 10 +- onmt/train_single.py | 26 +-- onmt/trainer.py | 6 +- onmt/utils/__init__.py | 4 +- onmt/utils/fn_activation.py | 13 ++ onmt/utils/optimizers.py | 3 +- train.py | 8 +- 16 files changed, 256 insertions(+), 396 deletions(-) rename onmt/{models => encoders}/bert.py (57%) rename onmt/models/{language_model.py => bert_generators.py} (57%) diff --git a/bert_ckp_convert.py b/bert_ckp_convert.py index 73d7286f82..d033f6cd83 100644 --- a/bert_ckp_convert.py +++ b/bert_ckp_convert.py @@ -49,17 +49,17 @@ def convert_key(key, max_layers): r'bert.encoder.layer.'+str(max_layers-1)+'.output.LayerNorm', key) elif 'bert.pooler' in key: key = key - elif 'cls.next_sentence' in key: - key = re.sub(r'cls.next_sentence.linear\.(.*)', + elif 'generator.next_sentence' in key: + key = re.sub(r'generator.next_sentence.linear\.(.*)', r'cls.seq_relationship.\1', key) - elif 'cls.mask_lm' in key: - key = re.sub(r'cls.mask_lm.bias', + elif 'generator.mask_lm' in key: + key = re.sub(r'generator.mask_lm.bias', r'cls.predictions.bias', key) - key = re.sub(r'cls.mask_lm.decode.weight', + key = re.sub(r'generator.mask_lm.decode.weight', r'cls.predictions.decoder.weight', key) - key = re.sub(r'cls.mask_lm.transform.dense\.(.*)', + key = re.sub(r'generator.mask_lm.transform.dense\.(.*)', r'cls.predictions.transform.dense.\1', key) - key = re.sub(r'cls.mask_lm.transform.layer_norm\.(.*)', + key = re.sub(r'generator.mask_lm.transform.layer_norm\.(.*)', r'cls.predictions.transform.LayerNorm.\1', key) else: raise ValueError("Unexpected keys!") @@ -69,13 +69,20 @@ def convert_key(key, max_layers): def load_bert_weights(bert_model, weights_dict, n_layers=12): bert_model_keys = bert_model.state_dict().keys() weights_keys = weights_dict.keys() - model_weights = OrderedDict() - + bert_weights = OrderedDict() + generator_weights = OrderedDict() + model_weights = {"bert": bert_weights, + "generator": generator_weights} try: for key in bert_model_keys: key_huggingface = convert_key(key, n_layers) # model_weights[key] = converted_key - model_weights[key] = weights_dict[key_huggingface] + if 'generator' not in key: + truncted_key = re.sub(r'bert\.(.*)', r'\1', key) + model_weights['bert'][truncted_key] = weights_dict[key_huggingface] + else: + truncted_key = re.sub(r'generator\.(.*)', r'\1', key) + model_weights['generator'][truncted_key] = weights_dict[key_huggingface] except ValueError: print("Unsuccessful convert!") exit() @@ -84,37 +91,38 @@ def load_bert_weights(bert_model, weights_dict, n_layers=12): def main(): parser = ArgumentParser() - parser.add_argument("--layers", type=int, default=None) - parser.add_argument("--bert_model", type=str, default="bert-base-multilingual-uncased")#, # required=True, + parser.add_argument("--layers", type=int, default=None, required=True) + # parser.add_argument("--bert_model", type=str, default="bert-base-multilingual-uncased")#, # required=True, # choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", # "bert-base-multilingual-uncased", "bert-base-chinese"]) - parser.add_argument("--bert_model_weights_path", type=str, default="PreTrainedBertckp/") - parser.add_argument("--output_dir", type=Path, default="PreTrainedBertckp/") - parser.add_argument("--output_name", type=str, default="onmt-bert-base-multilingual-uncased.pt") + # parser.add_argument("--bert_model_weights_path", type=str, default="PreTrainedBertckp/") + # parser.add_argument("--output_dir", type=Path, default="PreTrainedBertckp/") + # parser.add_argument("--output_name", type=str, default="onmt-bert-base-multilingual-uncased.weights") + parser.add_argument("--bert_model_weights_file", "-i", type=str, default="PreTrainedBertckp/bert-base-multilingual-uncased.pt") + parser.add_argument("--output_name", "-o", type=str, default="PreTrainedBertckp/onmt-bert-base-multilingual-uncased.weights") args = parser.parse_args() - bert_model_weights = args.bert_model_weights_path + args.bert_model +".pt" - print(bert_model_weights) - args.output_dir.mkdir(exist_ok=True) - outfile = args.output_dir.joinpath(args.output_name) - - # pretrained_model_name_or_path = args.bert_model - # bert_pretrained = BertForPreTraining.from_pretrained(pretrained_model_name_or_path, cache=args.output_dir) - if args.layers is None: - if 'large' in args.bert_model: - n_layers = 24 - else: - n_layers = 12 - else: - n_layers = args.layers + n_layers = args.layers + print("Model contain {} layers.".format(n_layers)) + + bert_model_weights = args.bert_model_weights_file + print("Load weights from {}.".format(bert_model_weights)) bert_weights = torch.load(bert_model_weights) - bert = onmt.models.BERT(105879) - bertlm = onmt.models.BertLM(bert) + embeddings = onmt.modules.bert_embeddings.BertEmbeddings(105879) + bert_encoder = onmt.encoders.BertEncoder(embeddings) + generator = onmt.models.BertPreTrainingHeads(bert_encoder.d_model, bert_encoder.vocab_size) + bertlm = torch.nn.Sequential(OrderedDict([ + ('bert', bert_encoder), + ('generator', generator)])) model_weights = load_bert_weights(bertlm, bert_weights, n_layers) - ckp={'model': model_weights} + + ckp={'model': model_weights['bert'], 'generator': model_weights['generator']} + + outfile = args.output_name + print("Converted weights file in {}".format(outfile)) torch.save(ckp, outfile) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/onmt/encoders/__init__.py b/onmt/encoders/__init__.py index 53daac6d82..e5e9a58bc5 100644 --- a/onmt/encoders/__init__.py +++ b/onmt/encoders/__init__.py @@ -6,11 +6,12 @@ from onmt.encoders.mean_encoder import MeanEncoder from onmt.encoders.audio_encoder import AudioEncoder from onmt.encoders.image_encoder import ImageEncoder +from onmt.encoders.bert import BertEncoder, BertLayerNorm str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder, "transformer": TransformerEncoder, "img": ImageEncoder, - "audio": AudioEncoder, "mean": MeanEncoder} + "audio": AudioEncoder, "mean": MeanEncoder, "bert": BertEncoder} __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", - "MeanEncoder", "str2enc"] + "MeanEncoder", "str2enc", "BertEncoder"] diff --git a/onmt/models/bert.py b/onmt/encoders/bert.py similarity index 57% rename from onmt/models/bert.py rename to onmt/encoders/bert.py index 1d7b481241..0f7fde4ccd 100644 --- a/onmt/models/bert.py +++ b/onmt/encoders/bert.py @@ -4,44 +4,43 @@ from onmt.encoders.transformer import TransformerEncoderLayer -class BERT(nn.Module): +class BertEncoder(nn.Module): """ BERT Implementation: https://arxiv.org/abs/1810.04805 Use a Transformer Encoder as Language modeling. """ - def __init__(self, vocab_size, num_layers=12, d_model=768, heads=12, - dropout=0.1, max_relative_positions=0, embeds=None): - super(BERT, self).__init__() - self.vocab_size = vocab_size + def __init__(self, embeddings, num_layers=12, d_model=768, + heads=12, d_ff=3072, dropout=0.1, + max_relative_positions=0): + super(BertEncoder, self).__init__() self.num_layers = num_layers - self.d_model = d_model # = hidden_size = embed_size + self.d_model = d_model self.heads = heads self.dropout = dropout - # Feed-Forward size is set to be 4H as in paper - self.d_ff = 4 * d_model - - # Build Embeddings according to vocab_size and d_model - # --DONE--: BERTEmbeddings() - # ref. build_embeddings in onmt.model_builder.py - # BERT input embeddings is sum of: - # 1. Token embeddings - # 2. Segmentation embeddings - # 3. Position embeddings - if embeds is not None: - self.embeddings = embeds - else: - self.embeddings = BertEmbeddings(vocab_size=vocab_size, - embed_size=d_model, dropout=dropout) + # Feed-Forward size should be 4*d_model as in paper + self.d_ff = d_ff + self.embeddings = embeddings # Transformer Encoder Block self.transformer_encoder = nn.ModuleList( - [TransformerEncoderLayer(d_model, heads, self.d_ff, dropout, + [TransformerEncoderLayer(d_model, heads, d_ff, dropout, max_relative_positions=max_relative_positions, - activation='GeLU', is_bert=True) for _ in range(num_layers)]) + activation='gelu', is_bert=True) for _ in range(num_layers)]) self.layer_norm = BertLayerNorm(d_model, eps=1e-12) self.pooler = BertPooler(d_model) - # TODO: self.apply(self.init_bert_weight) + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + embeddings, + opt.enc_layers, + opt.enc_rnn_size, + opt.heads, + opt.transformer_ff, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opt.max_relative_positions) def forward(self, input_ids, token_type_ids=None, input_mask=None, output_all_encoded_layers=False): @@ -57,44 +56,13 @@ def forward(self, input_ids, token_type_ids=None, input_mask=None, pooled_output: shape (batch, d_model), to be used for classification task """ - # # version 1: coder timo waiting for mask of size [B,1,T,T] - # [batch, seq] -> [batch, 1, seq] - # -> [batch, seq, seq] -> [batch, 1, seq, seq] - # attention masking for padded token - # mask: torch.ByteTensor([batch, 1, seq, seq]) - # mask = (input_ids > 0).unsqueeze(1) - # .repeat(1, input_ids.size(1), 1).unsqueeze(1) - # # This version mask 0, different masked_fill in Attention - - # # version 2: hugging face waiting for mask of size [B,1,1,T] - # if attention_mask is None: - # attention_mask = torch.ones_like(input_ids) - # if token_type_ids is None: - # token_type_ids = torch.zeros_like(input_ids) - # # extended_attention_mask.shape = [batch_size, 1, 1, seq_length] - # -> broadcast to [batch, num_heads, from_seq_length, to_seq_length] - # extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - # # for fp16 compatibility - # extended_attention_mask = extended_attention_mask - # .to(dtype=next(self.parameters()).dtype) - # -10000.0 for mask, 0 otherwise - # extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - # # version 3: OpenNMT waiting for mask of size [B, 1, T], + # OpenNMT waiting for mask of size [B, 1, T], # while in MultiHeadAttention part2 -> [B, 1, 1, T] - # TODO: create_attention_mask_from_input_mask - # padding_idx = self.embeddings.word_padding_idx - # mask = input_ids.data.eq(padding_idx).unsqueeze(1) if input_mask is None: - # input_mask = torch.ones_like(input_ids) # shape: 2D tensor [batch, seq] padding_idx = self.embeddings.word_padding_idx - # input_mask = input_ids.data.ne(padding_idx) # shape: 2D tensor [batch, seq]: 1 for tokens, 0 for paddings input_mask = input_ids.data.eq(padding_idx) - # if token_type_ids is None: - # NOTE: not needed! already done in bert_embeddings.py - # token_type_ids = torch.zeros_like(input_ids) # [batch, seq] -> [batch, 1, seq] attention_mask = input_mask.unsqueeze(1) diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 936d6edc57..183ece3690 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -22,21 +22,23 @@ class TransformerEncoderLayer(nn.Module): d_ff (int): the second-layer of the PositionwiseFeedForward. dropout (float): dropout probability(0-1.0). activation (str): activation function to chose from - ['ReLU', 'GeLU'] + ['relu', 'gelu'] is_bert (bool): default False. When set True, layer_norm will be performed on the direct connection of residual block. """ def __init__(self, d_model, heads, d_ff, dropout, - max_relative_positions=0, activation='ReLU', is_bert=False): + max_relative_positions=0, activation='relu', is_bert=False): super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiHeadedAttention( heads, d_model, dropout=dropout, max_relative_positions=max_relative_positions) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, activation, is_bert) - self.layer_norm = onmt.models.BertLayerNorm(d_model, eps=1e-12) if is_bert else nn.LayerNorm(d_model, eps=1e-6) + self.layer_norm = (onmt.encoders.BertLayerNorm(d_model,eps=1e-12) + if is_bert + else nn.LayerNorm(d_model, eps=1e-6)) self.dropout = nn.Dropout(dropout) self.is_bert = is_bert diff --git a/onmt/inputters/dataset_bert.py b/onmt/inputters/dataset_bert.py index 92f7ea22b7..ae9ceaf08f 100644 --- a/onmt/inputters/dataset_bert.py +++ b/onmt/inputters/dataset_bert.py @@ -11,26 +11,20 @@ def bert_text_sort_key(ex): class BertDataset(TorchtextDataset): """Defines a BERT dataset composed of Examples along with its Fields. Args: - fields_dict (dict[str, Field]): a dict with the structure - returned by :func:`onmt.inputters.get_bert_fields()`. - instances (Iterable[dict[]]): a list of document instance that - are going to be transfored into Examples + fields_dict (dict[str, Field]): a dict containing all Field with + its name. + instances (Iterable[dict[]]): a list of dictionary, each dict + represent one Example with its field specified by fields_dict. """ - def __init__(self, fields_dict, instances, sort_key=bert_text_sort_key, filter_pred=None): + def __init__(self, fields_dict, instances, + sort_key=bert_text_sort_key, filter_pred=None): self.sort_key = sort_key examples = [] - # NOTE: need to adapt ? ex_fields = {k: [(k, v)] for k, v in fields_dict.items()} - # print(ex_fields) for instance in instances: - # print("###################") - # print(instance) - # print("###################") ex = Example.fromdict(instance, ex_fields) - # print(ex) examples.append(ex) - # exit(1) fields_list = list(fields_dict.items()) super(BertDataset, self).__init__(examples, fields_list, filter_pred) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index e5b2fa733b..dddf3cd08c 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -9,7 +9,7 @@ import onmt.inputters as inputters import onmt.modules -from onmt.encoders import str2enc +from onmt.encoders import str2enc, BertEncoder from onmt.decoders import str2dec @@ -19,7 +19,8 @@ from onmt.utils.logging import logger from onmt.utils.parse import ArgumentParser -from onmt.models import BertLM, BERT, BertPreTrainingHeads +from onmt.models import BertPreTrainingHeads, ClassificationHead, \ + TokenGenerationHead from onmt.modules.bert_embeddings import BertEmbeddings from collections import OrderedDict @@ -229,90 +230,6 @@ def build_model(model_opt, opt, fields, checkpoint): return model -def build_bert(model_opt, opt, fields, checkpoint): - logger.info('Building BERT model...') - model = build_bertLM(model_opt, fields, use_gpu(opt), checkpoint) - logger.info(model) - return model - - -def build_bertLM(model_opt, fields, gpu, checkpoint=None, gpu_id=None): - """Build a model from opts. - - Args: - model_opt: the option loaded from checkpoint. It's important that - the opts have been updated and validated. See - :class:`onmt.utils.parse.ArgumentParser`. - fields (dict[str, torchtext.data.Field]): - `Field` objects for the model. - gpu (bool): whether to use gpu. - checkpoint: the model generated by train phase, or a resumed snapshot - model from a stopped training. - gpu_id (int or NoneType): Which GPU to use. - - Returns: - the BertLM, composed of Bert with 2 generator heads for 2 task. - """ - # TODO: compability of opt.vocab_size - # Build BertEmbeddings - # tokens_fields = fields['tokens'] - # vocab_size = len(tokens_fields.vocab) - - # Build BertModel(= encoder), BertEmbeddings also built inside Bert. - if gpu and gpu_id is not None: - device = torch.device("cuda", gpu_id) - elif gpu and not gpu_id: - device = torch.device("cuda") - elif not gpu: - device = torch.device("cpu") - bert = build_bert_encoder(model_opt, fields, gpu, checkpoint) - # BertEmbeddings is built inside Bert - # tokens_emb = bert.embeddings - model = BertLM(bert) - - # # if model_opt.task == 'classification': - # generator = nn.Sequential(nn.Linear(bert_encoder.d_model, opt.num_labels), - # nn.LogSoftmax(dim=-1)) - # # if model_opt.task == 'generation': - # generator = MaskedLanguageModel(bert_encoder.d_model, bert_encoder.vocab_size, - # bert_encoder.embeddings.word_embeddings.weight) - # model.cls = generator - # load states from checkpoints - if checkpoint is not None: - logger.info("load states from checkpoints...") - # TODO: check model.load_state_dict(...) - model.load_state_dict(checkpoint['model'], strict=False) - else: - logger.info("No checkpoint, Initialize Parameters...") - if model_opt.param_init_normal != 0.0: - normal_std = model_opt.param_init_normal - for p in model.parameters(): - p.data.normal_(mean=0, std=normal_std) - elif model_opt.param_init != 0.0: - for p in model.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - elif model_opt.param_init_glorot: - for p in model.parameters(): - if p.dim() > 1: - xavier_uniform_(p) - else: - raise AttributeError("Initialization method haven't be used!") - - model.to(device) - return model - - -def build_bert_encoder(model_opt, fields, gpu, checkpoint=None, gpu_id=None): - # TODO: need to be more elegent - token_fields_vocab = fields['tokens'].vocab - vocab_size = len(token_fields_vocab) - bert = BERT(vocab_size, num_layers=model_opt.layers, - d_model=model_opt.word_vec_size, heads=model_opt.heads, - dropout=model_opt.dropout[0], - max_relative_positions=model_opt.max_relative_positions) - return bert - - def build_bert_embeddings(opt, fields): token_fields_vocab = fields['tokens'].vocab vocab_size = len(token_fields_vocab) @@ -322,16 +239,41 @@ def build_bert_embeddings(opt, fields): return bert_emb -def build_bert_encoder_v2(model_opt, fields, embs): - # TODO: need to be more elegent - vocab_size = embs.vocab_size - bert = BERT(vocab_size, num_layers=model_opt.layers, - d_model=model_opt.word_vec_size, heads=model_opt.heads, - dropout=model_opt.dropout[0], embeds=embs, - max_relative_positions=model_opt.max_relative_positions) +def build_bert_encoder(model_opt, fields, embeddings): + bert = BertEncoder(embeddings, num_layers=model_opt.layers, + d_model=model_opt.word_vec_size, heads=model_opt.heads, + d_ff=model_opt.transformer_ff, dropout=model_opt.dropout[0], + max_relative_positions=model_opt.max_relative_positions) return bert +def build_bert_generator(model_opt, fields, bert_encoder): + """Main part for transfer learning: + set opt.task_type to `pretraining` if want finetuning; + set opt.task_type to `classification` if want use Bert to classification task; + set opt.task_type to `generation` if want use Bert to generate tokens. + Both all_encoder_layers and pooled_output will be feed to generator, + pretraining task will use the two, + while only pooled_output will be used for classification generator; + only all_encoder_layers will be used for generation generator; + """ + task = model_opt.task_type + if task == 'pretraining': + generator = BertPreTrainingHeads(bert_encoder.d_model, + bert_encoder.embeddings.vocab_size) + if model_opt.reuse_embeddings: + generator.mask_lm.decode.weight = bert_encoder.embeddings.word_embeddings.weight + elif task == 'generation': + generator = TokenGenerationHead(bert_encoder.d_model, + bert_encoder.vocab_size) + if model_opt.reuse_embeddings: + generator.decode.weight = bert_encoder.embeddings.word_embeddings.weight + elif task == 'classification': + n_class = model_opt.classification + generator = ClassificationHead(bert_encoder.d_model, n_class) + return generator + + def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): """Build a model from opts. @@ -342,7 +284,7 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): fields (dict[str, torchtext.data.Field]): `Field` objects for the model. gpu (bool): whether to use gpu. - checkpoint: the model gnerated by train phase, or a resumed snapshot + checkpoint: the model generated by train phase, or a resumed snapshot model from a stopped training. gpu_id (int or NoneType): Which GPU to use. @@ -354,10 +296,8 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): bert_emb = build_bert_embeddings(model_opt, fields) # Build encoder. - bert_encoder = build_bert_encoder_v2(model_opt, fields, bert_emb) + bert_encoder = build_bert_encoder(model_opt, fields, bert_emb) - - # Build NMTModel(= encoder + decoder). gpu = use_gpu(opt) if gpu and gpu_id is not None: device = torch.device("cuda", gpu_id) @@ -366,45 +306,39 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): elif not gpu: device = torch.device("cpu") - """Main part for transfer learning: - set opt.task to `pretraining` if want finetuning; - set opt.task to `classification` if want use Bert to classification task; - set opt.task to `generation` if want use Bert to generate a sequence. - The pooled_output from bert encoder will be feed to classification generator; - The all_encoder_layers from bert encoder will be feed to generation generator; - """ # Build Generator. - # if model_opt.task == 'pretraining': - generator = BertPreTrainingHeads(bert_encoder.d_model, bert_encoder.vocab_size, - bert_encoder.embeddings.word_embeddings.weight) - # if model_opt.share_embeddings: - # generator.mask_lm.decode.weight = bert_emb.word_embeddings.weight - # if model_opt.task == 'classification': - # generator = nn.Sequential(nn.Linear(bert_encoder.d_model, opt.num_labels), - # nn.LogSoftmax(dim=-1)) - # if model_opt.task == 'generation': - # generator = MaskedLanguageModel(bert_encoder.d_model, bert_encoder.vocab_size, - # bert_encoder.embeddings.word_embeddings.weight) - # if model_opt.share_embeddings: - # generator.mask_lm.decode.weight = bert_emb.word_embeddings.weight + generator = build_bert_generator(model_opt, fields, bert_encoder) + # Build Bert Model(= encoder + generator). model = nn.Sequential(OrderedDict([ ('bert', bert_encoder), - ('cls', generator)])) + ('generator', generator)])) # Load the model states from checkpoint or initialize them. if checkpoint is not None: - model.load_state_dict(checkpoint['model'], strict=False) - # generator.load_state_dict(checkpoint['generator'], strict=False) + assert 'model' in checkpoint + logger.info("Load Model Parameters...") + model.bert.load_state_dict(checkpoint['model'], strict=False) + if model_opt.task_type == 'pretraining': + logger.info("Load generator Parameters...") + model.generator.load_state_dict(checkpoint['generator'], strict=False) + else: + logger.info("Initialize generator Parameters...") + for p in model.generator.parameters(): + if p.dim() > 1: + xavier_uniform_(p) else: logger.info("No checkpoint, Initialize Parameters...") if model_opt.param_init_normal != 0.0: + logger.info('Initialize weights using a normal distribution') normal_std = model_opt.param_init_normal for p in model.parameters(): p.data.normal_(mean=0, std=normal_std) elif model_opt.param_init != 0.0: + logger.info('Initialize weights using a uniform distribution') for p in model.parameters(): p.data.uniform_(-model_opt.param_init, model_opt.param_init) elif model_opt.param_init_glorot: + logger.info('Glorot initialization') for p in model.parameters(): if p.dim() > 1: xavier_uniform_(p) diff --git a/onmt/models/__init__.py b/onmt/models/__init__.py index 09a54a1e33..0185af2b46 100644 --- a/onmt/models/__init__.py +++ b/onmt/models/__init__.py @@ -1,9 +1,9 @@ """Module defining models.""" from onmt.models.model_saver import build_model_saver, ModelSaver from onmt.models.model import NMTModel -from onmt.models.bert import BERT, BertLayerNorm -from onmt.models.language_model import BertLM, BertPreTrainingHeads +from onmt.models.bert_generators import BertPreTrainingHeads,\ + ClassificationHead, TokenGenerationHead -__all__ = ["build_model_saver", "ModelSaver", "NMTModel", "BERT", - "BertLM", "BertLayerNorm", "BertPreTrainingHeads", - "check_sru_requirement"] +__all__ = ["build_model_saver", "ModelSaver", "NMTModel", + "BertPreTrainingHeads", "ClassificationHead", + "TokenGenerationHead" ,"check_sru_requirement"] diff --git a/onmt/models/language_model.py b/onmt/models/bert_generators.py similarity index 57% rename from onmt/models/language_model.py rename to onmt/models/bert_generators.py index 20b1175791..83b331a880 100644 --- a/onmt/models/language_model.py +++ b/onmt/models/bert_generators.py @@ -2,51 +2,17 @@ import torch.nn as nn import onmt -from onmt.utils.fn_activation import GELU - - -class BertLM(nn.Module): - """ - BERT Language Model for pretraining, trained with 2 task : - Next Sentence Prediction Model + Masked Language Model - """ - def __init__(self, bert): - """ - Args: - bert: BERT model which should be trained - """ - super(BertLM, self).__init__() - self.bert = bert - self.vocab_size = bert.vocab_size - self.cls = BertPreTrainingHeads(self.bert.d_model, self.vocab_size, - self.bert.embeddings.word_embeddings.weight) - - def forward(self, input_ids, token_type_ids, input_mask=None, - output_all_encoded_layers=False): - """ - Args: - input_ids: shape [batch, seq] padding ids=0 - token_type_ids: shape [batch, seq], A(0), B(1), pad(0) - input_mask: shape [batch, seq], 1 for masked position(that padding) - Returns: - seq_class_log_prob: next sentence predi, (batch, 2) - prediction_log_prob: masked lm predi, (batch, seq, vocab) - """ - x, pooled_out = self.bert(input_ids, token_type_ids, input_mask, - output_all_encoded_layers) - seq_class_log_prob, prediction_log_prob = self.cls(x, pooled_out) - return seq_class_log_prob, prediction_log_prob +from onmt.utils import get_activation_fn class BertPreTrainingHeads(nn.Module): """ Bert Pretraining Heads: Masked Language Models, Next Sentence Prediction """ - def __init__(self, hidden_size, vocab_size, embedding_weights): + def __init__(self, hidden_size, vocab_size): super(BertPreTrainingHeads, self).__init__() self.next_sentence = NextSentencePrediction(hidden_size) - self.mask_lm = MaskedLanguageModel(hidden_size, vocab_size, - embedding_weights) + self.mask_lm = MaskedLanguageModel(hidden_size, vocab_size) def forward(self, x, pooled_out): """ @@ -68,8 +34,7 @@ class MaskedLanguageModel(nn.Module): n-class classification problem, n-class = vocab_size """ - def __init__(self, hidden_size, vocab_size, - bert_word_embedding_weights=None): + def __init__(self, hidden_size, vocab_size): """ Args: hidden_size: output size of BERT model @@ -78,21 +43,10 @@ def __init__(self, hidden_size, vocab_size, """ super(MaskedLanguageModel, self).__init__() self.transform = BertPredictionTransform(hidden_size) - self.reuse_emb = (True - if bert_word_embedding_weights is not None - else False) - if self.reuse_emb: # NOTE: reinit ? - assert hidden_size == bert_word_embedding_weights.size(1) - assert vocab_size == bert_word_embedding_weights.size(0) - self.decode = nn.Linear(bert_word_embedding_weights.size(1), - bert_word_embedding_weights.size(0), - bias=False) - self.decode.weight = bert_word_embedding_weights - self.bias = nn.Parameter(torch.zeros(vocab_size)) - else: - self.decode = nn.Linear(hidden_size, vocab_size, bias=False) - self.bias = nn.Parameter(torch.zeros(vocab_size)) - + + self.decode = nn.Linear(hidden_size, vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(vocab_size)) + self.softmax = nn.LogSoftmax(dim=-1) def forward(self, x): @@ -138,8 +92,8 @@ class BertPredictionTransform(nn.Module): def __init__(self, hidden_size): super(BertPredictionTransform, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) - self.activation = GELU() # get_activation fn - self.layer_norm = onmt.models.BertLayerNorm(hidden_size, eps=1e-12) + self.activation = get_activation_fn('gelu') #GELU() # get_activation fn + self.layer_norm = onmt.encoders.BertLayerNorm(hidden_size, eps=1e-12) def forward(self, hidden_states): """ @@ -149,3 +103,64 @@ def forward(self, hidden_states): hidden_states = self.layer_norm(self.activation( self.dense(hidden_states))) return hidden_states + + +class ClassificationHead(nn.Module): + """ + n-class classification head + """ + + def __init__(self, hidden_size, n_class): + """ + Args: + hidden_size: BERT model output size + """ + super(ClassificationHead, self).__init__() + self.linear = nn.Linear(hidden_size, n_class) + self.softmax = nn.LogSoftmax(dim=-1) + + def forward(self, all_hidden, pooled): + """ + Args: + all_hidden: first output argument of Bert encoder (batch, src, d_model) + pooled: last layer's output of bert encoder, shape (batch, src, d_model) + Returns: + class_log_prob: shape (batch_size, 2) + """ + score = self.linear(pooled) # (batch, n_class) + class_log_prob = self.softmax(score) # (batch, n_class) + return class_log_prob + + +class TokenGenerationHead(nn.Module): + """ + Token generation head: generation token from input sequence + """ + + def __init__(self, hidden_size, vocab_size): + """ + Args: + hidden_size: output size of BERT model + vocab_size: total vocab size + bert_word_embedding_weights: reuse embedding weights if set + """ + super(TokenGenerationHead, self).__init__() + self.transform = BertPredictionTransform(hidden_size) + + self.decode = nn.Linear(hidden_size, vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(vocab_size)) + + self.softmax = nn.LogSoftmax(dim=-1) + + def forward(self, x, pooled): + """ + Args: + x: last layer output of bert, shape (batch, seq, d_model) + Returns: + prediction_log_prob: shape (batch, seq, vocab) + """ + last_hidden = x[-1] + y = self.transform(last_hidden) # (batch, seq, d_model) + prediction_scores = self.decode(y) + self.bias # (batch, seq, vocab) + prediction_log_prob = self.softmax(prediction_scores) + return prediction_log_prob diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index b76e0a010e..befc5a4b5e 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -9,14 +9,6 @@ def build_model_saver(model_opt, opt, model, fields, optim): - # if opt.is_bert: - # model_saver = BertModelSaver(opt.save_model, - # model, - # model_opt, - # fields, - # optim, - # opt.keep_checkpoint) - # else: model_saver = ModelSaver(opt.save_model, model, model_opt, @@ -108,16 +100,9 @@ def _save(self, step, model): real_model = (model.module if isinstance(model, nn.DataParallel) else model) - if hasattr(real_model, "generator"): - print('NMT generator saving') - real_generator = (real_model.generator.module - if isinstance(real_model.generator, nn.DataParallel) - else real_model.generator) - if hasattr(real_model, "cls"): - print('BERT generator saving') - real_generator = (real_model.cls.module - if isinstance(real_model.cls, nn.DataParallel) - else real_model.cls) + real_generator = (real_model.generator.module + if isinstance(real_model.generator, nn.DataParallel) + else real_model.generator) model_state_dict = real_model.state_dict() model_state_dict = {k: v for k, v in model_state_dict.items() @@ -151,16 +136,6 @@ def _save(self, step, model): for key in keys_to_pop: field.fields[0][1].vocab.stoi.pop(key, None) - # for side in ["src", "tgt"]: - # keys_to_pop = [] - # if hasattr(vocab[side], "fields"): - # unk_token = vocab[side].fields[0][1].vocab.itos[0] - # for key, value in vocab[side].fields[0][1].vocab.stoi.items(): - # if value == 0 and key != unk_token: - # keys_to_pop.append(key) - # for key in keys_to_pop: - # vocab[side].fields[0][1].vocab.stoi.pop(key, None) - checkpoint = { 'model': model_state_dict, 'generator': generator_state_dict, @@ -176,58 +151,3 @@ def _save(self, step, model): def _rm_checkpoint(self, name): os.remove(name) - - -class BertModelSaver(ModelSaverBase): - """Simple model saver to filesystem""" - - def _save(self, step, model): - real_model = (model.module - if isinstance(model, nn.DataParallel) - else model) - # real_generator = (real_model.generator.module - # if isinstance(real_model.generator, nn.DataParallel) - # else real_model.generator) - - model_state_dict = real_model.state_dict() - model_state_dict = {k: v for k, v in model_state_dict.items() - if 'generator' not in k} - # generator_state_dict = real_generator.state_dict() - - # NOTE: We need to trim the vocab to remove any unk tokens that - # were not originally here. - - vocab = deepcopy(self.fields) - for side in ["tokens"]: - keys_to_pop = [] - # if hasattr(vocab[side], "fields"): - # unk_token = vocab[side].fields[0][1].vocab.itos[0] - # for key, value in vocab[side].fields[0][1] - # .vocab.stoi.items(): - # if value == 0 and key != unk_token: - # keys_to_pop.append(key) - # for key in keys_to_pop: - # vocab[side].fields[0][1].vocab.stoi.pop(key, None) - unk_token = vocab[side].unk_token - unk_id = vocab[side].vocab.stoi[unk_token] - for key, value in vocab[side].vocab.stoi.items(): - if value == unk_id and key != unk_token: - keys_to_pop.append(key) - for key in keys_to_pop: - vocab[side].vocab.stoi.pop(key, None) - - checkpoint = { - 'model': model_state_dict, - # 'generator': generator_state_dict, - 'vocab': vocab, - 'opt': self.model_opt, - 'optim': self.optim.state_dict(), - } - - logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) - checkpoint_path = '%s_step_%d.pt' % (self.base_path, step) - torch.save(checkpoint, checkpoint_path) - return checkpoint, checkpoint_path - - def _rm_checkpoint(self, name): - os.remove(name) diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index 36f5a9f7d6..7a68794f46 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -3,7 +3,7 @@ import torch.nn as nn import onmt -from onmt.utils.fn_activation import GELU +from onmt.utils import get_activation_fn class PositionwiseFeedForward(nn.Module): @@ -21,15 +21,15 @@ class PositionwiseFeedForward(nn.Module): """ def __init__(self, d_model, d_ff, dropout=0.1, - activation='ReLU', is_bert=False): + activation='relu', is_bert=False): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) - self.layer_norm = (onmt.models.BertLayerNorm(d_model, eps=1e-12) + self.layer_norm = (onmt.encoders.BertLayerNorm(d_model, eps=1e-12) if is_bert else nn.LayerNorm(d_model, eps=1e-6)) self.dropout_1 = nn.Dropout(dropout) - self.relu = GELU() if activation == 'GeLU' else nn.ReLU() + self.activation = get_activation_fn(activation) self.dropout_2 = nn.Dropout(dropout) self.is_bert = is_bert @@ -47,7 +47,7 @@ def forward(self, x): (FloatTensor): Output ``(batch_size, input_len, model_dim)``. """ - inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) + inter = self.dropout_1(self.activation(self.w_1(self.layer_norm(x)))) output = self.dropout_2(self.w_2(inter)) return self.residual(output, x) diff --git a/onmt/train_single.py b/onmt/train_single.py index a9413a963a..59c5f7e95f 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -6,7 +6,7 @@ from onmt.inputters.inputter import build_dataset_iter, \ load_old_vocab, old_style_vocab, build_dataset_iter_multiple -from onmt.model_builder import build_model, build_bert, build_bert_model +from onmt.model_builder import build_model, build_bert_model from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed from onmt.trainer import build_trainer @@ -51,13 +51,18 @@ def main(opt, device_id, batch_queue=None, semaphore=None): logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) - # model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) - model_opt = opt # TODO: test - # ArgumentParser.update_model_opts(model_opt) # TODO - # ArgumentParser.validate_model_opts(model_opt) # TODO - logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) - # vocab = checkpoint['vocab'] - vocab = torch.load(opt.data + '.vocab.pt') # TODO + if 'opt' in checkpoint: + model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) + ArgumentParser.update_model_opts(model_opt) + ArgumentParser.validate_model_opts(model_opt) + else: + model_opt = opt + + if 'vocab' in checkpoint: + logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) + vocab = checkpoint['vocab'] + else: + vocab = torch.load(opt.data + '.vocab.pt') else: checkpoint = None model_opt = opt @@ -92,7 +97,6 @@ def main(opt, device_id, batch_queue=None, semaphore=None): # Build model. if opt.is_bert: - # model = build_bert(model_opt, opt, fields, checkpoint) # V1 model = build_bert_model(model_opt, opt, fields, checkpoint) # V2 n_params = 0 for param in model.parameters(): @@ -106,8 +110,8 @@ def main(opt, device_id, batch_queue=None, semaphore=None): logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) - # Build optimizer. # TODO: checkpoint=checkpoint # DEBUG - optim = Optimizer.from_opt(model, opt, checkpoint=None) + # Build optimizer. + optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) diff --git a/onmt/trainer.py b/onmt/trainer.py index 2b109b49e4..9ecef71f38 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -131,8 +131,6 @@ def __init__(self, model, train_loss, valid_loss, optim, self.dropout = dropout self.dropout_steps = dropout_steps self.is_bert = True if hasattr(self.model, 'bert') else False - # self.is_bert = True if isinstance(self.model, - # onmt.models.language_model.BertLM) else False # NOTE: NEW parameter for bert training for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 @@ -335,7 +333,7 @@ def validate(self, valid_iter, moving_average=None): # F-prop through the model. # NOTE: keyword args: input_mask, output_all_encoded_layers # Version 2: all_encoder_layers, pooled_out = valid_model.bert(input_ids, token_type_ids) - seq_class_log_prob, prediction_log_prob = valid_model.cls(all_encoder_layers, pooled_out) + seq_class_log_prob, prediction_log_prob = valid_model.generator(all_encoder_layers, pooled_out) # Version 1: # seq_class_log_prob, prediction_log_prob = valid_model(input_ids, token_type_ids) # TODO: Heads @@ -535,7 +533,7 @@ def _bert_gradient_accumulation(self, true_batches, total_stats, report_stats): self.optim.zero_grad() # Version 2: all_encoder_layers, pooled_out = self.model.bert(input_ids, token_type_ids) - seq_class_log_prob, prediction_log_prob = self.model.cls(all_encoder_layers, pooled_out) + seq_class_log_prob, prediction_log_prob = self.model.generator(all_encoder_layers, pooled_out) # Version 1: # seq_class_log_prob, prediction_log_prob = self.model(input_ids, token_type_ids) # NOTE: (batch_size, 2), (batch_size, seq_size, vocab_size) diff --git a/onmt/utils/__init__.py b/onmt/utils/__init__.py index a1f422333a..8dcc2ffaee 100644 --- a/onmt/utils/__init__.py +++ b/onmt/utils/__init__.py @@ -6,9 +6,9 @@ from onmt.utils.optimizers import MultipleOptimizer, \ Optimizer, AdaFactor, BertAdam from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts - +from onmt.utils.fn_activation import get_activation_fn __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", "build_report_manager", "Statistics", "BertStatistics", "MultipleOptimizer", "Optimizer", "AdaFactor", "BertAdam", - "EarlyStopping", "scorers_from_opts"] + "EarlyStopping", "scorers_from_opts", "get_activation_fn"] diff --git a/onmt/utils/fn_activation.py b/onmt/utils/fn_activation.py index ff9f60b1f7..ec1d9336a7 100644 --- a/onmt/utils/fn_activation.py +++ b/onmt/utils/fn_activation.py @@ -3,6 +3,19 @@ import math +def get_activation_fn(activation): + if activation is 'gelu': + fn = GELU() + elif activation is 'relu': + fn = nn.ReLU() + elif activation is 'tanh': + fn = nn.Tanh() + else: + raise ValueError("Please pass a valid \ + activation function") + return fn + + """ Adapted from huggingface implementation to reproduce the result https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index 3d2ed3bd7b..07787b9037 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -353,7 +353,8 @@ def from_opt(cls, model, opt, checkpoint=None): optim_opt = opt optim_state_dict = None - if opt.train_from and checkpoint is not None: + if opt.train_from and checkpoint is not None \ + and 'optim' in checkpoint: optim = checkpoint['optim'] ckpt_opt = checkpoint['opt'] ckpt_state_dict = {} diff --git a/train.py b/train.py index 2b0cd86466..2b56b4c052 100755 --- a/train.py +++ b/train.py @@ -28,9 +28,11 @@ def main(opt): logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) - logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) - # vocab = checkpoint['vocab'] TODO:test - vocab = torch.load(opt.data + '.vocab.pt') + if 'vocab' in checkpoint: + logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) + vocab = checkpoint['vocab'] + else: + vocab = torch.load(opt.data + '.vocab.pt') else: vocab = torch.load(opt.data + '.vocab.pt') From 8c3436f4a313a858c164fd0f98762854efeb10ee Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Tue, 23 Jul 2019 16:23:47 +0200 Subject: [PATCH 08/28] add downsteam task support --- onmt/inputters/inputter.py | 32 ++++--- onmt/models/bert_generators.py | 6 +- onmt/utils/loss.py | 160 +++++++++++++++++++-------------- onmt/utils/statistics.py | 80 ++++++++++++----- 4 files changed, 176 insertions(+), 102 deletions(-) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index ee7ec86e55..1173b50e88 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -137,7 +137,8 @@ def get_fields( return fields -def get_bert_fields(pad='[PAD]', bos='[CLS]', eos='[SEP]', unk='[UNK]'): +def get_bert_fields(task='pretraining', pad='[PAD]', bos='[CLS]', + eos='[SEP]', unk='[UNK]'): fields = {} # tokens_kwargs = {"n_feats": 0, # "include_lengths": True, @@ -152,21 +153,28 @@ def get_bert_fields(pad='[PAD]', bos='[CLS]', eos='[SEP]', unk='[UNK]'): segment_ids = Field(use_vocab=False, dtype=torch.long, sequential=True, pad_token=0, batch_first=True) fields["segment_ids"] = segment_ids + if task == 'pretraining': + is_next = Field(use_vocab=False, dtype=torch.long, + sequential=False, batch_first=True) # 0/1 + fields["is_next"] = is_next - is_next = Field(use_vocab=False, dtype=torch.long, - sequential=False, batch_first=True) # 0/1 - fields["is_next"] = is_next + lm_labels_ids = Field(sequential=True, use_vocab=False, + pad_token=-1, batch_first=True) + fields["lm_labels_ids"] = lm_labels_ids - # masked_lm_positions = Field(use_vocab=False, dtype=torch.int, - # sequential=False) # indices that masked: [int] - # fields["masked_lm_positions"] = masked_lm_positions + elif task == 'classification': + category = Field(use_vocab=False, dtype=torch.long, + sequential=False, batch_first=True) # 0/1 + fields["category"] = category - # masked_lm_labels = Field(use_vocab=True, sequential=False)# tokens masked - # fields["masked_lm_labels"] = masked_lm_labels + elif task == 'generation': + token_labels_ids = Field(sequential=True, use_vocab=False, + pad_token=-1, batch_first=True) + fields["token_labels_ids"] = token_labels_ids + + else: + raise ValueError("task '{}' has not been implemented yet!".format(task)) - lm_labels_ids = Field(sequential=True, use_vocab=False, - pad_token=-1, batch_first=True) - fields["lm_labels_ids"] = lm_labels_ids return fields diff --git a/onmt/models/bert_generators.py b/onmt/models/bert_generators.py index 83b331a880..8ddf182657 100644 --- a/onmt/models/bert_generators.py +++ b/onmt/models/bert_generators.py @@ -126,10 +126,11 @@ def forward(self, all_hidden, pooled): pooled: last layer's output of bert encoder, shape (batch, src, d_model) Returns: class_log_prob: shape (batch_size, 2) + None: this is a placeholder for token level prediction task """ score = self.linear(pooled) # (batch, n_class) class_log_prob = self.softmax(score) # (batch, n_class) - return class_log_prob + return class_log_prob, None class TokenGenerationHead(nn.Module): @@ -157,10 +158,11 @@ def forward(self, x, pooled): Args: x: last layer output of bert, shape (batch, seq, d_model) Returns: + None: this is a placeholder for sentence level task prediction_log_prob: shape (batch, seq, vocab) """ last_hidden = x[-1] y = self.transform(last_hidden) # (batch, seq, d_model) prediction_scores = self.decode(y) + self.bias # (batch, seq, vocab) prediction_log_prob = self.softmax(prediction_scores) - return prediction_log_prob + return None, prediction_log_prob diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 1e2e6d48e9..1de1691e65 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -27,7 +27,8 @@ def build_loss_compute(model, tgt_field, opt, train=True): assert tgt_field is None # BERT use -1 for unmasked token in lm_label_ids criterion = nn.NLLLoss(ignore_index=-1, reduction='mean') - compute = BertLoss(criterion) + task = opt.task_type + compute = BertLoss(criterion, task) else: assert isinstance(model, onmt.models.NMTModel) padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] @@ -62,22 +63,11 @@ def build_loss_compute(model, tgt_field, opt, train=True): return compute -# def build_bert_loss_compute(opt, train=True): -# """FOR BERT PRETRAINING. -# Returns a LossCompute subclass which wraps around an nn.Module subclass -# (such as nn.NLLLoss) which defines the loss criterion. -# """ -# device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") -# # BERT use -1 for unmasked token in lm_label_ids -# criterion = nn.NLLLoss(ignore_index=-1, reduction='mean') -# compute = BertLoss(criterion).to(device) -# return compute - - class BertLoss(nn.Module): - def __init__(self, criterion): + def __init__(self, criterion, task): super(BertLoss, self).__init__() self.criterion = criterion + self.task =task @property def padding_idx(self): @@ -86,53 +76,62 @@ def padding_idx(self): def _bottle(self, _v): return _v.view(-1, _v.size(2)) - def _stats(self, loss, mlm_scores, mlm_target, - nx_sent_scores, nx_sent_target): + def _stats(self, loss, tokens_scores, tokens_target, + sents_scores, sents_target): """ Args: loss (:obj:`FloatTensor`): the loss computed by the loss criterion. - scores (:obj:`FloatTensor`): a score for each possible output - target (:obj:`FloatTensor`): true targets + tokens_scores (:obj:`FloatTensor`): scores for each token + tokens_target (:obj:`FloatTensor`): true targets for each token + sents_scores (:obj:`FloatTensor`): scores for each sentence + sents_target (:obj:`FloatTensor`): true targets for each sentence Returns: :obj:`onmt.utils.Statistics` : statistics for this batch. """ - # masked lm task - pred_mlm = mlm_scores.argmax(1) # (batch*seq, vocab) -> (batch*seq) - non_padding = mlm_target.ne(self.padding_idx) # mask: (batch*seq) - mlm_match = pred_mlm.eq(mlm_target).masked_select(non_padding) - num_correct = mlm_match.sum().item() - num_non_padding = non_padding.sum().item() - - # next sentence prediction task - pred_nx_sent = nx_sent_scores.argmax(-1) # (batch_size, 2) -> (2) - num_correct_nx_sent = nx_sent_target.eq(pred_nx_sent).sum().item() - num_sentence = len(nx_sent_target) - # print("lm: {}/{}".format(num_correct, num_non_padding)) - # print("nx: {}/{}".format(num_correct_nx_sent, num_sentence)) - return onmt.utils.BertStatistics(loss.item(), num_non_padding, - num_correct, num_sentence, - num_correct_nx_sent) - - # TODO: currently not support trunc_size & shard_size - # def _make_shard_state(self, batch, output): - # return { - # "output": output, - # "target": batch.tgt[range_[0] + 1: range_[1], :, 0], - # } - - # def _compute_loss(self, batch, output, target): - # bottled_output = self._bottle(output) - - # scores = self.generator(bottled_output) - # gtruth = target.view(-1) + if self.task == 'pretraining': + # masked lm task: token level + pred_tokens = tokens_scores.argmax(1) # (batch*seq, vocab) -> (batch*seq) + non_padding = tokens_target.ne(self.padding_idx) # mask: (batch*seq) + tokens_match = pred_tokens.eq(tokens_target).masked_select(non_padding) + n_correct_tokens = tokens_match.sum().item() + n_tokens = non_padding.sum().item() + + # next sentence prediction task: sentence level + pred_sents = sents_scores.argmax(-1) # (batch_size, 2) -> (2) + n_correct_sents = sents_target.eq(pred_sents).sum().item() + n_sentences = len(sents_target) + + elif self.task == 'classification': + # token level task: Not valide + n_correct_tokens = 0 + n_tokens = 0 + # sentence level task: + pred_sents = sents_scores.argmax(-1) # (batch_size, n_label) -> (n_label) + n_correct_sents = sents_target.eq(pred_sents).sum().item() + n_sentences = len(sents_target) + + elif self.task == 'generation': + # token level task: + pred_tokens = tokens_scores.argmax(1) # (batch*seq, vocab) -> (batch*seq) + non_padding = tokens_target.ne(self.padding_idx) # mask: (batch*seq) + tokens_match = pred_tokens.eq(tokens_target).masked_select(non_padding) + n_correct_tokens = tokens_match.sum().item() + n_tokens = non_padding.sum().item() + # sentence level task: Not valide + n_correct_sents = 0 + n_sentences = 0 + else: + raise ValueError("task '{}' has not been implemented yet!".format(self.task)) - # loss = self.criterion(scores, gtruth) - # stats = self._stats(loss.clone(), scores, gtruth) + # print("lm: {}/{}".format(n_correct_tokens, n_tokens)) + # print("nx: {}/{}".format(n_correct_sents, n_sentences)) + return onmt.utils.BertStatistics(loss.item(), n_tokens, + n_correct_tokens, n_sentences, + n_correct_sents) - # return loss, stats - def forward(self, batch, outputs, normalization=1.0): # TODO: shard=0 + def forward(self, batch, outputs): """ Args: batch: batch of examples @@ -142,28 +141,55 @@ def forward(self, batch, outputs, normalization=1.0): # TODO: shard=0 """ assert isinstance(outputs, tuple) seq_class_log_prob, prediction_log_prob = outputs - assert list(seq_class_log_prob.size()) == [len(batch), 2] - - gtruth_next_sentence = batch.is_next # (batch,) - gtruth_masked_lm = batch.lm_labels_ids # (batch, seq) - # (batch, seq, vocab) -> (batch * seq, vocab) - bottled_prediction_log_prob = self._bottle(prediction_log_prob) - bottled_gtruth_masked_lm = gtruth_masked_lm.view(-1) # (batch * seq) - # loss mean by number of sentence - next_loss = self.criterion(seq_class_log_prob, gtruth_next_sentence) - # loss mean by number of masked token - mask_loss = self.criterion(bottled_prediction_log_prob, - bottled_gtruth_masked_lm) - total_loss = next_loss + mask_loss # total_loss reduced by mean + if self.task == 'pretraining': + assert list(seq_class_log_prob.size()) == [len(batch), 2] + # masked lm task: token level(loss mean by number of tokens) + # targets: + gtruth_tokens = batch.lm_labels_ids # (batch, seq) + bottled_gtruth_tokens = gtruth_tokens.view(-1) # (batch * seq) + # prediction: (batch, seq, vocab) -> (batch * seq, vocab) + bottled_prediction_log_prob = self._bottle(prediction_log_prob) + mask_loss = self.criterion(bottled_prediction_log_prob, + bottled_gtruth_tokens) + # next sentence prediction task: sentence level(mean by sentence) + gtruth_sentences = batch.is_next # (batch,) + next_loss = self.criterion(seq_class_log_prob, gtruth_sentences) + total_loss = next_loss + mask_loss # total_loss reduced by mean + + elif self.task == 'classification': + assert prediction_log_prob is None + assert hasattr(batch, 'category') + # token level task: Not valide + bottled_prediction_log_prob = None + bottled_gtruth_tokens = None + # sentence level task: loss mean by number of sentences + gtruth_sentences = batch.category + total_loss = self.criterion(seq_class_log_prob, gtruth_sentences) + + elif self.task == 'generation': + assert seq_class_log_prob is None + assert hasattr(batch, 'token_labels_ids') + # token level task: loss mean by number of tokens + gtruth_tokens = batch.token_labels_ids # (batch, seq) + bottled_gtruth_tokens = gtruth_tokens.view(-1) # (batch * seq) + # prediction: (batch, seq, vocab) -> (batch * seq, vocab) + bottled_prediction_log_prob = self._bottle(prediction_log_prob) + total_loss = self.criterion(bottled_prediction_log_prob, + bottled_gtruth_tokens) + # sentence level task: Not valide + seq_class_log_prob = None + gtruth_sentences = None + else: + raise ValueError("task '{}' has not been implemented yet!".format(self.task)) # loss_accum_normalized = total_loss #/ float(normalization) # print("loss: ({} + {})/{} = {}".format(next_loss, mask_loss, # float(normalization), loss_accum_normalized)) - # print("nx: {}/{}".format(num_correct_nx_sent, num_sentence)) + # print("nx: {}/{}".format(n_correct_sents, n_sentences)) stats = self._stats(total_loss.clone(), bottled_prediction_log_prob, - bottled_gtruth_masked_lm, + bottled_gtruth_tokens, seq_class_log_prob, - gtruth_next_sentence) + gtruth_sentences) return total_loss, stats diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index 8d60c2fc3d..1afd976e22 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -139,15 +139,25 @@ def log_tensorboard(self, prefix, writer, learning_rate, step): class BertStatistics(Statistics): """ Bert Statistics as the loss is reduced by mean """ def __init__(self, loss=0, n_words=0, n_correct=0, - n_sentence=0, n_correct_nx_sentence=0): + n_sentence=0, n_correct_sentence=0): super(BertStatistics, self).__init__(loss, n_words, n_correct) self.n_update = 0 if n_words == 0 else 1 self.n_sentence = n_sentence - self.n_correct_nx_sentence = n_correct_nx_sentence + self.n_correct_sentence = n_correct_sentence - def next_sentence_accuracy(self): - """ compute accuracy """ - return 100 * (self.n_correct_nx_sentence / self.n_sentence) + def accuracy(self): + """ compute token level accuracy """ + if self.n_words != 0: + return 100 * (self.n_correct / self.n_words) + else: + return None + + def sentence_accuracy(self): + """ compute sentence level accuracy """ + if self.n_sentence != 0: + return 100 * (self.n_correct_sentence / self.n_sentence) + else: + return None def xent(self): """ compute cross entropy """ @@ -176,7 +186,7 @@ def update(self, stat, update_n_src_words=False): self.n_words += stat.n_words self.n_correct += stat.n_correct self.n_sentence += stat.n_sentence - self.n_correct_nx_sentence += stat.n_correct_nx_sentence + self.n_correct_sentence += stat.n_correct_sentence if update_n_src_words: self.n_src_words += stat.n_src_words @@ -193,18 +203,43 @@ def output(self, step, num_steps, learning_rate, start): step_fmt = "%2d" % step if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) - logger.info( - ("Step %s; acc(mlm/nx):%6.2f/%6.2f; total ppl: %5.2f; " + - "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") - % (step_fmt, - self.accuracy(), - self.next_sentence_accuracy(), - self.ppl(), - self.xent(), - learning_rate, - self.n_src_words / (t + 1e-5), - self.n_words / (t + 1e-5), - time.time() - start)) + if self.n_words == 0: # sentence level task + logger.info( + ("Step %s; acc(sent):%6.2f; ppl: %5.2f; " + + "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + % (step_fmt, + self.sentence_accuracy(), + self.ppl(), + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_words / (t + 1e-5), + time.time() - start)) + elif self.n_sentence == 0: # token level task + logger.info( + ("Step %s; acc(token):%6.2f; ppl: %5.2f; " + + "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + % (step_fmt, + self.accuracy(), + self.ppl(), + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_words / (t + 1e-5), + time.time() - start)) + else: # pretraining + logger.info( + ("Step %s; acc(mlm/nx):%6.2f/%6.2f; total ppl: %5.2f; " + + "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + % (step_fmt, + self.accuracy(), + self.sentence_accuracy(), + self.ppl(), + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_words / (t + 1e-5), + time.time() - start)) sys.stdout.flush() def log_tensorboard(self, prefix, writer, learning_rate, step): @@ -212,8 +247,11 @@ def log_tensorboard(self, prefix, writer, learning_rate, step): t = self.elapsed_time() writer.add_scalar(prefix + "/xent", self.xent(), step) writer.add_scalar(prefix + "/ppl", self.ppl(), step) - writer.add_scalar(prefix + "/accuracy(mlm)", self.accuracy(), step) - writer.add_scalar(prefix + "/accuracy(nx)", - self.next_sentence_accuracy(), step) + if self.n_words != 0: + writer.add_scalar(prefix + "/accuracy(token)", + self.accuracy(), step) + if self.n_sentence != 0: + writer.add_scalar(prefix + "/accuracy(sent)", + self.sentence_accuracy(), step) writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) writer.add_scalar(prefix + "/lr", learning_rate, step) From 12a909aedacb1721543ad2f265112b873dc1a8ed Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Thu, 25 Jul 2019 16:07:20 +0200 Subject: [PATCH 09/28] update --- bert_ckp_convert.py | 111 +++++++++++++------------- onmt/encoders/bert.py | 12 ++- onmt/encoders/transformer.py | 5 +- onmt/inputters/inputter.py | 11 +-- onmt/model_builder.py | 81 ++++++++++--------- onmt/models/bert_generators.py | 18 ++--- onmt/models/model_saver.py | 31 ++++---- onmt/modules/__init__.py | 3 +- onmt/modules/bert_embeddings.py | 13 ++- onmt/trainer.py | 135 ++++++++++++++------------------ onmt/utils/optimizers.py | 10 +-- onmt/utils/report_manager.py | 19 ++++- onmt/utils/statistics.py | 52 ++++++------ train.py | 16 +++- 14 files changed, 268 insertions(+), 249 deletions(-) mode change 100644 => 100755 bert_ckp_convert.py diff --git a/bert_ckp_convert.py b/bert_ckp_convert.py old mode 100644 new mode 100755 index d033f6cd83..a5eb55fd22 --- a/bert_ckp_convert.py +++ b/bert_ckp_convert.py @@ -1,15 +1,14 @@ #!/usr/bin/env python -""" Convert ckp of huggingface to onmt version""" +""" Convert weights of huggingface Bert to onmt Bert""" from argparse import ArgumentParser -from pathlib import Path -# import pytorch_pretrained_bert -# from pytorch_pretrained_bert.modeling import BertForPreTraining import torch -import onmt +from onmt.encoders.bert import BertEncoder +from onmt.models.bert_generators import BertPreTrainingHeads +from onmt.modules.bert_embeddings import BertEmbeddings from collections import OrderedDict import re -# -1 + def decrement(matched): value = int(matched.group(1)) if value < 1: @@ -17,72 +16,73 @@ def decrement(matched): string = "bert.encoder.layer.{}.output.LayerNorm".format(value-1) return string -def convert_key(key, max_layers): + +def mapping_key(key, max_layers): if 'bert.embeddings' in key: key = key - elif 'bert.transformer_encoder' in key: + elif 'bert.encoder' in key: # convert layer_norm weights - key = re.sub(r'bert.transformer_encoder.0.layer_norm\.(.*)', - r'bert.embeddings.LayerNorm.\1', key) - key = re.sub(r'bert.transformer_encoder\.(\d+)\.layer_norm', - decrement, key) # TODO + key = re.sub(r'bert.encoder.0.layer_norm\.(.*)', + r'bert.embeddings.LayerNorm.\1', key) + key = re.sub(r'bert.encoder\.(\d+)\.layer_norm', + decrement, key) # convert attention weights - key = re.sub(r'bert.transformer_encoder\.(\d+)\.self_attn.linear_keys\.(.*)', - r'bert.encoder.layer.\1.attention.self.key.\2', key) - key = re.sub(r'bert.transformer_encoder\.(\d+)\.self_attn.linear_values\.(.*)', - r'bert.encoder.layer.\1.attention.self.value.\2', key) - key = re.sub(r'bert.transformer_encoder\.(\d+)\.self_attn.linear_query\.(.*)', - r'bert.encoder.layer.\1.attention.self.query.\2', key) - key = re.sub(r'bert.transformer_encoder\.(\d+)\.self_attn.final_linear\.(.*)', - r'bert.encoder.layer.\1.attention.output.dense.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.self_attn.linear_keys\.(.*)', + r'bert.encoder.layer.\1.attention.self.key.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.self_attn.linear_values\.(.*)', + r'bert.encoder.layer.\1.attention.self.value.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.self_attn.linear_query\.(.*)', + r'bert.encoder.layer.\1.attention.self.query.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.self_attn.final_linear\.(.*)', + r'bert.encoder.layer.\1.attention.output.dense.\2', key) # convert feed forward weights - key = re.sub(r'bert.transformer_encoder\.(\d+)\.feed_forward.layer_norm\.(.*)', - r'bert.encoder.layer.\1.attention.output.LayerNorm.\2', key) - key = re.sub(r'bert.transformer_encoder\.(\d+)\.feed_forward.w_1\.(.*)', - r'bert.encoder.layer.\1.intermediate.dense.\2', key) - key = re.sub(r'bert.transformer_encoder\.(\d+)\.feed_forward.w_2\.(.*)', - r'bert.encoder.layer.\1.output.dense.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.feed_forward.layer_norm\.(.*)', + r'bert.encoder.layer.\1.attention.output.LayerNorm.\2', + key) + key = re.sub(r'bert.encoder\.(\d+)\.feed_forward.w_1\.(.*)', + r'bert.encoder.layer.\1.intermediate.dense.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.feed_forward.w_2\.(.*)', + r'bert.encoder.layer.\1.output.dense.\2', key) elif 'bert.layer_norm' in key: key = re.sub(r'bert.layer_norm', - r'bert.encoder.layer.'+str(max_layers-1)+'.output.LayerNorm', key) + r'bert.encoder.layer.' + str(max_layers - 1) + + '.output.LayerNorm', key) elif 'bert.pooler' in key: key = key elif 'generator.next_sentence' in key: key = re.sub(r'generator.next_sentence.linear\.(.*)', - r'cls.seq_relationship.\1', key) + r'cls.seq_relationship.\1', key) elif 'generator.mask_lm' in key: key = re.sub(r'generator.mask_lm.bias', - r'cls.predictions.bias', key) + r'cls.predictions.bias', key) key = re.sub(r'generator.mask_lm.decode.weight', - r'cls.predictions.decoder.weight', key) + r'cls.predictions.decoder.weight', key) key = re.sub(r'generator.mask_lm.transform.dense\.(.*)', - r'cls.predictions.transform.dense.\1', key) + r'cls.predictions.transform.dense.\1', key) key = re.sub(r'generator.mask_lm.transform.layer_norm\.(.*)', - r'cls.predictions.transform.LayerNorm.\1', key) + r'cls.predictions.transform.LayerNorm.\1', key) else: - raise ValueError("Unexpected keys!") + raise ValueError("Unexpected keys!") return key -def load_bert_weights(bert_model, weights_dict, n_layers=12): +def convert_bert_weights(bert_model, weights, n_layers=12): bert_model_keys = bert_model.state_dict().keys() - weights_keys = weights_dict.keys() bert_weights = OrderedDict() generator_weights = OrderedDict() model_weights = {"bert": bert_weights, "generator": generator_weights} try: for key in bert_model_keys: - key_huggingface = convert_key(key, n_layers) - # model_weights[key] = converted_key + hugface_key = mapping_key(key, n_layers) if 'generator' not in key: - truncted_key = re.sub(r'bert\.(.*)', r'\1', key) - model_weights['bert'][truncted_key] = weights_dict[key_huggingface] + onmt_key = re.sub(r'bert\.(.*)', r'\1', key) + model_weights['bert'][onmt_key] = weights[hugface_key] else: - truncted_key = re.sub(r'generator\.(.*)', r'\1', key) - model_weights['generator'][truncted_key] = weights_dict[key_huggingface] + onmt_key = re.sub(r'generator\.(.*)', r'\1', key) + model_weights['generator'][onmt_key] = weights[hugface_key] except ValueError: print("Unsuccessful convert!") exit() @@ -92,14 +92,15 @@ def load_bert_weights(bert_model, weights_dict, n_layers=12): def main(): parser = ArgumentParser() parser.add_argument("--layers", type=int, default=None, required=True) - # parser.add_argument("--bert_model", type=str, default="bert-base-multilingual-uncased")#, # required=True, - # choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", - # "bert-base-multilingual-uncased", "bert-base-chinese"]) - # parser.add_argument("--bert_model_weights_path", type=str, default="PreTrainedBertckp/") - # parser.add_argument("--output_dir", type=Path, default="PreTrainedBertckp/") - # parser.add_argument("--output_name", type=str, default="onmt-bert-base-multilingual-uncased.weights") - parser.add_argument("--bert_model_weights_file", "-i", type=str, default="PreTrainedBertckp/bert-base-multilingual-uncased.pt") - parser.add_argument("--output_name", "-o", type=str, default="PreTrainedBertckp/onmt-bert-base-multilingual-uncased.weights") + + parser.add_argument("--bert_model_weights_file", "-i", type=str, + default=None, required=True, help="Path to the " + "huggingface Bert weights file download from " + "https://github.com/huggingface/pytorch-transformers") + + parser.add_argument("--output_name", "-o", type=str, + default=None, required=True, + help="output onmt version Bert weight file Path") args = parser.parse_args() n_layers = args.layers @@ -109,15 +110,17 @@ def main(): print("Load weights from {}.".format(bert_model_weights)) bert_weights = torch.load(bert_model_weights) - embeddings = onmt.modules.bert_embeddings.BertEmbeddings(105879) - bert_encoder = onmt.encoders.BertEncoder(embeddings) - generator = onmt.models.BertPreTrainingHeads(bert_encoder.d_model, bert_encoder.vocab_size) + embeddings = BertEmbeddings(105879) + bert_encoder = BertEncoder(embeddings) + generator = BertPreTrainingHeads(bert_encoder.d_model, + embeddings.vocab_size) bertlm = torch.nn.Sequential(OrderedDict([ ('bert', bert_encoder), ('generator', generator)])) - model_weights = load_bert_weights(bertlm, bert_weights, n_layers) + model_weights = convert_bert_weights(bertlm, bert_weights, n_layers) - ckp={'model': model_weights['bert'], 'generator': model_weights['generator']} + ckp = {'model': model_weights['bert'], + 'generator': model_weights['generator']} outfile = args.output_name print("Converted weights file in {}".format(outfile)) diff --git a/onmt/encoders/bert.py b/onmt/encoders/bert.py index 0f7fde4ccd..309f222bd6 100644 --- a/onmt/encoders/bert.py +++ b/onmt/encoders/bert.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -from onmt.modules.bert_embeddings import BertEmbeddings from onmt.encoders.transformer import TransformerEncoderLayer @@ -22,10 +21,10 @@ def __init__(self, embeddings, num_layers=12, d_model=768, self.embeddings = embeddings # Transformer Encoder Block - self.transformer_encoder = nn.ModuleList( + self.encoder = nn.ModuleList( [TransformerEncoderLayer(d_model, heads, d_ff, dropout, - max_relative_positions=max_relative_positions, - activation='gelu', is_bert=True) for _ in range(num_layers)]) + max_relative_positions=max_relative_positions, + activation='gelu', is_bert=True) for _ in range(num_layers)]) self.layer_norm = BertLayerNorm(d_model, eps=1e-12) self.pooler = BertPooler(d_model) @@ -68,9 +67,8 @@ def forward(self, input_ids, token_type_ids=None, input_mask=None, # embedding vectors: [batch, seq, hidden_size] out = self.embeddings(input_ids, token_type_ids) - all_encoder_layers = [] - for layer in self.transformer_encoder: + for layer in self.encoder: out = layer(out, attention_mask) if output_all_encoded_layers: all_encoder_layers.append(self.layer_norm(out)) @@ -83,7 +81,7 @@ def forward(self, input_ids, token_type_ids=None, input_mask=None, def update_dropout(self, dropout): self.dropout = dropout self.embeddings.update_dropout(dropout) - for layer in self.transformer_encoder: + for layer in self.encoder: layer.update_dropout(dropout) diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 183ece3690..181989f9f2 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -35,8 +35,9 @@ def __init__(self, d_model, heads, d_ff, dropout, self.self_attn = MultiHeadedAttention( heads, d_model, dropout=dropout, max_relative_positions=max_relative_positions) - self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, activation, is_bert) - self.layer_norm = (onmt.encoders.BertLayerNorm(d_model,eps=1e-12) + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, + activation, is_bert) + self.layer_norm = (onmt.encoders.BertLayerNorm(d_model, eps=1e-12) if is_bert else nn.LayerNorm(d_model, eps=1e-6)) self.dropout = nn.Dropout(dropout) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index 1173b50e88..f8a8b60dbd 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -173,7 +173,7 @@ def get_bert_fields(task='pretraining', pad='[PAD]', bos='[CLS]', fields["token_labels_ids"] = token_labels_ids else: - raise ValueError("task '{}' has not been implemented yet!".format(task)) + raise ValueError("task %s has not been implemented yet!" % task) return fields @@ -398,8 +398,9 @@ def _build_bert_fields_vocab(fields, counters, vocab_size, # _build_field_vocab(tokens_field, tokens_counter, # size_multiple=vocab_size_multiple, # max_size=vocab_size, min_freq=tokens_min_frequency) - tokens_field.vocab = tokens_field.vocab_cls(tokens_counter, specials=[], - max_size=vocab_size, min_freq=tokens_min_frequency) + tokens_field.vocab = tokens_field.vocab_cls( + tokens_counter, specials=[], max_size=vocab_size, + min_freq=tokens_min_frequency) if vocab_size_multiple > 1: _pad_vocab_to_multiple(tokens_field.vocab, vocab_size_multiple) @@ -806,8 +807,8 @@ def max_tok_len(new, count, sofar): such that the total number of src/tgt tokens (including padding) in a batch <= batch_size """ - if hasattr(new, 'is_next'): - # when a example has the attr 'is_next', + if hasattr(new, 'tokens'): + # when a example has the attr 'tokens', # this means we are loading Bert Data # Maintains the longest token length in the current batch global max_tokens_in_batch diff --git a/onmt/model_builder.py b/onmt/model_builder.py index dddf3cd08c..abcbc9d6a6 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -240,18 +240,19 @@ def build_bert_embeddings(opt, fields): def build_bert_encoder(model_opt, fields, embeddings): - bert = BertEncoder(embeddings, num_layers=model_opt.layers, - d_model=model_opt.word_vec_size, heads=model_opt.heads, - d_ff=model_opt.transformer_ff, dropout=model_opt.dropout[0], - max_relative_positions=model_opt.max_relative_positions) + bert = BertEncoder( + embeddings, num_layers=model_opt.layers, + d_model=model_opt.word_vec_size, heads=model_opt.heads, + d_ff=model_opt.transformer_ff, dropout=model_opt.dropout[0], + max_relative_positions=model_opt.max_relative_positions) return bert def build_bert_generator(model_opt, fields, bert_encoder): """Main part for transfer learning: - set opt.task_type to `pretraining` if want finetuning; - set opt.task_type to `classification` if want use Bert to classification task; - set opt.task_type to `generation` if want use Bert to generate tokens. + set opt.task_type to `pretraining` if want finetuning Bert; + set opt.task_type to `classification` if want sentence level task; + set opt.task_type to `generation` if want token level task. Both all_encoder_layers and pooled_output will be feed to generator, pretraining task will use the two, while only pooled_output will be used for classification generator; @@ -259,15 +260,17 @@ def build_bert_generator(model_opt, fields, bert_encoder): """ task = model_opt.task_type if task == 'pretraining': - generator = BertPreTrainingHeads(bert_encoder.d_model, - bert_encoder.embeddings.vocab_size) + generator = BertPreTrainingHeads( + bert_encoder.d_model, bert_encoder.embeddings.vocab_size) if model_opt.reuse_embeddings: - generator.mask_lm.decode.weight = bert_encoder.embeddings.word_embeddings.weight + generator.mask_lm.decode.weight = \ + bert_encoder.embeddings.word_embeddings.weight elif task == 'generation': - generator = TokenGenerationHead(bert_encoder.d_model, - bert_encoder.vocab_size) + generator = TokenGenerationHead( + bert_encoder.d_model, bert_encoder.vocab_size) if model_opt.reuse_embeddings: - generator.decode.weight = bert_encoder.embeddings.word_embeddings.weight + generator.decode.weight = \ + bert_encoder.embeddings.word_embeddings.weight elif task == 'classification': n_class = model_opt.classification generator = ClassificationHead(bert_encoder.d_model, n_class) @@ -313,37 +316,41 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): model = nn.Sequential(OrderedDict([ ('bert', bert_encoder), ('generator', generator)])) + # Load the model states from checkpoint or initialize them. + model_init = {'bert': False, 'generator': False} if checkpoint is not None: assert 'model' in checkpoint + if model.bert.state_dict().keys() != checkpoint['model'].keys(): + raise ValueError("Provide checkpoint don't match actual model!") logger.info("Load Model Parameters...") - model.bert.load_state_dict(checkpoint['model'], strict=False) - if model_opt.task_type == 'pretraining': + model.bert.load_state_dict(checkpoint['model'], strict=True) + model_init['bert'] = True + if model.generator.state_dict().keys() == checkpoint['generator'].keys(): logger.info("Load generator Parameters...") - model.generator.load_state_dict(checkpoint['generator'], strict=False) - else: - logger.info("Initialize generator Parameters...") - for p in model.generator.parameters(): + model.generator.load_state_dict(checkpoint['generator'], strict=True) + model_init['generator'] = True + + for sub_module, is_init in model_init.items(): + if not is_init: + logger.info("Initialize {} Parameters...".format(sub_module)) + if model_opt.param_init_normal != 0.0: + logger.info('Initialize weights using a normal distribution') + normal_std = model_opt.param_init_normal + for p in model.sub_module.parameters(): + p.data.normal_(mean=0, std=normal_std) + elif model_opt.param_init != 0.0: + logger.info('Initialize weights using a uniform distribution') + for p in model.sub_module.parameters(): + p.data.uniform_(-model_opt.param_init, + model_opt.param_init) + elif model_opt.param_init_glorot: + logger.info('Glorot initialization') + for p in model.sub_module.parameters(): if p.dim() > 1: xavier_uniform_(p) - else: - logger.info("No checkpoint, Initialize Parameters...") - if model_opt.param_init_normal != 0.0: - logger.info('Initialize weights using a normal distribution') - normal_std = model_opt.param_init_normal - for p in model.parameters(): - p.data.normal_(mean=0, std=normal_std) - elif model_opt.param_init != 0.0: - logger.info('Initialize weights using a uniform distribution') - for p in model.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - elif model_opt.param_init_glorot: - logger.info('Glorot initialization') - for p in model.parameters(): - if p.dim() > 1: - xavier_uniform_(p) - else: - raise AttributeError("Initialization method haven't be used!") + else: + raise AttributeError("Initialization method haven't be used!") model.to(device) logger.info(model) diff --git a/onmt/models/bert_generators.py b/onmt/models/bert_generators.py index 8ddf182657..b467db12d6 100644 --- a/onmt/models/bert_generators.py +++ b/onmt/models/bert_generators.py @@ -43,16 +43,16 @@ def __init__(self, hidden_size, vocab_size): """ super(MaskedLanguageModel, self).__init__() self.transform = BertPredictionTransform(hidden_size) - + self.decode = nn.Linear(hidden_size, vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size)) - + self.softmax = nn.LogSoftmax(dim=-1) def forward(self, x): """ Args: - x: last layer output of bert, shape (batch, seq, d_model) + x: first output of bert encoder, (batch, seq, d_model) Returns: prediction_log_prob: shape (batch, seq, vocab) """ @@ -79,12 +79,12 @@ def __init__(self, hidden_size): def forward(self, x): """ Args: - x: last layer's output of bert encoder, shape (batch, src, d_model) + x: second output of bert encoder, (batch, d_model) Returns: seq_class_prob: shape (batch_size, 2) """ seq_relationship_score = self.linear(x) # (batch, 2) - seq_class_log_prob = self.softmax(seq_relationship_score) # (batch, 2) + seq_class_log_prob = self.softmax(seq_relationship_score) return seq_class_log_prob @@ -92,7 +92,7 @@ class BertPredictionTransform(nn.Module): def __init__(self, hidden_size): super(BertPredictionTransform, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) - self.activation = get_activation_fn('gelu') #GELU() # get_activation fn + self.activation = get_activation_fn('gelu') self.layer_norm = onmt.encoders.BertLayerNorm(hidden_size, eps=1e-12) def forward(self, hidden_states): @@ -122,8 +122,8 @@ def __init__(self, hidden_size, n_class): def forward(self, all_hidden, pooled): """ Args: - all_hidden: first output argument of Bert encoder (batch, src, d_model) - pooled: last layer's output of bert encoder, shape (batch, src, d_model) + all_hidden: first output of Bert encoder (batch, src, d_model) + pooled: second output of bert encoder, shape (batch, src, d_model) Returns: class_log_prob: shape (batch_size, 2) None: this is a placeholder for token level prediction task @@ -147,7 +147,7 @@ def __init__(self, hidden_size, vocab_size): """ super(TokenGenerationHead, self).__init__() self.transform = BertPredictionTransform(hidden_size) - + self.decode = nn.Linear(hidden_size, vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size)) diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index befc5a4b5e..4028e8f285 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -10,11 +10,11 @@ def build_model_saver(model_opt, opt, model, fields, optim): model_saver = ModelSaver(opt.save_model, - model, - model_opt, - fields, - optim, - opt.keep_checkpoint) + model, + model_opt, + fields, + optim, + opt.keep_checkpoint) return model_saver @@ -101,12 +101,14 @@ def _save(self, step, model): if isinstance(model, nn.DataParallel) else model) real_generator = (real_model.generator.module - if isinstance(real_model.generator, nn.DataParallel) - else real_model.generator) - - model_state_dict = real_model.state_dict() - model_state_dict = {k: v for k, v in model_state_dict.items() - if 'generator' not in k} + if isinstance(real_model.generator, nn.DataParallel) + else real_model.generator) + if hasattr(real_model, 'bert'): + model_state_dict = real_model.bert.state_dict() + else: + model_state_dict = real_model.state_dict() + model_state_dict = {k: v for k, v in model_state_dict.items() + if 'generator' not in k} generator_state_dict = real_generator.state_dict() # NOTE: We need to trim the vocab to remove any unk tokens that @@ -115,7 +117,8 @@ def _save(self, step, model): vocab = deepcopy(self.fields) for name, field in vocab.items(): if isinstance(field, Field): - if hasattr(field, "vocab"): + if hasattr(field, "vocab") and \ + hasattr(field, "unk_token"): assert name == 'tokens' keys_to_pop = [] unk_token = field.unk_token @@ -123,8 +126,8 @@ def _save(self, step, model): for key, value in field.vocab.stoi.items(): if value == unk_id and key != unk_token: keys_to_pop.append(key) - for key in keys_to_pop: - field.vocab.stoi.pop(key, None) + for key in keys_to_pop: + field.vocab.stoi.pop(key, None) else: if hasattr(field, "fields"): assert name in ["src", "tgt"] diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 0bed435ff7..06ef91ac5d 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -5,6 +5,7 @@ from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention from onmt.modules.multi_headed_attn import MultiHeadedAttention from onmt.modules.embeddings import Embeddings, PositionalEncoding +from onmt.modules.bert_embeddings import BertEmbeddings from onmt.modules.weight_norm import WeightNormConv2d from onmt.modules.average_attn import AverageAttention from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ @@ -14,4 +15,4 @@ "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", "CopyGeneratorLoss", "CopyGeneratorLossCompute", "MultiHeadedAttention", "Embeddings", "PositionalEncoding", - "WeightNormConv2d", "AverageAttention"] + "WeightNormConv2d", "AverageAttention", "BertEmbeddings"] diff --git a/onmt/modules/bert_embeddings.py b/onmt/modules/bert_embeddings.py index cc3bda783e..8cd0bf54ae 100644 --- a/onmt/modules/bert_embeddings.py +++ b/onmt/modules/bert_embeddings.py @@ -24,13 +24,13 @@ def __init__(self, vocab_size, embed_size=768, pad_idx=0, self.embed_size = embed_size self.word_padding_idx = pad_idx # Token embeddings: for input tokens - self.word_embeddings = nn.Embedding(vocab_size, embed_size, - padding_idx=pad_idx) + self.word_embeddings = nn.Embedding( + vocab_size, embed_size, padding_idx=pad_idx) # Position embeddings: for Position Encoding self.position_embeddings = nn.Embedding(max_position, embed_size) # Segmentation embeddings: for distinguish sentences A/B - self.token_type_embeddings = nn.Embedding(num_sentence, embed_size, - padding_idx=pad_idx) + self.token_type_embeddings = nn.Embedding( + num_sentence, embed_size, padding_idx=pad_idx) self.dropout = nn.Dropout(dropout) @@ -43,8 +43,8 @@ def forward(self, input_ids, token_type_ids=None): embeddings: word embeds in shape [batch, seq, hidden_size]. """ seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, - device=input_ids.device) # [0, 1,..., seq_length-1] + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) # [[0,1,...,seq_length-1]] -> [[0,1,...,seq_length-1] *batch_size] position_ids = position_ids.unsqueeze(0).expand_as(input_ids) @@ -54,7 +54,6 @@ def forward(self, input_ids, token_type_ids=None): word_embeds = self.word_embeddings(input_ids) position_embeds = self.position_embeddings(position_ids) token_type_embeds = self.token_type_embeddings(token_type_ids) - embeddings = word_embeds + position_embeds + token_type_embeds # NOTE: in our version, LayerNorm is done in EncoderLayer # before fed into Attention comparing to original implementation diff --git a/onmt/trainer.py b/onmt/trainer.py index 9ecef71f38..888cae3415 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -34,7 +34,7 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): tgt_field = dict(fields)["tgt"].base_field if not opt.is_bert else None train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) valid_loss = onmt.utils.loss.build_loss_compute( - model, tgt_field, opt, train=False) + model, tgt_field, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0 @@ -131,17 +131,17 @@ def __init__(self, model, train_loss, valid_loss, optim, self.dropout = dropout self.dropout_steps = dropout_steps self.is_bert = True if hasattr(self.model, 'bert') else False - + for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 if self.accum_count_l[i] > 1: assert self.trunc_size == 0, \ """To enable accumulated gradients, you must disable target sequence truncating.""" - + if self.is_bert: assert self.trunc_size == 0 - """ Bert currently not support target sequence truncating""" # TODO + """ Bert currently not support target sequence truncating""" # Set model in training mode. self.model.train() @@ -168,7 +168,7 @@ def _accum_batches(self, iterator): if self.is_bert is False: # Bert don't need normalization if self.norm_method == "tokens": num_tokens = batch.tgt[1:, :, 0].ne( - self.train_loss.padding_idx).sum() + self.train_loss.padding_idx).sum() normalization += num_tokens.item() else: normalization += batch.batch_size @@ -241,21 +241,23 @@ def train(self, n_minibatch %d" % (self.gpu_rank, i + 1, len(batches))) - if self.n_gpu > 1: # NOTE: DEBUG - list_norm = onmt.utils.distributed.all_gather_list(normalization) + if self.n_gpu > 1: + l_norm = onmt.utils.distributed.all_gather_list(normalization) + # NOTE: DEBUG # current_rank = torch.distributed.get_rank() # print("-> RANK: {}".format(current_rank)) # print(list_norm) - normalization = sum(list_norm) + normalization = sum(l_norm) # Training Step: Forward -> compute Loss -> optimize if self.is_bert: - self._bert_gradient_accumulation(batches, total_stats, report_stats) + self._bert_gradient_accumulation( + batches, total_stats, report_stats) else: self._gradient_accumulation( - batches, normalization, total_stats, - report_stats) - + batches, normalization, total_stats, + report_stats) + # Moving average if self.average_decay > 0 and i % self.average_every == 0: self._update_average(step) @@ -264,9 +266,7 @@ def train(self, step, train_steps, self.optim.learning_rate(), report_stats) - # NOTE: DEBUG - # exit() - + # Part: validation if valid_iter is not None and step % valid_steps == 0: if self.gpu_verbose_level > 0: @@ -323,34 +323,35 @@ def validate(self, valid_iter, moving_average=None): if self.is_bert: stats = onmt.utils.BertStatistics() for batch in valid_iter: - # input_ids: Size([batch_size, max_seq_length_in_batch]), seq_lengths: Size([batch_size]) - input_ids, seq_lengths = batch.tokens if isinstance(batch.tokens, tuple) \ - else (batch.tokens, None) - # segment_ids, lm_labels_ids: Size([batch_size, max_seq_length_in_batch]), is_next: Size([batch_size]) - token_type_ids = batch.segment_ids # 0 for sens A, 1 for sens B. 0 padding - is_next = batch.is_next - lm_labels_ids = batch.lm_labels_ids # -1 padding, others for predict in lm task - # F-prop through the model. # NOTE: keyword args: input_mask, output_all_encoded_layers - # Version 2: - all_encoder_layers, pooled_out = valid_model.bert(input_ids, token_type_ids) - seq_class_log_prob, prediction_log_prob = valid_model.generator(all_encoder_layers, pooled_out) - # Version 1: - # seq_class_log_prob, prediction_log_prob = valid_model(input_ids, token_type_ids) - # TODO: Heads + # input_ids: Size([batch_size, max_seq_length_in_batch]), + # seq_lengths: Size([batch_size]) + if isinstance(batch.tokens, tuple): + input_ids, _ = batch.tokens + else: + input_ids, _ = (batch.tokens, None) + # segment_ids: Size([batch_size, max_seq_length_in_batch]) + # 0 for sens A, 1 for sens B. 0 padding + token_type_ids = batch.segment_ids + # F-prop through the model. + all_encoder_layers, pooled_out = \ + valid_model.bert(input_ids, token_type_ids) + seq_class_log_prob, prediction_log_prob = \ + valid_model.generator(all_encoder_layers, pooled_out) + outputs = (seq_class_log_prob, prediction_log_prob) # Compute loss. _, batch_stats = self.valid_loss(batch, outputs) # Update statistics. - stats.update(batch_stats) + stats.update(batch_stats) else: stats = onmt.utils.Statistics() for batch in valid_iter: - src, src_lengths = batch.src if isinstance(batch.src, tuple) \ - else (batch.src, None) + src, src_lengths = batch.src if isinstance( + batch.src, tuple) else (batch.src, None) tgt = batch.tgt - # F-prop through the model. + # F-prop through the model. outputs, attns = valid_model(src, tgt, src_lengths) # Compute loss. @@ -358,7 +359,7 @@ def validate(self, valid_iter, moving_average=None): # Update statistics. stats.update(batch_stats) - + if moving_average: del valid_model else: @@ -496,53 +497,41 @@ def _report_step(self, learning_rate, step, train_stats=None, learning_rate, step, train_stats=train_stats, valid_stats=valid_stats) - - def _bert_gradient_accumulation(self, true_batches, total_stats, report_stats): - """As the loss will be reduced by mean, normalization is not needed anymore. + def _bert_gradient_accumulation(self, true_batches, + total_stats, report_stats): + """As the loss will be reduced by mean, normalization is not needed. But we still need to average between GPUs. """ if self.accum_count > 1: self.optim.zero_grad() for k, batch in enumerate(true_batches): - # target_size = batch.tgt.size(0) - # NOTE: for batch in BERT : batch_first is True -> [batch, seq, vocab] - # # Truncated BPTT: reminder not compatible with accum > 1 - # if self.trunc_size: # TODO - # trunc_size = self.trunc_size - # else: - # trunc_size = target_size - - input_ids, seq_lengths = batch.tokens if isinstance(batch.tokens, tuple) \ - else (batch.tokens, None) + # target_size = batch.tgt.size(0) + # NOTE: for batch in BERT : + # batch_first is True -> [batch, seq, vocab] + if isinstance(batch.tokens, tuple): + input_ids, seq_lengths = batch.tokens + else: + input_ids, seq_lengths = (batch.tokens, None) + if seq_lengths is not None: report_stats.n_src_words += seq_lengths.sum().item() - # tgt_outer = batch.tgt token_type_ids = batch.segment_ids - is_next = batch.is_next - lm_labels_ids = batch.lm_labels_ids - # bptt = False - # TODO: to be removed, as not support bptt yet! - # for j in range(0, target_size-1, trunc_size): - # 1. Create truncated target. - # tgt = tgt_outer[j: j + trunc_size] - # 2. F-prop all to get log likelihood of two task. + # 1. F-prop all to get log likelihood of two task. if self.accum_count == 1: self.optim.zero_grad() - # Version 2: - all_encoder_layers, pooled_out = self.model.bert(input_ids, token_type_ids) - seq_class_log_prob, prediction_log_prob = self.model.generator(all_encoder_layers, pooled_out) - # Version 1: - # seq_class_log_prob, prediction_log_prob = self.model(input_ids, token_type_ids) + + all_encoder_layers, pooled_out = self.model.bert( + input_ids, token_type_ids) + seq_class_log_prob, prediction_log_prob = self.model.generator( + all_encoder_layers, pooled_out) # NOTE: (batch_size, 2), (batch_size, seq_size, vocab_size) outputs = (seq_class_log_prob, prediction_log_prob) - # outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt) - # bptt = True - # 3. Compute loss. - try: # NOTE: unuse normalisation + # 2. Compute loss. + try: loss, batch_stats = self.train_loss(batch, outputs) # NOTE: DEBUG # loss_list = onmt.utils.distributed.all_gather_list(loss) @@ -565,33 +554,27 @@ def _bert_gradient_accumulation(self, true_batches, total_stats, report_stats): logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k) - # 4. Update the parameters and statistics. + # 3. Update the parameters and statistics. if self.accum_count == 1: # Multi GPU gradient gather if self.n_gpu > 1: grads = [p.grad.data for p in self.model.parameters() - if p.requires_grad - and p.grad is not None] + if p.requires_grad + and p.grad is not None] # current_rank = torch.distributed.get_rank() # print("{}-> RANK: {}, grads BEFORE:{}".format( # k, current_rank, grads[0])) # NOTE: average the gradient across the GPU onmt.utils.distributed.all_reduce_and_rescale_tensors( grads, float(self.n_gpu)) - # reduced_grads = [p.grad.data for p in self.model.parameters() + # reduced_grads = [p.grad.data for p in + # self.model.parameters() # if p.requires_grad # and p.grad is not None] # print("{}-> RANK: {}, grads AFTER:{}".format( # k, current_rank, reduced_grads[0])) self.optim.step() - # If truncated, don't backprop fully. - # TO CHECK - # if dec_state is not None: - # dec_state.detach() - # if self.model.decoder.state is not None: # TODO: ?? - # self.model.decoder.detach_state() - # in case of multi step gradient accumulation, # update only after accum batches if self.accum_count > 1: diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index 07787b9037..c1e2c4773d 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -639,7 +639,7 @@ class BertAdam(torch.optim.Optimizer): weight_decay: Weight decay. Default: 0.01 # TODO: exclude LayerNorm from weight decay? max_grad_norm: Maximum norm for the gradients (-1 means no clipping). - """ + """ # TODO: add parameter to opt def __init__(self, params, lr=None, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): if not 0.0 <= lr: @@ -675,7 +675,7 @@ def step(self, closure=None): continue grad = p.grad.data if grad.is_sparse: - raise RuntimeError('Adam : not support sparse gradients,' + + raise RuntimeError('Adam: not support sparse gradients,' + 'please consider SparseAdam instead') state = self.state[p] @@ -695,7 +695,7 @@ def step(self, closure=None): # if group['max_grad_norm'] > 0: # clip_grad_norm_(p, group['max_grad_norm']) - # Decay the first and second moment running average coefficient + # Decay first and second moment running average coefficient # In-place operations to update the averages at the same time # exp_avg = exp_avg * beta1 + (1-beta1)*grad exp_avg.mul_(beta1).add_(1 - beta1, grad) @@ -708,8 +708,8 @@ def step(self, closure=None): # is *not* the correct way of using L2/weight decay with Adam, # since it will interact with m/v parameters in strange ways. # - # Instead we want to decay the weights in a manner that doesn't - # interact with the m/v. This is equivalent to add the square + # Instead we want to decay the weights that does not interact + # with the m/v. This is equivalent to add the square # of the weights to the loss with plain (non-momentum) SGD. if group['weight_decay'] > 0.0: update += group['weight_decay'] * p.data diff --git a/onmt/utils/report_manager.py b/onmt/utils/report_manager.py index 06a46e1e9b..4db7cc2c2c 100644 --- a/onmt/utils/report_manager.py +++ b/onmt/utils/report_manager.py @@ -134,7 +134,10 @@ def _report_training(self, step, num_steps, learning_rate, "progress", learning_rate, self.progress_step) - report_stats = onmt.utils.Statistics() + if isinstance(report_stats, onmt.utils.BertStatistics): + report_stats = onmt.utils.BertStatistics() + else: + report_stats = onmt.utils.Statistics() return report_stats @@ -144,7 +147,12 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None): """ if train_stats is not None: self.log('Train perplexity: %g' % train_stats.ppl()) - self.log('Train accuracy: %g' % train_stats.accuracy()) + if isinstance(train_stats, onmt.utils.BertStatistics) \ + and train_stats.accuracy() is None: + accuracy = train_stats.sentence_accuracy() + else: + accuracy = train_stats.accuracy() + self.log('Train accuracy: %g' % accuracy) self.maybe_log_tensorboard(train_stats, "train", @@ -153,7 +161,12 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None): if valid_stats is not None: self.log('Validation perplexity: %g' % valid_stats.ppl()) - self.log('Validation accuracy: %g' % valid_stats.accuracy()) + if isinstance(valid_stats, onmt.utils.BertStatistics) \ + and valid_stats.accuracy() is None: + accuracy = valid_stats.sentence_accuracy() + else: + accuracy = valid_stats.accuracy() + self.log('Validation accuracy: %g' % accuracy) self.maybe_log_tensorboard(valid_stats, "valid", diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index 1afd976e22..14d7cb38e1 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -141,7 +141,7 @@ class BertStatistics(Statistics): def __init__(self, loss=0, n_words=0, n_correct=0, n_sentence=0, n_correct_sentence=0): super(BertStatistics, self).__init__(loss, n_words, n_correct) - self.n_update = 0 if n_words == 0 else 1 + self.n_update = 0 if n_words == 0 and n_sentence == 0 else 1 self.n_sentence = n_sentence self.n_correct_sentence = n_correct_sentence @@ -206,40 +206,40 @@ def output(self, step, num_steps, learning_rate, start): if self.n_words == 0: # sentence level task logger.info( ("Step %s; acc(sent):%6.2f; ppl: %5.2f; " + - "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + "xent: %4.2f; lr: %7.5f; %3.0f tok/%3.0f sent/s; %6.0f sec") % (step_fmt, - self.sentence_accuracy(), - self.ppl(), - self.xent(), - learning_rate, - self.n_src_words / (t + 1e-5), - self.n_words / (t + 1e-5), - time.time() - start)) + self.sentence_accuracy(), + self.ppl(), + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_sentence / (t + 1e-5), + time.time() - start)) elif self.n_sentence == 0: # token level task logger.info( ("Step %s; acc(token):%6.2f; ppl: %5.2f; " + - "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") % (step_fmt, - self.accuracy(), - self.ppl(), - self.xent(), - learning_rate, - self.n_src_words / (t + 1e-5), - self.n_words / (t + 1e-5), - time.time() - start)) + self.accuracy(), + self.ppl(), + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_words / (t + 1e-5), + time.time() - start)) else: # pretraining logger.info( ("Step %s; acc(mlm/nx):%6.2f/%6.2f; total ppl: %5.2f; " + - "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") % (step_fmt, - self.accuracy(), - self.sentence_accuracy(), - self.ppl(), - self.xent(), - learning_rate, - self.n_src_words / (t + 1e-5), - self.n_words / (t + 1e-5), - time.time() - start)) + self.accuracy(), + self.sentence_accuracy(), + self.ppl(), + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_words / (t + 1e-5), + time.time() - start)) sys.stdout.flush() def log_tensorboard(self, prefix, writer, learning_rate, step): diff --git a/train.py b/train.py index 2b56b4c052..4fba91678c 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,8 @@ def main(opt): checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) if 'vocab' in checkpoint: - logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) + logger.info('Loading vocab from checkpoint at %s.' + % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') @@ -127,8 +128,17 @@ def next_batch(device_id): else: b.tokens = b.tokens.to(torch.device(device_id)) b.segment_ids = b.segment_ids.to(torch.device(device_id)) - b.is_next = b.is_next.to(torch.device(device_id)) - b.lm_labels_ids = b.lm_labels_ids.to(torch.device(device_id)) + if opt.task_type == 'pretraining': + b.is_next = b.is_next.to(torch.device(device_id)) + b.lm_labels_ids = b.lm_labels_ids.to(torch.device(device_id)) + elif opt.task_type == 'classification': + b.category = b.category.to(torch.device(device_id)) + elif opt.task_type == 'prediction': + b.token_labels_ids = b.token_labels_ids.to( + torch.device(device_id)) + else: + raise ValueError("task type Error") + else: if isinstance(b.src, tuple): b.src = tuple([_.to(torch.device(device_id)) From ea14b13e1925f023afb4f34b7db755f8224a9eee Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Fri, 26 Jul 2019 18:37:29 +0200 Subject: [PATCH 10/28] update --- onmt/inputters/inputter.py | 6 - onmt/model_builder.py | 6 +- onmt/utils/optimizers.py | 8 - onmt/utils/statistics.py | 4 +- pregenerate_bert_training_data.py | 6 +- preprocess_bert.py | 256 ++++++++++++++++++++++++++++++ 6 files changed, 264 insertions(+), 22 deletions(-) create mode 100644 preprocess_bert.py diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index f8a8b60dbd..72a3e82e3c 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -140,12 +140,6 @@ def get_fields( def get_bert_fields(task='pretraining', pad='[PAD]', bos='[CLS]', eos='[SEP]', unk='[UNK]'): fields = {} - # tokens_kwargs = {"n_feats": 0, - # "include_lengths": True, - # "pad": "[PAD]", "bos": "[CLS]", "eos": "[SEP]", - # "truncate": src_truncate, - # "base_name": "tokens"} - # fields["tokens"] = text_fields(**tokens_kwargs) tokens = Field(sequential=True, use_vocab=True, pad_token=pad, unk_token=unk, include_lengths=True, batch_first=True) fields["tokens"] = tokens diff --git a/onmt/model_builder.py b/onmt/model_builder.py index abcbc9d6a6..16ee7d44cc 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -337,16 +337,16 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): if model_opt.param_init_normal != 0.0: logger.info('Initialize weights using a normal distribution') normal_std = model_opt.param_init_normal - for p in model.sub_module.parameters(): + for p in getattr(model, sub_module).parameters(): p.data.normal_(mean=0, std=normal_std) elif model_opt.param_init != 0.0: logger.info('Initialize weights using a uniform distribution') - for p in model.sub_module.parameters(): + for p in getattr(model, sub_module).parameters(): p.data.uniform_(-model_opt.param_init, model_opt.param_init) elif model_opt.param_init_glorot: logger.info('Glorot initialization') - for p in model.sub_module.parameters(): + for p in getattr(model, sub_module).parameters(): if p.dim() > 1: xavier_uniform_(p) else: diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index c1e2c4773d..ee835d496d 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -192,8 +192,6 @@ def linear_decay(step, warmup_steps, total_steps): """ if not 0 <= warmup_steps < total_steps: raise ValueError("Invalid decay: check warmup_step & train_steps") - if step > total_steps: - raise ValueError("Invalid step: step surpass train_steps!") if step < warmup_steps: return step / warmup_steps * 1.0 else: @@ -217,8 +215,6 @@ def cosine_decay(step, warmup_steps, total_steps, cycles=0.5): if not 0 <= warmup_steps < total_steps: raise ValueError("Invalid decay: check warmup_step & train_steps") - if step > total_steps: - raise ValueError("Invalid step: step surpass train_steps!") if step < warmup_steps: return step / warmup_steps * 1.0 else: @@ -236,8 +232,6 @@ def cosine_hard_restart_decay(step, warmup_steps, total_steps, cycles=1.0): assert(cycles >= 1.0) if not 0 <= warmup_steps < total_steps: raise ValueError("Invalid decay: check warmup_step & train_steps") - if step > total_steps: - raise ValueError("Invalid step: step surpass train_steps!") if step < warmup_steps: return step / warmup_steps * 1.0 else: @@ -252,8 +246,6 @@ def cosine_warmup_restart_decay(step, warmup_steps, total_steps, cycles=1.0): """ if not 0 <= warmup_steps < total_steps: raise ValueError("Invalid decay: check warmup_step & train_steps") - if step > total_steps: - raise ValueError("Invalid step: step surpass train_steps!") if not cycles * warmup_steps / total_steps < 1.0: raise ValueError("Invalid decay: Error for decay! Check cycles!") warmup_ratio = warmup_steps * cycles / total_steps diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index 14d7cb38e1..eb39df71f7 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -248,10 +248,10 @@ def log_tensorboard(self, prefix, writer, learning_rate, step): writer.add_scalar(prefix + "/xent", self.xent(), step) writer.add_scalar(prefix + "/ppl", self.ppl(), step) if self.n_words != 0: - writer.add_scalar(prefix + "/accuracy(token)", + writer.add_scalar(prefix + "/accuracy_token", self.accuracy(), step) if self.n_sentence != 0: - writer.add_scalar(prefix + "/accuracy(sent)", + writer.add_scalar(prefix + "/accuracy_sent", self.sentence_accuracy(), step) writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) writer.add_scalar(prefix + "/lr", learning_rate, step) diff --git a/pregenerate_bert_training_data.py b/pregenerate_bert_training_data.py index 9539323d57..8f2579f66b 100755 --- a/pregenerate_bert_training_data.py +++ b/pregenerate_bert_training_data.py @@ -264,7 +264,7 @@ def main(): # choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", # "bert-base-multilingual", "bert-base-chinese"]) # parser.add_argument("--vocab_pathname", type=Path, required=True) # vocab file correspand to bert_model - + parser.add_argument("--do_lower_case", default=True) # action="store_true") parser.add_argument("--reduce_memory", action="store_true", @@ -333,7 +333,7 @@ def main(): dataset.save(epoch_filename) num_doc_instances = len(docs_instances) print("output file {}, num_example {}, max_seq_len {}".format(epoch_filename,num_doc_instances,args.max_seq_len)) - + metrics_file = args.output_dir / f"{args.output_name}.metrics.{args.corpus_type}.{epoch}.json" with metrics_file.open('w') as metrics_file: metrics = { @@ -350,7 +350,7 @@ def main(): fields = _build_bert_fields_vocab(fields, counters, vocab_size, args.tokens_min_frequency, args.vocab_size_multiple) # bert_vocab_file = args.output_dir / f"{args.output_name}.vocab.pt" torch.save(fields, bert_vocab_file) - + if __name__ == '__main__': main() diff --git a/preprocess_bert.py b/preprocess_bert.py new file mode 100644 index 0000000000..ee23bebfe3 --- /dev/null +++ b/preprocess_bert.py @@ -0,0 +1,256 @@ +from argparse import ArgumentParser +from tqdm import tqdm +import csv +from random import random +from onmt.utils.bert_tokenization import BertTokenizer, \ + PRETRAINED_VOCAB_ARCHIVE_MAP +import json +from onmt.inputters.inputter import get_bert_fields, \ + _build_bert_fields_vocab +from onmt.inputters.dataset_bert import BertDataset +from collections import Counter, defaultdict +import torch +import os + + +def truncate_seq(tokens, max_num_tokens): + """Truncates a sequences to a maximum sequence length.""" + while True: + total_length = len(tokens) + if total_length <= max_num_tokens: + break + assert len(tokens) >= 1 + # We want to sometimes truncate from the front and sometimes + # from the back to add more randomness and avoid biases. + if random() < 0.5: + del tokens[0] + else: + tokens.pop() + + +def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): + """Truncates a pair of sequences to a maximum sequence length. + Lifted from Google's BERT repo.""" + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_num_tokens: + break + trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b + assert len(trunc_tokens) >= 1 + + # We want to sometimes truncate from the front and sometimes from the + # back to add more randomness and avoid biases. + if random() < 0.5: + del trunc_tokens[0] + else: + trunc_tokens.pop() + + +def create_sentence_instance(sentence, tokenizer, max_seq_length): + tokens = tokenizer.tokenize(sentence) + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 2 + if len(tokens) > max_num_tokens: + tokens = tokens[:max_num_tokens] + tokens_processed = ["[CLS]"] + tokens + ["[SEP]"] + segment_ids = [0 for _ in range(len(tokens) + 2)] + return tokens_processed, segment_ids + + +def create_sentence_pair_instance(sent_a, sent_b, tokenizer, max_seq_length): + tokens_a = tokenizer.tokenize(sent_a) + tokens_b = tokenizer.tokenize(sent_b) + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 3 + truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) + tokens_processed = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] + segment_ids = [0 for _ in range(len(tokens_a) + 2)] + \ + [1 for _ in range(len(tokens_b) + 1)] + return tokens_processed, segment_ids + + +def create_instances(records, skip_head, tokenizer, max_seq_length, + column_a, column_b, label_column, labels): + instances = [] + for _i, record in tqdm(enumerate(records), desc="Process", unit=" lines"): + if _i == 0 and skip_head: + continue + else: + sentence_a = record[column_a].strip() + if column_b is not None: + sentence_b = record[column_b].strip() + else: + sentence_b = None + if label_column is not None: + label = record[label_column].strip() + target = None + for i, label_name in enumerate(labels): + if label == label_name: + target = i + if target is None: + raise ValueError("Unregconizable label: %s" % label) + else: + target = -1 + if column_b is None: + tokens_processed, segment_ids = create_sentence_instance( + sentence_a, tokenizer, max_seq_length) + else: + tokens_processed, segment_ids = create_sentence_pair_instance( + sentence_a, sentence_b, tokenizer, max_seq_length) + instance = { + "tokens": tokens_processed, + "segment_ids": segment_ids, + "category": target} + instances.append(instance) + return instances + + +def build_instances_from_csv(data, skip_head, tokenizer, input_columns, + label_column, labels, max_seq_len): + with open(data, "r", encoding="utf-8-sig") as csvfile: + reader = csv.reader(csvfile, delimiter='\t', quotechar=None) + lines = list(reader) + print("total {} line loaded: ".format(len(lines))) + if len(input_columns) == 1: + column_a = int(input_columns[0]) + column_b = None + else: + column_a = int(input_columns[0]) + column_b = int(input_columns[1]) + instances = create_instances(lines, skip_head, tokenizer, max_seq_len, + column_a, column_b, label_column, labels) + return instances + + +def _build_bert_vocab(vocab, name, counters, min_freq=0): + """ similar to _load_vocab in inputter.py, but build from a vocab list. + in place change counters + """ + vocab_size = len(vocab) + for i, token in enumerate(vocab): + counters[name][token] = vocab_size - i + min_freq + return vocab, vocab_size + + +def build_vocab_from_tokenizer(fields, tokenizer, tokens_min_frequency=0, + vocab_size_multiple=1): + vocab_list = list(tokenizer.vocab.keys()) + counters = defaultdict(Counter) + _, vocab_size = _build_bert_vocab(vocab_list, "tokens", counters) + fields_vocab = _build_bert_fields_vocab(fields, counters, vocab_size, + tokens_min_frequency, + vocab_size_multiple) + return fields_vocab + + +def save_data_as_json(instances, json_name): + instances_json = [json.dumps(instance) for instance in instances] + num_instances = 0 + with open(json_name, 'w') as json_file: + for instance in instances_json: + json_file.write(instance + '\n') + num_instances += 1 + return num_instances + + +def validate_preprocess_bert_opts(opts): + assert opts.bert_model in PRETRAINED_VOCAB_ARCHIVE_MAP.keys(), \ + "Unsupported Pretrain model '%s'" % (opts.bert_model) + + assert os.path.isfile(opts.data), "Please check path of %s" % opts.data + + if opts.data_type == "csv": + assert len(opts.input_columns) in [1, 2],\ + "Please indicate N.colomn for sentence A (and B)" + + +def main(): + parser = ArgumentParser() + parser.add_argument('--data', type=str, default=None, required=True, + help="input data to prepare: path/filename.suffix") + parser.add_argument('--data_type', type=str, default="csv", + help="input data type") + parser.add_argument('--skip_head', action="store_true", + help="If csv file contain head line.") + + parser.add_argument('--input_columns', nargs='+', default=[None], + help="Column numbers where contain sentence A(,B)") + parser.add_argument('--label_column', type=int, default=None, + help="Column number where contain label") + parser.add_argument('--labels', nargs='+', default=[None], + help="labels of sentence") + + parser.add_argument('--task', type=str, default="classification", + choices=["classification", "generation"], + help="Target task to perform") + parser.add_argument("--corpus_type", type=str, default="train", + choices=["train", "valid", "test"]) + parser.add_argument('--save_data', '-save_data', type=str, + default=None, required=True, + help="Output file Prefix for the prepared data") + parser.add_argument("--bert_model", type=str, + default="bert-base-multilingual-uncased", + choices=["bert-base-uncased", "bert-large-uncased", + "bert-base-cased", "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese"], + help="Bert pretrained model to finetuning with.") + + parser.add_argument("--do_lower_case", action="store_true", + help='lowercase data') + parser.add_argument("--max_seq_len", type=int, default=512, + help="max sequence length for prepared data," + "set the limite of position encoding") + parser.add_argument("--tokens_min_frequency", type=int, default=0) + parser.add_argument("--vocab_size_multiple", type=int, default=1) + parser.add_argument("--save_json", action="store_true", + help='save a copy of data in json form.') + + args = parser.parse_args() + validate_preprocess_bert_opts(args) + + print("Load data file %s with skip head %s" % (args.data, args.skip_head)) + input_columns = args.input_columns + label_column = args.label_column + print("Input column at {}, label'{}'".format(input_columns, label_column)) + print("Task: '%s', model: '%s', corpus: '%s'." + % (args.task, args.bert_model, args.corpus_type)) + + fields = get_bert_fields(args.task) + tokenizer = BertTokenizer.from_pretrained( + args.bert_model, do_lower_case=args.do_lower_case) + + # Build instances from csv file + if args.data_type == 'csv': + instances = build_instances_from_csv( + args.data, args.skip_head, tokenizer, + input_columns, label_column, args.labels, args.max_seq_len) + else: + raise NotImplementedError("Not support other file type yet!") + + onmt_filename = args.save_data + ".{}.0.pt".format(args.corpus_type) + # Build BertDataset from instances collected from different document + dataset = BertDataset(fields, instances) + dataset.save(onmt_filename) + print("save processed data {}, num_example {}, max_seq_len {}".format( + onmt_filename, len(instances), args.max_seq_len)) + + if args.save_json: + json_name = args.save_data + ".{}.json".format(args.corpus_type) + num_instances = save_data_as_json(instances, json_name) + print("output file {}, num_example {}, max_seq_len {}".format( + json_name, num_instances, args.max_seq_len)) + + # Build file Vocab.pt from tokenizer + if args.corpus_type == "train": + print("Generating vocab from corresponding text file...") + fields_vocab = build_vocab_from_tokenizer(fields, tokenizer, + args.tokens_min_frequency, + args.vocab_size_multiple) + bert_vocab_file = args.save_data + ".vocab.pt" + torch.save(fields_vocab, bert_vocab_file) + + +if __name__ == '__main__': + main() From 3fae4467c7bdfe63ad721a4b843bf1fc5926d6cd Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Fri, 9 Aug 2019 19:23:33 +0200 Subject: [PATCH 11/28] fix bug; add new feature --- bert_ckp_convert.py | 19 +- onmt/encoders/bert.py | 2 +- onmt/inputters/inputter.py | 31 ++- onmt/model_builder.py | 13 +- onmt/models/__init__.py | 5 +- onmt/models/bert_generators.py | 47 ++++- onmt/models/model_saver.py | 2 +- onmt/trainer.py | 9 +- onmt/utils/loss.py | 80 ++++--- onmt/utils/report_manager.py | 12 +- onmt/utils/statistics.py | 19 +- pregenerate_bert_training_data.py | 5 +- preprocess_bert.py | 333 ++++++++++++++++++++++++------ train.py | 4 +- 14 files changed, 441 insertions(+), 140 deletions(-) mode change 100644 => 100755 preprocess_bert.py diff --git a/bert_ckp_convert.py b/bert_ckp_convert.py index a5eb55fd22..10a45ca2c3 100755 --- a/bert_ckp_convert.py +++ b/bert_ckp_convert.py @@ -64,7 +64,7 @@ def mapping_key(key, max_layers): key = re.sub(r'generator.mask_lm.transform.layer_norm\.(.*)', r'cls.predictions.transform.LayerNorm.\1', key) else: - raise ValueError("Unexpected keys!") + raise KeyError("Unexpected keys! Please provide HuggingFace weights") return key @@ -74,9 +74,26 @@ def convert_bert_weights(bert_model, weights, n_layers=12): generator_weights = OrderedDict() model_weights = {"bert": bert_weights, "generator": generator_weights} + hugface_keys = weights.keys() try: for key in bert_model_keys: hugface_key = mapping_key(key, n_layers) + if hugface_key not in hugface_keys: + if 'LayerNorm' in hugface_key: + # Fix LayerNorm of old huggingface ckp + hugface_key = re.sub(r'LayerNorm.weight', + r'LayerNorm.gamma', hugface_key) + hugface_key = re.sub(r'LayerNorm.bias', + r'LayerNorm.beta', hugface_key) + if hugface_key in hugface_keys: + print("[OLD Weights file]gamma/beta is used in " + + "naming BertLayerNorm.") + else: + raise KeyError("Key %s not found in weight file" + % hugface_key) + else: + raise KeyError("Key %s not found in weight file" + % hugface_key) if 'generator' not in key: onmt_key = re.sub(r'bert\.(.*)', r'\1', key) model_weights['bert'][onmt_key] = weights[hugface_key] diff --git a/onmt/encoders/bert.py b/onmt/encoders/bert.py index 309f222bd6..4d935d1cf5 100644 --- a/onmt/encoders/bert.py +++ b/onmt/encoders/bert.py @@ -50,7 +50,7 @@ def forward(self, input_ids, token_type_ids=None, input_mask=None, input_mask: shape [batch, seq], 1 for masked position(that padding) output_all_encoded_layers: if out contain all hidden layer Returns: - all_encoder_layers: list of out in shape (batch, src, d_model), + all_encoder_layers: list of out in shape (batch, seq, d_model), to be used for generation task pooled_output: shape (batch, d_model), to be used for classification task diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index 72a3e82e3c..554d788ef2 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -9,7 +9,7 @@ import torch import torchtext.data -from torchtext.data import Field, RawField +from torchtext.data import Field, RawField, LabelField from torchtext.vocab import Vocab from torchtext.data.utils import RandomShuffler @@ -144,7 +144,7 @@ def get_bert_fields(task='pretraining', pad='[PAD]', bos='[CLS]', unk_token=unk, include_lengths=True, batch_first=True) fields["tokens"] = tokens - segment_ids = Field(use_vocab=False, dtype=torch.long, + segment_ids = Field(use_vocab=False, dtype=torch.long, unk_token=None, sequential=True, pad_token=0, batch_first=True) fields["segment_ids"] = segment_ids if task == 'pretraining': @@ -157,14 +157,14 @@ def get_bert_fields(task='pretraining', pad='[PAD]', bos='[CLS]', fields["lm_labels_ids"] = lm_labels_ids elif task == 'classification': - category = Field(use_vocab=False, dtype=torch.long, - sequential=False, batch_first=True) # 0/1 + category = LabelField(sequential=False, use_vocab=True, + pad_token=None, batch_first=True) fields["category"] = category - elif task == 'generation': - token_labels_ids = Field(sequential=True, use_vocab=False, - pad_token=-1, batch_first=True) - fields["token_labels_ids"] = token_labels_ids + elif task == 'generation' or task == 'tagging': + token_labels = Field(sequential=True, use_vocab=True, unk_token=None, + pad_token=pad, batch_first=True) + fields["token_labels"] = token_labels else: raise ValueError("task %s has not been implemented yet!" % task) @@ -383,8 +383,8 @@ def _build_fields_vocab(fields, counters, data_type, share_vocab, return fields -def _build_bert_fields_vocab(fields, counters, vocab_size, - tokens_min_frequency=0, vocab_size_multiple=1): +def _build_bert_fields_vocab(fields, counters, vocab_size, label_name=None, + tokens_min_frequency=1, vocab_size_multiple=1): tokens_field = fields["tokens"] tokens_counter = counters["tokens"] # NOTE: Do not use _build_field_vocab @@ -398,6 +398,15 @@ def _build_bert_fields_vocab(fields, counters, vocab_size, if vocab_size_multiple > 1: _pad_vocab_to_multiple(tokens_field.vocab, vocab_size_multiple) + if label_name is not None: + label_field = fields[label_name] + label_counter = counters[label_name] + all_specials = [label_field.unk_token, label_field.pad_token, + label_field.init_token, label_field.eos_token] + specials = [tok for tok in all_specials if tok is not None] + + label_field.vocab = label_field.vocab_cls( + label_counter, specials=specials) return fields @@ -630,7 +639,7 @@ def __iter__(self): instead of a torchtext.data.Batch object. """ while True: - self.init_epoch() + self.init_epoch() # Inside, create_batches() will be called for idx, minibatch in enumerate(self.batches): # fast-forward if loaded from state if self._iterations_this_epoch > idx: diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 16ee7d44cc..d4b69946a5 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -20,7 +20,7 @@ from onmt.utils.parse import ArgumentParser from onmt.models import BertPreTrainingHeads, ClassificationHead, \ - TokenGenerationHead + TokenGenerationHead, TokenTaggingHead from onmt.modules.bert_embeddings import BertEmbeddings from collections import OrderedDict @@ -259,6 +259,8 @@ def build_bert_generator(model_opt, fields, bert_encoder): only all_encoder_layers will be used for generation generator; """ task = model_opt.task_type + dropout = model_opt.dropout[0] if type(model_opt.dropout) is list \ + else model_opt.dropout if task == 'pretraining': generator = BertPreTrainingHeads( bert_encoder.d_model, bert_encoder.embeddings.vocab_size) @@ -272,8 +274,13 @@ def build_bert_generator(model_opt, fields, bert_encoder): generator.decode.weight = \ bert_encoder.embeddings.word_embeddings.weight elif task == 'classification': - n_class = model_opt.classification - generator = ClassificationHead(bert_encoder.d_model, n_class) + n_class = len(fields["category"].vocab.stoi) #model_opt.labels + logger.info('Generator of classification with %s class.' % n_class) + generator = ClassificationHead(bert_encoder.d_model, n_class, dropout) + elif task == 'tagging': + n_class = len(fields["token_labels"].vocab.stoi) + logger.info('Generator of tagging with %s tag.' % n_class) + generator = TokenTaggingHead(bert_encoder.d_model, n_class, dropout) return generator diff --git a/onmt/models/__init__.py b/onmt/models/__init__.py index 0185af2b46..f922c8a858 100644 --- a/onmt/models/__init__.py +++ b/onmt/models/__init__.py @@ -2,8 +2,9 @@ from onmt.models.model_saver import build_model_saver, ModelSaver from onmt.models.model import NMTModel from onmt.models.bert_generators import BertPreTrainingHeads,\ - ClassificationHead, TokenGenerationHead + ClassificationHead, TokenGenerationHead, TokenTaggingHead __all__ = ["build_model_saver", "ModelSaver", "NMTModel", "BertPreTrainingHeads", "ClassificationHead", - "TokenGenerationHead" ,"check_sru_requirement"] + "TokenGenerationHead", "TokenTaggingHead", + "check_sru_requirement"] diff --git a/onmt/models/bert_generators.py b/onmt/models/bert_generators.py index b467db12d6..459c55dbda 100644 --- a/onmt/models/bert_generators.py +++ b/onmt/models/bert_generators.py @@ -107,32 +107,65 @@ def forward(self, hidden_states): class ClassificationHead(nn.Module): """ - n-class classification head + n-class Sentence classification head """ - def __init__(self, hidden_size, n_class): + def __init__(self, hidden_size, n_class, dropout=0.1): """ Args: hidden_size: BERT model output size """ super(ClassificationHead, self).__init__() + self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(hidden_size, n_class) self.softmax = nn.LogSoftmax(dim=-1) def forward(self, all_hidden, pooled): """ Args: - all_hidden: first output of Bert encoder (batch, src, d_model) - pooled: second output of bert encoder, shape (batch, src, d_model) + all_hidden: layer output of BERT, list [(batch, seq, d_model)] + pooled: last layer hidden [CLS] of BERT, (batch, d_model) Returns: class_log_prob: shape (batch_size, 2) None: this is a placeholder for token level prediction task """ + pooled = self.dropout(pooled) score = self.linear(pooled) # (batch, n_class) class_log_prob = self.softmax(score) # (batch, n_class) return class_log_prob, None +class TokenTaggingHead(nn.Module): + """ + n-class Token Tagging head + """ + + def __init__(self, hidden_size, n_class, dropout=0.1): + """ + Args: + hidden_size: BERT model output size + """ + super(TokenTaggingHead, self).__init__() + self.dropout = nn.Dropout(dropout) + self.linear = nn.Linear(hidden_size, n_class) + self.softmax = nn.LogSoftmax(dim=-1) + + def forward(self, all_hidden, pooled): + """ + Args: + all_hidden: layer output of BERT, list [(batch, seq, d_model)] + pooled: last layer hidden [CLS] of BERT, (batch, d_model) + Returns: + None: this is a placeholder for sentence level task + tok_class_log_prob: shape (batch, seq, n_class) + """ + last_hidden = all_hidden[-1] + last_hidden = self.dropout(last_hidden) # (batch, seq, d_model) + score = self.linear(last_hidden) # (batch, seq, n_class) + tok_class_log_prob = self.softmax(score) # (batch, seq, n_class) + return None, tok_class_log_prob + + class TokenGenerationHead(nn.Module): """ Token generation head: generation token from input sequence @@ -153,15 +186,15 @@ def __init__(self, hidden_size, vocab_size): self.softmax = nn.LogSoftmax(dim=-1) - def forward(self, x, pooled): + def forward(self, all_hidden, pooled): """ Args: - x: last layer output of bert, shape (batch, seq, d_model) + all_hidden: layer output of BERT, list [(batch, seq, d_model)] Returns: None: this is a placeholder for sentence level task prediction_log_prob: shape (batch, seq, vocab) """ - last_hidden = x[-1] + last_hidden = all_hidden[-1] y = self.transform(last_hidden) # (batch, seq, d_model) prediction_scores = self.decode(y) + self.bias # (batch, seq, vocab) prediction_log_prob = self.softmax(prediction_scores) diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index 4028e8f285..2b67929de4 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -118,7 +118,7 @@ def _save(self, step, model): for name, field in vocab.items(): if isinstance(field, Field): if hasattr(field, "vocab") and \ - hasattr(field, "unk_token"): + (field.unk_token is not None): assert name == 'tokens' keys_to_pop = [] unk_token = field.unk_token diff --git a/onmt/trainer.py b/onmt/trainer.py index 888cae3415..3c23fe635d 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -31,7 +31,14 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object used to save the model """ - tgt_field = dict(fields)["tgt"].base_field if not opt.is_bert else None + if not opt.is_bert: + tgt_field = dict(fields)["tgt"].base_field + elif opt.task_type == 'tagging' or opt.task_type == 'generation': + tgt_field = fields["token_labels"] + elif opt.task_type == 'classification': + tgt_field = fields["category"] + else: # pretraining task + tgt_field = fields["lm_labels_ids"] train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) valid_loss = onmt.utils.loss.build_loss_compute( model, tgt_field, opt, train=False) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 1de1691e65..a86f14bd6f 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -10,6 +10,7 @@ import onmt from onmt.modules.sparse_losses import SparsemaxLoss from onmt.modules.sparse_activations import LogSparsemax +from sklearn.metrics import f1_score def build_loss_compute(model, tgt_field, opt, train=True): @@ -24,9 +25,14 @@ def build_loss_compute(model, tgt_field, opt, train=True): device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") if opt.is_bert is True: assert hasattr(model, 'bert') - assert tgt_field is None - # BERT use -1 for unmasked token in lm_label_ids - criterion = nn.NLLLoss(ignore_index=-1, reduction='mean') + if tgt_field.pad_token is not None: + if tgt_field.use_vocab: + padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] + else: # target is pre-numerized: -1 for unmasked token in mlm + padding_idx = tgt_field.pad_token + criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='mean') + else: # sentence level + criterion = nn.NLLLoss(reduction='mean') task = opt.task_type compute = BertLoss(criterion, task) else: @@ -67,7 +73,7 @@ class BertLoss(nn.Module): def __init__(self, criterion, task): super(BertLoss, self).__init__() self.criterion = criterion - self.task =task + self.task = task @property def padding_idx(self): @@ -91,14 +97,14 @@ def _stats(self, loss, tokens_scores, tokens_target, """ if self.task == 'pretraining': # masked lm task: token level - pred_tokens = tokens_scores.argmax(1) # (batch*seq, vocab) -> (batch*seq) - non_padding = tokens_target.ne(self.padding_idx) # mask: (batch*seq) + pred_tokens = tokens_scores.argmax(1) # (B*S, V) -> (B*S) + non_padding = tokens_target.ne(self.padding_idx) # mask: (B*S) tokens_match = pred_tokens.eq(tokens_target).masked_select(non_padding) n_correct_tokens = tokens_match.sum().item() n_tokens = non_padding.sum().item() - + f1 = 0 # next sentence prediction task: sentence level - pred_sents = sents_scores.argmax(-1) # (batch_size, 2) -> (2) + pred_sents = sents_scores.argmax(-1) # (B, 2) -> (2) n_correct_sents = sents_target.eq(pred_sents).sum().item() n_sentences = len(sents_target) @@ -106,29 +112,46 @@ def _stats(self, loss, tokens_scores, tokens_target, # token level task: Not valide n_correct_tokens = 0 n_tokens = 0 + f1 = 0 # sentence level task: - pred_sents = sents_scores.argmax(-1) # (batch_size, n_label) -> (n_label) + pred_sents = sents_scores.argmax(-1) # (B, n_label) -> (n_label) n_correct_sents = sents_target.eq(pred_sents).sum().item() n_sentences = len(sents_target) + elif self.task == 'tagging': + # token level task: + pred_tokens = tokens_scores.argmax(1) # (B*S, V) -> (B*S) + non_padding = tokens_target.ne(self.padding_idx) # mask: (B*S) + tokens_match = pred_tokens.eq(tokens_target).masked_select(non_padding) + n_correct_tokens = tokens_match.sum().item() + n_tokens = non_padding.sum().item() + # for f1: + tokens_target_select = tokens_target.masked_select(non_padding) + pred_tokens_select = pred_tokens.masked_select(non_padding) + f1 = f1_score(tokens_target_select.cpu(), + pred_tokens_select.cpu(), average="micro") + + # sentence level task: Not valide + n_correct_sents = 0 + n_sentences = 0 + elif self.task == 'generation': # token level task: - pred_tokens = tokens_scores.argmax(1) # (batch*seq, vocab) -> (batch*seq) - non_padding = tokens_target.ne(self.padding_idx) # mask: (batch*seq) + pred_tokens = tokens_scores.argmax(1) # (B*S, V) -> (B*S) + non_padding = tokens_target.ne(self.padding_idx) # mask: (B*S) tokens_match = pred_tokens.eq(tokens_target).masked_select(non_padding) n_correct_tokens = tokens_match.sum().item() n_tokens = non_padding.sum().item() + f1 = 0 # sentence level task: Not valide n_correct_sents = 0 n_sentences = 0 else: - raise ValueError("task '{}' has not been implemented yet!".format(self.task)) + raise ValueError("task %s not available!" % (self.task)) - # print("lm: {}/{}".format(n_correct_tokens, n_tokens)) - # print("nx: {}/{}".format(n_correct_sents, n_sentences)) return onmt.utils.BertStatistics(loss.item(), n_tokens, n_correct_tokens, n_sentences, - n_correct_sents) + n_correct_sents, f1) def forward(self, batch, outputs): @@ -144,15 +167,14 @@ def forward(self, batch, outputs): if self.task == 'pretraining': assert list(seq_class_log_prob.size()) == [len(batch), 2] # masked lm task: token level(loss mean by number of tokens) - # targets: - gtruth_tokens = batch.lm_labels_ids # (batch, seq) - bottled_gtruth_tokens = gtruth_tokens.view(-1) # (batch * seq) - # prediction: (batch, seq, vocab) -> (batch * seq, vocab) + gtruth_tokens = batch.lm_labels_ids # (B, S) + bottled_gtruth_tokens = gtruth_tokens.view(-1) # (B, S) + # prediction: (B, S, V) -> (B * S, V) bottled_prediction_log_prob = self._bottle(prediction_log_prob) mask_loss = self.criterion(bottled_prediction_log_prob, bottled_gtruth_tokens) # next sentence prediction task: sentence level(mean by sentence) - gtruth_sentences = batch.is_next # (batch,) + gtruth_sentences = batch.is_next # (B,) next_loss = self.criterion(seq_class_log_prob, gtruth_sentences) total_loss = next_loss + mask_loss # total_loss reduced by mean @@ -166,25 +188,23 @@ def forward(self, batch, outputs): gtruth_sentences = batch.category total_loss = self.criterion(seq_class_log_prob, gtruth_sentences) - elif self.task == 'generation': + elif self.task == 'tagging' or self.task == 'generation': assert seq_class_log_prob is None - assert hasattr(batch, 'token_labels_ids') + assert hasattr(batch, 'token_labels') # token level task: loss mean by number of tokens - gtruth_tokens = batch.token_labels_ids # (batch, seq) - bottled_gtruth_tokens = gtruth_tokens.view(-1) # (batch * seq) - # prediction: (batch, seq, vocab) -> (batch * seq, vocab) + gtruth_tokens = batch.token_labels # (B, S) + bottled_gtruth_tokens = gtruth_tokens.view(-1) # (B, S) + # prediction: (B, S, V) -> (B * S, V) bottled_prediction_log_prob = self._bottle(prediction_log_prob) total_loss = self.criterion(bottled_prediction_log_prob, bottled_gtruth_tokens) # sentence level task: Not valide seq_class_log_prob = None gtruth_sentences = None + else: - raise ValueError("task '{}' has not been implemented yet!".format(self.task)) - # loss_accum_normalized = total_loss #/ float(normalization) - # print("loss: ({} + {})/{} = {}".format(next_loss, mask_loss, - # float(normalization), loss_accum_normalized)) - # print("nx: {}/{}".format(n_correct_sents, n_sentences)) + raise ValueError("task %s not available!" % (self.task)) + stats = self._stats(total_loss.clone(), bottled_prediction_log_prob, bottled_gtruth_tokens, diff --git a/onmt/utils/report_manager.py b/onmt/utils/report_manager.py index 4db7cc2c2c..c52c278308 100644 --- a/onmt/utils/report_manager.py +++ b/onmt/utils/report_manager.py @@ -147,12 +147,14 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None): """ if train_stats is not None: self.log('Train perplexity: %g' % train_stats.ppl()) - if isinstance(train_stats, onmt.utils.BertStatistics) \ - and train_stats.accuracy() is None: + if train_stats.accuracy() is None: + assert isinstance(train_stats, onmt.utils.BertStatistics) accuracy = train_stats.sentence_accuracy() else: accuracy = train_stats.accuracy() self.log('Train accuracy: %g' % accuracy) + if hasattr(train_stats, 'f1') and train_stats.f1 != 0: + self.log('Train F1: %g' % train_stats.f1) self.maybe_log_tensorboard(train_stats, "train", @@ -161,12 +163,14 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None): if valid_stats is not None: self.log('Validation perplexity: %g' % valid_stats.ppl()) - if isinstance(valid_stats, onmt.utils.BertStatistics) \ - and valid_stats.accuracy() is None: + if valid_stats.accuracy() is None: + assert isinstance(valid_stats, onmt.utils.BertStatistics) accuracy = valid_stats.sentence_accuracy() else: accuracy = valid_stats.accuracy() self.log('Validation accuracy: %g' % accuracy) + if hasattr(valid_stats, 'f1') and valid_stats.f1 != 0: + self.log('Validation F1: %g' % valid_stats.f1) self.maybe_log_tensorboard(valid_stats, "valid", diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index eb39df71f7..4ac26e3c32 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -139,11 +139,12 @@ def log_tensorboard(self, prefix, writer, learning_rate, step): class BertStatistics(Statistics): """ Bert Statistics as the loss is reduced by mean """ def __init__(self, loss=0, n_words=0, n_correct=0, - n_sentence=0, n_correct_sentence=0): + n_sentence=0, n_correct_sentence=0, f1=0): super(BertStatistics, self).__init__(loss, n_words, n_correct) self.n_update = 0 if n_words == 0 and n_sentence == 0 else 1 self.n_sentence = n_sentence self.n_correct_sentence = n_correct_sentence + self.f1 = f1 def accuracy(self): """ compute token level accuracy """ @@ -187,6 +188,8 @@ def update(self, stat, update_n_src_words=False): self.n_correct += stat.n_correct self.n_sentence += stat.n_sentence self.n_correct_sentence += stat.n_correct_sentence + self.f1 = (self.f1 * self.n_update + stat.f1 * + stat.n_update) / (self.n_update + stat.n_update) if update_n_src_words: self.n_src_words += stat.n_src_words @@ -203,7 +206,7 @@ def output(self, step, num_steps, learning_rate, start): step_fmt = "%2d" % step if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) - if self.n_words == 0: # sentence level task + if self.n_words == 0: # sentence level task: Acc, PPL, X-entropy logger.info( ("Step %s; acc(sent):%6.2f; ppl: %5.2f; " + "xent: %4.2f; lr: %7.5f; %3.0f tok/%3.0f sent/s; %6.0f sec") @@ -215,13 +218,13 @@ def output(self, step, num_steps, learning_rate, start): self.n_src_words / (t + 1e-5), self.n_sentence / (t + 1e-5), time.time() - start)) - elif self.n_sentence == 0: # token level task + elif self.n_sentence == 0: # token level task: Tok Acc, F1, X-entropy logger.info( - ("Step %s; acc(token):%6.2f; ppl: %5.2f; " + + ("Step %s; acc(token):%6.2f; f1: %5.4f; " + "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") % (step_fmt, self.accuracy(), - self.ppl(), + self.f1, self.xent(), learning_rate, self.n_src_words / (t + 1e-5), @@ -247,10 +250,12 @@ def log_tensorboard(self, prefix, writer, learning_rate, step): t = self.elapsed_time() writer.add_scalar(prefix + "/xent", self.xent(), step) writer.add_scalar(prefix + "/ppl", self.ppl(), step) - if self.n_words != 0: + if self.n_words != 0: # Token level task writer.add_scalar(prefix + "/accuracy_token", self.accuracy(), step) - if self.n_sentence != 0: + writer.add_scalar(prefix + "/F1", + self.f1, step) + if self.n_sentence != 0: # Sentence level task writer.add_scalar(prefix + "/accuracy_sent", self.sentence_accuracy(), step) writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) diff --git a/pregenerate_bert_training_data.py b/pregenerate_bert_training_data.py index 8f2579f66b..5fc1a205f1 100755 --- a/pregenerate_bert_training_data.py +++ b/pregenerate_bert_training_data.py @@ -1,5 +1,5 @@ """ -This file is massively inspired from huggingface and adapted into onmt custom. +This file is lifted from huggingface and adapted for onmt structure. Ref: https://github.com/huggingface/pytorch-transformers/blob/master/examples/lm_finetuning/pregenerate_training_data.py """ from argparse import ArgumentParser @@ -19,6 +19,7 @@ from collections import Counter, defaultdict import torch + class DocumentDatabase: def __init__(self, reduce_memory=False): if reduce_memory: @@ -347,7 +348,7 @@ def main(): vocab_list = list(tokenizer.vocab.keys()) counters = defaultdict(Counter) _, vocab_size = _build_bert_vocab(vocab_list, "tokens", counters) - fields = _build_bert_fields_vocab(fields, counters, vocab_size, args.tokens_min_frequency, args.vocab_size_multiple) # + fields = _build_bert_fields_vocab(fields, counters, vocab_size, None, args.tokens_min_frequency, args.vocab_size_multiple) # bert_vocab_file = args.output_dir / f"{args.output_name}.vocab.pt" torch.save(fields, bert_vocab_file) diff --git a/preprocess_bert.py b/preprocess_bert.py old mode 100644 new mode 100755 index ee23bebfe3..0fc4b2713e --- a/preprocess_bert.py +++ b/preprocess_bert.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from tqdm import tqdm import csv -from random import random +from random import random, shuffle from onmt.utils.bert_tokenization import BertTokenizer, \ PRETRAINED_VOCAB_ARCHIVE_MAP import json @@ -11,6 +11,7 @@ from collections import Counter, defaultdict import torch import os +import codecs def truncate_seq(tokens, max_num_tokens): @@ -46,12 +47,16 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): trunc_tokens.pop() -def create_sentence_instance(sentence, tokenizer, max_seq_length): +def create_sentence_instance(sentence, tokenizer, + max_seq_length, random_trunc=False): tokens = tokenizer.tokenize(sentence) # Account for [CLS], [SEP], [SEP] max_num_tokens = max_seq_length - 2 if len(tokens) > max_num_tokens: - tokens = tokens[:max_num_tokens] + if random_trunc is True: + truncate_seq(tokens, max_num_tokens) + else: + tokens = tokens[:max_num_tokens] tokens_processed = ["[CLS]"] + tokens + ["[SEP]"] segment_ids = [0 for _ in range(len(tokens) + 2)] return tokens_processed, segment_ids @@ -69,44 +74,41 @@ def create_sentence_pair_instance(sent_a, sent_b, tokenizer, max_seq_length): return tokens_processed, segment_ids -def create_instances(records, skip_head, tokenizer, max_seq_length, - column_a, column_b, label_column, labels): +def create_instances_from_csv(records, skip_head, tokenizer, max_seq_length, + column_a, column_b, label_column, labels): instances = [] for _i, record in tqdm(enumerate(records), desc="Process", unit=" lines"): if _i == 0 and skip_head: continue else: sentence_a = record[column_a].strip() - if column_b is not None: - sentence_b = record[column_b].strip() + if column_b is None: + tokens_processed, segment_ids = create_sentence_instance( + sentence_a, tokenizer, max_seq_length) else: - sentence_b = None + sentence_b = record[column_b].strip() + tokens_processed, segment_ids = create_sentence_pair_instance( + sentence_a, sentence_b, tokenizer, max_seq_length) if label_column is not None: label = record[label_column].strip() - target = None - for i, label_name in enumerate(labels): - if label == label_name: - target = i - if target is None: - raise ValueError("Unregconizable label: %s" % label) - else: - target = -1 - if column_b is None: - tokens_processed, segment_ids = create_sentence_instance( - sentence_a, tokenizer, max_seq_length) - else: - tokens_processed, segment_ids = create_sentence_pair_instance( - sentence_a, sentence_b, tokenizer, max_seq_length) - instance = { - "tokens": tokens_processed, - "segment_ids": segment_ids, - "category": target} - instances.append(instance) - return instances + if label not in labels: + labels.append(label) + instance = { + "tokens": tokens_processed, + "segment_ids": segment_ids, + "category": label} + else: # TODO: prediction dataset + label = None + instance = { + "tokens": tokens_processed, + "segment_ids": segment_ids, + "category": label} + instances.append(instance) + return instances, labels def build_instances_from_csv(data, skip_head, tokenizer, input_columns, - label_column, labels, max_seq_len): + label_column, labels, max_seq_len, do_shuffle): with open(data, "r", encoding="utf-8-sig") as csvfile: reader = csv.reader(csvfile, delimiter='\t', quotechar=None) lines = list(reader) @@ -117,28 +119,141 @@ def build_instances_from_csv(data, skip_head, tokenizer, input_columns, else: column_a = int(input_columns[0]) column_b = int(input_columns[1]) - instances = create_instances(lines, skip_head, tokenizer, max_seq_len, - column_a, column_b, label_column, labels) + instances, labels = create_instances_from_csv( + lines, skip_head, tokenizer, max_seq_len, + column_a, column_b, label_column, labels) + if do_shuffle is True: + print("Shuffle all {} instance".format(len(instances))) + shuffle(instances) + return instances, labels + + +def create_instances_from_file(records, label, tokenizer, max_seq_length): + instances = [] + for _i, record in tqdm(enumerate(records), desc="Process", unit=" lines"): + sentence = record.strip() + tokens_processed, segment_ids = create_sentence_instance( + sentence, tokenizer, max_seq_length, random_trunc=True) + instance = { + "tokens": tokens_processed, + "segment_ids": segment_ids, + "category": label} + instances.append(instance) return instances -def _build_bert_vocab(vocab, name, counters, min_freq=0): +def build_instances_from_files(data, labels, tokenizer, + max_seq_len, do_shuffle): + instances = [] + for filename in data: #zip(data, labels): + label = filename.split('/')[-2] + with codecs.open(filename, "r", encoding="utf-8") as f: + lines = f.readlines() + print("total {} line of File {} loaded for label: {}.".format( + len(lines), filename, label)) + file_instances = create_instances_from_file( + lines, label, tokenizer, max_seq_len) + instances.extend(file_instances) + if do_shuffle is True: + print("Shuffle all {} instance".format(len(instances))) + shuffle(instances) + return instances + + +def create_tag_instance_from_sentence(token_pairs, tokenizer, max_seq_len, + pad_tok): + """ + token_pairs: list of (word, tag) pair that form a sentence + tokenizer: tokenizer we use to tokenizer the words in token_pairs + max_seq_len: max sequence length that a instance could contain + """ + sentence = [] + tags = [] + max_num_tokens = max_seq_len - 2 + for (word, tag) in token_pairs: + tokens = tokenizer.tokenize(word) + n_pad = len(tokens) - 1 + paded_tag = [tag] + [pad_tok] * n_pad + if len(sentence) + len(tokens) > max_num_tokens: + break + else: + sentence.extend(tokens) + tags.extend(paded_tag) + sentence = ["[CLS]"] + sentence + ["[SEP]"] + tags = [pad_tok] + tags + [pad_tok] + segment_ids = [0 for _ in range(len(sentence))] + instance = { + "tokens": sentence, + "segment_ids": segment_ids, + "token_labels": tags + } + return instance + + +def build_tag_instances_from_file(filename, skip_head, tokenizer, max_seq_len, + token_column, tag_column, tags, do_shuffle, + pad_tok, delimiter=' '): + sentences = [] + labels = [] if tags is None else tags + with codecs.open(filename, "r", encoding="utf-8") as f: + lines = f.readlines() + if skip_head is True: + lines = lines[1:] + print("total {} line of file {} loaded.".format( + len(lines), filename)) + sentence_sofar = [] + for line in tqdm(lines, desc="Process", unit=" lines"): + line = line.strip() + if line is '': + if len(sentence_sofar) > 0: + sentences.append(sentence_sofar) + sentence_sofar = [] + else: + elements = line.split(delimiter) + token = elements[token_column] + tag = elements[tag_column] + if tag not in labels: + labels.append(tag) + sentence_sofar.append((token, tag)) + print("total {} sentence loaded.".format(len(sentences))) + print("All tags:", labels) + + instances = [] + for sentence in sentences: + instance = create_tag_instance_from_sentence( + sentence, tokenizer, max_seq_len, pad_tok) + instances.append(instance) + + if do_shuffle is True: + print("Shuffle all {} instance".format(len(instances))) + shuffle(instances) + return instances, labels + + +def _build_bert_vocab(vocab, name, counters): """ similar to _load_vocab in inputter.py, but build from a vocab list. in place change counters """ vocab_size = len(vocab) for i, token in enumerate(vocab): - counters[name][token] = vocab_size - i + min_freq + counters[name][token] = vocab_size - i return vocab, vocab_size -def build_vocab_from_tokenizer(fields, tokenizer, tokens_min_frequency=0, - vocab_size_multiple=1): +def build_vocab_from_tokenizer(fields, tokenizer, named_labels, + tokens_min_frequency=0, vocab_size_multiple=1): vocab_list = list(tokenizer.vocab.keys()) counters = defaultdict(Counter) _, vocab_size = _build_bert_vocab(vocab_list, "tokens", counters) + + if named_labels is not None: + label_name, label_list = named_labels + _, _ = _build_bert_vocab(label_list, label_name, counters) + else: + label_name = None + fields_vocab = _build_bert_fields_vocab(fields, counters, vocab_size, - tokens_min_frequency, + label_name, tokens_min_frequency, vocab_size_multiple) return fields_vocab @@ -156,38 +271,69 @@ def save_data_as_json(instances, json_name): def validate_preprocess_bert_opts(opts): assert opts.bert_model in PRETRAINED_VOCAB_ARCHIVE_MAP.keys(), \ "Unsupported Pretrain model '%s'" % (opts.bert_model) + for filename in opts.data: + assert os.path.isfile(filename),\ + "Please check path of %s" % filename - assert os.path.isfile(opts.data), "Please check path of %s" % opts.data + if args.task == "tagging": + assert args.data_type == 'txt',\ + "For sequence tagging, only txt file is supported." + + assert len(opts.input_columns) == 1,\ + "For sequence tagging, only one column for input tokens." + opts.input_columns = opts.input_columns[0] if opts.data_type == "csv": + assert len(opts.data) == 1,\ + "For csv, only one file is needed." assert len(opts.input_columns) in [1, 2],\ "Please indicate N.colomn for sentence A (and B)" - - -def main(): - parser = ArgumentParser() - parser.add_argument('--data', type=str, default=None, required=True, - help="input data to prepare: path/filename.suffix") + # if opts.label_column is not None: + # assert len(opts.labels) != 0,\ + # "label list is needed when csv contain label column" + + # elif opts.data_type == "txt": + # if opts.task == "classification": + # assert len(opts.datas) == len(opts.labels), \ + # "Label should correspond to input files" + return opts + + +def _get_parser(): + parser = ArgumentParser(description='preprocess_bert.py') + + parser.add_argument('--data', type=str, nargs='+', default=[], + required=True, help="input datas to prepare: [CLS]" + + "Single file for csv with column indicate label," + + "One file for each class as path/label/file; [TAG]" + + "Single file contain (tok, tag) in each line,"+ + "Sentence separated by blank line.") parser.add_argument('--data_type', type=str, default="csv", + choices=["csv", "txt"], help="input data type") parser.add_argument('--skip_head', action="store_true", - help="If csv file contain head line.") + help="CSV: If csv file contain head line.") - parser.add_argument('--input_columns', nargs='+', default=[None], - help="Column numbers where contain sentence A(,B)") + parser.add_argument('--input_columns', type=int, nargs='+', default=[], + help="CSV: Column where contain sentence A(,B)") parser.add_argument('--label_column', type=int, default=None, - help="Column number where contain label") - parser.add_argument('--labels', nargs='+', default=[None], - help="labels of sentence") + help="CSV: Column where contain label") + parser.add_argument('--labels', type=str, nargs='+', default=[], + help="CSV: labels of sentence;" + + "TXT: labels for sentence in files.") + parser.add_argument('--delimiter', '-d', type=str, default=' ', + help="CSV: delimiter used for seperate column.") parser.add_argument('--task', type=str, default="classification", - choices=["classification", "generation"], + choices=["classification", "tagging"], help="Target task to perform") parser.add_argument("--corpus_type", type=str, default="train", choices=["train", "valid", "test"]) parser.add_argument('--save_data', '-save_data', type=str, default=None, required=True, help="Output file Prefix for the prepared data") + parser.add_argument("--do_shuffle", action="store_true", + help='shuffle data') parser.add_argument("--bert_model", type=str, default="bert-base-multilingual-uncased", choices=["bert-base-uncased", "bert-large-uncased", @@ -206,14 +352,10 @@ def main(): parser.add_argument("--vocab_size_multiple", type=int, default=1) parser.add_argument("--save_json", action="store_true", help='save a copy of data in json form.') + return parser - args = parser.parse_args() - validate_preprocess_bert_opts(args) - print("Load data file %s with skip head %s" % (args.data, args.skip_head)) - input_columns = args.input_columns - label_column = args.label_column - print("Input column at {}, label'{}'".format(input_columns, label_column)) +def main(args): print("Task: '%s', model: '%s', corpus: '%s'." % (args.task, args.bert_model, args.corpus_type)) @@ -221,14 +363,55 @@ def main(): tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case) - # Build instances from csv file - if args.data_type == 'csv': - instances = build_instances_from_csv( - args.data, args.skip_head, tokenizer, - input_columns, label_column, args.labels, args.max_seq_len) - else: - raise NotImplementedError("Not support other file type yet!") - + if args.task == "classification": + # Build instances from csv file + if args.data_type == 'csv': + filename = args.data[0] + print("Load data file %s with skip head %s" % ( + filename, args.skip_head)) + input_columns = args.input_columns + label_column = args.label_column + print("Input column at {}, label at [{}]".format( + input_columns, label_column)) + instances, labels = build_instances_from_csv( + filename, args.skip_head, tokenizer, + input_columns, label_column, + args.labels, args.max_seq_len, args.do_shuffle) + labels.sort() + args.labels = labels + print("Labels:", args.labels) + elif args.data_type == 'txt': + if len(args.labels) == 0: + print("Build labels from file dir...") + labels = [] + for filename in args.data: + label = filename.split('/')[-2] + if label not in labels: + labels.append(label) + labels.sort() + args.labels = labels + print("Labels:", args.labels) + instances = build_instances_from_files( + args.data, args.labels, tokenizer, args.max_seq_len, args.do_shuffle) + else: + raise NotImplementedError("Not support other file type yet!") + + if args.task == "tagging": + pad_tok = fields["token_labels"].pad_token # "[PAD]" for Bert Paddings + filename = args.data[0] + print("Load data file %s with skip head %s" % ( + filename, args.skip_head)) + token_column = args.input_columns + tag_column = args.label_column + instances, labels = build_tag_instances_from_file( + filename, args.skip_head, tokenizer, args.max_seq_len, + token_column, tag_column, args.labels, args.do_shuffle, + pad_tok, delimiter=args.delimiter) + labels.sort() + args.labels = [pad_tok] + labels + print("Labels:", args.labels) + + # Save processed data in OpenNMT format onmt_filename = args.save_data + ".{}.0.pt".format(args.corpus_type) # Build BertDataset from instances collected from different document dataset = BertDataset(fields, instances) @@ -245,12 +428,26 @@ def main(): # Build file Vocab.pt from tokenizer if args.corpus_type == "train": print("Generating vocab from corresponding text file...") - fields_vocab = build_vocab_from_tokenizer(fields, tokenizer, - args.tokens_min_frequency, - args.vocab_size_multiple) + if args.task == "classification": + if len(args.labels) == 0: # TODO + raise AttributeError("Labels should be given") + else: + named_labels = ("category", args.labels) + print("Save Labels:", named_labels, "in vocab file.") + + if args.task == "tagging": + named_labels = ("token_labels", args.labels) + print("Save Labels:", named_labels, "in vocab file.") + + fields_vocab = build_vocab_from_tokenizer( + fields, tokenizer, named_labels, + args.tokens_min_frequency, args.vocab_size_multiple) bert_vocab_file = args.save_data + ".vocab.pt" torch.save(fields_vocab, bert_vocab_file) if __name__ == '__main__': - main() + parser = _get_parser() + args = parser.parse_args() + args = validate_preprocess_bert_opts(args) + main(args) diff --git a/train.py b/train.py index 4fba91678c..237560f429 100755 --- a/train.py +++ b/train.py @@ -133,8 +133,8 @@ def next_batch(device_id): b.lm_labels_ids = b.lm_labels_ids.to(torch.device(device_id)) elif opt.task_type == 'classification': b.category = b.category.to(torch.device(device_id)) - elif opt.task_type == 'prediction': - b.token_labels_ids = b.token_labels_ids.to( + elif opt.task_type == 'prediction' or opt.task_type == 'tagging': + b.token_labels = b.token_labels.to( torch.device(device_id)) else: raise ValueError("task type Error") From 4b511e4f363425cbf72fa0027233740e40a485be Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Tue, 13 Aug 2019 14:17:41 +0200 Subject: [PATCH 12/28] add prediction file --- bert_ckp_convert.py | 14 +- onmt/inputters/__init__.py | 7 +- onmt/inputters/dataset_bert.py | 143 ++++++++++++++ onmt/model_builder.py | 18 ++ onmt/opts.py | 53 +++++ onmt/translate/__init__.py | 3 +- onmt/translate/predictor.py | 352 +++++++++++++++++++++++++++++++++ onmt/utils/parse.py | 23 +++ predict.py | 56 ++++++ preprocess_bert.py | 112 +++-------- 10 files changed, 681 insertions(+), 100 deletions(-) create mode 100644 onmt/translate/predictor.py create mode 100755 predict.py diff --git a/bert_ckp_convert.py b/bert_ckp_convert.py index 10a45ca2c3..7fe778d7b6 100755 --- a/bert_ckp_convert.py +++ b/bert_ckp_convert.py @@ -87,7 +87,7 @@ def convert_bert_weights(bert_model, weights, n_layers=12): r'LayerNorm.beta', hugface_key) if hugface_key in hugface_keys: print("[OLD Weights file]gamma/beta is used in " + - "naming BertLayerNorm.") + "naming BertLayerNorm. Mapping succeed.") else: raise KeyError("Key %s not found in weight file" % hugface_key) @@ -120,21 +120,19 @@ def main(): help="output onmt version Bert weight file Path") args = parser.parse_args() - n_layers = args.layers - print("Model contain {} layers.".format(n_layers)) + print("Model contain {} layers.".format(args.layers)) - bert_model_weights = args.bert_model_weights_file - print("Load weights from {}.".format(bert_model_weights)) + print("Load weights from {}.".format(args.bert_model_weights_file)) - bert_weights = torch.load(bert_model_weights) - embeddings = BertEmbeddings(105879) + bert_weights = torch.load(args.bert_model_weights_file) + embeddings = BertEmbeddings(28996) # vocab don't bother the conversion bert_encoder = BertEncoder(embeddings) generator = BertPreTrainingHeads(bert_encoder.d_model, embeddings.vocab_size) bertlm = torch.nn.Sequential(OrderedDict([ ('bert', bert_encoder), ('generator', generator)])) - model_weights = convert_bert_weights(bertlm, bert_weights, n_layers) + model_weights = convert_bert_weights(bertlm, bert_weights, args.layers) ckp = {'model': model_weights['bert'], 'generator': model_weights['generator']} diff --git a/onmt/inputters/__init__.py b/onmt/inputters/__init__.py index a3af38e144..f9f17f17fa 100644 --- a/onmt/inputters/__init__.py +++ b/onmt/inputters/__init__.py @@ -11,7 +11,8 @@ from onmt.inputters.image_dataset import img_sort_key, ImageDataReader from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader from onmt.inputters.datareader_base import DataReaderBase -from onmt.inputters.dataset_bert import BertDataset, bert_text_sort_key +from onmt.inputters.dataset_bert import BertDataset, bert_text_sort_key,\ + ClassifierDataset, TaggerDataset str2reader = { "text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader} @@ -22,6 +23,6 @@ __all__ = ['Dataset', 'load_old_vocab', 'get_fields', 'get_bert_fields', 'DataReaderBase', 'filter_example', 'old_style_vocab', 'build_vocab', 'OrderedIterator', 'text_sort_key', - 'img_sort_key', 'audio_sort_key', - 'BertDataset', 'bert_text_sort_key', + 'img_sort_key', 'audio_sort_key', 'BertDataset', + 'bert_text_sort_key', 'ClassifierDataset', 'TaggerDataset', 'TextDataReader', 'ImageDataReader', 'AudioDataReader'] diff --git a/onmt/inputters/dataset_bert.py b/onmt/inputters/dataset_bert.py index ae9ceaf08f..391e3e1621 100644 --- a/onmt/inputters/dataset_bert.py +++ b/onmt/inputters/dataset_bert.py @@ -1,6 +1,7 @@ import torch from torchtext.data import Dataset as TorchtextDataset from torchtext.data import Example +from random import random def bert_text_sort_key(ex): @@ -8,6 +9,66 @@ def bert_text_sort_key(ex): return len(ex.tokens) +def truncate_seq(tokens, max_num_tokens): + """Truncates a sequences to a maximum sequence length.""" + while True: + total_length = len(tokens) + if total_length <= max_num_tokens: + break + assert len(tokens) >= 1 + # We want to sometimes truncate from the front and sometimes + # from the back to add more randomness and avoid biases. + if random() < 0.5: + del tokens[0] + else: + tokens.pop() + + +def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): + """Truncates a pair of sequences to a maximum sequence length. + Lifted from Google's BERT repo.""" + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_num_tokens: + break + trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b + assert len(trunc_tokens) >= 1 + + # We want to sometimes truncate from the front and sometimes from the + # back to add more randomness and avoid biases. + if random() < 0.5: + del trunc_tokens[0] + else: + trunc_tokens.pop() + + +def create_sentence_instance(sentence, tokenizer, + max_seq_length, random_trunc=False): + tokens = tokenizer.tokenize(sentence) + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 2 + if len(tokens) > max_num_tokens: + if random_trunc is True: + truncate_seq(tokens, max_num_tokens) + else: + tokens = tokens[:max_num_tokens] + tokens_processed = ["[CLS]"] + tokens + ["[SEP]"] + segment_ids = [0 for _ in range(len(tokens) + 2)] + return tokens_processed, segment_ids + + +def create_sentence_pair_instance(sent_a, sent_b, tokenizer, max_seq_length): + tokens_a = tokenizer.tokenize(sent_a) + tokens_b = tokenizer.tokenize(sent_b) + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 3 + truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) + tokens_processed = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] + segment_ids = [0 for _ in range(len(tokens_a) + 2)] + \ + [1 for _ in range(len(tokens_b) + 1)] + return tokens_processed, segment_ids + + class BertDataset(TorchtextDataset): """Defines a BERT dataset composed of Examples along with its Fields. Args: @@ -42,3 +103,85 @@ def save(self, path, remove_fields=True): if remove_fields: self.fields = [] torch.save(self, path) + + +class ClassifierDataset(BertDataset): + """Defines a BERT dataset composed of Examples along with its Fields. + Args: + fields_dict (dict[str, Field]): a dict containing all Field with + its name. + data (list[]): a list of sequence, each sequence can be one sentence + or one sentence pair seperate by ' ||| '. + """ + + def __init__(self, fields_dict, data, tokenizer, + delimiter=' ||| ', max_seq_len=256): + data = [seq.decode("utf-8") for seq in data] + instances = self.create_instances( + data, tokenizer, delimiter, max_seq_len) + super(ClassifierDataset, self).__init__(fields_dict, instances) + + def create_instances(self, datas, tokenizer, delimiter, max_seq_len): + instances = [] + for data in datas: + sentences = data.strip().split(delimiter, 1) + if len(sentences) == 2: + sent_a, sent_b = sentences + tokens, segment_ids = create_sentence_pair_instance( + sent_a, sent_b, tokenizer, max_seq_len) + else: + sentence = sentences[0] + tokens, segment_ids = create_sentence_instance( + sentence, tokenizer, max_seq_len, random_trunc=False) + instance = { + "tokens": tokens, + "segment_ids": segment_ids, + "category": None} + instances.append(instance) + return instances + + +class TaggerDataset(BertDataset): + """Defines a BERT dataset composed of Examples along with its Fields. + Args: + fields_dict (dict[str, Field]): a dict containing all Field with + its name. + data (list[]): a list of sequence, each sequence is composed with + tokens that to be tagging. + """ + + def __init__(self, fields_dict, data, tokenizer, + delimiter=' ', max_seq_len=256): + targer_field = fields_dict["token_labels"] + self.pad_tok = targer_field.pad_token + self.predict_tok = targer_field.vocab.itos[-1] + data = [seq.decode("utf-8") for seq in data] + instances = self.create_instances( + data, tokenizer, delimiter, max_seq_len) + super(TaggerDataset, self).__init__(fields_dict, instances) + + def create_instances(self, datas, tokenizer, delimiter, max_seq_len): + instances = [] + for data in datas: + words = data.strip().split(delimiter) + sentence = [] + tags = [] + max_num_tokens = max_seq_len - 2 + for word in words: + tokens = tokenizer.tokenize(word) + n_pad = len(tokens) - 1 + paded_tag = [self.predict_tok] + [self.pad_tok] * n_pad + if len(sentence) + len(tokens) > max_num_tokens: + break + else: + sentence.extend(tokens) + tags.extend(paded_tag) + sentence = ["[CLS]"] + sentence + ["[SEP]"] + tags = [self.pad_tok] + tags + [self.pad_tok] + segment_ids = [0 for _ in range(len(sentence))] + instance = { + "tokens": sentence, + "segment_ids": segment_ids, + "token_labels": tags} + instances.append(instance) + return instances diff --git a/onmt/model_builder.py b/onmt/model_builder.py index d4b69946a5..0bc738476f 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -362,3 +362,21 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): model.to(device) logger.info(model) return model + +def load_bert_model(opt, model_path): + checkpoint = torch.load(model_path, + map_location=lambda storage, loc: storage) + logger.info("Checkpoint from {} Loaded.".format(model_path)) + model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt']) + ArgumentParser.update_model_opts(model_opt) + # ArgumentParser.validate_model_opts(model_opt) + vocab = checkpoint['vocab'] + fields = vocab + model = build_bert_model(model_opt, opt, fields, checkpoint, gpu_id=opt.gpu) + + if opt.fp32: + model.float() + model.eval() + model.bert.eval() + model.generator.eval() + return fields, model, model_opt diff --git a/onmt/opts.py b/onmt/opts.py index 0c154ea6d7..b5cddab047 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -715,6 +715,59 @@ def translate_opts(parser): "model faster and smaller") +def predict_opts(parser): + """ Prediction [Using Pretrained model] options """ + group = parser.add_argument_group('Model') + group.add("--bert_model", type=str, + default="bert-base-uncased", + choices=["bert-base-uncased", "bert-large-uncased", + "bert-base-cased", "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese"], + help="Bert pretrained tokenizer model to use.") + group.add("--model", type=str, default=None, required=True, + help="Path to Bert model that for predicting.") + group.add('--task', type=str, default=None, required=True, + choices=["classification", "tagging"], + help="Target task to perform") + + group = parser.add_argument_group('Data') + group.add('--data', '-i', type=str, default=None, required=True, + help="predicting data for classification / tagging" + + "Classification: Sentence1 ||| Sentence2, " + + "Tagging: one tokenized sentence a line") + group.add("--do_lower_case", action="store_true", help='lowercase data') + group.add('--delimiter', '-d', type=str, default=None, + help="Delimiter used for seperate sentence/word. " + + "Default: ' ||| ' for sentence used in [CLS], " + + " ' ' for word used in [TAG].") + group.add("--max_seq_len", type=int, default=256, + help="max sequence length for prepared data," + "set the limite of position encoding") + group.add('--output', '-output', default=None, required=True, + help="Path to output the predictions") + group.add('--shard_size', '-shard_size', type=int, default=10000, + help="Divide data into smaller multiple data files, " + "then build shards, each shard will have " + "opt.shard_size samples except last shard. " + "shard_size=0 means no segmentation " + "shard_size>0 segment data into multiple shards, " + "each shard has shard_size samples") + + group = parser.add_argument_group('Efficiency') + group.add('--batch_size', '-batch_size', type=int, default=8, + help='Batch size') + group.add('--gpu', '-gpu', type=int, default=-1, help="Device to run on") + group.add('--seed', '-seed', type=int, default=829, help="Random seed") + group.add('--log_file', '-log_file', type=str, default="", + help="Output logs to a file under this path.") + group.add('--fp32', '-fp32', action='store_true', + help="Force the model to be in FP32 " + "because FP16 is very slow on GTX1080(ti).") + group.add('--verbose', '-verbose', action="store_true", + help='Print scores and predictions for each sentence') + # Copyright 2016 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/onmt/translate/__init__.py b/onmt/translate/__init__.py index 2b1ba49133..e61a9fb5ff 100644 --- a/onmt/translate/__init__.py +++ b/onmt/translate/__init__.py @@ -8,8 +8,9 @@ from onmt.translate.penalties import PenaltyBuilder from onmt.translate.translation_server import TranslationServer, \ ServerModelError +from onmt.translate.predictor import Classifier, Tagger __all__ = ['Translator', 'Translation', 'Beam', 'BeamSearch', 'GNMTGlobalScorer', 'TranslationBuilder', 'PenaltyBuilder', 'TranslationServer', 'ServerModelError', - "DecodeStrategy", "RandomSampling"] + "DecodeStrategy", "RandomSampling", "Classifier", "Tagger"] diff --git a/onmt/translate/predictor.py b/onmt/translate/predictor.py new file mode 100644 index 0000000000..d9f119e1ca --- /dev/null +++ b/onmt/translate/predictor.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python +""" Classifier Class and builder """ +from __future__ import print_function +import codecs +import time + +import torch +import torchtext.data +import onmt.model_builder +import onmt.inputters as inputters +from onmt.utils.misc import set_random_seed + + +def build_classifier(opt, logger=None, out_file=None): + if out_file is None: + out_file = codecs.open(opt.output, 'w+', 'utf-8') + + load_bert_model = onmt.model_builder.load_bert_model + fields, model, model_opt = load_bert_model(opt, opt.model) + + classifier = Classifier.from_opt( + model, + fields, + opt, + model_opt, + out_file=out_file, + logger=logger + ) + return classifier + + +def build_tagger(opt, logger=None, out_file=None): + if out_file is None: + out_file = codecs.open(opt.output, 'w+', 'utf-8') + + load_bert_model = onmt.model_builder.load_bert_model + fields, model, model_opt = load_bert_model(opt, opt.model) + + tagger = Tagger.from_opt( + model, + fields, + opt, + model_opt, + out_file=out_file, + logger=logger + ) + return tagger + + +class Predictor(object): + """Predictor a batch of data with a saved model. + + Args: + model (onmt.modules.Sequential): model to use + fields (dict[str, torchtext.data.Field]): A dict of field. + gpu (int): GPU device. Set to negative for no GPU. + data_type (str): Source data type. + verbose (bool): Print/log every translation. + report_time (bool): Print/log total time/frequency. + out_file (TextIO or codecs.StreamReaderWriter): Output file. + logger (logging.Logger or NoneType): Logger. + """ + + def __init__( + self, + model, + fields, + gpu=-1, + verbose=False, + out_file=None, + report_time=True, + logger=None, + seed=-1): + self.model = model + self.fields = fields + # tgt_field = dict(self.fields)["tgt"].base_field + # self._tgt_vocab = tgt_field.vocab + # self._tgt_eos_idx = self._tgt_vocab.stoi[tgt_field.eos_token] + # self._tgt_pad_idx = self._tgt_vocab.stoi[tgt_field.pad_token] + # self._tgt_bos_idx = self._tgt_vocab.stoi[tgt_field.init_token] + # self._tgt_unk_idx = self._tgt_vocab.stoi[tgt_field.unk_token] + # self._tgt_vocab_len = len(self._tgt_vocab) + + self._gpu = gpu + self._use_cuda = gpu > -1 + self._dev = torch.device("cuda", self._gpu) \ + if self._use_cuda else torch.device("cpu") + + self.verbose = verbose + self.report_time = report_time + self.out_file = out_file + self.logger = logger + + # self.use_filter_pred = False + # self._filter_pred = None + + set_random_seed(seed, self._use_cuda) + + @classmethod + def from_opt( + cls, + model, + fields, + opt, + model_opt, + out_file=None, + logger=None): + """Alternate constructor. + + Args: + model (onmt.modules): See :func:`__init__()`. + fields (dict[str, torchtext.data.Field]): See + :func:`__init__()`. + opt (argparse.Namespace): Command line options + model_opt (argparse.Namespace): Command line options saved with + the model checkpoint. + out_file (TextIO or codecs.StreamReaderWriter): See + :func:`__init__()`. + logger (logging.Logger or NoneType): See :func:`__init__()`. + """ + + return cls( + model, + fields, + gpu=opt.gpu, + verbose=opt.verbose, + out_file=out_file, + logger=logger, + seed=opt.seed) + + def _log(self, msg): + if self.logger: + self.logger.info(msg) + else: + print(msg) + + +class Classifier(Predictor): + """classify a batch of sentences with a saved model. + + Args: + model (onmt.modules.Sequential): BERT model to use for classify + fields (dict[str, torchtext.data.Field]): A dict of field. + gpu (int): GPU device. Set to negative for no GPU. + data_type (str): Source data type. + verbose (bool): Print/log every translation. + report_time (bool): Print/log total time/frequency. + out_file (TextIO or codecs.StreamReaderWriter): Output file. + logger (logging.Logger or NoneType): Logger. + """ + + def __init__( + self, + model, + fields, + gpu=-1, + verbose=False, + out_file=None, + report_time=True, + logger=None, + seed=-1): + super(Classifier, self).__init__( + model, + fields, + gpu=gpu, + verbose=verbose, + out_file=out_file, + report_time=report_time, + logger=logger, + seed=seed) + label_field = self.fields["category"] + self.label_vocab = label_field.vocab + + def classify(self, data, batch_size, tokenizer, + delimiter=' ||| ', max_seq_len=256): + """Classify content of ``data``. + + Args: + data: list of sentences to classify,ex. Sentence1 ||| Sentence2. + batch_size (int): size of examples per mini-batch + + Returns: + (`list`, `list`) + + * all_scores is a list of `batch_size` lists of `n_best` scores + * all_predictions is a list of `batch_size` lists + of `n_best` predictions + """ + + dataset = inputters.ClassifierDataset( + self.fields, data, tokenizer, delimiter, max_seq_len) + + data_iter = torchtext.data.Iterator( + dataset=dataset, + batch_size=batch_size, + device=self._dev, + train=False, + sort=False, + sort_within_batch=False, + shuffle=False + ) + + all_predictions = [] + + start_time = time.time() + + for batch in data_iter: + pred_sents_labels = self.classify_batch(batch) + all_predictions.extend(pred_sents_labels) + self.out_file.write('\n'.join(pred_sents_labels) + '\n') + self.out_file.flush() + + end_time = time.time() + + if self.report_time: + total_time = end_time - start_time + self._log("Total classification time: %f s" % total_time) + self._log("Average classification time: %f s" % ( + total_time / len(all_predictions))) + self._log("Sentences per second: %f" % ( + len(all_predictions) / total_time)) + return all_predictions + + def classify_batch(self, batch): + """Translate a batch of sentences.""" + with torch.no_grad(): + input_ids, seq_lengths = batch.tokens + token_type_ids = batch.segment_ids + all_encoder_layers, pooled_out = self.model.bert( + input_ids, token_type_ids) + seq_class_log_prob, prediction_log_prob = self.model.generator( + all_encoder_layers, pooled_out) + # outputs = (seq_class_log_prob, prediction_log_prob) + + pred_sents_ids = seq_class_log_prob.argmax(-1).tolist() + pred_sents_labels = [self.label_vocab.itos[index] + for index in pred_sents_ids] + return pred_sents_labels + + +class Tagger(Predictor): + """Tagging a batch of sentences with a saved model. + + Args: + model (onmt.modules.Sequential): BERT model to use for Tagging + fields (dict[str, torchtext.data.Field]): A dict of field. + gpu (int): GPU device. Set to negative for no GPU. + data_type (str): Source data type. + verbose (bool): Print/log every translation. + report_time (bool): Print/log total time/frequency. + out_file (TextIO or codecs.StreamReaderWriter): Output file. + logger (logging.Logger or NoneType): Logger. + """ + + def __init__( + self, + model, + fields, + gpu=-1, + verbose=False, + out_file=None, + report_time=True, + logger=None, + seed=-1): + super(Tagger, self).__init__( + model, + fields, + gpu=gpu, + verbose=verbose, + out_file=out_file, + report_time=report_time, + logger=logger, + seed=seed) + label_field = self.fields["token_labels"] + self.label_vocab = label_field.vocab + self.pad_token = label_field.pad_token + self.pad_index = self.label_vocab.stoi[self.pad_token] + + def tagging(self, data, batch_size, tokenizer, + delimiter=' ', max_seq_len=256): + """Tagging content of ``data``. + + Args: + data: list of sentences to classify,ex. Sentence1 ||| Sentence2. + batch_size (int): size of examples per mini-batch + + Returns: + (`list`, `list`) + + * all_scores is a list of `batch_size` lists of `n_best` scores + * all_predictions is a list of `batch_size` lists + of `n_best` predictions + """ + dataset = inputters.TaggerDataset( + self.fields, data, tokenizer, delimiter, max_seq_len) + + data_iter = torchtext.data.Iterator( + dataset=dataset, + batch_size=batch_size, + device=self._dev, + train=False, + sort=False, + sort_within_batch=False, + shuffle=False + ) + + all_predictions = [] + + start_time = time.time() + + for batch in data_iter: + pred_tokens_tag = self.tagging_batch(batch) + all_predictions.extend(pred_tokens_tag) + for pred_sent in pred_tokens_tag: + self.out_file.write('\n'.join(pred_sent) + '\n' + '\n') + self.out_file.flush() + + end_time = time.time() + + if self.report_time: + total_time = end_time - start_time + self._log("Total tagging time (s): %f" % total_time) + self._log("Average tagging time (s): %f" % ( + total_time / len(all_predictions))) + self._log("Sentence per second: %f" % ( + len(all_predictions) / total_time)) + return all_predictions + + def tagging_batch(self, batch): + """Translate a batch of sentences.""" + with torch.no_grad(): + # Batch + input_ids, seq_lengths = batch.tokens + token_type_ids = batch.segment_ids + taggings = batch.token_labels + # Forward + all_encoder_layers, pooled_out = self.model.bert( + input_ids, token_type_ids) + seq_class_log_prob, prediction_log_prob = self.model.generator( + all_encoder_layers, pooled_out) + # Predicting + pred_tag_ids = prediction_log_prob.argmax(-1) + non_padding = taggings.ne(self.pad_index) + batch_tag_ids, batch_mask = list(pred_tag_ids), list(non_padding) + batch_tag_select_ids = [pred.masked_select(mask).tolist() + for pred, mask in + zip(batch_tag_ids, batch_mask)] + + pred_tokens_tag = [[self.label_vocab.itos[index] + for index in tag_select_ids] + for tag_select_ids in batch_tag_select_ids] + return pred_tokens_tag diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index a96bf66f9a..97f7c1129d 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -5,6 +5,7 @@ import onmt.opts as opts from onmt.utils.logging import logger +from onmt.utils.bert_tokenization import PRETRAINED_VOCAB_ARCHIVE_MAP class ArgumentParser(cfargparse.ArgumentParser): @@ -134,3 +135,25 @@ def validate_preprocess_args(cls, opt): "Please check path of your src vocab!" assert not opt.tgt_vocab or os.path.isfile(opt.tgt_vocab), \ "Please check path of your tgt vocab!" + + @classmethod + def validate_predict_opts(cls, opt): + if opt.delimiter is None: + if opt.task == 'classification': + opt.delimiter = ' ||| ' + else: + opt.delimiter = ' ' + logger.info("NOTICE: opt.delimiter set to `%s`" % opt.delimiter) + assert opt.bert_model in PRETRAINED_VOCAB_ARCHIVE_MAP.keys(), \ + "Unsupported Pretrain model '%s'" % (opt.bert_model) + if '-cased' in opt.bert_model and opt.do_lower_case is True: + logger.info("WARNING: The pre-trained model you are loading " + + "is cased model, you shouldn't set `do_lower_case`," + + "we turned it off for you.") + opt.do_lower_case = False + elif '-cased' not in opt.bert_model and opt.do_lower_case is False: + logger.info("WARNING: The pre-trained model you are loading " + + "is uncased model, you should set `do_lower_case`, " + + "we turned it on for you.") + opt.do_lower_case = True + return opt diff --git a/predict.py b/predict.py new file mode 100755 index 0000000000..8d706b8f9e --- /dev/null +++ b/predict.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals + +from onmt.utils.logging import init_logger +from onmt.utils.misc import split_corpus +from onmt.translate.predictor import build_classifier, build_tagger + +import onmt.opts as opts +from onmt.utils.parse import ArgumentParser +from onmt.utils.bert_tokenization import BertTokenizer + + +def main(opt): + logger = init_logger(opt.log_file) + opt = ArgumentParser.validate_predict_opts(opt) + tokenizer = BertTokenizer.from_pretrained( + opt.bert_model, do_lower_case=opt.do_lower_case) + data_shards = split_corpus(opt.data, opt.shard_size) + if opt.task == 'classification': + classifier = build_classifier(opt) + for i, data_shard in enumerate(data_shards): + logger.info("Classify shard %d." % i) + classifier.classify( + data_shard, + opt.batch_size, + tokenizer, + delimiter=opt.delimiter, + max_seq_len=opt.max_seq_len + ) + if opt.task == 'tagging': + tagger = build_tagger(opt) + for i, data_shard in enumerate(data_shards): + logger.info("Tagging shard %d." % i) + tagger.tagging( + data_shard, + opt.batch_size, + tokenizer, + delimiter=opt.delimiter, + max_seq_len=opt.max_seq_len + ) + + +def _get_parser(): + parser = ArgumentParser(description='predict.py') + opts.config_opts(parser) + opts.predict_opts(parser) + return parser + + +if __name__ == "__main__": + parser = _get_parser() + + opt = parser.parse_args() + main(opt) diff --git a/preprocess_bert.py b/preprocess_bert.py index 0fc4b2713e..a9409d9ae4 100755 --- a/preprocess_bert.py +++ b/preprocess_bert.py @@ -1,79 +1,20 @@ from argparse import ArgumentParser from tqdm import tqdm import csv -from random import random, shuffle +from random import shuffle from onmt.utils.bert_tokenization import BertTokenizer, \ PRETRAINED_VOCAB_ARCHIVE_MAP import json from onmt.inputters.inputter import get_bert_fields, \ _build_bert_fields_vocab -from onmt.inputters.dataset_bert import BertDataset +from onmt.inputters.dataset_bert import BertDataset, \ + create_sentence_instance, create_sentence_pair_instance from collections import Counter, defaultdict import torch import os import codecs -def truncate_seq(tokens, max_num_tokens): - """Truncates a sequences to a maximum sequence length.""" - while True: - total_length = len(tokens) - if total_length <= max_num_tokens: - break - assert len(tokens) >= 1 - # We want to sometimes truncate from the front and sometimes - # from the back to add more randomness and avoid biases. - if random() < 0.5: - del tokens[0] - else: - tokens.pop() - - -def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): - """Truncates a pair of sequences to a maximum sequence length. - Lifted from Google's BERT repo.""" - while True: - total_length = len(tokens_a) + len(tokens_b) - if total_length <= max_num_tokens: - break - trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b - assert len(trunc_tokens) >= 1 - - # We want to sometimes truncate from the front and sometimes from the - # back to add more randomness and avoid biases. - if random() < 0.5: - del trunc_tokens[0] - else: - trunc_tokens.pop() - - -def create_sentence_instance(sentence, tokenizer, - max_seq_length, random_trunc=False): - tokens = tokenizer.tokenize(sentence) - # Account for [CLS], [SEP], [SEP] - max_num_tokens = max_seq_length - 2 - if len(tokens) > max_num_tokens: - if random_trunc is True: - truncate_seq(tokens, max_num_tokens) - else: - tokens = tokens[:max_num_tokens] - tokens_processed = ["[CLS]"] + tokens + ["[SEP]"] - segment_ids = [0 for _ in range(len(tokens) + 2)] - return tokens_processed, segment_ids - - -def create_sentence_pair_instance(sent_a, sent_b, tokenizer, max_seq_length): - tokens_a = tokenizer.tokenize(sent_a) - tokens_b = tokenizer.tokenize(sent_b) - # Account for [CLS], [SEP], [SEP] - max_num_tokens = max_seq_length - 3 - truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) - tokens_processed = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] - segment_ids = [0 for _ in range(len(tokens_a) + 2)] + \ - [1 for _ in range(len(tokens_b) + 1)] - return tokens_processed, segment_ids - - def create_instances_from_csv(records, skip_head, tokenizer, max_seq_length, column_a, column_b, label_column, labels): instances = [] @@ -89,20 +30,14 @@ def create_instances_from_csv(records, skip_head, tokenizer, max_seq_length, sentence_b = record[column_b].strip() tokens_processed, segment_ids = create_sentence_pair_instance( sentence_a, sentence_b, tokenizer, max_seq_length) - if label_column is not None: - label = record[label_column].strip() - if label not in labels: - labels.append(label) - instance = { - "tokens": tokens_processed, - "segment_ids": segment_ids, - "category": label} - else: # TODO: prediction dataset - label = None - instance = { - "tokens": tokens_processed, - "segment_ids": segment_ids, - "category": label} + + label = record[label_column].strip() + if label not in labels: + labels.append(label) + instance = { + "tokens": tokens_processed, + "segment_ids": segment_ids, + "category": label} instances.append(instance) return instances, labels @@ -145,7 +80,7 @@ def create_instances_from_file(records, label, tokenizer, max_seq_length): def build_instances_from_files(data, labels, tokenizer, max_seq_len, do_shuffle): instances = [] - for filename in data: #zip(data, labels): + for filename in data: label = filename.split('/')[-2] with codecs.open(filename, "r", encoding="utf-8") as f: lines = f.readlines() @@ -283,11 +218,17 @@ def validate_preprocess_bert_opts(opts): "For sequence tagging, only one column for input tokens." opts.input_columns = opts.input_columns[0] + assert args.label_column is not None,\ + "For sequence tagging, label column should be given." + if opts.data_type == "csv": assert len(opts.data) == 1,\ "For csv, only one file is needed." assert len(opts.input_columns) in [1, 2],\ "Please indicate N.colomn for sentence A (and B)" + assert args.label_column is not None,\ + "For csv file, label column should be given." + # if opts.label_column is not None: # assert len(opts.labels) != 0,\ # "label list is needed when csv contain label column" @@ -306,7 +247,7 @@ def _get_parser(): required=True, help="input datas to prepare: [CLS]" + "Single file for csv with column indicate label," + "One file for each class as path/label/file; [TAG]" + - "Single file contain (tok, tag) in each line,"+ + "Single file contain (tok, tag) in each line," + "Sentence separated by blank line.") parser.add_argument('--data_type', type=str, default="csv", choices=["csv", "txt"], @@ -392,7 +333,8 @@ def main(args): args.labels = labels print("Labels:", args.labels) instances = build_instances_from_files( - args.data, args.labels, tokenizer, args.max_seq_len, args.do_shuffle) + args.data, args.labels, tokenizer, + args.max_seq_len, args.do_shuffle) else: raise NotImplementedError("Not support other file type yet!") @@ -401,8 +343,7 @@ def main(args): filename = args.data[0] print("Load data file %s with skip head %s" % ( filename, args.skip_head)) - token_column = args.input_columns - tag_column = args.label_column + token_column, tag_column = args.input_columns, args.label_column instances, labels = build_tag_instances_from_file( filename, args.skip_head, tokenizer, args.max_seq_len, token_column, tag_column, args.labels, args.do_shuffle, @@ -429,15 +370,10 @@ def main(args): if args.corpus_type == "train": print("Generating vocab from corresponding text file...") if args.task == "classification": - if len(args.labels) == 0: # TODO - raise AttributeError("Labels should be given") - else: - named_labels = ("category", args.labels) - print("Save Labels:", named_labels, "in vocab file.") - + named_labels = ("category", args.labels) if args.task == "tagging": named_labels = ("token_labels", args.labels) - print("Save Labels:", named_labels, "in vocab file.") + print("Save Labels:", named_labels, "in vocab file.") fields_vocab = build_vocab_from_tokenizer( fields, tokenizer, named_labels, From 7f6a1278ea784626171b394e1ca011984df8c951 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Wed, 14 Aug 2019 18:41:02 +0200 Subject: [PATCH 13/28] clean up code --- onmt/inputters/dataset_bert.py | 39 +++--- onmt/opts.py | 91 ++++++++++++- onmt/translate/predictor.py | 25 ++-- onmt/utils/parse.py | 45 +++++++ predict.py | 6 +- preprocess_bert.py | 4 +- preprocess_bert_new.py | 234 +++++++++++++++++++++++++++++++++ train.py | 14 +- 8 files changed, 411 insertions(+), 47 deletions(-) create mode 100755 preprocess_bert_new.py diff --git a/onmt/inputters/dataset_bert.py b/onmt/inputters/dataset_bert.py index 391e3e1621..60421efacc 100644 --- a/onmt/inputters/dataset_bert.py +++ b/onmt/inputters/dataset_bert.py @@ -110,21 +110,23 @@ class ClassifierDataset(BertDataset): Args: fields_dict (dict[str, Field]): a dict containing all Field with its name. - data (list[]): a list of sequence, each sequence can be one sentence - or one sentence pair seperate by ' ||| '. + data (list[]): a list of sequence (sentence or sentence pair), + possible with its label becoming tuple(list[]). """ def __init__(self, fields_dict, data, tokenizer, - delimiter=' ||| ', max_seq_len=256): - data = [seq.decode("utf-8") for seq in data] + max_seq_len=256, delimiter=' ||| '): + if isinstance(data, tuple) is False: + data = data, [None for _ in range(len(data))] instances = self.create_instances( data, tokenizer, delimiter, max_seq_len) super(ClassifierDataset, self).__init__(fields_dict, instances) - def create_instances(self, datas, tokenizer, delimiter, max_seq_len): + def create_instances(self, data, tokenizer, + delimiter, max_seq_len): instances = [] - for data in datas: - sentences = data.strip().split(delimiter, 1) + for sentence, label in zip(*data): + sentences = sentence.strip().split(delimiter, 1) if len(sentences) == 2: sent_a, sent_b = sentences tokens, segment_ids = create_sentence_pair_instance( @@ -136,7 +138,7 @@ def create_instances(self, datas, tokenizer, delimiter, max_seq_len): instance = { "tokens": tokens, "segment_ids": segment_ids, - "category": None} + "category": label} instances.append(instance) return instances @@ -146,31 +148,36 @@ class TaggerDataset(BertDataset): Args: fields_dict (dict[str, Field]): a dict containing all Field with its name. - data (list[]): a list of sequence, each sequence is composed with - tokens that to be tagging. + data (list[]|tuple(list[])): a list of sequence, each sequence is + composed with tokens that to be tagging. Can also combined with + its tags as tuple([tokens], [tags]) """ def __init__(self, fields_dict, data, tokenizer, - delimiter=' ', max_seq_len=256): + max_seq_len=256, delimiter=' '): targer_field = fields_dict["token_labels"] self.pad_tok = targer_field.pad_token self.predict_tok = targer_field.vocab.itos[-1] - data = [seq.decode("utf-8") for seq in data] + if isinstance(data, tuple) is False: + data = (data, [None for _ in range(len(data))]) instances = self.create_instances( data, tokenizer, delimiter, max_seq_len) super(TaggerDataset, self).__init__(fields_dict, instances) def create_instances(self, datas, tokenizer, delimiter, max_seq_len): instances = [] - for data in datas: - words = data.strip().split(delimiter) + for words, taggings in zip(*datas): + if isinstance(words, str): # build from raw sentence + words = words.strip().split(delimiter) + if taggings is None: # when predicting + taggings = [self.predict_tok for _ in range(len(words))] sentence = [] tags = [] max_num_tokens = max_seq_len - 2 - for word in words: + for word, tag in zip(words, taggings): tokens = tokenizer.tokenize(word) n_pad = len(tokens) - 1 - paded_tag = [self.predict_tok] + [self.pad_tok] * n_pad + paded_tag = [tag] + [self.pad_tok] * n_pad if len(sentence) + len(tokens) > max_num_tokens: break else: diff --git a/onmt/opts.py b/onmt/opts.py index b5cddab047..e872b45eec 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -151,6 +151,9 @@ def model_opts(parser): help='Number of heads for transformer self-attention') group.add('--transformer_ff', '-transformer_ff', type=int, default=2048, help='Size of hidden transformer feed-forward') + group.add('--activation', '-activation', default='relu', + choices=['relu', 'gelu'], + help='type of activation function used in Bert encoder.') # Generator and loss options. group.add('--copy_attn', '-copy_attn', action="store_true", @@ -309,9 +312,85 @@ def preprocess_opts(parser): "model faster and smaller") +def preprocess_bert_opts(parser): + """ Pre-procesing options for pretrained model """ + # Data options + group = parser.add_argument_group('Common') + group.add('--task', '-task', type=str, required=True, + choices=["classification", "tagging"], + help="Target task to perform") + group.add('--corpus_type', '-corpus_type', type=str, default="train", + choices=['train', 'valid'], + help="corpus type choose from ['train', 'valid'], " + + "Vocab file will be generate if `train`") + + group = parser.add_argument_group('Data') + group.add('--file_type', type=str, default="txt", choices=["csv", "txt"], + help="input file type. Choose [txt|csv]") + group.add('--data', '-data', type=str, nargs='+', default=[], + required=True, + help="input datas to prepare: [CLS]" + + "Single file for csv with column indicate label," + + "One file for each class as path/label/file; [TAG]" + + "Single file contain (tok, tag) in each line," + + "Sentence separated by blank line.") + group.add('--skip_head', '-skip_head', action="store_true", + help="CSV: If csv file contain head line.") + group.add('--do_lower_case', '-lower', action='store_true', + help='lowercase data') + group.add("--max_seq_len", type=int, default=256, + help="Maximum sequence length to keep.") + group.add('--save_data', '-save_data', type=str, required=True, + help="Output file Prefix for the prepared data") + + group = parser.add_argument_group('Columns') + # options for column-like input file with fields seperate by -delimiter + group.add('--delimiter', '-delimiter', type=str, default=' ', + help="delimiter used in input file for seperate fields.") + group.add('--input_columns', type=int, nargs='+', default=[], + help="Column where contain sentence A(,B)") + group.add('--label_column', type=int, default=None, + help="Column where contain label") + + group = parser.add_argument_group('Vocab') + group.add('--labels', '-labels', type=str, nargs='+', default=[], + help="Candidate labels, will be used to build label vocab. " + + "If not given, this will be built from input file.") + group.add('--sort_label_vocab', '-sort_label', type=bool, default=True, + help="sort label vocab in alphabetic order.") + group.add("--vocab_model", "-vm", type=str, default="bert-base-uncased", + choices=["bert-base-uncased", "bert-large-uncased", + "bert-base-cased", "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese"], + help="Pretrained BertTokenizer model use to tokenizer text.") + + # Data processing options + group = parser.add_argument_group('Random') + group.add('--do_shuffle', '-shuffle', action="store_true", + help="Shuffle data") + + group = parser.add_argument_group('Logging') + group.add('--log_file', '-log_file', type=str, default="", + help="Output logs to a file under this path.") + + def train_opts(parser): """ Training and saving options """ + group = parser.add_argument_group('Pretrain-finetuning') + group.add('--is_bert', '-is_bert', action='store_true') + group.add('--task_type', '-task_type', type=str, default='classification', + choices=["pretraining", "classification", "tagging"], + help="Downstream task for Bert if is_bert set True" + "Choose from pretraining Bert," + "use pretrained Bert for classification," + "use pretrained Bert for token generation.") + group.add('--reuse_embeddings', '-reuse_embeddings', type=bool, + default=False, help="if reuse embeddings for generator " + + "currently not available") + group = parser.add_argument_group('General') group.add('--data', '-data', required=True, help='Path prefix to the ".train.pt" and ' @@ -366,6 +445,10 @@ def train_opts(parser): group.add('--param_init_glorot', '-param_init_glorot', action='store_true', help="Init parameters with xavier_uniform. " "Required for transformer.") + group.add('--param_init_normal', '-param_normal', type=float, default=0.0, + help="Parameters are initialized over normal distribution " + "with (mean=0, std=param_init_normal). Used in BERT with 0.02." + "Set value > 0 and param_init 0.0 to activate.") group.add('--train_from', '-train_from', default='', type=str, help="If training from a checkpoint then this is the " @@ -438,7 +521,7 @@ def train_opts(parser): nargs="*", default=None, help='Criteria to use for early stopping.') group.add('--optim', '-optim', default='sgd', - choices=['sgd', 'adagrad', 'adadelta', 'adam', + choices=['sgd', 'adagrad', 'adadelta', 'adam', 'bertadam', 'sparseadam', 'adafactor', 'fusedadam'], help="Optimization method.") group.add('--adagrad_accumulator_init', '-adagrad_accumulator_init', @@ -513,10 +596,14 @@ def train_opts(parser): help="Decay every decay_steps") group.add('--decay_method', '-decay_method', type=str, default="none", - choices=['noam', 'noamwd', 'rsqrt', 'none'], + choices=['none', 'noam', 'noamwd', 'rsqrt', 'linear', + 'linearconst', 'cosine', 'cosine_hard_restart', + 'cosine_warmup_restart'], help="Use a custom decay rate.") group.add('--warmup_steps', '-warmup_steps', type=int, default=4000, help="Number of warmup steps for custom decay.") + group.add('--cycles', '-cycles', type=int, default=None, + help="required for cosine related decay.") group = parser.add_argument_group('Logging') group.add('--report_every', '-report_every', type=int, default=50, diff --git a/onmt/translate/predictor.py b/onmt/translate/predictor.py index d9f119e1ca..edaa019749 100644 --- a/onmt/translate/predictor.py +++ b/onmt/translate/predictor.py @@ -55,7 +55,7 @@ class Predictor(object): fields (dict[str, torchtext.data.Field]): A dict of field. gpu (int): GPU device. Set to negative for no GPU. data_type (str): Source data type. - verbose (bool): Print/log every translation. + verbose (bool): Print/log every predition. report_time (bool): Print/log total time/frequency. out_file (TextIO or codecs.StreamReaderWriter): Output file. logger (logging.Logger or NoneType): Logger. @@ -91,9 +91,6 @@ def __init__( self.out_file = out_file self.logger = logger - # self.use_filter_pred = False - # self._filter_pred = None - set_random_seed(seed, self._use_cuda) @classmethod @@ -143,7 +140,7 @@ class Classifier(Predictor): fields (dict[str, torchtext.data.Field]): A dict of field. gpu (int): GPU device. Set to negative for no GPU. data_type (str): Source data type. - verbose (bool): Print/log every translation. + verbose (bool): Print/log every predition. report_time (bool): Print/log total time/frequency. out_file (TextIO or codecs.StreamReaderWriter): Output file. logger (logging.Logger or NoneType): Logger. @@ -180,15 +177,12 @@ def classify(self, data, batch_size, tokenizer, batch_size (int): size of examples per mini-batch Returns: - (`list`, `list`) - - * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists - of `n_best` predictions + of sentence classification """ dataset = inputters.ClassifierDataset( - self.fields, data, tokenizer, delimiter, max_seq_len) + self.fields, data, tokenizer, max_seq_len, delimiter) data_iter = torchtext.data.Iterator( dataset=dataset, @@ -246,7 +240,7 @@ class Tagger(Predictor): fields (dict[str, torchtext.data.Field]): A dict of field. gpu (int): GPU device. Set to negative for no GPU. data_type (str): Source data type. - verbose (bool): Print/log every translation. + verbose (bool): Print/log every predition. report_time (bool): Print/log total time/frequency. out_file (TextIO or codecs.StreamReaderWriter): Output file. logger (logging.Logger or NoneType): Logger. @@ -285,14 +279,11 @@ def tagging(self, data, batch_size, tokenizer, batch_size (int): size of examples per mini-batch Returns: - (`list`, `list`) - - * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists - of `n_best` predictions + of token taggings """ dataset = inputters.TaggerDataset( - self.fields, data, tokenizer, delimiter, max_seq_len) + self.fields, data, tokenizer, max_seq_len, delimiter) data_iter = torchtext.data.Iterator( dataset=dataset, @@ -327,7 +318,7 @@ def tagging(self, data, batch_size, tokenizer, return all_predictions def tagging_batch(self, batch): - """Translate a batch of sentences.""" + """Tagging a batch of sentences.""" with torch.no_grad(): # Batch input_ids, seq_lengths = batch.tokens diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 97f7c1129d..f988edf89c 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -136,6 +136,51 @@ def validate_preprocess_args(cls, opt): assert not opt.tgt_vocab or os.path.isfile(opt.tgt_vocab), \ "Please check path of your tgt vocab!" + @classmethod + def validate_preprocess_bert_opts(cls, opt): + assert opt.vocab_model in PRETRAINED_VOCAB_ARCHIVE_MAP.keys(), \ + "Unsupported Pretrain model '%s'" % (opt.vocab_model) + if '-cased' in opt.vocab_model and opt.do_lower_case is True: + logger.warning("The pre-trained model you are loading is " + + "cased model, you shouldn't set `do_lower_case`," + + "we turned it off for you.") + opt.do_lower_case = False + elif '-cased' not in opt.vocab_model and opt.do_lower_case is False: + logger.warning("The pre-trained model you are loading is " + + "uncased model, you should set `do_lower_case`, " + + "we turned it on for you.") + opt.do_lower_case = True + + for filename in opt.data: + assert os.path.isfile(filename),\ + "Please check path of %s" % filename + + if opt.task == "tagging": + assert opt.file_type == 'txt' and len(opt.data) == 1,\ + "For sequence tagging, only single txt file is supported." + opt.data = opt.data[0] + + assert len(opt.input_columns) == 1,\ + "For sequence tagging, only one column for input tokens." + opt.input_columns = opt.input_columns[0] + + assert opt.label_column is not None,\ + "For sequence tagging, label column should be given." + + if opt.task == "classification": + if opt.file_type == "csv": + assert len(opt.data) == 1,\ + "For csv, only single file is needed." + opt.data = opt.data[0] + assert len(opt.input_columns) in [1, 2],\ + "Please indicate colomn of sentence A (and B)" + assert opt.label_column is not None,\ + "For csv file, label column should be given." + if opt.delimiter != '\t': + logger.warning("for csv file, we set delimiter to '\t'") + opt.delimiter = '\t' + return opt + @classmethod def validate_predict_opts(cls, opt): if opt.delimiter is None: diff --git a/predict.py b/predict.py index 8d706b8f9e..779e4e0283 100755 --- a/predict.py +++ b/predict.py @@ -22,8 +22,9 @@ def main(opt): classifier = build_classifier(opt) for i, data_shard in enumerate(data_shards): logger.info("Classify shard %d." % i) + data = [seq.decode("utf-8") for seq in data_shard] classifier.classify( - data_shard, + data, opt.batch_size, tokenizer, delimiter=opt.delimiter, @@ -33,8 +34,9 @@ def main(opt): tagger = build_tagger(opt) for i, data_shard in enumerate(data_shards): logger.info("Tagging shard %d." % i) + data = [seq.decode("utf-8") for seq in data_shard] tagger.tagging( - data_shard, + data, opt.batch_size, tokenizer, delimiter=opt.delimiter, diff --git a/preprocess_bert.py b/preprocess_bert.py index a9409d9ae4..57bd211033 100755 --- a/preprocess_bert.py +++ b/preprocess_bert.py @@ -260,8 +260,8 @@ def _get_parser(): parser.add_argument('--label_column', type=int, default=None, help="CSV: Column where contain label") parser.add_argument('--labels', type=str, nargs='+', default=[], - help="CSV: labels of sentence;" + - "TXT: labels for sentence in files.") + help="Candidate labels. If not given, build from " + + "input file and sort in alphabetic order.") parser.add_argument('--delimiter', '-d', type=str, default=' ', help="CSV: delimiter used for seperate column.") diff --git a/preprocess_bert_new.py b/preprocess_bert_new.py new file mode 100755 index 0000000000..1277e9b28f --- /dev/null +++ b/preprocess_bert_new.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" + Pre-process Data files and build vocabulary for Bert model. +""" +from onmt.utils.parse import ArgumentParser +from tqdm import tqdm +import csv +from collections import Counter, defaultdict +import torch +import codecs +from random import shuffle +from onmt.utils.bert_tokenization import BertTokenizer +from onmt.inputters.inputter import get_bert_fields, \ + _build_bert_fields_vocab +import onmt.opts as opts +from onmt.inputters.dataset_bert import ClassifierDataset, \ + TaggerDataset +from onmt.utils.logging import init_logger, logger + + +def shuffle_pair_list(list_a, list_b): + assert len(list_a) == len(list_b),\ + "Two list to shuffle should be equal length" + logger.info("Shuffle all instance") + pair_list = list(zip(list_a, list_b)) + shuffle(pair_list) + list_a, list_b = zip(*pair_list) + return list_a, list_b + + +def build_label_vocab_from_path(paths): + labels = [] + for filename in paths: + label = filename.split('/')[-2] + if label not in labels: + labels.append(label) + return labels + + +def _build_bert_vocab(vocab, name, counters): + """ similar to _load_vocab in inputter.py, but build from a vocab list. + in place change counters + """ + vocab_size = len(vocab) + for i, token in enumerate(vocab): + counters[name][token] = vocab_size - i + return vocab, vocab_size + + +def build_vocab_from_tokenizer(fields, tokenizer, named_labels): + logger.info("Building token vocab from BertTokenizer...") + vocab_list = list(tokenizer.vocab.keys()) + counters = defaultdict(Counter) + _, vocab_size = _build_bert_vocab(vocab_list, "tokens", counters) + + label_name, label_list = named_labels + logger.info("Building label vocab {}...".format(named_labels)) + _, _ = _build_bert_vocab(label_list, label_name, counters) + + fields_vocab = _build_bert_fields_vocab(fields, counters, vocab_size, + label_name) + return fields_vocab + + +def build_save_vocab(fields, tokenizer, label_vocab, opt): + if opt.sort_label_vocab is True: + label_vocab.sort() + if opt.task == "classification": + named_labels = ("category", label_vocab) + if opt.task == "tagging": + named_labels = ("token_labels", label_vocab) + + fields_vocab = build_vocab_from_tokenizer( + fields, tokenizer, named_labels) + bert_vocab_file = opt.save_data + ".vocab.pt" + torch.save(fields_vocab, bert_vocab_file) + + +def create_cls_instances_from_csv(opt): + logger.info("Reading csv with input in column %s, label in column %s" + % (opt.input_columns, opt.label_column)) + with codecs.open(opt.data, "r", encoding="utf-8-sig") as csvfile: + reader = csv.reader(csvfile, delimiter=opt.delimiter, quotechar=None) + lines = list(reader) + if opt.skip_head is True: + lines = lines[1:] + if len(opt.input_columns) == 1: + column_a = int(opt.input_columns[0]) + column_b = None + else: + column_a = int(opt.input_columns[0]) + column_b = int(opt.input_columns[1]) + + instances, labels, label_vocab = [], [], opt.labels + for line in tqdm(lines, desc="Process", unit=" lines"): + label = line[opt.label_column].strip() + if label not in label_vocab: + label_vocab.append(label) + sentence = line[column_a].strip() + if column_b is not None: + sentence_b = line[column_b].strip() + sentence = sentence + ' ||| ' + sentence_b + instances.append(sentence) + labels.append(label) + logger.info("total %d line loaded with skip_head [%s]" + % (len(lines), opt.skip_head)) + + return instances, labels, label_vocab + + +def create_cls_instances_from_files(opt): + instances = [] + labels = [] + label_vocab = build_label_vocab_from_path(opt.data) + for filename in opt.data: + label = filename.split('/')[-2] + with codecs.open(filename, "r", encoding="utf-8") as f: + lines = f.readlines() + print("total {} line of File {} loaded for label: {}.".format( + len(lines), filename, label)) + lines_labels = [label for _ in range(len(lines))] + instances.extend(lines) + labels.extend(lines_labels) + return instances, labels, label_vocab + + +def build_cls_dataset(corpus_type, fields, tokenizer, opt): + """Build classification dataset with vocab file if train set""" + assert corpus_type in ['train', 'valid'] + if opt.file_type == 'csv': + instances, labels, label_vocab = create_cls_instances_from_csv(opt) + else: + instances, labels, label_vocab = create_cls_instances_from_files(opt) + logger.info("Exiting labels:%s" % label_vocab) + if corpus_type == 'train': + build_save_vocab(fields, tokenizer, label_vocab, opt) + + if opt.do_shuffle is True: + instances, labels = shuffle_pair_list(instances, labels) + cls_instances = instances, labels + logger.info("Building %s dataset..." % corpus_type) + dataset = ClassifierDataset( + fields, cls_instances, tokenizer, opt.max_seq_len) + return dataset, len(cls_instances[0]) + + +def create_tag_instances_from_file(opt): + logger.info("Reading tag with token in column %s, tag in column %s" + % (opt.input_columns, opt.label_column)) + sentences, taggings = [], [] + tag_vocab = opt.labels + with codecs.open(opt.data, "r", encoding="utf-8") as f: + lines = f.readlines() + print("total {} line of file {} loaded.".format( + len(lines), opt.data)) + sentence_sofar = [] + for line in tqdm(lines, desc="Process", unit=" lines"): + line = line.strip() + if line is '': + if len(sentence_sofar) > 0: + tokens, tags = zip(*sentence_sofar) + sentences.append(tokens) + taggings.append(tags) + sentence_sofar = [] + else: + elements = line.split(opt.delimiter) + token = elements[opt.input_columns] + tag = elements[opt.label_column] + if tag not in tag_vocab: + tag_vocab.append(tag) + sentence_sofar.append((token, tag)) + print("total {} sentence loaded.".format(len(sentences))) + print("All tags:", tag_vocab) + + return sentences, taggings, tag_vocab + + +def build_tag_dataset(corpus_type, fields, tokenizer, opt): + """Build tagging dataset with vocab file if train set""" + assert corpus_type in ['train', 'valid'] + sentences, taggings, tag_vocab = create_tag_instances_from_file(opt) + logger.info("Exiting Tags:%s" % tag_vocab) + if corpus_type == 'train': + build_save_vocab(fields, tokenizer, tag_vocab, opt) + + if opt.do_shuffle is True: + sentences, taggings = shuffle_pair_list(sentences, taggings) + + tag_instances = sentences, taggings + logger.info("Building %s dataset..." % corpus_type) + dataset = TaggerDataset( + fields, tag_instances, tokenizer, opt.max_seq_len) + return dataset, len(tag_instances[0]) + + +def _get_parser(): + parser = ArgumentParser(description='preprocess_bert.py') + opts.config_opts(parser) + opts.preprocess_bert_opts(parser) + return parser + + +def main(opt): + init_logger(opt.log_file) + opt = ArgumentParser.validate_preprocess_bert_opts(opt) + logger.info("Preprocess dataset...") + + fields = get_bert_fields(opt.task) + logger.info("Get fields for Task: '%s'." % opt.task) + + tokenizer = BertTokenizer.from_pretrained( + opt.vocab_model, do_lower_case=opt.do_lower_case) + logger.info("Use pretrained tokenizer: '%s', do_lower_case [%s]" + % (opt.vocab_model, opt.do_lower_case)) + + if opt.task == "classification": + dataset, n_instance = build_cls_dataset( + opt.corpus_type, fields, tokenizer, opt) + + elif opt.task == "tagging": + dataset, n_instance = build_tag_dataset( + opt.corpus_type, fields, tokenizer, opt) + # Save processed data in OpenNMT format + onmt_filename = opt.save_data + ".{}.0.pt".format(opt.corpus_type) + dataset.save(onmt_filename) + logger.info("* save num_example [%d], max_seq_len [%d] to [%s]." + % (n_instance, opt.max_seq_len, onmt_filename)) + + +if __name__ == '__main__': + parser = _get_parser() + opt = parser.parse_args() + main(opt) diff --git a/train.py b/train.py index 237560f429..72666e1b05 100755 --- a/train.py +++ b/train.py @@ -4,7 +4,7 @@ import signal import torch -import onmt.opts_bert as opts +import onmt.opts as opts import onmt.utils.distributed from onmt.utils.misc import set_random_seed @@ -19,9 +19,9 @@ def main(opt): # JUST FOR verify the options - # ArgumentParser.validate_train_opts(opt) - # ArgumentParser.update_model_opts(opt) - # ArgumentParser.validate_model_opts(opt) + ArgumentParser.validate_train_opts(opt) + ArgumentParser.update_model_opts(opt) + ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: @@ -214,10 +214,8 @@ def _get_parser(): parser = ArgumentParser(description='train.py') opts.config_opts(parser) - # opts.model_opts(parser) - # opts.train_opts(parser) - opts.bert_model_opts(parser) - opts.bert_pretraining(parser) + opts.model_opts(parser) + opts.train_opts(parser) return parser From 892e0a010437ae3d6031b87c608d6230203c5ef9 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Tue, 20 Aug 2019 14:56:17 +0200 Subject: [PATCH 14/28] tagging bug fix --- onmt/inputters/dataset_bert.py | 4 +++- preprocess_bert_new.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/onmt/inputters/dataset_bert.py b/onmt/inputters/dataset_bert.py index 60421efacc..4ffbd7c057 100644 --- a/onmt/inputters/dataset_bert.py +++ b/onmt/inputters/dataset_bert.py @@ -157,7 +157,8 @@ def __init__(self, fields_dict, data, tokenizer, max_seq_len=256, delimiter=' '): targer_field = fields_dict["token_labels"] self.pad_tok = targer_field.pad_token - self.predict_tok = targer_field.vocab.itos[-1] + if hasattr(targer_field, 'vocab'): # when predicting + self.predict_tok = targer_field.vocab.itos[-1] if isinstance(data, tuple) is False: data = (data, [None for _ in range(len(data))]) instances = self.create_instances( @@ -170,6 +171,7 @@ def create_instances(self, datas, tokenizer, delimiter, max_seq_len): if isinstance(words, str): # build from raw sentence words = words.strip().split(delimiter) if taggings is None: # when predicting + assert hasattr(self, 'predict_tok') taggings = [self.predict_tok for _ in range(len(words))] sentence = [] tags = [] diff --git a/preprocess_bert_new.py b/preprocess_bert_new.py index 1277e9b28f..c521edecca 100755 --- a/preprocess_bert_new.py +++ b/preprocess_bert_new.py @@ -152,7 +152,7 @@ def create_tag_instances_from_file(opt): tag_vocab = opt.labels with codecs.open(opt.data, "r", encoding="utf-8") as f: lines = f.readlines() - print("total {} line of file {} loaded.".format( + logger.info("total {} line of file {} loaded.".format( len(lines), opt.data)) sentence_sofar = [] for line in tqdm(lines, desc="Process", unit=" lines"): @@ -170,8 +170,8 @@ def create_tag_instances_from_file(opt): if tag not in tag_vocab: tag_vocab.append(tag) sentence_sofar.append((token, tag)) - print("total {} sentence loaded.".format(len(sentences))) - print("All tags:", tag_vocab) + logger.info("total {} sentence loaded.".format(len(sentences))) + logger.info("All tags:{}".format(tag_vocab)) return sentences, taggings, tag_vocab From ed0cf4db6561fb255682a7737ea5b89035d35e6c Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Mon, 26 Aug 2019 14:46:23 +0200 Subject: [PATCH 15/28] clean code --- docs/source/FAQ.md | 149 ++++++- docs/source/refs.bib | 49 +++ onmt/encoders/bert.py | 57 ++- onmt/encoders/transformer.py | 8 + onmt/inputters/dataset_bert.py | 75 +++- onmt/model_builder.py | 20 +- onmt/models/bert_generators.py | 98 +++-- onmt/modules/bert_embeddings.py | 29 +- onmt/modules/position_ffn.py | 9 +- onmt/opts.py | 18 +- onmt/trainer.py | 30 +- onmt/translate/predictor.py | 40 +- onmt/utils/__init__.py | 2 +- .../{fn_activation.py => activation_fn.py} | 3 +- onmt/utils/bert_tokenization.py | 24 +- onmt/utils/file_utils.py | 48 +-- onmt/utils/loss.py | 35 +- onmt/utils/optimizers.py | 30 +- onmt/utils/parse.py | 18 +- onmt/utils/statistics.py | 13 +- predict.py | 2 +- pregenerate_bert_training_data.py | 379 ++++++++++------- preprocess_bert.py | 389 ------------------ preprocess_bert_new.py | 9 +- 24 files changed, 743 insertions(+), 791 deletions(-) rename onmt/utils/{fn_activation.py => activation_fn.py} (90%) delete mode 100755 preprocess_bert.py diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index c5ebeb201f..f5808e8b10 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -77,7 +77,7 @@ python train.py -data /tmp/de2/data -save_model /tmp/extra \ -optim adam -adam_beta2 0.998 -decay_method noam -warmup_steps 8000 -learning_rate 2 \ -max_grad_norm 0 -param_init 0 -param_init_glorot \ -label_smoothing 0.1 -valid_steps 10000 -save_checkpoint_steps 10000 \ - -world_size 4 -gpu_ranks 0 1 2 3 + -world_size 4 -gpu_ranks 0 1 2 3 ``` Here are what each of the parameters mean: @@ -85,8 +85,8 @@ Here are what each of the parameters mean: * `param_init_glorot` `-param_init 0`: correct initialization of parameters * `position_encoding`: add sinusoidal position encoding to each embedding * `optim adam`, `decay_method noam`, `warmup_steps 8000`: use special learning rate. -* `batch_type tokens`, `normalization tokens`, `accum_count 4`: batch and normalize based on number of tokens and not sentences. Compute gradients based on four batches. -- `label_smoothing 0.1`: use label smoothing loss. +* `batch_type tokens`, `normalization tokens`, `accum_count 4`: batch and normalize based on number of tokens and not sentences. Compute gradients based on four batches. +- `label_smoothing 0.1`: use label smoothing loss. Multi GPU settings First you need to make sure you export CUDA_VISIBLE_DEVICES=0,1,2,3 @@ -136,3 +136,146 @@ E.g. will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`, and that when building batches, we'll take 1 example from corpus A, then 7 examples from corpus B, and so on. **Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing. + +## How do I use BERT? +BERT is a general-purpose "language understanding" model introduced by Google, it can be used for various downstream NLP tasks and easily adapted into a new task using transfer learning. Using BERT has two stages: Pre-training and fine-tuning. But as the Pre-training is super expensive, we do not recommande you to pre-train a BERT from scratch. Instead loading weights from a existing pretrained model and fine-tuning it is suggested. Currently we support sentence(-pair) classification and token tagging downstream task. + +### Use pretrained BERT weights +To use weights from a existing huggingface's pretrained model, we provide you a script to convert huggingface's BERT model weights into ours. + +Usage: +```bash +bert_ckp_convert.py --layers NUMBER_LAYER + --bert_model_weights_file HUGGINGFACE_BERT_WEIGHTS + --output_name OUTPUT_FILE +``` +* Go to modeling_bert.py in https://github.com/huggingface/pytorch-transformers/ to check all available pretrained model. + +### Preprocess train/dev dataset +To genenrate train/dev data for BERT, you can use preprocess_bert.py by providing raw data in certain format and choose a BERT Tokenizer model `-vm` coherent with pretrained model. +#### Classification +For classification dataset, we support input file in csv or plain text file format. + +* For csv file, each line should contain a instance with one or two sentence column and one column for label as in GLUE dataset, other csv format dataset should be compatible. A typical csv file should be like: + + | ID | SENTENCE_A | SENTENCE_B(Optional) | LABEL | + | -- | ------------------------ | ------------------------ | ------- | + | 0 | sentence a of instance 0 | sentence b of instance 0 | class 2 | + | 1 | sentence a of instance 0 | sentence b of instance 1 | class 1 | + | ...| ... | ... | ... | + + Then calling preprocess_bert.py and providing input sentence columns and label column: + ```bash + python preprocess_bert.py --task classification --corpus_type {train, valid} + --file_type csv [--delimiter '\t'] [--skip_head] + --input_columns 1 2 --label_column 3 + --data DATA_DIR/FILENAME.tsv + --save_data dataset + -vm bert-base-cased --max_seq_len 256 [--do_lower_case] + [--sort_label_vocab] [--do_shuffle] + ``` +* For plain text format, we accept multiply files as input, each file contains instances for one specific class. Each line of the file contain one instance which could be composed by one sentence or two separated by ` ||| `. All input file should be arranged in following way: + ``` + .../LABEL_1/filename + .../LABEL_2/filename + .../LABEL_3/filename + ``` + Then call preprocess_bert.py as following to generate training data: + ```bash + python preprocess_bert.py --task classification --corpus_type {'train', 'valid'} + --file_type txt [--delimiter ' ||| '] + --data DIR_BASE/LABEL_1/FILENAME1 ... DIR_BASE/LABEL_N/FILENAME2 + --save_data dataset + --vocab_model {bert-base-uncased,...} + --max_seq_len 256 [--do_lower_case] + [--sort_label_vocab] [--do_shuffle] + ``` +#### Tagging +For tagging dataset, we support input file in plain text file format. + +Each line of the input file should contain token and its tagging, different fields should be separated by a delimiter(default space) while sentences are separated by a blank line. + +A example of input file is given below (`Token X X Label`): + ``` + -DOCSTART- -X- O O + + CRICKET NNP I-NP O + - : O O + LEICESTERSHIRE NNP I-NP I-ORG + TAKE NNP I-NP O + OVER IN I-PP O + AT NNP I-NP O + TOP NNP I-NP O + AFTER NNP I-NP O + INNINGS NNP I-NP O + VICTORY NN I-NP O + . . O O + + LONDON NNP I-NP I-LOC + 1996-08-30 CD I-NP O + + ``` +Then call preprocess_bert.py providing token column and label column as following to generate training data for token tagging task: + ```bash + python preprocess_bert.py --task tagging --corpus_type {'train', 'valid'} + --file_type txt [--delimiter ' '] + --input_columns 1 --label_column 3 + --data DATA_DIR/FILENAME + --save_data dataset + --vocab_model {bert-base-uncased,...} + --max_seq_len 256 [--do_lower_case] + [--sort_label_vocab] [--do_shuffle] + ``` +#### Pretraining objective +Even if it's not recommended, we also provide you a script to generate pretraining dataset as you may want to finetuning a existing pretrained model on masked language modeling and next sentence prediction. + +The script expect a single file as input, consisting of untokenized text, with one sentence per line, and one blank line between documents. +A usage example is given below: +```bash +python3 pregenerate_bert_training_data.py --input_file INPUT_FILE + --output_dir OUTPUT_DIR + --output_name OUTPUT_FILE_PREFIX + --corpus_type {'train', 'valid'} + --vocab_model {bert-base-uncased,...} + [--do_lower_case] [--do_whole_word_mask] [--reduce_memory] + --epochs_to_generate 2 + --max_seq_len 128 + --short_seq_prob 0.1 --masked_lm_prob 0.15 + --max_predictions_per_seq 20 + [--save_json] +``` + +### Training +After preprocessed data have been generated, you can load weights from a pretrained BERT and transfer it to downstream task with a task specific output head. This task specific head will be initialized by a method you choose if there is no such architecture in weights file specified by `--train_from`. Among all available optimizer, you are suggest to use `--optim bertadam` as it is the method used to train BERT. `warmup_steps` could be set as 1% of `train_steps` as in original paper if use linear decay method. + +A usage example is given below: +```bash +python3 train.py --is_bert --task_type {pretraining, classification, tagging} + --data PREPROCESSED_DATAIFILE + --train_from CONVERTED_CHECKPOINT.pt [--param_init 0.1] + --save_model MODEL_PREFIX --save_checkpoint_steps 1000 + [--world_size 2] [--gpu_ranks 0 1] + --word_vec_size 768 --rnn_size 768 + --layers 12 --heads 8 --transformer_ff 3072 + --activation gelu --dropout 0.1 --average_decay 0.0001 + --batch_size 8 [--accum_count 4] --optim bertadam [--max_grad_norm 0] + --learning_rate 2e-5 --learning_rate_decay 0.99 --decay_method linear + --train_steps 4000 --valid_steps 200 --warmup_steps 40 + [--report_every 10] [--seed 3435] + [--tensorboard] [--tensorboard_log_dir LOGDIR] +``` + +### Predicting +After training, you can use `predict.py` to generate predicting for raw file. Make sure to use the same BERT Tokenizer model `--vocab_model` as in training data. + +For classification task, file to be predicting should be one sentence(-pair) a line with ` ||| ` separating sentence. +For tagging task, each line should be a tokenized sentence with tokens separated by space. + +Usage: +```bash +python3 predict.py --task {classification, tagging} + --model ONMT_BERT_CHECKPOINT.pt + --vocab_model bert-base-uncased [--do_lower_case] + --data DATA_2_PREDICT [--delimiter {' ||| ', ' '}] --max_seq_len 256 + --output PREDICT.txt [--batch_size 8] [--gpu 1] [--seed 3435] +``` diff --git a/docs/source/refs.bib b/docs/source/refs.bib index 045b494d19..8d2617f775 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -435,3 +435,52 @@ @article{DBLP:journals/corr/MartinsA16 biburl = {https://dblp.org/rec/bib/journals/corr/MartinsA16}, bibsource = {dblp computer science bibliography, https://dblp.org} } + +@article{DBLP:journals/corr/abs-1711-05101, + author = {Ilya Loshchilov and + Frank Hutter}, + title = {Fixing Weight Decay Regularization in Adam}, + journal = {CoRR}, + volume = {abs/1711.05101}, + year = {2017}, + url = {http://arxiv.org/abs/1711.05101}, + archivePrefix = {arXiv}, + eprint = {1711.05101}, + timestamp = {Mon, 13 Aug 2018 16:48:18 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/abs-1711-05101}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} + +@article{DBLP:journals/corr/abs-1810-04805, + author = {Jacob Devlin and + Ming{-}Wei Chang and + Kenton Lee and + Kristina Toutanova}, + title = {{BERT:} Pre-training of Deep Bidirectional Transformers for Language + Understanding}, + journal = {CoRR}, + volume = {abs/1810.04805}, + year = {2018}, + url = {http://arxiv.org/abs/1810.04805}, + archivePrefix = {arXiv}, + eprint = {1810.04805}, + timestamp = {Tue, 30 Oct 2018 20:39:56 +0100}, + biburl = {https://dblp.org/rec/bib/journals/corr/abs-1810-04805}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} + +@article{DBLP:journals/corr/HendrycksG16, + author = {Dan Hendrycks and + Kevin Gimpel}, + title = {Bridging Nonlinearities and Stochastic Regularizers with Gaussian + Error Linear Units}, + journal = {CoRR}, + volume = {abs/1606.08415}, + year = {2016}, + url = {http://arxiv.org/abs/1606.08415}, + archivePrefix = {arXiv}, + eprint = {1606.08415}, + timestamp = {Mon, 13 Aug 2018 16:46:20 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/HendrycksG16}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} diff --git a/onmt/encoders/bert.py b/onmt/encoders/bert.py index 4d935d1cf5..49571bbd69 100644 --- a/onmt/encoders/bert.py +++ b/onmt/encoders/bert.py @@ -4,10 +4,18 @@ class BertEncoder(nn.Module): + """BERT Encoder: A Transformer Encoder with BertLayerNorm and BertPooler. + :cite:`DBLP:journals/corr/abs-1810-04805` + + Args: + embeddings (onmt.modules.BertEmbeddings): embeddings to use + num_layers (int): number of encoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + dropout (float): dropout parameters """ - BERT Implementation: https://arxiv.org/abs/1810.04805 - Use a Transformer Encoder as Language modeling. - """ + def __init__(self, embeddings, num_layers=12, d_model=768, heads=12, d_ff=3072, dropout=0.1, max_relative_positions=0): @@ -34,8 +42,8 @@ def from_opt(cls, opt, embeddings): """Alternate constructor.""" return cls( embeddings, - opt.enc_layers, - opt.enc_rnn_size, + opt.layers, + opt.word_vec_size, opt.heads, opt.transformer_ff, opt.dropout[0] if type(opt.dropout) is list else opt.dropout, @@ -45,16 +53,15 @@ def forward(self, input_ids, token_type_ids=None, input_mask=None, output_all_encoded_layers=False): """ Args: - input_ids: shape [batch, seq] padding ids=0 - token_type_ids: shape [batch, seq], A(0), B(1), pad(0) - input_mask: shape [batch, seq], 1 for masked position(that padding) - output_all_encoded_layers: if out contain all hidden layer + input_ids (Tensor): ``(B, S)``, padding ids=0 + token_type_ids (Tensor): ``(B, S)``, A(0), B(1), pad(0) + input_mask (Tensor): ``(B, S)``, 1 for masked (padding) + output_all_encoded_layers (bool): if out contain all hidden layer Returns: - all_encoder_layers: list of out in shape (batch, seq, d_model), - to be used for generation task - pooled_output: shape (batch, d_model), - to be used for classification task + all_encoder_layers (list of Tensor): ``(B, S, H)``, token level + pooled_output (Tensor): ``(B, H)``, sequence level """ + # OpenNMT waiting for mask of size [B, 1, T], # while in MultiHeadAttention part2 -> [B, 1, 1, T] if input_mask is None: @@ -87,17 +94,25 @@ def update_dropout(self, dropout): class BertPooler(nn.Module): def __init__(self, hidden_size): + """A pooling block (Linear layer followed by Tanh activation). + + Args: + hidden_size (int): size of hidden layer. + """ + super(BertPooler, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation_fn = nn.Tanh() def forward(self, hidden_states): - """ + """hidden_states[:, 0, :] --> {Linear, Tanh} --> Returns. + Args: - hidden_states: last layer's hidden_states,(batch, src, d_model) + hidden_states (Tensor): last layer's hidden_states, ``(B, S, H)`` Returns: - pooled_output: transformed output of last layer's hidden_states + pooled_output (Tensor): transformed output of last layer's hidden """ + first_token_tensor = hidden_states[:, 0, :] # [batch, d_model] pooled_output = self.activation_fn(self.dense(first_token_tensor)) return pooled_output @@ -105,15 +120,21 @@ def forward(self, hidden_states): class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): - """Construct a layernorm module in the TF style - (epsilon inside the square root). + """Layernorm module in the TF style(epsilon inside the square root). + https://www.tensorflow.org/api_docs/python/tf/contrib/layers/layer_norm. + + Args: + hidden_size (int): size of hidden layer. """ + super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): + """layer normalization is perform on input x.""" + u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 181989f9f2..7fb72d6459 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -44,6 +44,13 @@ def __init__(self, d_model, heads, d_ff, dropout, self.is_bert = is_bert def residual(self, output, x): + """A Residual connection. + + Official BERT perform residual connection on layer normed input. + BERT's layer_norm is done before pass into next block while onmt's + layer_norm is performed at the begining. + """ + maybe_norm = self.layer_norm(x) if self.is_bert else x return output + maybe_norm @@ -58,6 +65,7 @@ def forward(self, inputs, mask): * outputs ``(batch_size, src_len, model_dim)`` """ + input_norm = self.layer_norm(inputs) context, _ = self.self_attn(input_norm, input_norm, input_norm, mask=mask, type="self") diff --git a/onmt/inputters/dataset_bert.py b/onmt/inputters/dataset_bert.py index 4ffbd7c057..512ec21a25 100644 --- a/onmt/inputters/dataset_bert.py +++ b/onmt/inputters/dataset_bert.py @@ -10,7 +10,8 @@ def bert_text_sort_key(ex): def truncate_seq(tokens, max_num_tokens): - """Truncates a sequences to a maximum sequence length.""" + """Truncates a sequences randomly from front or back + to a maximum sequence length.""" while True: total_length = len(tokens) if total_length <= max_num_tokens: @@ -26,7 +27,9 @@ def truncate_seq(tokens, max_num_tokens): def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): """Truncates a pair of sequences to a maximum sequence length. - Lifted from Google's BERT repo.""" + Lifted from Google's BERT repo: create_pretraining_data.py in + https://github.com/google-research/bert/""" + while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_num_tokens: @@ -44,6 +47,20 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): def create_sentence_instance(sentence, tokenizer, max_seq_length, random_trunc=False): + """Create single processed instance in BERT format. + + Args: + sentence (str): a raw single sentence. + tokenizer (onmt.utils.BertTokenizer): tokenizer to be used on data. + max_seq_len (int): maximum length of sequence. + random_trunc (bool): if false, trunc tail. + + Returns: + (list, list): + + * tokens_processed: ["[CLS]", sent_a, "[SEP]"] + * segment_ids: [0, ..., 0] + """ tokens = tokenizer.tokenize(sentence) # Account for [CLS], [SEP], [SEP] max_num_tokens = max_seq_length - 2 @@ -58,6 +75,20 @@ def create_sentence_instance(sentence, tokenizer, def create_sentence_pair_instance(sent_a, sent_b, tokenizer, max_seq_length): + """Create single processed instance in BERT format. + + Args: + sent_a (str): a raw single sentence. + sent_b (str): another raw single sentence. + tokenizer (onmt.utils.BertTokenizer): tokenizer to be used on data. + max_seq_len (int): maximum length of sequence. + + Returns: + (list, list): + + * tokens_processed: ["[CLS]", sent_a, "[SEP]", sent_b, "[SEP]"] + * segment_ids: [0, ..., 0, 1, ..., 1] + """ tokens_a = tokenizer.tokenize(sent_a) tokens_b = tokenizer.tokenize(sent_b) # Account for [CLS], [SEP], [SEP] @@ -71,6 +102,7 @@ def create_sentence_pair_instance(sent_a, sent_b, tokenizer, max_seq_length): class BertDataset(TorchtextDataset): """Defines a BERT dataset composed of Examples along with its Fields. + Args: fields_dict (dict[str, Field]): a dict containing all Field with its name. @@ -107,11 +139,16 @@ def save(self, path, remove_fields=True): class ClassifierDataset(BertDataset): """Defines a BERT dataset composed of Examples along with its Fields. + Fields include "tokens", "segment_ids", "category". + Args: fields_dict (dict[str, Field]): a dict containing all Field with its name. data (list[]): a list of sequence (sentence or sentence pair), possible with its label becoming tuple(list[]). + tokenizer (onmt.utils.BertTokenizer): a tokenizer to be used on data. + max_seq_len (int): maximum length of sequence. + delimiter (str): delimiter used to separate tokens in input sequence. """ def __init__(self, fields_dict, data, tokenizer, @@ -124,6 +161,19 @@ def __init__(self, fields_dict, data, tokenizer, def create_instances(self, data, tokenizer, delimiter, max_seq_len): + """Return data instances in the form of list of dict. + + Args: + data (list[]): a list of sequence (sentence or sentence pair), + possible with its label becoming tuple(list[]). + tokenizer (onmt.utils.BertTokenizer): tokenizer to use on data. + max_seq_len (int): maximum length of sequence. + delimiter (str): delimiter used to separate tokens in sequence. + + Returns: + instances (list of dict): list of sequence classification instance. + """ + instances = [] for sentence, label in zip(*data): sentences = sentence.strip().split(delimiter, 1) @@ -145,12 +195,16 @@ def create_instances(self, data, tokenizer, class TaggerDataset(BertDataset): """Defines a BERT dataset composed of Examples along with its Fields. + Args: fields_dict (dict[str, Field]): a dict containing all Field with its name. - data (list[]|tuple(list[])): a list of sequence, each sequence is + data (list of str|tuple of list): a list of sequence, each sequence is composed with tokens that to be tagging. Can also combined with - its tags as tuple([tokens], [tags]) + its tags as tuple([tokens], [tags]). + tokenizer (onmt.utils.BertTokenizer): a tokenizer to be used on data. + max_seq_len (int): maximum length of sequence. + delimiter (str): delimiter used to separate tokens in input sequence. """ def __init__(self, fields_dict, data, tokenizer, @@ -166,6 +220,19 @@ def __init__(self, fields_dict, data, tokenizer, super(TaggerDataset, self).__init__(fields_dict, instances) def create_instances(self, datas, tokenizer, delimiter, max_seq_len): + """Return data instances in the form of list of dict. + + Args: + data (list[]): a list of sequence (sentence or sentence pair), + possible with its label becoming tuple(list[]). + tokenizer (onmt.utils.BertTokenizer): tokenizer to use on data. + max_seq_len (int): maximum length of sequence. + delimiter (str): delimiter used to separate tokens in sequence. + + Returns: + instances (list of dict): list of tokens tagging instance. + """ + instances = [] for words, taggings in zip(*datas): if isinstance(words, str): # build from raw sentence diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 0bc738476f..b8d9691749 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -9,7 +9,7 @@ import onmt.inputters as inputters import onmt.modules -from onmt.encoders import str2enc, BertEncoder +from onmt.encoders import str2enc from onmt.decoders import str2dec @@ -239,12 +239,8 @@ def build_bert_embeddings(opt, fields): return bert_emb -def build_bert_encoder(model_opt, fields, embeddings): - bert = BertEncoder( - embeddings, num_layers=model_opt.layers, - d_model=model_opt.word_vec_size, heads=model_opt.heads, - d_ff=model_opt.transformer_ff, dropout=model_opt.dropout[0], - max_relative_positions=model_opt.max_relative_positions) +def build_bert_encoder(model_opt, embeddings): + bert = str2enc['bert'].from_opt(model_opt, embeddings) return bert @@ -274,7 +270,7 @@ def build_bert_generator(model_opt, fields, bert_encoder): generator.decode.weight = \ bert_encoder.embeddings.word_embeddings.weight elif task == 'classification': - n_class = len(fields["category"].vocab.stoi) #model_opt.labels + n_class = len(fields["category"].vocab.stoi) logger.info('Generator of classification with %s class.' % n_class) generator = ClassificationHead(bert_encoder.d_model, n_class, dropout) elif task == 'tagging': @@ -306,7 +302,7 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): bert_emb = build_bert_embeddings(model_opt, fields) # Build encoder. - bert_encoder = build_bert_encoder(model_opt, fields, bert_emb) + bert_encoder = build_bert_encoder(model_opt, bert_emb) gpu = use_gpu(opt) if gpu and gpu_id is not None: @@ -363,16 +359,16 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): logger.info(model) return model + def load_bert_model(opt, model_path): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) logger.info("Checkpoint from {} Loaded.".format(model_path)) model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt']) - ArgumentParser.update_model_opts(model_opt) - # ArgumentParser.validate_model_opts(model_opt) vocab = checkpoint['vocab'] fields = vocab - model = build_bert_model(model_opt, opt, fields, checkpoint, gpu_id=opt.gpu) + model = build_bert_model( + model_opt, opt, fields, checkpoint, gpu_id=opt.gpu) if opt.fp32: model.float() diff --git a/onmt/models/bert_generators.py b/onmt/models/bert_generators.py index 459c55dbda..33ee221652 100644 --- a/onmt/models/bert_generators.py +++ b/onmt/models/bert_generators.py @@ -8,6 +8,10 @@ class BertPreTrainingHeads(nn.Module): """ Bert Pretraining Heads: Masked Language Models, Next Sentence Prediction + + Args: + hidden_size (int): output size of BERT model + vocab_size (int): total vocab size """ def __init__(self, hidden_size, vocab_size): super(BertPreTrainingHeads, self).__init__() @@ -17,11 +21,11 @@ def __init__(self, hidden_size, vocab_size): def forward(self, x, pooled_out): """ Args: - x: list of out of all_encoder_layers, shape (batch, seq, d_model) - pooled_output: transformed output of last layer's hidden_states + x (list of Tensor): all_encoder_layers, shape ``(B, S, H)`` + pooled_output (Tensor): second output of bert encoder, ``(B, H)`` Returns: - seq_class_log_prob: next sentence prediction, (batch, 2) - prediction_log_prob: masked lm prediction, (batch, seq, vocab) + seq_class_log_prob (Tensor): next sentence prediction, ``(B, 2)`` + prediction_log_prob (Tensor): mlm prediction, ``(B, S, vocab)`` """ seq_class_log_prob = self.next_sentence(pooled_out) prediction_log_prob = self.mask_lm(x[-1]) @@ -29,18 +33,15 @@ def forward(self, x, pooled_out): class MaskedLanguageModel(nn.Module): - """ - predicting origin token from masked input sequence + """predicting origin token from masked input sequence n-class classification problem, n-class = vocab_size + + Args: + hidden_size (int): output size of BERT model + vocab_size (int): total vocab size """ def __init__(self, hidden_size, vocab_size): - """ - Args: - hidden_size: output size of BERT model - vocab_size: total vocab size - bert_word_embedding_weights: reuse embedding weights if set - """ super(MaskedLanguageModel, self).__init__() self.transform = BertPredictionTransform(hidden_size) @@ -52,9 +53,9 @@ def __init__(self, hidden_size, vocab_size): def forward(self, x): """ Args: - x: first output of bert encoder, (batch, seq, d_model) + x (Tensor): first output of bert encoder, ``(B, S, H)`` Returns: - prediction_log_prob: shape (batch, seq, vocab) + prediction_log_prob (Tensor): shape ``(B, S, vocab)`` """ x = self.transform(x) # (batch, seq, d_model) prediction_scores = self.decode(x) + self.bias # (batch, seq, vocab) @@ -65,13 +66,12 @@ def forward(self, x): class NextSentencePrediction(nn.Module): """ 2-class classification model : is_next, is_random_next + + Args: + hidden_size (int): BERT model output size """ def __init__(self, hidden_size): - """ - Args: - hidden_size: BERT model output size - """ super(NextSentencePrediction, self).__init__() self.linear = nn.Linear(hidden_size, 2) self.softmax = nn.LogSoftmax(dim=-1) @@ -79,9 +79,9 @@ def __init__(self, hidden_size): def forward(self, x): """ Args: - x: second output of bert encoder, (batch, d_model) + x (Tensor): second output of bert encoder, ``(B, H)`` Returns: - seq_class_prob: shape (batch_size, 2) + seq_class_prob (Tensor): ``(B, 2)`` """ seq_relationship_score = self.linear(x) # (batch, 2) seq_class_log_prob = self.softmax(seq_relationship_score) @@ -89,7 +89,14 @@ def forward(self, x): class BertPredictionTransform(nn.Module): + """{Linear(h,h), Activation, LN} block.""" + def __init__(self, hidden_size): + """ + Args: + hidden_size (int): BERT model hidden layer size. + """ + super(BertPredictionTransform, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation = get_activation_fn('gelu') @@ -98,22 +105,24 @@ def __init__(self, hidden_size): def forward(self, hidden_states): """ Args: - hidden_states: BERT model output size (batch, seq, d_model) + hidden_states (Tensor): BERT encoder output ``(B, S, H)`` """ + hidden_states = self.layer_norm(self.activation( self.dense(hidden_states))) return hidden_states class ClassificationHead(nn.Module): - """ - n-class Sentence classification head + """n-class Sentence classification head + + Args: + hidden_size (int): BERT model output size + n_class (int): number of classification label """ def __init__(self, hidden_size, n_class, dropout=0.1): """ - Args: - hidden_size: BERT model output size """ super(ClassificationHead, self).__init__() self.dropout = nn.Dropout(dropout) @@ -123,12 +132,13 @@ def __init__(self, hidden_size, n_class, dropout=0.1): def forward(self, all_hidden, pooled): """ Args: - all_hidden: layer output of BERT, list [(batch, seq, d_model)] - pooled: last layer hidden [CLS] of BERT, (batch, d_model) + all_hidden (list of Tensor): layers output, list [``(B, S, H)``] + pooled (Tensor): last layer hidden [CLS], ``(B, H)`` Returns: - class_log_prob: shape (batch_size, 2) + class_log_prob (Tensor): shape ``(B, 2)`` None: this is a placeholder for token level prediction task """ + pooled = self.dropout(pooled) score = self.linear(pooled) # (batch, n_class) class_log_prob = self.softmax(score) # (batch, n_class) @@ -136,15 +146,14 @@ def forward(self, all_hidden, pooled): class TokenTaggingHead(nn.Module): - """ - n-class Token Tagging head + """n-class Token Tagging head + + Args: + hidden_size (int): BERT model output size + n_class (int): number of tagging label """ def __init__(self, hidden_size, n_class, dropout=0.1): - """ - Args: - hidden_size: BERT model output size - """ super(TokenTaggingHead, self).__init__() self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(hidden_size, n_class) @@ -153,11 +162,11 @@ def __init__(self, hidden_size, n_class, dropout=0.1): def forward(self, all_hidden, pooled): """ Args: - all_hidden: layer output of BERT, list [(batch, seq, d_model)] - pooled: last layer hidden [CLS] of BERT, (batch, d_model) + all_hidden (list of Tensor): layers output, list [``(B, S, H)``] + pooled (Tensor): last layer hidden [CLS], ``(B, H)`` Returns: None: this is a placeholder for sentence level task - tok_class_log_prob: shape (batch, seq, n_class) + tok_class_log_prob (Tensor): shape ``(B, S, n_class)`` """ last_hidden = all_hidden[-1] last_hidden = self.dropout(last_hidden) # (batch, seq, d_model) @@ -169,15 +178,13 @@ def forward(self, all_hidden, pooled): class TokenGenerationHead(nn.Module): """ Token generation head: generation token from input sequence + + Args: + hidden_size (int): output size of BERT model + vocab_size (int): total vocab size """ def __init__(self, hidden_size, vocab_size): - """ - Args: - hidden_size: output size of BERT model - vocab_size: total vocab size - bert_word_embedding_weights: reuse embedding weights if set - """ super(TokenGenerationHead, self).__init__() self.transform = BertPredictionTransform(hidden_size) @@ -189,10 +196,11 @@ def __init__(self, hidden_size, vocab_size): def forward(self, all_hidden, pooled): """ Args: - all_hidden: layer output of BERT, list [(batch, seq, d_model)] + all_hidden (list of Tensor): layers output, list [``(B, S, H)``] + pooled (Tensor): last layer hidden [CLS], ``(B, H)`` Returns: None: this is a placeholder for sentence level task - prediction_log_prob: shape (batch, seq, vocab) + prediction_log_prob (Tensor): shape ``(B, S, vocab)`` """ last_hidden = all_hidden[-1] y = self.transform(last_hidden) # (batch, seq, d_model) diff --git a/onmt/modules/bert_embeddings.py b/onmt/modules/bert_embeddings.py index 8cd0bf54ae..4b8f647b0c 100644 --- a/onmt/modules/bert_embeddings.py +++ b/onmt/modules/bert_embeddings.py @@ -7,18 +7,18 @@ class BertEmbeddings(nn.Module): 1. Token embeddings: called word_embeddings 2. Segmentation embeddings: called token_type_embeddings 3. Position embeddings: called position_embeddings - Ref: https://arxiv.org/abs/1810.04805 section 3.2 + :cite:`DBLP:journals/corr/abs-1810-04805` section 3.2 + + Args: + vocab_size (int): Size of the embedding vocabulary. + embed_size (int): Width of the word embeddings. + pad_idx (int): padding index + dropout (float): dropout rate + max_position (int): max sentence length in input + num_sentence (int): number of segment """ def __init__(self, vocab_size, embed_size=768, pad_idx=0, dropout=0.1, max_position=512, num_sentence=2): - """ - Args: - vocab_size: int. Size of the embedding vocabulary. - embed_size: int. Width of the word embeddings. - dropout: dropout rate - pad_idx: padding index - max_position: max sentence length in input - """ super(BertEmbeddings, self).__init__() self.vocab_size = vocab_size self.embed_size = embed_size @@ -37,10 +37,11 @@ def __init__(self, vocab_size, embed_size=768, pad_idx=0, def forward(self, input_ids, token_type_ids=None): """ Args: - input_ids: word ids in shape [batch, seq, hidden_size]. - token_type_ids: token type ids in shape [batch, seq]. - Output: - embeddings: word embeds in shape [batch, seq, hidden_size]. + input_ids (Tensor): ``(B, S)``. + token_type_ids (Tensor): segment id ``(B, S)``. + + Returns: + embeddings (Tensor): final embeddings, ``(B, S, H)``. """ seq_length = input_ids.size(1) position_ids = torch.arange( @@ -56,7 +57,7 @@ def forward(self, input_ids, token_type_ids=None): token_type_embeds = self.token_type_embeddings(token_type_ids) embeddings = word_embeds + position_embeds + token_type_embeds # NOTE: in our version, LayerNorm is done in EncoderLayer - # before fed into Attention comparing to original implementation + # before fed into Attention comparing to original one # embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index 7a68794f46..435b642560 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -14,7 +14,7 @@ class PositionwiseFeedForward(nn.Module): d_ff (int): the hidden layer size of the second-layer of the FNN. dropout (float): dropout probability in :math:`[0, 1)`. - activation (str): activation function to use. ['ReLU', 'GeLU'] + activation (str): activation function to use. ['relu', 'gelu'] is_bert (bool): default False. When set True, layer_norm will be performed on the direct connection of residual block. @@ -34,6 +34,13 @@ def __init__(self, d_model, d_ff, dropout=0.1, self.is_bert = is_bert def residual(self, output, x): + """A Residual connection. + + Official BERT perform residual connection on layer normed input. + BERT's layer_norm is done before pass into next block while onmt's + layer_norm is performed at the begining. + """ + maybe_norm = self.layer_norm(x) if self.is_bert else x return output + maybe_norm diff --git a/onmt/opts.py b/onmt/opts.py index e872b45eec..ba63b134f9 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -363,7 +363,10 @@ def preprocess_bert_opts(parser): "bert-base-cased", "bert-large-cased", "bert-base-multilingual-uncased", "bert-base-multilingual-cased", - "bert-base-chinese"], + "bert-base-chinese", "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-base-cased-finetuned-mrpc"], help="Pretrained BertTokenizer model use to tokenizer text.") # Data processing options @@ -381,15 +384,15 @@ def train_opts(parser): group = parser.add_argument_group('Pretrain-finetuning') group.add('--is_bert', '-is_bert', action='store_true') - group.add('--task_type', '-task_type', type=str, default='classification', - choices=["pretraining", "classification", "tagging"], + group.add('--task_type', '-task_type', type=str, default="none", + choices=["none", "pretraining", "classification", "tagging"], help="Downstream task for Bert if is_bert set True" "Choose from pretraining Bert," "use pretrained Bert for classification," "use pretrained Bert for token generation.") group.add('--reuse_embeddings', '-reuse_embeddings', type=bool, default=False, help="if reuse embeddings for generator " + - "currently not available") + "only for generation or pretraining task") group = parser.add_argument_group('General') group.add('--data', '-data', required=True, @@ -805,13 +808,16 @@ def translate_opts(parser): def predict_opts(parser): """ Prediction [Using Pretrained model] options """ group = parser.add_argument_group('Model') - group.add("--bert_model", type=str, + group.add("--vocab_model", type=str, default="bert-base-uncased", choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", "bert-large-cased", "bert-base-multilingual-uncased", "bert-base-multilingual-cased", - "bert-base-chinese"], + "bert-base-chinese", "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-base-cased-finetuned-mrpc"], help="Bert pretrained tokenizer model to use.") group.add("--model", type=str, default=None, required=True, help="Path to Bert model that for predicting.") diff --git a/onmt/trainer.py b/onmt/trainer.py index 3c23fe635d..965cd65d8c 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -249,12 +249,9 @@ def train(self, % (self.gpu_rank, i + 1, len(batches))) if self.n_gpu > 1: - l_norm = onmt.utils.distributed.all_gather_list(normalization) - # NOTE: DEBUG - # current_rank = torch.distributed.get_rank() - # print("-> RANK: {}".format(current_rank)) - # print(list_norm) - normalization = sum(l_norm) + normalization = sum(onmt.utils.distributed + .all_gather_list + (normalization)) # Training Step: Forward -> compute Loss -> optimize if self.is_bert: @@ -540,22 +537,12 @@ def _bert_gradient_accumulation(self, true_batches, # 2. Compute loss. try: loss, batch_stats = self.train_loss(batch, outputs) - # NOTE: DEBUG - # loss_list = onmt.utils.distributed.all_gather_list(loss) - # current_rank = torch.distributed.get_rank() - # print("{}-> RANK: {}, loss:{} in {}".format( - # k, current_rank, loss, loss_list)) - # print("{}-> RANK: {}, stat:{}".format( - # k, current_rank, batch_stats.loss)) - # print(str(loss) + " ~ " +str(loss_list)) if loss is not None: self.optim.backward(loss) total_stats.update(batch_stats) report_stats.update(batch_stats) - # print(str(loss.item())+ " - " + str(report_stats.loss)) - # exit() except Exception: traceback.print_exc() logger.info("At step %d, we removed a batch - accum %d", @@ -568,18 +555,11 @@ def _bert_gradient_accumulation(self, true_batches, grads = [p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None] - # current_rank = torch.distributed.get_rank() - # print("{}-> RANK: {}, grads BEFORE:{}".format( - # k, current_rank, grads[0])) + # NOTE: average the gradient across the GPU onmt.utils.distributed.all_reduce_and_rescale_tensors( grads, float(self.n_gpu)) - # reduced_grads = [p.grad.data for p in - # self.model.parameters() - # if p.requires_grad - # and p.grad is not None] - # print("{}-> RANK: {}, grads AFTER:{}".format( - # k, current_rank, reduced_grads[0])) + self.optim.step() # in case of multi step gradient accumulation, diff --git a/onmt/translate/predictor.py b/onmt/translate/predictor.py index edaa019749..b49c789b25 100644 --- a/onmt/translate/predictor.py +++ b/onmt/translate/predictor.py @@ -12,6 +12,8 @@ def build_classifier(opt, logger=None, out_file=None): + """Return a classifier with result redirect to `out_file`.""" + if out_file is None: out_file = codecs.open(opt.output, 'w+', 'utf-8') @@ -30,6 +32,8 @@ def build_classifier(opt, logger=None, out_file=None): def build_tagger(opt, logger=None, out_file=None): + """Return a tagger with result redirect to `out_file`.""" + if out_file is None: out_file = codecs.open(opt.output, 'w+', 'utf-8') @@ -51,7 +55,7 @@ class Predictor(object): """Predictor a batch of data with a saved model. Args: - model (onmt.modules.Sequential): model to use + model (nn.Sequential): model to use fields (dict[str, torchtext.data.Field]): A dict of field. gpu (int): GPU device. Set to negative for no GPU. data_type (str): Source data type. @@ -73,13 +77,6 @@ def __init__( seed=-1): self.model = model self.fields = fields - # tgt_field = dict(self.fields)["tgt"].base_field - # self._tgt_vocab = tgt_field.vocab - # self._tgt_eos_idx = self._tgt_vocab.stoi[tgt_field.eos_token] - # self._tgt_pad_idx = self._tgt_vocab.stoi[tgt_field.pad_token] - # self._tgt_bos_idx = self._tgt_vocab.stoi[tgt_field.init_token] - # self._tgt_unk_idx = self._tgt_vocab.stoi[tgt_field.unk_token] - # self._tgt_vocab_len = len(self._tgt_vocab) self._gpu = gpu self._use_cuda = gpu > -1 @@ -136,7 +133,7 @@ class Classifier(Predictor): """classify a batch of sentences with a saved model. Args: - model (onmt.modules.Sequential): BERT model to use for classify + model (nn.Sequential): BERT model to use for classify fields (dict[str, torchtext.data.Field]): A dict of field. gpu (int): GPU device. Set to negative for no GPU. data_type (str): Source data type. @@ -173,12 +170,11 @@ def classify(self, data, batch_size, tokenizer, """Classify content of ``data``. Args: - data: list of sentences to classify,ex. Sentence1 ||| Sentence2. + data (list of str): ['Sentence1 ||| Sentence2',...]. batch_size (int): size of examples per mini-batch Returns: - * all_predictions is a list of `batch_size` lists - of sentence classification + all_predictions (list of str):[c1, ..., cn]. """ dataset = inputters.ClassifierDataset( @@ -216,16 +212,15 @@ def classify(self, data, batch_size, tokenizer, return all_predictions def classify_batch(self, batch): - """Translate a batch of sentences.""" + """Classify a batch of sentences.""" with torch.no_grad(): - input_ids, seq_lengths = batch.tokens + input_ids, _ = batch.tokens token_type_ids = batch.segment_ids all_encoder_layers, pooled_out = self.model.bert( input_ids, token_type_ids) - seq_class_log_prob, prediction_log_prob = self.model.generator( + seq_class_log_prob, _ = self.model.generator( all_encoder_layers, pooled_out) - # outputs = (seq_class_log_prob, prediction_log_prob) - + # Predicting pred_sents_ids = seq_class_log_prob.argmax(-1).tolist() pred_sents_labels = [self.label_vocab.itos[index] for index in pred_sents_ids] @@ -236,7 +231,7 @@ class Tagger(Predictor): """Tagging a batch of sentences with a saved model. Args: - model (onmt.modules.Sequential): BERT model to use for Tagging + model (nn.Sequential): BERT model to use for Tagging fields (dict[str, torchtext.data.Field]): A dict of field. gpu (int): GPU device. Set to negative for no GPU. data_type (str): Source data type. @@ -275,12 +270,11 @@ def tagging(self, data, batch_size, tokenizer, """Tagging content of ``data``. Args: - data: list of sentences to classify,ex. Sentence1 ||| Sentence2. + data (list of str): ['T1 T2 ... Tn',...]. batch_size (int): size of examples per mini-batch Returns: - * all_predictions is a list of `batch_size` lists - of token taggings + all_predictions (list of list of str): [['L1', ..., 'Ln'],...]. """ dataset = inputters.TaggerDataset( self.fields, data, tokenizer, max_seq_len, delimiter) @@ -321,13 +315,13 @@ def tagging_batch(self, batch): """Tagging a batch of sentences.""" with torch.no_grad(): # Batch - input_ids, seq_lengths = batch.tokens + input_ids, _ = batch.tokens token_type_ids = batch.segment_ids taggings = batch.token_labels # Forward all_encoder_layers, pooled_out = self.model.bert( input_ids, token_type_ids) - seq_class_log_prob, prediction_log_prob = self.model.generator( + _, prediction_log_prob = self.model.generator( all_encoder_layers, pooled_out) # Predicting pred_tag_ids = prediction_log_prob.argmax(-1) diff --git a/onmt/utils/__init__.py b/onmt/utils/__init__.py index 8dcc2ffaee..b9352e3d14 100644 --- a/onmt/utils/__init__.py +++ b/onmt/utils/__init__.py @@ -6,7 +6,7 @@ from onmt.utils.optimizers import MultipleOptimizer, \ Optimizer, AdaFactor, BertAdam from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts -from onmt.utils.fn_activation import get_activation_fn +from onmt.utils.activation_fn import get_activation_fn __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", "build_report_manager", "Statistics", "BertStatistics", diff --git a/onmt/utils/fn_activation.py b/onmt/utils/activation_fn.py similarity index 90% rename from onmt/utils/fn_activation.py rename to onmt/utils/activation_fn.py index ec1d9336a7..d8e18e9d4a 100644 --- a/onmt/utils/fn_activation.py +++ b/onmt/utils/activation_fn.py @@ -4,6 +4,7 @@ def get_activation_fn(activation): + """Return an activation function Module according to its name.""" if activation is 'gelu': fn = GELU() elif activation is 'relu': @@ -22,12 +23,12 @@ def get_activation_fn(activation): """ class GELU(nn.Module): """ Implementation of the gelu activation function + :cite:`DBLP:journals/corr/HendrycksG16` For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - see https://arxiv.org/abs/1606.08415 Examples:: >>> m = GELU() diff --git a/onmt/utils/bert_tokenization.py b/onmt/utils/bert_tokenization.py index 3937d6e011..9762cf9109 100644 --- a/onmt/utils/bert_tokenization.py +++ b/onmt/utils/bert_tokenization.py @@ -34,6 +34,12 @@ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", + 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", + 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", + 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", + 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", } PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 'bert-base-uncased': 512, @@ -43,6 +49,12 @@ 'bert-base-multilingual-uncased': 512, 'bert-base-multilingual-cased': 512, 'bert-base-chinese': 512, + 'bert-base-german-cased': 512, + 'bert-large-uncased-whole-word-masking': 512, + 'bert-large-cased-whole-word-masking': 512, + 'bert-large-uncased-whole-word-masking-finetuned-squad': 512, + 'bert-large-cased-whole-word-masking-finetuned-squad': 512, + 'bert-base-cased-finetuned-mrpc': 512, } VOCAB_NAME = 'vocab.txt' @@ -50,15 +62,11 @@ def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() - index = 0 with open(vocab_file, "r", encoding="utf-8") as reader: - while True: - token = reader.readline() - if not token: - break - token = token.strip() - vocab[token] = index - index += 1 + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip('\n') + vocab[token] = index return vocab diff --git a/onmt/utils/file_utils.py b/onmt/utils/file_utils.py index 17bdd258ea..9abfddeef4 100644 --- a/onmt/utils/file_utils.py +++ b/onmt/utils/file_utils.py @@ -1,5 +1,6 @@ """ Utilities for working with the local dataset cache. +Get from https://github.com/huggingface/pytorch-transformers. This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp Copyright by the AllenNLP authors. """ @@ -35,9 +36,6 @@ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) -CONFIG_NAME = "config.json" -WEIGHTS_NAME = "pytorch_model.bin" - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -59,32 +57,6 @@ def url_to_filename(url, etag=None): return filename -def filename_to_url(filename, cache_dir=None): - """ - Return the url and etag (which may be ``None``) stored for `filename`. - Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. - """ - if cache_dir is None: - cache_dir = PYTORCH_PRETRAINED_BERT_CACHE - if sys.version_info[0] == 3 and isinstance(cache_dir, Path): - cache_dir = str(cache_dir) - - cache_path = os.path.join(cache_dir, filename) - if not os.path.exists(cache_path): - raise EnvironmentError("file {} not found".format(cache_path)) - - meta_path = cache_path + '.json' - if not os.path.exists(meta_path): - raise EnvironmentError("file {} not found".format(meta_path)) - - with open(meta_path, encoding="utf-8") as meta_file: - metadata = json.load(meta_file) - url = metadata['url'] - etag = metadata['etag'] - - return url, etag - - def cached_path(url_or_filename, cache_dir=None): """ Given something that might be a URL (or might be a local path), @@ -250,21 +222,3 @@ def get_from_cache(url, cache_dir=None): logger.info("removing temp file %s", temp_file.name) return cache_path - - -def read_set_from_file(filename): - ''' - Extract a de-duped collection (set) of text from a file. - Expected file format is one item per line. - ''' - collection = set() - with open(filename, 'r', encoding='utf-8') as file_: - for line in file_: - collection.add(line.rstrip()) - return collection - - -def get_file_extension(path, dot=True, lower=True): - ext = os.path.splitext(path)[1] - ext = ext if dot else ext[1:] - return ext.lower() if lower else ext diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index a86f14bd6f..c018307ad1 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -24,7 +24,6 @@ def build_loss_compute(model, tgt_field, opt, train=True): """ device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") if opt.is_bert is True: - assert hasattr(model, 'bert') if tgt_field.pad_token is not None: if tgt_field.use_vocab: padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] @@ -70,6 +69,13 @@ def build_loss_compute(model, tgt_field, opt, train=True): class BertLoss(nn.Module): + """Class for managing BERT loss computation which is reduced by mean. + + Args: + criterion (:obj:`nn.NLLLoss`) : module that measures loss + between input and target. + task (str): BERT downstream task. + """ def __init__(self, criterion, task): super(BertLoss, self).__init__() self.criterion = criterion @@ -86,20 +92,21 @@ def _stats(self, loss, tokens_scores, tokens_target, sents_scores, sents_target): """ Args: - loss (:obj:`FloatTensor`): the loss computed by the loss criterion. + loss (:obj:`FloatTensor`): the loss reduced by mean. tokens_scores (:obj:`FloatTensor`): scores for each token tokens_target (:obj:`FloatTensor`): true targets for each token sents_scores (:obj:`FloatTensor`): scores for each sentence sents_target (:obj:`FloatTensor`): true targets for each sentence Returns: - :obj:`onmt.utils.Statistics` : statistics for this batch. + :obj:`onmt.utils.BertStatistics` : statistics for this batch. """ if self.task == 'pretraining': # masked lm task: token level pred_tokens = tokens_scores.argmax(1) # (B*S, V) -> (B*S) non_padding = tokens_target.ne(self.padding_idx) # mask: (B*S) - tokens_match = pred_tokens.eq(tokens_target).masked_select(non_padding) + tokens_match = pred_tokens.eq( + tokens_target).masked_select(non_padding) n_correct_tokens = tokens_match.sum().item() n_tokens = non_padding.sum().item() f1 = 0 @@ -122,7 +129,8 @@ def _stats(self, loss, tokens_scores, tokens_target, # token level task: pred_tokens = tokens_scores.argmax(1) # (B*S, V) -> (B*S) non_padding = tokens_target.ne(self.padding_idx) # mask: (B*S) - tokens_match = pred_tokens.eq(tokens_target).masked_select(non_padding) + tokens_match = pred_tokens.eq( + tokens_target).masked_select(non_padding) n_correct_tokens = tokens_match.sum().item() n_tokens = non_padding.sum().item() # for f1: @@ -139,7 +147,8 @@ def _stats(self, loss, tokens_scores, tokens_target, # token level task: pred_tokens = tokens_scores.argmax(1) # (B*S, V) -> (B*S) non_padding = tokens_target.ne(self.padding_idx) # mask: (B*S) - tokens_match = pred_tokens.eq(tokens_target).masked_select(non_padding) + tokens_match = pred_tokens.eq( + tokens_target).masked_select(non_padding) n_correct_tokens = tokens_match.sum().item() n_tokens = non_padding.sum().item() f1 = 0 @@ -153,15 +162,19 @@ def _stats(self, loss, tokens_scores, tokens_target, n_correct_tokens, n_sentences, n_correct_sents, f1) - def forward(self, batch, outputs): """ Args: - batch: batch of examples - outputs: tuple of log proba for next sentense & lm - (seq_class_log_prob:(batch, 2), - prediction_log_prob:(batch, seq, vocab)) + batch (Tensor): batch of examples + outputs (tuple of Tensor): (seq_class_log_prob:``(B, 2)``, + prediction_log_prob:``(B, S, vocab)``) + + Returns: + (float, BertStatistics) + * total_loss: total loss of input batch reduced by 'mean'. + * stats: A statistic object. """ + assert isinstance(outputs, tuple) seq_class_log_prob, prediction_log_prob = outputs if self.task == 'pretraining': diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index ee835d496d..5e853e38f7 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -52,7 +52,7 @@ def build_torch_optimizer(model, opt): lr=opt.learning_rate, betas=betas, eps=1e-9) - elif opt.optim == 'bertadam': # TODO:to be verified + elif opt.optim == 'bertadam': optimizer = BertAdam( params, lr=opt.learning_rate, @@ -621,17 +621,18 @@ def step(self, closure=None): class BertAdam(torch.optim.Optimizer): - """Implements BERT version of Adam algorithm with weight decay fix - (while doesn't compensate for bias). - Ref: https://arxiv.org/abs/1711.05101 - Params: - lr: learning rate - betas: Adam betas(beta1, beta2). Default: (0.9, 0.999) - eps: Adams epsilon. Default: 1e-6 - weight_decay: Weight decay. Default: 0.01 - # TODO: exclude LayerNorm from weight decay? - max_grad_norm: Maximum norm for the gradients (-1 means no clipping). - """ # TODO: add parameter to opt + """Implements Adam algorithm with weight decay fix + (used in BERT while doesn't compensate for bias). + :cite:`DBLP:journals/corr/abs-1711-05101` + + Args: + lr (float): learning rate + betas (tuple of float): Adam (beta1, beta2). Default: (0.9, 0.999) + eps (float): Adams epsilon. Default: 1e-6 + weight_decay (float): Weight decay. Default: 0.01 + max_grad_norm (float): -1 means no gradients clipping. + """ + def __init__(self, params, lr=None, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): if not 0.0 <= lr: @@ -653,7 +654,8 @@ def __init__(self, params, lr=None, betas=(0.9, 0.999), def step(self, closure=None): """Performs a single optimization step. - Arguments: + + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ @@ -695,7 +697,7 @@ def step(self, closure=None): exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) update = exp_avg / (exp_avg_sq.sqrt() + group['eps']) - # ref: https://arxiv.org/abs/1711.05101 + # Ref: https://arxiv.org/abs/1711.05101 # Just adding the square of the weights to the loss function # is *not* the correct way of using L2/weight decay with Adam, # since it will interact with m/v parameters in strange ways. diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index f988edf89c..2d090e3dd5 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -76,6 +76,16 @@ def ckpt_model_opts(cls, ckpt_opt): @classmethod def validate_train_opts(cls, opt): + if opt.is_bert: + logger.info("WE ARE IN BERT MODE.") + if opt.task_type is "none": + raise ValueError( + "Downstream task should be chosen when use BERT.") + if opt.reuse_embeddings is True: + if opt.task_type != "pretraining": + opt.reuse_embeddings = False + logger.warning( + "reuse_embeddings not available for this task.") if opt.epochs: raise AssertionError( "-epochs is deprecated please use -train_steps.") @@ -189,14 +199,14 @@ def validate_predict_opts(cls, opt): else: opt.delimiter = ' ' logger.info("NOTICE: opt.delimiter set to `%s`" % opt.delimiter) - assert opt.bert_model in PRETRAINED_VOCAB_ARCHIVE_MAP.keys(), \ - "Unsupported Pretrain model '%s'" % (opt.bert_model) - if '-cased' in opt.bert_model and opt.do_lower_case is True: + assert opt.vocab_model in PRETRAINED_VOCAB_ARCHIVE_MAP.keys(), \ + "Unsupported Pretrain model '%s'" % (opt.vocab_model) + if '-cased' in opt.vocab_model and opt.do_lower_case is True: logger.info("WARNING: The pre-trained model you are loading " + "is cased model, you shouldn't set `do_lower_case`," + "we turned it off for you.") opt.do_lower_case = False - elif '-cased' not in opt.bert_model and opt.do_lower_case is False: + elif '-cased' not in opt.vocab_model and opt.do_lower_case is False: logger.info("WARNING: The pre-trained model you are loading " + "is uncased model, you should set `do_lower_case`, " + "we turned it on for you.") diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index 4ac26e3c32..b3a2629901 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -137,7 +137,14 @@ def log_tensorboard(self, prefix, writer, learning_rate, step): class BertStatistics(Statistics): - """ Bert Statistics as the loss is reduced by mean """ + """ Bert Statistics as the loss is reduced by mean. + + Currently calculates: + * accuracy in token/sentence level + * perplexity + * elapsed time + * micro f1 for tagging + """ def __init__(self, loss=0, n_words=0, n_correct=0, n_sentence=0, n_correct_sentence=0, f1=0): super(BertStatistics, self).__init__(loss, n_words, n_correct) @@ -173,8 +180,8 @@ def update(self, stat, update_n_src_words=False): Update statistics by suming values with another `Statistics` object Args: - stat: another statistic object - update_n_src_words(bool): whether to update (sum) `n_src_words` + stat (BertStatistics): another statistic object + update_n_src_words (bool): whether to update (sum) `n_src_words` or not """ diff --git a/predict.py b/predict.py index 779e4e0283..8ca5c78fb6 100755 --- a/predict.py +++ b/predict.py @@ -16,7 +16,7 @@ def main(opt): logger = init_logger(opt.log_file) opt = ArgumentParser.validate_predict_opts(opt) tokenizer = BertTokenizer.from_pretrained( - opt.bert_model, do_lower_case=opt.do_lower_case) + opt.vocab_model, do_lower_case=opt.do_lower_case) data_shards = split_corpus(opt.data, opt.shard_size) if opt.task == 'classification': classifier = build_classifier(opt) diff --git a/pregenerate_bert_training_data.py b/pregenerate_bert_training_data.py index 5fc1a205f1..0df5be4373 100755 --- a/pregenerate_bert_training_data.py +++ b/pregenerate_bert_training_data.py @@ -1,6 +1,6 @@ """ This file is lifted from huggingface and adapted for onmt structure. -Ref: https://github.com/huggingface/pytorch-transformers/blob/master/examples/lm_finetuning/pregenerate_training_data.py +Ref in https://github.com/huggingface/pytorch-transformers/. """ from argparse import ArgumentParser from pathlib import Path @@ -8,16 +8,17 @@ from tempfile import TemporaryDirectory import shelve -from random import random, randrange, randint, shuffle, choice, sample -from onmt.utils.bert_tokenization import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP +from random import random, randrange, randint, shuffle, choice +from onmt.utils.bert_tokenization import BertTokenizer, \ + PRETRAINED_VOCAB_ARCHIVE_MAP from onmt.utils.file_utils import cached_path +from preprocess_bert_new import build_vocab_from_tokenizer import numpy as np import json -from onmt.inputters.inputter import get_bert_fields, _build_bert_fields_vocab -from onmt.inputters.dataset_bert import BertDataset -import os -from collections import Counter, defaultdict +from onmt.inputters.inputter import get_bert_fields +from onmt.inputters.dataset_bert import BertDataset, truncate_seq_pair import torch +import collections class DocumentDatabase: @@ -26,8 +27,8 @@ def __init__(self, reduce_memory=False): self.temp_dir = TemporaryDirectory() self.working_dir = Path(self.temp_dir.name) self.document_shelf_filepath = self.working_dir / 'shelf.db' - self.document_shelf = shelve.open(str(self.document_shelf_filepath), - flag='n', protocol=-1) + self.document_shelf = shelve.open( + str(self.document_shelf_filepath), flag='n', protocol=-1) self.documents = None else: self.documents = [] @@ -54,9 +55,10 @@ def _precalculate_doc_weights(self): self.cumsum_max = self.doc_cumsum[-1] def sample_doc(self, current_idx, sentence_weighted=True): - # Uses the current iteration counter to ensure we don't sample the same doc twice + # Uses the current_idx to ensure we don't sample the same doc twice if sentence_weighted: - # With sentence weighting, we sample docs proportionally to their sentence length + # With sentence weighting, we sample docs + # proportionally to their sentence length if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths): self._precalculate_doc_weights() rand_start = self.doc_cumsum[current_idx] @@ -64,7 +66,7 @@ def sample_doc(self, current_idx, sentence_weighted=True): sentence_index = randrange(rand_start, rand_end) % self.cumsum_max sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') else: - # If we don't use sentence weighting, then every doc has an equal chance to be chosen + # If sentence weighting is False, chose doc equally sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths) assert sampled_doc_index != current_idx if self.reduce_memory: @@ -91,68 +93,90 @@ def __exit__(self, exc_type, exc_val, traceback): self.temp_dir.cleanup() -def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): - """Truncates a pair of sequences to a maximum sequence length. Lifted from Google's BERT repo.""" - while True: - total_length = len(tokens_a) + len(tokens_b) - if total_length <= max_num_tokens: - break - - trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b - assert len(trunc_tokens) >= 1 - - # We want to sometimes truncate from the front and sometimes from the - # back to add more randomness and avoid biases. - if random() < 0.5: - del trunc_tokens[0] - else: - trunc_tokens.pop() +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", + ["index", "label"]) -def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, tokenizer): - """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but - with several refactors to clean it up and remove a lot of unnecessary variables.""" - vocab_dict = tokenizer.vocab +def create_masked_lm_predictions(tokens, masked_lm_prob, + max_predictions_per_seq, + whole_word_mask, vocab_dict): + """Creates the predictions for the masked LM. This is mostly copied from + the Huggingface BERT repo, but pregenerate lm_labels_ids.""" vocab_list = list(vocab_dict.keys()) cand_indices = [] for (i, token) in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue - cand_indices.append(i) + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. When a word has been split into + # WordPieces, the first token does not have any marker and any + # subsequence tokens are prefixed with ##. So whenever we see the ##, + # we append it to the previous set of word indexes. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (whole_word_mask and len(cand_indices) >= 1 + and token.startswith("##")): + cand_indices[-1].append(i) + else: + cand_indices.append([i]) num_to_mask = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) shuffle(cand_indices) - mask_indices = sorted(sample(cand_indices, num_to_mask)) - masked_token_labels = [] - for index in mask_indices: - # 80% of the time, replace with [MASK] - if random() < 0.8: - masked_token = "[MASK]" - else: - # 10% of the time, keep original - if random() < 0.5: - masked_token = tokens[index] - # 10% of the time, replace with random word + masked_lms = [] + covered_indexes = set() + for index_set in cand_indices: + if len(masked_lms) >= num_to_mask: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_mask: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if random() < 0.8: + masked_token = "[MASK]" else: - masked_token = choice(vocab_list) - masked_token_labels.append(tokens[index]) - # Once we've saved the true label for that token, we can overwrite it with the masked version - tokens[index] = masked_token + # 10% of the time, keep original + if random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = choice(vocab_list) + masked_lms.append(MaskedLmInstance(index=index, + label=tokens[index])) + # Replace true token with masked token + tokens[index] = masked_token + + assert len(masked_lms) <= num_to_mask + masked_lms = sorted(masked_lms, key=lambda x: x.index) + mask_indices = [p.index for p in masked_lms] + masked_token_labels = [p.label for p in masked_lms] lm_labels_ids = [-1 for _ in tokens] for (i, token) in zip(mask_indices, masked_token_labels): lm_labels_ids[i] = vocab_dict[token] assert len(lm_labels_ids) == len(tokens) - return tokens, mask_indices, masked_token_labels, lm_labels_ids + return tokens, mask_indices, masked_token_labels, lm_labels_ids def create_instances_from_document( - doc_database, doc_idx, max_seq_length, short_seq_prob, - masked_lm_prob, max_predictions_per_seq, tokenizer): - """This code is mostly a duplicate of the equivalent function from Google BERT's repo. - However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function. - Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence - (rather than each document) has an equal chance of being sampled as a false example for the NextSentence task.""" + doc_database, doc_idx, vocab_dict, max_seq_length, short_seq_prob, + masked_lm_prob, max_predictions_per_seq, whole_word_mask): + """This code is mostly a duplicate of the equivalent function from + HuggingFace BERT's repo. But we use lm_labels_ids rather than + mask_indices and masked_token_labels.""" document = doc_database[doc_idx] # Account for [CLS], [SEP], [SEP] max_num_tokens = max_seq_length - 3 @@ -171,7 +195,7 @@ def create_instances_from_document( # We DON'T just concatenate all of the tokens from a document into a long # sequence and choose an arbitrary split point because this would make the # next sentence prediction task too easy. Instead, we split the input into - # segments "A" and "B" based on the actual "sentences" provided by the user + # segments "A" and "B" based on the actual "sentences" provided by user's # input. instances = [] current_chunk = [] @@ -183,8 +207,8 @@ def create_instances_from_document( current_length += len(segment) if i == len(document) - 1 or current_length >= target_seq_length: if current_chunk: - # `a_end` is how many segments from `current_chunk` go into the `A` - # (first) sentence. + # `a_end` is how many segments from `current_chunk` go into + # `A` (first) sentence. a_end = 1 if len(current_chunk) >= 2: a_end = randrange(1, len(current_chunk)) @@ -200,16 +224,18 @@ def create_instances_from_document( is_next = False target_b_length = target_seq_length - len(tokens_a) - # Sample a random document, with longer docs being sampled more frequently - random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True) + # Sample a random document with longer docs being + # sampled more frequently + random_document = doc_database.sample_doc( + current_idx=doc_idx, sentence_weighted=True) random_start = randrange(0, len(random_document)) for j in range(random_start, len(random_document)): tokens_b.extend(random_document[j]) if len(tokens_b) >= target_b_length: break - # We didn't actually use these segments so we "put them back" so - # they don't go to waste. + # We didn't actually use these segments so we + # "put them back" so they don't go to waste. num_unused_segments = len(current_chunk) - a_end i -= num_unused_segments # Actual next @@ -222,13 +248,14 @@ def create_instances_from_document( assert len(tokens_a) >= 1 assert len(tokens_b) >= 1 - tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] - # The segment IDs are 0 for the [CLS] token, the A tokens and the first [SEP] - # They are 1 for the B tokens and the final [SEP] - segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] + tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + \ + tokens_b + ["[SEP]"] + segment_ids = [0 for _ in range(len(tokens_a) + 2)] + \ + [1 for _ in range(len(tokens_b) + 1)] - tokens, masked_lm_positions, masked_lm_labels, lm_labels_ids = create_masked_lm_predictions( - tokens, masked_lm_prob, max_predictions_per_seq, tokenizer) + tokens, _, _, lm_labels_ids = create_masked_lm_predictions( + tokens, masked_lm_prob, max_predictions_per_seq, + whole_word_mask, vocab_dict) instance = { "tokens": tokens, @@ -245,54 +272,9 @@ def create_instances_from_document( return instances -def _build_bert_vocab(vocab, name, counters, min_freq=0): - """ similar to _load_vocab in inputter.py, but build from a vocab list. - in place change counters - """ - vocab_size = len(vocab) - for i, token in enumerate(vocab): - counters[name][token] = vocab_size - i + min_freq - return vocab, vocab_size - - -def main(): - parser = ArgumentParser() - parser.add_argument('--train_corpus', type=Path, default="/home/lzeng/Documents/OpenNMT-py/onmt/inputters/small_wiki_sentence_corpus.txt") # required=True) - parser.add_argument("--corpus_type", type=str, default="train") # required=True) - parser.add_argument("--output_dir", type=Path, default="/home/lzeng/Documents/OpenNMT-py/onmt/inputters/test_opennmt/") # required=True) - parser.add_argument("--output_name", type=str, default="dataset") - parser.add_argument("--bert_model", type=str, default="bert-base-uncased")#, # required=True, - # choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", - # "bert-base-multilingual", "bert-base-chinese"]) - # parser.add_argument("--vocab_pathname", type=Path, required=True) # vocab file correspand to bert_model - - parser.add_argument("--do_lower_case", default=True) # action="store_true") - - parser.add_argument("--reduce_memory", action="store_true", - help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") - - parser.add_argument("--epochs_to_generate", type=int, default=20, - help="Number of epochs of data to pregenerate") - parser.add_argument("--max_seq_len", type=int, default=256) # 128 - parser.add_argument("--short_seq_prob", type=float, default=0.1, - help="Probability of making a short sentence as a training example") - parser.add_argument("--masked_lm_prob", type=float, default=0.15, - help="Probability of masking each token for the LM task") - parser.add_argument("--max_predictions_per_seq", type=int, default=20, - help="Maximum number of tokens to mask in each sequence") - parser.add_argument("--tokens_min_frequency", type=int, default=0) # not tested - parser.add_argument("--vocab_size_multiple", type=int, default=1) # not tested - - args = parser.parse_args() - fields = get_bert_fields() # - tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) - # save Vocab file - vocab_file_url = PRETRAINED_VOCAB_ARCHIVE_MAP[args.bert_model] - vocab_dir = Path.joinpath(args.output_dir, f"{args.bert_model}-vocab.txt") - vocab_file = cached_path(vocab_file_url, cache_dir=vocab_dir) - print("Donwload ") - with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: - with args.train_corpus.open() as f: +def build_document_database(input_file, tokenizer, reduce_memory): + with DocumentDatabase(reduce_memory=reduce_memory) as docs: + with input_file.open() as f: doc = [] for line in tqdm(f, desc="Loading Dataset", unit=" lines"): line = line.strip() @@ -302,56 +284,137 @@ def main(): else: tokens = tokenizer.tokenize(line) doc.append(tokens) - if doc: - docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added + if len(doc) != 0: # If didn't end on a newline, still add + docs.add_document(doc) if len(docs) <= 1: - exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to " - "ensure that random NextSentences are not sampled from the same document. Please add blank lines to " - "indicate breaks between documents in your input file. If your dataset does not contain multiple " - "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, " - "sections or paragraphs.") - - args.output_dir.mkdir(exist_ok=True) - for epoch in trange(args.epochs_to_generate, desc="Epoch"): - # epoch_filename = args.output_dir / f"epoch_{epoch}.json" - epoch_filename = args.output_dir / f"{args.output_name}.{args.corpus_type}.{epoch}.pt" - json_name = args.output_dir / f"{args.output_name}.{args.corpus_type}.{epoch}.json" - num_instances = 0 - with json_name.open('w') as epoch_file: - docs_instances = [] - for doc_idx in trange(len(docs), desc="Document"): - doc_instances = create_instances_from_document( - docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, - masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, - tokenizer=tokenizer) # return a list of dict [{}] - docs_instances.extend(doc_instances) - doc_instances_json = [json.dumps(instance) for instance in doc_instances] - for instance in doc_instances_json: - epoch_file.write(instance + '\n') - num_instances += 1 - # build BertDataset from instances collected from different document - dataset = BertDataset(fields, docs_instances) - dataset.save(epoch_filename) - num_doc_instances = len(docs_instances) - print("output file {}, num_example {}, max_seq_len {}".format(epoch_filename,num_doc_instances,args.max_seq_len)) - - metrics_file = args.output_dir / f"{args.output_name}.metrics.{args.corpus_type}.{epoch}.json" + exit("""ERROR: No document breaks were found in the input file! + These are necessary to ensure that random NextSentences + are not sampled from the same document. Please add blank + lines to indicate breaks between documents in your file. + If your dataset does not contain multiple documents, + blank lines can be inserted at any natural boundary, + such as the ends of chapters, sections or paragraphs.""") + return docs + + +def create_instances_from_docs(doc_database, vocab_dict, args): + docs_instances = [] + for doc_idx in trange(len(doc_database), desc="Document"): + doc_instances = create_instances_from_document( + doc_database, doc_idx, vocab_dict=vocab_dict, + max_seq_length=args.max_seq_len, + short_seq_prob=args.short_seq_prob, + masked_lm_prob=args.masked_lm_prob, + max_predictions_per_seq=args.max_predictions_per_seq, + whole_word_mask=args.do_whole_word_mask) + docs_instances.extend(doc_instances) + return docs_instances + + +def save_data_as_json(instances, json_name): + instances_json = [json.dumps(instance) for instance in instances] + num_instances = 0 + with open(json_name, 'w') as json_file: + for instance in instances_json: + json_file.write(instance + '\n') + num_instances += 1 + return num_instances + + +def _get_parser(): + parser = ArgumentParser() + parser.add_argument('--input_file', type=Path, required=True) + parser.add_argument("--output_dir", type=Path, required=True) + parser.add_argument("--output_name", type=str, default="dataset") + parser.add_argument('--corpus_type', type=str, default="train", + choices=['train', 'valid'], + help="Choose from ['train', 'valid'], " + + "Vocab file will be generate if `train`") + parser.add_argument("--vocab_model", type=str, required=True, + choices=["bert-base-uncased", "bert-large-uncased", + "bert-base-cased", "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-base-cased-finetuned-mrpc"], + help="Pretrained vocab model use to tokenizer text.") + + parser.add_argument("--do_lower_case", action="store_true") + parser.add_argument("--do_whole_word_mask", action="store_true", + help="Whether to use whole word masking.") + parser.add_argument("--reduce_memory", action="store_true", + help="""Reduce memory usage for large datasets + by keeping data on disc rather than in memory""") + + parser.add_argument("--epochs_to_generate", type=int, default=2, + help="Number of epochs of data to pregenerate") + parser.add_argument("--max_seq_len", type=int, default=128) + parser.add_argument("--short_seq_prob", type=float, default=0.1, + help="Prob. of a short sentence as training example") + parser.add_argument("--masked_lm_prob", type=float, default=0.15, + help="Prob. of masking each token for the LM task") + parser.add_argument("--max_predictions_per_seq", type=int, default=20, + help="Max number of tokens to mask in each sequence") + parser.add_argument("--save_json", action="store_true", + help='save a copy of data in json form.') + return parser + + +def main(args): + tokenizer = BertTokenizer.from_pretrained( + args.vocab_model, do_lower_case=args.do_lower_case) + + docs = build_document_database( + args.input_file, tokenizer, args.reduce_memory) + + fields = get_bert_fields() + vocab_dict = tokenizer.vocab + args.output_dir.mkdir(exist_ok=True) + + # Build file corpus.pt + for epoch in trange(args.epochs_to_generate, desc="Epoch"): + docs_instances = create_instances_from_docs(docs, vocab_dict, args) + + # build BertDataset from instances collected from different document + dataset = BertDataset(fields, docs_instances) + epoch_filename = args.output_dir / "{}.{}.{}.pt".format( + args.output_name, args.corpus_type, epoch) + dataset.save(epoch_filename) + print("output file {}, num_example {}, max_seq_len {}".format( + epoch_filename, len(docs_instances), args.max_seq_len)) + + if args.save_json: + json_name = args.output_dir / "{}.{}.{}.json".format( + args.output_name, args.corpus_type, epoch) + num_instances = save_data_as_json(docs_instances, json_name) + metrics_file = args.output_dir / "{}.{}.{}.metrics.json".format( + args.output_name, args.corpus_type, epoch) with metrics_file.open('w') as metrics_file: metrics = { "num_training_examples": num_instances, "max_seq_len": args.max_seq_len } metrics_file.write(json.dumps(metrics)) - # Build file Vocab.pt + + # Build file Vocab.pt if args.corpus_type == "train": - print("Building vocab from text file...") - vocab_list = list(tokenizer.vocab.keys()) - counters = defaultdict(Counter) - _, vocab_size = _build_bert_vocab(vocab_list, "tokens", counters) - fields = _build_bert_fields_vocab(fields, counters, vocab_size, None, args.tokens_min_frequency, args.vocab_size_multiple) # - bert_vocab_file = args.output_dir / f"{args.output_name}.vocab.pt" - torch.save(fields, bert_vocab_file) + vocab_file_url = PRETRAINED_VOCAB_ARCHIVE_MAP[args.vocab_model] + vocab_dir = Path.joinpath(args.output_dir, + "%s-vocab.txt" % (args.vocab_model)) + cached_vocab = cached_path(vocab_file_url, cache_dir=vocab_dir) + print("Vocab file is Cached at %s." % cached_vocab) + fields_vocab = build_vocab_from_tokenizer( + fields, tokenizer, None) + bert_vocab_file = Path.joinpath(args.output_dir, + "%s.vocab.pt" % (args.output_name)) + print("Build Fields Vocab file.") + torch.save(fields_vocab, bert_vocab_file) if __name__ == '__main__': - main() + parser = _get_parser() + args = parser.parse_args() + main(args) diff --git a/preprocess_bert.py b/preprocess_bert.py deleted file mode 100755 index 57bd211033..0000000000 --- a/preprocess_bert.py +++ /dev/null @@ -1,389 +0,0 @@ -from argparse import ArgumentParser -from tqdm import tqdm -import csv -from random import shuffle -from onmt.utils.bert_tokenization import BertTokenizer, \ - PRETRAINED_VOCAB_ARCHIVE_MAP -import json -from onmt.inputters.inputter import get_bert_fields, \ - _build_bert_fields_vocab -from onmt.inputters.dataset_bert import BertDataset, \ - create_sentence_instance, create_sentence_pair_instance -from collections import Counter, defaultdict -import torch -import os -import codecs - - -def create_instances_from_csv(records, skip_head, tokenizer, max_seq_length, - column_a, column_b, label_column, labels): - instances = [] - for _i, record in tqdm(enumerate(records), desc="Process", unit=" lines"): - if _i == 0 and skip_head: - continue - else: - sentence_a = record[column_a].strip() - if column_b is None: - tokens_processed, segment_ids = create_sentence_instance( - sentence_a, tokenizer, max_seq_length) - else: - sentence_b = record[column_b].strip() - tokens_processed, segment_ids = create_sentence_pair_instance( - sentence_a, sentence_b, tokenizer, max_seq_length) - - label = record[label_column].strip() - if label not in labels: - labels.append(label) - instance = { - "tokens": tokens_processed, - "segment_ids": segment_ids, - "category": label} - instances.append(instance) - return instances, labels - - -def build_instances_from_csv(data, skip_head, tokenizer, input_columns, - label_column, labels, max_seq_len, do_shuffle): - with open(data, "r", encoding="utf-8-sig") as csvfile: - reader = csv.reader(csvfile, delimiter='\t', quotechar=None) - lines = list(reader) - print("total {} line loaded: ".format(len(lines))) - if len(input_columns) == 1: - column_a = int(input_columns[0]) - column_b = None - else: - column_a = int(input_columns[0]) - column_b = int(input_columns[1]) - instances, labels = create_instances_from_csv( - lines, skip_head, tokenizer, max_seq_len, - column_a, column_b, label_column, labels) - if do_shuffle is True: - print("Shuffle all {} instance".format(len(instances))) - shuffle(instances) - return instances, labels - - -def create_instances_from_file(records, label, tokenizer, max_seq_length): - instances = [] - for _i, record in tqdm(enumerate(records), desc="Process", unit=" lines"): - sentence = record.strip() - tokens_processed, segment_ids = create_sentence_instance( - sentence, tokenizer, max_seq_length, random_trunc=True) - instance = { - "tokens": tokens_processed, - "segment_ids": segment_ids, - "category": label} - instances.append(instance) - return instances - - -def build_instances_from_files(data, labels, tokenizer, - max_seq_len, do_shuffle): - instances = [] - for filename in data: - label = filename.split('/')[-2] - with codecs.open(filename, "r", encoding="utf-8") as f: - lines = f.readlines() - print("total {} line of File {} loaded for label: {}.".format( - len(lines), filename, label)) - file_instances = create_instances_from_file( - lines, label, tokenizer, max_seq_len) - instances.extend(file_instances) - if do_shuffle is True: - print("Shuffle all {} instance".format(len(instances))) - shuffle(instances) - return instances - - -def create_tag_instance_from_sentence(token_pairs, tokenizer, max_seq_len, - pad_tok): - """ - token_pairs: list of (word, tag) pair that form a sentence - tokenizer: tokenizer we use to tokenizer the words in token_pairs - max_seq_len: max sequence length that a instance could contain - """ - sentence = [] - tags = [] - max_num_tokens = max_seq_len - 2 - for (word, tag) in token_pairs: - tokens = tokenizer.tokenize(word) - n_pad = len(tokens) - 1 - paded_tag = [tag] + [pad_tok] * n_pad - if len(sentence) + len(tokens) > max_num_tokens: - break - else: - sentence.extend(tokens) - tags.extend(paded_tag) - sentence = ["[CLS]"] + sentence + ["[SEP]"] - tags = [pad_tok] + tags + [pad_tok] - segment_ids = [0 for _ in range(len(sentence))] - instance = { - "tokens": sentence, - "segment_ids": segment_ids, - "token_labels": tags - } - return instance - - -def build_tag_instances_from_file(filename, skip_head, tokenizer, max_seq_len, - token_column, tag_column, tags, do_shuffle, - pad_tok, delimiter=' '): - sentences = [] - labels = [] if tags is None else tags - with codecs.open(filename, "r", encoding="utf-8") as f: - lines = f.readlines() - if skip_head is True: - lines = lines[1:] - print("total {} line of file {} loaded.".format( - len(lines), filename)) - sentence_sofar = [] - for line in tqdm(lines, desc="Process", unit=" lines"): - line = line.strip() - if line is '': - if len(sentence_sofar) > 0: - sentences.append(sentence_sofar) - sentence_sofar = [] - else: - elements = line.split(delimiter) - token = elements[token_column] - tag = elements[tag_column] - if tag not in labels: - labels.append(tag) - sentence_sofar.append((token, tag)) - print("total {} sentence loaded.".format(len(sentences))) - print("All tags:", labels) - - instances = [] - for sentence in sentences: - instance = create_tag_instance_from_sentence( - sentence, tokenizer, max_seq_len, pad_tok) - instances.append(instance) - - if do_shuffle is True: - print("Shuffle all {} instance".format(len(instances))) - shuffle(instances) - return instances, labels - - -def _build_bert_vocab(vocab, name, counters): - """ similar to _load_vocab in inputter.py, but build from a vocab list. - in place change counters - """ - vocab_size = len(vocab) - for i, token in enumerate(vocab): - counters[name][token] = vocab_size - i - return vocab, vocab_size - - -def build_vocab_from_tokenizer(fields, tokenizer, named_labels, - tokens_min_frequency=0, vocab_size_multiple=1): - vocab_list = list(tokenizer.vocab.keys()) - counters = defaultdict(Counter) - _, vocab_size = _build_bert_vocab(vocab_list, "tokens", counters) - - if named_labels is not None: - label_name, label_list = named_labels - _, _ = _build_bert_vocab(label_list, label_name, counters) - else: - label_name = None - - fields_vocab = _build_bert_fields_vocab(fields, counters, vocab_size, - label_name, tokens_min_frequency, - vocab_size_multiple) - return fields_vocab - - -def save_data_as_json(instances, json_name): - instances_json = [json.dumps(instance) for instance in instances] - num_instances = 0 - with open(json_name, 'w') as json_file: - for instance in instances_json: - json_file.write(instance + '\n') - num_instances += 1 - return num_instances - - -def validate_preprocess_bert_opts(opts): - assert opts.bert_model in PRETRAINED_VOCAB_ARCHIVE_MAP.keys(), \ - "Unsupported Pretrain model '%s'" % (opts.bert_model) - for filename in opts.data: - assert os.path.isfile(filename),\ - "Please check path of %s" % filename - - if args.task == "tagging": - assert args.data_type == 'txt',\ - "For sequence tagging, only txt file is supported." - - assert len(opts.input_columns) == 1,\ - "For sequence tagging, only one column for input tokens." - opts.input_columns = opts.input_columns[0] - - assert args.label_column is not None,\ - "For sequence tagging, label column should be given." - - if opts.data_type == "csv": - assert len(opts.data) == 1,\ - "For csv, only one file is needed." - assert len(opts.input_columns) in [1, 2],\ - "Please indicate N.colomn for sentence A (and B)" - assert args.label_column is not None,\ - "For csv file, label column should be given." - - # if opts.label_column is not None: - # assert len(opts.labels) != 0,\ - # "label list is needed when csv contain label column" - - # elif opts.data_type == "txt": - # if opts.task == "classification": - # assert len(opts.datas) == len(opts.labels), \ - # "Label should correspond to input files" - return opts - - -def _get_parser(): - parser = ArgumentParser(description='preprocess_bert.py') - - parser.add_argument('--data', type=str, nargs='+', default=[], - required=True, help="input datas to prepare: [CLS]" + - "Single file for csv with column indicate label," + - "One file for each class as path/label/file; [TAG]" + - "Single file contain (tok, tag) in each line," + - "Sentence separated by blank line.") - parser.add_argument('--data_type', type=str, default="csv", - choices=["csv", "txt"], - help="input data type") - parser.add_argument('--skip_head', action="store_true", - help="CSV: If csv file contain head line.") - - parser.add_argument('--input_columns', type=int, nargs='+', default=[], - help="CSV: Column where contain sentence A(,B)") - parser.add_argument('--label_column', type=int, default=None, - help="CSV: Column where contain label") - parser.add_argument('--labels', type=str, nargs='+', default=[], - help="Candidate labels. If not given, build from " + - "input file and sort in alphabetic order.") - parser.add_argument('--delimiter', '-d', type=str, default=' ', - help="CSV: delimiter used for seperate column.") - - parser.add_argument('--task', type=str, default="classification", - choices=["classification", "tagging"], - help="Target task to perform") - parser.add_argument("--corpus_type", type=str, default="train", - choices=["train", "valid", "test"]) - parser.add_argument('--save_data', '-save_data', type=str, - default=None, required=True, - help="Output file Prefix for the prepared data") - parser.add_argument("--do_shuffle", action="store_true", - help='shuffle data') - parser.add_argument("--bert_model", type=str, - default="bert-base-multilingual-uncased", - choices=["bert-base-uncased", "bert-large-uncased", - "bert-base-cased", "bert-large-cased", - "bert-base-multilingual-uncased", - "bert-base-multilingual-cased", - "bert-base-chinese"], - help="Bert pretrained model to finetuning with.") - - parser.add_argument("--do_lower_case", action="store_true", - help='lowercase data') - parser.add_argument("--max_seq_len", type=int, default=512, - help="max sequence length for prepared data," - "set the limite of position encoding") - parser.add_argument("--tokens_min_frequency", type=int, default=0) - parser.add_argument("--vocab_size_multiple", type=int, default=1) - parser.add_argument("--save_json", action="store_true", - help='save a copy of data in json form.') - return parser - - -def main(args): - print("Task: '%s', model: '%s', corpus: '%s'." - % (args.task, args.bert_model, args.corpus_type)) - - fields = get_bert_fields(args.task) - tokenizer = BertTokenizer.from_pretrained( - args.bert_model, do_lower_case=args.do_lower_case) - - if args.task == "classification": - # Build instances from csv file - if args.data_type == 'csv': - filename = args.data[0] - print("Load data file %s with skip head %s" % ( - filename, args.skip_head)) - input_columns = args.input_columns - label_column = args.label_column - print("Input column at {}, label at [{}]".format( - input_columns, label_column)) - instances, labels = build_instances_from_csv( - filename, args.skip_head, tokenizer, - input_columns, label_column, - args.labels, args.max_seq_len, args.do_shuffle) - labels.sort() - args.labels = labels - print("Labels:", args.labels) - elif args.data_type == 'txt': - if len(args.labels) == 0: - print("Build labels from file dir...") - labels = [] - for filename in args.data: - label = filename.split('/')[-2] - if label not in labels: - labels.append(label) - labels.sort() - args.labels = labels - print("Labels:", args.labels) - instances = build_instances_from_files( - args.data, args.labels, tokenizer, - args.max_seq_len, args.do_shuffle) - else: - raise NotImplementedError("Not support other file type yet!") - - if args.task == "tagging": - pad_tok = fields["token_labels"].pad_token # "[PAD]" for Bert Paddings - filename = args.data[0] - print("Load data file %s with skip head %s" % ( - filename, args.skip_head)) - token_column, tag_column = args.input_columns, args.label_column - instances, labels = build_tag_instances_from_file( - filename, args.skip_head, tokenizer, args.max_seq_len, - token_column, tag_column, args.labels, args.do_shuffle, - pad_tok, delimiter=args.delimiter) - labels.sort() - args.labels = [pad_tok] + labels - print("Labels:", args.labels) - - # Save processed data in OpenNMT format - onmt_filename = args.save_data + ".{}.0.pt".format(args.corpus_type) - # Build BertDataset from instances collected from different document - dataset = BertDataset(fields, instances) - dataset.save(onmt_filename) - print("save processed data {}, num_example {}, max_seq_len {}".format( - onmt_filename, len(instances), args.max_seq_len)) - - if args.save_json: - json_name = args.save_data + ".{}.json".format(args.corpus_type) - num_instances = save_data_as_json(instances, json_name) - print("output file {}, num_example {}, max_seq_len {}".format( - json_name, num_instances, args.max_seq_len)) - - # Build file Vocab.pt from tokenizer - if args.corpus_type == "train": - print("Generating vocab from corresponding text file...") - if args.task == "classification": - named_labels = ("category", args.labels) - if args.task == "tagging": - named_labels = ("token_labels", args.labels) - print("Save Labels:", named_labels, "in vocab file.") - - fields_vocab = build_vocab_from_tokenizer( - fields, tokenizer, named_labels, - args.tokens_min_frequency, args.vocab_size_multiple) - bert_vocab_file = args.save_data + ".vocab.pt" - torch.save(fields_vocab, bert_vocab_file) - - -if __name__ == '__main__': - parser = _get_parser() - args = parser.parse_args() - args = validate_preprocess_bert_opts(args) - main(args) diff --git a/preprocess_bert_new.py b/preprocess_bert_new.py index c521edecca..0505b584df 100755 --- a/preprocess_bert_new.py +++ b/preprocess_bert_new.py @@ -54,9 +54,12 @@ def build_vocab_from_tokenizer(fields, tokenizer, named_labels): counters = defaultdict(Counter) _, vocab_size = _build_bert_vocab(vocab_list, "tokens", counters) - label_name, label_list = named_labels - logger.info("Building label vocab {}...".format(named_labels)) - _, _ = _build_bert_vocab(label_list, label_name, counters) + if named_labels is not None: + label_name, label_list = named_labels + logger.info("Building label vocab {}...".format(named_labels)) + _, _ = _build_bert_vocab(label_list, label_name, counters) + else: + label_name = None fields_vocab = _build_bert_fields_vocab(fields, counters, vocab_size, label_name) From 6c5ec3a423ee79fe0e07eac5e25e6d9c60c924ed Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Mon, 26 Aug 2019 16:37:32 +0200 Subject: [PATCH 16/28] Fix flake8 --- onmt/encoders/__init__.py | 2 +- onmt/model_builder.py | 8 +- onmt/train_single.py | 3 +- onmt/utils/__init__.py | 5 +- onmt/utils/activation_fn.py | 4 - onmt/utils/bert_tokenization.py | 144 ++++++++++++++------------- onmt/utils/bert_vocab_archive_map.py | 18 ++++ onmt/utils/file_utils.py | 50 ++++++---- onmt/utils/loss.py | 9 +- onmt/utils/parse.py | 2 +- pregenerate_bert_training_data.py | 16 +-- 11 files changed, 152 insertions(+), 109 deletions(-) create mode 100644 onmt/utils/bert_vocab_archive_map.py diff --git a/onmt/encoders/__init__.py b/onmt/encoders/__init__.py index e5e9a58bc5..fc054b8208 100644 --- a/onmt/encoders/__init__.py +++ b/onmt/encoders/__init__.py @@ -14,4 +14,4 @@ "audio": AudioEncoder, "mean": MeanEncoder, "bert": BertEncoder} __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", - "MeanEncoder", "str2enc", "BertEncoder"] + "MeanEncoder", "str2enc", "BertEncoder", "BertLayerNorm"] diff --git a/onmt/model_builder.py b/onmt/model_builder.py index cfa98e6085..df13264e3d 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -268,7 +268,7 @@ def build_bert_generator(model_opt, fields, bert_encoder): Both all_encoder_layers and pooled_output will be feed to generator, pretraining task will use the two, while only pooled_output will be used for classification generator; - only all_encoder_layers will be used for generation generator; + only all_encoder_layers will be used for generation generator """ task = model_opt.task_type dropout = model_opt.dropout[0] if type(model_opt.dropout) is list \ @@ -345,9 +345,11 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): logger.info("Load Model Parameters...") model.bert.load_state_dict(checkpoint['model'], strict=True) model_init['bert'] = True - if model.generator.state_dict().keys() == checkpoint['generator'].keys(): + if (model.generator.state_dict().keys() == + checkpoint['generator'].keys()): logger.info("Load generator Parameters...") - model.generator.load_state_dict(checkpoint['generator'], strict=True) + model.generator.load_state_dict(checkpoint['generator'], + strict=True) model_init['generator'] = True for sub_module, is_init in model_init.items(): diff --git a/onmt/train_single.py b/onmt/train_single.py index 2cb790db9b..b4c36bb9da 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -59,7 +59,8 @@ def main(opt, device_id, batch_queue=None, semaphore=None): model_opt = opt if 'vocab' in checkpoint: - logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) + logger.info('Loading vocab from checkpoint at %s.', + opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') diff --git a/onmt/utils/__init__.py b/onmt/utils/__init__.py index b9352e3d14..628ec83ca6 100644 --- a/onmt/utils/__init__.py +++ b/onmt/utils/__init__.py @@ -7,8 +7,11 @@ Optimizer, AdaFactor, BertAdam from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts from onmt.utils.activation_fn import get_activation_fn +from onmt.utils.bert_tokenization import BertTokenizer +from onmt.utils.bert_vocab_archive_map import PRETRAINED_VOCAB_ARCHIVE_MAP __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", "build_report_manager", "Statistics", "BertStatistics", "MultipleOptimizer", "Optimizer", "AdaFactor", "BertAdam", - "EarlyStopping", "scorers_from_opts", "get_activation_fn"] + "EarlyStopping", "scorers_from_opts", "get_activation_fn", + "BertTokenizer", "PRETRAINED_VOCAB_ARCHIVE_MAP"] diff --git a/onmt/utils/activation_fn.py b/onmt/utils/activation_fn.py index d8e18e9d4a..705f730d54 100644 --- a/onmt/utils/activation_fn.py +++ b/onmt/utils/activation_fn.py @@ -17,10 +17,6 @@ def get_activation_fn(activation): return fn -""" -Adapted from huggingface implementation to reproduce the result -https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py -""" class GELU(nn.Module): """ Implementation of the gelu activation function :cite:`DBLP:journals/corr/HendrycksG16` diff --git a/onmt/utils/bert_tokenization.py b/onmt/utils/bert_tokenization.py index 9762cf9109..7dda9cbb63 100644 --- a/onmt/utils/bert_tokenization.py +++ b/onmt/utils/bert_tokenization.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,8 @@ # limitations under the License. """Tokenization classes.""" -from __future__ import absolute_import, division, print_function, unicode_literals +from __future__ import absolute_import, division, \ + print_function, unicode_literals import collections import logging @@ -23,24 +24,10 @@ from io import open from .file_utils import cached_path +from onmt.utils.bert_vocab_archive_map import PRETRAINED_VOCAB_ARCHIVE_MAP logger = logging.getLogger(__name__) -PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", - 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", - 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", - 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", - 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", - 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", - 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", -} PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 'bert-base-uncased': 512, 'bert-large-uncased': 512, @@ -82,33 +69,37 @@ def whitespace_tokenize(text): class BertTokenizer(object): """Runs end-to-end tokenization: punctuation splitting + wordpiece""" - def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, + def __init__(self, vocab_file, do_lower_case=True, max_len=None, + do_basic_tokenize=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): """Constructs a BertTokenizer. Args: - vocab_file: Path to a one-wordpiece-per-line vocabulary file - do_lower_case: Whether to lower case the input - Only has an effect when do_wordpiece_only=False - do_basic_tokenize: Whether to do basic tokenization before wordpiece. - max_len: An artificial maximum length to truncate tokenized sequences to; - Effective maximum length is always the minimum of this - value (if specified) and the underlying BERT model's - sequence length. - never_split: List of tokens which will never be split during tokenization. - Only has an effect when do_wordpiece_only=False + vocab_file (str): Path to a one-wordpiece-per-line vocabulary file + do_lower_case (bool): If to lower case the input, Only has + an effect when do_wordpiece_only=False + do_basic_tokenize (bool): If to do basic tokenization before WP. + max_len (int): Maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of + this value (if specified) and the underlying BERT + model's sequence length. + never_split (list): List of tokens which will never be split during + tokenization. Only has an effect when + do_wordpiece_only=False. """ if not os.path.isfile(vocab_file): raise ValueError( - "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " - "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + "Can't find a vocabulary file at path '{}'. " + "To load the vocabulary from a Google pretrained model use " + "`tokenizer = BertTokenizer.from_pretrained(" + "PRETRAINED_MODEL_NAME)`".format(vocab_file)) self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict( [(ids, tok) for tok, ids in self.vocab.items()]) self.do_basic_tokenize = do_basic_tokenize if do_basic_tokenize: - self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, - never_split=never_split) + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.max_len = max_len if max_len is not None else int(1e12) @@ -129,9 +120,10 @@ def convert_tokens_to_ids(self, tokens): ids.append(self.vocab[token]) if len(ids) > self.max_len: logger.warning( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this BERT model ({} > {}). Running this" - " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + "Token indices sequence length is longer than the specified " + "maximum sequence length for this BERT model ({} > {}). " + "Running this sequence through BERT will result in " + "indexing errors".format(len(ids), self.max_len) ) return ids @@ -148,32 +140,43 @@ def save_vocabulary(self, vocab_path): if os.path.isdir(vocab_path): vocab_file = os.path.join(vocab_path, VOCAB_NAME) with open(vocab_file, "w", encoding="utf-8") as writer: - for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + for token, token_index in sorted(self.vocab.items(), + key=lambda kv: kv[1]): if index != token_index: - logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." - " Please check that the vocabulary is not corrupted!".format(vocab_file)) + logger.warning("Saving vocabulary to {}: vocabulary " + "indices are not consecutive. Please " + "check that the vocabulary is not " + "corrupted!".format(vocab_file)) index = token_index writer.write(token + u'\n') index += 1 return vocab_file @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, + *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] - if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is a cased model but you have not set " - "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ + pretrained_model_name_or_path] + if ('-cased' in pretrained_model_name_or_path + and kwargs.get('do_lower_case', True)): + logger.warning("The pre-trained model you are loading is " + "a cased model but you have not set " + "`do_lower_case` to False. We are setting " + "`do_lower_case=False` for you but " "you may want to check this behavior.") kwargs['do_lower_case'] = False - elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is an uncased model but you have set " - "`do_lower_case` to False. We are setting `do_lower_case=True` for you " - "but you may want to check this behavior.") + elif ('-cased' not in pretrained_model_name_or_path + and not kwargs.get('do_lower_case', True)): + logger.warning("The pre-trained model you are loading is " + "a uncased model but you have set " + "`do_lower_case` to False. We are setting " + "`do_lower_case=True` for you but " + "you may want to check this behavior.") kwargs['do_lower_case'] = True else: vocab_file = pretrained_model_name_or_path @@ -185,8 +188,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, except EnvironmentError: logger.error( "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( + "We assumed '{}' was a path or url but couldn't find any file" + " associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), vocab_file)) @@ -196,10 +199,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) - if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: - # if we're using a pretrained model, ensure the tokenizer wont index sequences longer - # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + if (pretrained_model_name_or_path + in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP): + # if we're using a pretrained model, ensure the tokenizer wont + # index sequences longer than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ + pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) @@ -223,12 +228,12 @@ def __init__(self, def tokenize(self, text): """Tokenizes a piece of text.""" text = self._clean_text(text) - # This was added on November 1st, 2018 for the multilingual and Chinese - # models. This is also applied to the English models now, but it doesn't - # matter since the English models were not trained on any Chinese data - # and generally don't have any Chinese data in them (there are Chinese - # characters in the vocabulary because Wikipedia does have some Chinese - # words in the English Wikipedia.). + # This was added on Nov. 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it does + # not matter since the English models were not trained on any Chinese + # data and generally don't have any Chinese data in them (there are + # Chinese characters in the vocabulary because Wikipedia does have + # some Chinese words in the English Wikipedia.). text = self._tokenize_chinese_chars(text) orig_tokens = whitespace_tokenize(text) split_tokens = [] @@ -289,14 +294,14 @@ def _tokenize_chinese_chars(self, text): def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" - # This defines a "chinese character" as anything in the CJK Unicode block: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # This defines a "Chinese character" as anything in CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # - # Note that the CJK Unicode block is NOT all Japanese and Korean characters, - # despite its name. The modern Korean Hangul alphabet is a different block, - # as is Japanese Hiragana and Katakana. Those alphabets are used to write - # space-separated words, so they are not treated specially and handled - # like the all of the other languages. + # Note that CJK Unicode block is NOT all Japanese and Korean chars, + # despite its name. Modern Korean Hangul alphabet is a different block + # as is Japanese Hiragana and Katakana. Those alphabets are used to + # write space-separated words, so they are not treated specially and + # handled like the all of the other languages. if ((cp >= 0x4E00 and cp <= 0x9FFF) or # (cp >= 0x3400 and cp <= 0x4DBF) or # (cp >= 0x20000 and cp <= 0x2A6DF) or # @@ -310,7 +315,8 @@ def _is_chinese_char(self, cp): return False def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" + """Performs invalid character removal + and whitespace cleanup on text.""" output = [] for char in text: cp = ord(char) @@ -334,8 +340,8 @@ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): def tokenize(self, text): """Tokenizes a piece of text into its word pieces. - This uses a greedy longest-match-first algorithm to perform tokenization - using the given vocabulary. + This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. For example: input = "unaffable" diff --git a/onmt/utils/bert_vocab_archive_map.py b/onmt/utils/bert_vocab_archive_map.py new file mode 100644 index 0000000000..9987a2edb9 --- /dev/null +++ b/onmt/utils/bert_vocab_archive_map.py @@ -0,0 +1,18 @@ +# coding=utf-8 +# flake8: noqa + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", + 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", + 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", + 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", + 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", +} \ No newline at end of file diff --git a/onmt/utils/file_utils.py b/onmt/utils/file_utils.py index 9abfddeef4..caddaa819d 100644 --- a/onmt/utils/file_utils.py +++ b/onmt/utils/file_utils.py @@ -1,10 +1,12 @@ """ Utilities for working with the local dataset cache. Get from https://github.com/huggingface/pytorch-transformers. -This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +This file is adapted from the AllenNLP library +at https://github.com/allenai/allennlp Copyright by the AllenNLP authors. """ -from __future__ import (absolute_import, division, print_function, unicode_literals) +from __future__ import absolute_import, division, \ + print_function, unicode_literals import sys import json @@ -15,12 +17,11 @@ import fnmatch from functools import wraps from hashlib import sha256 -import sys from io import open import boto3 -import requests from botocore.exceptions import ClientError +import requests from tqdm import tqdm try: @@ -30,11 +31,13 @@ try: from pathlib import Path - PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', - Path.home() / '.pytorch_pretrained_bert')) + PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv( + 'PYTORCH_PRETRAINED_BERT_CACHE', + Path.home() / '.pytorch_pretrained_bert')) except (AttributeError, ImportError): - PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', - os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( + 'PYTORCH_PRETRAINED_BERT_CACHE', + os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -84,7 +87,8 @@ def cached_path(url_or_filename, cache_dir=None): raise EnvironmentError("file {} not found".format(url_or_filename)) else: # Something unknown - raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + raise ValueError( + "unable to parse {} as a URL/local path".format(url_or_filename)) def split_s3_path(url): @@ -142,7 +146,7 @@ def http_get(url, temp_file): total = int(content_length) if content_length is not None else None progress = tqdm(unit="B", total=total) for chunk in req.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks + if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) temp_file.write(chunk) progress.close() @@ -184,16 +188,19 @@ def get_from_cache(url, cache_dir=None): # If we don't have a connection (etag is None) and can't identify the file # try to get the last downloaded one if not os.path.exists(cache_path) and etag is None: - matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') - matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) + matching_files = fnmatch.filter(os.listdir(cache_dir), + filename + '.*') + matching_files = list( + filter(lambda s: not s.endswith('.json'), matching_files)) if matching_files: cache_path = os.path.join(cache_dir, matching_files[-1]) if not os.path.exists(cache_path): # Download to temporary file, then copy to cache dir once finished. - # Otherwise you get corrupt cache entries if the download gets interrupted. + # Or you get corrupt cache entries if the download gets interrupted. with tempfile.NamedTemporaryFile() as temp_file: - logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + logger.info("%s not found in cache, downloading to %s", + url, temp_file.name) # GET file object if url.startswith("s3://"): @@ -201,12 +208,14 @@ def get_from_cache(url, cache_dir=None): else: http_get(url, temp_file) - # we are copying the file before closing it, so flush to avoid truncation + # we are copying the file before close it, so flush to avoid trunc temp_file.flush() - # shutil.copyfileobj() starts at the current position, so go to the start + # shutil.copyfileobj() starts at the current position, + # so go to the start temp_file.seek(0) - logger.info("copying %s to cache at %s", temp_file.name, cache_path) + logger.info("copying %s to cache at %s", + temp_file.name, cache_path) with open(cache_path, 'wb') as cache_file: shutil.copyfileobj(temp_file, cache_file) @@ -215,8 +224,11 @@ def get_from_cache(url, cache_dir=None): meta_path = cache_path + '.json' with open(meta_path, 'w') as meta_file: output_string = json.dumps(meta) - if sys.version_info[0] == 2 and isinstance(output_string, str): - output_string = unicode(output_string, 'utf-8') # The beauty of python 2 + if (sys.version_info[0] == 2 + and isinstance(output_string, str)): + # The beauty of python 2 + output_string = unicode( # noqa: F821 + output_string, 'utf-8') meta_file.write(output_string) logger.info("removing temp file %s", temp_file.name) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 772d55fba1..824272fab4 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -47,11 +47,12 @@ def build_loss_compute(model, tgt_field, opt, train=True): unk_index=unk_idx, ignore_index=padding_idx ) elif opt.label_smoothing > 0 and train: - criterion = LabelSmoothingLoss( - opt.label_smoothing, len(tgt_field.vocab), ignore_index=padding_idx - ) + criterion = LabelSmoothingLoss(opt.label_smoothing, + len(tgt_field.vocab), + ignore_index=padding_idx) elif isinstance(model.generator[-1], LogSparsemax): - criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') + criterion = SparsemaxLoss(ignore_index=padding_idx, + reduction='sum') else: criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index e88165a958..901c718dc5 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -5,7 +5,7 @@ import onmt.opts as opts from onmt.utils.logging import logger -from onmt.utils.bert_tokenization import PRETRAINED_VOCAB_ARCHIVE_MAP +from onmt.utils import PRETRAINED_VOCAB_ARCHIVE_MAP class ArgumentParser(cfargparse.ArgumentParser): diff --git a/pregenerate_bert_training_data.py b/pregenerate_bert_training_data.py index 0df5be4373..92ba10030c 100755 --- a/pregenerate_bert_training_data.py +++ b/pregenerate_bert_training_data.py @@ -9,8 +9,7 @@ import shelve from random import random, randrange, randint, shuffle, choice -from onmt.utils.bert_tokenization import BertTokenizer, \ - PRETRAINED_VOCAB_ARCHIVE_MAP +from onmt.utils import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP from onmt.utils.file_utils import cached_path from preprocess_bert_new import build_vocab_from_tokenizer import numpy as np @@ -59,15 +58,20 @@ def sample_doc(self, current_idx, sentence_weighted=True): if sentence_weighted: # With sentence weighting, we sample docs # proportionally to their sentence length - if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths): + if (self.doc_cumsum is None + or len(self.doc_cumsum) != len(self.doc_lengths)): self._precalculate_doc_weights() rand_start = self.doc_cumsum[current_idx] - rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] + rand_end = (rand_start + self.cumsum_max + - self.doc_lengths[current_idx]) sentence_index = randrange(rand_start, rand_end) % self.cumsum_max - sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') + sampled_doc_index = np.searchsorted( + self.doc_cumsum, sentence_index, side='right') else: # If sentence weighting is False, chose doc equally - sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths) + sampled_doc_index = ((current_idx + + randrange(1, len(self.doc_lengths))) + % len(self.doc_lengths)) assert sampled_doc_index != current_idx if self.reduce_memory: return self.document_shelf[str(sampled_doc_index)] From ba8a358fcd8f191914eda53c0650a723bad18210 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Mon, 26 Aug 2019 16:47:37 +0200 Subject: [PATCH 17/28] solve PR check --- onmt/utils/activation_fn.py | 6 +++--- onmt/utils/parse.py | 2 +- preprocess_bert_new.py | 2 +- requirements.opt.txt | 2 ++ 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/onmt/utils/activation_fn.py b/onmt/utils/activation_fn.py index 705f730d54..9497a8b845 100644 --- a/onmt/utils/activation_fn.py +++ b/onmt/utils/activation_fn.py @@ -5,11 +5,11 @@ def get_activation_fn(activation): """Return an activation function Module according to its name.""" - if activation is 'gelu': + if activation == 'gelu': fn = GELU() - elif activation is 'relu': + elif activation == 'relu': fn = nn.ReLU() - elif activation is 'tanh': + elif activation == 'tanh': fn = nn.Tanh() else: raise ValueError("Please pass a valid \ diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 901c718dc5..80e1cb2804 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -78,7 +78,7 @@ def ckpt_model_opts(cls, ckpt_opt): def validate_train_opts(cls, opt): if opt.is_bert: logger.info("WE ARE IN BERT MODE.") - if opt.task_type is "none": + if opt.task_type == "none": raise ValueError( "Downstream task should be chosen when use BERT.") if opt.reuse_embeddings is True: diff --git a/preprocess_bert_new.py b/preprocess_bert_new.py index 0505b584df..82b3590f75 100755 --- a/preprocess_bert_new.py +++ b/preprocess_bert_new.py @@ -160,7 +160,7 @@ def create_tag_instances_from_file(opt): sentence_sofar = [] for line in tqdm(lines, desc="Process", unit=" lines"): line = line.strip() - if line is '': + if line == '': if len(sentence_sofar) > 0: tokens, tags = zip(*sentence_sofar) sentences.append(tokens) diff --git a/requirements.opt.txt b/requirements.opt.txt index 1d800852e9..9488683e5d 100644 --- a/requirements.opt.txt +++ b/requirements.opt.txt @@ -9,3 +9,5 @@ pyonmttok opencv-python git+https://github.com/NVIDIA/apex flask +boto3 +sklearn From 08b10802616f7ce3be0c3bc554d75d8c10b6790a Mon Sep 17 00:00:00 2001 From: pltrdy Date: Mon, 26 Aug 2019 19:15:00 +0200 Subject: [PATCH 18/28] minor changes to make code simpler/more explicit --- onmt/inputters/dataset_bert.py | 4 ++-- onmt/model_builder.py | 5 +++-- onmt/models/bert_generators.py | 20 ++++++++++---------- onmt/train_single.py | 2 +- onmt/trainer.py | 6 +++--- onmt/utils/loss.py | 2 +- onmt/utils/optimizers.py | 2 +- onmt/utils/parse.py | 4 ++-- 8 files changed, 23 insertions(+), 22 deletions(-) diff --git a/onmt/inputters/dataset_bert.py b/onmt/inputters/dataset_bert.py index 512ec21a25..69d6306cb8 100644 --- a/onmt/inputters/dataset_bert.py +++ b/onmt/inputters/dataset_bert.py @@ -153,7 +153,7 @@ class ClassifierDataset(BertDataset): def __init__(self, fields_dict, data, tokenizer, max_seq_len=256, delimiter=' ||| '): - if isinstance(data, tuple) is False: + if not isinstance(data, tuple): data = data, [None for _ in range(len(data))] instances = self.create_instances( data, tokenizer, delimiter, max_seq_len) @@ -213,7 +213,7 @@ def __init__(self, fields_dict, data, tokenizer, self.pad_tok = targer_field.pad_token if hasattr(targer_field, 'vocab'): # when predicting self.predict_tok = targer_field.vocab.itos[-1] - if isinstance(data, tuple) is False: + if not isinstance(data, tuple): data = (data, [None for _ in range(len(data))]) instances = self.create_instances( data, tokenizer, delimiter, max_seq_len) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index df13264e3d..2cfba57a8d 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -333,8 +333,9 @@ def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): # Build Bert Model(= encoder + generator). model = nn.Sequential(OrderedDict([ - ('bert', bert_encoder), - ('generator', generator)])) + ('bert', bert_encoder), + ('generator', generator) + ])) # Load the model states from checkpoint or initialize them. model_init = {'bert': False, 'generator': False} diff --git a/onmt/models/bert_generators.py b/onmt/models/bert_generators.py index 33ee221652..c82fb7481e 100644 --- a/onmt/models/bert_generators.py +++ b/onmt/models/bert_generators.py @@ -48,7 +48,7 @@ def __init__(self, hidden_size, vocab_size): self.decode = nn.Linear(hidden_size, vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size)) - self.softmax = nn.LogSoftmax(dim=-1) + self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, x): """ @@ -59,7 +59,7 @@ def forward(self, x): """ x = self.transform(x) # (batch, seq, d_model) prediction_scores = self.decode(x) + self.bias # (batch, seq, vocab) - prediction_log_prob = self.softmax(prediction_scores) + prediction_log_prob = self.log_softmax(prediction_scores) return prediction_log_prob @@ -74,7 +74,7 @@ class NextSentencePrediction(nn.Module): def __init__(self, hidden_size): super(NextSentencePrediction, self).__init__() self.linear = nn.Linear(hidden_size, 2) - self.softmax = nn.LogSoftmax(dim=-1) + self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, x): """ @@ -84,7 +84,7 @@ def forward(self, x): seq_class_prob (Tensor): ``(B, 2)`` """ seq_relationship_score = self.linear(x) # (batch, 2) - seq_class_log_prob = self.softmax(seq_relationship_score) + seq_class_log_prob = self.log_softmax(seq_relationship_score) return seq_class_log_prob @@ -127,7 +127,7 @@ def __init__(self, hidden_size, n_class, dropout=0.1): super(ClassificationHead, self).__init__() self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(hidden_size, n_class) - self.softmax = nn.LogSoftmax(dim=-1) + self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, all_hidden, pooled): """ @@ -141,7 +141,7 @@ def forward(self, all_hidden, pooled): pooled = self.dropout(pooled) score = self.linear(pooled) # (batch, n_class) - class_log_prob = self.softmax(score) # (batch, n_class) + class_log_prob = self.log_softmax(score) # (batch, n_class) return class_log_prob, None @@ -157,7 +157,7 @@ def __init__(self, hidden_size, n_class, dropout=0.1): super(TokenTaggingHead, self).__init__() self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(hidden_size, n_class) - self.softmax = nn.LogSoftmax(dim=-1) + self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, all_hidden, pooled): """ @@ -171,7 +171,7 @@ def forward(self, all_hidden, pooled): last_hidden = all_hidden[-1] last_hidden = self.dropout(last_hidden) # (batch, seq, d_model) score = self.linear(last_hidden) # (batch, seq, n_class) - tok_class_log_prob = self.softmax(score) # (batch, seq, n_class) + tok_class_log_prob = self.log_softmax(score) # (batch, seq, n_class) return None, tok_class_log_prob @@ -191,7 +191,7 @@ def __init__(self, hidden_size, vocab_size): self.decode = nn.Linear(hidden_size, vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size)) - self.softmax = nn.LogSoftmax(dim=-1) + self.log_softmax = nn.LogSoftmax(dim=-1) def forward(self, all_hidden, pooled): """ @@ -205,5 +205,5 @@ def forward(self, all_hidden, pooled): last_hidden = all_hidden[-1] y = self.transform(last_hidden) # (batch, seq, d_model) prediction_scores = self.decode(y) + self.bias # (batch, seq, vocab) - prediction_log_prob = self.softmax(prediction_scores) + prediction_log_prob = self.log_softmax(prediction_scores) return None, prediction_log_prob diff --git a/onmt/train_single.py b/onmt/train_single.py index b4c36bb9da..105a52f5e0 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -118,7 +118,7 @@ def main(opt, device_id, batch_queue=None, semaphore=None): model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( - opt, device_id, model, fields, optim, model_saver=model_saver) + opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: diff --git a/onmt/trainer.py b/onmt/trainer.py index 965cd65d8c..aed5fdedcc 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -41,7 +41,7 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): tgt_field = fields["lm_labels_ids"] train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) valid_loss = onmt.utils.loss.build_loss_compute( - model, tgt_field, opt, train=False) + model, tgt_field, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0 @@ -137,7 +137,7 @@ def __init__(self, model, train_loss, valid_loss, optim, self.earlystopper = earlystopper self.dropout = dropout self.dropout_steps = dropout_steps - self.is_bert = True if hasattr(self.model, 'bert') else False + self.is_bert = hasattr(self.model, 'bert') for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 @@ -172,7 +172,7 @@ def _accum_batches(self, iterator): self.accum_count = self._accum_count(self.optim.training_step) for batch in iterator: batches.append(batch) - if self.is_bert is False: # Bert don't need normalization + if not self.is_bert: # Bert don't need normalization if self.norm_method == "tokens": num_tokens = batch.tgt[1:, :, 0].ne( self.train_loss.padding_idx).sum() diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 824272fab4..57f39efc17 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -23,7 +23,7 @@ def build_loss_compute(model, tgt_field, opt, train=True): for when using a copy mechanism. """ device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") - if opt.is_bert is True: + if opt.is_bert: if tgt_field.pad_token is not None: if tgt_field.use_vocab: padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index bac68e49e0..1e82937137 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -635,7 +635,7 @@ class BertAdam(torch.optim.Optimizer): def __init__(self, params, lr=None, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): - if not 0.0 <= lr: + if not lr >= 0.0: raise ValueError("Invalid learning rate: {}".format(lr) + " - should be >= 0.0") if not 0.0 <= betas[0] < 1.0: diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 80e1cb2804..0fb916e34a 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -158,7 +158,7 @@ def validate_preprocess_bert_opts(cls, opt): "cased model, you shouldn't set `do_lower_case`," + "we turned it off for you.") opt.do_lower_case = False - elif '-cased' not in opt.vocab_model and opt.do_lower_case is False: + elif '-cased' not in opt.vocab_model and not opt.do_lower_case: logger.warning("The pre-trained model you are loading is " + "uncased model, you should set `do_lower_case`, " + "we turned it on for you.") @@ -209,7 +209,7 @@ def validate_predict_opts(cls, opt): "is cased model, you shouldn't set `do_lower_case`," + "we turned it off for you.") opt.do_lower_case = False - elif '-cased' not in opt.vocab_model and opt.do_lower_case is False: + elif '-cased' not in opt.vocab_model and not opt.do_lower_case: logger.info("WARNING: The pre-trained model you are loading " + "is uncased model, you should set `do_lower_case`, " + "we turned it on for you.") From 660e459af52c902f56bc0e9c75c08b149c5e6af7 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Tue, 27 Aug 2019 14:29:54 +0200 Subject: [PATCH 19/28] simplify code --- bert_ckp_convert.py | 10 +- docs/source/FAQ.md | 34 ++--- onmt/encoders/bert.py | 9 +- onmt/model_builder.py | 258 +++++++++++--------------------- onmt/train_single.py | 25 ++-- onmt/trainer.py | 18 +-- onmt/translate/predictor.py | 18 +-- onmt/utils/bert_tokenization.py | 2 +- train.py | 4 +- 9 files changed, 148 insertions(+), 230 deletions(-) diff --git a/bert_ckp_convert.py b/bert_ckp_convert.py index 7fe778d7b6..c8d1814b49 100755 --- a/bert_ckp_convert.py +++ b/bert_ckp_convert.py @@ -89,10 +89,10 @@ def convert_bert_weights(bert_model, weights, n_layers=12): print("[OLD Weights file]gamma/beta is used in " + "naming BertLayerNorm. Mapping succeed.") else: - raise KeyError("Key %s not found in weight file" + raise KeyError("Failed fix LayerNorm %s, check file" % hugface_key) else: - raise KeyError("Key %s not found in weight file" + raise KeyError("Mapped key %s not in weight file" % hugface_key) if 'generator' not in key: onmt_key = re.sub(r'bert\.(.*)', r'\1', key) @@ -100,9 +100,9 @@ def convert_bert_weights(bert_model, weights, n_layers=12): else: onmt_key = re.sub(r'generator\.(.*)', r'\1', key) model_weights['generator'][onmt_key] = weights[hugface_key] - except ValueError: - print("Unsuccessful convert!") - exit() + except KeyError: + print("Unsuccessful convert.") + raise return model_weights diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index e6e6a21e02..15dbc056a7 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -138,21 +138,21 @@ will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`, **Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing. ## How do I use BERT? -BERT is a general-purpose "language understanding" model introduced by Google, it can be used for various downstream NLP tasks and easily adapted into a new task using transfer learning. Using BERT has two stages: Pre-training and fine-tuning. But as the Pre-training is super expensive, we do not recommande you to pre-train a BERT from scratch. Instead loading weights from a existing pretrained model and fine-tuning it is suggested. Currently we support sentence(-pair) classification and token tagging downstream task. +BERT is a general-purpose "language understanding" model introduced by Google, it can be used for various downstream NLP tasks and easily adapted into a new task using transfer learning. Using BERT has two stages: Pre-training and fine-tuning. But as the Pre-training is super expensive, we do not recommand you to pre-train a BERT from scratch. Instead loading weights from a existing pretrained model and fine-tuning is suggested. Currently we support sentence(-pair) classification and token tagging downstream task. ### Use pretrained BERT weights To use weights from a existing huggingface's pretrained model, we provide you a script to convert huggingface's BERT model weights into ours. Usage: ```bash -bert_ckp_convert.py --layers NUMBER_LAYER - --bert_model_weights_file HUGGINGFACE_BERT_WEIGHTS - --output_name OUTPUT_FILE +python bert_ckp_convert.py --layers NUMBER_LAYER + --bert_model_weights_file HUGGINGFACE_BERT_WEIGHTS + --output_name OUTPUT_FILE ``` -* Go to modeling_bert.py in https://github.com/huggingface/pytorch-transformers/ to check all available pretrained model. +* Go to [modeling_bert.py](https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py) to check all available pretrained model. ### Preprocess train/dev dataset -To genenrate train/dev data for BERT, you can use preprocess_bert.py by providing raw data in certain format and choose a BERT Tokenizer model `-vm` coherent with pretrained model. +To generate train/dev data for BERT, you can use preprocess_bert.py by providing raw data in certain format and choose a BERT Tokenizer model `-vm` coherent with pretrained model. #### Classification For classification dataset, we support input file in csv or plain text file format. @@ -161,10 +161,10 @@ For classification dataset, we support input file in csv or plain text file form | ID | SENTENCE_A | SENTENCE_B(Optional) | LABEL | | -- | ------------------------ | ------------------------ | ------- | | 0 | sentence a of instance 0 | sentence b of instance 0 | class 2 | - | 1 | sentence a of instance 0 | sentence b of instance 1 | class 1 | + | 1 | sentence a of instance 1 | sentence b of instance 1 | class 1 | | ...| ... | ... | ... | - Then calling preprocess_bert.py and providing input sentence columns and label column: + Then calling `preprocess_bert.py` and providing input sentence columns and label column: ```bash python preprocess_bert.py --task classification --corpus_type {train, valid} --file_type csv [--delimiter '\t'] [--skip_head] @@ -174,13 +174,13 @@ For classification dataset, we support input file in csv or plain text file form -vm bert-base-cased --max_seq_len 256 [--do_lower_case] [--sort_label_vocab] [--do_shuffle] ``` -* For plain text format, we accept multiply files as input, each file contains instances for one specific class. Each line of the file contain one instance which could be composed by one sentence or two separated by ` ||| `. All input file should be arranged in following way: +* For plain text format, we accept multiply files as input, each file contains instances for one specific class. Each line of the file contains one instance which could be composed by one sentence or two separated by ` ||| `. All input file should be arranged in following way: ``` .../LABEL_1/filename .../LABEL_2/filename .../LABEL_3/filename ``` - Then call preprocess_bert.py as following to generate training data: + Then call `preprocess_bert.py` as following to generate training data: ```bash python preprocess_bert.py --task classification --corpus_type {'train', 'valid'} --file_type txt [--delimiter ' ||| '] @@ -193,7 +193,7 @@ For classification dataset, we support input file in csv or plain text file form #### Tagging For tagging dataset, we support input file in plain text file format. -Each line of the input file should contain token and its tagging, different fields should be separated by a delimiter(default space) while sentences are separated by a blank line. +Each line of the input file should contain one token and its tagging, different fields should be separated by a delimiter(default space) while sentences are separated by a blank line. A example of input file is given below (`Token X X Label`): ``` @@ -229,10 +229,10 @@ Then call preprocess_bert.py providing token column and label column as followin #### Pretraining objective Even if it's not recommended, we also provide you a script to generate pretraining dataset as you may want to finetuning a existing pretrained model on masked language modeling and next sentence prediction. -The script expect a single file as input, consisting of untokenized text, with one sentence per line, and one blank line between documents. +The script expects a single file as input, consisting of untokenized text, with one sentence per line, and one blank line between documents. A usage example is given below: ```bash -python3 pregenerate_bert_training_data.py --input_file INPUT_FILE +python pregenerate_bert_training_data.py --input_file INPUT_FILE --output_dir OUTPUT_DIR --output_name OUTPUT_FILE_PREFIX --corpus_type {'train', 'valid'} @@ -246,11 +246,11 @@ python3 pregenerate_bert_training_data.py --input_file INPUT_FILE ``` ### Training -After preprocessed data have been generated, you can load weights from a pretrained BERT and transfer it to downstream task with a task specific output head. This task specific head will be initialized by a method you choose if there is no such architecture in weights file specified by `--train_from`. Among all available optimizer, you are suggest to use `--optim bertadam` as it is the method used to train BERT. `warmup_steps` could be set as 1% of `train_steps` as in original paper if use linear decay method. +After preprocessed data have been generated, you can load weights from a pretrained BERT and transfer it to downstream task with a task specific output head. This task specific head will be initialized by a method you choose if there is no such architecture in weights file specified by `--train_from`. Among all available optimizers, you are suggested to use `--optim bertadam` as it is the method used to train BERT. `warmup_steps` could be set as 1% of `train_steps` as in original paper if use linear decay method. A usage example is given below: ```bash -python3 train.py --is_bert --task_type {pretraining, classification, tagging} +python train.py --is_bert --task_type {pretraining, classification, tagging} --data PREPROCESSED_DATAIFILE --train_from CONVERTED_CHECKPOINT.pt [--param_init 0.1] --save_model MODEL_PREFIX --save_checkpoint_steps 1000 @@ -268,12 +268,12 @@ python3 train.py --is_bert --task_type {pretraining, classification, tagging} ### Predicting After training, you can use `predict.py` to generate predicting for raw file. Make sure to use the same BERT Tokenizer model `--vocab_model` as in training data. -For classification task, file to be predicting should be one sentence(-pair) a line with ` ||| ` separating sentence. +For classification task, file to be predicted should be one sentence(-pair) a line with ` ||| ` separating sentence. For tagging task, each line should be a tokenized sentence with tokens separated by space. Usage: ```bash -python3 predict.py --task {classification, tagging} +python predict.py --task {classification, tagging} --model ONMT_BERT_CHECKPOINT.pt --vocab_model bert-base-uncased [--do_lower_case] --data DATA_2_PREDICT [--delimiter {' ||| ', ' '}] --max_seq_len 256 diff --git a/onmt/encoders/bert.py b/onmt/encoders/bert.py index 49571bbd69..853e74dfb5 100644 --- a/onmt/encoders/bert.py +++ b/onmt/encoders/bert.py @@ -16,8 +16,8 @@ class BertEncoder(nn.Module): dropout (float): dropout parameters """ - def __init__(self, embeddings, num_layers=12, d_model=768, - heads=12, d_ff=3072, dropout=0.1, + def __init__(self, embeddings, num_layers=12, d_model=768, heads=12, + d_ff=3072, dropout=0.1, attention_dropout=0.1, max_relative_positions=0): super(BertEncoder, self).__init__() self.num_layers = num_layers @@ -30,7 +30,8 @@ def __init__(self, embeddings, num_layers=12, d_model=768, self.embeddings = embeddings # Transformer Encoder Block self.encoder = nn.ModuleList( - [TransformerEncoderLayer(d_model, heads, d_ff, dropout, + [TransformerEncoderLayer(d_model, heads, d_ff, + dropout, attention_dropout, max_relative_positions=max_relative_positions, activation='gelu', is_bert=True) for _ in range(num_layers)]) @@ -47,6 +48,8 @@ def from_opt(cls, opt, embeddings): opt.heads, opt.transformer_ff, opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opt.attention_dropout[0] if type(opt.attention_dropout) + is list else opt.attention_dropout, opt.max_relative_positions) def forward(self, input_ids, token_type_ids=None, input_mask=None, diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 2cfba57a8d..867d324277 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -13,25 +13,34 @@ from onmt.decoders import str2dec -from onmt.modules import Embeddings, VecEmbedding, CopyGenerator +from onmt.modules import Embeddings, VecEmbedding, CopyGenerator, \ + BertEmbeddings from onmt.modules.util_class import Cast from onmt.utils.misc import use_gpu from onmt.utils.logging import logger from onmt.utils.parse import ArgumentParser from onmt.models import BertPreTrainingHeads, ClassificationHead, \ - TokenGenerationHead, TokenTaggingHead -from onmt.modules.bert_embeddings import BertEmbeddings -from collections import OrderedDict + TokenGenerationHead, TokenTaggingHead def build_embeddings(opt, text_field, for_encoder=True): """ Args: opt: the option in current environment. - text_field(TextMultiField): word and feats field. + text_field(TextMultiField | Field): word and feats field. for_encoder(bool): build Embeddings for encoder or decoder? """ + if opt.is_bert: + token_fields_vocab = text_field.vocab + vocab_size = len(token_fields_vocab) + emb_dim = opt.word_vec_size + return BertEmbeddings( + vocab_size, emb_dim, + dropout=(opt.dropout[0] if type(opt.dropout) is list + else opt.dropout) + ) + emb_dim = opt.src_word_vec_size if for_encoder else opt.tgt_word_vec_size if opt.model_type == "vec" and for_encoder: @@ -76,8 +85,11 @@ def build_encoder(opt, embeddings): opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this encoder. """ - enc_type = opt.encoder_type if opt.model_type == "text" \ - or opt.model_type == "vec" else opt.model_type + if opt.is_bert: + enc_type = 'bert' + else: + enc_type = opt.encoder_type if opt.model_type == "text" \ + or opt.model_type == "vec" else opt.model_type return str2enc[enc_type].from_opt(opt, embeddings) @@ -95,6 +107,7 @@ def build_decoder(opt, embeddings): def load_test_model(opt, model_path=None): if model_path is None: + assert hasattr(opt, 'models') model_path = opt.models[0] checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) @@ -134,7 +147,7 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): gpu_id (int or NoneType): Which GPU to use. Returns: - the NMTModel. + the NMTModel or BertEncoder(with generator). """ # for back compat when attention_dropout was not defined @@ -144,7 +157,10 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): model_opt.attention_dropout = model_opt.dropout # Build embeddings. - if model_opt.model_type == "text" or model_opt.model_type == "vec": + if model_opt.is_bert: + src_field = fields["tokens"] + src_emb = build_embeddings(model_opt, src_field) + elif model_opt.model_type == "text" or model_opt.model_type == "vec": src_field = fields["src"] src_emb = build_embeddings(model_opt, src_field) else: @@ -153,31 +169,32 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): # Build encoder. encoder = build_encoder(model_opt, src_emb) - # Build decoder. - tgt_field = fields["tgt"] - tgt_emb = build_embeddings(model_opt, tgt_field, for_encoder=False) + if not model_opt.is_bert: + # Build decoder. + tgt_field = fields["tgt"] + tgt_emb = build_embeddings(model_opt, tgt_field, for_encoder=False) - # Share the embedding matrix - preprocess with share_vocab required. - if model_opt.share_embeddings: - # src/tgt vocab should be the same if `-share_vocab` is specified. - assert src_field.base_field.vocab == tgt_field.base_field.vocab, \ - "preprocess with -share_vocab if you use share_embeddings" + # Share the embedding matrix - preprocess with share_vocab required. + if model_opt.share_embeddings: + # src/tgt vocab should be the same if `-share_vocab` is specified. + assert src_field.base_field.vocab == tgt_field.base_field.vocab, \ + "preprocess with -share_vocab if you use share_embeddings" - tgt_emb.word_lut.weight = src_emb.word_lut.weight + tgt_emb.word_lut.weight = src_emb.word_lut.weight - decoder = build_decoder(model_opt, tgt_emb) + decoder = build_decoder(model_opt, tgt_emb) - # Build NMTModel(= encoder + decoder). if gpu and gpu_id is not None: device = torch.device("cuda", gpu_id) elif gpu and not gpu_id: device = torch.device("cuda") elif not gpu: device = torch.device("cpu") - model = onmt.models.NMTModel(encoder, decoder) # Build Generator. - if not model_opt.copy_attn: + if model_opt.is_bert: + generator = build_bert_generator(model_opt, fields, encoder) + elif not model_opt.copy_attn: if model_opt.generator_function == "sparsemax": gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) else: @@ -196,40 +213,62 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx) + if model_opt.is_bert: + model = encoder + else: + # Build NMTModel(= encoder + decoder). + model = onmt.models.NMTModel(encoder, decoder) # Load the model states from checkpoint or initialize them. + model_init = {'model': False, 'generator': False} if checkpoint is not None: - # This preserves backward-compat for models using customed layernorm - def fix_key(s): - s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.b_2', - r'\1.layer_norm\2.bias', s) - s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.a_2', - r'\1.layer_norm\2.weight', s) - return s - - checkpoint['model'] = {fix_key(k): v - for k, v in checkpoint['model'].items()} - # end of patch for backward compatibility - - model.load_state_dict(checkpoint['model'], strict=False) - generator.load_state_dict(checkpoint['generator'], strict=False) - else: - if model_opt.param_init != 0.0: - for p in model.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - for p in generator.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - if model_opt.param_init_glorot: - for p in model.parameters(): - if p.dim() > 1: - xavier_uniform_(p) - for p in generator.parameters(): - if p.dim() > 1: - xavier_uniform_(p) - - if hasattr(model.encoder, 'embeddings'): + assert 'model' in checkpoint + if not model_opt.is_bert: + # This preserves back-compat for models using customed layernorm + def fix_key(s): + s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.b_2', + r'\1.layer_norm\2.bias', s) + s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.a_2', + r'\1.layer_norm\2.weight', s) + return s + + checkpoint['model'] = {fix_key(k): v + for k, v in checkpoint['model'].items()} + # end of patch for backward compatibility + if model.state_dict().keys() != checkpoint['model'].keys(): + raise ValueError("Checkpoint don't match actual model!") + logger.info("Load Model Parameters...") + model.load_state_dict(checkpoint['model'], strict=True) + model_init['model'] = True + if generator.state_dict().keys() == checkpoint['generator'].keys(): + logger.info("Load generator Parameters...") + generator.load_state_dict(checkpoint['generator'], strict=True) + model_init['generator'] = True + + for module_name, is_init in model_init.items(): + if not is_init: + logger.info("Initialize {} Parameters...".format(module_name)) + sub_module = model if module_name == 'model' else generator + if model_opt.param_init != 0.0: + logger.info('Initialize weights using a uniform distribution') + for p in sub_module.parameters(): + p.data.uniform_(-model_opt.param_init, + model_opt.param_init) + if model_opt.param_init_normal != 0.0: + logger.info('Initialize weights using a normal distribution') + normal_std = model_opt.param_init_normal + for p in sub_module.parameters(): + p.data.normal_(mean=0, std=normal_std) + if model_opt.param_init_glorot: + logger.info('Glorot initialization') + for p in sub_module.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + if checkpoint is None: + if hasattr(model, 'encoder') and hasattr(model.encoder, 'embeddings'): model.encoder.embeddings.load_pretrained_vectors( model_opt.pre_word_vecs_enc) - if hasattr(model.decoder, 'embeddings'): + if hasattr(model, 'decoder') and hasattr(model.decoder, 'embeddings'): model.decoder.embeddings.load_pretrained_vectors( model_opt.pre_word_vecs_dec) @@ -246,20 +285,6 @@ def build_model(model_opt, opt, fields, checkpoint): return model -def build_bert_embeddings(opt, fields): - token_fields_vocab = fields['tokens'].vocab - vocab_size = len(token_fields_vocab) - emb_size = opt.word_vec_size - bert_emb = BertEmbeddings(vocab_size, emb_size, - dropout=opt.dropout[0]) - return bert_emb - - -def build_bert_encoder(model_opt, embeddings): - bert = str2enc['bert'].from_opt(model_opt, embeddings) - return bert - - def build_bert_generator(model_opt, fields, bert_encoder): """Main part for transfer learning: set opt.task_type to `pretraining` if want finetuning Bert; @@ -294,104 +319,3 @@ def build_bert_generator(model_opt, fields, bert_encoder): logger.info('Generator of tagging with %s tag.' % n_class) generator = TokenTaggingHead(bert_encoder.d_model, n_class, dropout) return generator - - -def build_bert_model(model_opt, opt, fields, checkpoint=None, gpu_id=None): - """Build a model from opts. - - Args: - model_opt: the option loaded from checkpoint. It's important that - the opts have been updated and validated. See - :class:`onmt.utils.parse.ArgumentParser`. - fields (dict[str, torchtext.data.Field]): - `Field` objects for the model. - gpu (bool): whether to use gpu. - checkpoint: the model generated by train phase, or a resumed snapshot - model from a stopped training. - gpu_id (int or NoneType): Which GPU to use. - - Returns: - the BERT model. - """ - logger.info('Building BERT model...') - # Build embeddings. - bert_emb = build_bert_embeddings(model_opt, fields) - - # Build encoder. - bert_encoder = build_bert_encoder(model_opt, bert_emb) - - gpu = use_gpu(opt) - if gpu and gpu_id is not None: - device = torch.device("cuda", gpu_id) - elif gpu and not gpu_id: - device = torch.device("cuda") - elif not gpu: - device = torch.device("cpu") - - # Build Generator. - generator = build_bert_generator(model_opt, fields, bert_encoder) - - # Build Bert Model(= encoder + generator). - model = nn.Sequential(OrderedDict([ - ('bert', bert_encoder), - ('generator', generator) - ])) - - # Load the model states from checkpoint or initialize them. - model_init = {'bert': False, 'generator': False} - if checkpoint is not None: - assert 'model' in checkpoint - if model.bert.state_dict().keys() != checkpoint['model'].keys(): - raise ValueError("Provide checkpoint don't match actual model!") - logger.info("Load Model Parameters...") - model.bert.load_state_dict(checkpoint['model'], strict=True) - model_init['bert'] = True - if (model.generator.state_dict().keys() == - checkpoint['generator'].keys()): - logger.info("Load generator Parameters...") - model.generator.load_state_dict(checkpoint['generator'], - strict=True) - model_init['generator'] = True - - for sub_module, is_init in model_init.items(): - if not is_init: - logger.info("Initialize {} Parameters...".format(sub_module)) - if model_opt.param_init_normal != 0.0: - logger.info('Initialize weights using a normal distribution') - normal_std = model_opt.param_init_normal - for p in getattr(model, sub_module).parameters(): - p.data.normal_(mean=0, std=normal_std) - elif model_opt.param_init != 0.0: - logger.info('Initialize weights using a uniform distribution') - for p in getattr(model, sub_module).parameters(): - p.data.uniform_(-model_opt.param_init, - model_opt.param_init) - elif model_opt.param_init_glorot: - logger.info('Glorot initialization') - for p in getattr(model, sub_module).parameters(): - if p.dim() > 1: - xavier_uniform_(p) - else: - raise AttributeError("Initialization method haven't be used!") - - model.to(device) - logger.info(model) - return model - - -def load_bert_model(opt, model_path): - checkpoint = torch.load(model_path, - map_location=lambda storage, loc: storage) - logger.info("Checkpoint from {} Loaded.".format(model_path)) - model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt']) - vocab = checkpoint['vocab'] - fields = vocab - model = build_bert_model( - model_opt, opt, fields, checkpoint, gpu_id=opt.gpu) - - if opt.fp32: - model.float() - model.eval() - model.bert.eval() - model.generator.eval() - return fields, model, model_opt diff --git a/onmt/train_single.py b/onmt/train_single.py index 105a52f5e0..5cfa2fc94b 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -6,7 +6,7 @@ from onmt.inputters.inputter import build_dataset_iter, \ load_old_vocab, old_style_vocab, build_dataset_iter_multiple -from onmt.model_builder import build_model, build_bert_model +from onmt.model_builder import build_model from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed from onmt.trainer import build_trainer @@ -71,9 +71,7 @@ def main(opt, device_id, batch_queue=None, semaphore=None): # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) - if opt.is_bert: # TODO: test amelioration - fields = vocab - elif old_style_vocab(vocab): + if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: @@ -82,8 +80,7 @@ def main(opt, device_id, batch_queue=None, semaphore=None): if opt.is_bert: # Report bert tokens vocab sizes, including for features f = fields['tokens'] - if f.use_vocab: # NOTE: useless! - logger.info(' * %s vocab size = %d' % ("BERT", len(f.vocab))) + logger.info(' * %s vocab size = %d' % ("BERT", len(f.vocab))) else: # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: @@ -96,19 +93,15 @@ def main(opt, device_id, batch_queue=None, semaphore=None): if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) - # Build model. + model = build_model(model_opt, opt, fields, checkpoint) + n_params, enc, dec = _tally_parameters(model) + logger.info('encoder: %d' % enc) if opt.is_bert: - model = build_bert_model(model_opt, opt, fields, checkpoint) # V2 - n_params = 0 - for param in model.parameters(): - n_params += param.nelement() - logger.info('* number of parameters: %d' % n_params) + logger.info('generator: %d' % dec) else: - model = build_model(model_opt, opt, fields, checkpoint) - n_params, enc, dec = _tally_parameters(model) - logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) - logger.info('* number of parameters: %d' % n_params) + logger.info('* number of parameters: %d' % n_params) + _check_save_model_path(opt) # Build optimizer. diff --git a/onmt/trainer.py b/onmt/trainer.py index aed5fdedcc..f33a64f860 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -76,7 +76,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): model_dtype=opt.model_dtype, earlystopper=earlystopper, dropout=dropout, - dropout_steps=dropout_steps) + dropout_steps=dropout_steps, + is_bert=opt.is_bert) return trainer @@ -107,11 +108,10 @@ class Trainer(object): """ def __init__(self, model, train_loss, valid_loss, optim, - trunc_size=0, shard_size=32, - norm_method="sents", accum_count=[1], - accum_steps=[0], - n_gpu=1, gpu_rank=1, - gpu_verbose_level=0, report_manager=None, model_saver=None, + trunc_size=0, shard_size=32, norm_method="sents", + accum_count=[1], accum_steps=[0], + n_gpu=1, gpu_rank=1, gpu_verbose_level=0, + report_manager=None, model_saver=None, is_bert=False, average_decay=0, average_every=1, model_dtype='fp32', earlystopper=None, dropout=[0.3], dropout_steps=[0]): # Basic attributes. @@ -137,7 +137,7 @@ def __init__(self, model, train_loss, valid_loss, optim, self.earlystopper = earlystopper self.dropout = dropout self.dropout_steps = dropout_steps - self.is_bert = hasattr(self.model, 'bert') + self.is_bert = is_bert for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 @@ -159,7 +159,7 @@ def _accum_count(self, step): _accum = self.accum_count_l[i] return _accum - def _maybe_update_dropout(self, step): # TODO: to be test with Bert + def _maybe_update_dropout(self, step): for i in range(len(self.dropout_steps)): if step > 1 and step == self.dropout_steps[i] + 1: self.model.update_dropout(self.dropout[i]) @@ -527,7 +527,7 @@ def _bert_gradient_accumulation(self, true_batches, if self.accum_count == 1: self.optim.zero_grad() - all_encoder_layers, pooled_out = self.model.bert( + all_encoder_layers, pooled_out = self.model( input_ids, token_type_ids) seq_class_log_prob, prediction_log_prob = self.model.generator( all_encoder_layers, pooled_out) diff --git a/onmt/translate/predictor.py b/onmt/translate/predictor.py index b49c789b25..b95391d3fc 100644 --- a/onmt/translate/predictor.py +++ b/onmt/translate/predictor.py @@ -15,10 +15,10 @@ def build_classifier(opt, logger=None, out_file=None): """Return a classifier with result redirect to `out_file`.""" if out_file is None: - out_file = codecs.open(opt.output, 'w+', 'utf-8') + out_file = codecs.open(opt.output, 'w', 'utf-8') - load_bert_model = onmt.model_builder.load_bert_model - fields, model, model_opt = load_bert_model(opt, opt.model) + load_model = onmt.model_builder.load_test_model + fields, model, model_opt = load_model(opt, opt.model) classifier = Classifier.from_opt( model, @@ -35,10 +35,10 @@ def build_tagger(opt, logger=None, out_file=None): """Return a tagger with result redirect to `out_file`.""" if out_file is None: - out_file = codecs.open(opt.output, 'w+', 'utf-8') + out_file = codecs.open(opt.output, 'w', 'utf-8') - load_bert_model = onmt.model_builder.load_bert_model - fields, model, model_opt = load_bert_model(opt, opt.model) + load_model = onmt.model_builder.load_test_model + fields, model, model_opt = load_model(opt, opt.model) tagger = Tagger.from_opt( model, @@ -123,7 +123,7 @@ def from_opt( seed=opt.seed) def _log(self, msg): - if self.logger: + if self.logger is not None: self.logger.info(msg) else: print(msg) @@ -216,7 +216,7 @@ def classify_batch(self, batch): with torch.no_grad(): input_ids, _ = batch.tokens token_type_ids = batch.segment_ids - all_encoder_layers, pooled_out = self.model.bert( + all_encoder_layers, pooled_out = self.model( input_ids, token_type_ids) seq_class_log_prob, _ = self.model.generator( all_encoder_layers, pooled_out) @@ -319,7 +319,7 @@ def tagging_batch(self, batch): token_type_ids = batch.segment_ids taggings = batch.token_labels # Forward - all_encoder_layers, pooled_out = self.model.bert( + all_encoder_layers, pooled_out = self.model( input_ids, token_type_ids) _, prediction_log_prob = self.model.generator( all_encoder_layers, pooled_out) diff --git a/onmt/utils/bert_tokenization.py b/onmt/utils/bert_tokenization.py index 7dda9cbb63..6a5ba03045 100644 --- a/onmt/utils/bert_tokenization.py +++ b/onmt/utils/bert_tokenization.py @@ -24,7 +24,7 @@ from io import open from .file_utils import cached_path -from onmt.utils.bert_vocab_archive_map import PRETRAINED_VOCAB_ARCHIVE_MAP +from onmt.utils import PRETRAINED_VOCAB_ARCHIVE_MAP logger = logging.getLogger(__name__) diff --git a/train.py b/train.py index 72666e1b05..e4941e5df1 100755 --- a/train.py +++ b/train.py @@ -39,9 +39,7 @@ def main(opt): # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) - if opt.is_bert: - fields = vocab - elif old_style_vocab(vocab): + if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: From e5b035527823646fda4feb4ceab173687a7c97c1 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Tue, 27 Aug 2019 14:42:05 +0200 Subject: [PATCH 20/28] fix import; clarify FAQ --- docs/source/FAQ.md | 8 +++++--- onmt/utils/bert_tokenization.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index 15dbc056a7..b8c5e421f4 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -176,9 +176,11 @@ For classification dataset, we support input file in csv or plain text file form ``` * For plain text format, we accept multiply files as input, each file contains instances for one specific class. Each line of the file contains one instance which could be composed by one sentence or two separated by ` ||| `. All input file should be arranged in following way: ``` - .../LABEL_1/filename - .../LABEL_2/filename - .../LABEL_3/filename + . + ├── LABEL_A + │   └── FILE_WITH_INSTANCE_A + └── LABEL_B + └── FILE_WITH_INSTANCE_B ``` Then call `preprocess_bert.py` as following to generate training data: ```bash diff --git a/onmt/utils/bert_tokenization.py b/onmt/utils/bert_tokenization.py index 6a5ba03045..7dda9cbb63 100644 --- a/onmt/utils/bert_tokenization.py +++ b/onmt/utils/bert_tokenization.py @@ -24,7 +24,7 @@ from io import open from .file_utils import cached_path -from onmt.utils import PRETRAINED_VOCAB_ARCHIVE_MAP +from onmt.utils.bert_vocab_archive_map import PRETRAINED_VOCAB_ARCHIVE_MAP logger = logging.getLogger(__name__) From 1a676b225b96b31314ab1c92531d91658becadd9 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Tue, 27 Aug 2019 15:06:24 +0200 Subject: [PATCH 21/28] fix build --- onmt/opts.py | 2 +- onmt/utils/parse.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/onmt/opts.py b/onmt/opts.py index b4fee51725..c32a5c8b23 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -59,6 +59,7 @@ def model_opts(parser): # Encoder-Decoder Options group = parser.add_argument_group('Model- Encoder-Decoder') + group.add('--is_bert', '-is_bert', action='store_true') group.add('--model_type', '-model_type', default='text', choices=['text', 'img', 'audio', 'vec'], help="Type of source model to use. Allows " @@ -385,7 +386,6 @@ def train_opts(parser): """ Training and saving options """ group = parser.add_argument_group('Pretrain-finetuning') - group.add('--is_bert', '-is_bert', action='store_true') group.add('--task_type', '-task_type', type=str, default="none", choices=["none", "pretraining", "classification", "tagging"], help="Downstream task for Bert if is_bert set True" diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 0fb916e34a..420d1dc650 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -47,6 +47,9 @@ def update_model_opts(cls, model_opt): if model_opt.copy_attn_type is None: model_opt.copy_attn_type = model_opt.global_attention + if not hasattr(model_opt, 'is_bert'): + model_opt.is_bert = False + @classmethod def validate_model_opts(cls, model_opt): assert model_opt.model_type in ["text", "img", "audio", "vec"], \ From 4335a1322aeb2b79a64b62710a5f62f2d7b4a2e2 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Tue, 27 Aug 2019 15:32:13 +0200 Subject: [PATCH 22/28] fix exception --- onmt/model_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 867d324277..6f77968ee8 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -234,8 +234,8 @@ def fix_key(s): checkpoint['model'] = {fix_key(k): v for k, v in checkpoint['model'].items()} # end of patch for backward compatibility - if model.state_dict().keys() != checkpoint['model'].keys(): - raise ValueError("Checkpoint don't match actual model!") + # if model.state_dict().keys() != checkpoint['model'].keys(): + # raise ValueError("Checkpoint don't match actual model!") logger.info("Load Model Parameters...") model.load_state_dict(checkpoint['model'], strict=True) model_init['model'] = True From f5aec9f4112e3a832f47f3e9c77e96ab4b3dca10 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Mon, 2 Sep 2019 12:05:21 +0200 Subject: [PATCH 23/28] switch BertLayerNorm to offical LayerNorm, change BertAdam to AdamW with option correct_bias --- onmt/encoders/__init__.py | 4 +- onmt/encoders/bert.py | 28 +-------- onmt/encoders/transformer.py | 6 +- onmt/models/bert_generators.py | 3 +- onmt/modules/position_ffn.py | 6 +- onmt/utils/__init__.py | 4 +- onmt/utils/optimizers.py | 106 +++++++++++++++++++-------------- 7 files changed, 72 insertions(+), 85 deletions(-) diff --git a/onmt/encoders/__init__.py b/onmt/encoders/__init__.py index fc054b8208..d59b097fea 100644 --- a/onmt/encoders/__init__.py +++ b/onmt/encoders/__init__.py @@ -6,7 +6,7 @@ from onmt.encoders.mean_encoder import MeanEncoder from onmt.encoders.audio_encoder import AudioEncoder from onmt.encoders.image_encoder import ImageEncoder -from onmt.encoders.bert import BertEncoder, BertLayerNorm +from onmt.encoders.bert import BertEncoder str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder, @@ -14,4 +14,4 @@ "audio": AudioEncoder, "mean": MeanEncoder, "bert": BertEncoder} __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", - "MeanEncoder", "str2enc", "BertEncoder", "BertLayerNorm"] + "MeanEncoder", "str2enc", "BertEncoder"] diff --git a/onmt/encoders/bert.py b/onmt/encoders/bert.py index 853e74dfb5..03f714d00e 100644 --- a/onmt/encoders/bert.py +++ b/onmt/encoders/bert.py @@ -1,10 +1,9 @@ -import torch import torch.nn as nn from onmt.encoders.transformer import TransformerEncoderLayer class BertEncoder(nn.Module): - """BERT Encoder: A Transformer Encoder with BertLayerNorm and BertPooler. + """BERT Encoder: A Transformer Encoder with LayerNorm and BertPooler. :cite:`DBLP:journals/corr/abs-1810-04805` Args: @@ -35,7 +34,7 @@ def __init__(self, embeddings, num_layers=12, d_model=768, heads=12, max_relative_positions=max_relative_positions, activation='gelu', is_bert=True) for _ in range(num_layers)]) - self.layer_norm = BertLayerNorm(d_model, eps=1e-12) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-12) self.pooler = BertPooler(d_model) @classmethod @@ -119,26 +118,3 @@ def forward(self, hidden_states): first_token_tensor = hidden_states[:, 0, :] # [batch, d_model] pooled_output = self.activation_fn(self.dense(first_token_tensor)) return pooled_output - - -class BertLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-12): - """Layernorm module in the TF style(epsilon inside the square root). - https://www.tensorflow.org/api_docs/python/tf/contrib/layers/layer_norm. - - Args: - hidden_size (int): size of hidden layer. - """ - - super(BertLayerNorm, self).__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) - self.variance_epsilon = eps - - def forward(self, x): - """layer normalization is perform on input x.""" - - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x + self.bias diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 6dd6d5a315..fb9da4eb88 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -4,7 +4,6 @@ import torch.nn as nn -import onmt from onmt.encoders.encoder import EncoderBase from onmt.modules import MultiHeadedAttention from onmt.modules.position_ffn import PositionwiseFeedForward @@ -38,9 +37,8 @@ def __init__(self, d_model, heads, d_ff, dropout, attention_dropout, max_relative_positions=max_relative_positions) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, activation, is_bert) - self.layer_norm = (onmt.encoders.BertLayerNorm(d_model, eps=1e-12) - if is_bert - else nn.LayerNorm(d_model, eps=1e-6)) + self.layer_norm = nn.LayerNorm( + d_model, eps=1e-12 if is_bert else 1e-6) self.dropout = nn.Dropout(dropout) self.is_bert = is_bert diff --git a/onmt/models/bert_generators.py b/onmt/models/bert_generators.py index c82fb7481e..34f734df76 100644 --- a/onmt/models/bert_generators.py +++ b/onmt/models/bert_generators.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn -import onmt from onmt.utils import get_activation_fn @@ -100,7 +99,7 @@ def __init__(self, hidden_size): super(BertPredictionTransform, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation = get_activation_fn('gelu') - self.layer_norm = onmt.encoders.BertLayerNorm(hidden_size, eps=1e-12) + self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12) def forward(self, hidden_states): """ diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index 435b642560..a35f6a1168 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -2,7 +2,6 @@ import torch.nn as nn -import onmt from onmt.utils import get_activation_fn @@ -25,9 +24,8 @@ def __init__(self, d_model, d_ff, dropout=0.1, super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) - self.layer_norm = (onmt.encoders.BertLayerNorm(d_model, eps=1e-12) - if is_bert - else nn.LayerNorm(d_model, eps=1e-6)) + self.layer_norm = nn.LayerNorm( + d_model, eps=1e-12 if is_bert else 1e-6) self.dropout_1 = nn.Dropout(dropout) self.activation = get_activation_fn(activation) self.dropout_2 = nn.Dropout(dropout) diff --git a/onmt/utils/__init__.py b/onmt/utils/__init__.py index 628ec83ca6..54be26f797 100644 --- a/onmt/utils/__init__.py +++ b/onmt/utils/__init__.py @@ -4,7 +4,7 @@ from onmt.utils.report_manager import ReportMgr, build_report_manager from onmt.utils.statistics import Statistics, BertStatistics from onmt.utils.optimizers import MultipleOptimizer, \ - Optimizer, AdaFactor, BertAdam + Optimizer, AdaFactor, AdamW from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts from onmt.utils.activation_fn import get_activation_fn from onmt.utils.bert_tokenization import BertTokenizer @@ -12,6 +12,6 @@ __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", "build_report_manager", "Statistics", "BertStatistics", - "MultipleOptimizer", "Optimizer", "AdaFactor", "BertAdam", + "MultipleOptimizer", "Optimizer", "AdaFactor", "AdamW", "EarlyStopping", "scorers_from_opts", "get_activation_fn", "BertTokenizer", "PRETRAINED_VOCAB_ARCHIVE_MAP"] diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index 1e82937137..39b770d839 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -53,11 +53,14 @@ def build_torch_optimizer(model, opt): betas=betas, eps=1e-9) elif opt.optim == 'bertadam': - optimizer = BertAdam( + optimizer = AdamW( params, lr=opt.learning_rate, betas=betas, - eps=1e-9) + eps=1e-9, + amsgrad=False, + correct_bias=False, + weight_decay=0.01) elif opt.optim == 'sparseadam': dense = [] sparse = [] @@ -620,37 +623,53 @@ def step(self, closure=None): return loss -class BertAdam(torch.optim.Optimizer): - """Implements Adam algorithm with weight decay fix - (used in BERT while doesn't compensate for bias). - :cite:`DBLP:journals/corr/abs-1711-05101` +class AdamW(torch.optim.Optimizer): + r"""Implements Adam algorithm with weight decay fix, compensate for bias + can be turned off (as in BERT) with option correct_bias. + Enable not use correct_bias comparing to torch.optim.adamw. Args: + params (iterable): iterable of parameters to optimize or dicts define + parameter groups lr (float): learning rate betas (tuple of float): Adam (beta1, beta2). Default: (0.9, 0.999) eps (float): Adams epsilon. Default: 1e-6 + amsgrad (bool): whether to use the AMSGrad variant of this algorithm + from the paper `On the Convergence of Adam and Beyond`_ + Default: False. weight_decay (float): Weight decay. Default: 0.01 - max_grad_norm (float): -1 means no gradients clipping. + correct_bias (bool): whether to use bias correction. Default: True. + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, params, lr=None, betas=(0.9, 0.999), - eps=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): - if not lr >= 0.0: + def __init__(self, params, lr=None, betas=(0.9, 0.999), eps=1e-6, + amsgrad=False, correct_bias=True, weight_decay=0.01): + if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr) + " - should be >= 0.0") + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps) + + " - should be >= 0.0") if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid betas[0] parameter: {}".format( betas[0]) + " - should be in [0.0, 1.0)") if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid betas[1] parameter: {}".format( betas[1]) + " - should be in [0.0, 1.0)") - if not eps >= 0.0: - raise ValueError("Invalid epsilon value: {}".format(eps) + - " - should be >= 0.0") - defaults = dict(lr=lr, betas=betas, - eps=eps, weight_decay=weight_decay, - max_grad_norm=max_grad_norm) - super(BertAdam, self).__init__(params, defaults) + defaults = dict(lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, + correct_bias=correct_bias, weight_decay=weight_decay) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) def step(self, closure=None): """Performs a single optimization step. @@ -671,54 +690,51 @@ def step(self, closure=None): if grad.is_sparse: raise RuntimeError('Adam: not support sparse gradients,' + 'please consider SparseAdam instead') + amsgrad = group['amsgrad'] state = self.state[p] # State initialization if len(state) == 0: - # state['step'] = 0 + state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. + state['max_exp_avg_sq'] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] beta1, beta2 = group['betas'] - # NOTE: Add grad clipping, DONE before step function - # if group['max_grad_norm'] > 0: - # clip_grad_norm_(p, group['max_grad_norm']) + state['step'] += 1 # Decay first and second moment running average coefficient - # In-place operations to update the averages at the same time # exp_avg = exp_avg * beta1 + (1-beta1)*grad exp_avg.mul_(beta1).add_(1 - beta1, grad) # exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2)*grad**2 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - update = exp_avg / (exp_avg_sq.sqrt() + group['eps']) - - # Ref: https://arxiv.org/abs/1711.05101 - # Just adding the square of the weights to the loss function - # is *not* the correct way of using L2/weight decay with Adam, - # since it will interact with m/v parameters in strange ways. - # - # Instead we want to decay the weights that does not interact - # with the m/v. This is equivalent to add the square - # of the weights to the loss with plain (non-momentum) SGD. - if group['weight_decay'] > 0.0: - update += group['weight_decay'] * p.data - - lr_scheduled = group['lr'] - - update_with_lr = lr_scheduled * update - p.data.add_(-update_with_lr) + if amsgrad: + # Maintains max of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) - # state['step'] += 1 + step_size = group['lr'] + # NOTE: AdamW used in Bert has "No bias correction" + if group['correct_bias']: + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = (step_size * sqrt(bias_correction2) + / bias_correction1) - # NOTE: BertAdam "No bias correction" comparing to standard - # bias_correction1 = 1 - betas[0] ** state['step'] - # bias_correction2 = 1 - betas[1] ** state['step'] - # step_size = lr_scheduled * math.sqrt(bias_correction2) - # / bias_correction1 + p.data.addcdiv_(-step_size, exp_avg, denom) + # Perform correct weight decay(rather than L2) + p.data.mul_(1 - group['lr'] * group['weight_decay']) return loss From 4938a932960abc8c24d30f910d6d2de14647694e Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Wed, 4 Sep 2019 11:24:19 +0200 Subject: [PATCH 24/28] fix bert valid step, remove unuse part in saver --- onmt/models/model_saver.py | 10 ++++------ onmt/trainer.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index 2b67929de4..1d744e81c6 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -103,12 +103,10 @@ def _save(self, step, model): real_generator = (real_model.generator.module if isinstance(real_model.generator, nn.DataParallel) else real_model.generator) - if hasattr(real_model, 'bert'): - model_state_dict = real_model.bert.state_dict() - else: - model_state_dict = real_model.state_dict() - model_state_dict = {k: v for k, v in model_state_dict.items() - if 'generator' not in k} + + model_state_dict = real_model.state_dict() + model_state_dict = {k: v for k, v in model_state_dict.items() + if 'generator' not in k} generator_state_dict = real_generator.state_dict() # NOTE: We need to trim the vocab to remove any unk tokens that diff --git a/onmt/trainer.py b/onmt/trainer.py index f33a64f860..dac40be80f 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -338,7 +338,7 @@ def validate(self, valid_iter, moving_average=None): token_type_ids = batch.segment_ids # F-prop through the model. all_encoder_layers, pooled_out = \ - valid_model.bert(input_ids, token_type_ids) + valid_model(input_ids, token_type_ids) seq_class_log_prob, prediction_log_prob = \ valid_model.generator(all_encoder_layers, pooled_out) From b1658f5deb0b2943ba9d7ffe0a9b7787354085ce Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Thu, 12 Sep 2019 16:42:58 +0200 Subject: [PATCH 25/28] add dynamic batchingwhen inference --- onmt/opts.py | 4 ++++ onmt/translate/predictor.py | 15 +++++++++++---- predict.py | 10 ++++++---- preprocess_bert_new.py => preprocess_bert.py | 0 4 files changed, 21 insertions(+), 8 deletions(-) rename preprocess_bert_new.py => preprocess_bert.py (100%) diff --git a/onmt/opts.py b/onmt/opts.py index c32a5c8b23..4004abe3ea 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -856,6 +856,10 @@ def predict_opts(parser): group = parser.add_argument_group('Efficiency') group.add('--batch_size', '-batch_size', type=int, default=8, help='Batch size') + group.add('--batch_type', '-batch_type', default='sents', + choices=["sents", "tokens"], + help="Batch grouping for batch_size. Standard " + "is sents. Tokens will do dynamic batching") group.add('--gpu', '-gpu', type=int, default=-1, help="Device to run on") group.add('--seed', '-seed', type=int, default=829, help="Random seed") group.add('--log_file', '-log_file', type=str, default="", diff --git a/onmt/translate/predictor.py b/onmt/translate/predictor.py index b95391d3fc..28091b8369 100644 --- a/onmt/translate/predictor.py +++ b/onmt/translate/predictor.py @@ -8,6 +8,7 @@ import torchtext.data import onmt.model_builder import onmt.inputters as inputters +from onmt.inputters.inputter import max_tok_len from onmt.utils.misc import set_random_seed @@ -165,13 +166,15 @@ def __init__( label_field = self.fields["category"] self.label_vocab = label_field.vocab - def classify(self, data, batch_size, tokenizer, - delimiter=' ||| ', max_seq_len=256): + def classify(self, data, batch_size, tokenizer, delimiter=' ||| ', + max_seq_len=256, batch_type="sents"): """Classify content of ``data``. Args: data (list of str): ['Sentence1 ||| Sentence2',...]. batch_size (int): size of examples per mini-batch + batch_type (str): Batch grouping for batch_size. Chose from + {'sents', 'tokens'}, default batch_size count by sentence. Returns: all_predictions (list of str):[c1, ..., cn]. @@ -183,6 +186,7 @@ def classify(self, data, batch_size, tokenizer, data_iter = torchtext.data.Iterator( dataset=dataset, batch_size=batch_size, + batch_size_fn=max_tok_len if batch_type == "tokens" else None, device=self._dev, train=False, sort=False, @@ -265,13 +269,15 @@ def __init__( self.pad_token = label_field.pad_token self.pad_index = self.label_vocab.stoi[self.pad_token] - def tagging(self, data, batch_size, tokenizer, - delimiter=' ', max_seq_len=256): + def tagging(self, data, batch_size, tokenizer, delimiter=' ', + max_seq_len=256, batch_type="sents"): """Tagging content of ``data``. Args: data (list of str): ['T1 T2 ... Tn',...]. batch_size (int): size of examples per mini-batch + batch_type (str): Batch grouping for batch_size. Chose from + {'sents', 'tokens'}, default batch_size count by sentence. Returns: all_predictions (list of list of str): [['L1', ..., 'Ln'],...]. @@ -282,6 +288,7 @@ def tagging(self, data, batch_size, tokenizer, data_iter = torchtext.data.Iterator( dataset=dataset, batch_size=batch_size, + batch_size_fn=max_tok_len if batch_type == "tokens" else None, device=self._dev, train=False, sort=False, diff --git a/predict.py b/predict.py index 8ca5c78fb6..5fdeb53f29 100755 --- a/predict.py +++ b/predict.py @@ -19,7 +19,7 @@ def main(opt): opt.vocab_model, do_lower_case=opt.do_lower_case) data_shards = split_corpus(opt.data, opt.shard_size) if opt.task == 'classification': - classifier = build_classifier(opt) + classifier = build_classifier(opt, logger) for i, data_shard in enumerate(data_shards): logger.info("Classify shard %d." % i) data = [seq.decode("utf-8") for seq in data_shard] @@ -28,10 +28,11 @@ def main(opt): opt.batch_size, tokenizer, delimiter=opt.delimiter, - max_seq_len=opt.max_seq_len + max_seq_len=opt.max_seq_len, + batch_type=opt.batch_type ) if opt.task == 'tagging': - tagger = build_tagger(opt) + tagger = build_tagger(opt, logger) for i, data_shard in enumerate(data_shards): logger.info("Tagging shard %d." % i) data = [seq.decode("utf-8") for seq in data_shard] @@ -40,7 +41,8 @@ def main(opt): opt.batch_size, tokenizer, delimiter=opt.delimiter, - max_seq_len=opt.max_seq_len + max_seq_len=opt.max_seq_len, + batch_type=opt.batch_type ) diff --git a/preprocess_bert_new.py b/preprocess_bert.py similarity index 100% rename from preprocess_bert_new.py rename to preprocess_bert.py From 9b1abd283788cc77f2d99951b294540ab85a5d5a Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Tue, 19 Nov 2019 10:27:14 +0100 Subject: [PATCH 26/28] update classifier with confiance option --- onmt/translate/predictor.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/onmt/translate/predictor.py b/onmt/translate/predictor.py index 28091b8369..5ebffa2293 100644 --- a/onmt/translate/predictor.py +++ b/onmt/translate/predictor.py @@ -60,7 +60,7 @@ class Predictor(object): fields (dict[str, torchtext.data.Field]): A dict of field. gpu (int): GPU device. Set to negative for no GPU. data_type (str): Source data type. - verbose (bool): Print/log every predition. + verbose (bool): output every predition with confidences. report_time (bool): Print/log total time/frequency. out_file (TextIO or codecs.StreamReaderWriter): Output file. logger (logging.Logger or NoneType): Logger. @@ -138,7 +138,7 @@ class Classifier(Predictor): fields (dict[str, torchtext.data.Field]): A dict of field. gpu (int): GPU device. Set to negative for no GPU. data_type (str): Source data type. - verbose (bool): Print/log every predition. + verbose (bool): output every predition with confidences. report_time (bool): Print/log total time/frequency. out_file (TextIO or codecs.StreamReaderWriter): Output file. logger (logging.Logger or NoneType): Logger. @@ -228,6 +228,13 @@ def classify_batch(self, batch): pred_sents_ids = seq_class_log_prob.argmax(-1).tolist() pred_sents_labels = [self.label_vocab.itos[index] for index in pred_sents_ids] + if self.verbose: + seq_class_prob = seq_class_log_prob.exp() + category_probs = seq_class_prob.tolist() + preds = ['\t'.join(map(str, category_prob)) + '\t' + pred + for category_prob, pred in zip( + category_probs, pred_sents_labels)] + return preds return pred_sents_labels @@ -239,7 +246,7 @@ class Tagger(Predictor): fields (dict[str, torchtext.data.Field]): A dict of field. gpu (int): GPU device. Set to negative for no GPU. data_type (str): Source data type. - verbose (bool): Print/log every predition. + verbose (bool): output every predition with confidences. report_time (bool): Print/log total time/frequency. out_file (TextIO or codecs.StreamReaderWriter): Output file. logger (logging.Logger or NoneType): Logger. @@ -276,8 +283,6 @@ def tagging(self, data, batch_size, tokenizer, delimiter=' ', Args: data (list of str): ['T1 T2 ... Tn',...]. batch_size (int): size of examples per mini-batch - batch_type (str): Batch grouping for batch_size. Chose from - {'sents', 'tokens'}, default batch_size count by sentence. Returns: all_predictions (list of list of str): [['L1', ..., 'Ln'],...]. From e352a94a6eda3e4ade7efa7432a504fb69be16d4 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Fri, 22 Nov 2019 17:10:37 +0100 Subject: [PATCH 27/28] rm tailing space --- docs/source/FAQ.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index ffded19c78..34ed2422d4 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -263,7 +263,7 @@ After preprocessed data have been generated, you can load weights from a pretrai A usage example is given below: ```bash python train.py --is_bert --task_type {pretraining, classification, tagging} - --data PREPROCESSED_DATAIFILE + --data PREPROCESSED_DATAIFILE --train_from CONVERTED_CHECKPOINT.pt [--param_init 0.1] --save_model MODEL_PREFIX --save_checkpoint_steps 1000 [--world_size 2] [--gpu_ranks 0 1] From 6c8e8e61ea22c577b05e80d454cfb57c99e26633 Mon Sep 17 00:00:00 2001 From: Linxiao Zeng Date: Fri, 22 Nov 2019 18:22:20 +0100 Subject: [PATCH 28/28] fix travis --- .travis.yml | 1 + docs/source/refs.bib | 2 ++ onmt/bin/train.py | 2 +- onmt/trainer.py | 3 +++ pregenerate_bert_training_data.py | 2 +- 5 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index f1411e0ef6..036bab819e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,6 +14,7 @@ addons: before_install: # Install CPU version of PyTorch. - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install torch==1.2.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html; fi + - pip install --upgrade setuptools - pip install -r requirements.opt.txt - python setup.py install env: diff --git a/docs/source/refs.bib b/docs/source/refs.bib index cc95d25c0c..b25ab7879a 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -483,6 +483,8 @@ @article{DBLP:journals/corr/HendrycksG16 timestamp = {Mon, 13 Aug 2018 16:46:20 +0200}, biburl = {https://dblp.org/rec/bib/journals/corr/HendrycksG16}, bibsource = {dblp computer science bibliography, https://dblp.org} +} + @inproceedings{garg2019jointly, title = {Jointly Learning to Align and Translate with Transformer Models}, author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias}, diff --git a/onmt/bin/train.py b/onmt/bin/train.py index f7478c61b9..86068044ec 100755 --- a/onmt/bin/train.py +++ b/onmt/bin/train.py @@ -139,7 +139,7 @@ def next_batch(device_id): else: if isinstance(b.src, tuple): b.src = tuple([_.to(torch.device(device_id)) - for _ in b.src]) + for _ in b.src]) else: b.src = b.src.to(torch.device(device_id)) b.tgt = b.tgt.to(torch.device(device_id)) diff --git a/onmt/trainer.py b/onmt/trainer.py index 276392513e..42b57c6bf5 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -310,7 +310,10 @@ def train(self, def validate(self, valid_iter, moving_average=None): """ Validate model. + + Args: valid_iter: validate data iterator + Returns: :obj:`nmt.Statistics`: validation loss statistics """ diff --git a/pregenerate_bert_training_data.py b/pregenerate_bert_training_data.py index 92ba10030c..a557410196 100755 --- a/pregenerate_bert_training_data.py +++ b/pregenerate_bert_training_data.py @@ -11,7 +11,7 @@ from random import random, randrange, randint, shuffle, choice from onmt.utils import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP from onmt.utils.file_utils import cached_path -from preprocess_bert_new import build_vocab_from_tokenizer +from preprocess_bert import build_vocab_from_tokenizer import numpy as np import json from onmt.inputters.inputter import get_bert_fields