Skip to content

Using autoround's implementation for gguf q4_k #2042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchao/prototype/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .gguf import GGUFWeightOnlyConfig
from .gguf import ARGGUFWeightOnlyConfig, GGUFWeightOnlyConfig

__all__ = [
"GGUFWeightOnlyConfig",
"ARGGUFWeightOnlyConfig",
]
3 changes: 2 additions & 1 deletion torchao/prototype/quantization/gguf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .api import GGUFWeightOnlyConfig
from .api import ARGGUFWeightOnlyConfig, GGUFWeightOnlyConfig
from .gguf_quantized_tensor import (
GGUFQuantizedTensor,
)

__all__ = [
"GGUFQuantizedTensor",
"GGUFWeightOnlyConfig",
"ARGGUFWeightOnlyConfig",
]
37 changes: 36 additions & 1 deletion torchao/prototype/quantization/gguf/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from torchao.core.config import AOBaseConfig
from torchao.quantization.transform_module import register_quantize_module_handler

from .gguf_quantized_tensor import GGUFQuantizedTensor
from .gguf_quantized_tensor import ARGGUFQuantizedTensor, GGUFQuantizedTensor

__all__ = [
"GGUFWeightOnlyConfig",
"ARGGUFWeightOnlyConfig",
]


Expand Down Expand Up @@ -50,3 +51,37 @@ def _gguf_weight_only_transform(
)
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
return module


@dataclass
class ARGGUFWeightOnlyConfig(AOBaseConfig):
bits: int = 4
group_size: int = 32
super_group_size: int = 8
super_bits: int = 6


@register_quantize_module_handler(ARGGUFWeightOnlyConfig)
def _ar_gguf_weight_only_transform(
module: torch.nn.Module,
config: GGUFWeightOnlyConfig,
):
"""
Applies gguf weight-only quantization to linear layers.

Returns:
Callable for quantization transformation.
"""
weight = module.weight
if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0):
return module

quantized_weight = ARGGUFQuantizedTensor.from_float(
weight,
bits=config.bits,
group_size=config.group_size,
super_group_size=config.super_group_size,
super_bits=config.super_bits,
)
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
return module
204 changes: 203 additions & 1 deletion torchao/prototype/quantization/gguf/gguf_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

from torchao.quantization.quant_primitives import (
choose_qparams_gguf,
choose_qparams_gguf_ar,
dequantize_gguf,
dequantize_gguf_ar,
quantize_gguf,
quantize_gguf_ar,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
Expand All @@ -24,6 +27,7 @@

__all__ = [
"GGUFQuantizedTensor",
"ARGGUFQuantizedTensor",
]


Expand Down Expand Up @@ -247,6 +251,204 @@ def _(func, types, args, kwargs):
)


@implements(aten.squeeze.default)
def _(func, types, args, kwargs):
assert len(args) == 1, "Expecting a single argument for squeeze"
t = args[0]
if all(i != 1 for i in t.shape):
return return_and_correct_aliasing(func, args, kwargs, args[0])
raise NotImplementedError(
f"squeeze for tensor of shape {t.shape} not implemented yet"
)


@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if not input_tensor.is_floating_point():
raise NotImplementedError(
f"{func} is not implemented for non floating point input"
)

dtype = input_tensor.dtype

if hasattr(weight_tensor, "dequantize"):
weight_tensor = weight_tensor.dequantize(output_dtype=dtype)

return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


class ARGGUFQuantizedTensor(TorchAOBaseTensor):
"""
A Tensor subclass that when applied to a weight used in a linear op/module,
changes that linear op to a weight-only int4 quantized linear op with groupwise
affine quantization on the weight.
"""

@staticmethod
def __new__(
cls,
int_data,
qparams_dict,
shape,
**kwargs,
):
scale = qparams_dict["scale"]
kwargs["device"] = kwargs.get("device", scale.device)
kwargs["dtype"] = kwargs.get("dtype", scale.dtype)
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
int_data,
qparams_dict,
shape,
**kwargs,
):
self.int_data = int_data
self.qparams_dict = qparams_dict

def _apply_fn_to_data(self, fn):
qparams_dict = self.qparams_dict.copy()
for k, v in qparams_dict.items():
if isinstance(v, torch.Tensor):
v = fn(v)
qparams_dict[k] = v

return self.__class__(
fn(self.int_data),
qparams_dict,
self.shape,
dtype=self.dtype,
)

def __tensor_flatten__(self):
return [
"int_data",
], (
self.qparams_dict,
self.dtype,
self.shape,
)

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None
):
int_data = tensor_data_dict["int_data"]
qparams_dict, dtype, shape = attributes
return cls(
int_data,
qparams_dict,
shape if outer_size is None else outer_size,
dtype=dtype,
)

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype

return dequantize_gguf_ar(
self.int_data,
self.qparams_dict["scale"],
self.qparams_dict["wmin_m"],
self.qparams_dict["orig_shape"],
self.qparams_dict["pad_len"],
output_dtype=output_dtype,
)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs.pop("device")
qparams_dict = self.qparams_dict.copy()
for k, v in qparams_dict.items():
if isinstance(v, torch.Tensor):
v = v.to(device)
qparams_dict[k] = v

return self.__class__(
self.int_data.to(device),
qparams_dict,
self.shape,
**kwargs,
)

def requires_grad_(self, requires_grad=False):
"""
Modifies the tensor's `requires_grad` status in-place.
"""
assert not requires_grad, "Only requires_grad == False is supported"
return self

@classmethod
def from_float(cls, input_float, bits, group_size, super_group_size, super_bits):
assert bits == 4, "only uint4 quantization is supported right now"

qparams_dict = choose_qparams_gguf_ar(
input_float,
bits=bits,
group_size=group_size,
super_group_size=super_group_size,
super_bits=super_bits,
)

int_data = quantize_gguf_ar(
input_float,
bits,
group_size,
qparams_dict["scale"],
qparams_dict["wmin_m"],
)
return cls(
int_data,
qparams_dict,
input_float.shape,
)


implements = ARGGUFQuantizedTensor.implements


@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)


@implements(aten.squeeze.default)
def _(func, types, args, kwargs):
assert len(args) == 1, "Expecting a single argument for squeeze"
t = args[0]
if all(i != 1 for i in t.shape):
return return_and_correct_aliasing(func, args, kwargs, args[0])
raise NotImplementedError(
f"squeeze for tensor of shape {t.shape} not implemented yet"
)


@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
Expand All @@ -269,4 +471,4 @@ def _(func, types, args, kwargs):

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([GGUFQuantizedTensor])
torch.serialization.add_safe_globals([GGUFQuantizedTensor, ARGGUFQuantizedTensor])
Loading
Loading