Skip to content

Commit 5f2b831

Browse files
authored
[JAX] Scale swizzling via JAX transpose op (#2163)
* add swizzle in jax Signed-off-by: Phuong Nguyen <[email protected]> * added outer_impl Signed-off-by: Phuong Nguyen <[email protected]> * clean up FFI Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]>
1 parent a26a7f1 commit 5f2b831

File tree

3 files changed

+81
-77
lines changed

3 files changed

+81
-77
lines changed

transformer_engine/jax/cpp_extensions/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ def impl():
134134
"""
135135
return NotImplemented
136136

137+
@classmethod
138+
def outer_impl(cls, *args, **kwargs):
139+
"""
140+
to describe implementation for outer primitive
141+
"""
142+
return cls.impl(*args, **kwargs)
143+
137144
@staticmethod
138145
@abstractmethod
139146
def batcher():
@@ -196,7 +203,7 @@ def name_of_wrapper_p():
196203
outer_p = core.Primitive(name_of_wrapper_p())
197204
dispatch.prim_requires_devices_during_lowering.add(outer_p)
198205
outer_p.multiple_results = cls.multiple_results
199-
outer_p.def_impl(cls.impl)
206+
outer_p.def_impl(cls.outer_impl)
200207
outer_p.def_abstract_eval(cls.outer_abstract)
201208
batching.primitive_batchers[outer_p] = cls.batcher
202209
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,21 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
152152
return lhs_q, rhs_q
153153

154154

155+
@partial(jax.jit, static_argnums=(1, 2))
156+
def swizzled_scale(scale_inv, flatten_axis, is_colwise):
157+
"Swizzle scale_inv via JAX transpose ops"
158+
original_shape = scale_inv.shape
159+
shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:]))
160+
if is_colwise:
161+
scale_inv = jnp.transpose(scale_inv.reshape(shape_2d))
162+
cols, rows = shape_2d
163+
else:
164+
rows, cols = shape_2d
165+
reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4)
166+
swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4))
167+
return swizzled.reshape(original_shape)
168+
169+
155170
class GemmPrimitive(BasePrimitive):
156171
"""
157172
Primitive for cuBLAS GEMM
@@ -286,28 +301,18 @@ def _dims_are_consecutive(dims):
286301
)
287302
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
288303

289-
# Need extra workspace for swizzled scale factors
290-
lhs_swizzle_size = 0
291-
rhs_swizzle_size = 0
292-
swizzle_dtype = jnp.uint8
293-
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
294-
lhs_swizzle_size = lhs_scale_inv.size
295-
rhs_swizzle_size = rhs_scale_inv.size
296-
lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype)
297-
rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype)
298-
299304
# Declare cuBLAS workspace
300305
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
301306
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
302307
workspace_size = get_cublas_workspace_size_bytes() + 256
303308
workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
304309

305-
return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace
310+
return output, bias_grad, pre_gelu_out, workspace
306311

307312
@staticmethod
308313
def outer_abstract(*args, **kwargs):
309314
outputs = GemmPrimitive.abstract(*args, **kwargs)
310-
return outputs[:-3] # discard workspace arrays
315+
return outputs[:-1] # discard workspace array
311316

