@@ -174,15 +174,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
174174 TensorWrapper (nullptr , std::vector<size_t >{num_segments + 1 }, DType::kInt32 );
175175 auto ragged_offset_tensor =
176176 TensorWrapper (nullptr , std::vector<size_t >{num_segments + 1 }, DType::kInt32 );
177- nvte_fused_attn_fwd (
178- q_tensor.data (), k_tensor.data (), v_tensor.data (), bias_tensor.data (),
179- dummy_softmax_offset_tensor.data (), s_tensor.data (), o_tensor.data (), &aux_output_tensors,
180- q_cu_seqlens_tensor.data (), kv_cu_seqlens_tensor.data (), ragged_offset_tensor.data (),
181- ragged_offset_tensor.data (), dummy_page_table_tensor.data (),
182- dummy_page_table_tensor.data (), dummy_rng_state_tensor.data (), q_max_seqlen,
183- kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
184- mask_type, softmax_type, window_size_left, window_size_right,
185- query_workspace_tensor.data (), nullptr );
177+ nvte_fused_attn_fwd (
178+ q_tensor.data (), k_tensor.data (), v_tensor.data (), bias_tensor.data (),
179+ dummy_softmax_offset_tensor.data (), s_tensor.data (), o_tensor.data (), &aux_output_tensors,
180+ q_cu_seqlens_tensor.data (), kv_cu_seqlens_tensor.data (), ragged_offset_tensor.data (),
181+ ragged_offset_tensor.data (), dummy_page_table_tensor.data (), dummy_page_table_tensor.data (),
182+ dummy_rng_state_tensor.data (), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
183+ dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
184+ window_size_right, query_workspace_tensor.data (), nullptr );
186185 }
187186
188187 nvte_tensor_pack_destroy (&aux_output_tensors);
@@ -270,7 +269,7 @@ static void FusedAttnForwardImpl(
270269
271270 /* Call the underlying NVTE API */
272271 auto dummy_page_table_tensor = TensorWrapper (nullptr , std::vector<size_t >{1 }, DType::kInt32 );
273-
272+
274273 // Prepare Q, K, V pointers and shapes based on layout
275274 // Python passes dummy tensors for unused slots, so we extract from the actual packed data
276275 void *q_ptr = q;
@@ -279,15 +278,15 @@ static void FusedAttnForwardImpl(
279278 auto q_shape = std::vector<size_t >{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
280279 auto k_shape = std::vector<size_t >{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
281280 auto v_shape = std::vector<size_t >{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
282-
281+
283282 if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
284283 // QKV packed in q: [batch*seqlen, 3, heads, dim]
285284 // Python passes: q=packed_qkv, k=dummy, v=dummy
286285 // Extract K and V pointers from the packed q data
287286 size_t stride = (typeToSize (dtype) * attn_heads * qk_head_dim);
288287 q_ptr = q;
289- k_ptr = static_cast <void *>(static_cast <int8_t *>(q) + stride);
290- v_ptr = static_cast <void *>(static_cast <int8_t *>(q) + 2 * stride);
288+ k_ptr = static_cast <void *>(static_cast <int8_t *>(q) + stride);
289+ v_ptr = static_cast <void *>(static_cast <int8_t *>(q) + 2 * stride);
291290 // For packed QKV, all have same shape since they're views into the same packed tensor
292291 k_shape = q_shape;
293292 v_shape = q_shape;
@@ -298,16 +297,16 @@ static void FusedAttnForwardImpl(
298297 size_t stride = (typeToSize (dtype) * num_gqa_groups * qk_head_dim);
299298 q_ptr = q;
300299 k_ptr = k;
301- v_ptr = static_cast <void *>(static_cast <int8_t *>(k) + stride);
300+ v_ptr = static_cast <void *>(static_cast <int8_t *>(k) + stride);
302301 // V has same shape as K since they're views into the same packed tensor
303302 v_shape = k_shape;
304303 }
305304 // else NVTE_HD_HD_HD: pointers and shapes already correct
306-
305+
307306 auto q_tensor = TensorWrapper (q_ptr, q_shape, dtype);
308307 auto k_tensor = TensorWrapper (k_ptr, k_shape, dtype);
309308 auto v_tensor = TensorWrapper (v_ptr, v_shape, dtype);
310-
309+
311310 nvte_fused_attn_fwd (
312311 q_tensor.data (), k_tensor.data (), v_tensor.data (), bias_tensor.data (),
313312 softmax_offset_tensor.data (), s_tensor.data (), o_tensor.data (), &aux_output_tensors,
@@ -454,7 +453,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
454453 softmax_type == NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX) {
455454 dummy_d_softmax_offset_tensor = TensorWrapper (nullptr , std::vector<size_t >{1 }, DType::kFloat32 );
456455 }
457-
456+
458457 for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
459458 // the last one is the largest which will be the returned workspace size
460459 auto q_cu_seqlens_tensor =
@@ -463,7 +462,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
463462 TensorWrapper (nullptr , std::vector<size_t >{num_segments + 1 }, DType::kInt32 );
464463 auto dummy_ragged_offset_tensor =
465464 TensorWrapper (nullptr , std::vector<size_t >{num_segments + 1 }, DType::kInt32 );
466-
465+
467466 nvte_fused_attn_bwd (q_tensor.data (), k_tensor.data (), v_tensor.data (), output_tensor.data (),
468467 doutput_tensor.data (),
469468 s_tensor.data (), // not used for F16
@@ -534,65 +533,66 @@ static void FusedAttnBackwardImpl(
534533 auto q_shape = std::vector<size_t >{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
535534 auto k_shape = std::vector<size_t >{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
536535 auto v_shape = std::vector<size_t >{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
537-
536+
538537 if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
539538 // QKV packed in q: [batch*seqlen, 3, heads, dim]
540539 size_t stride = (typeToSize (dtype) * attn_heads * qk_head_dim);
541540 q_ptr = q;
542- k_ptr = static_cast <void *>(static_cast <int8_t *>(q) + stride);
543- v_ptr = static_cast <void *>(static_cast <int8_t *>(q) + 2 * stride);
541+ k_ptr = static_cast <void *>(static_cast <int8_t *>(q) + stride);
542+ v_ptr = static_cast <void *>(static_cast <int8_t *>(q) + 2 * stride);
544543 dq_ptr = dq;
545- dk_ptr = static_cast <void *>(static_cast <int8_t *>(dq) + stride);
546- dv_ptr = static_cast <void *>(static_cast <int8_t *>(dq) + 2 * stride);
544+ dk_ptr = static_cast <void *>(static_cast <int8_t *>(dq) + stride);
545+ dv_ptr = static_cast <void *>(static_cast <int8_t *>(dq) + 2 * stride);
547546 k_shape = q_shape;
548547 v_shape = q_shape;
549548 } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
550549 // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
551550 size_t stride = (typeToSize (dtype) * num_gqa_groups * qk_head_dim);
552551 q_ptr = q;
553552 k_ptr = k;
554- v_ptr = static_cast <void *>(static_cast <int8_t *>(k) + stride);
553+ v_ptr = static_cast <void *>(static_cast <int8_t *>(k) + stride);
555554 dq_ptr = dq;
556555 dk_ptr = dk;
557- dv_ptr = static_cast <void *>(static_cast <int8_t *>(dk) + stride);
556+ dv_ptr = static_cast <void *>(static_cast <int8_t *>(dk) + stride);
558557 v_shape = k_shape;
559558 }
560-
559+
561560 auto q_tensor = TensorWrapper (q_ptr, q_shape, dtype);
562561 auto k_tensor = TensorWrapper (k_ptr, k_shape, dtype);
563562 auto v_tensor = TensorWrapper (v_ptr, v_shape, dtype);
564563 auto dq_tensor = TensorWrapper (dq_ptr, q_shape, dtype);
565564 auto dk_tensor = TensorWrapper (dk_ptr, k_shape, dtype);
566565 auto dv_tensor = TensorWrapper (dv_ptr, v_shape, dtype);
567-
566+
568567 if (is_ragged) {
569568 cudaMemsetAsync (dq, 0 , transformer_engine::jax::product (q_shape) * typeToSize (dtype), stream);
570569 if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
571570 // For packed QKV, dq contains all gradients, so clear it once with full size
572- cudaMemsetAsync (static_cast <int8_t *>(dq) + transformer_engine::jax::product (q_shape) * typeToSize (dtype),
573- 0 , 2 * transformer_engine::jax::product (q_shape) * typeToSize (dtype), stream);
571+ cudaMemsetAsync (
572+ static_cast <int8_t *>(dq) + transformer_engine::jax::product (q_shape) * typeToSize (dtype),
573+ 0 , 2 * transformer_engine::jax::product (q_shape) * typeToSize (dtype), stream);
574574 } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
575575 cudaMemsetAsync (dk, 0 , transformer_engine::jax::product (k_shape) * typeToSize (dtype), stream);
576576 // For packed KV, dk contains both K and V gradients
577- cudaMemsetAsync (static_cast <int8_t *>(dk) + transformer_engine::jax::product (k_shape) * typeToSize (dtype),
578- 0 , transformer_engine::jax::product (k_shape) * typeToSize (dtype), stream);
577+ cudaMemsetAsync (
578+ static_cast <int8_t *>(dk) + transformer_engine::jax::product (k_shape) * typeToSize (dtype),
579+ 0 , transformer_engine::jax::product (k_shape) * typeToSize (dtype), stream);
579580 } else {
580581 cudaMemsetAsync (dk, 0 , transformer_engine::jax::product (k_shape) * typeToSize (dtype), stream);
581582 cudaMemsetAsync (dv, 0 , transformer_engine::jax::product (v_shape) * typeToSize (dtype), stream);
582583 }
583584 }
584-
585- nvte_fused_attn_bwd (q_tensor.data (), k_tensor.data (), v_tensor.data (), output_tensor.data (),
586- doutput_tensor.data (),
587- s_tensor.data (), // not used for F16
588- s_tensor.data (), // not used for F16
589- &aux_input_tensors, dq_tensor.data (), dk_tensor.data (), dv_tensor.data (),
590- dbias_tensor.data (), dsoftmax_offset_tensor.data (),
591- q_cu_seqlens_tensor.data (), kv_cu_seqlens_tensor.data (),
592- q_seq_offsets_tensor.data (), k_seq_offsets_tensor.data (), q_max_seqlen,
593- kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
594- mask_type, softmax_type, window_size_left, window_size_right, deterministic,
595- workspace_tensor.data (), stream);
585+
586+ nvte_fused_attn_bwd (
587+ q_tensor.data (), k_tensor.data (), v_tensor.data (), output_tensor.data (),
588+ doutput_tensor.data (),
589+ s_tensor.data (), // not used for F16
590+ s_tensor.data (), // not used for F16
591+ &aux_input_tensors, dq_tensor.data (), dk_tensor.data (), dv_tensor.data (), dbias_tensor.data (),
592+ dsoftmax_offset_tensor.data (), q_cu_seqlens_tensor.data (), kv_cu_seqlens_tensor.data (),
593+ q_seq_offsets_tensor.data (), k_seq_offsets_tensor.data (), q_max_seqlen, kv_max_seqlen,
594+ scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
595+ window_size_left, window_size_right, deterministic, workspace_tensor.data (), stream);
596596
597597 nvte_tensor_pack_destroy (&aux_input_tensors);
598598}
0 commit comments