-
Notifications
You must be signed in to change notification settings - Fork 294
PARQ quantizer support for torchao's weight-only configs #2091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
e3532e4
Add parq.quant.UnifTorchaoQuantizer for quantize_ API equivalence
lisjin bc0e52a
Test IntxWeightOnlyConfig
lisjin df8867f
Formatting fix
lisjin 1d3d7d9
Per-row IntxWeightOnlyConfig test
lisjin b5d83bb
Add end-to-end QAT prepare/convert test case
lisjin 5ac8a9b
Merge remote-tracking branch 'pytorch/main' into parq
lisjin d7710cf
Pass explicit layout to int4_weight_only
lisjin 6130cc2
Add QuantOptimizer.torchao_quantize_
lisjin 667dfe7
Merge remote-tracking branch 'pytorch/main' into parq
lisjin 9e70d6d
Update README, add Int4UnifTorchaoQuantizer
lisjin 490bdf6
Merge remote-tracking branch 'pytorch/main' into parq
lisjin 3809add
Add test_intx_weight_only_e2e, set UnifTorchaoQuantizer defaults
lisjin 073e1fa
Update PARQ README
lisjin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,46 +3,87 @@ | |
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import copy | ||
import unittest | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import nn | ||
from torch.testing._internal import common_utils | ||
|
||
from torchao.core.config import AOBaseConfig | ||
from torchao.dtypes import Int4CPULayout | ||
from torchao.prototype.parq.optim import ( | ||
ProxHardQuant, | ||
ProxPARQ, | ||
QuantOptimizer, | ||
) | ||
from torchao.prototype.parq.quant import LSBQuantizer, UnifQuantizer | ||
from torchao.prototype.parq.quant import ( | ||
Int4UnifTorchaoQuantizer, | ||
LSBQuantizer, | ||
TernaryUnifQuantizer, | ||
UnifQuantizer, | ||
UnifTorchaoQuantizer, | ||
) | ||
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE | ||
from torchao.quantization.granularity import PerGroup | ||
from torchao.quantization.quant_api import ( | ||
IntxWeightOnlyConfig, | ||
_is_linear, | ||
int4_weight_only, | ||
quantize_, | ||
) | ||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_4, | ||
TORCH_VERSION_AT_LEAST_2_6, | ||
check_cpu_version, | ||
) | ||
|
||
_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def split_param_groups(model): | ||
params_no_quant, params_quant = [], [] | ||
for p in model.parameters(): | ||
if p.dim() > 1: | ||
params_quant.append(p) | ||
else: | ||
params_no_quant.append(p) | ||
return params_no_quant, params_quant | ||
params_quant, params_no_quant = [], [] | ||
|
||
def get_param_groups(model): | ||
for module in model.children(): | ||
is_linear = _is_linear(module) | ||
for n, p in module.named_parameters(): | ||
if is_linear and n == "weight": | ||
params_quant.append(p) | ||
else: | ||
params_no_quant.append(p) | ||
|
||
get_param_groups(model) | ||
return params_quant, params_no_quant | ||
|
||
|
||
class M(torch.nn.Module): | ||
def __init__(self): | ||
def build_param_groups(model, b: int = 2, group_size: Optional[int] = None): | ||
params_quant, params_no_quant = split_param_groups(model) | ||
quant_kwargs = {"quant_block_size": group_size} if group_size else {} | ||
return [ | ||
{"params": params_quant, "quant_bits": b, **quant_kwargs}, | ||
{"params": params_no_quant}, | ||
] | ||
|
||
|
||
class M(nn.Module): | ||
def __init__(self, m=256, n=128, k=16, bias=False): | ||
super().__init__() | ||
self.embedding = torch.nn.Embedding(10, 256) | ||
self.linear1 = torch.nn.Linear(256, 128) | ||
self.linear2 = torch.nn.Linear(128, 16) | ||
self.relu = torch.nn.ReLU() | ||
self.sigmoid = torch.nn.Sigmoid() | ||
self.embedding = nn.Embedding(10, m) | ||
self.linear1 = nn.Linear(m, n, bias=bias) | ||
self.linear2 = nn.Linear(n, k, bias=bias) | ||
self.relu = nn.ReLU() | ||
self.sigmoid = nn.Sigmoid() | ||
|
||
def reset_parameters(self): | ||
for module in (self.linear1, self.linear2): | ||
torch.nn.init.xavier_uniform_(module.weight) | ||
torch.nn.init.zeros_(module.bias) | ||
nn.init.xavier_uniform_(module.weight) | ||
if module.bias is not None: | ||
nn.init.zeros_(module.bias) | ||
|
||
def example_inputs(self): | ||
return torch.randint(1, 10, (1, 256)) | ||
def example_inputs(self, device=None): | ||
return torch.randint(1, 10, (1, 256), device=device) | ||
|
||
def forward(self, x): | ||
x = self.embedding(x) | ||
|
@@ -53,52 +94,179 @@ def forward(self, x): | |
return x | ||
|
||
|
||
class TestPARQuantization(unittest.TestCase): | ||
class TestPARQuantization(common_utils.TestCase): | ||
def setUp(self): | ||
torch.manual_seed(123) | ||
self.model = M().to(_DEVICE) | ||
self.params_no_quant, self.params_quant = split_param_groups(self.model) | ||
self.model = M(bias=True).to(_DEVICE) | ||
|
||
def test_2bit_unif_quantizer_hard_prox(self): | ||
@common_utils.parametrize("b", [0, 1, 2, 4]) | ||
@common_utils.parametrize("unif_quant", [True, False]) | ||
@common_utils.parametrize("hard_prox", [True, False]) | ||
def test_parq_train_loop(self, b: int = 2, unif_quant=True, hard_prox=True): | ||
self.model.reset_parameters() | ||
param_groups = [ | ||
{"params": self.params_no_quant}, | ||
{"params": self.params_quant, "quant_bits": 2}, | ||
] | ||
param_groups = build_param_groups(self.model, b) | ||
base_optimizer = torch.optim.AdamW(param_groups) | ||
quantizer = UnifQuantizer() | ||
prox_map = ProxHardQuant() | ||
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map) | ||
|
||
x = self.model.example_inputs().to(_DEVICE) | ||
out = self.model(x) | ||
out.sum().backward() | ||
optimizer.step() | ||
|
||
for child in self.model.children(): | ||
if isinstance(child, torch.nn.Linear): | ||
self.assertEqual(child.weight.unique().numel(), 4) | ||
|
||
def test_ternarybit_lsbq_parq_prox(self): | ||
self.model.reset_parameters() | ||
param_groups = [ | ||
{"params": self.params_no_quant}, | ||
{"params": self.params_quant, "quant_bits": 0}, | ||
] | ||
base_optimizer = torch.optim.AdamW(param_groups) | ||
quantizer = LSBQuantizer() | ||
prox_map = ProxPARQ(anneal_start=0, anneal_end=2) | ||
if unif_quant: | ||
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer() | ||
else: | ||
quantizer = LSBQuantizer() | ||
prox_map = ( | ||
ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2) | ||
) | ||
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map) | ||
|
||
for _ in range(3): | ||
x = self.model.example_inputs().to(_DEVICE) | ||
x = self.model.example_inputs(device=_DEVICE) | ||
out = self.model(x) | ||
out.sum().backward() | ||
optimizer.step() | ||
|
||
for child in self.model.children(): | ||
if isinstance(child, torch.nn.Linear): | ||
self.assertEqual(child.weight.unique().numel(), 3) | ||
if isinstance(child, nn.Linear): | ||
self.assertEqual( | ||
child.weight.unique().numel(), quantizer.get_quant_size(b) | ||
) | ||
|
||
|
||
class TestUnifTorchaoQuantizer(common_utils.TestCase): | ||
def setUp(self): | ||
torch.manual_seed(123) | ||
|
||
def compare_quantized_models( | ||
self, | ||
model: nn.Module, | ||
m_ref: nn.Module, | ||
quantizer: UnifTorchaoQuantizer, | ||
b: int, | ||
group_size: int, | ||
): | ||
for n, module in model.named_children(): | ||
if not _is_linear(module): | ||
continue | ||
|
||
# simulate grouping from QuantOptimizer.step | ||
p = module.weight | ||
original_shape = p.shape | ||
p = p.view(-1, group_size) | ||
|
||
q, Q = quantizer.quantize(p, b=b, dim=-1) | ||
q = q.view(original_shape) | ||
|
||
# compare to AffineQuantizedTensor instance | ||
ref = getattr(m_ref, n).weight.dequantize() | ||
self.assertTrue(q.equal(ref)) | ||
|
||
def compare_parq_convert( | ||
self, | ||
model: nn.Module, | ||
m_ref: nn.Module, | ||
optimizer: QuantOptimizer, | ||
config: AOBaseConfig, | ||
): | ||
# do not update model weights, just quantize | ||
optimizer.zero_grad() | ||
optimizer.step() | ||
|
||
orig_model = copy.deepcopy(model) # save copy of PARQ quantized model | ||
|
||
# equivalent to torchao's convert step | ||
model.eval() | ||
optimizer.restore_latent_params() | ||
quantize_(model, config, filter_fn=optimizer.get_filter_fn(model)) | ||
|
||
for n, module in model.named_modules(): | ||
if not _is_linear(module): | ||
continue | ||
|
||
p_orig = getattr(orig_model, n).weight # PARQ weight | ||
p = module.weight.dequantize() # PARQ weight after quantize_ | ||
p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_ | ||
|
||
self.assertTrue(p_orig.equal(p_ref)) | ||
self.assertTrue(p.equal(p_ref)) | ||
|
||
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") | ||
@common_utils.parametrize("group_size", [32, 256]) | ||
def test_int4_weight_only(self, group_size: int = 32): | ||
model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE) | ||
model.reset_parameters() | ||
|
||
m_ref = copy.deepcopy(model).eval().to(_DEVICE) | ||
config = int4_weight_only(group_size=group_size) | ||
if check_cpu_version(_DEVICE): | ||
config.layout = Int4CPULayout() | ||
quantize_(m_ref, config) | ||
|
||
b = 4 | ||
self.compare_quantized_models( | ||
model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size | ||
) | ||
|
||
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") | ||
@common_utils.parametrize("b", [2, 3, 4, 8]) | ||
@common_utils.parametrize("group_size", [32, 512]) | ||
def test_intx_weight_only(self, b: int = 2, group_size: int = 32): | ||
model = M(m=512, n=512).to(_DEVICE) | ||
model.reset_parameters() | ||
|
||
m_ref = copy.deepcopy(model).eval().to(_DEVICE) | ||
quantize_( | ||
m_ref, | ||
IntxWeightOnlyConfig( | ||
weight_dtype=_BIT_WIDTH_TO_DTYPE[b], granularity=PerGroup(group_size) | ||
), | ||
) | ||
|
||
quantizer = UnifTorchaoQuantizer() | ||
self.compare_quantized_models(model, m_ref, quantizer, b, group_size) | ||
|
||
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") | ||
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available") | ||
lisjin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def test_int4_weight_only_e2e(self, group_size: int = 32): | ||
model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE) | ||
model.reset_parameters() | ||
|
||
m_ref = copy.deepcopy(model).eval().to(_DEVICE) | ||
config = int4_weight_only(group_size=group_size) | ||
if check_cpu_version(_DEVICE): | ||
config.layout = Int4CPULayout() | ||
quantize_(m_ref, config) | ||
|
||
b = 4 | ||
base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) | ||
optimizer = QuantOptimizer( | ||
base_optimizer, | ||
Int4UnifTorchaoQuantizer(), | ||
ProxHardQuant(), | ||
quant_per_channel=True, | ||
) | ||
self.compare_parq_convert(model, m_ref, optimizer, config) | ||
|
||
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") | ||
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available") | ||
@common_utils.parametrize("b", [2, 3, 4, 8]) | ||
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): | ||
Comment on lines
+247
to
+248
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @metascroy Thanks for looking it over! I've added this end-to-end test, along with |
||
model = M(m=512, n=512).to(_DEVICE) | ||
model.reset_parameters() | ||
|
||
m_ref = copy.deepcopy(model).eval().to(_DEVICE) | ||
config = IntxWeightOnlyConfig( | ||
weight_dtype=_BIT_WIDTH_TO_DTYPE[b], granularity=PerGroup(group_size) | ||
) | ||
quantize_(m_ref, config) | ||
|
||
base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) | ||
optimizer = QuantOptimizer( | ||
base_optimizer, | ||
UnifTorchaoQuantizer(), | ||
ProxHardQuant(), | ||
quant_per_channel=True, | ||
) | ||
self.compare_parq_convert(model, m_ref, optimizer, config) | ||
|
||
|
||
common_utils.instantiate_parametrized_tests(TestPARQuantization) | ||
common_utils.instantiate_parametrized_tests(TestUnifTorchaoQuantizer) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.