312317
@staticmethod
313318
def lowering(
@@ -374,24 +379,22 @@ def impl(
374379
grad,
375380
use_split_accumulator,
376381
):
377-
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
378-
lhs_transposed, rhs_transposed = _get_gemm_layout(
379-
(lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
380-
)
381-
lhs_scale_inv = apply_padding_to_scale_inv(
382-
lhs_scale_inv,
383-
scaling_mode,
384-
lhs.shape,
385-
is_colwise=lhs_transposed,
386-
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
387-
)
388-
rhs_scale_inv = apply_padding_to_scale_inv(
389-
rhs_scale_inv,
390-
scaling_mode,
391-
rhs.shape,
392-
is_colwise=not rhs_transposed,
393-
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
394-
)
382+
if scaling_mode.is_1d_block_scaling():
383+
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
384+
lhs_transposed, rhs_transposed = _get_gemm_layout(
385+
(lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
386+
)
387+
lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims)
388+
rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1
389+
390+
lhs_scale_inv = apply_padding_to_scale_inv(
391+
lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis
392+
)
393+
rhs_scale_inv = apply_padding_to_scale_inv(
394+
rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis
395+
)
396+
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
397+
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
395398

396399
outputs = GemmPrimitive.inner_primitive.bind(
397400
lhs,
@@ -408,7 +411,39 @@ def impl(
408411
grad=grad,
409412
use_split_accumulator=use_split_accumulator,
410413
)
411-
return outputs[:-3] # discard workspace arrays
414+
return outputs[:-1] # discard workspace array
415+
416+
@staticmethod
417+
def outer_impl(
418+
lhs,
419+
lhs_scale_inv,
420+
rhs,
421+
rhs_scale_inv,
422+
bias,
423+
gelu_input,
424+
out_dtype,
425+
contracting_dims,
426+
scaling_mode,
427+
fuse_bias,
428+
fuse_gelu,
429+
grad,
430+
use_split_accumulator,
431+
):
432+
return GemmPrimitive.impl(
433+
lhs,
434+
lhs_scale_inv,
435+
rhs,
436+
rhs_scale_inv,
437+
bias,
438+
gelu_input,
439+
out_dtype,
440+
contracting_dims,
441+
scaling_mode,
442+
fuse_bias,
443+
fuse_gelu,
444+
grad,
445+
use_split_accumulator,
446+
)
412447

413448
@staticmethod
414449
def batcher(

transformer_engine/jax/csrc/extensions/gemm.cpp

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
2828
}
2929

3030
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
31-
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv,
32-
JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
31+
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode,
32+
size_t axis_boundary, bool rowwise) {
3333
// Set tensor data with collapsed 2D shape
3434
auto buffer_dims = buffer.dimensions();
3535
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
@@ -61,40 +61,6 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
6161
} else {
6262
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
6363
}
64-
65-
// Swizzle scaling factors for MXFP8
66-
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
67-
// Get the swizzle buffer
68-
NVTE_CHECK(swizzled_scale_inv->element_count() > 0,
69-
"Missing swizzled inverse scale buffer in the JAX primitive.");
70-
auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
71-
auto swizzled_scale_inv_dtype =
72-
convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type());
73-
NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1,
74-
"Inverse scale factors need to have an 8-bit data type.");
75-
76-
// Create tensor to hold swizzled scale factor
77-
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
78-
if (rowwise) {
79-
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
80-
output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
81-
} else {
82-
output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
83-
output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
84-
scale_shape);
85-
}
86-
87-
// Launch swizzle kernel
88-
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
89-
90-
// Set swizzled scales into the input tensor
91-
if (rowwise) {
92-
input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
93-
} else {
94-
input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
95-
scale_shape);
96-
}
97-
}
9864
}
9965

10066
return std::make_tuple(std::move(input), input_shape);
@@ -103,21 +69,19 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
10369
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
10470
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
10571
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
106-
Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace,
107-
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
72+
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
10873
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
10974
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
110-
// Operands (this includes swizzling MXFP8 scaling factors)
11175
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
11276
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
11377
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
11478
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
11579
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
11680
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
117-
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(
118-
stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
119-
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(
120-
stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
81+
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode,
82+
lhs_axis_boundary, make_lhs_rowwise);
83+
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode,
84+
rhs_axis_boundary, make_rhs_rowwise);
12185

12286
// Output tensor
12387
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
@@ -188,8 +152,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
188152
.Ret<Buffer_Type>() // output
189153
.Ret<Buffer_Type>() // bias_grad
190154
.Ret<Buffer_Type>() // pre_gelu_out
191-
.Ret<Buffer_Type>() // lhs_swizzled
192-
.Ret<Buffer_Type>() // rhs_swizzled
193155
.Ret<Buffer_Type>() // workspace
194156
.Attr<JAXX_Scaling_Mode>("scaling_mode")
195157
.Attr<int64_t>("lhs_axis_boundary")

0 commit comments

Comments
 (0)