Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,13 @@ pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_s
bool zero_centered_gamma, int sm_margin);

// Quantization
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);

pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
Expand All @@ -90,21 +93,29 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
QuantizeLayout q_layout);

// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

// Attention
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);

NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
Expand All @@ -131,11 +142,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
int64_t window_size_right);

// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);

// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);

// Cudnn helpers
Expand Down
69 changes: 69 additions & 0 deletions transformer_engine/jax/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,5 +671,74 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.Attrs(),
FFI_CudaGraph_Traits);

Error_Type FusedAttnForwardInitializeFFI(
cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf,
Buffer_Type bias_buf, Buffer_Type seed_buf, Buffer_Type q_cu_seqlens_buf,
Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Variadic_Buffer_Type _unused_args, Result_Type output_buf, Result_Type softmax_aux_buf,
Result_Type rng_state_buf, Result_Type workspace_buf, Dictionary attrs) {
return wrapInStreamCapture(std::function(FusedAttnForwardFFI), stream, q_buf, k_buf, v_buf,
bias_buf, seed_buf, q_cu_seqlens_buf, kv_cu_seqlens_buf,
q_seq_offsets_buf, k_seq_offsets_buf, _unused_args, output_buf,
softmax_aux_buf, rng_state_buf, workspace_buf, attrs);
}

Error_Type FusedAttnBackwardInitializeFFI(
cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf,
Buffer_Type bias_buf, Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf,
Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type q_cu_seqlens_buf,
Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Variadic_Buffer_Type _unused_args, Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf,
Result_Type dbias_buf, Result_Type workspace_buf, Dictionary attrs) {
return wrapInStreamCapture(std::function(FusedAttnBackwardFFI), stream, q_buf, k_buf, v_buf,
bias_buf, softmax_aux_buf, rng_state_buf, output_buf, doutput_buf,
q_cu_seqlens_buf, kv_cu_seqlens_buf, q_seq_offsets_buf,
k_seq_offsets_buf, _unused_args, dq_buf, dk_buf, dv_buf, dbias_buf,
workspace_buf, attrs);
}

// FFI Handler Symbols for Initialization
XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardInitializeHandler, FusedAttnForwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // q
.Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // seed_buf
.Arg<Buffer_Type>() // q_cu_seqlens
.Arg<Buffer_Type>() // kv_cu_seqlens
.Arg<Buffer_Type>() // q_seq_offsets
.Arg<Buffer_Type>() // k_seq_offsets
.RemainingArgs() // _cp_aux_args unused
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // softmax_aux
.Ret<Buffer_Type>() // rng_state
.Ret<Buffer_Type>() // workspace
.Attrs());

XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardInitializeHandler, FusedAttnBackwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // q
.Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // softmax_aux
.Arg<Buffer_Type>() // rng_state
.Arg<Buffer_Type>() // output
.Arg<Buffer_Type>() // doutput
.Arg<Buffer_Type>() // q_cu_seqlens
.Arg<Buffer_Type>() // kv_cu_seqlens
.Arg<Buffer_Type>() // q_seq_offsets
.Arg<Buffer_Type>() // k_seq_offsets
.RemainingArgs() // _cp_aux_args unused
.Ret<Buffer_Type>() // dq
.Ret<Buffer_Type>() // dk
.Ret<Buffer_Type>() // dv
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // workspace
.Attrs());

} // namespace jax
} // namespace transformer_engine
88 changes: 88 additions & 0 deletions transformer_engine/jax/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,5 +742,93 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Attr<bool>("is_grouped_dense_wgrad")
.Attr<bool>("use_async_d2h_group_sizes"));

Error_Type GemmInitializeFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv,
Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias,
Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator,
JAXX_Collective_Op collective_op) {
return wrapInStreamCapture(std::function(GemmFFI), stream, lhs, lhs_scale_inv, rhs, rhs_scale_inv,
bias, gelu_input, output, bias_grad, pre_gelu_out, workspace,
scaling_mode, lhs_axis_boundary, rhs_axis_boundary, lhs_transposed,
rhs_transposed, fuse_bias, fuse_gelu, grad, use_split_accumulator,
collective_op);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmInitializeHandler, GemmInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs
.Arg<Buffer_Type>() // lhs_scale_inv
.Arg<Buffer_Type>() // rhs
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
.Attr<int64_t>("rhs_axis_boundary")
.Attr<bool>("lhs_transposed")
.Attr<bool>("rhs_transposed")
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator")
.Attr<JAXX_Collective_Op>("collective_op"));

