Skip to content

Commit 28f38c4

Browse files
szyszyzysfacebook-github-bot
authored andcommitted
Add the ops for groupwise lut quantization for embeding (#2823)
Summary: Pull Request resolved: #2823 Reviewed By: metascroy Differential Revision: D79749992
1 parent bdbdc5e commit 28f38c4

File tree

6 files changed

+394
-1
lines changed

6 files changed

+394
-1
lines changed

torchao/experimental/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ if(TORCHAO_BUILD_ATEN_OPS)
134134
ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp
135135
ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp
136136
ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp
137+
ops/embedding_lut/op_embedding_groupwise_lowbit_lut_aten.cpp
137138
)
138139
list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")
139140

@@ -194,7 +195,8 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
194195
ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp
195196
ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp
196197
ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp
197-
ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp)
198+
ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp
199+
ops/embedding_lut/op_embedding_groupwise_lowbit_lut_executorch.cpp)
198200

199201
list(TRANSFORM _torchao_op_srcs_executorch PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")
200202
add_library(torchao_ops_executorch STATIC ${_torchao_op_srcs_executorch})
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
10+
#include <torchao/experimental/kernels/cpu/aarch64/embedding/embedding_lut.h>
11+
#endif // TORCHAO_BUILD_CPU_AARCH64
12+
13+
#include <torchao/experimental/ops/embedding_lut/packed_weights_header.h>
14+
#include <torchao/experimental/ops/library.h>
15+
#include <torchao/experimental/ops/parallel.h>
16+
17+
template <int weight_nbit>
18+
void check_embedding_lut_inputs(
19+
const Tensor& packed_weight_indices,
20+
const Tensor& indices,
21+
int64_t num_embeddings,
22+
int64_t embedding_dim,
23+
int64_t scale_group_size,
24+
int64_t lut_group_size,
25+
bool has_scales) {
26+
// Check packed weights header
27+
TORCHAO_CHECK(
28+
packed_weight_indices.dim() == 1, "packed_weight_indices must be 1D");
29+
#ifdef USE_ATEN
30+
TORCHAO_CHECK(
31+
packed_weight_indices.dtype() == torch::kInt8,
32+
"packed_weight_indices must be byte");
33+
#endif // USE_ATEN
34+
TORCHAO_CHECK(
35+
packed_weight_indices.size(0) >=
36+
torchao::ops::PackedWeightsHeader::size(),
37+
"packed_weight_indices is not large enough to contain a header");
38+
39+
// Check indices tensor
40+
TORCHAO_CHECK(indices.dim() == 1, "indices must be 1D");
41+
TORCHAO_CHECK(
42+
(indices.dtype() == Tensor_dtype_kInt32) ||
43+
(indices.dtype() == Tensor_dtype_kInt64),
44+
"indices must be int32 or int64");
45+
46+
// Check header
47+
auto header = torchao::ops::PackedWeightsHeader::read(
48+
packed_weight_indices.const_data_ptr());
49+
TORCHAO_CHECK(
50+
header ==
51+
torchao::ops::embedding_lut::get_packed_weights_header(
52+
/*version=*/1,
53+
weight_nbit,
54+
num_embeddings,
55+
embedding_dim,
56+
scale_group_size,
57+
lut_group_size,
58+
has_scales),
59+
"packed_weights are not compatible with the kernel");
60+
}
61+
62+
#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
63+
template <int weight_nbit>
64+
Tensor embedding_out_cpu(
65+
const Tensor& packed_weights,
66+
const Tensor& indices,
67+
int64_t num_embeddings,
68+
int64_t embedding_dim,
69+
int64_t scale_group_size,
70+
int64_t lut_group_size,
71+
bool has_scales,
72+
Tensor& out) {
73+
check_embedding_lut_inputs<weight_nbit>(
74+
packed_weights,
75+
indices,
76+
num_embeddings,
77+
embedding_dim,
78+
scale_group_size,
79+
lut_group_size,
80+
has_scales);
81+
82+
const int num_out = indices.size(0);
83+
TORCHAO_RESIZE_TENSOR(out, {(int)num_out, (int)embedding_dim});
84+
85+
const int32_t* index32_ptr = nullptr;
86+
const int64_t* index64_ptr = nullptr;
87+
if (indices.dtype() == Tensor_dtype_kInt32) {
88+
index32_ptr = indices.const_data_ptr<int32_t>();
89+
} else {
90+
index64_ptr = indices.const_data_ptr<int64_t>();
91+
}
92+
93+
// The actual packed data starts after the header
94+
const void* packed_data_ptr = packed_weights.const_data_ptr<int8_t>() +
95+
torchao::ops::PackedWeightsHeader::size();
96+
97+
torchao::parallel_1d(0, num_out, [&](int64_t idx) {
98+
int index = (index32_ptr != nullptr) ? index32_ptr[idx] : index64_ptr[idx];
99+
TORCHAO_CHECK(index >= 0 && index < num_embeddings, "Index out of bounds");
100+
101+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
102+
torchao::kernels::cpu::aarch64::embedding::
103+
dequantize_embedding_row_at_idx_lut<weight_nbit>(
104+
out.mutable_data_ptr<float>() + idx * embedding_dim,
105+
packed_data_ptr,
106+
index,
107+
num_embeddings,
108+
embedding_dim,
109+
scale_group_size,
110+
lut_group_size,
111+
has_scales);
112+
#else
113+
TORCHAO_CHECK(false, "Unsupported platform for embedding_lut kernel");
114+
#endif // TORCHAO_BUILD_CPU_AARCH64
115+
});
116+
117+
return out;
118+
}
119+
#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH)
120+
121+
#ifdef USE_ATEN
122+
template <int weight_nbit>
123+
Tensor embedding_cpu(
124+
const Tensor& packed_weights,
125+
const Tensor& indices,
126+
int64_t num_embeddings,
127+
int64_t embedding_dim,
128+
int64_t scale_group_size,
129+
int64_t lut_group_size,
130+
bool has_scales) {
131+
Tensor output_tensor = torch::empty({0}, torch::kFloat32);
132+
embedding_out_cpu<weight_nbit>(
133+
packed_weights,
134+
indices,
135+
num_embeddings,
136+
embedding_dim,
137+
scale_group_size,
138+
lut_group_size,
139+
has_scales,
140+
output_tensor);
141+
return output_tensor;
142+
}
143+
144+
template <int weight_nbit>
145+
Tensor pack_embedding_cpu(
146+
const Tensor& weight_qval_idxs,
147+
const Tensor& luts,
148+
int64_t scale_group_size,
149+
int64_t lut_group_size,
150+
const std::optional<Tensor>& weight_scales) {
151+
const bool has_scales = weight_scales.has_value();
152+
TORCHAO_CHECK(weight_qval_idxs.dim() == 2, "weight_qval_idxs must be 2D");
153+
const int64_t num_embeddings = weight_qval_idxs.size(0);
154+
const int64_t embedding_dim = weight_qval_idxs.size(1);
155+
156+
TORCHAO_CHECK(
157+
(embedding_dim * weight_nbit) % 8 == 0,
158+
"Total bits must be a multiple of 8.");
159+
160+
const size_t packed_embedding_size =
161+
torchao::kernels::cpu::aarch64::embedding::packed_embedding_size(
162+
weight_nbit,
163+
num_embeddings,
164+
embedding_dim,
165+
scale_group_size,
166+
lut_group_size,
167+
has_scales);
168+
const size_t total_packed_size =
169+
torchao::ops::PackedWeightsHeader::size() + packed_embedding_size;
170+
171+
// Allocate and Pack
172+
auto out = torch::empty({(long)total_packed_size}, torch::kInt8);
173+
174+
// Write header
175+
auto header = torchao::ops::embedding_lut::get_packed_weights_header(
176+
/*version=*/1,
177+
weight_nbit,
178+
num_embeddings,
179+
embedding_dim,
180+
scale_group_size,
181+
lut_group_size,
182+
has_scales);
183+
header.write(out.mutable_data_ptr());
184+
185+
void* packed_table_ptr = out.mutable_data_ptr<int8_t>() +
186+
torchao::ops::PackedWeightsHeader::size();
187+
188+
// Pack each row
189+
torchao::parallel_1d(0, num_embeddings, [&](int64_t i) {
190+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
191+
torchao::kernels::cpu::aarch64::embedding::pack_embedding_row_at_index_lut<
192+
weight_nbit>(
193+
packed_table_ptr,
194+
i,
195+
weight_qval_idxs.const_data_ptr<uint8_t>(),
196+
has_scales ? weight_scales->const_data_ptr<float>() : nullptr,
197+
luts.const_data_ptr<float>(),
198+
num_embeddings,
199+
embedding_dim,
200+
scale_group_size,
201+
lut_group_size,
202+
has_scales);
203+
#else
204+
TORCHAO_CHECK(false, "Unsupported platform for pack_embedding kernel");
205+
#endif // defined(TORCHAO_BUILD_CPU_AARCH64)
206+
});
207+
208+
return out;
209+
}
210+
211+
template <int weight_nbit>
212+
Tensor pack_embedding_meta(
213+
const Tensor& weight_qval_idxs,
214+
const Tensor& luts,
215+
int64_t scale_group_size,
216+
int64_t lut_group_size,
217+
const std::optional<Tensor>& weight_scales) {
218+
const int64_t num_embeddings = weight_qval_idxs.size(0);
219+
const int64_t embedding_dim = weight_qval_idxs.size(1);
220+
const bool has_scales = weight_scales.has_value();
221+
222+
TORCHAO_CHECK(
223+
(embedding_dim * weight_nbit) % 8 == 0,
224+
"Total bits must be a multiple of 8 for meta function.");
225+
226+
const size_t packed_embedding_size =
227+
torchao::kernels::cpu::aarch64::embedding::packed_embedding_size(
228+
weight_nbit,
229+
num_embeddings,
230+
embedding_dim,
231+
scale_group_size,
232+
lut_group_size,
233+
has_scales);
234+
;
235+
const size_t total_packed_size = torchao::ops::PackedWeightsHeader::size() + packed_embedding_size;
236+
237+
auto options =
238+
torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8);
239+
return torch::empty({(long)total_packed_size}, options);
240+
}
241+
#endif // USE_ATEN
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <torchao/experimental/ops/embedding_lut/op_embedding_groupwise_lowbit_lut-impl.h>
8+
9+
// This macro defines the operator signatures.
10+
// The signatures now correctly match the C++ implementation.
11+
#define DEFINE_LUT_OP(weight_nbit) \
12+
m.def( \
13+
"_pack_embedding_lut_" #weight_nbit \
14+
"bit(Tensor weight_qval_idxs, Tensor luts, int scale_group_size, " \
15+
"int lut_group_size, Tensor? weight_scales) -> Tensor"); \
16+
m.def( \
17+
"_embedding_lut_" #weight_nbit \
18+
"bit(Tensor packed_weights, Tensor indices, int num_embeddings, " \
19+
"int embedding_dim, int scale_group_size, int lut_group_size, " \
20+
"bool has_scales) -> Tensor"); \
21+
m.def( \
22+
"_embedding_lut_" #weight_nbit \
23+
"bit.out(Tensor packed_weights, Tensor indices, int num_embeddings, " \
24+
"int embedding_dim, int scale_group_size, int lut_group_size, " \
25+
"bool has_scales, *, Tensor(a!) out) -> Tensor(a!)");
26+
27+
// This macro registers the CPU implementations for the LUT-based operators.
28+
#define DEFINE_CPU_IMPL(weight_nbit) \
29+
m.impl( \
30+
"_pack_embedding_lut_" #weight_nbit "bit", \
31+
torch::dispatch( \
32+
c10::DispatchKey::CPU, &pack_embedding_cpu<weight_nbit>)); \
33+
m.impl( \
34+
"_embedding_lut_" #weight_nbit "bit", \
35+
torch::dispatch( \
36+
c10::DispatchKey::CPU, &embedding_cpu<weight_nbit>)); \
37+
m.impl( \
38+
"_embedding_lut_" #weight_nbit "bit.out", \
39+
torch::dispatch( \
40+
c10::DispatchKey::CPU, &embedding_out_cpu<weight_nbit>));
41+
42+
// This macro registers the Meta (device-agnostic) implementation for packing.
43+
#define DEFINE_META_IMPL(weight_nbit) \
44+
m.impl( \
45+
"_pack_embedding_lut_" #weight_nbit "bit", \
46+
torch::dispatch( \
47+
c10::DispatchKey::Meta, &pack_embedding_meta<weight_nbit>));
48+
49+
// Operator definitions
50+
TORCH_LIBRARY_FRAGMENT(torchao, m) {
51+
DEFINE_LUT_OP(1);
52+
DEFINE_LUT_OP(2);
53+
DEFINE_LUT_OP(3);
54+
DEFINE_LUT_OP(4);
55+
}
56+
57+
// CPU implementations
58+
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
59+
DEFINE_CPU_IMPL(1);
60+
DEFINE_CPU_IMPL(2);
61+
DEFINE_CPU_IMPL(3);
62+
DEFINE_CPU_IMPL(4);
63+
}
64+
65+
// Meta implementations
66+
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
67+
DEFINE_META_IMPL(1);
68+
DEFINE_META_IMPL(2);
69+
DEFINE_META_IMPL(3);
70+
DEFINE_META_IMPL(4);
71+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <torchao/experimental/ops/embedding_lut/op_embedding_groupwise_lowbit_lut-impl.h>
8+
9+
#define DEFINE_LUT_OP(weight_nbit) \
10+
Tensor _op_lut_out_##weight_nbit( \
11+
RuntimeContext& ctx, \
12+
const Tensor& packed_weights, \
13+
const Tensor& indices, \
14+
const int64_t& num_embeddings, \
15+
const int64_t& embedding_dim, \
16+
const int64_t& scale_group_size, \
17+
const int64_t& lut_group_size, \
18+
const bool& has_scales, \
19+
Tensor& out) { \
20+
(void)ctx; \
21+
embedding_out_cpu<weight_nbit>( \
22+
packed_weights, \
23+
indices, \
24+
num_embeddings, \
25+
embedding_dim, \
26+
scale_group_size, \
27+
lut_group_size, \
28+
has_scales, \
29+
out); \
30+
return out; \
31+
} \
32+
EXECUTORCH_LIBRARY( \
33+
torchao, \
34+
"_embedding_lut_" #weight_nbit "bit.out", \
35+
_op_lut_out_##weight_nbit)
36+
37+
DEFINE_LUT_OP(1);
38+
DEFINE_LUT_OP(2);
39+
DEFINE_LUT_OP(3);
40+
DEFINE_LUT_OP(4);
41+
DEFINE_LUT_OP(5);
42+
DEFINE_LUT_OP(6);
43+
DEFINE_LUT_OP(7);
44+
DEFINE_LUT_OP(8);

0 commit comments

Comments
 (0)