diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index ebdb1ab019..040d784129 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -381,7 +381,8 @@ def test_inference_print_str(): ) @pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100") @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn]) -def test_inference_subclass(elem_dtype): +@pytest.mark.parametrize("bias", [True, False]) +def test_inference_subclass(elem_dtype, bias: bool): """ Smoke test for inference compile """ @@ -389,7 +390,7 @@ def test_inference_subclass(elem_dtype): if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") - m = nn.Sequential(nn.Linear(32, 128, bias=False, dtype=torch.bfloat16)) + m = nn.Sequential(nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) config = MXFPConfig() @@ -401,6 +402,8 @@ def test_inference_subclass(elem_dtype): y_mx = m_mx(x) sqnr = compute_error(y_ref, y_mx) if elem_dtype is torch.float8_e4m3fn: - assert sqnr >= 20.0 + assert ( + sqnr >= 20.0 if not bias else 14 + ) # TODO see if this is working as expected else: assert sqnr >= 11.5 diff --git a/torchao/core/config.py b/torchao/core/config.py index fe03ac225b..c9ee24f68c 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -175,6 +175,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: "torchao.quantization", "torchao.sparsity.sparse_api", "torchao.prototype.quantization", + "torchao.prototype.mx_formats", } diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index 7252c33dc9..e707a4889f 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -4,6 +4,7 @@ MXLinearConfig, MXLinearRecipeName, ) +from torchao.prototype.mx_formats.mx_subclass import MXFPConfig # import mx_linear here to register the quantize_ transform logic # ruff: noqa: I001 @@ -14,4 +15,5 @@ "MXInferenceLinearConfig", "MXLinearConfig", "MXLinearRecipeName", + "MXFPConfig", ] diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 93eab0c4ac..87ef095bf2 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -20,6 +20,9 @@ from typing import Any, Dict, Optional import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) from torch.utils._pytree import tree_map import torchao.ops @@ -52,21 +55,11 @@ def decorator(func): return decorator -@implements([aten.detach.default]) -def mx_desugar_op(aten_op, args, kwargs=None): - old = args[0] - new_data = aten_op(old._data, *args[1:], **kwargs) - new = MXTensor( - old._scale_e8m0, - new_data, - old._elem_dtype, - old._block_size, - old._orig_dtype, - old._use_fp4_custom_triton_dequant_kernel, - old._gemm_kernel_choice, - old._pack_fp6, +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(func) ) - return new def _get_gemm_choice( @@ -85,12 +78,15 @@ def _get_gemm_choice( return choice_a if choice_a is not None else choice_b -@implements([aten.mm.default, aten.matmul.default]) -def mx_mm(aten_op, args, kwargs=None): - a = args[0] - b = args[1] - assert isinstance(a, MXTensor) and isinstance(b, MXTensor) +def _addmm_mx_dispatch( + a: MXTensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: + """ + Core implementation shared between mx_mm and mx_addmm. + The only difference is whether bias is None or not. + """ gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) + if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): # real MX gemm backed by torchao's CUTLASS kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] @@ -112,15 +108,21 @@ def mx_mm(aten_op, args, kwargs=None): b._data, a_scale_block.view(torch.float8_e8m0fnu), b_scale_block.view(torch.float8_e8m0fnu), + bias=bias, out_dtype=torch.bfloat16, ) else: assert a._elem_dtype == DTYPE_FP4 assert b._elem_dtype == DTYPE_FP4 assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported" + # FP4 operations res = torchao.ops.mx_fp4_bf16( a._data, b._data, a_scale_block, b_scale_block ) + + # TODO update kernel to accept optional + if bias is not None: + res = res + bias else: # emulated MX gemm a_hp = a.to_dtype(a._orig_dtype) @@ -128,12 +130,41 @@ def mx_mm(aten_op, args, kwargs=None): # assert memory layout we expect to be required in hardware assert a_hp.is_contiguous() assert b_hp.t().is_contiguous() - res = aten_op(a_hp, b_hp) + + # Call appropriate aten_op based on whether bias is provided + if bias is not None: + res = aten_op(bias, a_hp, b_hp) # addmm + else: + res = aten_op(a_hp, b_hp) # mm + return res +@implements([aten.mm.default, aten.matmul.default]) +def mx_mm(func, types, args, kwargs): + a = args[0] + b = args[1] + assert isinstance(a, MXTensor) and isinstance(b, MXTensor) + + return _addmm_mx_dispatch(a, b, func) + + +@implements([aten.addmm.default]) +def mx_addmm(func, types, args, kwargs): + assert ( + isinstance(args[0], torch.Tensor) + and isinstance(args[1], MXTensor) + and isinstance(args[2], MXTensor) + ) + bias = args[0] + a = args[1] + b = args[2] + + return _addmm_mx_dispatch(a, b, func, bias=bias) + + @implements([aten.t.default]) -def mx_t(aten_op, args, kwargs=None): +def mx_t(func, types, args, kwargs): # For now, only transpose(input, 0, 1) is supported. old = args[0] new = MXTensor( @@ -150,7 +181,7 @@ def mx_t(aten_op, args, kwargs=None): @implements([aten.sum.dim_IntList]) -def mx_cast_up_op(aten_op, args, kwargs=None): +def mx_cast_up_op(func, types, args, kwargs): """Be careful with this function, this is a "fallback" op that casts the output of the op to the original precision. And performs the op. @@ -166,11 +197,11 @@ def unwrap(x): new_args = tree_map(unwrap, args) new_kwargs = tree_map(unwrap, kwargs) - return aten_op(*new_args, **new_kwargs) + return func(*new_args, **new_kwargs) @implements([aten.view.default]) -def mx_view_op(aten_op, args, kwargs=None): +def mx_view_op(func, types, args, kwargs): data = args[0]._data new_size = args[1] if args[0]._elem_dtype == DTYPE_FP4: @@ -179,7 +210,7 @@ def mx_view_op(aten_op, args, kwargs=None): elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6: # special case fp6 as we pack 4 elements in 3 bytes new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) - new_data = aten_op(data, new_size, *args[2:], **kwargs) + new_data = func(data, new_size, *args[2:], **kwargs) return MXTensor( args[0]._scale_e8m0, new_data, @@ -192,28 +223,72 @@ def mx_view_op(aten_op, args, kwargs=None): ) +# @implements([aten._to_copy.default]) +# def autocast_to_copy(func, types, args, kwargs): +# """This gets called when running matmul under autocast +# when the input is a MXTensor, presenting as a fp32 +# tensor. +# """ +# assert isinstance(args[0], MXTensor) +# breakpoint() +# assert len(kwargs) == 1 and "dtype" in kwargs, ( +# "Only support dtype kwarg for autocast" +# ) +# assert kwargs["dtype"] in { +# torch.float16, +# torch.bfloat16, +# }, "Only support floating point conversion for autocast w/ MXTensor" +# res = MXTensor( +# args[0]._scale_e8m0, +# args[0]._data, +# args[0]._elem_dtype, +# args[0]._block_size, +# kwargs["dtype"], +# args[0]._use_fp4_custom_triton_dequant_kernel, +# args[0]._gemm_kernel_choice, +# args[0]._pack_fp6, +# ) +# return res + + @implements([aten._to_copy.default]) -def autocast_to_copy(aten_op, args, kwargs=None): - """This gets called when running matmul under autocast - when the input is a MXTensor, presenting as a fp32 - tensor. - """ +def autocast_to_copy(func, types, args, kwargs): + """Autocast + device movement""" assert isinstance(args[0], MXTensor) - assert len(kwargs) == 1 and "dtype" in kwargs, ( - "Only support dtype kwarg for autocast" - ) - assert kwargs["dtype"] in { - torch.float16, - torch.bfloat16, - }, "Only support floating point conversion for autocast w/ MXTensor" - res = MXTensor( - args[0]._scale_e8m0, - args[0]._data, - args[0]._elem_dtype, - args[0]._block_size, - kwargs["dtype"], - args[0]._use_fp4_custom_triton_dequant_kernel, - args[0]._gemm_kernel_choice, - args[0]._pack_fp6, - ) - return res + + # Handle dtype parameter + dtype = kwargs.pop("dtype", None) + if dtype is not None: + assert dtype in { + torch.float16, + torch.bfloat16, + }, "Only support floating point conversion for autocast w/ MXTensor" + + # Handle device parameter + device = kwargs.pop("device", None) + if device is not None: + # Apply device change using _apply_fn_to_data + tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device)) + tensor = return_and_correct_aliasing(func, args, {}, tensor) + else: + tensor = args[0] + + # Verify no other kwargs remain + assert len(kwargs) == 0, "Only support dtype and device kwargs for autocast" + + # If dtype is specified, create a new MXTensor with the requested dtype + if dtype is not None: + res = MXTensor( + tensor._scale_e8m0, + tensor._data, + tensor._elem_dtype, + tensor._block_size, + dtype, + tensor._use_fp4_custom_triton_dequant_kernel, + tensor._gemm_kernel_choice, + tensor._pack_fp6, + ) + return res + + # If only device was changed, return the device-changed tensor + return tensor diff --git a/torchao/prototype/mx_formats/mx_subclass.py b/torchao/prototype/mx_formats/mx_subclass.py index 5961ca55d7..2c0897d195 100644 --- a/torchao/prototype/mx_formats/mx_subclass.py +++ b/torchao/prototype/mx_formats/mx_subclass.py @@ -24,7 +24,7 @@ from torchao.quantization.transform_module import ( register_quantize_module_handler, ) -from torchao.utils import is_sm_at_least_100 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100 @dataclass @@ -115,3 +115,9 @@ def _mx_inference_linear_transform(module: torch.nn.Module, config: MXFPConfig): module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module + + +if TORCH_VERSION_AT_LEAST_2_5: + torch.serialization.add_safe_globals( + [MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp] + ) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index f3aca15a73..0a874f6e61 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -18,7 +18,7 @@ """ from enum import Enum, auto -from typing import Dict, Union +from typing import Callable, Dict, Union import torch @@ -562,7 +562,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): from torchao.prototype.mx_formats.mx_ops import MX_OPS_TABLE if func in MX_OPS_TABLE: - return MX_OPS_TABLE[func](func, args, kwargs) + return MX_OPS_TABLE[func](func, types, args, kwargs) raise NotImplementedError(f"{func} not implemented") @@ -631,5 +631,21 @@ def __tensor_unflatten__( metadata["_pack_fp6"], ) + def _apply_fn_to_data(self, fn: Callable): + """Applies a fn to all tensor components stored on this class""" + tensor_names, ctx = self.__tensor_flatten__() + + # Apply the function to each tensor component + new_tensors = {} + for name in tensor_names: + new_tensors[name] = fn(getattr(self, name)) + + return self.__class__.__tensor_unflatten__( + new_tensors, + ctx, + None, # outer_size parameter + None, # outer_stride parameter + ) + # Do not force the MXTensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl