Skip to content

Commit 3a3a0cf

Browse files
abhinavgoel95Abhinav Goel
authored and
Abhinav Goel
committed
rebased
1 parent 86f8de2 commit 3a3a0cf

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

praxis/layers/attentions.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,7 @@ class DotProductAttention(base_layer.BaseLayer):
10071007
decode_cache: bool = True
10081008
attention_mask_summary: bool = False
10091009
zero_fully_masked: bool = False
1010+
mha_mask_addition_pattern: bool = True
10101011
qk_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp)
10111012
pv_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp)
10121013

@@ -1342,8 +1343,14 @@ def _dot_atten(
13421343
logits = self._cap_logits(logits)
13431344
# Attention softmax is always carried out in fp32.
13441345
logits = logits.astype(jnp.float32)
1346+
13451347
# Apply attention masking
1346-
padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask)
1348+
if self.mha_mask_addition_pattern:
1349+
padded_logits = logits + atten_mask.astype(jnp.float32)
1350+
else:
1351+
padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask)
1352+
1353+
13471354
if self.attention_mask_summary:
13481355
self.add_summary('attention_mask', atten_mask)
13491356
if self.attention_extra_logit is None:

0 commit comments

Comments
 (0)