@@ -1275,6 +1275,139 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
1275
1275
return (hidden_states , attention_mask )
1276
1276
1277
1277
1278
+ class ProphetNetEncoderLayerBetterTransformer (BetterTransformerBaseLayer ):
1279
+ def __init__ (self , prophetnet_layer , config ):
1280
+ r"""
1281
+ A simple conversion of the ProphetNet Encoder layer to its `BetterTransformer` implementation.
1282
+
1283
+ Args:
1284
+ prophet_net_layer (`torch.nn.Module`):
1285
+ The original ProphetNet Layer where the weights needs to be retrieved.
1286
+ """
1287
+ super ().__init__ (config )
1288
+ self .config = config
1289
+ # In_proj layer
1290
+ self .in_proj_weight = nn .Parameter (
1291
+ torch .cat (
1292
+ [
1293
+ prophetnet_layer .self_attn .query_proj .weight ,
1294
+ prophetnet_layer .self_attn .key_proj .weight ,
1295
+ prophetnet_layer .self_attn .value_proj .weight ,
1296
+ ]
1297
+ )
1298
+ )
1299
+ self .in_proj_bias = nn .Parameter (
1300
+ torch .cat (
1301
+ [
1302
+ prophetnet_layer .self_attn .query_proj .bias ,
1303
+ prophetnet_layer .self_attn .key_proj .bias ,
1304
+ prophetnet_layer .self_attn .value_proj .bias ,
1305
+ ]
1306
+ )
1307
+ )
1308
+
1309
+ # Out proj layer
1310
+ self .out_proj_weight = prophetnet_layer .self_attn .out_proj .weight
1311
+ self .out_proj_bias = prophetnet_layer .self_attn .out_proj .bias
1312
+
1313
+ # Linear layer 1
1314
+ self .linear1_weight = prophetnet_layer .feed_forward .intermediate .weight
1315
+ self .linear1_bias = prophetnet_layer .feed_forward .intermediate .bias
1316
+
1317
+ # Linear layer 2
1318
+ self .linear2_weight = prophetnet_layer .feed_forward .output .weight
1319
+ self .linear2_bias = prophetnet_layer .feed_forward .output .bias
1320
+
1321
+ # Layer norm 1
1322
+ self .norm1_eps = prophetnet_layer .self_attn_layer_norm .eps
1323
+ self .norm1_weight = prophetnet_layer .self_attn_layer_norm .weight
1324
+ self .norm1_bias = prophetnet_layer .self_attn_layer_norm .bias
1325
+
1326
+ # Layer norm 2
1327
+ self .norm2_eps = prophetnet_layer .feed_forward_layer_norm .eps
1328
+ self .norm2_weight = prophetnet_layer .feed_forward_layer_norm .weight
1329
+ self .norm2_bias = prophetnet_layer .feed_forward_layer_norm .bias
1330
+
1331
+ # Model hyper parameters
1332
+ self .num_heads = prophetnet_layer .self_attn .num_attn_heads
1333
+ self .embed_dim = prophetnet_layer .self_attn .head_dim * self .num_heads
1334
+
1335
+ # Last step: set the last layer to `False` -> this will be set to `True` when converting the model
1336
+ self .is_last_layer = False
1337
+
1338
+ self .original_layers_mapping = {
1339
+ "in_proj_weight" : [
1340
+ "self_attn.query_proj.weight" ,
1341
+ "self_attn.key_proj.weight" ,
1342
+ "self_attn.value_proj.weight" ,
1343
+ ],
1344
+ "in_proj_bias" : ["self_attn.query_proj.bias" , "self_attn.key_proj.bias" , "self_attn.value_proj.bias" ],
1345
+ "out_proj_weight" : "self_attn.out_proj.weight" ,
1346
+ "out_proj_bias" : "self_attn.out_proj.bias" ,
1347
+ "linear1_weight" : "feed_forward.intermediate.weight" ,
1348
+ "linear1_bias" : "feed_forward.intermediate.bias" ,
1349
+ "linear2_weight" : "feed_forward.output.weight" ,
1350
+ "linear2_bias" : "feed_forward.output.bias" ,
1351
+ "norm1_weight" : "self_attn_layer_norm.weight" ,
1352
+ "norm1_bias" : "self_attn_layer_norm.bias" ,
1353
+ "norm2_weight" : "feed_forward_layer_norm.weight" ,
1354
+ "norm2_bias" : "feed_forward_layer_norm.bias" ,
1355
+ }
1356
+
1357
+ self .validate_bettertransformer ()
1358
+
1359
+ def forward (self , hidden_states , attention_mask , * _ , ** __ ):
1360
+ r"""
1361
+ This is just a wrapper around the forward function proposed in:
1362
+ https://github.com/huggingface/transformers/pull/19553
1363
+ """
1364
+ super ().forward_checker ()
1365
+
1366
+ if not hasattr (hidden_states , "original_shape" ):
1367
+ original_shape = hidden_states .shape
1368
+ else :
1369
+ original_shape = hidden_states .original_shape
1370
+
1371
+ if hidden_states .is_nested :
1372
+ attention_mask = None
1373
+
1374
+ if attention_mask is not None :
1375
+ # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
1376
+ # 0->false->keep this token -inf->true->mask this token
1377
+ attention_mask = attention_mask .squeeze (1 )[:, 0 ]
1378
+ attention_mask = attention_mask .bool ()
1379
+ attention_mask = torch .reshape (attention_mask , (attention_mask .shape [0 ], attention_mask .shape [- 1 ]))
1380
+ hidden_states = torch ._nested_tensor_from_mask (hidden_states , ~ attention_mask )
1381
+ attention_mask = None
1382
+
1383
+ hidden_states = torch ._transformer_encoder_layer_fwd (
1384
+ hidden_states ,
1385
+ self .embed_dim ,
1386
+ self .num_heads ,
1387
+ self .in_proj_weight ,
1388
+ self .in_proj_bias ,
1389
+ self .out_proj_weight ,
1390
+ self .out_proj_bias ,
1391
+ self .use_gelu ,
1392
+ self .norm_first ,
1393
+ self .norm1_eps ,
1394
+ self .norm1_weight ,
1395
+ self .norm1_bias ,
1396
+ self .norm2_weight ,
1397
+ self .norm2_bias ,
1398
+ self .linear1_weight ,
1399
+ self .linear1_bias ,
1400
+ self .linear2_weight ,
1401
+ self .linear2_bias ,
1402
+ attention_mask ,
1403
+ )
1404
+ if not self .is_last_layer :
1405
+ hidden_states .original_shape = original_shape
1406
+ elif hidden_states .is_nested and self .is_last_layer :
1407
+ hidden_states = hidden_states .to_padded_tensor (0.0 , original_shape )
1408
+ return (hidden_states ,)
1409
+
1410
+
1278
1411
class CLIPLayerBetterTransformer (BetterTransformerBaseLayer ):
1279
1412
def __init__ (self , layer , config ):
1280
1413
r"""
0 commit comments