Error_Type GroupedGemmD2HGroupSizesInitializeFFI(cudaStream_t stream, Buffer_Type group_sizes,
Result_Type dummy_output, size_t num_gemms) {
return wrapInStreamCapture(std::function(GroupedGemmD2HGroupSizesFFI), stream, group_sizes,
dummy_output, num_gemms);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesInitializeHandler,
GroupedGemmD2HGroupSizesInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // group_sizes
.Ret<Buffer_Type>() // dummy_output
.Attr<int64_t>("num_gemms"));

Error_Type GroupedGemmInitializeFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset,
Result_Type output, Result_Type workspace, size_t m, size_t n,
size_t k, bool lhs_is_trans, bool rhs_is_trans,
JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) {
return wrapInStreamCapture(std::function(GroupedGemmFFI), stream, lhs_data, lhs_sinv, rhs_data,
rhs_sinv, bias, group_sizes, group_offset, output, workspace, m, n, k,
lhs_is_trans, rhs_is_trans, scaling_mode, has_bias,
is_grouped_dense_wgrad, use_async_d2h_group_sizes);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmInitializeHandler, GroupedGemmInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs_data
.Arg<Buffer_Type>() // lhs_sinv
.Arg<Buffer_Type>() // rhs_data
.Arg<Buffer_Type>() // rhs_sinv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // group_sizes
.Arg<Buffer_Type>() // group_offset
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("M")
.Attr<int64_t>("N")
.Attr<int64_t>("K")
.Attr<bool>("lhs_is_trans")
.Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad")
.Attr<bool>("use_async_d2h_group_sizes"));

} // namespace jax
} // namespace transformer_engine
64 changes: 43 additions & 21 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,37 @@ pybind11::dict Registrations() {
pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler));

// Quantization
dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
dict["te_dbias_quantize_ffi"] =
pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(DBiasQuantizeInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(DBiasQuantizeHandler));
dict["te_grouped_quantize_ffi"] =
pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(GroupedQuantizeInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedQuantizeHandler));
dict["te_dequantize_ffi"] =
pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(DequantizeInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(DequantizeHandler));

// Softmax
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_masked_softmax_backward_ffi"] =
EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler);
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler);
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
dict["te_scaled_softmax_forward_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(ScaledSoftmaxForwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ScaledSoftmaxForwardHandler));
dict["te_scaled_softmax_backward_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(ScaledSoftmaxBackwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ScaledSoftmaxBackwardHandler));
dict["te_scaled_masked_softmax_forward_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(ScaledMaskedSoftmaxForwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler));
dict["te_scaled_masked_softmax_backward_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(ScaledMaskedSoftmaxBackwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler));
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = pybind11::dict(
pybind11::arg("initialize") =
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler));
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = pybind11::dict(
pybind11::arg("initialize") =
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler));

// Normalization
dict["te_norm_forward_ffi"] =
Expand All @@ -56,24 +73,29 @@ pybind11::dict Registrations() {
pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler));

// Attention
dict["te_fused_attn_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnForwardHandler));
dict["te_fused_attn_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler));
dict["te_fused_attn_forward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(FusedAttnForwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnForwardHandler));
dict["te_fused_attn_backward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(FusedAttnBackwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler));

// GEMM
dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(GemmInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));

// Grouped GEMM
dict["te_grouped_gemm_d2h_group_sizes_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler));
dict["te_grouped_gemm_d2h_group_sizes_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(GroupedGemmD2HGroupSizesInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler));
dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(GroupedGemmInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));

return dict;
Expand Down
Loading
Loading