Skip to content

Commit f64c864

Browse files
DrJessopChuck Gillen-O'Neel
andauthored
Revert D84020397: Group-quantized embedding op (#14915)
Summary: Revert D84020397: [Cadence ops] Group-quantized embedding op Differential Revision: D84186522 Co-authored-by: Chuck Gillen-O'Neel <[email protected]>
1 parent 0142a1a commit f64c864

File tree

4 files changed

+2
-173
lines changed

4 files changed

+2
-173
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,3 @@
468468
kernels:
469469
- arg_meta: null
470470
kernel_name: impl::generic::requantize_per_tensor_out
471-
472-
- func: cadence::quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, Tensor indices, bool pruned_weights, *, Tensor(a!) out) -> Tensor(a!)
473-
kernels:
474-
- arg_meta: null
475-
kernel_name: impl::generic::quantized_embedding_byte_out

backends/cadence/aot/ops_registrations.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@
320320
"float out_scale, int out_zero_point) -> (Tensor Z)"
321321
)
322322
lib.define(
323-
"quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
323+
"quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
324324
"Tensor indices, bool pruned_weights=False) -> (Tensor X)"
325325
)
326326
lib.define(
@@ -514,7 +514,7 @@
514514
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
515515
)
516516
lib.define(
517-
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
517+
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
518518
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
519519
)
520520

