Skip to content

Commit 3b022db

Browse files
szyszyzysliangel-02
authored andcommitted
Add test for lut based embedding quantization.
Differential Revision: D79750022 Pull Request resolved: #2825
1 parent 959c29b commit 3b022db

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import unittest
9+
10+
import torch
11+
import torch.nn as nn
12+
from parameterized import param, parameterized
13+
from torch import uint1, uint2, uint3, uint4
14+
15+
from torchao.prototype.quantization.codebook_groupwise.api import (
16+
EmbeddingLutQuantizer,
17+
GroupwiseLutWeightConfig,
18+
)
19+
20+
21+
def generate_test_cases():
22+
"""Generates test cases with logic to handle has_scales correctly."""
23+
code_dtypes = [uint1, uint2, uint3, uint4]
24+
lut_block_shapes = [[1, -1], [2, -1], [4, -1]]
25+
26+
test_cases = []
27+
28+
for code_dtype in code_dtypes:
29+
for lut_block_shape in lut_block_shapes:
30+
test_cases.append(
31+
param(
32+
config=GroupwiseLutWeightConfig(
33+
code_dtype=code_dtype,
34+
lut_block_shape=lut_block_shape,
35+
scale_block_shape=None,
36+
has_scale=False,
37+
),
38+
embedding_dim=256,
39+
num_embeddings=128,
40+
)
41+
)
42+
43+
return test_cases
44+
45+
46+
class TestLutEmbeddingQuantizer(unittest.TestCase):
47+
@parameterized.expand(generate_test_cases())
48+
def test_accuracy_vs_qdq_reference(
49+
self,
50+
config: GroupwiseLutWeightConfig,
51+
embedding_dim: int,
52+
num_embeddings: int = 128,
53+
):
54+
"""
55+
Tests the numerical accuracy of the custom quantized embedding module
56+
against a QDQ (Quantize-Dequantize) reference implementation.
57+
"""
58+
embedding_dim = embedding_dim
59+
model = nn.Sequential(nn.Embedding(num_embeddings, embedding_dim))
60+
indices = torch.randint(0, num_embeddings, (10, 20), dtype=torch.int64)
61+
62+
# --- 1. Get ACTUAL result from the custom kernel implementation ---
63+
quantized_model = copy.deepcopy(model)
64+
# Ensure the 'use_qdq_reference' flag is False for the performance path
65+
perf_config = copy.deepcopy(config)
66+
perf_config.use_qdq_reference = False
67+
68+
quantizer = EmbeddingLutQuantizer(perf_config)
69+
quantizer.quantize(quantized_model)
70+
71+
with torch.no_grad():
72+
actual_result = quantized_model(indices)
73+
74+
# --- 2. Get EXPECTED result from the QDQ reference implementation ---
75+
reference_model = copy.deepcopy(model)
76+
# Set the 'use_qdq_reference' flag to True for the reference path
77+
ref_config = copy.deepcopy(config)
78+
ref_config.use_qdq_reference = True
79+
80+
quantizer = EmbeddingLutQuantizer(ref_config)
81+
quantizer.quantize(reference_model)
82+
83+
with torch.no_grad():
84+
expected_result = reference_model(indices)
85+
86+
# --- 3. Compare results ---
87+
self.assertTrue(
88+
torch.allclose(actual_result, expected_result, atol=1e-6, rtol=1e-5)
89+
)
90+
91+
92+
if __name__ == "__main__":
93+
unittest.main()

0 commit comments

Comments
 (0)