Skip to content

Commit e2f2a0b

Browse files
[JAX] Make SR rng state always 2D (num_devices, 4) to fix partitioning issue (#2294)
* Make SR rng state always 2D (num_devices, 4) Signed-off-by: Jeremy Berchtold <[email protected]> * fix pure-jax impl Signed-off-by: Jeremy Berchtold <[email protected]> * fix test shape Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]>
1 parent eb34783 commit e2f2a0b

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

tests/jax/test_custom_call_compute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def _sample_sr_qdq(
876876
for i in range(num_samples):
877877
iter_key = jax.random.fold_in(key, i)
878878
sr_rng_state = jax.random.randint(
879-
iter_key, (4,), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
879+
iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
880880
)
881881
quantizer = QuantizerFactory.create(
882882
q_dtype=q_dtype,

transformer_engine/jax/quantize/helper.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,10 +631,8 @@ def _make_stochastic_rounding_rng_state(
631631
)
632632
sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash)
633633

634-
# Generate 4 random uint32 values from the JAX PRNG key
635-
shape = (4,)
636-
if get_num_devices_in_mesh() > 1:
637-
shape = (get_num_devices_in_mesh(), 4)
634+
# Generate 4 random uint32 values per device from the JAX PRNG key
635+
shape = (get_num_devices_in_mesh(), 4)
638636
sr_jax_rng_state = jax.random.randint(
639637
sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
640638
).view(jnp.uint32)

transformer_engine/jax/quantize/quantizer.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TensorSource,
3535
)
3636
from .device_utils import is_fp8_gemm_with_all_layouts_supported
37+
from ..sharding import get_num_devices_in_mesh
3738

3839
__all__ = [
3940
"QuantizeLayout",
@@ -633,24 +634,27 @@ def _apply_stochastic_rounding(self, x):
633634
assert (
634635
self.stochastic_rounding_rng_state is not None
635636
), "Stochastic rounding RNG state is not initialized"
636-
assert self.stochastic_rounding_rng_state.shape == (
637-
4,
638-
), "Stochastic rounding RNG state must be of shape (4,)"
637+
expected_sr_rng_state_shape = (get_num_devices_in_mesh(), 4)
638+
assert self.stochastic_rounding_rng_state.shape == expected_sr_rng_state_shape, (
639+
"Stochastic rounding RNG state must be of shape (num_devices_in_mesh, 4). Expected"
640+
f" {expected_sr_rng_state_shape}, but got {self.stochastic_rounding_rng_state.shape}"
641+
)
639642
assert (
640643
self.stochastic_rounding_rng_state.dtype == jnp.uint32
641644
), "Stochastic rounding RNG state must be of dtype uint32"
642645

643646
# Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s
644647
key_bits = jnp.array(
645648
[
646-
self.stochastic_rounding_rng_state[0],
647-
self.stochastic_rounding_rng_state[1],
649+
# only take the first device's RNG state as the pure-JAX stochastic rounding impl only uses a single-device
650+
self.stochastic_rounding_rng_state[0][0],
651+
self.stochastic_rounding_rng_state[0][1],
648652
],
649653
dtype=jnp.uint32,
650654
)
651655
key = jax.random.wrap_key_data(key_bits)
652-
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[2])
653-
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[3])
656+
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][2])
657+
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][3])
654658

655659
abs_x = jnp.abs(x)
656660
sign_x = jnp.sign(x)

0 commit comments

Comments
 (0)