Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
fp8_tensor_statistics,
tensor_to_scale,
)
from torchao.float8.quantizeable_mm import _Float8MM, _QuantizeableMM
from torchao.testing.training.test_utils import get_test_float8_linear_config
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
Expand Down Expand Up @@ -478,6 +479,58 @@ def test_quantize(self):
m(x)


class TestFloat8MM:
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.TENSORWISE,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(
torch.cuda.is_available() and not is_sm_at_least_90(), "CUDA capability < 9.0"
)
def test_quantizeable_mm_hello_world(self, recipe_name):
class M(nn.Module):
def __init__(self):
super().__init__()
self.mm = _QuantizeableMM()

def forward(self, a, b):
c = self.mm(a, b)
return c

config = Float8LinearConfig.from_recipe_name(recipe_name)
m_ref = M()
m = copy.deepcopy(m_ref)
m = convert_to_float8_training(m, config=config)

a_ref = torch.randn(32, 64, device="cuda", requires_grad=True)
b_ref = torch.randn(64, 128, device="cuda", requires_grad=True)
go_ref = torch.randn(32, 128, device="cuda")

a = copy.deepcopy(a_ref)
b = copy.deepcopy(b_ref)
go = copy.deepcopy(go_ref)

with torch.autocast("cuda", dtype=torch.bfloat16):
c_ref = m_ref(a_ref, b_ref)
c = m(a, b)
c_ref.backward(go_ref)
c.backward(go)

with torch.no_grad():
sqnr_c = compute_error(c_ref, c)
sqnr_ga = compute_error(a_ref.grad, a.grad)
sqnr_gb = compute_error(b_ref.grad, b.grad)

assert sqnr_c > 20.0
assert sqnr_ga > 20.0
assert sqnr_gb > 20.0


class TestScaledMM:
@unittest.skipIf(
not is_sm_at_least_89(),
Expand Down Expand Up @@ -802,6 +855,25 @@ def test_fp8_tensor_statistics(self):
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype)
self.assertEqual((zero_cnt, max_cnt), (tensor_len, tensor_len))

def test_swap_quantizeable_mm(self):
module = _QuantizeableMM()
config = Float8LinearConfig()
module = convert_to_float8_training(module, config=config)
self.assertIsInstance(module, _Float8MM)

class M(nn.Module):
def __init__(self):
super().__init__()
self.mm = _QuantizeableMM()

def forward(self, a, b):
c = self.mm(a, b)
return c

module2 = M()
module2 = convert_to_float8_training(module2, config=config)
self.assertIsInstance(module2.mm, _Float8MM)


if __name__ == "__main__":
pytest.main([__file__])
37 changes: 32 additions & 5 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@

from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.quantizeable_mm import _Float8MM, _QuantizeableMM

log = logging.getLogger(__name__)
log.addHandler(logging.NullHandler())


def swap_linear_layers(
def _swap_linear_layers(
module: nn.Module,
from_float_func: Callable[[nn.Linear], nn.Linear],
mm_from_float_func: Callable[[_QuantizeableMM], _Float8MM],
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
) -> nn.Module:
Expand Down Expand Up @@ -50,6 +52,16 @@ def swap_linear_layers(
return from_float_func(
module,
)
elif isinstance(module, _QuantizeableMM) and (
module_filter_fn is None or module_filter_fn(module, "")
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
)
return mm_from_float_func(
module,
)

root_module = module

Expand Down Expand Up @@ -78,6 +90,15 @@ def post_order_traversal(
new_linear_module = from_float_func(module)
cur_module_name = cur_fqn.split(".")[-1]
setattr(parent_module, cur_module_name, new_linear_module)
elif isinstance(module, _QuantizeableMM) and (
module_filter_fn is None or module_filter_fn(module, cur_fqn)
):
assert parent_module is not None, (
f"Linear root module should return early: {module}"
)
new_linear_module = mm_from_float_func(module)
cur_module_name = cur_fqn.split(".")[-1]
setattr(parent_module, cur_module_name, new_linear_module)

post_order_traversal(root_module)
return root_module
Expand Down Expand Up @@ -106,17 +127,23 @@ def convert_to_float8_training(
if config is None:
config = Float8LinearConfig()

from_float = lambda m: Float8Linear.from_float(
linear_from_float = lambda m: Float8Linear.from_float(
m,
config=config,
)

return swap_linear_layers(
quantizeable_mm_from_float = lambda m: _Float8MM.from_float(
m,
config=config,
)
res = _swap_linear_layers(
module,
from_float,
linear_from_float,
quantizeable_mm_from_float,
module_filter_fn=module_filter_fn,
)

return res


def _auto_filter_for_recipe(
recipe: Union[str, Float8LinearRecipeName], filter_fqns: List[str]
Expand Down
81 changes: 81 additions & 0 deletions torchao/float8/quantizeable_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import torch

from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
from torchao.float8.float8_training_tensor import (
LinearMMConfig,
ScaledMMConfig,
)


class _QuantizeableMM(torch.nn.Module):
"""
A modularized version of `torch.mm` which is easy to module swap to
a quantized version.

Note: this is a prototype API which may change in a future release.
"""

def forward(self, a, b):
return torch.mm(a, b)


class _Float8MM(torch.nn.Module):
"""
A float8 quantized version of `_QuantizeableMM`.

Note: this is a prototype API which may change in a future release.
"""

def __init__(self, config):
super().__init__()
self.config = config
self.linear_mm_config = LinearMMConfig(
# output
ScaledMMConfig(
config.emulate,
self.config.gemm_config_output.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
config.emulate,
self.config.gemm_config_grad_input.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
config.emulate,
self.config.gemm_config_grad_weight.use_fast_accum,
False,
self.config.pad_inner_dim,
),
)

def forward(self, a, b):
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
a = a.to(autocast_dtype)
b = b.to(autocast_dtype)

c = matmul_with_hp_or_float8_args.apply(
a,
b,
self.linear_mm_config,
self.config,
)
return c

@classmethod
def from_float(cls, mod, config):
assert isinstance(mod, _QuantizeableMM)
return _Float8MM(config)
Loading