Skip to content

Commit 727716a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a1e0c51 commit 727716a

File tree

1 file changed

+43
-43
lines changed

1 file changed

+43
-43
lines changed

transformer_engine/jax/csrc/extensions/attention.cpp

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
174174
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
175175
auto ragged_offset_tensor =
176176
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
177-
nvte_fused_attn_fwd(
178-
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
179-
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
180-
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
181-
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
182-
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
183-
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
184-
mask_type, softmax_type, window_size_left, window_size_right,
185-
query_workspace_tensor.data(), nullptr);
177+
nvte_fused_attn_fwd(
178+
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
179+
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
180+
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
181+
ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
182+
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
183+
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
184+
window_size_right, query_workspace_tensor.data(), nullptr);
186185
}
187186

188187
nvte_tensor_pack_destroy(&aux_output_tensors);
@@ -270,7 +269,7 @@ static void FusedAttnForwardImpl(
270269

271270
/* Call the underlying NVTE API */
272271
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
273-
272+
274273
// Prepare Q, K, V pointers and shapes based on layout
275274
// Python passes dummy tensors for unused slots, so we extract from the actual packed data
276275
void *q_ptr = q;
@@ -279,15 +278,15 @@ static void FusedAttnForwardImpl(
279278
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
280279
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
281280
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
282-
281+
283282
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
284283
// QKV packed in q: [batch*seqlen, 3, heads, dim]
285284
// Python passes: q=packed_qkv, k=dummy, v=dummy
286285
// Extract K and V pointers from the packed q data
287286
size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
288287
q_ptr = q;
289-
k_ptr = static_cast<void*>(static_cast<int8_t*>(q) + stride);
290-
v_ptr = static_cast<void*>(static_cast<int8_t*>(q) + 2 * stride);
288+
k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
289+
v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
291290
// For packed QKV, all have same shape since they're views into the same packed tensor
292291
k_shape = q_shape;
293292
v_shape = q_shape;
@@ -298,16 +297,16 @@ static void FusedAttnForwardImpl(
298297
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
299298
q_ptr = q;
300299
k_ptr = k;
301-
v_ptr = static_cast<void*>(static_cast<int8_t*>(k) + stride);
300+
v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
302301
// V has same shape as K since they're views into the same packed tensor
303302
v_shape = k_shape;
304303
}
305304
// else NVTE_HD_HD_HD: pointers and shapes already correct
306-
305+
307306
auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
308307
auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
309308
auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
310-
309+
311310
nvte_fused_attn_fwd(
312311
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
313312
softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
@@ -454,7 +453,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
454453
softmax_type == NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX) {
455454
dummy_d_softmax_offset_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
456455
}
457-
456+
458457
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
459458
// the last one is the largest which will be the returned workspace size
460459
auto q_cu_seqlens_tensor =
@@ -463,7 +462,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
463462
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
464463
auto dummy_ragged_offset_tensor =
465464
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
466-
465+
467466
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
468467
doutput_tensor.data(),
469468
s_tensor.data(), // not used for F16
@@ -534,65 +533,66 @@ static void FusedAttnBackwardImpl(
534533
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
535534
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
536535
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
537-
536+
538537
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
539538
// QKV packed in q: [batch*seqlen, 3, heads, dim]
540539
size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
541540
q_ptr = q;
542-
k_ptr = static_cast<void*>(static_cast<int8_t*>(q) + stride);
543-
v_ptr = static_cast<void*>(static_cast<int8_t*>(q) + 2 * stride);
541+
k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
542+
v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
544543
dq_ptr = dq;
545-
dk_ptr = static_cast<void*>(static_cast<int8_t*>(dq) + stride);
546-
dv_ptr = static_cast<void*>(static_cast<int8_t*>(dq) + 2 * stride);
544+
dk_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + stride);
545+
dv_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + 2 * stride);
547546
k_shape = q_shape;
548547
v_shape = q_shape;
549548
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
550549
// Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
551550
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
552551
q_ptr = q;
553552
k_ptr = k;
554-
v_ptr = static_cast<void*>(static_cast<int8_t*>(k) + stride);
553+
v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
555554
dq_ptr = dq;
556555
dk_ptr = dk;
557-
dv_ptr = static_cast<void*>(static_cast<int8_t*>(dk) + stride);
556+
dv_ptr = static_cast<void *>(static_cast<int8_t *>(dk) + stride);
558557
v_shape = k_shape;
559558
}
560-
559+
561560
auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
562561
auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
563562
auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
564563
auto dq_tensor = TensorWrapper(dq_ptr, q_shape, dtype);
565564
auto dk_tensor = TensorWrapper(dk_ptr, k_shape, dtype);
566565
auto dv_tensor = TensorWrapper(dv_ptr, v_shape, dtype);
567-
566+
568567
if (is_ragged) {
569568
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
570569
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
571570
// For packed QKV, dq contains all gradients, so clear it once with full size
572-
cudaMemsetAsync(static_cast<int8_t*>(dq) + transformer_engine::jax::product(q_shape) * typeToSize(dtype),
573-
0, 2 * transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
571+
cudaMemsetAsync(
572+
static_cast<int8_t *>(dq) + transformer_engine::jax::product(q_shape) * typeToSize(dtype),
573+
0, 2 * transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
574574
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
575575
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream);
576576
// For packed KV, dk contains both K and V gradients
577-
cudaMemsetAsync(static_cast<int8_t*>(dk) + transformer_engine::jax::product(k_shape) * typeToSize(dtype),
578-
0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream);
577+
cudaMemsetAsync(
578+
static_cast<int8_t *>(dk) + transformer_engine::jax::product(k_shape) * typeToSize(dtype),
579+
0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream);
579580
} else {
580581
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream);
581582
cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream);
582583
}
583584
}
584-
585-
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
586-
doutput_tensor.data(),
587-
s_tensor.data(), // not used for F16
588-
s_tensor.data(), // not used for F16
589-
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
590-
dbias_tensor.data(), dsoftmax_offset_tensor.data(),
591-
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
592-
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
593-
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
594-
mask_type, softmax_type, window_size_left, window_size_right, deterministic,
595-
workspace_tensor.data(), stream);
585+
586+
nvte_fused_attn_bwd(
587+
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
588+
doutput_tensor.data(),
589+
s_tensor.data(), // not used for F16
590+
s_tensor.data(), // not used for F16
591+
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
592+
dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
593+
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
594+
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
595+
window_size_left, window_size_right, deterministic, workspace_tensor.data(), stream);
596596

597597
nvte_tensor_pack_destroy(&aux_input_tensors);
598598
}

0 commit comments

Comments
 (0)