Skip to content

Commit 4f044b1

Browse files
authored
[Kernel][CPU] CPU MLA (#14744)
Signed-off-by: Thien Tran <[email protected]>
1 parent 4157f56 commit 4f044b1

15 files changed

+1010
-17
lines changed

.buildkite/run-cpu-test.sh

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ function cpu_tests() {
3838
set -e
3939
pip install -r vllm/requirements/test.txt
4040
pip install -r vllm/requirements/cpu.txt
41+
pytest -v -s tests/kernels/test_cache.py -m cpu_model
42+
pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model
4143
pytest -v -s tests/models/decoder_only/language -m cpu_model
4244
pytest -v -s tests/models/embedding/language -m cpu_model
4345
pytest -v -s tests/models/encoder_decoder/language -m cpu_model

cmake/cpu_extension.cmake

+1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ set(VLLM_EXT_SRC
190190
"csrc/cpu/cache.cpp"
191191
"csrc/cpu/utils.cpp"
192192
"csrc/cpu/layernorm.cpp"
193+
"csrc/cpu/mla_decode.cpp"
193194
"csrc/cpu/pos_encoding.cpp"
194195
"csrc/cpu/torch_bindings.cpp")
195196

csrc/cpu/cache.cpp

+74
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,48 @@ void reshape_and_cache_cpu_impl(
8888
}
8989
}; // namespace
9090

91+
template <typename scalar_t>
92+
void concat_and_cache_mla_cpu_impl(
93+
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
94+
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
95+
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
96+
// + pe_dim)]
97+
const int64_t* __restrict__ slot_mapping, // [num_tokens]
98+
const int num_tokens, //
99+
const int block_stride, //
100+
const int entry_stride, //
101+
const int kv_c_stride, //
102+
const int k_pe_stride, //
103+
const int kv_lora_rank, //
104+
const int pe_dim, //
105+
const int block_size //
106+
) {
107+
#pragma omp parallel for
108+
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
109+
const int64_t slot_idx = slot_mapping[token_idx];
110+
// NOTE: slot_idx can be -1 if the token is padded
111+
if (slot_idx < 0) {
112+
continue;
113+
}
114+
const int64_t block_idx = slot_idx / block_size;
115+
const int64_t block_offset = slot_idx % block_size;
116+
117+
auto copy = [&](const scalar_t* __restrict__ src,
118+
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
119+
int size, int offset) {
120+
for (int i = 0; i < size; i++) {
121+
const int64_t src_idx = token_idx * src_stride + i;
122+
const int64_t dst_idx =
123+
block_idx * block_stride + block_offset * entry_stride + i + offset;
124+
dst[dst_idx] = src[src_idx];
125+
}
126+
};
127+
128+
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
129+
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
130+
}
131+
}
132+
91133
// Note: the key_caches and value_caches vectors are constant but
92134
// not the Tensors they contain. The vectors need to be const refs
93135
// in order to satisfy pytorch's C++ operator registration code.
@@ -134,6 +176,38 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
134176
});
135177
}
136178

179+
void concat_and_cache_mla(
180+
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
181+
torch::Tensor& k_pe, // [num_tokens, pe_dim]
182+
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
183+
// pe_dim)]
184+
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
185+
const std::string& kv_cache_dtype, torch::Tensor& scale) {
186+
int num_tokens = slot_mapping.size(0);
187+
int kv_lora_rank = kv_c.size(1);
188+
int pe_dim = k_pe.size(1);
189+
int block_size = kv_cache.size(1);
190+
191+
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
192+
TORCH_CHECK(kv_cache_dtype != "fp8");
193+
194+
int kv_c_stride = kv_c.stride(0);
195+
int k_pe_stride = k_pe.stride(0);
196+
int block_stride = kv_cache.stride(0);
197+
int entry_stride = kv_cache.stride(1);
198+
199+
VLLM_DISPATCH_FLOATING_TYPES(
200+
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
201+
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
202+
concat_and_cache_mla_cpu_impl<scalar_t>(
203+
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
204+
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
205+
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
206+
kv_lora_rank, pe_dim, block_size);
207+
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
208+
});
209+
}
210+
137211
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
138212
const torch::Tensor& block_mapping) {
139213
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")

csrc/cpu/cpu_types_x86.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
130130

131131
__m512i reg;
132132

133+
explicit BF16Vec32() : reg(_mm512_setzero_si512()) {}
134+
133135
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
134136

135137
explicit BF16Vec32(__m512i data) : reg(data) {}

0 commit comments

Comments
 (0)