From a06d72c247aab3759727ce33e885c2dbd2d4f582 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Feb 2025 00:58:14 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- qa/L0_pytorch_unittest/test.sh | 2 +- .../fused_attn_f16_arbitrary_seqlen.cu | 14 ++++---- .../jax/csrc/extensions/attention.cpp | 32 ++++++++++--------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 438ab3d8fd..0c7a907051 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -25,4 +25,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || FAIL=1 -exit $FAIL \ No newline at end of file +exit $FAIL diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 102d44359f..e12122f822 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -439,14 +439,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devOffsetsK = nullptr; void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = - static_cast(devOffsets) + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; + devOffsetsK = static_cast(devOffsets) + + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { devOffsetsS = static_cast(devOffsets) + - (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * + num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( @@ -900,14 +901,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devOffsetsK = nullptr; void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = - static_cast(devOffsets) + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; + devOffsetsK = static_cast(devOffsets) + + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { devOffsetsS = static_cast(devOffsets) + - (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * + num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 83db39426c..4b64a113ab 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -165,16 +165,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), - nullptr); + ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), + dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, + kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_fwd( q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else { @@ -271,9 +272,10 @@ static void FusedAttnForwardImpl( nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), + dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, + is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; @@ -281,13 +283,13 @@ static void FusedAttnForwardImpl( auto q_tensor = TensorWrapper(q, q_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); }