Skip to content

Enable range learning for QAT #2033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import copy
import unittest
from typing import List

import torch
import torch.nn.functional as F
Expand All @@ -25,7 +26,9 @@
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
from_intx_quantization_aware_training,
initialize_fake_quantizers,
intx_quantization_aware_training,
)
from torchao.quantization.qat.embedding import (
Expand Down Expand Up @@ -95,6 +98,16 @@ def __init__(self):
def example_inputs(self):
return (torch.randn(1, 512).to(torch.float),)

def _get_all_weight_qparams(self) -> List[torch.Tensor]:
return [
self.linear1.weight_fake_quantizer.scale,
self.linear1.weight_fake_quantizer.zero_point,
self.sub.linear.weight_fake_quantizer.scale,
self.sub.linear.weight_fake_quantizer.zero_point,
self.linear2.weight_fake_quantizer.scale,
self.linear2.weight_fake_quantizer.zero_point,
]

def forward(self, x):
x = self.linear1(x)
x = self.sub(x)
Expand Down Expand Up @@ -980,6 +993,21 @@ def test_fake_quantize_config_dtype(self):
FakeQuantizeConfig(TorchAODType.INT7, "per_token")
FakeQuantizeConfig(torch.int8, "per_token")

def test_fake_quantize_config_dynamic_and_range_learning(self):
"""
Test that `is_dynamic` and `range_learning` cannot both be set.
"""
FakeQuantizeConfig(
torch.int8, "per_channel", is_dynamic=True, range_learning=False
)
FakeQuantizeConfig(
torch.int8, "per_channel", is_dynamic=False, range_learning=True
)
with self.assertRaisesRegex(ValueError, "not compatible"):
FakeQuantizeConfig(
torch.int8, "per_channel", is_dynamic=True, range_learning=True
)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
Expand Down Expand Up @@ -1486,6 +1514,95 @@ def test_qat_8da4w_prepare_vs_convert(self):
)
self.assertEqual(len(non_inf_sqnr), 0, fail_message)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_fake_quantizer_range_learning(self):
"""
Test that range learning requires `FakeQuantizer`s to be initialized correctly.
"""
config = FakeQuantizeConfig(
torch.int8,
"per_channel",
is_dynamic=False,
range_learning=True,
scale_precision=torch.float32,
zero_point_precision=torch.float32,
)
fake_quantizer = FakeQuantizer(config)
example_inputs = (torch.randn(2, 3),)

# Not initialized, should fail
self.assertFalse(fake_quantizer._initialized)
self.assertIsNone(fake_quantizer.scale)
self.assertIsNone(fake_quantizer.zero_point)
with self.assertRaisesRegex(
ValueError,
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
"before initializing the optimizer and beginning training.",
):
fake_quantizer(*example_inputs)

# Should pass after initializing
initialize_fake_quantizers(fake_quantizer, example_inputs)
self.assertTrue(fake_quantizer._initialized)
self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter)
self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter)
self.assertTrue(fake_quantizer.scale.requires_grad)
self.assertTrue(fake_quantizer.zero_point.requires_grad)
fake_quantizer(*example_inputs)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_qat_range_learning(self):
"""
Test end-to-end QAT flow with range learning.
"""
config = FakeQuantizeConfig(
torch.int8,
"per_channel",
is_dynamic=False,
range_learning=True,
scale_precision=torch.float32,
zero_point_precision=torch.float32,
)
m = M()
example_inputs = m.example_inputs()
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))

# Not initialized, should fail
for t in m._get_all_weight_qparams():
self.assertIsNone(t)
with self.assertRaisesRegex(
ValueError,
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
"before initializing the optimizer and beginning training.",
):
m(*example_inputs)

# Should pass after initializing
# All scales and zero points should be in `m.parameters()`
initialize_fake_quantizers(m, example_inputs)
params = set(m.parameters())
for t in m._get_all_weight_qparams():
self.assertIsInstance(t, torch.nn.Parameter)
self.assertTrue(t.requires_grad)
self.assertTrue(t in params)
m(*example_inputs)

