Skip to content

Commit 313c015

Browse files
committed
Update PARQ README
1 parent 3809add commit 313c015

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

test/prototype/test_parq.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchao.prototype.parq.quant import (
2222
Int4UnifTorchaoQuantizer,
2323
LSBQuantizer,
24+
TernaryUnifQuantizer,
2425
UnifQuantizer,
2526
UnifTorchaoQuantizer,
2627
)
@@ -102,14 +103,14 @@ def setUp(self):
102103
@common_utils.parametrize("unif_quant", [True, False])
103104
@common_utils.parametrize("hard_prox", [True, False])
104105
def test_parq_train_loop(self, b: int = 2, unif_quant=True, hard_prox=True):
105-
if unif_quant and b == 0:
106-
self.skipTest("Ternary uniform quantization not yet supported")
107-
108106
self.model.reset_parameters()
109107
param_groups = build_param_groups(self.model, b)
110108
base_optimizer = torch.optim.AdamW(param_groups)
111109

112-
quantizer = UnifQuantizer() if unif_quant else LSBQuantizer()
110+
if unif_quant:
111+
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
112+
else:
113+
quantizer = LSBQuantizer()
113114
prox_map = (
114115
ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2)
115116
)

torchao/prototype/parq/README.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ optimizer = QuantOptimizer(
4848
<td valign="top">
4949

5050
```python
51+
from torchao.quantization import quantize_
5152
from torchao.quantization.qat import (
5253
FakeQuantizeConfig,
5354
intx_quantization_aware_training,
@@ -67,16 +68,20 @@ quantize_(
6768
<td valign="top">
6869

6970
```python
70-
config = IntXWeightOnlyConfig(
71+
from torchao.quantization import IntxWeightOnlyConfig, quantize_
72+
73+
config = IntxWeightOnlyConfig(
7174
weight_dtype=torch.int4, granularity=PerGroup(32)
7275
)
73-
optimizer.torchao_quantize_(model, config)
76+
optimizer.restore_latent_params()
77+
quantize_(model, config, filter_fn=optimizer._get_filter_fn(model))
7478
```
7579

7680
</td>
7781
<td valign="top">
7882

7983
```python
84+
from torchao.quantization import quantize_
8085
from torchao.quantization.qat import from_intx_quantization_aware_training
8186

8287
quantize_(model, from_intx_quantization_aware_training())

torchao/prototype/parq/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,12 @@
1111
ProxPARQ,
1212
QuantOptimizer,
1313
)
14-
from .quant import LSBQuantizer, Quantizer, UnifQuantizer # noqa: F401
14+
from .quant import ( # noqa: F401
15+
Int4UnifTorchaoQuantizer,
16+
LSBQuantizer,
17+
MaxUnifQuantizer,
18+
Quantizer,
19+
TernaryUnifQuantizer,
20+
UnifQuantizer,
21+
UnifTorchaoQuantizer,
22+
)

0 commit comments

Comments
 (0)