@@ -88,6 +88,48 @@ void reshape_and_cache_cpu_impl(
88
88
}
89
89
}; // namespace
90
90
91
+ template <typename scalar_t >
92
+ void concat_and_cache_mla_cpu_impl (
93
+ const scalar_t * __restrict__ kv_c, // [num_tokens, kv_lora_rank]
94
+ const scalar_t * __restrict__ k_pe, // [num_tokens, pe_dim]
95
+ scalar_t * __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
96
+ // + pe_dim)]
97
+ const int64_t * __restrict__ slot_mapping, // [num_tokens]
98
+ const int num_tokens, //
99
+ const int block_stride, //
100
+ const int entry_stride, //
101
+ const int kv_c_stride, //
102
+ const int k_pe_stride, //
103
+ const int kv_lora_rank, //
104
+ const int pe_dim, //
105
+ const int block_size //
106
+ ) {
107
+ #pragma omp parallel for
108
+ for (int token_idx = 0 ; token_idx < num_tokens; ++token_idx) {
109
+ const int64_t slot_idx = slot_mapping[token_idx];
110
+ // NOTE: slot_idx can be -1 if the token is padded
111
+ if (slot_idx < 0 ) {
112
+ continue ;
113
+ }
114
+ const int64_t block_idx = slot_idx / block_size;
115
+ const int64_t block_offset = slot_idx % block_size;
116
+
117
+ auto copy = [&](const scalar_t * __restrict__ src,
118
+ scalar_t * __restrict__ dst, int src_stride, int dst_stride,
119
+ int size, int offset) {
120
+ for (int i = 0 ; i < size; i++) {
121
+ const int64_t src_idx = token_idx * src_stride + i;
122
+ const int64_t dst_idx =
123
+ block_idx * block_stride + block_offset * entry_stride + i + offset;
124
+ dst[dst_idx] = src[src_idx];
125
+ }
126
+ };
127
+
128
+ copy (kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0 );
129
+ copy (k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
130
+ }
131
+ }
132
+
91
133
// Note: the key_caches and value_caches vectors are constant but
92
134
// not the Tensors they contain. The vectors need to be const refs
93
135
// in order to satisfy pytorch's C++ operator registration code.
@@ -134,6 +176,38 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
134
176
});
135
177
}
136
178
179
+ void concat_and_cache_mla (
180
+ torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
181
+ torch::Tensor& k_pe, // [num_tokens, pe_dim]
182
+ torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
183
+ // pe_dim)]
184
+ torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
185
+ const std::string& kv_cache_dtype, torch::Tensor& scale) {
186
+ int num_tokens = slot_mapping.size (0 );
187
+ int kv_lora_rank = kv_c.size (1 );
188
+ int pe_dim = k_pe.size (1 );
189
+ int block_size = kv_cache.size (1 );
190
+
191
+ TORCH_CHECK (kv_cache.size (2 ) == kv_lora_rank + pe_dim);
192
+ TORCH_CHECK (kv_cache_dtype != " fp8" );
193
+
194
+ int kv_c_stride = kv_c.stride (0 );
195
+ int k_pe_stride = k_pe.stride (0 );
196
+ int block_stride = kv_cache.stride (0 );
197
+ int entry_stride = kv_cache.stride (1 );
198
+
199
+ VLLM_DISPATCH_FLOATING_TYPES (
200
+ kv_c.scalar_type (), " concat_and_cache_mla_cpu_impl" , [&] {
201
+ CPU_KERNEL_GUARD_IN (concat_and_cache_mla_cpu_impl)
202
+ concat_and_cache_mla_cpu_impl<scalar_t >(
203
+ kv_c.data_ptr <scalar_t >(), k_pe.data_ptr <scalar_t >(),
204
+ kv_cache.data_ptr <scalar_t >(), slot_mapping.data_ptr <int64_t >(),
205
+ num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
206
+ kv_lora_rank, pe_dim, block_size);
207
+ CPU_KERNEL_GUARD_OUT (concat_and_cache_mla_cpu_impl)
208
+ });
209
+ }
210
+
137
211
void swap_blocks (torch::Tensor& src, torch::Tensor& dst,
138
212
const torch::Tensor& block_mapping) {
139
213
TORCH_CHECK (false , " swap_blocks is unsupported on CPU." )
0 commit comments