Skip to content

Commit 2766116

Browse files
authored
[BT] add BetterTransformer support for ProphetNet (#923)
* Add ProphetNetEncoder better transformer layer * add prophetnet to tests * fix deprecated MutableMapping import * set original_shape attr to hidden_states Set original_shape attr once instead of recomputing it every time from attention mask * add prophetnet entry to overview * change test prophetnet model to `hirotasoshu/tiny-random-prophetnet`
1 parent b4a83e2 commit 2766116

File tree

6 files changed

+140
-2
lines changed

6 files changed

+140
-2
lines changed

docs/source/bettertransformer/overview.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ The list of supported model below:
4242
- [Marian](https://arxiv.org/abs/1804.00344)
4343
- [MBart](https://arxiv.org/abs/2001.08210)
4444
- [M2M100](https://arxiv.org/abs/2010.11125)
45+
- [ProphetNet](https://arxiv.org/abs/2001.04063)
4546
- [RemBERT](https://arxiv.org/abs/2010.12821)
4647
- [RoBERTa](https://arxiv.org/abs/1907.11692)
4748
- [RoCBert](https://aclanthology.org/2022.acl-long.65.pdf)

optimum/bettertransformer/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DistilBertLayerBetterTransformer,
3030
FSMTEncoderLayerBetterTransformer,
3131
MBartEncoderLayerBetterTransformer,
32+
ProphetNetEncoderLayerBetterTransformer,
3233
ViltLayerBetterTransformer,
3334
ViTLayerBetterTransformer,
3435
Wav2Vec2EncoderLayerBetterTransformer,
@@ -74,6 +75,7 @@ class BetterTransformerManager:
7475
"opt": {"OPTAttention": OPTAttentionLayerBetterTransformer},
7576
"pegasus": {"PegasusAttention": BartAttentionLayerBetterTransformer},
7677
"rembert": {"RemBertLayer": BertLayerBetterTransformer},
78+
"prophetnet": {"ProphetNetEncoderLayer": ProphetNetEncoderLayerBetterTransformer},
7779
"roberta": {"RobertaLayer": BertLayerBetterTransformer},
7880
"roc_bert": {"RoCBertLayer": BertLayerBetterTransformer},
7981
"roformer": {"RoFormerLayer": BertLayerBetterTransformer},

optimum/bettertransformer/models/encoder_models.py

+133
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,139 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
12751275
return (hidden_states, attention_mask)
12761276

12771277

1278+
class ProphetNetEncoderLayerBetterTransformer(BetterTransformerBaseLayer):
1279+
def __init__(self, prophetnet_layer, config):
1280+
r"""
1281+
A simple conversion of the ProphetNet Encoder layer to its `BetterTransformer` implementation.
1282+
1283+
Args:
1284+
prophet_net_layer (`torch.nn.Module`):
1285+
The original ProphetNet Layer where the weights needs to be retrieved.
1286+
"""
1287+
super().__init__(config)
1288+
self.config = config
1289+
# In_proj layer
1290+
self.in_proj_weight = nn.Parameter(
1291+
torch.cat(
1292+
[
1293+
prophetnet_layer.self_attn.query_proj.weight,
1294+
prophetnet_layer.self_attn.key_proj.weight,
1295+
prophetnet_layer.self_attn.value_proj.weight,
1296+
]
1297+
)
1298+
)
1299+
self.in_proj_bias = nn.Parameter(
1300+
torch.cat(
1301+
[
1302+
prophetnet_layer.self_attn.query_proj.bias,
1303+
prophetnet_layer.self_attn.key_proj.bias,
1304+
prophetnet_layer.self_attn.value_proj.bias,
1305+
]
1306+
)
1307+
)
1308+
1309+
# Out proj layer
1310+
self.out_proj_weight = prophetnet_layer.self_attn.out_proj.weight
1311+
self.out_proj_bias = prophetnet_layer.self_attn.out_proj.bias
1312+
1313+
# Linear layer 1
1314+
self.linear1_weight = prophetnet_layer.feed_forward.intermediate.weight
1315+
self.linear1_bias = prophetnet_layer.feed_forward.intermediate.bias
1316+
1317+
# Linear layer 2
1318+
self.linear2_weight = prophetnet_layer.feed_forward.output.weight
1319+
self.linear2_bias = prophetnet_layer.feed_forward.output.bias
1320+
1321+
# Layer norm 1
1322+
self.norm1_eps = prophetnet_layer.self_attn_layer_norm.eps
1323+
self.norm1_weight = prophetnet_layer.self_attn_layer_norm.weight
1324+
self.norm1_bias = prophetnet_layer.self_attn_layer_norm.bias
1325+
1326+
# Layer norm 2
1327+
self.norm2_eps = prophetnet_layer.feed_forward_layer_norm.eps
1328+
self.norm2_weight = prophetnet_layer.feed_forward_layer_norm.weight
1329+
self.norm2_bias = prophetnet_layer.feed_forward_layer_norm.bias
1330+
1331+
# Model hyper parameters
1332+
self.num_heads = prophetnet_layer.self_attn.num_attn_heads
1333+
self.embed_dim = prophetnet_layer.self_attn.head_dim * self.num_heads
1334+
1335+
# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
1336+
self.is_last_layer = False
1337+
1338+
self.original_layers_mapping = {
1339+
"in_proj_weight": [
1340+
"self_attn.query_proj.weight",
1341+
"self_attn.key_proj.weight",
1342+
"self_attn.value_proj.weight",
1343+
],
1344+
"in_proj_bias": ["self_attn.query_proj.bias", "self_attn.key_proj.bias", "self_attn.value_proj.bias"],
1345+
"out_proj_weight": "self_attn.out_proj.weight",
1346+
"out_proj_bias": "self_attn.out_proj.bias",
1347+
"linear1_weight": "feed_forward.intermediate.weight",
1348+
"linear1_bias": "feed_forward.intermediate.bias",
1349+
"linear2_weight": "feed_forward.output.weight",
1350+
"linear2_bias": "feed_forward.output.bias",
1351+
"norm1_weight": "self_attn_layer_norm.weight",
1352+
"norm1_bias": "self_attn_layer_norm.bias",
1353+
"norm2_weight": "feed_forward_layer_norm.weight",
1354+
"norm2_bias": "feed_forward_layer_norm.bias",
1355+
}
1356+
1357+
self.validate_bettertransformer()
1358+
1359+
def forward(self, hidden_states, attention_mask, *_, **__):
1360+
r"""
1361+
This is just a wrapper around the forward function proposed in:
1362+
https://github.com/huggingface/transformers/pull/19553
1363+
"""
1364+
super().forward_checker()
1365+
1366+
if not hasattr(hidden_states, "original_shape"):
1367+
original_shape = hidden_states.shape
1368+
else:
1369+
original_shape = hidden_states.original_shape
1370+
1371+
if hidden_states.is_nested:
1372+
attention_mask = None
1373+
1374+
if attention_mask is not None:
1375+
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
1376+
# 0->false->keep this token -inf->true->mask this token
1377+
attention_mask = attention_mask.squeeze(1)[:, 0]
1378+
attention_mask = attention_mask.bool()
1379+
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
1380+
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
1381+
attention_mask = None
1382+
1383+
hidden_states = torch._transformer_encoder_layer_fwd(
1384+
hidden_states,
1385+
self.embed_dim,
1386+
self.num_heads,
1387+
self.in_proj_weight,
1388+
self.in_proj_bias,
1389+
self.out_proj_weight,
1390+
self.out_proj_bias,
1391+
self.use_gelu,
1392+
self.norm_first,
1393+
self.norm1_eps,
1394+
self.norm1_weight,
1395+
self.norm1_bias,
1396+
self.norm2_weight,
1397+
self.norm2_bias,
1398+
self.linear1_weight,
1399+
self.linear1_bias,
1400+
self.linear2_weight,
1401+
self.linear2_bias,
1402+
attention_mask,
1403+
)
1404+
if not self.is_last_layer:
1405+
hidden_states.original_shape = original_shape
1406+
elif hidden_states.is_nested and self.is_last_layer:
1407+
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
1408+
return (hidden_states,)
1409+
1410+
12781411
class CLIPLayerBetterTransformer(BetterTransformerBaseLayer):
12791412
def __init__(self, layer, config):
12801413
r"""

optimum/utils/testing_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414

1515

16-
import collections
1716
import importlib.util
1817
import itertools
1918
import os
2019
import subprocess
2120
import sys
2221
import unittest
22+
from collections.abc import MutableMapping
2323
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
2424

2525
import torch
@@ -35,7 +35,7 @@ def flatten_dict(dictionary: Dict):
3535
items = []
3636
for k, v in dictionary.items():
3737
new_key = k
38-
if isinstance(v, collections.MutableMapping):
38+
if isinstance(v, MutableMapping):
3939
items.extend(flatten_dict(v).items())
4040
else:
4141
items.append((new_key, v))

tests/bettertransformer/test_encoder_decoder.py

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest
4343
"marian",
4444
"mbart",
4545
"pegasus",
46+
"prophetnet",
4647
"t5",
4748
]
4849

tests/bettertransformer/testing_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"mbart": "hf-internal-testing/tiny-random-mbart",
5555
"opt": "hf-internal-testing/tiny-random-OPTModel",
5656
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
57+
"prophetnet": "hirotasoshu/tiny-random-prophetnet", # the other tiny ones have a too small max_position_embeddings
5758
"rembert": "hf-internal-testing/tiny-random-rembert",
5859
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
5960
"rocbert": "hf-internal-testing/tiny-random-RoCBertModel",

0 commit comments

Comments
 (0)