Skip to content

various fixes #2133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: drisspg/stack/50
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,16 @@ def test_inference_print_str():
)
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
def test_inference_subclass(elem_dtype):
@pytest.mark.parametrize("bias", [True, False])
def test_inference_subclass(elem_dtype, bias: bool):
"""
Smoke test for inference compile
"""
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")

m = nn.Sequential(nn.Linear(32, 128, bias=False, dtype=torch.bfloat16))
m = nn.Sequential(nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
config = MXFPConfig()
Expand All @@ -401,6 +402,8 @@ def test_inference_subclass(elem_dtype):
y_mx = m_mx(x)
sqnr = compute_error(y_ref, y_mx)
if elem_dtype is torch.float8_e4m3fn:
assert sqnr >= 20.0
assert (
sqnr >= 20.0 if not bias else 14
) # TODO see if this is working as expected
else:
assert sqnr >= 11.5
1 change: 1 addition & 0 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
"torchao.quantization",
"torchao.sparsity.sparse_api",
"torchao.prototype.quantization",
"torchao.prototype.mx_formats",
}


Expand Down
2 changes: 2 additions & 0 deletions torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
MXLinearConfig,
MXLinearRecipeName,
)
from torchao.prototype.mx_formats.mx_subclass import MXFPConfig

# import mx_linear here to register the quantize_ transform logic
# ruff: noqa: I001
Expand All @@ -14,4 +15,5 @@
"MXInferenceLinearConfig",
"MXLinearConfig",
"MXLinearRecipeName",
"MXFPConfig",
]
171 changes: 123 additions & 48 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from typing import Any, Dict, Optional

import torch
from torch.utils._python_dispatch import (
return_and_correct_aliasing,
)
from torch.utils._pytree import tree_map

import torchao.ops
Expand Down Expand Up @@ -52,21 +55,11 @@ def decorator(func):
return decorator


@implements([aten.detach.default])
def mx_desugar_op(aten_op, args, kwargs=None):
old = args[0]
new_data = aten_op(old._data, *args[1:], **kwargs)
new = MXTensor(
old._scale_e8m0,
new_data,
old._elem_dtype,
old._block_size,
old._orig_dtype,
old._use_fp4_custom_triton_dequant_kernel,
old._gemm_kernel_choice,
old._pack_fp6,
@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(func)
)
return new


