|
8 | 8 | from typing import List, Optional
|
9 | 9 |
|
10 | 10 | import torch
|
| 11 | +import torch.nn as nn |
11 | 12 |
|
12 | 13 | from torchao.core.config import AOBaseConfig
|
13 | 14 | from torchao.prototype.quantization.codebook_coreml.codebook_quantized_tensor import (
|
|
16 | 17 | from torchao.prototype.quantization.codebook_groupwise.codebook_quantized_tensor import (
|
17 | 18 | CodebookQuantizedPackedTensor,
|
18 | 19 | )
|
| 20 | +from torchao.prototype.quantization.codebook_utils.codebook_utils import ( |
| 21 | + block_shape_to_group_size, |
| 22 | +) |
| 23 | +from torchao.quantization.quant_primitives import ( |
| 24 | + _DTYPE_TO_BIT_WIDTH, |
| 25 | +) |
19 | 26 | from torchao.quantization.transform_module import register_quantize_module_handler
|
20 | 27 |
|
21 | 28 |
|
@@ -98,9 +105,11 @@ def __post_init__(self):
|
98 | 105 | raise ValueError(
|
99 | 106 | "`lut_block_shape` must contain exactly one '-1' to specify the grouping dimension."
|
100 | 107 | )
|
| 108 | + if self.has_scale == True: |
| 109 | + raise ValueError("currently only support lut quantization without scale") |
101 | 110 |
|
102 | 111 | # 3. Validate scale_block_shape if it exists
|
103 |
| - if self.scale_block_shape is not None: |
| 112 | + if self.has_scale and self.scale_block_shape is not None: |
104 | 113 | if not (
|
105 | 114 | isinstance(self.scale_block_shape, list)
|
106 | 115 | and len(self.scale_block_shape) == 2
|
@@ -142,3 +151,170 @@ def _groupwise_lut_weight_transform(
|
142 | 151 | module.weight.data.copy_(dequantized_weight)
|
143 | 152 |
|
144 | 153 | return module
|
| 154 | + |
| 155 | + |
| 156 | +class QuantizedLutEmbedding(nn.Module): |
| 157 | + """ |
| 158 | + A PyTorch module that holds a LUT-based quantized embedding layer and |
| 159 | + performs the forward pass using a high-performance C++ kernel. |
| 160 | +
|
| 161 | + This module should be created from a floating-point nn.Embedding module |
| 162 | + using the `from_float` classmethod. |
| 163 | + """ |
| 164 | + |
| 165 | + def __init__( |
| 166 | + self, config: GroupwiseLutWeightConfig, num_embeddings: int, embedding_dim: int |
| 167 | + ): |
| 168 | + super().__init__() |
| 169 | + # Store config and metadata needed for the forward pass |
| 170 | + self.config = config |
| 171 | + self.num_embeddings = num_embeddings |
| 172 | + self.embedding_dim = embedding_dim |
| 173 | + self.bit_width = _DTYPE_TO_BIT_WIDTH[config.code_dtype] |
| 174 | + |
| 175 | + # This buffer will be populated by the from_float method |
| 176 | + self.register_buffer("packed_weights", torch.empty(0, dtype=torch.uint8)) |
| 177 | + |
| 178 | + @classmethod |
| 179 | + def from_float( |
| 180 | + cls, float_embedding: nn.Embedding, config: GroupwiseLutWeightConfig |
| 181 | + ) -> "QuantizedLutEmbedding": |
| 182 | + """ |
| 183 | + Creates a quantized embedding module from a floating-point nn.Embedding. |
| 184 | +
|
| 185 | + Args: |
| 186 | + float_embedding (nn.Embedding): The original, trained embedding module. |
| 187 | + config (GroupwiseLutWeightConfig): The configuration for quantization. |
| 188 | +
|
| 189 | + Returns: |
| 190 | + QuantizedLutEmbedding: A new module with quantized and packed weights. |
| 191 | + """ |
| 192 | + assert isinstance(float_embedding, nn.Embedding), ( |
| 193 | + "Input must be an nn.Embedding module." |
| 194 | + ) |
| 195 | + |
| 196 | + weight = float_embedding.weight.data |
| 197 | + num_embeddings, embedding_dim = weight.shape |
| 198 | + |
| 199 | + # --- 1. Call our universal quantize_dispatch function --- |
| 200 | + quantized_tensor = CodebookQuantizedTensor.from_float( |
| 201 | + weight, code_dtype=config.code_dtype, block_size=config.lut_block_shape |
| 202 | + ) |
| 203 | + codes = quantized_tensor.codes |
| 204 | + codebook = quantized_tensor.codebook.to(torch.float32) |
| 205 | + # Currently only support lut quantization without scale. Upate this when we support scale. |
| 206 | + scales = None |
| 207 | + |
| 208 | + # Pack the quantized data |
| 209 | + bit_width = _DTYPE_TO_BIT_WIDTH[config.code_dtype] |
| 210 | + packer_op = getattr(torch.ops.torchao, f"_pack_embedding_lut_{bit_width}bit") |
| 211 | + packed_weights = packer_op( |
| 212 | + codes, |
| 213 | + codebook, |
| 214 | + block_shape_to_group_size( |
| 215 | + config.scale_block_shape, (num_embeddings, embedding_dim) |
| 216 | + ) |
| 217 | + if config.scale_block_shape |
| 218 | + else -1, |
| 219 | + block_shape_to_group_size( |
| 220 | + config.lut_block_shape, (num_embeddings, embedding_dim) |
| 221 | + ), |
| 222 | + scales, |
| 223 | + ) |
| 224 | + |
| 225 | + # Create and populate the new quantized module |
| 226 | + quantized_module = cls(config, num_embeddings, embedding_dim) |
| 227 | + quantized_module.register_buffer("packed_weights", packed_weights) |
| 228 | + |
| 229 | + return quantized_module |
| 230 | + |
| 231 | + def forward(self, indices: torch.Tensor) -> torch.Tensor: |
| 232 | + """ |
| 233 | + Performs the embedding lookup using the packed weights. |
| 234 | + """ |
| 235 | + # The forward pass logic remains the same. |
| 236 | + forward_op = getattr(torch.ops.torchao, f"_embedding_lut_{self.bit_width}bit") |
| 237 | + |
| 238 | + # The C++ operator reads all metadata from the packed_weights header |
| 239 | + result = forward_op( |
| 240 | + self.packed_weights, |
| 241 | + indices.reshape(-1), |
| 242 | + self.num_embeddings, |
| 243 | + self.embedding_dim, |
| 244 | + block_shape_to_group_size( |
| 245 | + self.config.scale_block_shape, (self.num_embeddings, self.embedding_dim) |
| 246 | + ) |
| 247 | + if self.config.scale_block_shape |
| 248 | + else -1, |
| 249 | + block_shape_to_group_size( |
| 250 | + self.config.lut_block_shape, (self.num_embeddings, self.embedding_dim) |
| 251 | + ), |
| 252 | + self.config.has_scale, |
| 253 | + ) |
| 254 | + return result.reshape(*indices.shape, self.embedding_dim).to( |
| 255 | + self.config.weight_dtype |
| 256 | + ) |
| 257 | + |
| 258 | + def __repr__(self): |
| 259 | + return ( |
| 260 | + f"QuantizedLutEmbedding(num_embeddings={self.num_embeddings}, " |
| 261 | + f"embedding_dim={self.embedding_dim}, bit_width={self.bit_width}, " |
| 262 | + f"lut_block_shape={self.config.lut_block_shape})" |
| 263 | + ) |
| 264 | + |
| 265 | + |
| 266 | +class EmbeddingLutQuantizer: |
| 267 | + """ |
| 268 | + A quantizer that finds nn.Embedding modules in a model and replaces |
| 269 | + them with the QuantizedLutEmbedding module based on a provided configuration. |
| 270 | + """ |
| 271 | + |
| 272 | + def __init__(self, config: GroupwiseLutWeightConfig): |
| 273 | + """ |
| 274 | + Initializes the quantizer with a single, comprehensive configuration object. |
| 275 | +
|
| 276 | + Args: |
| 277 | + config (GroupwiseLutWeightConfig): The configuration that defines |
| 278 | + how all embeddings should be quantized. |
| 279 | + """ |
| 280 | + # The quantizer now holds the entire configuration object. |
| 281 | + self.config = config |
| 282 | + |
| 283 | + def quantize(self, model: nn.Module) -> nn.Module: |
| 284 | + """ |
| 285 | + Recursively traverses the model and replaces all nn.Embedding layers. |
| 286 | +
|
| 287 | + Args: |
| 288 | + model (nn.Module): The model to be quantized. |
| 289 | +
|
| 290 | + Returns: |
| 291 | + nn.Module: The model with embedding layers replaced. |
| 292 | + """ |
| 293 | + self._replace_embedding(model) |
| 294 | + return model |
| 295 | + |
| 296 | + def _replace_embedding(self, module: nn.Module): |
| 297 | + for name, child in module.named_children(): |
| 298 | + if isinstance(child, nn.Embedding): |
| 299 | + if self.config.use_qdq_reference: |
| 300 | + weight = child.weight.data |
| 301 | + |
| 302 | + # 1. Run the full quantize -> dequantize pipeline in Python |
| 303 | + quantized_tensor = CodebookQuantizedTensor.from_float( |
| 304 | + weight, |
| 305 | + code_dtype=self.config.code_dtype, |
| 306 | + block_size=self.config.lut_block_shape, |
| 307 | + ) |
| 308 | + ref_weight = quantized_tensor.dequantize(self.config.weight_dtype) |
| 309 | + |
| 310 | + # 2. Create a standard nn.Embedding with the dequantized weight |
| 311 | + ref_embedding = nn.Embedding.from_pretrained( |
| 312 | + ref_weight, freeze=True |
| 313 | + ) |
| 314 | + setattr(module, name, ref_embedding) |
| 315 | + |
| 316 | + else: |
| 317 | + q_embedding = QuantizedLutEmbedding.from_float(child, self.config) |
| 318 | + setattr(module, name, q_embedding) |
| 319 | + else: |
| 320 | + self._replace_embedding(child) |
0 commit comments