diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 39fbd265e7..846c248ca2 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -16,7 +16,7 @@ import torch.distributed as dist from transformer_engine.common.recipe import ( - BlockScaling, + MXFP8BlockScaling, DelayedScaling, Format, Recipe, @@ -44,7 +44,7 @@ def quantization_recipe() -> Recipe: fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" ) if QUANTIZATION == "mxfp8": - return BlockScaling() + return MXFP8BlockScaling() return te.fp8.get_default_fp8_recipe() diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 11a7df5852..fe633f2b60 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -136,7 +136,7 @@ def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: fp8_format=transformer_engine.common.recipe.Format.E4M3, ) if name == "mxfp8": - return transformer_engine.common.recipe.BlockScaling( + return transformer_engine.common.recipe.MXFP8BlockScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, ) raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 920e5fce99..dcdfa771c8 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -53,7 +53,7 @@ class ModelConfig: fp8_recipes = [ recipe.DelayedScaling(), - recipe.BlockScaling(), + recipe.MXFP8BlockScaling(), ] # Supported data types @@ -315,7 +315,7 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if fp8_recipe.block() and not mxfp8_available: + if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) # Run model with different CUDA graph settings. diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index b2bd623ad8..570d679af8 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -148,7 +148,7 @@ def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: fp8_format=transformer_engine.common.recipe.Format.E4M3, ) if name == "mxfp8": - return transformer_engine.common.recipe.BlockScaling( + return transformer_engine.common.recipe.MXFP8BlockScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, ) raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b94094111e..451c9bee3c 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -98,7 +98,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq mask_types = ["causal", "no_mask"] fp8_recipes = [ - recipe.BlockScaling(), + recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), ] @@ -556,7 +556,7 @@ def _test_e2e_selective_recompute( def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) - if recipe.block() and not mxfp8_available: + if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) config = model_configs[model] @@ -668,7 +668,7 @@ def test_gpt_full_activation_recompute( ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) - if recipe.block() and not mxfp8_available: + if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) config = model_configs[model] @@ -1418,7 +1418,7 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f if fp8: if recipe.delayed(): split_size = 16 - if recipe.block(): + if recipe.mxfp8(): split_size = 128 m = config.seq_len // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() @@ -1463,9 +1463,9 @@ def test_grouped_linear_accuracy( ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) - if recipe.block() and not mxfp8_available: + if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if fp8 and recipe.block(): # TODO(ksivamani): debug mismatches + if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches pytest.skip("MXFP8 unsupported for grouped linear.") config = model_configs[model] @@ -1648,9 +1648,9 @@ def test_padding_grouped_linear_accuracy( ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) - if recipe.block() and not mxfp8_available: + if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if fp8 and recipe.block(): # TODO(ksivamani): debug mismatches + if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches pytest.skip("MXFP8 unsupported for grouped linear.") config = model_configs[model] @@ -1860,7 +1860,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): def test_gpt_fp8_parameters(dtype, bs, model, recipe): if not fp8_available: pytest.skip(reason_for_no_fp8) - if recipe.block() and not mxfp8_available: + if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) config = model_configs[model] diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index efd14d5607..f68edf155c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -44,9 +44,9 @@ class Recipe: Base recipe class. """ - def block(self): - """Whether the given recipe is block scaling.""" - return isinstance(self, BlockScaling) + def mxfp8(self): + """Whether the given recipe is MXFP8 block scaling.""" + return isinstance(self, MXFP8BlockScaling) def delayed(self): """Whether the given recipe is delayed scaling.""" @@ -162,7 +162,7 @@ def __repr__(self) -> str: @dataclass() -class BlockScaling(Recipe): +class MXFP8BlockScaling(Recipe): """ Use the current scaling factor strategy. diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index a83696ddd1..254bcf12e1 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -13,7 +13,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, BlockScaling +from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling from .constants import dist_group_type from .utils import get_device_compute_capability @@ -46,7 +46,7 @@ def check_mxfp8_support() -> Tuple[bool, str]: def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if get_device_compute_capability() >= (10, 0): # blackwell and above - return BlockScaling() + return MXFP8BlockScaling() return DelayedScaling() @@ -211,7 +211,7 @@ def add_fp8_tensors_to_global_buffer( wrapper. For non CG case, it's called from within the module. """ - if fp8_meta["recipe"].block(): + if fp8_meta["recipe"].mxfp8(): return # Every module must call this function exactly once since @@ -414,7 +414,7 @@ def fp8_autocast_enter( if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() assert fp8_available, reason_for_no_fp8 - if isinstance(fp8_recipe, BlockScaling): + if isinstance(fp8_recipe, MXFP8BlockScaling): mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() assert mxfp8_available, reason_for_no_mxfp8 @@ -434,7 +434,7 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - to ensure both forward steps are numerically same. """ - if fp8_meta["recipe"].block(): + if fp8_meta["recipe"].mxfp8(): return buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" @@ -460,7 +460,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non 1 forward for indentical numerical outputs. """ - if fp8_meta["recipe"].block(): + if fp8_meta["recipe"].mxfp8(): return # Store updated amaxes and scales from phase 1 post forward. @@ -479,7 +479,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" - if fp8_meta["recipe"].block(): + if fp8_meta["recipe"].mxfp8(): return fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) @@ -741,8 +741,8 @@ def create( cls = None if recipe.delayed(): cls = DelayedScalingRecipeState - elif recipe.block(): - cls = BlockScalingRecipeState + elif recipe.mxfp8(): + cls = MXFP8BlockScalingRecipeState else: raise ValueError("{recipe.__class__.__name__} is not supported") return cls( @@ -813,20 +813,20 @@ def make_quantizers(self) -> list: ] -class BlockScalingRecipeState(RecipeState): +class MXFP8BlockScalingRecipeState(RecipeState): """Configuration for MXFP8 quantization. MXFP8 quantization does not require state. """ - recipe: BlockScaling + recipe: MXFP8BlockScaling mode: str dtype: tex.DType def __init__( self, - recipe: BlockScaling, + recipe: MXFP8BlockScaling, *, mode: str, num_quantizers: int = 1, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 19951bb2af..b7ee87afb6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -22,7 +22,7 @@ from ._common import _ParameterInitMeta from ..fp8 import ( - BlockScalingRecipeState, + MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, FP8GlobalStateManager, RecipeState, @@ -540,7 +540,7 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd) return - if recipe.block() and isinstance(recipe_state, BlockScalingRecipeState): + if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1321a9f357..2f9de58984 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -86,7 +86,7 @@ def forward( device = inp.device # TODO Support MXFP8 # pylint: disable=fixme - if fp8 and FP8GlobalStateManager.get_fp8_recipe().block(): + if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(): raise NotImplementedError("GroupedLinear does not yet support MXFP8") # Make sure input dimensions are compatible diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index f3fb2c0a20..8346d31a40 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -15,7 +15,7 @@ from transformer_engine.common.recipe import Recipe from ..fp8 import ( - BlockScalingRecipeState, + MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, FP8GlobalStateManager, RecipeState, @@ -260,7 +260,7 @@ def _update_quantization_recipe_state( recipe_state = self._fp8_metas[mode][fp8_meta_key] need_to_reset_recipe_state = ( recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) - ) or (recipe.block() and not isinstance(recipe_state, BlockScalingRecipeState)) + ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) if need_to_reset_recipe_state: self._reset_quantization_recipe_state(recipe=recipe) return @@ -283,7 +283,7 @@ def _update_quantization_recipe_state( recipe_state = fp8_meta[fp8_meta_key] # Reallocate amax history if needed - if recipe.block(): + if recipe.mxfp8(): continue current_length = recipe_state.amax_history.size(0)