def _get_gemm_choice(
Expand All @@ -85,12 +78,15 @@ def _get_gemm_choice(
return choice_a if choice_a is not None else choice_b


@implements([aten.mm.default, aten.matmul.default])
def mx_mm(aten_op, args, kwargs=None):
a = args[0]
b = args[1]
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
def _addmm_mx_dispatch(
a: MXTensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Core implementation shared between mx_mm and mx_addmm.
The only difference is whether bias is None or not.
"""
gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice)

if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
# real MX gemm backed by torchao's CUTLASS kernels
M, K, N = a.shape[0], a.shape[1], b.shape[1]
Expand All @@ -112,28 +108,63 @@ def mx_mm(aten_op, args, kwargs=None):
b._data,
a_scale_block.view(torch.float8_e8m0fnu),
b_scale_block.view(torch.float8_e8m0fnu),
bias=bias,
out_dtype=torch.bfloat16,
)
else:
assert a._elem_dtype == DTYPE_FP4
assert b._elem_dtype == DTYPE_FP4
assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported"
# FP4 operations
res = torchao.ops.mx_fp4_bf16(
a._data, b._data, a_scale_block, b_scale_block
)

# TODO update kernel to accept optional
if bias is not None:
res = res + bias
else:
# emulated MX gemm
a_hp = a.to_dtype(a._orig_dtype)
b_hp = b.to_dtype(b._orig_dtype)
# assert memory layout we expect to be required in hardware
assert a_hp.is_contiguous()
assert b_hp.t().is_contiguous()
res = aten_op(a_hp, b_hp)

# Call appropriate aten_op based on whether bias is provided
if bias is not None:
res = aten_op(bias, a_hp, b_hp) # addmm
else:
res = aten_op(a_hp, b_hp) # mm

return res


@implements([aten.mm.default, aten.matmul.default])
def mx_mm(func, types, args, kwargs):
a = args[0]
b = args[1]
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)

return _addmm_mx_dispatch(a, b, func)


@implements([aten.addmm.default])
def mx_addmm(func, types, args, kwargs):
assert (
isinstance(args[0], torch.Tensor)
and isinstance(args[1], MXTensor)
and isinstance(args[2], MXTensor)
)
bias = args[0]
a = args[1]
b = args[2]

return _addmm_mx_dispatch(a, b, func, bias=bias)


@implements([aten.t.default])
def mx_t(aten_op, args, kwargs=None):
def mx_t(func, types, args, kwargs):
# For now, only transpose(input, 0, 1) is supported.
old = args[0]
new = MXTensor(
Expand All @@ -150,7 +181,7 @@ def mx_t(aten_op, args, kwargs=None):


@implements([aten.sum.dim_IntList])
def mx_cast_up_op(aten_op, args, kwargs=None):
def mx_cast_up_op(func, types, args, kwargs):
"""Be careful with this function, this is a "fallback" op that
casts the output of the op to the original precision. And performs the op.

Expand All @@ -166,11 +197,11 @@ def unwrap(x):

new_args = tree_map(unwrap, args)
new_kwargs = tree_map(unwrap, kwargs)
return aten_op(*new_args, **new_kwargs)
return func(*new_args, **new_kwargs)


@implements([aten.view.default])
def mx_view_op(aten_op, args, kwargs=None):
def mx_view_op(func, types, args, kwargs):
data = args[0]._data
new_size = args[1]
if args[0]._elem_dtype == DTYPE_FP4:
Expand All @@ -179,7 +210,7 @@ def mx_view_op(aten_op, args, kwargs=None):
elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6:
# special case fp6 as we pack 4 elements in 3 bytes
new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous())
new_data = aten_op(data, new_size, *args[2:], **kwargs)
new_data = func(data, new_size, *args[2:], **kwargs)
return MXTensor(
args[0]._scale_e8m0,
new_data,
Expand All @@ -192,28 +223,72 @@ def mx_view_op(aten_op, args, kwargs=None):
)


# @implements([aten._to_copy.default])
# def autocast_to_copy(func, types, args, kwargs):
# """This gets called when running matmul under autocast
# when the input is a MXTensor, presenting as a fp32
# tensor.
# """
# assert isinstance(args[0], MXTensor)
# breakpoint()
# assert len(kwargs) == 1 and "dtype" in kwargs, (
# "Only support dtype kwarg for autocast"
# )
# assert kwargs["dtype"] in {
# torch.float16,
# torch.bfloat16,
# }, "Only support floating point conversion for autocast w/ MXTensor"
# res = MXTensor(
# args[0]._scale_e8m0,
# args[0]._data,
# args[0]._elem_dtype,
# args[0]._block_size,
# kwargs["dtype"],
# args[0]._use_fp4_custom_triton_dequant_kernel,
# args[0]._gemm_kernel_choice,
# args[0]._pack_fp6,
# )
# return res


@implements([aten._to_copy.default])
def autocast_to_copy(aten_op, args, kwargs=None):
"""This gets called when running matmul under autocast
when the input is a MXTensor, presenting as a fp32
tensor.
"""
def autocast_to_copy(func, types, args, kwargs):
"""Autocast + device movement"""
assert isinstance(args[0], MXTensor)
assert len(kwargs) == 1 and "dtype" in kwargs, (
"Only support dtype kwarg for autocast"
)
assert kwargs["dtype"] in {
torch.float16,
torch.bfloat16,
}, "Only support floating point conversion for autocast w/ MXTensor"
res = MXTensor(
args[0]._scale_e8m0,
args[0]._data,
args[0]._elem_dtype,
args[0]._block_size,
kwargs["dtype"],
args[0]._use_fp4_custom_triton_dequant_kernel,
args[0]._gemm_kernel_choice,
args[0]._pack_fp6,
)
return res

# Handle dtype parameter
dtype = kwargs.pop("dtype", None)
if dtype is not None:
assert dtype in {
torch.float16,
torch.bfloat16,
}, "Only support floating point conversion for autocast w/ MXTensor"

# Handle device parameter
device = kwargs.pop("device", None)
if device is not None:
# Apply device change using _apply_fn_to_data
tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device))
tensor = return_and_correct_aliasing(func, args, {}, tensor)
else:
tensor = args[0]

# Verify no other kwargs remain
assert len(kwargs) == 0, "Only support dtype and device kwargs for autocast"

# If dtype is specified, create a new MXTensor with the requested dtype
if dtype is not None:
res = MXTensor(
tensor._scale_e8m0,
tensor._data,
tensor._elem_dtype,
tensor._block_size,
dtype,
tensor._use_fp4_custom_triton_dequant_kernel,
tensor._gemm_kernel_choice,
tensor._pack_fp6,
)
return res

# If only device was changed, return the device-changed tensor
return tensor
8 changes: 7 additions & 1 deletion torchao/prototype/mx_formats/mx_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.utils import is_sm_at_least_100
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100


@dataclass
Expand Down Expand Up @@ -115,3 +115,9 @@ def _mx_inference_linear_transform(module: torch.nn.Module, config: MXFPConfig):
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals(
[MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]
)
20 changes: 18 additions & 2 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

from enum import Enum, auto
from typing import Dict, Union
from typing import Callable, Dict, Union

import torch

Expand Down Expand Up @@ -562,7 +562,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
from torchao.prototype.mx_formats.mx_ops import MX_OPS_TABLE

if func in MX_OPS_TABLE:
return MX_OPS_TABLE[func](func, args, kwargs)
return MX_OPS_TABLE[func](func, types, args, kwargs)

raise NotImplementedError(f"{func} not implemented")

Expand Down Expand Up @@ -631,5 +631,21 @@ def __tensor_unflatten__(
metadata["_pack_fp6"],
)

def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to all tensor components stored on this class"""
tensor_names, ctx = self.__tensor_flatten__()

# Apply the function to each tensor component
new_tensors = {}
for name in tensor_names:
new_tensors[name] = fn(getattr(self, name))

return self.__class__.__tensor_unflatten__(
new_tensors,
ctx,
None, # outer_size parameter
None, # outer_stride parameter
)

# Do not force the MXTensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl
Loading