From 94f42bf10bdffacea9955cbb62e02d548d30db78 Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Mon, 24 Feb 2025 15:57:46 -0800 Subject: [PATCH 1/2] minor fixes for attention Signed-off-by: Charlene Yang --- transformer_engine/common/fused_attn/fused_attn.cpp | 6 +++--- transformer_engine/pytorch/attention.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 01151a50db..13c99ae244 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -153,7 +153,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging // special conditions for blackwell // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 - !(sm_arch_ == 100 && (head_dim_qk > 128 || head_dim_v > 128)) && + !(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) && // architecture ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && @@ -238,12 +238,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && // TODO(cyang): fix bug for BRCM + cross-attention on sm100 - (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700))))) && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d6b9894fc3..b3ca1d62c0 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -118,7 +118,7 @@ def _get_supported_versions(version_min, version_max): _flash_attn_version = PkgVersion("0") _flash_attn_version_required = PkgVersion("2.1.1") _flash_attn_version_required_blackwell = PkgVersion("2.7.3") -_flash_attn_max_version = PkgVersion("2.7.3") +_flash_attn_max_version = PkgVersion("2.7.4.post1") _flash_attn_2_plus = False _flash_attn_2_1_plus = False _flash_attn_2_3_plus = False @@ -507,13 +507,13 @@ def get_attention_backend( if use_flash_attention and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 - or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0))) + or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))) ): if _flash_attn_is_installed: logger.debug( "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " - "head_dim_qk <= 256 (>192 requires sm80/90). " + "head_dim_qk <= 256 (>192 requires sm80/90/100+). " "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", head_dim_qk, head_dim_v, From a011c2e005a03220e114ad114d963836603e2f48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Feb 2025 00:04:59 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b3ca1d62c0..7666d3f32b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -507,7 +507,10 @@ def get_attention_backend( if use_flash_attention and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 - or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))) + or ( + head_dim_qk > 192 + and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) + ) ): if _flash_attn_is_installed: logger.debug(