Skip to content

Commit 4ec98fa

Browse files
Assert in fused attn bwd pass for sm100
Signed-off-by: Kshitij Lakhani <[email protected]> Add check for sm100 Signed-off-by: Kshitij Lakhani <[email protected]>
1 parent 6a790cf commit 4ec98fa

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from jax.experimental.custom_partitioning import SdyShardingRule
1818

1919
import transformer_engine_jax
20-
from transformer_engine_jax import NVTE_Fused_Attn_Backend
20+
from transformer_engine_jax import (
21+
NVTE_Fused_Attn_Backend,
22+
get_device_compute_capability,
23+
)
2124
from transformer_engine.jax.attention import (
2225
AttnBiasType,
2326
AttnMaskType,
@@ -2745,6 +2748,11 @@ def fused_attn_bwd(
27452748
assert bias is None
27462749
bias = jnp.zeros(0, dtype=qkv[0].dtype)
27472750

2751+
if get_device_compute_capability==100:
2752+
assert (not(attn_bias_type != "no_bias" and dropout_probability != 0)
2753+
),"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
2754+
2755+
27482756
fused_config = _FusedAttnConfig(
27492757
attn_bias_type=attn_bias_type,
27502758
attn_mask_type=attn_mask_type,

0 commit comments

Comments
 (0)