Skip to content

Commit 47ab4a7

Browse files
[JAX] Add Transformer Layer tests for pre_scale_bias and post_scale_bias (#2104)
Add Transformer Layer tests for pre_scale_bias and post_scale_bias Signed-off-by: Kshitij Lakhani <[email protected]>
1 parent 2e23ad7 commit 47ab4a7

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tests/jax/test_layer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,16 @@ def enable_fused_attn():
263263
_KEY_OF_RELATIVE_EMBEDDING: False,
264264
_KEY_OF_WINDOW_SIZE: (2, 2),
265265
},
266+
# attrs29
267+
{
268+
_KEY_OF_RELATIVE_EMBEDDING: True,
269+
_KEY_OF_SELF_ATTN_BIAS_TYPE: "pre_scale_bias",
270+
},
271+
# attrs30
272+
{
273+
_KEY_OF_RELATIVE_EMBEDDING: True,
274+
_KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
275+
},
266276
]
267277

268278
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]

0 commit comments

Comments
 (0)