Skip to content

Commit fb7f521

Browse files
committed
Pass explicit layout to int4_weight_only
1 parent 5ac8a9b commit fb7f521

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed

test/prototype/test_parq.py

+33-24
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.testing._internal import common_utils
1212

1313
from torchao import quantize_
14+
from torchao.dtypes import Int4CPULayout, TensorCoreTiledLayout
1415
from torchao.prototype.parq.optim import (
1516
ProxHardQuant,
1617
ProxPARQ,
@@ -24,9 +25,10 @@
2425
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
2526
from torchao.quantization.granularity import PerGroup
2627
from torchao.quantization.quant_api import (
27-
Int4WeightOnlyConfig,
28+
LAYOUT_TO_ZERO_POINT_DOMAIN,
2829
IntxWeightOnlyConfig,
2930
_is_linear,
31+
int4_weight_only,
3032
)
3133
from torchao.quantization.quant_primitives import ZeroPointDomain
3234
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
@@ -66,8 +68,8 @@ def reset_parameters(self):
6668
if module.bias is not None:
6769
nn.init.zeros_(module.bias)
6870

69-
def example_inputs(self):
70-
return torch.randint(1, 10, (1, 256))
71+
def example_inputs(self, device=None):
72+
return torch.randint(1, 10, (1, 256), device=device)
7173

7274
def forward(self, x):
7375
x = self.embedding(x)
@@ -78,20 +80,19 @@ def forward(self, x):
7880
return x
7981

8082

81-
def train_loop(model, optimizer, update=True, steps=1):
82-
for _ in range(steps):
83-
x = model.example_inputs().to(_DEVICE)
84-
out = model(x)
85-
out.sum().backward()
86-
optimizer.step()
87-
88-
8983
class TestPARQuantization(common_utils.TestCase):
9084
def setUp(self):
9185
torch.manual_seed(123)
9286
self.model = M(bias=True).to(_DEVICE)
9387
self.params_quant, self.params_no_quant = split_param_groups(self.model)
9488

89+
def train_loop(self, optimizer, steps=1):
90+
for _ in range(steps):
91+
x = self.model.example_inputs(device=_DEVICE)
92+
out = self.model(x)
93+
out.sum().backward()
94+
optimizer.step()
95+
9596
def test_2bit_unif_quantizer_hard_prox(self):
9697
b = 2
9798
self.model.reset_parameters()
@@ -102,7 +103,7 @@ def test_2bit_unif_quantizer_hard_prox(self):
102103
base_optimizer = torch.optim.AdamW(param_groups)
103104
quantizer = UnifQuantizer()
104105
optimizer = QuantOptimizer(base_optimizer, quantizer, ProxHardQuant())
105-
train_loop(self.model, optimizer)
106+
self.train_loop(optimizer)
106107

107108
for child in self.model.children():
108109
if isinstance(child, nn.Linear):
@@ -122,7 +123,7 @@ def test_ternarybit_lsbq_parq_prox(self):
122123
optimizer = QuantOptimizer(
123124
base_optimizer, quantizer, ProxPARQ(anneal_start=0, anneal_end=2)
124125
)
125-
train_loop(self.model, optimizer, steps=3)
126+
self.train_loop(optimizer, steps=3)
126127

127128
for child in self.model.children():
128129
if isinstance(child, nn.Linear):
@@ -136,7 +137,7 @@ def setUp(self):
136137
torch.manual_seed(123)
137138

138139
@staticmethod
139-
def int4_torchao_quantizer(b: int = 4, config=None):
140+
def int4_torchao_quantizer(config, b: int = 4):
140141
# based off torchao.quantization.quant_api._int4_weight_only_transform
141142
return UnifTorchaoQuantizer(
142143
symmetric=False,
@@ -145,6 +146,7 @@ def int4_torchao_quantizer(b: int = 4, config=None):
145146
quant_max=2**b - 1,
146147
eps=1e-6,
147148
preserve_zero=False,
149+
zero_point_domain=LAYOUT_TO_ZERO_POINT_DOMAIN[type(config.layout)][0],
148150
config=config,
149151
)
150152

@@ -172,28 +174,31 @@ def compare_quantized_models(
172174
ref = getattr(m_ref, n).weight.dequantize()
173175
self.assertTrue(q.equal(ref))
174176

175-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
176-
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
177177
@common_utils.parametrize("group_size", [32, 256])
178+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
178179
def test_int4_weight_only(self, group_size: int = 32):
179180
model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE)
180181
model.reset_parameters()
181182

182-
m_ref = copy.deepcopy(model)
183-
quantize_(m_ref, Int4WeightOnlyConfig(group_size))
183+
m_ref = copy.deepcopy(model).eval()
184+
config = int4_weight_only(
185+
group_size=group_size,
186+
layout=Int4CPULayout() if _DEVICE == "cpu" else TensorCoreTiledLayout(8),
187+
)
188+
quantize_(m_ref, config)
184189

185190
b = 4
186-
quantizer = self.int4_torchao_quantizer()
191+
quantizer = self.int4_torchao_quantizer(config)
187192
self.compare_quantized_models(model, m_ref, quantizer, b, group_size)
188193

189-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
190194
@common_utils.parametrize("b", [2, 3, 4, 8])
191195
@common_utils.parametrize("group_size", [32, 512])
196+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
192197
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
193198
model = M(m=512, n=512).to(_DEVICE)
194199
model.reset_parameters()
195200

196-
m_ref = copy.deepcopy(model)
201+
m_ref = copy.deepcopy(model).eval()
197202
quantize_(
198203
m_ref,
199204
IntxWeightOnlyConfig(
@@ -214,8 +219,11 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
214219
model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE)
215220
model.reset_parameters()
216221

217-
m_ref = copy.deepcopy(model)
218-
config = Int4WeightOnlyConfig(group_size)
222+
m_ref = copy.deepcopy(model).eval()
223+
config = int4_weight_only(
224+
group_size=group_size,
225+
layout=TensorCoreTiledLayout(8),
226+
)
219227
quantize_(m_ref, config)
220228

221229
b = 4
@@ -226,7 +234,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
226234
]
227235
base_optimizer = torch.optim.AdamW(param_groups)
228236

229-
quantizer = self.int4_torchao_quantizer(config=config)
237+
quantizer = self.int4_torchao_quantizer(config)
230238
optimizer = QuantOptimizer(
231239
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
232240
)
@@ -238,6 +246,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
238246
orig_model = copy.deepcopy(model) # save copy of PARQ quantized model
239247

240248
# equivalent to torchao's convert step
249+
model.eval()
241250
with torch.no_grad():
242251
optimizer.restore_latent_params()
243252
quantize_(model, quantizer.config)

torchao/prototype/parq/quant/uniform_torchao.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
quant_max: Optional[Union[int, float]] = None,
3737
eps: Optional[float] = None,
3838
preserve_zero: bool = True,
39-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT,
39+
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE,
4040
config: Optional[AOBaseConfig] = None,
4141
) -> None:
4242
super().__init__(center=False)

0 commit comments

Comments
 (0)