Skip to content

Commit 8807d9f

Browse files
rewu93copybara-github
authored andcommitted
Update AI Edge Torch to use BLOCKWISE_XX interface in AEQ to achieve blockwise quantization.
PiperOrigin-RevId: 825135423
1 parent fc6ec07 commit 8807d9f

21 files changed

+176
-119
lines changed

ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,39 +1165,36 @@ def init_tensor_min_max(
11651165
A dictionary containing the min/max values for the tensor, or an empty
11661166
dictionary if the tensor data is None.
11671167
"""
1168-
if tensor_data is None:
1168+
weight_tensor_config = op_info.op_quant_config.weight_tensor_config
1169+
if tensor_data is None or weight_tensor_config is None:
11691170
return {}
11701171
else:
1171-
weight_tensor_config = op_info.op_quant_config.weight_tensor_config
1172-
quantized_dim = None
1173-
if weight_tensor_config is not None and (
1174-
weight_tensor_config.granularity == qtyping.QuantGranularity.CHANNELWISE
1175-
):
1172+
# Get reduce dimension for min/max calculation based on quantization
1173+
# granularity.
1174+
granularity = weight_tensor_config.granularity
1175+
if granularity == qtyping.QuantGranularity.TENSORWISE:
1176+
reduce_dims = None
1177+
keep_dims = True
1178+
elif granularity == qtyping.QuantGranularity.CHANNELWISE:
11761179
quantized_dim = common_utils.get_weight_quantized_dim(
11771180
op_info, tensor_data, weight_tensor_config.granularity
11781181
)
1179-
if (
1180-
weight_tensor_config is not None
1181-
and weight_tensor_config.granularity
1182-
== qtyping.QuantGranularity.BLOCKWISE
1183-
):
1184-
reshaped_data, reduce_dims = (
1182+
reduce_dims = common_utils.get_reduce_dims(
1183+
quantized_dim, tensor_data.shape
1184+
)
1185+
keep_dims = True
1186+
elif uniform_quantize_tensor.is_blockwise(granularity):
1187+
tensor_data, reduce_dims = (
11851188
uniform_quantize_tensor.reshape_data_for_blockwise(
11861189
tensor_data,
11871190
op_info.op_name,
1188-
weight_tensor_config.block_size,
1191+
granularity,
11891192
)
11901193
)
1191-
return {
1192-
"min": np.min(reshaped_data, axis=reduce_dims, keepdims=False),
1193-
"max": np.max(reshaped_data, axis=reduce_dims, keepdims=False),
1194-
}
1195-
1194+
keep_dims = False
11961195
else:
1197-
reduce_dims = common_utils.get_reduce_dims(
1198-
quantized_dim, tensor_data.shape
1199-
)
1200-
return {
1201-
"min": np.min(tensor_data, axis=reduce_dims, keepdims=True),
1202-
"max": np.max(tensor_data, axis=reduce_dims, keepdims=True),
1203-
}
1196+
raise ValueError(f"Unsupported granularity: {granularity}")
1197+
return {
1198+
"min": np.min(tensor_data, axis=reduce_dims, keepdims=keep_dims),
1199+
"max": np.max(tensor_data, axis=reduce_dims, keepdims=keep_dims),
1200+
}

ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def get_tensor_quant_params(
158158
op_info, tensor_quant_config, tensor_content, tensor_qsv
159159
)
160160

161-
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
161+
if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
162162
raise ValueError(
163163
"Blockwise quantization is not supported for dequantized weight"
164164
" recovery."

ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,7 @@ def test_fully_connected_blockwise_supported(self):
147147
weight_tensor_config=_TensorQuantConfig(
148148
num_bits=8,
149149
symmetric=True,
150-
granularity=qtyping.QuantGranularity.BLOCKWISE,
151-
block_size=32,
150+
granularity=qtyping.QuantGranularity.BLOCKWISE_32,
152151
),
153152
),
154153
)

ai_edge_quantizer/algorithms/uniform_quantize/mse.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def get_tensor_quant_params(
5555
ValueError: `tensor_qsv` must contain min/max values, or `tensor_content`
5656
must be provided so that they can be inferred.
5757
"""
58-
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
58+
if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
5959
raise ValueError(
6060
"Blockwise quantization is not supported for MSE quantization."
6161
)
@@ -113,13 +113,15 @@ def get_tensor_quant_params(
113113
num_bits=tensor_quant_config.num_bits,
114114
symmetric=tensor_quant_config.symmetric,
115115
quantized_dimension=quantized_dim,
116-
block_size=tensor_quant_config.block_size,
116+
block_size=uniform_quantize_tensor.extract_block_size_from_granularity(
117+
tensor_quant_config.granularity
118+
),
117119
)
118120

119121
quantized_vars = uniform_quantize_tensor.uniform_quantize(
120122
tensor_content,
121123
quant_params,
122-
tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
124+
uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity),
123125
)
124126

