1414from transformer_engine .common .recipe import Format as FP8Format
1515from transformer_engine .jax import fp8_autocast , get_delayed_scaling
1616from transformer_engine .jax .quantize import (
17- QuantizeConfig ,
17+ get_quantize_config ,
1818 is_fp8_available ,
1919 ScalingMode ,
2020 update_collections ,
21+ TensorSource ,
2122)
2223from transformer_engine .jax .sharding import MeshResource , global_mesh_resource
2324
@@ -49,7 +50,7 @@ def test_update_collections(self):
4950class TestFP8Functions (unittest .TestCase ):
5051
5152 def _check_default_state (self ):
52- self .assertFalse (QuantizeConfig .is_fp8_enabled ())
53+ self .assertFalse (get_quantize_config () .is_fp8_enabled ())
5354
5455 def _compare_delay_scaling (self , ref , test ):
5556 self .assertTrue (ref .margin == test .margin )
@@ -58,17 +59,23 @@ def _compare_delay_scaling(self, ref, test):
5859 self .assertTrue (ref .amax_compute_algo == test .amax_compute_algo )
5960
6061 def _compare_current_scaling (self , test ):
61- self .assertEqual (QuantizeConfig .FP8_FORMAT , test .fp8_format )
62- self .assertEqual (QuantizeConfig .SCALING_MODE , ScalingMode .CURRENT_TENSOR_SCALING )
62+ self .assertEqual (get_quantize_config ().FP8_FORMAT , test .fp8_format )
63+ for tensor_source in TensorSource :
64+ self .assertEqual (
65+ get_quantize_config ().get_scaling_mode (tensor_source ),
66+ ScalingMode .CURRENT_TENSOR_SCALING ,
67+ )
6368
6469 def _compare_mxfp8_scaling (self , test ):
65- self .assertEqual (QuantizeConfig .MARGIN , test .margin )
66- self .assertEqual (QuantizeConfig .FP8_FORMAT , test .fp8_format )
67- self .assertEqual (QuantizeConfig .SCALING_MODE , ScalingMode .MXFP8_1D_SCALING )
70+ self .assertEqual (get_quantize_config ().MARGIN , test .margin )
71+ self .assertEqual (get_quantize_config ().FP8_FORMAT , test .fp8_format )
72+ for tensor_source in TensorSource :
73+ self .assertEqual (
74+ get_quantize_config ().get_scaling_mode (tensor_source ), ScalingMode .MXFP8_1D_SCALING
75+ )
6876
6977 @unittest .skipIf (not is_fp8_supported , reason = reason )
7078 def test_fp8_autocast_delayed_scaling (self ):
71- QuantizeConfig .finalize () # Ensure the testing not affect by previous tests.
7279 self ._check_default_state ()
7380
7481 with fp8_autocast (enabled = False , fp8_recipe = DelayedScaling (), mesh_resource = MeshResource ()):
@@ -78,21 +85,20 @@ def test_fp8_autocast_delayed_scaling(self):
7885
7986 ds = DelayedScaling (margin = 5.0 , fp8_format = FP8Format .E4M3 , amax_history_len = 1 )
8087 with fp8_autocast (enabled = True , fp8_recipe = ds , mesh_resource = MeshResource ()):
81- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
88+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
8289 self ._compare_delay_scaling (get_delayed_scaling (), ds )
8390
8491 self ._check_default_state ()
8592
8693 ds = DelayedScaling (margin = 3.0 , fp8_format = FP8Format .HYBRID , amax_history_len = 1 )
8794 with fp8_autocast (enabled = True , fp8_recipe = ds , mesh_resource = MeshResource ()):
88- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
95+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
8996 self ._compare_delay_scaling (get_delayed_scaling (), ds )
9097
9198 self ._check_default_state ()
9299
93100 @unittest .skipIf (not is_fp8_supported , reason = reason )
94101 def test_fp8_autocast_current_scaling (self ):
95- QuantizeConfig .finalize () # Ensure the testing not affect by previous tests.
96102 self ._check_default_state ()
97103
98104 with fp8_autocast (
@@ -104,21 +110,20 @@ def test_fp8_autocast_current_scaling(self):
104110
105111 cs = Float8CurrentScaling (fp8_format = FP8Format .E4M3 )
106112 with fp8_autocast (enabled = True , fp8_recipe = cs , mesh_resource = MeshResource ()):
107- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
113+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
108114 self ._compare_current_scaling (cs )
109115
110116 self ._check_default_state ()
111117
112118 cs = Float8CurrentScaling (fp8_format = FP8Format .HYBRID )
113119 with fp8_autocast (enabled = True , fp8_recipe = cs , mesh_resource = MeshResource ()):
114- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
120+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
115121 self ._compare_current_scaling (cs )
116122
117123 self ._check_default_state ()
118124
119125 @unittest .skipIf (not is_mxfp8_supported , reason = mxfp8_reason )
120126 def test_fp8_autocast_mxfp8_block_scaling (self ):
121- QuantizeConfig .finalize () # Ensure the testing not affect by previous tests.
122127 self ._check_default_state ()
123128
124129 with fp8_autocast (
@@ -130,14 +135,14 @@ def test_fp8_autocast_mxfp8_block_scaling(self):
130135
131136 bs = MXFP8BlockScaling (margin = 5.0 , fp8_format = FP8Format .E4M3 )
132137 with fp8_autocast (enabled = True , fp8_recipe = bs , mesh_resource = MeshResource ()):
133- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
138+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
134139 self ._compare_mxfp8_scaling (bs )
135140
136141 self ._check_default_state ()
137142
138143 bs = MXFP8BlockScaling (margin = 3.0 , fp8_format = FP8Format .HYBRID )
139144 with fp8_autocast (enabled = True , fp8_recipe = bs , mesh_resource = MeshResource ()):
140- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
145+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
141146 self ._compare_mxfp8_scaling (bs )
142147
143148 self ._check_default_state ()
0 commit comments