From 4c91621a5e0f1ec5cc36ccc23ceb18c1b41122bc Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Sat, 9 Sep 2023 01:44:21 -0700 Subject: [PATCH] Inverse state dict for BERT (#527) --- .gitignore | 3 + flash_attn/models/bert.py | 154 +++++++++++++++++++++++++++++++++++--- tests/models/test_bert.py | 30 +++++++- 3 files changed, 171 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index e63f4d40c..c0a6c7cb1 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ var/ # IDE-related .idea/ + +# Dev +venv \ No newline at end of file diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index 2cafcca15..b6136dc52 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -10,23 +10,19 @@ from collections import OrderedDict from collections.abc import Sequence from functools import partial +from typing import Any, Mapping import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import BertConfig +from transformers import BertConfig, PretrainedConfig from transformers.models.bert.modeling_bert import ( - BaseModelOutputWithPoolingAndCrossAttentions, - BertForPreTrainingOutput, -) - -from flash_attn.bert_padding import ( - index_first_axis, - index_first_axis_residual, - pad_input, - unpad_input, -) + BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput) + +from flash_attn.bert_padding import (index_first_axis, + index_first_axis_residual, pad_input, + unpad_input) from flash_attn.modules.block import Block from flash_attn.modules.embedding import BertEmbeddings from flash_attn.modules.mha import MHA @@ -511,7 +507,11 @@ def forward( ) -def remap_state_dict(state_dict, config): +def remap_state_dict(state_dict, config: PretrainedConfig): + """ + Map the state_dict of a Huggingface BERT model to be flash_attn compatible. + """ + # LayerNorm def key_mapping_ln_gamma_beta(key): key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key) @@ -618,3 +618,133 @@ def key_mapping_decoder_bias(key): ) return state_dict + + +def inv_remap_state_dict(state_dict, config: PretrainedConfig): + """ + Map the state_dict of a flash_attn model to be Huggingface BERT compatible. + + This function is meant to be the inverse of remap_state_dict. + """ + # Word embedding + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + if pad_vocab_size_multiple > 1: + word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] + decoder_weight = state_dict["cls.predictions.decoder.weight"] + decoder_bias = state_dict["cls.predictions.decoder.bias"] + # unpad embeddings + state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[ + : config.orig_vocab_size, : + ] + state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :] + state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size] + + for d in range(config.num_hidden_layers): + last_layer_subset = getattr(config, "last_layer_subset", False) + if not last_layer_subset or d != (config.num_hidden_layers - 1): + Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight") + Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias") + state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[ + : Wqkv_weights.shape[0] // 3, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[ + Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[ + 2 * Wqkv_weights.shape[0] // 3 :, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[ + : Wqkv_biases.shape[0] // 3 + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[ + Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3 + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[ + 2 * Wqkv_biases.shape[0] // 3 : + ] + else: + Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight") + Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight") + Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias") + Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias") + state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight + state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[ + : Wkv_weights.shape[0] // 2, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[ + Wkv_weights.shape[0] // 2 :, : + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias + state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[ + : Wkv_biases.shape[0] // 2 + ] + state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[ + Wkv_biases.shape[0] // 2 : + ] + + def inv_key_mapping_ln(key): + key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key) + key = re.sub( + r"bert.encoder.layers.(\d+).norm1.(weight|bias)", + r"bert.encoder.layers.\1.attention.output.LayerNorm.\2", + key, + ) + key = re.sub( + r"bert.encoder.layers.(\d+).norm2.(weight|bias)", + r"bert.encoder.layers.\1.output.LayerNorm.\2", + key, + ) + key = re.sub( + r"cls.predictions.transform.layer_norm.(weight|bias)", + r"cls.predictions.transform.LayerNorm.\1", + key, + ) + return key + + def inv_key_mapping_ln_gamma_beta(key): + key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key) + key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key) + return key + + def inv_key_mapping_layers(key): + return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key) + + def inv_key_mapping_mlp(key): + key = re.sub( + r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)", + r"bert.encoder.layer.\1.intermediate.dense.\2", + key, + ) + key = re.sub( + r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)", + r"bert.encoder.layer.\1.output.dense.\2", + key, + ) + return key + + def inv_key_mapping_attn(key): + return re.sub( + r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)", + r"bert.encoder.layer.\1.attention.output.dense.\2", + key, + ) + + def inv_key_mapping_decoder_bias(key): + return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key) + + state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items()) + state_dict = OrderedDict( + (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items() + ) + state_dict = OrderedDict( + (inv_key_mapping_layers(key), value) for key, value in state_dict.items() + ) + state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items()) + state_dict = OrderedDict( + (inv_key_mapping_attn(key), value) for key, value in state_dict.items() + ) + state_dict = OrderedDict( + (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items() + ) + + return state_dict diff --git a/tests/models/test_bert.py b/tests/models/test_bert.py index f53cfb9be..c3d2ddbe1 100644 --- a/tests/models/test_bert.py +++ b/tests/models/test_bert.py @@ -5,12 +5,15 @@ import torch import torch.nn.functional as F from einops import rearrange -from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict -from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import BertConfig -from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF +from transformers.models.bert.modeling_bert import \ + BertForPreTraining as BertForPreTrainingHF from transformers.models.bert.modeling_bert import BertModel as BertModelHF +from flash_attn.models.bert import (BertForPreTraining, BertModel, + inv_remap_state_dict, remap_state_dict) +from flash_attn.utils.pretrained import state_dict_from_pretrained + @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) @@ -43,7 +46,7 @@ def key_mapping_ln_gamma_beta(key): return model_hf -@pytest.mark.parametrize('model_name', ["bert-base-uncased"]) +@pytest.mark.parametrize("model_name", ["bert-base-uncased"]) def test_bert_non_optimized(model_name): """Check that our implementation of BERT (without any optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF @@ -297,3 +300,22 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs ).abs().max().item() # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0. # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item() + + +@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"]) +def test_inv_remap_state_dict(model_name: str): + """ + Verify that we can convert a HF BERT model to flash_attn and back. + """ + + state_dict = state_dict_from_pretrained(model_name) + config = BertConfig.from_pretrained(model_name) + + flash_state_dict = remap_state_dict(state_dict, config) + recovered_state_dict = inv_remap_state_dict(flash_state_dict, config) + + assert set(state_dict.keys()) == set(recovered_state_dict.keys()) + + for k in state_dict.keys(): + assert state_dict[k].shape == recovered_state_dict[k].shape + torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)