Skip to content

Add support for Int4GroupwisePreshuffleTensor for fbgemm #2421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions test/quantization/quantize_/test_int4_groupwise_preshuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao.quantization import (
FbgemmConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
_is_fbgemm_genai_gpu_available,
is_sm_at_least_90,
)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
)
class TestInt4GroupwisePreshuffleTensor(TestCase):
def setUp(self):
self.config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)
self.bmm_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

def test_linear(self):
dtype = torch.bfloat16
device = "cuda"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, self.config)
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

def test_bmm(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight

def forward(self, x):
return torch.bmm(x, self.weight)

dtype = torch.bfloat16
device = "cuda"
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
m = M(weight).eval()
original = m(input)
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)

def test_to_device(self):
for device in self.GPU_DEVICES:
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear.to(device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear.to(device=device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear.to(device)

def test_module_path(self):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>",
)


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@
"to_fbgemm_fp8",
"FbgemmFp8Tensor",
"Int8DynamicActInt4WeightCPULayout",
"Int4GroupwisePreshuffleTensor",
]
5 changes: 5 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
dequantize_affine,
quantize_affine,
)
from .quantize_ import (
Int4GroupwisePreshuffleTensor,
)
from .smoothquant import (
SmoothFakeDynamicallyQuantizedLinear,
SmoothFakeDynQuantMixin,
Expand Down Expand Up @@ -149,6 +152,8 @@
"AOPerModuleConfig",
"ModuleFqnToConfig",
"FbgemmConfig",
# tensor subclasses
"Int4GroupwisePreshuffleTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
31 changes: 16 additions & 15 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
and mixed GEMM kernels
"""

import importlib.util
import logging
import types
import warnings
Expand Down Expand Up @@ -68,6 +67,9 @@
LinearActivationWeightObservedTensor,
)
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
from torchao.quantization.quantize_ import (
Int4GroupwisePreshuffleTensor,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
Expand All @@ -79,7 +81,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_fbcode,
_is_fbgemm_genai_gpu_available,
is_MI300,
is_sm_at_least_89,
is_sm_at_least_90,
Expand Down Expand Up @@ -2046,18 +2048,12 @@ class FbgemmConfig(AOBaseConfig):
block_size: Optional[List[int]] = None
activation_scale_ub: Optional[float] = None
transpose_input: bool = False
preshuffle: bool = False


@register_quantize_module_handler(FbgemmConfig)
def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
# TODO: use is_package_at_least("fbgemm_gpu", "1.2.0") when
# https://github.com/pytorch/FBGEMM/issues/4198 is fixed
if importlib.util.find_spec("fbgemm_gpu") is None:
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")

import fbgemm_gpu.experimental.gen_ai # noqa: F401

if not is_fbcode() and fbgemm_gpu.__version__ < "1.2.0":
if not _is_fbgemm_genai_gpu_available():
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")

_SUPPORTED_DTYPES = {
Expand All @@ -2070,11 +2066,16 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
and (config.weight_dtype == torch.int4)
and (config.output_dtype == torch.bfloat16)
):
weight = to_fbgemm_int4(
module.weight,
config.block_size,
config.transpose_input,
)
if config.preshuffle:
weight = Int4GroupwisePreshuffleTensor.from_float(
module.weight, config.block_size
)
else:
weight = to_fbgemm_int4(
module.weight,
config.block_size,
config.transpose_input,
)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module
Expand Down
9 changes: 9 additions & 0 deletions torchao/quantization/quantize_/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .int4_groupwise_preshuffle_tensor import (
Int4GroupwisePreshuffleTensor,
)

Int4GroupwisePreshuffleTensor.__module__ = "torchao.quantization"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we confirm (by actually testing it) that we can change the directory location later without breaking BC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a test that verifies the loaded weight have module path torchao.quantization.Int4GroupwisePreshuffleTensor, this (type(tensor))is used in the load code path: https://github.com/pytorch/pytorch/blob/d4b8857e51a089b7e0e722689398c5c3ada274c9/torch/_tensor.py#L262 which gives us good confidence that it would work as long as we do this

but I can do a e2e test a bit later by uploading the file in huggingface hub and change the path locally to verify as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in #2437


__all__ = [
"Int4GroupwisePreshuffleTensor",
]
Loading
Loading