From c036765b1eb611c1b7e8db72676f93b36e9ee36d Mon Sep 17 00:00:00 2001 From: Michael Goldfarb <mgoldfarb@nvidia.com> Date: Tue, 29 Oct 2024 20:05:20 -0500 Subject: [PATCH] [JAX] Consolidate FFI and old descriptor implementation for fused attention. (#1295) Consolidate FFI and old descriptor impleemntation for fused attention. Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com> Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- .../jax/csrc/extensions/attention.cpp | 270 +++++------------- 1 file changed, 72 insertions(+), 198 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index e4cc2112c6..541c51da58 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -185,46 +185,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } -void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); - auto qkv_layout = descriptor.qkv_layout; +static void FusedAttnForwardImpl( + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens, + void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output, + void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, + size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, + size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; - /* Input buffers from XLA */ - /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ - void *bias = buffers[3]; - void *q_cu_seqlens = buffers[4]; - void *kv_cu_seqlens = buffers[5]; - void *q_seq_offsets = is_ragged ? buffers[6] : nullptr; - void *k_seq_offsets = is_ragged ? buffers[7] : nullptr; - void *seed = buffers[8]; - - /* Output buffer from XLA */ - void *output = buffers[9]; - void *softmax_aux = buffers[10]; - void *rng_state = buffers[11]; - void *workspace = buffers[12]; - - /* Descriptor */ - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - auto dtype = descriptor.dtype; - auto is_training = descriptor.is_training; - auto max_segments_per_seq = descriptor.max_segments_per_seq; - auto window_size_left = descriptor.window_size_left; - auto window_size_right = descriptor.window_size_right; - /* Input tensors */ auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; @@ -247,8 +218,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); num_segments = runtime_num_segments_q; } - cudaMemsetAsync(output, 0, - input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream); + auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; + cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); } auto q_cu_seqlens_tensor = @@ -281,28 +252,25 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s backend, softmax_aux); /* cuDNN workspace */ - auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size}, - descriptor.wkspace_dtype); + auto workspace_tensor = + TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype); - /* Call the underly NVTE API */ + /* Call the underlying NVTE API */ auto layout_group = nvte_get_qkv_layout_group(qkv_layout); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv = buffers[0]; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - workspace_tensor.data(), stream); + auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); + nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, is_training, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q = buffers[0]; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv = buffers[1]; auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv_tensor = TensorWrapper(k, kv_shape, dtype); nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), @@ -310,14 +278,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q = buffers[0]; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k = buffers[1]; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v = buffers[2]; auto v_shape = k_shape; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, @@ -333,6 +298,37 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s nvte_tensor_pack_destroy(&aux_output_tensors); } +void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + const CustomCallFusedAttnDescriptor &descriptor = + *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); + auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD; + + /* Input buffers from XLA */ + /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ + void *bias = buffers[3]; + void *q_cu_seqlens = buffers[4]; + void *kv_cu_seqlens = buffers[5]; + void *q_seq_offsets = is_ragged ? buffers[6] : nullptr; + void *k_seq_offsets = is_ragged ? buffers[7] : nullptr; + void *seed = buffers[8]; + + /* Output buffer from XLA */ + void *output = buffers[9]; + void *softmax_aux = buffers[10]; + void *rng_state = buffers[11]; + void *workspace = buffers[12]; + + FusedAttnForwardImpl( + stream, buffers[0], buffers[1], buffers[2], bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, + k_seq_offsets, seed, output, softmax_aux, rng_state, workspace, descriptor.input_batch, + descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, + descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, + descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor, + descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type, + descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training, + descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right); +} + Error_Type FusedAttnForwardFFI( cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, @@ -344,147 +340,25 @@ Error_Type FusedAttnForwardFFI( double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_, int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { - /* Descriptor data type conversion */ - size_t input_batch = static_cast<size_t>(input_batch_); - size_t bias_batch = static_cast<size_t>(bias_batch_); - size_t q_max_seqlen = static_cast<size_t>(q_max_seqlen_); - size_t kv_max_seqlen = static_cast<size_t>(kv_max_seqlen_); - size_t attn_heads = static_cast<size_t>(attn_heads_); - size_t num_gqa_groups = static_cast<size_t>(num_gqa_groups_); - size_t bias_heads = static_cast<size_t>(bias_heads_); - size_t head_dim = static_cast<size_t>(head_dim_); - size_t max_segments_per_seq = static_cast<size_t>(max_segments_per_seq_); - size_t wkspace_size = static_cast<size_t>(wkspace_size_); - float scaling_factor = static_cast<float>(scaling_factor_); - float dropout_probability = static_cast<float>(dropout_probability_); - NVTE_Bias_Type bias_type = static_cast<NVTE_Bias_Type>(bias_type_); - NVTE_Mask_Type mask_type = static_cast<NVTE_Mask_Type>(mask_type_); NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_); - DType dtype = static_cast<DType>(dtype_); - DType wkspace_dtype = static_cast<DType>(wkspace_dtype_); auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; - /* Input buffers from XLA */ - /* q, k, v are parsed later for different qkv_layout */ - void *bias = bias_buf.untyped_data(); - void *q_cu_seqlens = q_cu_seqlens_buf.untyped_data(); - void *kv_cu_seqlens = kv_cu_seqlens_buf.untyped_data(); - void *q_seq_offsets = is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr; - void *k_seq_offsets = is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr; - void *seed = seed_buf.untyped_data(); - - /* Output buffer from XLA */ - void *output = output_buf->untyped_data(); - void *softmax_aux = softmax_aux_buf->untyped_data(); - void *rng_state = rng_state_buf->untyped_data(); - void *workspace = workspace_buf->untyped_data(); - - /* Input tensors */ - auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; - auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - - size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments - if (is_ragged) { - auto cudnn_runtime_version = cudnnGetVersion(); - if (cudnn_runtime_version >= 90300) { - num_segments = input_batch * max_segments_per_seq; - } else { - // workspace can be reused here as it is not used with cuDNN graph at the same time - size_t runtime_num_segments_q = - GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); - size_t runtime_num_segments_kv = - GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); - NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); - NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); - num_segments = runtime_num_segments_q; - } - auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; - cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); - } - - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32); - auto q_seq_offsets_tensor = - TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32); - auto k_seq_offsets_tensor = - TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32); - - /* Output tensors */ - auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 - auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto o_tensor = TensorWrapper(output, o_shape, dtype); - - /* Prepare RNG state */ - auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); - auto backend = nvte_get_fused_attn_backend( - static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); - PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); - - /* Auxiliary tensors (to be propagated to the backward pass later) */ - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads, - bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type, - backend, softmax_aux); - - /* cuDNN workspace */ - auto workspace_tensor = - TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype); - - /* Call the underlying NVTE API */ - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv = q_buf.untyped_data(); - auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, is_training, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q = q_buf.untyped_data(); - auto kv = k_buf.untyped_data(); - auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q = q_buf.untyped_data(); - auto k = k_buf.untyped_data(); - auto v = v_buf.untyped_data(); - auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); - } - - nvte_tensor_pack_destroy(&aux_output_tensors); + FusedAttnForwardImpl( + stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), + bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(), + is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(), + output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), + workspace_buf->untyped_data(), static_cast<size_t>(input_batch_), + static_cast<size_t>(bias_batch_), static_cast<size_t>(q_max_seqlen_), + static_cast<size_t>(kv_max_seqlen_), static_cast<size_t>(attn_heads_), + static_cast<size_t>(num_gqa_groups_), static_cast<size_t>(bias_heads_), + static_cast<size_t>(head_dim_), static_cast<size_t>(max_segments_per_seq_), + static_cast<size_t>(wkspace_size_), static_cast<float>(scaling_factor_), + static_cast<float>(dropout_probability_), static_cast<NVTE_Bias_Type>(bias_type_), + static_cast<NVTE_Mask_Type>(mask_type_), static_cast<NVTE_QKV_Layout>(qkv_layout_), + static_cast<DType>(dtype_), static_cast<DType>(wkspace_dtype_), is_training, deterministic, + window_size_left, window_size_right); return ffi_with_cuda_error_check(); }