From 085a452bba95dad1aa3f8e61c4e7e241a18ca59f Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 9 Oct 2025 14:13:15 -0700 Subject: [PATCH] Initialize TE/JAX primitives with stream capture to trigger module loading --- transformer_engine/jax/csrc/extensions.h | 14 +++ .../jax/csrc/extensions/attention.cpp | 69 ++++++++++++ .../jax/csrc/extensions/gemm.cpp | 88 +++++++++++++++ .../jax/csrc/extensions/pybind.cpp | 64 +++++++---- .../jax/csrc/extensions/quantization.cpp | 74 ++++++++++++ .../jax/csrc/extensions/softmax.cpp | 105 ++++++++++++++++++ 6 files changed, 393 insertions(+), 21 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3ce6dee731..8f98d2195c 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -78,10 +78,13 @@ pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_s bool zero_centered_gamma, int sm_margin); // Quantization +XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, @@ -90,21 +93,29 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ QuantizeLayout q_layout); // Softmax +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Attention +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, @@ -131,11 +142,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( int64_t window_size_right); // GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); // Cudnn helpers diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 9277569e11..5b953706b4 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -671,5 +671,74 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, .Attrs(), FFI_CudaGraph_Traits); +Error_Type FusedAttnForwardInitializeFFI( + cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, + Buffer_Type bias_buf, Buffer_Type seed_buf, Buffer_Type q_cu_seqlens_buf, + Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Variadic_Buffer_Type _unused_args, Result_Type output_buf, Result_Type softmax_aux_buf, + Result_Type rng_state_buf, Result_Type workspace_buf, Dictionary attrs) { + return wrapInStreamCapture(std::function(FusedAttnForwardFFI), stream, q_buf, k_buf, v_buf, + bias_buf, seed_buf, q_cu_seqlens_buf, kv_cu_seqlens_buf, + q_seq_offsets_buf, k_seq_offsets_buf, _unused_args, output_buf, + softmax_aux_buf, rng_state_buf, workspace_buf, attrs); +} + +Error_Type FusedAttnBackwardInitializeFFI( + cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, + Buffer_Type bias_buf, Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf, + Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type q_cu_seqlens_buf, + Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Variadic_Buffer_Type _unused_args, Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf, + Result_Type dbias_buf, Result_Type workspace_buf, Dictionary attrs) { + return wrapInStreamCapture(std::function(FusedAttnBackwardFFI), stream, q_buf, k_buf, v_buf, + bias_buf, softmax_aux_buf, rng_state_buf, output_buf, doutput_buf, + q_cu_seqlens_buf, kv_cu_seqlens_buf, q_seq_offsets_buf, + k_seq_offsets_buf, _unused_args, dq_buf, dk_buf, dv_buf, dbias_buf, + workspace_buf, attrs); +} + +// FFI Handler Symbols for Initialization +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardInitializeHandler, FusedAttnForwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .Arg() // bias + .Arg() // seed_buf + .Arg() // q_cu_seqlens + .Arg() // kv_cu_seqlens + .Arg() // q_seq_offsets + .Arg() // k_seq_offsets + .RemainingArgs() // _cp_aux_args unused + .Ret() // output + .Ret() // softmax_aux + .Ret() // rng_state + .Ret() // workspace + .Attrs()); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardInitializeHandler, FusedAttnBackwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .Arg() // bias + .Arg() // softmax_aux + .Arg() // rng_state + .Arg() // output + .Arg() // doutput + .Arg() // q_cu_seqlens + .Arg() // kv_cu_seqlens + .Arg() // q_seq_offsets + .Arg() // k_seq_offsets + .RemainingArgs() // _cp_aux_args unused + .Ret() // dq + .Ret() // dk + .Ret() // dv + .Ret() // dbias + .Ret() // workspace + .Attrs()); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 993ec1377d..8db0eef61b 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -742,5 +742,93 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("is_grouped_dense_wgrad") .Attr("use_async_d2h_group_sizes")); +Error_Type GemmInitializeFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, + Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator, + JAXX_Collective_Op collective_op) { + return wrapInStreamCapture(std::function(GemmFFI), stream, lhs, lhs_scale_inv, rhs, rhs_scale_inv, + bias, gelu_input, output, bias_grad, pre_gelu_out, workspace, + scaling_mode, lhs_axis_boundary, rhs_axis_boundary, lhs_transposed, + rhs_transposed, fuse_bias, fuse_gelu, grad, use_split_accumulator, + collective_op); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmInitializeHandler, GemmInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // workspace + .Attr("scaling_mode") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad") + .Attr("use_split_accumulator") + .Attr("collective_op")); + +Error_Type GroupedGemmD2HGroupSizesInitializeFFI(cudaStream_t stream, Buffer_Type group_sizes, + Result_Type dummy_output, size_t num_gemms) { + return wrapInStreamCapture(std::function(GroupedGemmD2HGroupSizesFFI), stream, group_sizes, + dummy_output, num_gemms); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesInitializeHandler, + GroupedGemmD2HGroupSizesInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // group_sizes + .Ret() // dummy_output + .Attr("num_gemms")); + +Error_Type GroupedGemmInitializeFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type group_sizes, Buffer_Type group_offset, + Result_Type output, Result_Type workspace, size_t m, size_t n, + size_t k, bool lhs_is_trans, bool rhs_is_trans, + JAXX_Scaling_Mode scaling_mode, bool has_bias, + bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { + return wrapInStreamCapture(std::function(GroupedGemmFFI), stream, lhs_data, lhs_sinv, rhs_data, + rhs_sinv, bias, group_sizes, group_offset, output, workspace, m, n, k, + lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, + is_grouped_dense_wgrad, use_async_d2h_group_sizes); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmInitializeHandler, GroupedGemmInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs_data + .Arg() // lhs_sinv + .Arg() // rhs_data + .Arg() // rhs_sinv + .Arg() // bias + .Arg() // group_sizes + .Arg() // group_offset + .Ret() // output + .Ret() // workspace + .Attr("M") + .Attr("N") + .Attr("K") + .Attr("lhs_is_trans") + .Attr("rhs_is_trans") + .Attr("scaling_mode") + .Attr("has_bias") + .Attr("is_grouped_dense_wgrad") + .Attr("use_async_d2h_group_sizes")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index f6b1acd439..dbbe152a48 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -30,20 +30,37 @@ pybind11::dict Registrations() { pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler)); // Quantization - dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); - dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler); - dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); + dict["te_dbias_quantize_ffi"] = + pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(DBiasQuantizeInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(DBiasQuantizeHandler)); + dict["te_grouped_quantize_ffi"] = + pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(GroupedQuantizeInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedQuantizeHandler)); + dict["te_dequantize_ffi"] = + pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(DequantizeInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(DequantizeHandler)); // Softmax - dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler); - dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler); - dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler); - dict["te_scaled_masked_softmax_backward_ffi"] = - EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler); - dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = - EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler); - dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = - EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + dict["te_scaled_softmax_forward_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(ScaledSoftmaxForwardInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ScaledSoftmaxForwardHandler)); + dict["te_scaled_softmax_backward_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(ScaledSoftmaxBackwardInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ScaledSoftmaxBackwardHandler)); + dict["te_scaled_masked_softmax_forward_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(ScaledMaskedSoftmaxForwardInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler)); + dict["te_scaled_masked_softmax_backward_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(ScaledMaskedSoftmaxBackwardInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler)); + dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = pybind11::dict( + pybind11::arg("initialize") = + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler)); + dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = pybind11::dict( + pybind11::arg("initialize") = + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler)); // Normalization dict["te_norm_forward_ffi"] = @@ -56,24 +73,29 @@ pybind11::dict Registrations() { pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler)); // Attention - dict["te_fused_attn_forward_ffi"] = - pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(FusedAttnForwardHandler)); - dict["te_fused_attn_backward_ffi"] = - pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); + dict["te_fused_attn_forward_ffi"] = pybind11::dict( + pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(FusedAttnForwardInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(FusedAttnForwardHandler)); + dict["te_fused_attn_backward_ffi"] = pybind11::dict( + pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(FusedAttnBackwardInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); // GEMM dict["te_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(GemmInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); // Grouped GEMM - dict["te_grouped_gemm_d2h_group_sizes_ffi"] = - pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler)); + dict["te_grouped_gemm_d2h_group_sizes_ffi"] = pybind11::dict( + pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(GroupedGemmD2HGroupSizesInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler)); dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(GroupedGemmInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); return dict; diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 05260741b6..0ba94998d6 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -199,6 +199,38 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Attr("flatten_axis"), FFI_CudaGraph_Traits); +Error_Type DBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, + Buffer_Type scale_buf, Buffer_Type amax_buf, + Result_Type output_buf, Result_Type output_trans_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type dbias_buf, + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t quantize_layout_enum, bool is_dbias, + int64_t flatten_axis) { + return wrapInStreamCapture(std::function(DBiasQuantizeFFI), stream, input_buf, scale_buf, + amax_buf, output_buf, output_trans_buf, scale_inv_buf, + colwise_scale_inv_buf, updated_amax_buf, dbias_buf, workspace_buf, + scaling_mode, quantize_layout_enum, is_dbias, flatten_axis); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeInitializeHandler, DBiasQuantizeInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Arg() // amax + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Ret() // dbias + .Ret() // wkspace + .Attr("scaling_mode") + .Attr("q_layout") + .Attr("is_dbias") + .Attr("flatten_axis")); + Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); @@ -230,6 +262,22 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI, .Ret(), // output FFI_CudaGraph_Traits); +Error_Type DequantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, + Result_Type output_buf) { + return wrapInStreamCapture(std::function(DequantizeFFI), stream, input_buf, amax_buf, scale_buf, + scale_inv_buf, output_buf); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeInitializeHandler, DequantizeInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret()); // output + Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scales, Buffer_Type group_sizes, Result_Type outputs, Result_Type colwise_outputs, Result_Type scale_invs, @@ -415,5 +463,31 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Attr("q_layout") .Attr("flatten_axis")); +Error_Type GroupedQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scales, + Buffer_Type group_sizes, Result_Type outputs, + Result_Type colwise_outputs, Result_Type scale_invs, + Result_Type colwise_scale_invs, Result_Type amaxs, + JAXX_Scaling_Mode scaling_mode, + int64_t quantize_layout_enum, int64_t flatten_axis) { + return wrapInStreamCapture(std::function(GroupedQuantizeFFI), stream, inputs, scales, group_sizes, + outputs, colwise_outputs, scale_invs, colwise_scale_invs, amaxs, + scaling_mode, quantize_layout_enum, flatten_axis); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeInitializeHandler, GroupedQuantizeInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Arg() // group_sizes + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("scaling_mode") + .Attr("q_layout") + .Attr("flatten_axis")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/softmax.cpp b/transformer_engine/jax/csrc/extensions/softmax.cpp index ee3e5b35e8..702dac4f76 100644 --- a/transformer_engine/jax/csrc/extensions/softmax.cpp +++ b/transformer_engine/jax/csrc/extensions/softmax.cpp @@ -155,3 +155,108 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler, } // namespace jax } // namespace transformer_engine +// FFI Handler Initialization Wrappers for Softmax +namespace transformer_engine { +namespace jax { + +Error_Type ScaledSoftmaxForwardInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, + Result_Type output_buf, double scale_factor_) { + return wrapInStreamCapture(std::function(ScaledSoftmaxForwardFFI), stream, input_buf, output_buf, + scale_factor_); +} + +Error_Type ScaledMaskedSoftmaxForwardInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, + Buffer_Type mask_buf, Result_Type output_buf, + double scale_factor_) { + return wrapInStreamCapture(std::function(ScaledMaskedSoftmaxForwardFFI), stream, input_buf, + mask_buf, output_buf, scale_factor_); +} + +Error_Type ScaledUpperTriangMaskedSoftmaxForwardInitializeFFI(cudaStream_t stream, + Buffer_Type input_buf, + Result_Type output_buf, + double scale_factor_) { + return wrapInStreamCapture(std::function(ScaledUpperTriangMaskedSoftmaxForwardFFI), stream, + input_buf, output_buf, scale_factor_); +} + +Error_Type ScaledSoftmaxBackwardInitializeFFI(cudaStream_t stream, Buffer_Type grad_output_buf, + Buffer_Type softmax_output_buf, Result_Type dgrad_buf, + double scale_factor_) { + return wrapInStreamCapture(std::function(ScaledSoftmaxBackwardFFI), stream, grad_output_buf, + softmax_output_buf, dgrad_buf, scale_factor_); +} + +Error_Type ScaledUpperTriangMaskedSoftmaxBackwardInitializeFFI(cudaStream_t stream, + Buffer_Type grad_output_buf, + Buffer_Type softmax_output_buf, + Result_Type dgrad_buf, + double scale_factor_) { + return wrapInStreamCapture(std::function(ScaledUpperTriangMaskedSoftmaxBackwardFFI), stream, + grad_output_buf, softmax_output_buf, dgrad_buf, scale_factor_); +} + +// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax +Error_Type ScaledMaskedSoftmaxBackwardInitializeFFI(cudaStream_t stream, + Buffer_Type grad_output_buf, + Buffer_Type softmax_output_buf, + Result_Type dgrad_buf, double scale_factor_) { + return wrapInStreamCapture(std::function(ScaledSoftmaxBackwardFFI), stream, grad_output_buf, + softmax_output_buf, dgrad_buf, scale_factor_); +} + +// FFI Handler Symbols for Initialization +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxForwardInitializeHandler, + ScaledSoftmaxForwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("scale_factor")); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardInitializeHandler, + ScaledMaskedSoftmaxForwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // mask + .Ret() // output + .Attr("scale_factor")); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardInitializeHandler, + ScaledUpperTriangMaskedSoftmaxForwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("scale_factor")); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxBackwardInitializeHandler, + ScaledSoftmaxBackwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // grad_output + .Arg() // softmax_output + .Ret() // dgrad + .Attr("scale_factor")); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardInitializeHandler, + ScaledMaskedSoftmaxBackwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // grad_output + .Arg() // softmax_output + .Ret() // dgrad + .Attr("scale_factor")); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardInitializeHandler, + ScaledUpperTriangMaskedSoftmaxBackwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // grad_output + .Arg() // softmax_output + .Ret() // dgrad + .Attr("scale_factor")); + +} // namespace jax +} // namespace transformer_engine