From c036765b1eb611c1b7e8db72676f93b36e9ee36d Mon Sep 17 00:00:00 2001
From: Michael Goldfarb <mgoldfarb@nvidia.com>
Date: Tue, 29 Oct 2024 20:05:20 -0500
Subject: [PATCH] [JAX] Consolidate FFI and old descriptor implementation for
 fused attention. (#1295)

Consolidate FFI and old descriptor impleemntation for fused attention.

Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
---
 .../jax/csrc/extensions/attention.cpp         | 270 +++++-------------
 1 file changed, 72 insertions(+), 198 deletions(-)

diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp
index e4cc2112c6..541c51da58 100644
--- a/transformer_engine/jax/csrc/extensions/attention.cpp
+++ b/transformer_engine/jax/csrc/extensions/attention.cpp
@@ -185,46 +185,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
   return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
 }
 
-void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
-  const CustomCallFusedAttnDescriptor &descriptor =
-      *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
-  auto qkv_layout = descriptor.qkv_layout;
+static void FusedAttnForwardImpl(
+    cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens,
+    void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output,
+    void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch,
+    size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups,
+    size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size,
+    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
+    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
+    bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
   auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
 
-  /* Input buffers from XLA */
-  /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
-  void *bias = buffers[3];
-  void *q_cu_seqlens = buffers[4];
-  void *kv_cu_seqlens = buffers[5];
-  void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
-  void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
-  void *seed = buffers[8];
-
-  /* Output buffer from XLA */
-  void *output = buffers[9];
-  void *softmax_aux = buffers[10];
-  void *rng_state = buffers[11];
-  void *workspace = buffers[12];
-
-  /* Descriptor */
-  auto input_batch = descriptor.input_batch;
-  auto bias_batch = descriptor.bias_batch;
-  auto q_max_seqlen = descriptor.q_max_seqlen;
-  auto kv_max_seqlen = descriptor.kv_max_seqlen;
-  auto attn_heads = descriptor.attn_heads;
-  auto num_gqa_groups = descriptor.num_gqa_groups;
-  auto bias_heads = descriptor.bias_heads;
-  auto head_dim = descriptor.head_dim;
-  auto scaling_factor = descriptor.scaling_factor;
-  auto dropout_probability = descriptor.dropout_probability;
-  auto bias_type = descriptor.bias_type;
-  auto mask_type = descriptor.mask_type;
-  auto dtype = descriptor.dtype;
-  auto is_training = descriptor.is_training;
-  auto max_segments_per_seq = descriptor.max_segments_per_seq;
-  auto window_size_left = descriptor.window_size_left;
-  auto window_size_right = descriptor.window_size_right;
-
   /* Input tensors */
   auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
   auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
@@ -247,8 +218,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
       NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
       num_segments = runtime_num_segments_q;
     }
-    cudaMemsetAsync(output, 0,
-                    input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream);
+    auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
+    cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
   }
 
   auto q_cu_seqlens_tensor =
@@ -281,28 +252,25 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
                                     backend, softmax_aux);
 
   /* cuDNN workspace */
-  auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
-                                        descriptor.wkspace_dtype);
+  auto workspace_tensor =
+      TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);
 
-  /* Call the underly NVTE API */
+  /* Call the underlying NVTE API */
   auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
   if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
-    auto qkv = buffers[0];
     auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
-    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
-    nvte_fused_attn_fwd_qkvpacked(
-        qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
-        &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
-        rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor,
-        dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
-        workspace_tensor.data(), stream);
+    auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
+    nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
+                                  o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
+                                  q_seq_offsets_tensor.data(), rng_state_tensor.data(),
+                                  q_max_seqlen, is_training, scaling_factor, dropout_probability,
+                                  qkv_layout, bias_type, mask_type, window_size_left,
+                                  window_size_right, workspace_tensor.data(), stream);
   } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
-    auto q = buffers[0];
     auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
-    auto q_tensor = TensorWrapper(q, q_shape, dtype);
-    auto kv = buffers[1];
     auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
-    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
+    auto q_tensor = TensorWrapper(q, q_shape, dtype);
+    auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
     nvte_fused_attn_fwd_kvpacked(
         q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
         &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
@@ -310,14 +278,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
         q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
         bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
   } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
