Skip to content

Commit f378eaf

Browse files
[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)
* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani <[email protected]> * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> Add check for sm100 Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani <[email protected]> * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani <[email protected]> * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3b4366b commit f378eaf

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

tests/jax/test_fused_attn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from transformer_engine_jax import (
4242
NVTE_Fused_Attn_Backend,
4343
get_cudnn_version,
44+
get_device_compute_capability,
4445
)
4546

4647
from distributed_test_base import assert_equal_collectives
@@ -348,6 +349,14 @@ def _check_configs(self):
348349
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
349350
)
350351

352+
if (
353+
get_device_compute_capability(0) == 100
354+
and self.dropout_prob == 0.1
355+
and self.attn_bias_type is not AttnBiasType.NO_BIAS
356+
):
357+
pytest.skip(
358+
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
359+
)
351360
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
352361
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
353362
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
te_dtype_to_jax_dtype,
3535
get_padded_spec,
3636
get_cudnn_version,
37+
get_all_device_compute_capability,
3738
)
3839
from ..sharding import (
3940
global_mesh_resource,
@@ -2745,6 +2746,11 @@ def fused_attn_bwd(
27452746
assert bias is None
27462747
bias = jnp.zeros(0, dtype=qkv[0].dtype)
27472748

2749+
if 100 in get_all_device_compute_capability():
2750+
assert not (
2751+
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
2752+
), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
2753+
27482754
fused_config = _FusedAttnConfig(
27492755
attn_bias_type=attn_bias_type,
27502756
attn_mask_type=attn_mask_type,

transformer_engine/jax/cpp_extensions/misc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,16 @@ def get_min_device_compute_capability():
193193
)
194194

195195

196+
def get_all_device_compute_capability():
197+
"""
198+
Returns a list of compute capability of all local devices.
199+
"""
200+
return tuple(
201+
transformer_engine_jax.get_device_compute_capability(local_gpu_id)
202+
for local_gpu_id in range(len(jax.local_devices()))
203+
)
204+
205+
196206
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
197207
"""
198208
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to

0 commit comments

Comments
 (0)