diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b3ca1d62c0..7666d3f32b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -507,7 +507,10 @@ def get_attention_backend( if use_flash_attention and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 - or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))) + or ( + head_dim_qk > 192 + and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) + ) ): if _flash_attn_is_installed: logger.debug(