Skip to content

Commit c1b8230

Browse files
authored
Merge branch 'main' into xiny/quantized_input
2 parents fcd52fc + de81b7d commit c1b8230

File tree

18 files changed

+649
-245
lines changed

18 files changed

+649
-245
lines changed

qa/L1_jax_distributed_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ set -xe
99
mkdir -p "$XML_LOG_DIR"
1010

1111
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
12+
SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh

tests/jax/multi_process_launch.sh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
#!/bin/bash
6+
7+
SCRIPT_NAME="${SCRIPT_NAME:-test.py}"
8+
9+
10+
XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
11+
--xla_gpu_enable_command_buffer=''"
12+
13+
export XLA_FLAGS="${XLA_BASE_FLAGS}"
14+
15+
NUM_RUNS=$(nvidia-smi --query-gpu=count --format=csv,noheader)
16+
for ((i=1; i<NUM_RUNS; i++))
17+
do
18+
CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_PROC > /dev/null 2>&1 &
19+
done
20+
21+
CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_PROC
22+
23+
wait

tests/jax/test_helper.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from transformer_engine.common.recipe import Format as FP8Format
1515
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
1616
from transformer_engine.jax.quantize import (
17-
QuantizeConfig,
17+
get_quantize_config,
1818
is_fp8_available,
1919
ScalingMode,
2020
update_collections,
21+
TensorSource,
2122
)
2223
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
2324

@@ -49,7 +50,7 @@ def test_update_collections(self):
4950
class TestFP8Functions(unittest.TestCase):
5051

5152
def _check_default_state(self):
52-
self.assertFalse(QuantizeConfig.is_fp8_enabled())
53+
self.assertFalse(get_quantize_config().is_fp8_enabled())
5354

5455
def _compare_delay_scaling(self, ref, test):
5556
self.assertTrue(ref.margin == test.margin)
@@ -58,17 +59,23 @@ def _compare_delay_scaling(self, ref, test):
5859
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
5960

6061
def _compare_current_scaling(self, test):
61-
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
62-
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
62+
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
63+
for tensor_source in TensorSource:
64+
self.assertEqual(
65+
get_quantize_config().get_scaling_mode(tensor_source),
66+
ScalingMode.CURRENT_TENSOR_SCALING,
67+
)
6368

6469
def _compare_mxfp8_scaling(self, test):
65-
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
66-
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
67-
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
70+
self.assertEqual(get_quantize_config().MARGIN, test.margin)
71+
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
72+
for tensor_source in TensorSource:
73+
self.assertEqual(
74+
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
75+
)
6876

6977
@unittest.skipIf(not is_fp8_supported, reason=reason)
7078
def test_fp8_autocast_delayed_scaling(self):
71-
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
7279
self._check_default_state()
7380

7481
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
@@ -78,21 +85,20 @@ def test_fp8_autocast_delayed_scaling(self):
7885

7986
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
8087
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
81-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
88+
self.assertTrue(get_quantize_config().is_fp8_enabled())
8289
self._compare_delay_scaling(get_delayed_scaling(), ds)
8390

8491
self._check_default_state()
8592

8693
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
8794
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
88-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
95+
self.assertTrue(get_quantize_config().is_fp8_enabled())
8996
self._compare_delay_scaling(get_delayed_scaling(), ds)
9097

9198
self._check_default_state()
9299

93100
@unittest.skipIf(not is_fp8_supported, reason=reason)
94101
def test_fp8_autocast_current_scaling(self):
95-
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
96102
self._check_default_state()
97103

