Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes for attention #1504

Merged
merged 2 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
// special conditions for blackwell
// TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7
!(sm_arch_ == 100 && (head_dim_qk > 128 || head_dim_v > 128)) &&
!(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) &&
// architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
Expand Down Expand Up @@ -238,12 +238,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
// TODO(cyang): fix bug for BRCM + cross-attention on sm100
(sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv &&
(sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv &&
cudnn_runtime_version <= 90700) ||
cudnn_runtime_version > 90700)))) ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
(sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv &&
(sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv &&
cudnn_runtime_version <= 90700) ||
cudnn_runtime_version > 90700))))) &&
max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
Expand Down
9 changes: 6 additions & 3 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _get_supported_versions(version_min, version_max):
_flash_attn_version = PkgVersion("0")
_flash_attn_version_required = PkgVersion("2.1.1")
_flash_attn_version_required_blackwell = PkgVersion("2.7.3")
_flash_attn_max_version = PkgVersion("2.7.3")
_flash_attn_max_version = PkgVersion("2.7.4.post1")
_flash_attn_2_plus = False
_flash_attn_2_1_plus = False
_flash_attn_2_3_plus = False
Expand Down Expand Up @@ -507,13 +507,16 @@ def get_attention_backend(
if use_flash_attention and (
head_dim_qk > 256
or head_dim_qk % 8 != 0
or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0)))
or (
head_dim_qk > 192
and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))
)
):
if _flash_attn_is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90). "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
Expand Down