File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed
transformer_engine/jax/cpp_extensions Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change 17
17
from jax .experimental .custom_partitioning import SdyShardingRule
18
18
19
19
import transformer_engine_jax
20
- from transformer_engine_jax import NVTE_Fused_Attn_Backend
20
+ from transformer_engine_jax import (
21
+ NVTE_Fused_Attn_Backend ,
22
+ get_device_compute_capability ,
23
+ )
21
24
from transformer_engine .jax .attention import (
22
25
AttnBiasType ,
23
26
AttnMaskType ,
@@ -2745,6 +2748,11 @@ def fused_attn_bwd(
2745
2748
assert bias is None
2746
2749
bias = jnp .zeros (0 , dtype = qkv [0 ].dtype )
2747
2750
2751
+ if get_device_compute_capability == 100 :
2752
+ assert (not (attn_bias_type != "no_bias" and dropout_probability != 0 )
2753
+ ),"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
2754
+
2755
+
2748
2756
fused_config = _FusedAttnConfig (
2749
2757
attn_bias_type = attn_bias_type ,
2750
2758
attn_mask_type = attn_mask_type ,
You can’t perform that action at this time.
0 commit comments