Skip to content

Commit 64b3c17

Browse files
authoredNov 21, 2024··
add support for QuantizedCache (#5)
* add support for QuantizedCache * update README * upgrade version * update README
1 parent 34a7f57 commit 64b3c17

File tree

4 files changed

+81
-41
lines changed

4 files changed

+81
-41
lines changed
 

‎README.md

+19-7
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,25 @@ _Average performance on the RULER dataset with 4k context length and Loogle Shor
7272

7373
Please refer to the [evaluation](evaluation/README.md) directory for more details and results.
7474

75+
## KV cache quantization
76+
77+
We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline:
78+
79+
```python
80+
from transformers import QuantizedCacheConfig, QuantoQuantizedCache
81+
82+
config = QuantizedCacheConfig(nbits=4)
83+
cache = QuantoQuantizedCache(config)
84+
85+
pipe(..., cache=cache)
86+
```
87+
88+
By default, the `DynamicCache` is used (no quantization).
89+
90+
> [!IMPORTANT]
91+
> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto==0.2.4`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)).
92+
93+
7594
## FAQ
7695

7796
<details><summary>
@@ -165,10 +184,3 @@ Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more d
165184
</details>
166185

167186
<details><summary>
168-
169-
### Is quantization supported ?
170-
</summary>
171-
172-
We don't support quantization of the KV cache yet. Quantization can achieve up to 4x compression moving from (b)float16 to int4 and we believe it is orthogonal to the KV cache pruning strategies proposed in this repository.
173-
174-
</details>

‎kvpress/pipeline.py

+39-28
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Optional
88

99
import torch
10-
from transformers import AutoModelForCausalLM, DynamicCache, Pipeline
10+
from transformers import AutoModelForCausalLM, Cache, DynamicCache, QuantizedCache, Pipeline
1111
from transformers.pipelines import PIPELINE_REGISTRY
1212
from transformers.pipelines.base import GenericTensor
1313

@@ -32,6 +32,7 @@ def _sanitize_parameters(
3232
press: Optional[BasePress] = None,
3333
max_new_tokens: int = 50,
3434
max_context_length: Optional[int] = None,
35+
cache: Optional[Cache] = None,
3536
**kwargs,
3637
):
3738
"""
@@ -42,7 +43,7 @@ def _sanitize_parameters(
4243
----------
4344
question : str, optional
4445
The question to be asked about the context. Exclusive with `questions`.
45-
questions : List[str], optional
46+
questions : list[str], optional
4647
A list of questions to be asked about the context. Exclusive with `question`.
4748
answer_prefix : str, optional
4849
The prefix to be added to the generated answer.
@@ -52,12 +53,14 @@ def _sanitize_parameters(
5253
The maximum number of new tokens to generate for each answer.
5354
max_context_length : int, optional
5455
The maximum number of tokens in the context. By default will use the maximum length supported by the model.
56+
cache : Cache, optional
57+
The cache to use for the forward pass. Defaults to None (DynamicCache).
5558
**kwargs : dict
5659
Additional keyword arguments, currently ignored.
5760
5861
Returns
5962
-------
60-
Tuple[Dict, Dict, Dict]
63+
Tuple[dict, dict, dict]
6164
A tuple containing three dictionaries:
6265
- preprocess_kwargs: The keyword arguments for the preprocess function.
6366
- forward_kwargs: The keyword arguments for the forward function.
@@ -75,7 +78,7 @@ def _sanitize_parameters(
7578
"answer_prefix": answer_prefix,
7679
"max_context_length": max_context_length,
7780
}
78-
forward_kwargs = {"press": press, "max_new_tokens": max_new_tokens}
81+
forward_kwargs = {"press": press, "max_new_tokens": max_new_tokens, "cache": cache}
7982
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
8083

8184
def preprocess(
@@ -90,7 +93,7 @@ def preprocess(
9093
9194
Returns
9295
-------
93-
Dict[str, GenericTensor]
96+
dict[str, GenericTensor]
9497
A dictionary containing the tokenized context (key: "context_ids") and questions (key: "questions_ids").
9598
9699
"""
@@ -127,47 +130,56 @@ def preprocess(
127130
return {"context_ids": context_ids, "questions_ids": question_ids}
128131

129132
def _forward(
130-
self, input_tensors: dict[str, GenericTensor], max_new_tokens: int = 50, press: Optional[BasePress] = None
133+
self,
134+
input_tensors: dict[str, GenericTensor],
135+
max_new_tokens: int = 50,
136+
press: Optional[BasePress] = None,
137+
cache: Optional[Cache] = None,
131138
):
132139
"""
133140
Forward pass of the kv-press pipeline.
134141
135142
Parameters
136143
----------
137-
input_tensors : Dict[str, GenericTensor]
144+
input_tensors : dict[str, GenericTensor]
138145
A dictionary containing the tokenized context and questions.
139146
max_new_tokens : int, optional
140147
The maximum number of new tokens to generate for each answer. Defaults to 50.
141148
press : BasePress, optional
142149
The key-value press to use for compression. Defaults to None.
150+
cache : Cache, optional
151+
The cache to use for the forward pass. Defaults to None (DynamicCache).
143152
144153
Returns
145154
-------
146-
List[str]
155+
list[str]
147156
A list of generated answers.
148157
"""
149158

150159
context_ids = input_tensors["context_ids"].to(self.model.device)
151160
context_length = context_ids.shape[1]
152161

153162
# Prefilling using the press on the context
163+
if cache is None:
164+
cache = DynamicCache()
165+
154166
with press(self.model) if press is not None else contextlib.nullcontext():
155-
past_key_values = self.model(
167+
self.model(
156168
input_ids=context_ids,
157-
past_key_values=DynamicCache(),
169+
past_key_values=cache,
158170
output_attentions=isinstance(press, ObservedAttentionPress),
159171
num_logits_to_keep=1,
160-
).past_key_values
172+
)
161173

162174
logger.debug(f"Context Length: {context_length}")
163-
logger.debug(f"Compressed Context Length: {past_key_values.get_seq_length()}")
175+
logger.debug(f"Compressed Context Length: {cache.get_seq_length()}")
164176

165177
# Greedy decoding for each question
166178
answers = []
167179
for question_ids in input_tensors["questions_ids"]:
168180
answer = self.generate_answer(
169181
question_ids=question_ids.to(self.model.device),
170-
past_key_values=past_key_values,
182+
cache=cache,
171183
context_length=context_length,
172184
max_new_tokens=max_new_tokens,
173185
)
@@ -181,7 +193,7 @@ def postprocess(self, model_outputs, single_question):
181193
return {"answers": model_outputs}
182194

183195
def generate_answer(
184-
self, question_ids: torch.Tensor, past_key_values: DynamicCache, context_length: int, max_new_tokens: int
196+
self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int
185197
) -> str:
186198
"""
187199
Generate an answer to a question using greedy decoding.
@@ -190,7 +202,7 @@ def generate_answer(
190202
----------
191203
question_ids : torch.Tensor
192204
The tokenized question.
193-
past_key_values : DynamicCache
205+
cache : Cache
194206
The compressed key-value cache.
195207
context_length : int
196208
The length of the context.
@@ -203,18 +215,15 @@ def generate_answer(
203215
The generated answer.
204216
"""
205217

206-
cache_seq_lengths = [
207-
past_key_values.get_seq_length(layer_idx=layer_idx) for layer_idx in range(len(past_key_values))
208-
]
209-
218+
cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))]
210219
position_ids = torch.arange(
211220
context_length, context_length + question_ids.shape[1], device=self.model.device
212221
).unsqueeze(0)
213222

214223
# if the user doesn't provide a question, skip forward pass
215224
outputs = self.model(
216225
input_ids=question_ids.to(self.model.device),
217-
past_key_values=past_key_values,
226+
past_key_values=cache,
218227
position_ids=position_ids,
219228
num_logits_to_keep=1,
220229
)
@@ -229,7 +238,7 @@ def generate_answer(
229238
for i in range(max_new_tokens - 1):
230239
outputs = self.model(
231240
input_ids=generated_ids[-1].unsqueeze(0).unsqueeze(0),
232-
past_key_values=outputs.past_key_values,
241+
past_key_values=cache,
233242
position_ids=position_ids + i,
234243
)
235244
new_id = outputs.logits[0, -1].argmax()
@@ -238,13 +247,15 @@ def generate_answer(
238247
break
239248
answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True)
240249

241-
# remove the generated tokens from the cache
242-
past_key_values.key_cache = [
243-
key[:, :, :cache_seq_len] for key, cache_seq_len in zip(past_key_values.key_cache, cache_seq_lengths)
244-
]
245-
past_key_values.value_cache = [
246-
value[:, :, :cache_seq_len] for value, cache_seq_len in zip(past_key_values.value_cache, cache_seq_lengths)
247-
]
250+
# Remove the generated tokens from the cache
251+
if isinstance(cache, QuantizedCache):
252+
key_attr, value_attr = "_quantized_key_cache", "_quantized_value_cache"
253+
else:
254+
key_attr, value_attr = "key_cache", "value_cache"
255+
256+
setattr(cache, key_attr, [key[:, :, :c] for key, c in zip(getattr(cache, key_attr), cache_seq_lengths)])
257+
setattr(cache, value_attr, [value[:, :, :c] for value, c in zip(getattr(cache, value_attr), cache_seq_lengths)])
258+
248259
return answer
249260

250261

‎kvpress/presses/base_press.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88

99
import torch
1010
from torch import nn
11-
from transformers import LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, Qwen2ForCausalLM
11+
from transformers import (
12+
LlamaForCausalLM,
13+
MistralForCausalLM,
14+
Phi3ForCausalLM,
15+
PreTrainedModel,
16+
Qwen2ForCausalLM,
17+
QuantizedCache,
18+
)
1219

1320
logger = logging.getLogger(__name__)
1421

@@ -92,8 +99,12 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
9299
if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
93100
return output
94101

95-
keys = cache.key_cache[module.layer_idx]
96-
values = cache.value_cache[module.layer_idx]
102+
if isinstance(cache, QuantizedCache):
103+
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
104+
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
105+
else:
106+
keys = cache.key_cache[module.layer_idx]
107+
values = cache.value_cache[module.layer_idx]
97108

98109
with torch.no_grad():
99110
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
@@ -104,8 +115,14 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
104115
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
105116

106117
# Update cache
107-
cache.key_cache[module.layer_idx] = keys.gather(2, indices)
108-
cache.value_cache[module.layer_idx] = values.gather(2, indices)
118+
keys = keys.gather(2, indices).contiguous()
119+
values = values.gather(2, indices).contiguous()
120+
if isinstance(cache, QuantizedCache):
121+
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
122+
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
123+
else:
124+
cache.key_cache[module.layer_idx] = keys
125+
cache.value_cache[module.layer_idx] = values
109126

110127
return output
111128

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "kvpress"
33
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
44
description = "Efficiently compress the KV cache of any pretrained transformer"
5-
version = "0.0.1"
5+
version = "0.0.2"
66
readme = "README.md"
77

88
[tool.poetry.dependencies]

0 commit comments

Comments
 (0)
Please sign in to comment.