From 72e6b387da9b2853b17450c18cfad33334eea849 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 19 Aug 2025 07:15:35 -0700 Subject: [PATCH] prototype float8 training quantizeable mm module Summary: Adds a `_QuantizeableMM` module, and a `_Float8MM` quantized version it. This is to enable quantizing calls to `torch.mm` where none of the inputs are weights. This requires modeling changes at the callsite. For now adding this as a prototype with names underscored to help test out a customer need. We can make the API more official at a later time after we get more signal on product market fit. Usage: ```python class M(nn.Module): def __init__(self): super().__init__() self.mm = _QuantizeableMM() def forward(self, a, b): c = self.mm(a, b) return c config = Float8LinearConfig.from_recipe_name(recipe_name) m_ref = M() m = copy.deepcopy(m_ref) m = convert_to_float8_training(m, config=config) ``` Test Plan: ```bash pytest test/float8/test_base.py -s -x -k quantizeable_mm ``` Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 72 ++++++++++++++++++++++++ torchao/float8/float8_linear_utils.py | 37 ++++++++++-- torchao/float8/quantizeable_mm.py | 81 +++++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 5 deletions(-) create mode 100644 torchao/float8/quantizeable_mm.py diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 1f9ae19346..343b7570be 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -41,6 +41,7 @@ fp8_tensor_statistics, tensor_to_scale, ) +from torchao.float8.quantizeable_mm import _Float8MM, _QuantizeableMM from torchao.testing.training.test_utils import get_test_float8_linear_config from torchao.testing.utils import skip_if_rocm from torchao.utils import ( @@ -478,6 +479,58 @@ def test_quantize(self): m(x) +class TestFloat8MM: + @pytest.mark.parametrize( + "recipe_name", + [ + Float8LinearRecipeName.TENSORWISE, + Float8LinearRecipeName.ROWWISE, + Float8LinearRecipeName.ROWWISE_WITH_GW_HP, + ], + ) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_90(), "CUDA capability < 9.0" + ) + def test_quantizeable_mm_hello_world(self, recipe_name): + class M(nn.Module): + def __init__(self): + super().__init__() + self.mm = _QuantizeableMM() + + def forward(self, a, b): + c = self.mm(a, b) + return c + + config = Float8LinearConfig.from_recipe_name(recipe_name) + m_ref = M() + m = copy.deepcopy(m_ref) + m = convert_to_float8_training(m, config=config) + + a_ref = torch.randn(32, 64, device="cuda", requires_grad=True) + b_ref = torch.randn(64, 128, device="cuda", requires_grad=True) + go_ref = torch.randn(32, 128, device="cuda") + + a = copy.deepcopy(a_ref) + b = copy.deepcopy(b_ref) + go = copy.deepcopy(go_ref) + + with torch.autocast("cuda", dtype=torch.bfloat16): + c_ref = m_ref(a_ref, b_ref) + c = m(a, b) + c_ref.backward(go_ref) + c.backward(go) + + with torch.no_grad(): + sqnr_c = compute_error(c_ref, c) + sqnr_ga = compute_error(a_ref.grad, a.grad) + sqnr_gb = compute_error(b_ref.grad, b.grad) + + assert sqnr_c > 20.0 + assert sqnr_ga > 20.0 + assert sqnr_gb > 20.0 + + class TestScaledMM: @unittest.skipIf( not is_sm_at_least_89(), @@ -802,6 +855,25 @@ def test_fp8_tensor_statistics(self): (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (tensor_len, tensor_len)) + def test_swap_quantizeable_mm(self): + module = _QuantizeableMM() + config = Float8LinearConfig() + module = convert_to_float8_training(module, config=config) + self.assertIsInstance(module, _Float8MM) + + class M(nn.Module): + def __init__(self): + super().__init__() + self.mm = _QuantizeableMM() + + def forward(self, a, b): + c = self.mm(a, b) + return c + + module2 = M() + module2 = convert_to_float8_training(module2, config=config) + self.assertIsInstance(module2.mm, _Float8MM) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index e0def790b8..557b73f4a1 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -12,14 +12,16 @@ from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName from torchao.float8.float8_linear import Float8Linear +from torchao.float8.quantizeable_mm import _Float8MM, _QuantizeableMM log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) -def swap_linear_layers( +def _swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], + mm_from_float_func: Callable[[_QuantizeableMM], _Float8MM], *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, ) -> nn.Module: @@ -50,6 +52,16 @@ def swap_linear_layers( return from_float_func( module, ) + elif isinstance(module, _QuantizeableMM) and ( + module_filter_fn is None or module_filter_fn(module, "") + ): + if len(list(module.children())) > 0: + raise AssertionError( + f"Does not support a root nn.Linear with children: {module}" + ) + return mm_from_float_func( + module, + ) root_module = module @@ -78,6 +90,15 @@ def post_order_traversal( new_linear_module = from_float_func(module) cur_module_name = cur_fqn.split(".")[-1] setattr(parent_module, cur_module_name, new_linear_module) + elif isinstance(module, _QuantizeableMM) and ( + module_filter_fn is None or module_filter_fn(module, cur_fqn) + ): + assert parent_module is not None, ( + f"Linear root module should return early: {module}" + ) + new_linear_module = mm_from_float_func(module) + cur_module_name = cur_fqn.split(".")[-1] + setattr(parent_module, cur_module_name, new_linear_module) post_order_traversal(root_module) return root_module @@ -106,17 +127,23 @@ def convert_to_float8_training( if config is None: config = Float8LinearConfig() - from_float = lambda m: Float8Linear.from_float( + linear_from_float = lambda m: Float8Linear.from_float( m, config=config, ) - - return swap_linear_layers( + quantizeable_mm_from_float = lambda m: _Float8MM.from_float( + m, + config=config, + ) + res = _swap_linear_layers( module, - from_float, + linear_from_float, + quantizeable_mm_from_float, module_filter_fn=module_filter_fn, ) + return res + def _auto_filter_for_recipe( recipe: Union[str, Float8LinearRecipeName], filter_fqns: List[str] diff --git a/torchao/float8/quantizeable_mm.py b/torchao/float8/quantizeable_mm.py new file mode 100644 index 0000000000..932d5f8eb9 --- /dev/null +++ b/torchao/float8/quantizeable_mm.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torchao.float8.float8_linear import matmul_with_hp_or_float8_args +from torchao.float8.float8_training_tensor import ( + LinearMMConfig, + ScaledMMConfig, +) + + +class _QuantizeableMM(torch.nn.Module): + """ + A modularized version of `torch.mm` which is easy to module swap to + a quantized version. + + Note: this is a prototype API which may change in a future release. + """ + + def forward(self, a, b): + return torch.mm(a, b) + + +class _Float8MM(torch.nn.Module): + """ + A float8 quantized version of `_QuantizeableMM`. + + Note: this is a prototype API which may change in a future release. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.linear_mm_config = LinearMMConfig( + # output + ScaledMMConfig( + config.emulate, + self.config.gemm_config_output.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_input + ScaledMMConfig( + config.emulate, + self.config.gemm_config_grad_input.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + config.emulate, + self.config.gemm_config_grad_weight.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + ) + + def forward(self, a, b): + if torch.is_autocast_enabled(): + # For now, hardcode to GPU's autocast dtype + # if we need CPU support in the future, we can add it + autocast_dtype = torch.get_autocast_gpu_dtype() + a = a.to(autocast_dtype) + b = b.to(autocast_dtype) + + c = matmul_with_hp_or_float8_args.apply( + a, + b, + self.linear_mm_config, + self.config, + ) + return c + + @classmethod + def from_float(cls, mod, config): + assert isinstance(mod, _QuantizeableMM) + return _Float8MM(config)