Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Possible issue with vocabulary when providing BertForClassification a BERT model path #3233

Closed
mega002 opened this issue Sep 10, 2019 · 6 comments

Comments

@mega002
Copy link

mega002 commented Sep 10, 2019

System:

  • OS: Ubuntu 16.04.5 LTS
  • Python version: 3.6.9
  • AllenNLP version: v0.8.5

Question:
I am trying to do a two-step fine-tuning of BERT-base cased for classification. Concretely:

  1. Load "bert-base-cased" and train in on one dataset
  2. Load the trained model and train it further on a second dataset

Step 1 is simple following the example configuration here:
https://github.com/allenai/allennlp/blob/master/allennlp/tests/fixtures/bert/bert_for_classification.jsonnet

For step 2, I tried to use the same configuration while providing the trained model path from step 1. However, I realized that eventually, the token indexer needs a vocabulary path, and not a compressed model. I couldn't find BERT-base original vocabulary under the vocabulary directory in the model's directory, so I tried directly providing the path of the cached vocabulary I used in step 1 ("bert-base-cased-vocab.txt"):

{
    "dataset_reader": {
        "token_indexers": {
            "bert": {
                "type": "bert-pretrained",
                "pretrained_model": 'bert-base-cased-vocab.txt'
            }
        }
    },
    "model": {
        "type": "bert_for_classification",
        "bert_model": 'experiments/bert_trained/model.tar.gz',
    },
}

According to the logs, it seems that the vocabulary and trained model were successfully loaded. However, I'm getting this error when training starts:

Traceback (most recent call last):
  File "/home/morgeva/miniconda3/envs/allennlp085/bin/allennlp", line 10, in <module>
    sys.exit(run())
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/run.py", line 18, in run
    main(prog="allennlp")
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/commands/__init__.py", line 102, in main
    args.func(args)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/commands/train.py", line 124, in train_model_from_args
    args.cache_prefix)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/commands/train.py", line 168, in train_model_from_file
    cache_directory, cache_prefix)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/commands/train.py", line 226, in train_model
    cache_prefix)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/training/trainer_pieces.py", line 61, in from_params
    model = Model.from_params(vocab=vocab, params=params.pop('model'))
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/common/from_params.py", line 365, in from_params
    return subclass.from_params(params=params, **extras)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/common/from_params.py", line 388, in from_params
    return cls(**kwargs)  # type: ignore
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/models/bert_for_classification.py", line 62, in __init__
    self.bert_model = PretrainedBertModel.load(bert_model)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/allennlp/modules/token_embedders/bert_token_embedder.py", line 38, in load
    model = BertModel.from_pretrained(model_name)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/pytorch_pretrained_bert/modeling.py", line 600, in from_pretrained
    model = cls(config, *inputs, **kwargs)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/pytorch_pretrained_bert/modeling.py", line 704, in __init__
    self.embeddings = BertEmbeddings(config)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/pytorch_pretrained_bert/modeling.py", line 251, in __init__
    self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
  File "/home/morgeva/miniconda3/envs/allennlp085/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 99, in __init__
    self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
RuntimeError: Trying to create tensor with negative dimension -1: [-1, 768]

I suspect the issue might be related with the vocabulary, which, for some reason, seems to have size -1 in the loaded model config ("vocab_size": -1).

Am I missing something? Is it possible to configure BertForClassification for using a model path rather than a model name like bert-base-uncased?

Thanks

@matt-gardner
Copy link
Contributor

So, this is definitely a use case we should try to support, but the problem is that the bert_model keys expect a huggingface model, while you're providing an allennlp model. I'm not sure what the right solution here is, though - we don't have an easy way to pull out the bert weights and config and such, and save a new hugging face model that you can use. Probably the right answer is some kind of convert-to-huggingface command, that takes a model archive and returns something compatible with huggingface code.

@joelgrus
Copy link
Contributor

joelgrus commented Oct 3, 2019

why is this not just a standard use of allennlp fine-tune? it should rehydrate your trained model, and will load the same token indexer, and will just be pointed at a different dataset. what does using BERT complicate about that?

@matt-gardner
Copy link
Contributor

Ah, yeah, I was thinking of using the fine-tuned BERT weights as you might want for some other (non-classification) task. If you just want to do two classifications, fine-tune might be enough.

@cemilcengiz
Copy link

I got the same error while I was trying to "further fine-tune" a base BERT model on a classification task, where the model was first fine-tuned on a sequence labeling task.

@MaksymDel
Copy link
Contributor

MaksymDel commented Dec 6, 2019

The problem is that the BERT vocab is currently doesn't get saved to disk after you train a model (see #3097). fine-tune then wants to load vocabulary from disk and doesn't find it. This issue should be fixed as soon as #3097 is fixed.
The problem is not only with the fine-tune command but with all functionality that includes loading BERT model serialized with allennlp (for example serving a demo for transformers).

Meanwhile, you can download and save the BERT vocab manually into the vocabulary folder inside of your allennlp experiment directory while naming the file correctly (depends on your namespace name, e.g. bert or tags).

Alternatively, you can try to run dry run command. I did not try it myself, but from the code, it looks like it will save you the vocabulary to the disk correctly. Next, you can use this vocabulary with your model.

@DeNeutoy
Copy link
Contributor

Let's track this in #3097, as the issues are similar if not identical

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants