@@ -28,8 +28,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
28
28
}
29
29
30
30
std::tuple<TensorWrapper, std::vector<size_t >> xla_buffer_to_nvte_gemm_operand (
31
- cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv ,
32
- JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
31
+ cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode ,
32
+ size_t axis_boundary, bool rowwise) {
33
33
// Set tensor data with collapsed 2D shape
34
34
auto buffer_dims = buffer.dimensions ();
35
35
std::vector<size_t > input_shape = {product (buffer_dims, 0 , axis_boundary),
@@ -61,40 +61,6 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
61
61
} else {
62
62
input.set_columnwise_scale_inv (scale_inv.untyped_data (), scale_dtype, scale_shape);
63
63
}
64
-
65
- // Swizzle scaling factors for MXFP8
66
- if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
67
- // Get the swizzle buffer
68
- NVTE_CHECK (swizzled_scale_inv->element_count () > 0 ,
69
- " Missing swizzled inverse scale buffer in the JAX primitive." );
70
- auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype (scale_inv.element_type ());
71
- auto swizzled_scale_inv_dtype =
72
- convert_ffi_datatype_to_te_dtype (swizzled_scale_inv->element_type ());
73
- NVTE_CHECK (typeToSize (scale_inv_dtype) == 1 && typeToSize (swizzled_scale_inv_dtype) == 1 ,
74
- " Inverse scale factors need to have an 8-bit data type." );
75
-
76
- // Create tensor to hold swizzled scale factor
77
- TensorWrapper output (get_nvte_scaling_mode (scaling_mode));
78
- if (rowwise) {
79
- output.set_rowwise_data (buffer.untyped_data (), input_dtype, input_shape);
80
- output.set_rowwise_scale_inv (swizzled_scale_inv->untyped_data (), scale_dtype, scale_shape);
81
- } else {
82
- output.set_columnwise_data (buffer.untyped_data (), input_dtype, input_shape);
83
- output.set_columnwise_scale_inv (swizzled_scale_inv->untyped_data (), scale_dtype,
84
- scale_shape);
85
- }
86
-
87
- // Launch swizzle kernel
88
- nvte_swizzle_scaling_factors (input.data (), output.data (), stream);
89
-
90
- // Set swizzled scales into the input tensor
91
- if (rowwise) {
92
- input.set_rowwise_scale_inv (swizzled_scale_inv->untyped_data (), scale_dtype, scale_shape);
93
- } else {
94
- input.set_columnwise_scale_inv (swizzled_scale_inv->untyped_data (), scale_dtype,
95
- scale_shape);
96
- }
97
- }
98
64
}
99
65
100
66
return std::make_tuple (std::move (input), input_shape);
@@ -103,21 +69,19 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
103
69
Error_Type GemmFFI (cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
104
70
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
105
71
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
106
- Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace,
107
- JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
72
+ Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
108
73
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
109
74
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
110
- // Operands (this includes swizzling MXFP8 scaling factors)
111
75
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
112
76
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
113
77
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
114
78
(is_tensor_scaling (scaling_mode) && nvte_is_non_tn_fp8_gemm_supported ()));
115
79
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
116
80
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
117
- auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand (
118
- stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
119
- auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand (
120
- stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
81
+ auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand (stream, lhs, lhs_scale_inv, scaling_mode,
82
+ lhs_axis_boundary, make_lhs_rowwise);
83
+ auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand (stream, rhs, rhs_scale_inv, scaling_mode,
84
+ rhs_axis_boundary, make_rhs_rowwise);
121
85
122
86
// Output tensor
123
87
std::vector<size_t > out_shape = {(lhs_transposed) ? lhs_shape[1 ] : lhs_shape[0 ],
@@ -188,8 +152,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
188
152
.Ret<Buffer_Type>() // output
189
153
.Ret<Buffer_Type>() // bias_grad
190
154
.Ret<Buffer_Type>() // pre_gelu_out
191
- .Ret<Buffer_Type>() // lhs_swizzled
192
- .Ret<Buffer_Type>() // rhs_swizzled
193
155
.Ret<Buffer_Type>() // workspace
194
156
.Attr<JAXX_Scaling_Mode>(" scaling_mode" )
195
157
.Attr<int64_t>(" lhs_axis_boundary" )
0 commit comments