Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename block scaling recipe #1442

Open
wants to merge 1 commit into
base: release_v2.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.distributed as dist

from transformer_engine.common.recipe import (
BlockScaling,
MXFP8BlockScaling,
DelayedScaling,
Format,
Recipe,
Expand Down Expand Up @@ -44,7 +44,7 @@ def quantization_recipe() -> Recipe:
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
if QUANTIZATION == "mxfp8":
return BlockScaling()
return MXFP8BlockScaling()
return te.fp8.get_default_fp8_recipe()


Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/distributed/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.BlockScaling(
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ModelConfig:

fp8_recipes = [
recipe.DelayedScaling(),
recipe.BlockScaling(),
recipe.MXFP8BlockScaling(),
]

# Supported data types
Expand Down Expand Up @@ -315,7 +315,7 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.block() and not mxfp8_available:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

# Run model with different CUDA graph settings.
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.BlockScaling(
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
Expand Down
18 changes: 9 additions & 9 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq
mask_types = ["causal", "no_mask"]

fp8_recipes = [
recipe.BlockScaling(),
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
]

Expand Down Expand Up @@ -556,7 +556,7 @@ def _test_e2e_selective_recompute(
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

config = model_configs[model]
Expand Down Expand Up @@ -668,7 +668,7 @@ def test_gpt_full_activation_recompute(
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

config = model_configs[model]
Expand Down Expand Up @@ -1418,7 +1418,7 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f
if fp8:
if recipe.delayed():
split_size = 16
if recipe.block():
if recipe.mxfp8():
split_size = 128
m = config.seq_len // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
Expand Down Expand Up @@ -1463,9 +1463,9 @@ def test_grouped_linear_accuracy(
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.block(): # TODO(ksivamani): debug mismatches
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")

config = model_configs[model]
Expand Down Expand Up @@ -1648,9 +1648,9 @@ def test_padding_grouped_linear_accuracy(
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.block(): # TODO(ksivamani): debug mismatches
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")

config = model_configs[model]
Expand Down Expand Up @@ -1860,7 +1860,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

config = model_configs[model]
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ class Recipe:
Base recipe class.
"""

def block(self):
"""Whether the given recipe is block scaling."""
return isinstance(self, BlockScaling)
def mxfp8(self):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)

def delayed(self):
"""Whether the given recipe is delayed scaling."""
Expand Down Expand Up @@ -162,7 +162,7 @@ def __repr__(self) -> str:


@dataclass()
class BlockScaling(Recipe):
class MXFP8BlockScaling(Recipe):
"""
Use the current scaling factor strategy.

Expand Down
24 changes: 12 additions & 12 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, BlockScaling
from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling

from .constants import dist_group_type
from .utils import get_device_compute_capability
Expand Down Expand Up @@ -46,7 +46,7 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if get_device_compute_capability() >= (10, 0): # blackwell and above
return BlockScaling()
return MXFP8BlockScaling()
return DelayedScaling()


Expand Down Expand Up @@ -211,7 +211,7 @@ def add_fp8_tensors_to_global_buffer(
wrapper. For non CG case, it's called from within the module.
"""

if fp8_meta["recipe"].block():
if fp8_meta["recipe"].mxfp8():
return

# Every module must call this function exactly once since
Expand Down Expand Up @@ -414,7 +414,7 @@ def fp8_autocast_enter(
if enabled:
fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
assert fp8_available, reason_for_no_fp8
if isinstance(fp8_recipe, BlockScaling):
if isinstance(fp8_recipe, MXFP8BlockScaling):
mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
assert mxfp8_available, reason_for_no_mxfp8

Expand All @@ -434,7 +434,7 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -
to ensure both forward steps are numerically same.
"""

if fp8_meta["recipe"].block():
if fp8_meta["recipe"].mxfp8():
return

buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
Expand All @@ -460,7 +460,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non
1 forward for indentical numerical outputs.
"""

if fp8_meta["recipe"].block():
if fp8_meta["recipe"].mxfp8():
return

# Store updated amaxes and scales from phase 1 post forward.
Expand All @@ -479,7 +479,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""

if fp8_meta["recipe"].block():
if fp8_meta["recipe"].mxfp8():
return

fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
Expand Down Expand Up @@ -741,8 +741,8 @@ def create(
cls = None
if recipe.delayed():
cls = DelayedScalingRecipeState
elif recipe.block():
cls = BlockScalingRecipeState
elif recipe.mxfp8():
cls = MXFP8BlockScalingRecipeState
else:
raise ValueError("{recipe.__class__.__name__} is not supported")
return cls(
Expand Down Expand Up @@ -813,20 +813,20 @@ def make_quantizers(self) -> list:
]


class BlockScalingRecipeState(RecipeState):
class MXFP8BlockScalingRecipeState(RecipeState):
"""Configuration for MXFP8 quantization.

MXFP8 quantization does not require state.

"""

recipe: BlockScaling
recipe: MXFP8BlockScaling
mode: str
dtype: tex.DType

def __init__(
self,
recipe: BlockScaling,
recipe: MXFP8BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ._common import _ParameterInitMeta
from ..fp8 import (
BlockScalingRecipeState,
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
Expand Down Expand Up @@ -540,7 +540,7 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
return
if recipe.block() and isinstance(recipe_state, BlockScalingRecipeState):
if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return

# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def forward(
device = inp.device

# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().block():
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("GroupedLinear does not yet support MXFP8")

# Make sure input dimensions are compatible
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from transformer_engine.common.recipe import Recipe
from ..fp8 import (
BlockScalingRecipeState,
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
Expand Down Expand Up @@ -260,7 +260,7 @@ def _update_quantization_recipe_state(
recipe_state = self._fp8_metas[mode][fp8_meta_key]
need_to_reset_recipe_state = (
recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)
) or (recipe.block() and not isinstance(recipe_state, BlockScalingRecipeState))
) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
if need_to_reset_recipe_state:
self._reset_quantization_recipe_state(recipe=recipe)
return
Expand All @@ -283,7 +283,7 @@ def _update_quantization_recipe_state(
recipe_state = fp8_meta[fp8_meta_key]

# Reallocate amax history if needed
if recipe.block():
if recipe.mxfp8():
continue

current_length = recipe_state.amax_history.size(0)
Expand Down
Loading