Skip to content

Commit 959c29b

Browse files
szyszyzysliangel-02
authored andcommitted
Add lut quantized embedding.
Differential Revision: D79750002 Pull Request resolved: #2824
1 parent b7af6ab commit 959c29b

File tree

2 files changed

+183
-2
lines changed

2 files changed

+183
-2
lines changed
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
from .api import GroupwiseLutWeightConfig
22
from .codebook_quantized_tensor import CodebookQuantizedPackedTensor
33

4-
__all__ = ["CodebookQuantizedPackedTensor", "GroupwiseLutWeightConfig"]
4+
__all__ = [
5+
"CodebookQuantizedPackedTensor",
6+
"GroupwiseLutWeightConfig",
7+
"QuantizedLutEmbedding",
8+
"EmbeddingLutQuantizer",
9+
]

torchao/prototype/quantization/codebook_groupwise/api.py

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import List, Optional
99

1010
import torch
11+
import torch.nn as nn
1112

1213
from torchao.core.config import AOBaseConfig
1314
from torchao.prototype.quantization.codebook_coreml.codebook_quantized_tensor import (
@@ -16,6 +17,12 @@
1617
from torchao.prototype.quantization.codebook_groupwise.codebook_quantized_tensor import (
1718
CodebookQuantizedPackedTensor,
1819
)
20+
from torchao.prototype.quantization.codebook_utils.codebook_utils import (
21+
block_shape_to_group_size,
22+
)
23+
from torchao.quantization.quant_primitives import (
24+
_DTYPE_TO_BIT_WIDTH,
25+
)
1926
from torchao.quantization.transform_module import register_quantize_module_handler
2027

2128

@@ -98,9 +105,11 @@ def __post_init__(self):
98105
raise ValueError(
99106
"`lut_block_shape` must contain exactly one '-1' to specify the grouping dimension."
100107
)
108+
if self.has_scale == True:
109+
raise ValueError("currently only support lut quantization without scale")
101110

102111
# 3. Validate scale_block_shape if it exists
103-
if self.scale_block_shape is not None:
112+
if self.has_scale and self.scale_block_shape is not None:
104113
if not (
105114
isinstance(self.scale_block_shape, list)
106115
and len(self.scale_block_shape) == 2
@@ -142,3 +151,170 @@ def _groupwise_lut_weight_transform(
142151
module.weight.data.copy_(dequantized_weight)
143152

144153
return module
154+
155+
156+
class QuantizedLutEmbedding(nn.Module):
157+
"""
158+
A PyTorch module that holds a LUT-based quantized embedding layer and
159+
performs the forward pass using a high-performance C++ kernel.
160+
161+
This module should be created from a floating-point nn.Embedding module
162+
using the `from_float` classmethod.
163+
"""
164+
165+
def __init__(
166+
self, config: GroupwiseLutWeightConfig, num_embeddings: int, embedding_dim: int
167+
):
168+
super().__init__()
169+
# Store config and metadata needed for the forward pass
170+
self.config = config
171+
self.num_embeddings = num_embeddings
172+
self.embedding_dim = embedding_dim
173+
self.bit_width = _DTYPE_TO_BIT_WIDTH[config.code_dtype]
174+
175+
# This buffer will be populated by the from_float method
176+
self.register_buffer("packed_weights", torch.empty(0, dtype=torch.uint8))
177+
178+
@classmethod
179+
def from_float(
180+
cls, float_embedding: nn.Embedding, config: GroupwiseLutWeightConfig
181+
) -> "QuantizedLutEmbedding":
182+
"""
183+
Creates a quantized embedding module from a floating-point nn.Embedding.
184+
185+
Args:
186+
float_embedding (nn.Embedding): The original, trained embedding module.
187+
config (GroupwiseLutWeightConfig): The configuration for quantization.
188+
189+
Returns:
190+
QuantizedLutEmbedding: A new module with quantized and packed weights.
191+
"""
192+
assert isinstance(float_embedding, nn.Embedding), (
193+
"Input must be an nn.Embedding module."
194+
)
195+
196+
weight = float_embedding.weight.data
197+
num_embeddings, embedding_dim = weight.shape
198+
199+
# --- 1. Call our universal quantize_dispatch function ---
200+
quantized_tensor = CodebookQuantizedTensor.from_float(
201+
weight, code_dtype=config.code_dtype, block_size=config.lut_block_shape
202+
)
203+
codes = quantized_tensor.codes
204+
codebook = quantized_tensor.codebook.to(torch.float32)
205+
# Currently only support lut quantization without scale. Upate this when we support scale.
206+
scales = None
207+
208+
# Pack the quantized data
209+
bit_width = _DTYPE_TO_BIT_WIDTH[config.code_dtype]
210+
packer_op = getattr(torch.ops.torchao, f"_pack_embedding_lut_{bit_width}bit")
211+
packed_weights = packer_op(
212+
codes,
213+
codebook,
214+
block_shape_to_group_size(
215+
config.scale_block_shape, (num_embeddings, embedding_dim)
216+
)
217+
if config.scale_block_shape
218+
else -1,
219+
block_shape_to_group_size(
220+
config.lut_block_shape, (num_embeddings, embedding_dim)
221+
),
222+
scales,
223+
)
224+
225+
# Create and populate the new quantized module
226+
quantized_module = cls(config, num_embeddings, embedding_dim)
227+
quantized_module.register_buffer("packed_weights", packed_weights)
228+
229+
return quantized_module
230+
231+
def forward(self, indices: torch.Tensor) -> torch.Tensor:
232+
"""
233+
Performs the embedding lookup using the packed weights.
234+
"""
235+
# The forward pass logic remains the same.
236+
forward_op = getattr(torch.ops.torchao, f"_embedding_lut_{self.bit_width}bit")
237+
238+
# The C++ operator reads all metadata from the packed_weights header
239+
result = forward_op(
240+
self.packed_weights,
241+
indices.reshape(-1),
242+
self.num_embeddings,
243+
self.embedding_dim,
244+
block_shape_to_group_size(
245+
self.config.scale_block_shape, (self.num_embeddings, self.embedding_dim)
246+
)
247+
if self.config.scale_block_shape
248+
else -1,
249+
block_shape_to_group_size(
250+
self.config.lut_block_shape, (self.num_embeddings, self.embedding_dim)
251+
),
252+
self.config.has_scale,
253+
)
254+
return result.reshape(*indices.shape, self.embedding_dim).to(
255+
self.config.weight_dtype
256+
)
257+
258+
def __repr__(self):
259+
return (
260+
f"QuantizedLutEmbedding(num_embeddings={self.num_embeddings}, "
261+
f"embedding_dim={self.embedding_dim}, bit_width={self.bit_width}, "
262+
f"lut_block_shape={self.config.lut_block_shape})"
263+
)
264+
265+
266+
class EmbeddingLutQuantizer:
267+
"""
268+
A quantizer that finds nn.Embedding modules in a model and replaces
269+
them with the QuantizedLutEmbedding module based on a provided configuration.
270+
"""
271+
272+
def __init__(self, config: GroupwiseLutWeightConfig):
273+
"""
274+
Initializes the quantizer with a single, comprehensive configuration object.
275+
276+
Args:
277+
config (GroupwiseLutWeightConfig): The configuration that defines
278+
how all embeddings should be quantized.
279+
"""
280+
# The quantizer now holds the entire configuration object.
281+
self.config = config
282+
283+
def quantize(self, model: nn.Module) -> nn.Module:
284+
"""
285+
Recursively traverses the model and replaces all nn.Embedding layers.
286+
287+
Args:
288+
model (nn.Module): The model to be quantized.
289+
290+
Returns:
291+
nn.Module: The model with embedding layers replaced.
292+
"""
293+
self._replace_embedding(model)
294+
return model
295+
296+
def _replace_embedding(self, module: nn.Module):
297+
for name, child in module.named_children():
298+
if isinstance(child, nn.Embedding):
299+
if self.config.use_qdq_reference:
300+
weight = child.weight.data
301+
302+
# 1. Run the full quantize -> dequantize pipeline in Python
303+
quantized_tensor = CodebookQuantizedTensor.from_float(
304+
weight,
305+
code_dtype=self.config.code_dtype,
306+
block_size=self.config.lut_block_shape,
307+
)
308+
ref_weight = quantized_tensor.dequantize(self.config.weight_dtype)
309+
310+
# 2. Create a standard nn.Embedding with the dequantized weight
311+
ref_embedding = nn.Embedding.from_pretrained(
312+
ref_weight, freeze=True
313+
)
314+
setattr(module, name, ref_embedding)
315+
316+
else:
317+
q_embedding = QuantizedLutEmbedding.from_float(child, self.config)
318+
setattr(module, name, q_embedding)
319+
else:
320+
self._replace_embedding(child)

0 commit comments

Comments
 (0)