125127
return dataclasses.replace(quant_params, quantized_data=quantized_vars)

ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_get_tensor_quant_params_raises_error_with_unsupported_granularity(
8484
tensor_quant_config=qtyping.TensorQuantizationConfig(
8585
num_bits=4,
8686
symmetric=True,
87-
granularity=qtyping.QuantGranularity.BLOCKWISE,
87+
granularity=qtyping.QuantGranularity.BLOCKWISE_32,
8888
),
8989
tensor_content=test_data,
9090
)

ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Performs naive min/max uniform quantization."""
1717

18+
import dataclasses
1819
from typing import Any, Optional
1920
import numpy as np
2021
from ai_edge_quantizer import qtyping
@@ -91,26 +92,20 @@ def get_tensor_quant_params(
9192
num_bits=tensor_quant_config.num_bits,
9293
symmetric=tensor_quant_config.symmetric,
9394
quantized_dimension=quantized_dim,
94-
block_size=tensor_quant_config.block_size,
95+
block_size=uniform_quantize_tensor.extract_block_size_from_granularity(
96+
tensor_quant_config.granularity
97+
),
9598
)
9699
if tensor_content is None:
97100
return quant_params
98101

99102
quantized_vars = uniform_quantize_tensor.uniform_quantize(
100103
tensor_content,
101104
quant_params,
102-
tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
105+
uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity),
103106
)
104107
# Update with quantized values.
105-
return qtyping.UniformQuantParams(
106-
scale=scale,
107-
zero_point=zp,
108-
num_bits=tensor_quant_config.num_bits,
109-
symmetric=tensor_quant_config.symmetric,
110-
quantized_dimension=quantized_dim,
111-
quantized_data=quantized_vars,
112-
block_size=tensor_quant_config.block_size,
113-
)
108+
return dataclasses.replace(quant_params, quantized_data=quantized_vars)
114109

115110

116111
# TODO: b/333731147 - Use named tuple to store min/max.

ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import cast
1818

1919
from absl.testing import parameterized
20+
import ml_dtypes
2021
import numpy as np
2122

2223
from tensorflow.python.platform import googletest
@@ -165,8 +166,7 @@ def test_get_tensor_quant_params_for_blockwise_weight(self):
165166
weight_tensor_config = _TensorQuantConfig(
166167
num_bits=4,
167168
symmetric=True,
168-
granularity=qtyping.QuantGranularity.BLOCKWISE,
169-
block_size=2,
169+
granularity=qtyping.QuantGranularity.BLOCKWISE_32,
170170
)
171171
op_info = qtyping.OpInfo(
172172
op=fc_op,
@@ -176,28 +176,32 @@ def test_get_tensor_quant_params_for_blockwise_weight(self):
176176
weight_tensor_config=weight_tensor_config,
177177
),
178178
)
179-
test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
179+
test_data = np.random.uniform(low=-10, high=10, size=(4, 32)).astype(
180+
np.float32
181+
)
180182
quant_params = naive_min_max_quantize.get_tensor_quant_params(
181183
op_info=op_info,
182184
tensor_quant_config=weight_tensor_config,
183185
tensor_content=test_data,
184186
)
185-
scale = quant_params.scale
186187
zp = quant_params.zero_point
187-
expected_scale = np.array([
188-
[1],
189-
[0.5703125],
190-
[0.5703125],
191-
[1],
192-
])
193-
expected_zp = np.zeros([4, 1])
194-
self.assertTrue(np.array_equal(zp, expected_zp))
195-
self.assertTrue(np.array_equal(scale, expected_scale))
188+
self.assertEqual(zp.shape, (4, 1))
189+
self.assertTrue(np.array_equal(zp, np.zeros([4, 1])))
190+
191+
self.assertEqual(quant_params.scale.shape, (4, 1))
192+
expected_scales = np.max(np.abs(test_data), axis=1, keepdims=True) / 7.0
193+
expected_scales = (
194+
expected_scales.astype(ml_dtypes.bfloat16)
195+
.astype(np.float16)
196+
.astype(np.float32)
197+
)
198+
self.assertTrue(np.allclose(quant_params.scale, expected_scales, atol=1e-5))
199+
196200
self.assertIsNotNone(quant_params.quantized_data)
197201
self.assertTupleEqual(
198202
cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
199203
)
200-
self.assertEqual(quant_params.block_size, 2)
204+
self.assertEqual(quant_params.block_size, 32)
201205
self.assertEqual(quant_params.quantized_dimension, 1)
202206

203207
def test_calibrate_ignores_inf_min_max(self):

ai_edge_quantizer/algorithms/uniform_quantize/octav.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,12 @@ def get_tensor_quant_params(
131131
quantized_dim = common_utils.get_weight_quantized_dim(
132132
op_info, tensor_content, tensor_quant_config.granularity
133133
)
134-
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
134+
if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
135135
reshaped_data, reduce_dims = (
136136
uniform_quantize_tensor.reshape_data_for_blockwise(
137137
tensor_content,
138138
op_info.op_name,
139-
tensor_quant_config.block_size,
139+
tensor_quant_config.granularity,
140140
)
141141
)
142142
else:
@@ -154,7 +154,7 @@ def get_tensor_quant_params(
154154
# We created a new dimension in order to reduce properly for blockwise
155155
# quantization, so we need to reshape the clipping constants back to the
156156
# min/max shape for the next step.
157-
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
157+
if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
158158
clipping_constants = clipping_constants.reshape(tensor_min_max["min"].shape)
159159

160160
zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
@@ -172,13 +172,17 @@ def get_tensor_quant_params(
172172
num_bits=tensor_quant_config.num_bits,
173173
symmetric=tensor_quant_config.symmetric,
174174
quantized_dimension=quantized_dim,
175-
block_size=tensor_quant_config.block_size,
175+
block_size=uniform_quantize_tensor.extract_block_size_from_granularity(
176+
tensor_quant_config.granularity
177+
),
176178
)
177179

178180
quantized_vars = uniform_quantize_tensor.uniform_quantize(
179181
tensor_content,
180182
quant_params,
181-
tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
183+
is_blockwise_quant=uniform_quantize_tensor.is_blockwise(
184+
tensor_quant_config.granularity
185+
),
182186
)
183187

184188
return dataclasses.replace(quant_params, quantized_data=quantized_vars)

ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ def test_get_tensor_quant_params_sanity_blockwise(self):
196196
tensor_config = qtyping.TensorQuantizationConfig(
197197
num_bits=4,
198198
symmetric=True,
199-
granularity=qtyping.QuantGranularity.BLOCKWISE,
200-
block_size=32,
199+
granularity=qtyping.QuantGranularity.BLOCKWISE_32,
201200
)
202201
fc_op_info = qtyping.OpInfo(
203202
op=self._fc_op,

ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ class IntType:
2929
signed: bool
3030

3131

32+
def is_blockwise(granularity: qtyping.QuantGranularity) -> bool:
33+
"""Checks if the quantization granularity is blockwise."""
34+
return "BLOCKWISE" in str(granularity)
35+
36+
3237
def get_quantized_range(qtype: IntType) -> tuple[float, float]:
3338
"""Calculates range of the quantized type."""
3439
if qtype.signed:
@@ -40,6 +45,22 @@ def get_quantized_range(qtype: IntType) -> tuple[float, float]:
4045
return float(qmin), float(qmax)
4146

4247

48+
def extract_block_size_from_granularity(
49+
granularity: qtyping.QuantGranularity,
50+
) -> int:
51+
"""Get the block size for blockwise quantization."""
52+
if granularity == qtyping.QuantGranularity.BLOCKWISE_32:
53+
return 32
54+
elif granularity == qtyping.QuantGranularity.BLOCKWISE_64:
55+
return 64
56+
elif granularity == qtyping.QuantGranularity.BLOCKWISE_128:
57+
return 128
58+
elif granularity == qtyping.QuantGranularity.BLOCKWISE_256:
59+
return 256
60+
else:
61+
return 0
62+
63+
4364
def _round_and_clip(
4465
tensor: np.ndarray, qtype: IntType, narrow: bool
4566
) -> np.ndarray:
@@ -157,26 +178,28 @@ def _get_tensor_shape_for_blockwise(
157178

158179

159180
def reshape_data_for_blockwise(
160-
tensor_data: np.ndarray, op_name: qtyping.TFLOperationName, block_size: int
181+
tensor_data: np.ndarray,
182+
op_name: qtyping.TFLOperationName,
183+
granularity: qtyping.QuantGranularity,
161184
) -> tuple[np.ndarray, int]:
162185
"""Reshapes data for blockwise quantization.
163186
164187
Args:
165188
tensor_data: The original tensor data.
166189
op_name: The name of the TFL op.
167-
block_size: The size of the block.
190+
granularity: The quantization granularity for the tensor.
168191
169192
Returns:
170193
A tuple containing the reshaped tensor data and the new reduce dimension.
171194
"""
172195
quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
173196
op_name
174197
]
198+
block_size = extract_block_size_from_granularity(granularity)
175199
new_shape = _get_tensor_shape_for_blockwise(
176200
tensor_data.shape, quantized_dim, block_size
177201
)
178-
reshaped_data = tensor_data.reshape(new_shape)
179-
return reshaped_data, quantized_dim + 1
202+
return tensor_data.reshape(new_shape), quantized_dim + 1
180203

181204

182205
def _broadcast_scale_zp_for_blockwise(
@@ -233,21 +256,21 @@ def _broadcast_scale_zp_for_blockwise(
233256
def uniform_quantize(
234257
tensor_data: np.ndarray,
235258
quantization_params: qtyping.UniformQuantParams,
236-
is_blockwise: bool = False,
259+
is_blockwise_quant: bool = False,
237260
):
238261
"""Uniform quantize a tensor.
239262
240263
Args:
241264
tensor_data: The tensor to be quantized.
242265
quantization_params: The quantization parameters.
243-
is_blockwise: Whether the tensor is blockwise quantized.
266+
is_blockwise_quant: Whether the tensor is blockwise quantized.
244267
245268
Returns:
246269
The quantized tensor.
247270
"""
248271
# The reshaping for blockwise quantization is unique hence we do this here
249272
# to avoid unexpected broadcast behavior downstream.
250-
if is_blockwise:
273+
if is_blockwise_quant:
251274
quantization_params = _broadcast_scale_zp_for_blockwise(
252275
tensor_data, quantization_params
253276
)
@@ -435,6 +458,7 @@ def tensor_zp_scale_from_min_max(
435458
Returns:
436459
The zero point and scale of the tensor.
437460
"""
461+
438462
# TODO: b/332574603 - support unsigned data type.
439463
qtype = IntType(
440464
num_bits,
@@ -445,7 +469,7 @@ def tensor_zp_scale_from_min_max(
445469
pos_clipping_values = None if clipping_values is None else clipping_values
446470
neg_clipping_values = None if clipping_values is None else -clipping_values
447471

448-
if granularity == qtyping.QuantGranularity.BLOCKWISE:
472+
if is_blockwise(granularity):
449473
# Blockwise quantization uses float16 scale,
450474
# with 7 bit mantissa, so the maximum scale value is 65280 and maximum
451475
# representable range is [-65280 * (2 ** num_bits),
@@ -493,7 +517,7 @@ def tensor_zp_scale_from_min_max(
493517
zp = qmin - bound_min / scale
494518
zp = np.rint(zp)
495519

496-
if granularity == qtyping.QuantGranularity.BLOCKWISE:
520+
if is_blockwise(granularity):
497521
# Round the scale values to 7 bit mantissa.
498522
scale = (
499523
scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)

0 commit comments

Comments
 (0)