@@ -2310,28 +2310,6 @@ def transposed_im2row_meta(
23102310
return input.new_empty(output_size, dtype=input.dtype)
23112311

23122312

2313-
@register_fake("cadence::quantized_embedding_byte")
2314-
def quantized_embedding_byte_meta(
2315-
weight: torch.Tensor,
2316-
weight_scales: torch.Tensor,
2317-
weight_zero_points: torch.Tensor | None,
2318-
indices: torch.Tensor,
2319-
pruned_weights: bool = False,
2320-
) -> torch.Tensor:
2321-
assert not pruned_weights
2322-
assert len(weight.shape) == 2
2323-
assert 1 <= len(weight_scales.shape) <= 2
2324-
if len(weight_scales.shape) == 2:
2325-
num_groups = weight_scales.shape[-1]
2326-
assert weight.shape[1] % num_groups == 0
2327-
2328-
if weight_zero_points is not None:
2329-
assert weight_zero_points.shape == weight_scales.shape
2330-
2331-
assert 1 <= len(indices.shape) <= 2
2332-
return torch.empty(*indices.shape, weight.shape[1], dtype=torch.float32)
2333-
2334-
23352313
@register_fake("cadence::where_Scalar")
23362314
def where_Scalar_meta(
23372315
condition: torch.Tensor,

backends/cadence/aot/ref_implementations.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,34 +1572,3 @@ def transposed_im2row(
15721572
# Optionally, flatten to (N, num_patches, patch_size) if needed
15731573
patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous()
15741574
return patches
1575-
1576-
1577-
@impl(m, "quantized_embedding_byte")
1578-
def quantized_embedding_byte(
1579-
weight: torch.Tensor,
1580-
weight_scales: torch.Tensor,
1581-
weight_zero_points: torch.Tensor | None,
1582-
indices: torch.Tensor,
1583-
pruned_weights: bool = False,
1584-
) -> torch.Tensor:
1585-
if pruned_weights:
1586-
raise NotImplementedError("Pruned weights not supported")
1587-
1588-
# Cannot use torch.ops.quantized_decomposed.embedding_byte.dtype because
1589-
# it doesn't support num_groups == 1
1590-
num_groups = 1
1591-
if len(weight_scales.shape) == 2:
1592-
num_groups = weight_scales.shape[1]
1593-
1594-
group_size = weight.shape[1] // num_groups
1595-
weight = torch.ops.torchao.dequantize_affine.default(
1596-
input=weight,
1597-
block_size=(1, group_size),
1598-
scale=weight_scales,
1599-
zero_point=weight_zero_points,
1600-
input_dtype=weight.dtype,
1601-
quant_min=torch.iinfo(weight.dtype).min,
1602-
quant_max=torch.iinfo(weight.dtype).max,
1603-
)
1604-
1605-
return weight[indices]

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 0 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,116 +2306,3 @@ def test_transposed_im2row(
23062306
torch.equal(output, expected_output),
23072307
f"transposed_im2row output mismatch in {name}: got {output}, expected {expected_output}",
23082308
)
2309-
2310-
@expand(
2311-
[
2312-
(
2313-
"1_group",
2314-
torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8),
2315-
torch.tensor([1, 1, 1], dtype=torch.float32),
2316-
torch.tensor([0, 0, 0], dtype=torch.int8),
2317-
torch.tensor([0, 2, 1], dtype=torch.int64),
2318-
torch.tensor(
2319-
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]],
2320-
dtype=torch.float32,
2321-
),
2322-
),
2323-
(
2324-
"2_groups",
2325-
torch.tensor(
2326-
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8
2327-
),
2328-
torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32),
2329-
torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8),
2330-
torch.tensor([0, 2, 1], dtype=torch.int64),
2331-
torch.tensor(
2332-
[
2333-
[0.0, 0.5, 1.0, 2.0],
2334-
[10.0, 12.5, 15.0, 18.0],
2335-
[3.0, 4.5, 6.0, 8.0],
2336-
],
2337-
dtype=torch.float32,
2338-
),
2339-
),
2340-
(
2341-
"1_group_none_zero_point",
2342-
torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8),
2343-
torch.tensor([1, 1, 1], dtype=torch.float32),
2344-
None,
2345-
torch.tensor([0, 2, 1], dtype=torch.int64),
2346-
torch.tensor(
2347-
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]],
2348-
dtype=torch.float32,
2349-
),
2350-
),
2351-
(
2352-
"1_group_batch2",
2353-
torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8),
2354-
torch.tensor([1, 1, 1], dtype=torch.float32),
2355-
torch.tensor([0, 0, 0], dtype=torch.int8),
2356-
torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64),
2357-
torch.tensor(
2358-
[
2359-
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]],
2360-
[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]],
2361-
],
2362-
dtype=torch.float32,
2363-
),
2364-
),
2365-
(
2366-
"2_groups_batch2",
2367-
torch.tensor(
2368-
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8
2369-
),
2370-
torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32),
2371-
torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8),
2372-
torch.tensor([[0, 2, 1], [2, 1, 0]], dtype=torch.int64),
2373-
torch.tensor(
2374-
[
2375-
[
2376-
[0.0, 0.5, 1.0, 2.0],
2377-
[10.0, 12.5, 15.0, 18.0],
2378-
[3.0, 4.5, 6.0, 8.0],
2379-
],
2380-
[
2381-
[10.0, 12.5, 15.0, 18.0],
2382-
[3.0, 4.5, 6.0, 8.0],
2383-
[0.0, 0.5, 1.0, 2.0],
2384-
],
2385-
],
2386-
dtype=torch.float32,
2387-
),
2388-
),
2389-
(
2390-
"1_group_none_zero_point_batch2",
2391-
torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8),
2392-
torch.tensor([1, 1, 1], dtype=torch.float32),
2393-
None,
2394-
torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64),
2395-
torch.tensor(
2396-
[
2397-
[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]],
2398-
[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]],
2399-
],
2400-
dtype=torch.float32,
2401-
),
2402-
),
2403-
]
2404-
)
2405-
def test_quantized_embedding_byte(
2406-
self,
2407-
name: str,
2408-
weight: torch.Tensor,
2409-
weight_scales: torch.Tensor,
2410-
weight_zero_points: torch.Tensor | None,
2411-
indices: torch.Tensor,
2412-
expected_out: torch.Tensor,
2413-
) -> None:
2414-
self.assertTrue(
2415-
torch.equal(
2416-
torch.ops.cadence.quantized_embedding_byte(
2417-
weight, weight_scales, weight_zero_points, indices
2418-
),
2419-
expected_out,
2420-
)
2421-
)

0 commit comments

Comments
 (0)