98104
with fp8_autocast(
@@ -104,21 +110,20 @@ def test_fp8_autocast_current_scaling(self):
104110

105111
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
106112
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
107-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
113+
self.assertTrue(get_quantize_config().is_fp8_enabled())
108114
self._compare_current_scaling(cs)
109115

110116
self._check_default_state()
111117

112118
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
113119
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
114-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
120+
self.assertTrue(get_quantize_config().is_fp8_enabled())
115121
self._compare_current_scaling(cs)
116122

117123
self._check_default_state()
118124

119125
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
120126
def test_fp8_autocast_mxfp8_block_scaling(self):
121-
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
122127
self._check_default_state()
123128

124129
with fp8_autocast(
@@ -130,14 +135,14 @@ def test_fp8_autocast_mxfp8_block_scaling(self):
130135

131136
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
132137
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
133-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
138+
self.assertTrue(get_quantize_config().is_fp8_enabled())
134139
self._compare_mxfp8_scaling(bs)
135140

136141
self._check_default_state()
137142

138143
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
139144
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
140-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
145+
self.assertTrue(get_quantize_config().is_fp8_enabled())
141146
self._compare_mxfp8_scaling(bs)
142147

143148
self._check_default_state()

tests/jax/test_layer.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
from transformer_engine.common import recipe
2424
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
2525
from transformer_engine.jax.quantize import (
26-
QuantizeConfig,
26+
get_quantize_config,
2727
ScalingMode,
2828
is_fp8_available,
2929
update_collections,
30+
TensorSource,
31+
fp8_autocast,
3032
)
31-
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
33+
from transformer_engine.jax.sharding import MeshResource
3234

3335

3436
@pytest.fixture(autouse=True, scope="function")
@@ -356,7 +358,7 @@ def test_backward(
356358

357359
ref_params, test_params = self._sync_params(ref_params, test_params)
358360

359-
if QuantizeConfig.is_fp8_enabled():
361+
if get_quantize_config().is_fp8_enabled():
360362
for _ in range(4):
361363
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
362364
inputs,
@@ -365,12 +367,15 @@ def test_backward(
365367
test_others,
366368
test_layer,
367369
)
368-
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
370+
if (
371+
get_quantize_config().get_scaling_mode(TensorSource.X)
372+
== ScalingMode.DELAYED_TENSOR_SCALING
373+
):
369374
_, updated_quantize_meta = flax.core.pop(
370-
updated_state[0], QuantizeConfig.COLLECTION_NAME
375+
updated_state[0], get_quantize_config().COLLECTION_NAME
371376
)
372377
test_others = update_collections(
373-
{QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others
378+
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others
374379
)
375380
del updated_quantize_meta
376381
del updated_state
@@ -500,41 +505,33 @@ class BaseTester:
500505

501506
def test_forward(self, data_shape, dtype, attrs):
502507
"""Test normal datatype forward"""
503-
QuantizeConfig.finalize() # Ensure FP8 disabled.
504-
with global_shard_guard(
505-
MeshResource()
506-
): # Empty MeshResource is used as we are running on a single device
508+
# Ensure FP8 disabled.
509+
# Empty MeshResource is used as we are running on a single device
510+
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
507511
self.runner(attrs).test_forward(data_shape, dtype)
508512

509513
def test_backward(self, data_shape, dtype, attrs):
510514
"""Test normal datatype backward"""
511-
QuantizeConfig.finalize() # Ensure FP8 disabled.
512-
with global_shard_guard(
513-
MeshResource()
514-
): # Empty MeshResource is used as we are running on a single device
515+
# Ensure FP8 disabled.
516+
# Empty MeshResource is used as we are running on a single device
517+
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
515518
self.runner(attrs).test_backward(data_shape, dtype)
516519

517520
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
518521
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
519522
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
520523
"""Test forward with fp8 enabled"""
521-
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
522-
with global_shard_guard(
523-
MeshResource()
524-
): # Empty MeshResource is used as we are running on a single device
524+
# Empty MeshResource is used as we are running on a single device
525+
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
525526
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
526-
QuantizeConfig.finalize()
527527

528528
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
529529
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
530530
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
531531
"""Test backward with fp8 enabled"""
532-
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
533-
with global_shard_guard(
534-
MeshResource()
535-
): # Empty MeshResource is used as we are running on a single device
532+
# Empty MeshResource is used as we are running on a single device
533+
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
536534
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
537-
QuantizeConfig.finalize()
538535

539536

540537
class TestEncoderLayer(BaseTester):

0 commit comments

Comments
 (0)