|
57 | 57 | from torchao.quantization.pt2e.quantizer.embedding_quantizer import ( # noqa: F811
|
58 | 58 | EmbeddingQuantizer,
|
59 | 59 | )
|
60 |
| -from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import ( |
| 60 | +from torchao.testing.pt2e._xnnpack_quantizer import ( |
61 | 61 | XNNPACKQuantizer,
|
62 | 62 | get_symmetric_quantization_config,
|
63 | 63 | )
|
64 |
| -from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import ( |
| 64 | +from torchao.testing.pt2e._xnnpack_quantizer_utils import ( |
65 | 65 | OP_TO_ANNOTATOR,
|
66 | 66 | QuantizationConfig,
|
67 | 67 | )
|
@@ -1328,6 +1328,40 @@ def validate(self, model: torch.fx.GraphModule) -> None:
|
1328 | 1328 | with self.assertRaises(Exception):
|
1329 | 1329 | m = prepare_pt2e(m, BackendAQuantizer())
|
1330 | 1330 |
|
| 1331 | + def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): |
| 1332 | + # resetting dynamo cache |
| 1333 | + torch._dynamo.reset() |
| 1334 | + |
| 1335 | + m = export_for_training( |
| 1336 | + m, |
| 1337 | + example_inputs, |
| 1338 | + ).module() |
| 1339 | + if is_qat: |
| 1340 | + m = prepare_qat_pt2e(m, quantizer) |
| 1341 | + else: |
| 1342 | + m = prepare_pt2e(m, quantizer) |
| 1343 | + m(*example_inputs) |
| 1344 | + m = convert_pt2e(m) |
| 1345 | + return m |
| 1346 | + |
| 1347 | + def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule: |
| 1348 | + class M(torch.nn.Module): |
| 1349 | + def __init__(self) -> None: |
| 1350 | + super().__init__() |
| 1351 | + self.linear = torch.nn.Linear(2, 2) |
| 1352 | + |
| 1353 | + def forward(self, x): |
| 1354 | + return self.linear(x) |
| 1355 | + |
| 1356 | + quantizer = XNNPACKQuantizer() |
| 1357 | + operator_config = get_symmetric_quantization_config( |
| 1358 | + is_per_channel=is_per_channel |
| 1359 | + ) |
| 1360 | + quantizer.set_global(operator_config) |
| 1361 | + example_inputs = (torch.randn(2, 2),) |
| 1362 | + m = M().eval() |
| 1363 | + return self._quantize(m, quantizer, example_inputs) |
| 1364 | + |
1331 | 1365 | def test_fold_quantize(self):
|
1332 | 1366 | """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)"""
|
1333 | 1367 | m = self._get_pt2e_quantized_linear()
|
@@ -2493,10 +2527,10 @@ def check_nn_module(node):
|
2493 | 2527 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
|
2494 | 2528 | class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
|
2495 | 2529 | def test_channel_group_quantization(self):
|
2496 |
| - from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken |
2497 |
| - from torchao.quantization.pt2e.pt2e._affine_quantization import ( |
| 2530 | + from torchao.quantization.pt2e._affine_quantization import ( |
2498 | 2531 | AffineQuantizedMinMaxObserver,
|
2499 | 2532 | )
|
| 2533 | + from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken |
2500 | 2534 |
|
2501 | 2535 | class BackendAQuantizer(Quantizer):
|
2502 | 2536 | def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
@@ -2576,14 +2610,14 @@ def forward(self, x):
|
2576 | 2610 | def test_dynamic_affine_act_per_channel_weights(self):
|
2577 | 2611 | import operator
|
2578 | 2612 |
|
| 2613 | + from torchao.quantization.pt2e._affine_quantization import ( |
| 2614 | + AffineQuantizedMovingAverageMinMaxObserver, |
| 2615 | + ) |
2579 | 2616 | from torchao.quantization.pt2e.observer import (
|
2580 | 2617 | MappingType,
|
2581 | 2618 | PerChannelMinMaxObserver,
|
2582 | 2619 | PerToken,
|
2583 | 2620 | )
|
2584 |
| - from torchao.quantization.pt2e.pt2e._affine_quantization import ( |
2585 |
| - AffineQuantizedMovingAverageMinMaxObserver, |
2586 |
| - ) |
2587 | 2621 |
|
2588 | 2622 | class BackendAQuantizer(Quantizer):
|
2589 | 2623 | def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
@@ -2667,13 +2701,12 @@ def forward(self, x):
|
2667 | 2701 | def test_dynamic_per_tok_act_per_group_weights(self):
|
2668 | 2702 | import operator
|
2669 | 2703 |
|
2670 |
| - from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken |
2671 |
| - |
2672 | 2704 | # TODO: merge into torchao observer
|
2673 |
| - from torchao.quantization.pt2e.pt2e._affine_quantization import ( |
| 2705 | + from torchao.quantization.pt2e._affine_quantization import ( |
2674 | 2706 | AffineQuantizedMinMaxObserver,
|
2675 | 2707 | AffineQuantizedPlaceholderObserver,
|
2676 | 2708 | )
|
| 2709 | + from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken |
2677 | 2710 |
|
2678 | 2711 | class BackendAQuantizer(Quantizer):
|
2679 | 2712 | def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
0 commit comments