# Simulate training
optimizer = torch.optim.SGD(
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
)
loss_fn = torch.nn.CrossEntropyLoss()
target = torch.randn(1, 512).float()
out = m(*example_inputs)
loss = loss_fn(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()


if __name__ == "__main__":
unittest.main()
8 changes: 5 additions & 3 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
from_intx_quantization_aware_training,
initialize_fake_quantizers,
intx_quantization_aware_training,
)
from .embedding import (
Expand All @@ -17,11 +18,12 @@
__all__ = [
"ComposableQATQuantizer",
"FakeQuantizeConfig",
"Int4WeightOnlyQATQuantizer",
"FromIntXQuantizationAwareTrainingConfig",
"Int4WeightOnlyEmbeddingQATQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"IntXQuantizationAwareTrainingConfig",
"initialize_fake_quantizers",
"intx_quantization_aware_training",
"from_intx_quantization_aware_training",
"FromIntXQuantizationAwareTrainingConfig",
"IntXQuantizationAwareTrainingConfig",
]
29 changes: 27 additions & 2 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -51,7 +51,8 @@ class FakeQuantizeConfig:
zero_point_precision: zero point dtype (default torch.int32)
zero_point_domain: whether zero point is in integer (default) or float domain
is_dynamic: whether to use dynamic (default) or static scale and zero points
range_learning: whether to learn scale and zero points during training (coming soon)
range_learning: whether to learn scale and zero points during training
(default false), not compatible with `is_dynamic`.

kwargs (optional):
group_size: size of each group in per group fake quantization,
Expand Down Expand Up @@ -120,6 +121,10 @@ def __init__(
"Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes)
)

# Dynamic is not compatible with range learning
if is_dynamic and range_learning:
raise ValueError("`is_dynamic` is not compatible with `range_learning`")

def _get_granularity(
self,
granularity: Union[Granularity, str, None],
Expand Down Expand Up @@ -391,3 +396,23 @@ def convert(
for quantizer in self.quantizers:
model = quantizer.convert(model)
return model


def initialize_fake_quantizers(
model: torch.nn.Module,
example_inputs: Tuple[Any, ...],
) -> None:
"""
Initialize the scales and zero points on all
:class:`~`torchao.quantization.qat.fake_quantizer.FakeQuantizer`
in the model based on the provided example inputs.
"""
# avoid circular dependencies
from torchao.quantization.qat.fake_quantizer import FakeQuantizer

def _set_initialized(m: torch.nn.Module):
if isinstance(m, FakeQuantizer):
m._initialized = True

model.apply(_set_initialized)
model(*example_inputs)
47 changes: 41 additions & 6 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_Round,
)


Expand All @@ -44,26 +45,38 @@ def __init__(self, config: FakeQuantizeConfig):
self.scale: Optional[torch.Tensor] = None
self.zero_point: Optional[torch.Tensor] = None

# TODO: support range learinng
if self.config.range_learning:
raise NotImplementedError("Range learning is not supported yet")
# For range learning only
# TODO: make this configurable?
self._scale_eps = 1e-9
self._initialized = False

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply fake quantization to the tensor based on the bit-width,
granularity, symmetry, and other properties specified in the config.
"""
if not self.enabled:
return x

if (
self.config.range_learning
and not self._initialized
and (self.scale is None or self.zero_point is None)
):
raise ValueError(
"Scales and zero points must be initialized for range learning. "
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
"before initializing the optimizer and beginning training."
)

if isinstance(self.config.granularity, PerToken):
return self._per_token_forward(x)
elif isinstance(self.config.granularity, (PerAxis, PerGroup)):
return self._per_channel_or_group_forward(x)
else:
raise ValueError("Unknown granularity '%s'" % self.config.granularity)

def _per_token_forward(self, x: torch.Tensor):
def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform per token fake quantization on the tensor.
"""
Expand All @@ -75,10 +88,11 @@ def _per_token_forward(self, x: torch.Tensor):
self.config.scale_precision,
self.config.zero_point_precision,
)
self._maybe_update_qparams_for_range_learning()
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)

def _per_channel_or_group_forward(self, x: torch.Tensor):
def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform per channel or per group fake quantization on the tensor.
We express per channel using per group where the group size is the size
Expand Down Expand Up @@ -117,6 +131,7 @@ def _per_channel_or_group_forward(self, x: torch.Tensor):
scale_precision,
)
self.zero_point = self.zero_point.to(zero_point_precision)
self._maybe_update_qparams_for_range_learning()

qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
return _fake_quantize_per_channel_group(
Expand All @@ -135,6 +150,26 @@ def _should_compute_qparams(self) -> bool:
"""
return self.config.is_dynamic or self.scale is None or self.zero_point is None

def _maybe_update_qparams_for_range_learning(self) -> None:
"""
If range learning is enabled, turn scales and zero points into trainable parameters.
This function is idempotent and should only be called once.
"""
if (
not self.config.range_learning
or isinstance(self.scale, torch.nn.Parameter)
or isinstance(self.zero_point, torch.nn.Parameter)
):
return
scale, zero_point = self.scale, self.zero_point
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
# Stabilize range learning
scale = torch.clamp(scale, min=self._scale_eps)
zero_point = _Round.apply(zero_point)
zero_point = torch.clamp(zero_point, qmin, qmax)
self.scale = torch.nn.Parameter(scale, requires_grad=True)
self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True)

def __repr__(self) -> str:
"""
Return a human readable representation of this `FakeQuantizer` with config details.
Expand Down
14 changes: 8 additions & 6 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_replace_linear_int4,
groupwise_affine_quantize_tensor,
)
from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_primitives import (
TorchAODType,
ZeroPointDomain,
Expand Down Expand Up @@ -83,12 +84,13 @@ def __init__(

# initialize weight fake quantizer
if weight_config is not None:
group_size = weight_config.group_size
if group_size is not None and in_features % group_size != 0:
raise ValueError(
"in_features (%s) %% group_size (%s) must be == 0"
% (in_features, group_size)
)
if isinstance(weight_config.granularity, PerGroup):
group_size = weight_config.group_size
if group_size is not None and in_features % group_size != 0:
raise ValueError(
"in_features (%s) %% group_size (%s) must be == 0"
% (in_features, group_size)
)
self.weight_fake_quantizer = FakeQuantizer(weight_config)
else:
self.weight_fake_quantizer = None
Expand Down
14 changes: 14 additions & 0 deletions torchao/quantization/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ def backward(ctx, gy):
return (gy,)


class _Round(torch.autograd.Function):
"""
Implementation of generic round operation with backward STE.
"""

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return torch.round(x)

@staticmethod
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
return gy


def _fake_quantize_per_channel_group(
input: torch.Tensor,
scales: torch.Tensor,
Expand Down
Loading