Skip to content

Commit f93433f

Browse files
szyszyzysfacebook-github-bot
authored andcommitted
Add lut quantized embedding. (#2824)
Summary: Pull Request resolved: #2824 Reviewed By: metascroy Differential Revision: D79750002
1 parent 28f38c4 commit f93433f

File tree

3 files changed

+307
-4
lines changed

3 files changed

+307
-4
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .api import GroupwiseLutWeightConfig
22
from .codebook_quantized_tensor import CodebookQuantizedPackedTensor
33

4-
__all__ = ["CodebookQuantizedPackedTensor", "GroupwiseLutWeightConfig"]
4+
__all__ = ["CodebookQuantizedPackedTensor", "GroupwiseLutWeightConfig", "QuantizedLutEmbedding", "EmbeddingLutQuantizer"]

torchao/prototype/quantization/codebook_groupwise/api.py

Lines changed: 246 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
#
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import hashlib
7-
import os
86
import types
97
from dataclasses import dataclass, field
108
from typing import List, Optional
119

1210
import torch
11+
import torch.nn as nn
1312

1413
from torchao.core.config import AOBaseConfig
1514
from torchao.prototype.quantization.codebook_coreml.codebook_quantized_tensor import (
@@ -18,7 +17,13 @@
1817
from torchao.prototype.quantization.codebook_groupwise.codebook_quantized_tensor import (
1918
CodebookQuantizedPackedTensor,
2019
)
20+
from torchao.prototype.quantization.codebook_utils.codebook_utils import (
21+
block_shape_to_group_size,
22+
)
2123
from torchao.quantization.transform_module import register_quantize_module_handler
24+
from torchao.quantization.quant_primitives import (
25+
_DTYPE_TO_BIT_WIDTH,
26+
)
2227

2328

2429
def _get_linear_extra_repr_for_lut(self) -> str:
@@ -100,9 +105,11 @@ def __post_init__(self):
100105
raise ValueError(
101106
"`lut_block_shape` must contain exactly one '-1' to specify the grouping dimension."
102107
)
108+
if self.has_scale == True:
109+
raise ValueError("currently only support lut quantization without scale")
103110

104111
# 3. Validate scale_block_shape if it exists
105-
if self.scale_block_shape is not None:
112+
if self.has_scale and self.scale_block_shape is not None:
106113
if not (
107114
isinstance(self.scale_block_shape, list)
108115
and len(self.scale_block_shape) == 2
@@ -144,3 +151,239 @@ def _groupwise_lut_weight_transform(
144151
module.weight.data.copy_(dequantized_weight)
145152

146153
return module
154+
155+
156+
class QuantizedLutEmbedding(nn.Module):
157+
"""
158+
A PyTorch module that holds a LUT-based quantized embedding layer and
159+
performs the forward pass using a high-performance C++ kernel.
160+
161+
This module should be created from a floating-point nn.Embedding module
162+
using the `from_float` classmethod.
163+
"""
164+
165+
def __init__(
166+
self, config: GroupwiseLutWeightConfig, num_embeddings: int, embedding_dim: int
167+
):
168+
super().__init__()
169+
# Store config and metadata needed for the forward pass
170+
self.config = config
171+
self.num_embeddings = num_embeddings
172+
self.embedding_dim = embedding_dim
173+
self.bit_width = _DTYPE_TO_BIT_WIDTH[config.code_dtype]
174+
175+
# This buffer will be populated by the from_float method
176+
self.register_buffer("packed_weights", torch.empty(0, dtype=torch.uint8))
177+
178+
@classmethod
179+
def from_float(
180+
cls, float_embedding: nn.Embedding, config: GroupwiseLutWeightConfig
181+
) -> "QuantizedLutEmbedding":
182+
"""
183+
Creates a quantized embedding module from a floating-point nn.Embedding.
184+
185+
Args:
186+
float_embedding (nn.Embedding): The original, trained embedding module.
187+
config (GroupwiseLutWeightConfig): The configuration for quantization.
188+
189+
Returns:
190+
QuantizedLutEmbedding: A new module with quantized and packed weights.
191+
"""
192+
assert isinstance(float_embedding, nn.Embedding), (
193+
"Input must be an nn.Embedding module."
194+
)
195+
196+
weight = float_embedding.weight.data
197+
num_embeddings, embedding_dim = weight.shape
198+
199+
# --- 1. Call our universal quantize_dispatch function ---
200+
quantized_tensor = CodebookQuantizedTensor.from_float(
201+
weight, code_dtype=config.code_dtype, block_size=config.lut_block_shape
202+
)
203+
codes = quantized_tensor.codes
204+
codebook = quantized_tensor.codebook.to(torch.float32)
205+
# Currently only support lut quantization without scale. Upate this when we support scale.
206+
scales = None
207+
208+
# Pack the quantized data
209+
bit_width = _DTYPE_TO_BIT_WIDTH[config.code_dtype]
210+
packer_op = getattr(torch.ops.torchao, f"_pack_embedding_lut_{bit_width}bit")
211+
packed_weights = packer_op(
212+
codes,
213+
codebook,
214+
block_shape_to_group_size(
215+
config.scale_block_shape, (num_embeddings, embedding_dim)
216+
)
217+
if config.scale_block_shape
218+
else -1,
219+
block_shape_to_group_size(
220+
config.lut_block_shape, (num_embeddings, embedding_dim)
221+
),
222+
scales,
223+
)
224+
225+
# Create and populate the new quantized module
226+
quantized_module = cls(config, num_embeddings, embedding_dim)
227+
quantized_module.register_buffer("packed_weights", packed_weights)
228+
229+
return quantized_module
230+
231+
def forward(self, indices: torch.Tensor) -> torch.Tensor:
232+
"""
233+
Performs the embedding lookup using the packed weights.
234+
"""
235+
# The forward pass logic remains the same.
236+
forward_op = getattr(torch.ops.torchao, f"_embedding_lut_{self.bit_width}bit")
237+
238+
# The C++ operator reads all metadata from the packed_weights header
239+
result = forward_op(
240+
self.packed_weights,
241+
indices.reshape(-1),
242+
self.num_embeddings,
243+
self.embedding_dim,
244+
block_shape_to_group_size(
245+
self.config.scale_block_shape, (self.num_embeddings, self.embedding_dim)
246+
)
247+
if self.config.scale_block_shape
248+
else -1,
249+
block_shape_to_group_size(
250+
self.config.lut_block_shape, (self.num_embeddings, self.embedding_dim)
251+
),
252+
self.config.has_scale,
253+
)
254+
return result.reshape(*indices.shape, self.embedding_dim).to(
255+
self.config.weight_dtype
256+
)
257+
258+
def __repr__(self):
259+
return (
260+
f"QuantizedLutEmbedding(num_embeddings={self.num_embeddings}, "
261+
f"embedding_dim={self.embedding_dim}, bit_width={self.bit_width}, "
262+
f"lut_block_shape={self.config.lut_block_shape})"
263+
)
264+
265+
266+
class EmbeddingLutQuantizer:
267+
"""
268+
A quantizer that replaces nn.Embedding modules with the
269+
QuantizedLutEmbedding module.
270+
"""
271+
272+
def __init__(
273+
self,
274+
weight_dtype: torch.dtype = torch.int4,
275+
group_size: int = 64,
276+
):
277+
assert weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)]
278+
self.bit_width = _DTYPE_TO_BIT_WIDTH[weight_dtype]
279+
self.group_size = group_size
280+
281+
def quantize(self, model: nn.Module) -> nn.Module:
282+
self._replace_embedding(model)
283+
return model
284+
285+
def _replace_embedding(self, module: nn.Module):
286+
for name, child in module.named_children():
287+
if isinstance(child, nn.Embedding):
288+
q_embedding = QuantizedLutEmbedding(self.bit_width)
289+
q_embedding.quantize_and_pack(child.weight.data, self.group_size)
290+
setattr(module, name, q_embedding)
291+
else:
292+
self._replace_embedding(child)
293+
294+
295+
class EmbeddingLutQuantizer:
296+
"""
297+
A quantizer that finds nn.Embedding modules in a model and replaces
298+
them with the QuantizedLutEmbedding module based on a provided configuration.
299+
"""
300+
301+
def __init__(self, config: GroupwiseLutWeightConfig):
302+
"""
303+
Initializes the quantizer with a single, comprehensive configuration object.
304+
305+
Args:
306+
config (GroupwiseLutWeightConfig): The configuration that defines
307+
how all embeddings should be quantized.
308+
"""
309+
# The quantizer now holds the entire configuration object.
310+
self.config = config
311+
312+
def quantize(self, model: nn.Module) -> nn.Module:
313+
"""
314+
Recursively traverses the model and replaces all nn.Embedding layers.
315+
316+
Args:
317+
model (nn.Module): The model to be quantized.
318+
319+
Returns:
320+
nn.Module: The model with embedding layers replaced.
321+
"""
322+
self._replace_embedding(model)
323+
return model
324+
325+
def _replace_embedding(self, module: nn.Module):
326+
for name, child in module.named_children():
327+
if isinstance(child, nn.Embedding):
328+
q_embedding = QuantizedLutEmbedding.from_float(child, self.config)
329+
330+
setattr(module, name, q_embedding)
331+
else:
332+
self._replace_embedding(child)
333+
334+
335+
class EmbeddingLutQuantizer:
336+
"""
337+
A quantizer that finds nn.Embedding modules in a model and replaces
338+
them with the QuantizedLutEmbedding module based on a provided configuration.
339+
"""
340+
341+
def __init__(self, config: GroupwiseLutWeightConfig):
342+
"""
343+
Initializes the quantizer with a single, comprehensive configuration object.
344+
345+
Args:
346+
config (GroupwiseLutWeightConfig): The configuration that defines
347+
how all embeddings should be quantized.
348+
"""
349+
# The quantizer now holds the entire configuration object.
350+
self.config = config
351+
352+
def quantize(self, model: nn.Module) -> nn.Module:
353+
"""
354+
Recursively traverses the model and replaces all nn.Embedding layers.
355+
356+
Args:
357+
model (nn.Module): The model to be quantized.
358+
359+
Returns:
360+
nn.Module: The model with embedding layers replaced.
361+
"""
362+
self._replace_embedding(model)
363+
return model
364+
365+
def _replace_embedding(self, module: nn.Module):
366+
for name, child in module.named_children():
367+
if isinstance(child, nn.Embedding):
368+
if self.config.use_qdq_reference:
369+
weight = child.weight.data
370+
371+
# 1. Run the full quantize -> dequantize pipeline in Python
372+
quantized_tensor = CodebookQuantizedTensor.from_float(
373+
weight,
374+
code_dtype=self.config.code_dtype,
375+
block_size=self.config.lut_block_shape,
376+
)
377+
ref_weight = quantized_tensor.dequantize(self.config.weight_dtype)
378+
379+
# 2. Create a standard nn.Embedding with the dequantized weight
380+
ref_embedding = nn.Embedding.from_pretrained(
381+
ref_weight, freeze=True
382+
)
383+
setattr(module, name, ref_embedding)
384+
385+
else:
386+
q_embedding = QuantizedLutEmbedding.from_float(child, self.config)
387+
setattr(module, name, q_embedding)
388+
else:
389+
self._replace_embedding(child)

torchao/prototype/quantization/codebook_utils/codebook_utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,66 @@
2222
)
2323
from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH
2424

25+
def block_shape_to_group_size(block_shape, tensor_shape):
26+
"""Calculates the total number of elements in a group from a block_shape."""
27+
n_group, k_group = block_shape
28+
n_dim, k_dim = tensor_shape
29+
30+
if n_group == -1:
31+
n_group = n_dim
32+
if k_group == -1:
33+
k_group = k_dim
34+
35+
return n_group * k_group
36+
37+
def group_size_to_block_shapes(
38+
lut_group_size: int,
39+
tensor_shape: Tuple[int, int],
40+
scale_group_size: Optional[int] = None,
41+
) -> Tuple[List[int], Optional[List[int]]]:
42+
"""
43+
Translates legacy integer-based group sizes into the new block_shape list format.
44+
45+
This function encodes the implicit assumptions of the old system:
46+
- LUTs were always grouped by rows.
47+
- Scales were always grouped by columns.
48+
49+
Args:
50+
lut_group_size (int): The total number of elements that shared a single LUT.
51+
tensor_shape (Tuple[int, int]): The shape of the weight tensor (N, K).
52+
This is required to calculate the number of rows for the LUT group.
53+
scale_group_size (Optional[int]): The number of elements (columns) that
54+
shared a single scale factor. Can be None or -1 if not used.
55+
56+
Returns:
57+
A tuple containing:
58+
- lut_block_shape (List[int]): The new block shape for LUTs (e.g., [N, -1]).
59+
- scale_block_shape (Optional[List[int]]): The new block shape for scales
60+
(e.g., [-1, K]), or None.
61+
"""
62+
n_rows, k_cols = tensor_shape
63+
64+
# --- 1. Translate LUT Group Size ---
65+
if lut_group_size % k_cols != 0:
66+
raise ValueError(
67+
f"lut_group_size ({lut_group_size}) must be divisible by the number "
68+
f"of columns ({k_cols}) for legacy row-grouping."
69+
)
70+
rows_per_lut = lut_group_size // k_cols
71+
lut_block_shape = [rows_per_lut, -1]
72+
73+
# --- 2. Translate Scale Group Size ---
74+
scale_block_shape = None
75+
if scale_group_size is not None and scale_group_size > 0:
76+
if k_cols % scale_group_size != 0:
77+
raise ValueError(
78+
f"Number of columns ({k_cols}) must be divisible by "
79+
f"scale_group_size ({scale_group_size}) for legacy column-grouping."
80+
)
81+
scale_block_shape = [1, scale_group_size]
82+
83+
return lut_block_shape, scale_block_shape
84+
2585

2686
def block_shape_to_group_size(block_shape, tensor_shape):
2787
"""Calculates the total number of elements in a group from a block_shape."""

0 commit comments

Comments
 (0)