Skip to content

Commit

Permalink
Inverse state dict for BERT (Dao-AILab#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhu authored Sep 9, 2023
1 parent a86442f commit 4c91621
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 16 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ var/

# IDE-related
.idea/

# Dev
venv
154 changes: 142 additions & 12 deletions flash_attn/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,19 @@
from collections import OrderedDict
from collections.abc import Sequence
from functools import partial
from typing import Any, Mapping

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BertConfig
from transformers import BertConfig, PretrainedConfig
from transformers.models.bert.modeling_bert import (
BaseModelOutputWithPoolingAndCrossAttentions,
BertForPreTrainingOutput,
)

from flash_attn.bert_padding import (
index_first_axis,
index_first_axis_residual,
pad_input,
unpad_input,
)
BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)

from flash_attn.bert_padding import (index_first_axis,
index_first_axis_residual, pad_input,
unpad_input)
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.modules.mha import MHA
Expand Down Expand Up @@ -511,7 +507,11 @@ def forward(
)


def remap_state_dict(state_dict, config):
def remap_state_dict(state_dict, config: PretrainedConfig):
"""
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
"""

# LayerNorm
def key_mapping_ln_gamma_beta(key):
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
Expand Down Expand Up @@ -618,3 +618,133 @@ def key_mapping_decoder_bias(key):
)

return state_dict


def inv_remap_state_dict(state_dict, config: PretrainedConfig):
"""
Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
This function is meant to be the inverse of remap_state_dict.
"""
# Word embedding
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
if pad_vocab_size_multiple > 1:
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
decoder_weight = state_dict["cls.predictions.decoder.weight"]
decoder_bias = state_dict["cls.predictions.decoder.bias"]
# unpad embeddings
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
: config.orig_vocab_size, :
]
state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]

for d in range(config.num_hidden_layers):
last_layer_subset = getattr(config, "last_layer_subset", False)
if not last_layer_subset or d != (config.num_hidden_layers - 1):
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
: Wqkv_weights.shape[0] // 3, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
2 * Wqkv_weights.shape[0] // 3 :, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
: Wqkv_biases.shape[0] // 3
]
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
]
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
2 * Wqkv_biases.shape[0] // 3 :
]
else:
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
: Wkv_weights.shape[0] // 2, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
Wkv_weights.shape[0] // 2 :, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
: Wkv_biases.shape[0] // 2
]
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
Wkv_biases.shape[0] // 2 :
]

def inv_key_mapping_ln(key):
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
key = re.sub(
r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
key,
)
key = re.sub(
r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
r"bert.encoder.layers.\1.output.LayerNorm.\2",
key,
)
key = re.sub(
r"cls.predictions.transform.layer_norm.(weight|bias)",
r"cls.predictions.transform.LayerNorm.\1",
key,
)
return key

def inv_key_mapping_ln_gamma_beta(key):
key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
return key

def inv_key_mapping_layers(key):
return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)

def inv_key_mapping_mlp(key):
key = re.sub(
r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
r"bert.encoder.layer.\1.intermediate.dense.\2",
key,
)
key = re.sub(
r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
r"bert.encoder.layer.\1.output.dense.\2",
key,
)
return key

def inv_key_mapping_attn(key):
return re.sub(
r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
r"bert.encoder.layer.\1.attention.output.dense.\2",
key,
)

def inv_key_mapping_decoder_bias(key):
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)

state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
state_dict = OrderedDict(
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
)
state_dict = OrderedDict(
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
)
state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
state_dict = OrderedDict(
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
)
state_dict = OrderedDict(
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
)

return state_dict
30 changes: 26 additions & 4 deletions tests/models/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from transformers.models.bert.modeling_bert import \
BertForPreTraining as BertForPreTrainingHF
from transformers.models.bert.modeling_bert import BertModel as BertModelHF

from flash_attn.models.bert import (BertForPreTraining, BertModel,
inv_remap_state_dict, remap_state_dict)
from flash_attn.utils.pretrained import state_dict_from_pretrained


@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
Expand Down Expand Up @@ -43,7 +46,7 @@ def key_mapping_ln_gamma_beta(key):
return model_hf


@pytest.mark.parametrize('model_name', ["bert-base-uncased"])
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
Expand Down Expand Up @@ -297,3 +300,22 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
).abs().max().item()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()


@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
def test_inv_remap_state_dict(model_name: str):
"""
Verify that we can convert a HF BERT model to flash_attn and back.
"""

state_dict = state_dict_from_pretrained(model_name)
config = BertConfig.from_pretrained(model_name)

flash_state_dict = remap_state_dict(state_dict, config)
recovered_state_dict = inv_remap_state_dict(flash_state_dict, config)

assert set(state_dict.keys()) == set(recovered_state_dict.keys())

for k in state_dict.keys():
assert state_dict[k].shape == recovered_state_dict[k].shape
torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)

0 comments on commit 4c91621

Please sign in to comment.