-    auto q = buffers[0];
     auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
-    auto q_tensor = TensorWrapper(q, q_shape, dtype);
-    auto k = buffers[1];
     auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
-    auto k_tensor = TensorWrapper(k, k_shape, dtype);
-    auto v = buffers[2];
     auto v_shape = k_shape;
+    auto q_tensor = TensorWrapper(q, q_shape, dtype);
+    auto k_tensor = TensorWrapper(k, k_shape, dtype);
     auto v_tensor = TensorWrapper(v, v_shape, dtype);
     nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                         s_tensor.data(), o_tensor.data(), &aux_output_tensors,
@@ -333,6 +298,37 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
   nvte_tensor_pack_destroy(&aux_output_tensors);
 }
 
+void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
+  const CustomCallFusedAttnDescriptor &descriptor =
+      *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
+  auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD;
+
+  /* Input buffers from XLA */
+  /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
+  void *bias = buffers[3];
+  void *q_cu_seqlens = buffers[4];
+  void *kv_cu_seqlens = buffers[5];
+  void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
+  void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
+  void *seed = buffers[8];
+
+  /* Output buffer from XLA */
+  void *output = buffers[9];
+  void *softmax_aux = buffers[10];
+  void *rng_state = buffers[11];
+  void *workspace = buffers[12];
+
+  FusedAttnForwardImpl(
+      stream, buffers[0], buffers[1], buffers[2], bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets,
+      k_seq_offsets, seed, output, softmax_aux, rng_state, workspace, descriptor.input_batch,
+      descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen,
+      descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
+      descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
+      descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
+      descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
+      descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
+}
+
 Error_Type FusedAttnForwardFFI(
     cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf,
     Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
@@ -344,147 +340,25 @@ Error_Type FusedAttnForwardFFI(
     double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_,
     int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic,
     int64_t window_size_left, int64_t window_size_right) {
-  /* Descriptor data type conversion */
-  size_t input_batch = static_cast<size_t>(input_batch_);
-  size_t bias_batch = static_cast<size_t>(bias_batch_);
-  size_t q_max_seqlen = static_cast<size_t>(q_max_seqlen_);
-  size_t kv_max_seqlen = static_cast<size_t>(kv_max_seqlen_);
-  size_t attn_heads = static_cast<size_t>(attn_heads_);
-  size_t num_gqa_groups = static_cast<size_t>(num_gqa_groups_);
-  size_t bias_heads = static_cast<size_t>(bias_heads_);
-  size_t head_dim = static_cast<size_t>(head_dim_);
-  size_t max_segments_per_seq = static_cast<size_t>(max_segments_per_seq_);
-  size_t wkspace_size = static_cast<size_t>(wkspace_size_);
-  float scaling_factor = static_cast<float>(scaling_factor_);
-  float dropout_probability = static_cast<float>(dropout_probability_);
-  NVTE_Bias_Type bias_type = static_cast<NVTE_Bias_Type>(bias_type_);
-  NVTE_Mask_Type mask_type = static_cast<NVTE_Mask_Type>(mask_type_);
   NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_);
-  DType dtype = static_cast<DType>(dtype_);
-  DType wkspace_dtype = static_cast<DType>(wkspace_dtype_);
   auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
 
-  /* Input buffers from XLA */
-  /* q, k, v are parsed later for different qkv_layout */
-  void *bias = bias_buf.untyped_data();
-  void *q_cu_seqlens = q_cu_seqlens_buf.untyped_data();
-  void *kv_cu_seqlens = kv_cu_seqlens_buf.untyped_data();
-  void *q_seq_offsets = is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr;
-  void *k_seq_offsets = is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr;
-  void *seed = seed_buf.untyped_data();
-
-  /* Output buffer from XLA */
-  void *output = output_buf->untyped_data();
-  void *softmax_aux = softmax_aux_buf->untyped_data();
-  void *rng_state = rng_state_buf->untyped_data();
-  void *workspace = workspace_buf->untyped_data();
-
-  /* Input tensors */
-  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
-  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
-  auto v_shape = k_shape;
-  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
-  auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
-
-  size_t num_segments = input_batch;  // Non-THD format, input_batch = num_segments
-  if (is_ragged) {
-    auto cudnn_runtime_version = cudnnGetVersion();
-    if (cudnn_runtime_version >= 90300) {
-      num_segments = input_batch * max_segments_per_seq;
-    } else {
-      // workspace can be reused here as it is not used with cuDNN graph at the same time
-      size_t runtime_num_segments_q =
-          GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
-      size_t runtime_num_segments_kv =
-          GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
-      NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
-      NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
-      num_segments = runtime_num_segments_q;
-    }
-    auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
-    cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
-  }
-
-  auto q_cu_seqlens_tensor =
-      TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
-  auto kv_cu_seqlens_tensor =
-      TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
-  auto q_seq_offsets_tensor =
-      TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
-  auto k_seq_offsets_tensor =
-      TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
-
-  /* Output tensors */
-  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
-  auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
-  auto o_tensor = TensorWrapper(output, o_shape, dtype);
-
-  /* Prepare RNG state */
-  auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
-  auto backend = nvte_get_fused_attn_backend(
-      static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
-      mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
-      head_dim, head_dim, window_size_left, window_size_right);
-  PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
-
-  /* Auxiliary tensors (to be propagated to the backward pass later) */
-  NVTETensorPack aux_output_tensors;
-  nvte_tensor_pack_create(&aux_output_tensors);
-  PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
-                                    bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
-                                    backend, softmax_aux);
-
-  /* cuDNN workspace */
-  auto workspace_tensor =
-      TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);
-
-  /* Call the underlying NVTE API */
-  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
-  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
-    auto qkv = q_buf.untyped_data();
-    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
-    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
-    nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
-                                  o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
-                                  q_seq_offsets_tensor.data(), rng_state_tensor.data(),
-                                  q_max_seqlen, is_training, scaling_factor, dropout_probability,
-                                  qkv_layout, bias_type, mask_type, window_size_left,
-                                  window_size_right, workspace_tensor.data(), stream);
-  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
-    auto q = q_buf.untyped_data();
-    auto kv = k_buf.untyped_data();
-    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
-    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
-    auto q_tensor = TensorWrapper(q, q_shape, dtype);
-    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
-    nvte_fused_attn_fwd_kvpacked(
-        q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
-        &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
-        q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
-        q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
-        bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
-  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
-    auto q = q_buf.untyped_data();
-    auto k = k_buf.untyped_data();
-    auto v = v_buf.untyped_data();
-    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
-    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
-    auto v_shape = k_shape;
-    auto q_tensor = TensorWrapper(q, q_shape, dtype);
-    auto k_tensor = TensorWrapper(k, k_shape, dtype);
-    auto v_tensor = TensorWrapper(v, v_shape, dtype);
-    nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
-                        s_tensor.data(), o_tensor.data(), &aux_output_tensors,
-                        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
-                        q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
-                        rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
-                        scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
-                        window_size_left, window_size_right, workspace_tensor.data(), stream);
-  } else {
-    NVTE_ERROR("Unsupported qkv_layout.");
-  }
-
-  nvte_tensor_pack_destroy(&aux_output_tensors);
+  FusedAttnForwardImpl(
+      stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
+      bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
+      is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
+      is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(),
+      output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(),
+      workspace_buf->untyped_data(), static_cast<size_t>(input_batch_),
+      static_cast<size_t>(bias_batch_), static_cast<size_t>(q_max_seqlen_),
+      static_cast<size_t>(kv_max_seqlen_), static_cast<size_t>(attn_heads_),
+      static_cast<size_t>(num_gqa_groups_), static_cast<size_t>(bias_heads_),
+      static_cast<size_t>(head_dim_), static_cast<size_t>(max_segments_per_seq_),
+      static_cast<size_t>(wkspace_size_), static_cast<float>(scaling_factor_),
+      static_cast<float>(dropout_probability_), static_cast<NVTE_Bias_Type>(bias_type_),
+      static_cast<NVTE_Mask_Type>(mask_type_), static_cast<NVTE_QKV_Layout>(qkv_layout_),
+      static_cast<DType>(dtype_), static_cast<DType>(wkspace_dtype_), is_training, deterministic,
+      window_size_left, window_size_right);
 
   return ffi_with_cuda_error_check();
 }