|
9 | 9 |
|
10 | 10 | import copy
|
11 | 11 | import unittest
|
| 12 | +from typing import List |
12 | 13 |
|
13 | 14 | import torch
|
14 | 15 | import torch.nn.functional as F
|
|
25 | 26 | from torchao.quantization.qat.api import (
|
26 | 27 | ComposableQATQuantizer,
|
27 | 28 | FakeQuantizeConfig,
|
| 29 | + IntXQuantizationAwareTrainingConfig, |
28 | 30 | from_intx_quantization_aware_training,
|
| 31 | + initialize_fake_quantizers, |
29 | 32 | intx_quantization_aware_training,
|
30 | 33 | )
|
31 | 34 | from torchao.quantization.qat.embedding import (
|
@@ -95,6 +98,16 @@ def __init__(self):
|
95 | 98 | def example_inputs(self):
|
96 | 99 | return (torch.randn(1, 512).to(torch.float),)
|
97 | 100 |
|
| 101 | + def _get_all_weight_qparams(self) -> List[torch.Tensor]: |
| 102 | + return [ |
| 103 | + self.linear1.weight_fake_quantizer.scale, |
| 104 | + self.linear1.weight_fake_quantizer.zero_point, |
| 105 | + self.sub.linear.weight_fake_quantizer.scale, |
| 106 | + self.sub.linear.weight_fake_quantizer.zero_point, |
| 107 | + self.linear2.weight_fake_quantizer.scale, |
| 108 | + self.linear2.weight_fake_quantizer.zero_point, |
| 109 | + ] |
| 110 | + |
98 | 111 | def forward(self, x):
|
99 | 112 | x = self.linear1(x)
|
100 | 113 | x = self.sub(x)
|
@@ -980,6 +993,21 @@ def test_fake_quantize_config_dtype(self):
|
980 | 993 | FakeQuantizeConfig(TorchAODType.INT7, "per_token")
|
981 | 994 | FakeQuantizeConfig(torch.int8, "per_token")
|
982 | 995 |
|
| 996 | + def test_fake_quantize_config_dynamic_and_range_learning(self): |
| 997 | + """ |
| 998 | + Test that `is_dynamic` and `range_learning` cannot both be set. |
| 999 | + """ |
| 1000 | + FakeQuantizeConfig( |
| 1001 | + torch.int8, "per_token", is_dynamic=True, range_learning=False |
| 1002 | + ) |
| 1003 | + FakeQuantizeConfig( |
| 1004 | + torch.int8, "per_token", is_dynamic=False, range_learning=True |
| 1005 | + ) |
| 1006 | + with self.assertRaisesRegex(ValueError, "not compatible"): |
| 1007 | + FakeQuantizeConfig( |
| 1008 | + torch.int8, "per_token", is_dynamic=True, range_learning=True |
| 1009 | + ) |
| 1010 | + |
983 | 1011 | @unittest.skipIf(
|
984 | 1012 | not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
|
985 | 1013 | )
|
@@ -1486,6 +1514,95 @@ def test_qat_8da4w_prepare_vs_convert(self):
|
1486 | 1514 | )
|
1487 | 1515 | self.assertEqual(len(non_inf_sqnr), 0, fail_message)
|
1488 | 1516 |
|
| 1517 | + @unittest.skipIf( |
| 1518 | + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
| 1519 | + ) |
| 1520 | + def test_fake_quantizer_range_learning(self): |
| 1521 | + """ |
| 1522 | + Test that range learning requires `FakeQuantizer`s to be initialized correctly. |
| 1523 | + """ |
| 1524 | + config = FakeQuantizeConfig( |
| 1525 | + torch.int8, |
| 1526 | + "per_channel", |
| 1527 | + is_dynamic=False, |
| 1528 | + range_learning=True, |
| 1529 | + scale_precision=torch.float32, |
| 1530 | + zero_point_precision=torch.float32, |
| 1531 | + ) |
| 1532 | + fake_quantizer = FakeQuantizer(config) |
| 1533 | + example_inputs = (torch.randn(2, 3),) |
| 1534 | + |
| 1535 | + # Not initialized, should fail |
| 1536 | + self.assertFalse(fake_quantizer._initialized) |
| 1537 | + self.assertIsNone(fake_quantizer.scale) |
| 1538 | + self.assertIsNone(fake_quantizer.zero_point) |
| 1539 | + with self.assertRaisesRegex( |
| 1540 | + ValueError, |
| 1541 | + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " |
| 1542 | + "before initializing the optimizer and beginning training.", |
| 1543 | + ): |
| 1544 | + fake_quantizer(*example_inputs) |
| 1545 | + |
| 1546 | + # Should pass after initializing |
| 1547 | + initialize_fake_quantizers(fake_quantizer, example_inputs) |
| 1548 | + self.assertTrue(fake_quantizer._initialized) |
| 1549 | + self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter) |
| 1550 | + self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter) |
| 1551 | + self.assertTrue(fake_quantizer.scale.requires_grad) |
| 1552 | + self.assertTrue(fake_quantizer.zero_point.requires_grad) |
| 1553 | + fake_quantizer(*example_inputs) |
| 1554 | + |
| 1555 | + @unittest.skipIf( |
| 1556 | + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
| 1557 | + ) |
| 1558 | + def test_qat_range_learning(self): |
| 1559 | + """ |
| 1560 | + Test that range learning requires `FakeQuantizer`s to be initialized correctly. |
| 1561 | + """ |
| 1562 | + config = FakeQuantizeConfig( |
| 1563 | + torch.int8, |
| 1564 | + "per_channel", |
| 1565 | + is_dynamic=False, |
| 1566 | + range_learning=True, |
| 1567 | + scale_precision=torch.float32, |
| 1568 | + zero_point_precision=torch.float32, |
| 1569 | + ) |
| 1570 | + m = M() |
| 1571 | + example_inputs = m.example_inputs() |
| 1572 | + quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) |
| 1573 | + |
| 1574 | + # Not initialized, should fail |
| 1575 | + for t in m._get_all_weight_qparams(): |
| 1576 | + self.assertIsNone(t) |
| 1577 | + with self.assertRaisesRegex( |
| 1578 | + ValueError, |
| 1579 | + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " |
| 1580 | + "before initializing the optimizer and beginning training.", |
| 1581 | + ): |
| 1582 | + m(*example_inputs) |
| 1583 | + |
| 1584 | + # Should pass after initializing |
| 1585 | + # All scales and zero points should be in `m.parameters()` |
| 1586 | + initialize_fake_quantizers(m, example_inputs) |
| 1587 | + params = set(m.parameters()) |
| 1588 | + for t in m._get_all_weight_qparams(): |
| 1589 | + self.assertIsInstance(t, torch.nn.Parameter) |
| 1590 | + self.assertTrue(t.requires_grad) |
| 1591 | + self.assertTrue(t in params) |
| 1592 | + m(*example_inputs) |
| 1593 | + |
| 1594 | + # Simulate training |
| 1595 | + optimizer = torch.optim.SGD( |
| 1596 | + m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 |
| 1597 | + ) |
| 1598 | + loss_fn = torch.nn.CrossEntropyLoss() |
| 1599 | + target = torch.randn(1, 512).float() |
| 1600 | + out = m(*example_inputs) |
| 1601 | + loss = loss_fn(out, target) |
| 1602 | + optimizer.zero_grad() |
| 1603 | + loss.backward() |
| 1604 | + optimizer.step() |
| 1605 | + |
1489 | 1606 |
|
1490 | 1607 | if __name__ == "__main__":
|
1491 | 1608 | unittest.main()
|
0 commit comments