From 87441885351bd3ae9f485bf81819d40166ccb043 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 26 Feb 2025 03:09:21 +0800 Subject: [PATCH] Minor fixes for attention (#1504) * minor fixes for attention Signed-off-by: Charlene Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 6 +++--- transformer_engine/pytorch/attention.py | 9 ++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 01151a50db..13c99ae244 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -153,7 +153,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging // special conditions for blackwell // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 - !(sm_arch_ == 100 && (head_dim_qk > 128 || head_dim_v > 128)) && + !(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) && // architecture ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && @@ -238,12 +238,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && // TODO(cyang): fix bug for BRCM + cross-attention on sm100 - (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700))))) && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d6b9894fc3..7666d3f32b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -118,7 +118,7 @@ def _get_supported_versions(version_min, version_max): _flash_attn_version = PkgVersion("0") _flash_attn_version_required = PkgVersion("2.1.1") _flash_attn_version_required_blackwell = PkgVersion("2.7.3") -_flash_attn_max_version = PkgVersion("2.7.3") +_flash_attn_max_version = PkgVersion("2.7.4.post1") _flash_attn_2_plus = False _flash_attn_2_1_plus = False _flash_attn_2_3_plus = False @@ -507,13 +507,16 @@ def get_attention_backend( if use_flash_attention and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 - or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0))) + or ( + head_dim_qk > 192 + and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) + ) ): if _flash_attn_is_installed: logger.debug( "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " - "head_dim_qk <= 256 (>192 requires sm80/90). " + "head_dim_qk <= 256 (>192 requires sm80/90/100+). " "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", head_dim_qk, head_dim_v,