Skip to content

Commit ca98d03

Browse files
committed
[JAX] Fix bug with pre scale bias (NVIDIA#2300)
* fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent c65e2e9 commit ca98d03

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

transformer_engine/jax/flax/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __call__(
197197
fused_scale_factor = scale_factor
198198
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
199199
attn_weights += bias
200+
bias = None
200201

201202
def apply_swa_mask(original_mask: Array) -> Array:
202203
"""Apply the sliding window mask to a given mask"""

0 commit comments

Comments
 (0)