Skip to content

Commit 5246168

Browse files
authored
Group-quantized embedding op
Differential Revision: D84020397 Pull Request resolved: #14835
1 parent ec56cfa commit 5246168

File tree

4 files changed

+173
-2
lines changed

4 files changed

+173
-2
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,3 +468,8 @@
468468
kernels:
469469
- arg_meta: null
470470
kernel_name: impl::generic::requantize_per_tensor_out
471+
472+
- func: cadence::quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, Tensor indices, bool pruned_weights, *, Tensor(a!) out) -> Tensor(a!)
473+
kernels:
474+
- arg_meta: null
475+
kernel_name: impl::generic::quantized_embedding_byte_out

backends/cadence/aot/ops_registrations.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@
320320
"float out_scale, int out_zero_point) -> (Tensor Z)"
321321
)
322322
lib.define(
323-
"quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
323+
"quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
324324
"Tensor indices, bool pruned_weights=False) -> (Tensor X)"
325325
)
326326
lib.define(
@@ -514,7 +514,7 @@
514514
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
515515
)
516516
lib.define(
517-
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
517+
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
518518
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
519519
)
520520

@@ -2310,6 +2310,28 @@ def transposed_im2row_meta(
23102310
return input.new_empty(output_size, dtype=input.dtype)
23112311

23122312

2313+
@register_fake("cadence::quantized_embedding_byte")
2314+
def quantized_embedding_byte_meta(
2315+
weight: torch.Tensor,
2316+
weight_scales: torch.Tensor,
2317+
weight_zero_points: torch.Tensor | None,
2318+
indices: torch.Tensor,
2319+
pruned_weights: bool = False,
2320+
) -> torch.Tensor:
2321+
assert not pruned_weights
2322+
assert len(weight.shape) == 2
2323+
assert 1 <= len(weight_scales.shape) <= 2
2324+
if len(weight_scales.shape) == 2:
2325+
num_groups = weight_scales.shape[-1]
2326+
assert weight.shape[1] % num_groups == 0
2327+
2328+
if weight_zero_points is not None:
2329+
assert weight_zero_points.shape == weight_scales.shape
2330+
2331+
assert 1 <= len(indices.shape) <= 2
2332+
return torch.empty(*indices.shape, weight.shape[1], dtype=torch.float32)
2333+
2334+
23132335
@register_fake("cadence::where_Scalar")
23142336
def where_Scalar_meta(
23152337
condition: torch.Tensor,

backends/cadence/aot/ref_implementations.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,3 +1572,34 @@ def transposed_im2row(
15721572
# Optionally, flatten to (N, num_patches, patch_size) if needed
15731573
patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous()
15741574
return patches
1575+
1576+
1577+
@impl(m, "quantized_embedding_byte")
1578+
def quantized_embedding_byte(
1579+
weight: torch.Tensor,
1580+
weight_scales: torch.Tensor,
1581+
weight_zero_points: torch.Tensor | None,
1582+
indices: torch.Tensor,
1583+
pruned_weights: bool = False,
1584+
) -> torch.Tensor:
1585+
if pruned_weights:
1586+
raise NotImplementedError("Pruned weights not supported")
1587+
1588+
# Cannot use torch.ops.quantized_decomposed.embedding_byte.dtype because
1589+
# it doesn't support num_groups == 1
1590+
num_groups = 1
1591+
if len(weight_scales.shape) == 2:
1592+
num_groups = weight_scales.shape[1]
1593+
1594+
group_size = weight.shape[1] // num_groups
1595+
weight = torch.ops.torchao.dequantize_affine.default(
1596+
input=weight,
1597+
block_size=(1, group_size),
1598+
scale=weight_scales,
1599+
zero_point=weight_zero_points,
1600+
input_dtype=weight.dtype,
1601+
quant_min=torch.iinfo(weight.dtype).min,
1602+
quant_max=torch.iinfo(weight.dtype).max,
1603+
)
1604+
1605+
return weight[indices]

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,3 +2306,116 @@ def test_transposed_im2row(
23062306
torch.equal(output, expected_output),
23072307
f"transposed_im2row output mismatch in {name}: got {output}, expected {expected_output}",
23082308
)
2309+
2310+
@expand(
2311+
[
2312+
(
2313+
"1_group",
2314+
torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8),
2315+
torch.tensor([1, 1, 1], dtype=torch.float32),
2316+
torch.tensor([0, 0, 0], dtype=torch.int8),
2317+
torch.tensor([0, 2, 1], dtype=torch.int64),
2318+
torch.tensor(
2319+
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]],
2320+
dtype=torch.float32,
2321+
),
2322+
),
2323+
(
2324+
"2_groups",
2325+
torch.tensor(
2326+
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8
2327+
),
2328+
torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32),
2329+
torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8),
2330+
torch.tensor([0, 2, 1], dtype=torch.int64),
2331+
torch.tensor(
2332+
[
2333+
[0.0, 0.5, 1.0, 2.0],
2334+
[10.0, 12.5, 15.0, 18.0],
2335+
[3.0, 4.5, 6.0, 8.0],
2336+
],
2337+
dtype=torch.float32,
2338+
),
2339+
),
2340+
(
2341+
"1_group_none_zero_point",
2342+
torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8),
2343+
torch.tensor([1, 1, 1], dtype=torch.float32),
2344+
None,
2345+
torch.tensor([0, 2, 1], dtype=torch.int64),
2346+
torch.tensor(
2347+
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]],
2348+
dtype=torch.float32,
2349+
),
2350+
),
2351+
(
2352+
"1_group_batch2",
2353+
torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8),
2354+
torch.tensor([1, 1, 1], dtype=torch.float32),
2355+
torch.tensor([0, 0, 0], dtype=torch.int8),
2356+
torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64),
2357+
torch.tensor(
2358+
[
2359+
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]],
2360+
[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]],
2361+
],
2362+
dtype=torch.float32,
2363+
),
2364+
),
2365+
(
2366+
"2_groups_batch2",
2367+
torch.tensor(
2368+
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8
2369+
),
2370+
torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32),
2371+
torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8),
2372+
torch.tensor([[0, 2, 1], [2, 1, 0]], dtype=torch.int64),
2373+
torch.tensor(
2374+
[
2375+
[
2376+
[0.0, 0.5, 1.0, 2.0],
2377+
[10.0, 12.5, 15.0, 18.0],
2378+
[3.0, 4.5, 6.0, 8.0],
2379+
],
2380+
[
2381+
[10.0, 12.5, 15.0, 18.0],
2382+
[3.0, 4.5, 6.0, 8.0],
2383+
[0.0, 0.5, 1.0, 2.0],
2384+
],
2385+
],
2386+
dtype=torch.float32,
2387+
),
2388+
),
2389+
(
2390+
"1_group_none_zero_point_batch2",
2391+
torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8),
2392+
torch.tensor([1, 1, 1], dtype=torch.float32),
2393+
None,
2394+
torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64),
2395+
torch.tensor(
2396+
[
2397+
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]],
2398+
[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]],
2399+
],
2400+
dtype=torch.float32,
2401+
),
2402+
),
2403+
]
2404+
)
2405+
def test_quantized_embedding_byte(
2406+
self,
2407+
name: str,
2408+
weight: torch.Tensor,
2409+
weight_scales: torch.Tensor,
2410+
weight_zero_points: torch.Tensor | None,
2411+
indices: torch.Tensor,
2412+
expected_out: torch.Tensor,
2413+
) -> None:
2414+
self.assertTrue(
2415+
torch.equal(
2416+
torch.ops.cadence.quantized_embedding_byte(
2417+
weight, weight_scales, weight_zero_points, indices
2418+
),
2419+
expected_out,
2420+
)
2421+
)

0 commit comments

Comments
 (0)