File tree 1 file changed +8
-1
lines changed
1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -1007,6 +1007,7 @@ class DotProductAttention(base_layer.BaseLayer):
1007
1007
decode_cache : bool = True
1008
1008
attention_mask_summary : bool = False
1009
1009
zero_fully_masked : bool = False
1010
+ mha_mask_addition_pattern : bool = True
1010
1011
qk_einsum_tpl : LayerTpl = template_field (base_ops .EinsumOp )
1011
1012
pv_einsum_tpl : LayerTpl = template_field (base_ops .EinsumOp )
1012
1013
@@ -1342,8 +1343,14 @@ def _dot_atten(
1342
1343
logits = self ._cap_logits (logits )
1343
1344
# Attention softmax is always carried out in fp32.
1344
1345
logits = logits .astype (jnp .float32 )
1346
+
1345
1347
# Apply attention masking
1346
- padded_logits = py_utils .apply_mask_to_logits (logits , atten_mask )
1348
+ if self .mha_mask_addition_pattern :
1349
+ padded_logits = logits + atten_mask .astype (jnp .float32 )
1350
+ else :
1351
+ padded_logits = py_utils .apply_mask_to_logits (logits , atten_mask )
1352
+
1353
+
1347
1354
if self .attention_mask_summary :
1348
1355
self .add_summary ('attention_mask' , atten_mask )
1349
1356
if self .attention_extra_logit is None :
You can’t perform that action at this time.
0 commit comments