Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 25, 2025
1 parent 62cffc8 commit a06d72c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 22 deletions.
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || FAIL=1

exit $FAIL
exit $FAIL
Original file line number Diff line number Diff line change
Expand Up @@ -439,14 +439,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
if (is_ragged_kv) {
devOffsetsK =
static_cast<int8_t *>(devOffsets) + static_cast<int>(is_ragged_q) * 2 * num_bytes_per_ragged_offset;
devOffsetsK = static_cast<int8_t *>(devOffsets) +
static_cast<int>(is_ragged_q) * 2 * num_bytes_per_ragged_offset;
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset;
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
}
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
Expand Down Expand Up @@ -900,14 +901,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
if (is_ragged_kv) {
devOffsetsK =
static_cast<int8_t *>(devOffsets) + static_cast<int>(is_ragged_q) * 2 * num_bytes_per_ragged_offset;
devOffsetsK = static_cast<int8_t *>(devOffsets) +
static_cast<int>(is_ragged_q) * 2 * num_bytes_per_ragged_offset;
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset;
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
}
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
Expand Down
32 changes: 17 additions & 15 deletions transformer_engine/jax/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(),
nullptr);
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
} else {
Expand Down Expand Up @@ -271,23 +272,24 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
Expand Down

0 comments on commit a06d72c

Please sign in to comment.