Skip to content

[not for landing/review] add fake quant ops for embedding/linear #2110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 276 additions & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import importlib.util
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import torch
from torch.utils._python_dispatch import TorchDispatchMode

from torchao.kernel import (
int_scaled_matmul,
)
from torchao.quantization.granularity import (
Granularity,
PerAxis,
PerGroup,
PerTensor,
PerToken,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
Expand Down Expand Up @@ -141,6 +148,24 @@ def _get_per_token_block_size(x: torch.Tensor) -> List[int]:
return block_size


def _get_block_size_from_granularity(
x: torch.Tensor, granularity: Granularity
) -> Tuple[int, ...]:
if isinstance(granularity, PerTensor):
block_size = x.shape
elif isinstance(granularity, PerToken):
block_size = _get_per_token_block_size(x)
elif isinstance(granularity, PerGroup):
block_size = list(x.shape)
block_size[-1] = granularity.group_size
elif isinstance(granularity, PerAxis):
block_size = list(x.shape)
block_size[granularity.axis] = 1
else:
raise ValueError(f"Unsupported granularity: {granularity}")
return tuple(block_size)


# taken from
# https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26
# and slightly modified
Expand Down Expand Up @@ -631,6 +656,256 @@ def per_token_dynamic_quant(
return dq


from torchao.utils import _register_custom_op

quant_lib = torch.library.Library("quant", "FRAGMENT")
register_custom_op = _register_custom_op(quant_lib)

_GRANULARITY_STR_AND_KWARG_KEYS = [
(PerTensor, "PER_TENSOR", []),
(PerAxis, "PER_AXIS", ["axis"]),
(PerGroup, "PER_GROUP", ["group_size"]),
(PerToken, "PER_TOKEN", []),
]


def _get_str_and_kwargs_for_granularity(granularity: Granularity):
for cls, cls_str, cls_kwarg_keys in _GRANULARITY_STR_AND_KWARG_KEYS:
if isinstance(granularity, cls):
return cls_str, {k: getattr(granularity, k) for k in cls_kwarg_keys}
raise ValueError(f"Unsupported granularity: {granularity}")


def _get_granularity_from_str_and_kwargs(granularity_str, kwargs):
for cls, cls_str, cls_kwarg_keys in _GRANULARITY_STR_AND_KWARG_KEYS:
if cls_str == granularity_str:
return cls(**{k: kwargs[k] for k in cls_kwarg_keys})
raise ValueError(f"Unsupported granularity: {granularity_str}")


def qdq_dynamic_affine_quantization(
input: torch.Tensor,
mapping_type: MappingType,
granularity: Granularity,
quant_min: int,
quant_max: int,
zero_point_domain: ZeroPointDomain,
*,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
eps: Optional[float] = None,
) -> torch.Tensor:
granularity_str, kwargs = _get_str_and_kwargs_for_granularity(granularity)
return _qdq_dynamic_affine_quantization(
input,
mapping_type.name,
granularity_str,
quant_min,
quant_max,
zero_point_domain.name,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
eps=eps,
**kwargs,
)


@register_custom_op
def _qdq_dynamic_affine_quantization(
input: torch.Tensor,
mapping_type: str,
granularity: str,
quant_min: int,
quant_max: int,
zero_point_domain: str,
*,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
eps: Optional[float] = None,
# granularity kwargs
axis: Optional[int] = None,
group_size: Optional[int] = None,
) -> torch.Tensor:
granularity = _get_granularity_from_str_and_kwargs(
granularity, {"axis": axis, "group_size": group_size}
)
block_size = _get_block_size_from_granularity(input, granularity)
quant_dtype = torch.int8
assert (
quant_max <= torch.iinfo(quant_dtype).max
), f"quant_max {quant_max} is too large for {quant_dtype}"
assert (
quant_min >= torch.iinfo(quant_dtype).min
), f"quant_min {quant_min} is too small for {quant_dtype}"
output_dtype = input.dtype

mapping_type = getattr(MappingType, mapping_type.upper())
zero_point_domain = getattr(ZeroPointDomain, zero_point_domain.upper())
preserve_zero = (mapping_type == MappingType.SYMMETRIC) or (
zero_point_domain == ZeroPointDomain.NONE
)
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
quant_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
q = quantize_affine(
input,
block_size,
scale,
zero_point,
quant_dtype,
quant_min,
quant_max,
zero_point_domain,
)
dq = dequantize_affine(
q,
block_size,
scale,
zero_point,
quant_dtype,
quant_min,
quant_max,
zero_point_domain,
output_dtype=output_dtype,
)
return dq


def qdq_weight_only_affine_quantized_embedding(
indices: torch.Tensor,
w_int_data: torch.Tensor,
w_quant_min: int,
w_quant_max: int,
w_scale: torch.Tensor,
w_zero_point: Optional[torch.Tensor],
granularity: Granularity,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
granularity_str, kwargs = _get_str_and_kwargs_for_granularity(granularity)
return _qdq_weight_only_affine_quantized_embedding(
indices,
w_int_data,
w_quant_min,
w_quant_max,
w_scale,
w_zero_point,
granularity_str,
output_dtype,
**kwargs,
)


@register_custom_op
def _qdq_weight_only_affine_quantized_embedding(
indices: torch.Tensor,
w_int_data: torch.Tensor,
w_quant_min: int,
w_quant_max: int,
w_scale: torch.Tensor,
w_zero_point: Optional[torch.Tensor],
granularity: str,
output_dtype: torch.dtype,
*,
# granularity kwargs
axis: Optional[int] = None,
group_size: Optional[int] = None,
) -> torch.Tensor:
granularity = _get_granularity_from_str_and_kwargs(
granularity, {"axis": axis, "group_size": group_size}
)
block_size = _get_block_size_from_granularity(w_int_data, granularity)
input_dtype = w_int_data.dtype
w_zero_point_domain = ZeroPointDomain.NONE
if w_zero_point is not None:
assert w_zero_point.dtype in [torch.int8, torch.int32, torch.int64]
w_zero_point_domain = ZeroPointDomain.INT
w_dq = dequantize_affine(
w_int_data,
block_size,
w_scale,
w_zero_point,
input_dtype,
w_quant_min,
w_quant_max,
w_zero_point_domain,
output_dtype=output_dtype,
)
return torch.nn.functional.embedding(indices, w_dq)


def qdq_weight_only_affine_quantized_linear(
x: torch.Tensor,
w_int_data: torch.Tensor,
w_quant_min: int,
w_quant_max: int,
w_scale: torch.Tensor,
w_zero_point: Optional[torch.Tensor],
granularity: Granularity,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
granularity_str, kwargs = _get_str_and_kwargs_for_granularity(granularity)
return _qdq_weight_only_affine_quantized_linear(
x,
w_int_data,
w_quant_min,
w_quant_max,
w_scale,
w_zero_point,
granularity_str,
bias,
**kwargs,
)


@register_custom_op
def _qdq_weight_only_affine_quantized_linear(
x: torch.Tensor,
w_int_data: torch.Tensor,
w_quant_min: int,
w_quant_max: int,
w_scale: torch.Tensor,
w_zero_point: Optional[torch.Tensor],
granularity: str,
bias: Optional[torch.Tensor],
*,
# granularity kwargs
axis: Optional[int] = None,
group_size: Optional[int] = None,
) -> torch.Tensor:
granularity = _get_granularity_from_str_and_kwargs(
granularity, {"axis": axis, "group_size": group_size}
)
block_size = _get_block_size_from_granularity(w_int_data, granularity)
output_dtype = x.dtype
input_dtype = w_int_data.dtype
w_zero_point_domain = ZeroPointDomain.NONE
if w_zero_point is not None:
assert w_zero_point.dtype in [torch.int8, torch.int32, torch.int64]
w_zero_point_domain = ZeroPointDomain.INT
w_dq = dequantize_affine(
w_int_data,
block_size,
w_scale,
w_zero_point,
input_dtype,
w_quant_min,
w_quant_max,
w_zero_point_domain,
output_dtype=output_dtype,
)
return torch.nn.functional.linear(x, w_dq, bias)


def recommended_inductor_config_setter():
"""
Set inductor config to use the following optimizations which have been showed to improve performance for quantized models:
Expand Down
Loading