Skip to content

Commit 2b350b0

Browse files
authoredDec 12, 2024··
Add KeyRerotationPress
1 parent 7503f0d commit 2b350b0

13 files changed

+315
-106
lines changed
 

‎.github/PULL_REQUEST_TEMPLATE.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
## PR description
2+
3+
Description of your PR. Fixes # (issue) (if applicable)
4+
5+
## New press checklist (if applicable)
6+
7+
- [ ] I added `mypress_press.py` in the `presses` directory
8+
- [ ] I added `MyPress` in `__init__.py`
9+
- [ ] I updated the `README.md` with a 1 liner about my new press in the Available presses section
10+
- [ ] I added my press in the `default_presses` list in `tests/default_presses.py`

‎README.md

+12-41
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ pip install flash-attn --no-build-isolation
1919

2020
## Usage
2121

22-
This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` parameter that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
23-
24-
22+
This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:
2523

2624
```python
2725
from kvpress import ExpectedAttentionPress
@@ -48,27 +46,28 @@ In the snippet above, the compression is only applied on the context tokens so t
4846
4947
## Contributing with a new press
5048

51-
We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [FAQ](#faq) for more information on how presses work and how to create new ones or check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide.
49+
We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide to understand how presses work and what should be done to create a new one.
5250

5351
## Available presses
5452

55-
All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score used to prune the KV pairs with lowest importance:
53+
All current presses are training free. Several of them inherit from `ScorerPress` and rely on a score to prune the KV pairs with lowest importance:
5654

5755
- `RandomPress`: random score
5856
- `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430))
59-
- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
57+
- `SnapKVPress`: average attention weight of the last queries ([paper](https://arxiv.org/abs/2404.14469))
6058
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
6159
- `StreamingLLMPress`: keep only the initial and recent tokens ([paper](https://arxiv.org/abs/2309.17453))
62-
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)). The input of this press is the lazy threshold, not the compression ratio.
6360
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
6461
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
6562

66-
We also provide presses relying on a different logic:
67-
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018))
63+
Some presses relying on a different logic:
64+
- `ThinKPress`: compress the dimensions of the keys based on the channel attention score on the last queries ([paper](https://arxiv.org/pdf/2407.21018))
65+
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846))
6866

69-
Finally we provide two special presses:
70-
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental)
71-
- `ComposedPress`: a press that composes multiple presses together by chaining their forward hooks
67+
Finally we provide special presses:
68+
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio
69+
- `ComposedPress`: compose multiple presses together by chaining their forward hooks
70+
- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`.
7271

7372
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
7473

@@ -129,9 +128,7 @@ Memory usage should be reduced by around `compression_ratio * kv_cache_size`. As
129128

130129
### How does a press work ? </summary>
131130

132-
A press registers a forward hook to each attention layer during the pre-filling phase:
133-
1. Immediately after the forward pass, the hook is called, and it computes a score for each key-value pair using the `press.score` method
134-
2. The key-value pairs with the lowest scores are then removed based on the `compression_ratio` parameter
131+
A press registers a forward hook (`press.forward_hook` method) to each attention layer during the pre-filling phase. Registration can be applied using the press as a context manager (`press.__call__` method):
135132

136133
```python
137134
import torch
@@ -170,29 +167,3 @@ with press(model):
170167
However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once.
171168

172169
</details>
173-
174-
<details><summary>
175-
176-
### How to create a new press ?
177-
</summary>
178-
179-
All presses are stored in the `presses` directory. The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair (see `knorm_press.py` for a simple example). Check the notebook [new_press.ipynb](notebooks/new_press.ipynb) for a step-by-step guide.
180-
181-
Before opening a pull request with a new press, make sure to register it in the `__init__.py` file of repository and to add it in [test_presses.py](tests/presses/test_presses.py).
182-
183-
</details>
184-
185-
<details><summary>
186-
187-
### Can I change the compression ratio from one layer to another ?
188-
</summary>
189-
190-
We provide an experimental feature, which only works with flash attention:
191-
```python
192-
from kvpress import PerLayerCompressionPress
193-
# compression_ratios should have the same length as the number of layers
194-
press = PerLayerCompressionPress(press, compression_ratios=[...])
195-
```
196-
197-
Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more details.
198-
</details>

‎kvpress/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from kvpress.presses.base_press import BasePress
77
from kvpress.presses.composed_press import ComposedPress
88
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
9+
from kvpress.presses.key_rerotation_press import KeyRerotationPress
910
from kvpress.presses.knorm_press import KnormPress
1011
from kvpress.presses.observed_attention_press import ObservedAttentionPress
1112
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
@@ -32,4 +33,5 @@
3233
"TOVAPress",
3334
"KVPressTextGenerationPipeline",
3435
"PerLayerCompressionPress",
36+
"KeyRerotationPress",
3537
]

‎kvpress/pipeline.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from transformers.pipelines.base import GenericTensor
1313

1414
from kvpress.presses.base_press import BasePress
15+
from kvpress.presses.composed_press import ComposedPress
16+
from kvpress.presses.key_rerotation_press import KeyRerotationPress
1517
from kvpress.presses.observed_attention_press import ObservedAttentionPress
18+
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
1619

1720
logger = logging.getLogger(__name__)
1821

@@ -167,7 +170,7 @@ def _forward(
167170
self.model(
168171
input_ids=context_ids,
169172
past_key_values=cache,
170-
output_attentions=isinstance(press, ObservedAttentionPress),
173+
output_attentions=self.output_attentions,
171174
num_logits_to_keep=1,
172175
)
173176

@@ -180,13 +183,26 @@ def _forward(
180183
answer = self.generate_answer(
181184
question_ids=question_ids.to(self.model.device),
182185
cache=cache,
183-
context_length=context_length,
186+
context_length=(cache.get_seq_length() if isinstance(press, KeyRerotationPress) else context_length),
184187
max_new_tokens=max_new_tokens,
185188
)
186189
answers.append(answer)
187190

188191
return answers
189192

193+
def output_attentions(self, press: BasePress):
194+
if isinstance(press, ObservedAttentionPress):
195+
return True
196+
if isinstance(press, (KeyRerotationPress, PerLayerCompressionPress)) and isinstance(
197+
press.press, ObservedAttentionPress
198+
):
199+
return True
200+
if isinstance(press, ComposedPress) and any(
201+
isinstance(sub_press, ObservedAttentionPress) for sub_press in press.presses
202+
):
203+
return True
204+
return False
205+
190206
def postprocess(self, model_outputs, single_question):
191207
if single_question:
192208
return {"answer": model_outputs[0]}
+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
import inspect
6+
from dataclasses import dataclass
7+
8+
import torch
9+
from torch import nn
10+
from transformers.models.llama.modeling_llama import rotate_half
11+
12+
from kvpress.presses.base_press import BasePress
13+
from kvpress.presses.scorer_press import ScorerPress
14+
15+
16+
@dataclass
17+
class KeyRerotationPress(BasePress):
18+
"""
19+
Rerotate keys to have a uniform RoPE representation of keys after pruning.
20+
This method is used in several key-value cache compression methods, such as
21+
- SinkCache implementation in Hugging Face's transformers library
22+
- FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models
23+
Parameters
24+
----------
25+
press : ScorerPress
26+
The press object to apply per-layer compression to.
27+
"""
28+
29+
press: ScorerPress
30+
31+
def compress(
32+
self,
33+
module: nn.Module,
34+
hidden_states: torch.Tensor,
35+
keys: torch.Tensor,
36+
values: torch.Tensor,
37+
attentions: torch.Tensor,
38+
kwargs: dict,
39+
) -> tuple[torch.Tensor, torch.Tensor]:
40+
if self.press.compression_ratio == 0:
41+
return keys, values
42+
43+
# Compute scores from base press
44+
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
45+
46+
# Get indices of KV pairs with the lowest scores
47+
q_len = hidden_states.shape[1]
48+
n_kept = int(q_len * (1 - self.press.compression_ratio))
49+
indices = scores.topk(n_kept, dim=-1).indices
50+
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
51+
52+
cos, sin = get_rope_embeddings(module, keys)
53+
# Rerotate as follows
54+
# 1. keys = RoPE(W_k * hidden_states)
55+
# 2. keys_unrotated = RoPE^-1(keys)
56+
# 3. keys_pruned = prune(keys_unrotated)
57+
# 4. keys = RoPE(keys_pruned)
58+
59+
# 2. Inverse of rotation matrix is equivalent to setting sin -> -sin in the equation below
60+
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
61+
# 3. Prune keys
62+
keys = keys.gather(2, indices).contiguous()
63+
# 4. Apply RoPE
64+
cos, sin = get_rope_embeddings(module, keys)
65+
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
66+
67+
values = values.gather(2, indices).contiguous()
68+
return keys, values
69+
70+
71+
def get_rope_embeddings(module, x):
72+
length = x.shape[2]
73+
# rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape
74+
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
75+
position_ids = torch.arange(length).unsqueeze(0).to(x.device)
76+
cos, sin = module.rotary_emb(x, position_ids)
77+
else:
78+
cos, sin = module.rotary_emb(x, length)
79+
return cos, sin

‎kvpress/presses/random_press.py

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44

55
from dataclasses import dataclass
6+
from typing import Optional
67

78
import torch
89
from torch import nn
@@ -14,6 +15,9 @@
1415
class RandomPress(ScorerPress):
1516
"""Randomly prune KV pairs"""
1617

18+
compression_ratio: float = 0.0
19+
seed: Optional[int] = None
20+
1721
def score(
1822
self,
1923
module: nn.Module,
@@ -23,4 +27,6 @@ def score(
2327
attentions: torch.Tensor,
2428
kwargs,
2529
) -> torch.Tensor:
30+
if self.seed is not None:
31+
torch.manual_seed(self.seed)
2632
return torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype)

‎kvpress/presses/streaming_llm_press.py

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ class StreamingLLMPress(ScorerPress):
1616
Prune a fixed number of KV pairs at the beginning and end of the sequence (https://arxiv.org/abs/2309.17453)
1717
We keep the first n_sink tokens and the last n_local tokens.
1818
n_local is computed using the compression ratio.
19+
20+
Note that the original implementation https://github.com/mit-han-lab/streaming-llm additionally rerotates keys.
21+
This can be achieved by using
22+
press = KeyRerotationPress(press=StreamingLLMPress(compression_ratio, n_sink))
1923
"""
2024

2125
compression_ratio: float = 0.0

‎notebooks/new_press.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@
242242
"source": [
243243
"All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n",
244244
"- register it in the `__init__.py` file of repository\n",
245-
"- add a test [test_presses.py](tests/presses/test_presses.py)\n",
245+
"- register the press in [default_presses.py](tests/default_presses.py)\n",
246246
"- update the README"
247247
]
248248
}

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "kvpress"
33
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
44
description = "Efficiently compress the KV cache of any pretrained transformer"
5-
version = "0.0.4"
5+
version = "0.1.0"
66
readme = "README.md"
77

88
[tool.poetry.dependencies]

‎tests/default_presses.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from kvpress import (
6+
ExpectedAttentionPress,
7+
KnormPress,
8+
RandomPress,
9+
SimLayerKVPress,
10+
SnapKVPress,
11+
StreamingLLMPress,
12+
ThinKPress,
13+
TOVAPress,
14+
)
15+
16+
# contains all presses to be tested
17+
# kwargs should be ordered easy to hard compression
18+
default_presses = [
19+
{"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
20+
{"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
21+
{"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
22+
{"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
23+
{
24+
"cls": SnapKVPress,
25+
"kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}],
26+
},
27+
{"cls": TOVAPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
28+
{
29+
"cls": ThinKPress,
30+
"kwargs": [
31+
{"key_channel_compression_ratio": 0.2, "window_size": 2},
32+
{"key_channel_compression_ratio": 0.8, "window_size": 2},
33+
],
34+
},
35+
{
36+
"cls": SimLayerKVPress,
37+
"kwargs": [
38+
{"lazy_threshold": 0.8, "n_initial": 1, "n_recent": 1, "n_last": 1},
39+
{"lazy_threshold": 0.2, "n_initial": 1, "n_recent": 1, "n_last": 1},
40+
],
41+
},
42+
]

‎tests/integration/test_ruler.py

+7-20
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,7 @@
44
from transformers import DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache
55
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available
66

7-
from kvpress import (
8-
ExpectedAttentionPress,
9-
KnormPress,
10-
SimLayerKVPress,
11-
SnapKVPress,
12-
StreamingLLMPress,
13-
ThinKPress,
14-
TOVAPress,
15-
)
7+
from tests.default_presses import default_presses
168
from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401
179

1810

@@ -25,18 +17,13 @@ def df_ruler():
2517

2618
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
2719
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
28-
@pytest.mark.parametrize(
29-
"cls", [KnormPress, ExpectedAttentionPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress, SimLayerKVPress]
30-
)
31-
@pytest.mark.parametrize("compression_ratio", [0.1, 0.2])
20+
@pytest.mark.parametrize("press_dict", default_presses)
3221
@pytest.mark.parametrize("cache", ["dynamic", "quantized"])
33-
def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, cls, compression_ratio, cache): # noqa: F811
34-
if cls == ThinKPress:
35-
press = cls(key_channel_compression_ratio=compression_ratio, window_size=2)
36-
elif cls == SimLayerKVPress:
37-
press = cls(lazy_threshold=1 - compression_ratio)
38-
else:
39-
press = cls(compression_ratio=compression_ratio)
22+
def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press_dict, cache): # noqa: F811
23+
cls = press_dict["cls"]
24+
kwargs = press_dict["kwargs"][0]
25+
press = cls(**kwargs)
26+
4027
if cache == "dynamic":
4128
cache = DynamicCache()
4229
elif cache == "quantized" and is_optimum_quanto_available():
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from dataclasses import dataclass
6+
7+
import pytest
8+
import torch
9+
from torch import nn
10+
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half
11+
12+
from kvpress import KeyRerotationPress, ScorerPress
13+
from kvpress.presses.key_rerotation_press import get_rope_embeddings
14+
from tests.fixtures import unit_test_model # noqa: F401
15+
16+
17+
@pytest.mark.parametrize("precision", ["full", "half"])
18+
def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM, precision): # noqa: F811
19+
"""
20+
Compare KeyRerotationPress' rerotation of keys with the reference implementation.
21+
In the reference implementation, we are computing
22+
1. keys = W_k * hidden_states
23+
2. keys_pruned = prune(keys)
24+
3. keys = RoPE(keys_pruned)
25+
"""
26+
if precision == "half" and torch.cuda.is_available():
27+
unit_test_model = unit_test_model.cuda().half()
28+
elif precision == "half" and not torch.cuda.is_available():
29+
pytest.skip("Half precision test is skipped because CUDA is not available.")
30+
31+
original_press = RandomPressStoreIndices(compression_ratio=0.5)
32+
key_rerotation_press = KeyRerotationPress(press=original_press)
33+
34+
module = unit_test_model.model.layers[0].self_attn
35+
hidden_states = torch.randn(
36+
8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype
37+
)
38+
39+
keys = get_keys_with_rope(module, hidden_states)
40+
41+
values = torch.randn_like(keys)
42+
# Press result
43+
keys_compressed, _ = key_rerotation_press.compress(
44+
module, hidden_states, keys, values, attentions=None, kwargs=dict()
45+
)
46+
47+
indices = original_press.indices
48+
keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices)
49+
50+
assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6 if precision == "full" else 1e-3)
51+
52+
53+
def get_keys_with_rope(module, hidden_states):
54+
# Compute keys with RoPE
55+
keys = get_keys_without_pos_embedding(module, hidden_states)
56+
cos, sin = get_rope_embeddings(module, keys)
57+
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
58+
return keys
59+
60+
61+
@dataclass
62+
class RandomPressStoreIndices(ScorerPress):
63+
compression_ratio: float = 0.0
64+
seed: int = 0
65+
66+
def __post_init__(self):
67+
self.indices = None
68+
super().__post_init__()
69+
70+
def score(
71+
self,
72+
module: nn.Module,
73+
hidden_states: torch.Tensor,
74+
keys: torch.Tensor,
75+
values: torch.Tensor,
76+
attentions: torch.Tensor,
77+
kwargs,
78+
) -> torch.Tensor:
79+
torch.manual_seed(self.seed)
80+
scores = torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype)
81+
# Get indices of KV pairs with the lowest scores
82+
q_len = hidden_states.shape[1]
83+
n_kept = int(q_len * (1 - self.compression_ratio))
84+
indices = scores.topk(n_kept, dim=-1).indices
85+
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
86+
self.indices = indices
87+
88+
return scores
89+
90+
91+
def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hidden_states, indices):
92+
"""
93+
Computes the rerotated keys for the given indices.
94+
1. keys = W_k * hidden_states
95+
2. keys_pruned = prune(keys)
96+
3. keys = RoPE(keys_pruned)
97+
"""
98+
# 1.
99+
keys = get_keys_without_pos_embedding(module, hidden_states)
100+
# 2.
101+
keys = keys.gather(2, indices).contiguous()
102+
# 3.
103+
cos, sin = get_rope_embeddings(module, keys)
104+
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
105+
return keys
106+
107+
108+
def get_keys_without_pos_embedding(module, hidden_states):
109+
key_states = module.k_proj(hidden_states)
110+
key_states = key_states.view(
111+
key_states.shape[0], key_states.shape[1], module.num_key_value_heads, module.head_dim
112+
).transpose(1, 2)
113+
return key_states

‎tests/presses/test_presses.py

+20-41
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,15 @@
22
# SPDX-License-Identifier: Apache-2.0
33
from dataclasses import dataclass
44

5+
import pytest
56
import torch
67
from torch import nn
78
from transformers import DynamicCache
89

9-
from kvpress import (
10-
ComposedPress,
11-
ExpectedAttentionPress,
12-
KnormPress,
13-
ObservedAttentionPress,
14-
RandomPress,
15-
SimLayerKVPress,
16-
SnapKVPress,
17-
StreamingLLMPress,
18-
TOVAPress,
19-
)
10+
from kvpress import ComposedPress, KeyRerotationPress, KnormPress, ObservedAttentionPress
2011
from kvpress.presses.scorer_press import ScorerPress
2112
from kvpress.presses.think_press import ThinKPress
13+
from tests.default_presses import default_presses
2214
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401
2315

2416

@@ -31,40 +23,27 @@ def test_composed_press(unit_test_model): # noqa: F811
3123
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
3224

3325

34-
def test_presses_run(unit_test_model): # noqa: F811
35-
for cls in [
36-
KnormPress,
37-
ExpectedAttentionPress,
38-
RandomPress,
39-
StreamingLLMPress,
40-
SimLayerKVPress,
41-
SnapKVPress,
42-
TOVAPress,
43-
ThinKPress,
44-
]:
45-
for value in [0.2, 0.4, 0.6, 0.8]:
46-
47-
# Load the press
48-
if cls == ThinKPress:
49-
press = cls(key_channel_compression_ratio=value, window_size=2)
50-
elif cls == SimLayerKVPress:
51-
press = cls(lazy_threshold=value, n_initial=1, n_recent=1, n_last=1)
52-
else:
53-
press = cls(compression_ratio=value)
54-
if cls == SnapKVPress:
55-
press.window_size = 2
56-
57-
# Run the press
58-
with press(unit_test_model):
59-
input_ids = unit_test_model.dummy_inputs["input_ids"]
60-
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
61-
# Check that the press has a compression_ratio attribute
62-
assert hasattr(press, "compression_ratio")
26+
@pytest.mark.parametrize("press_dict", default_presses)
27+
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress])
28+
def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
29+
cls = press_dict["cls"]
30+
for kwargs in press_dict["kwargs"]:
31+
press = cls(**kwargs)
32+
if isinstance(wrapper_press, ComposedPress):
33+
press = ComposedPress(presses=[press])
34+
if isinstance(wrapper_press, KeyRerotationPress):
35+
press = KeyRerotationPress(press=press)
36+
37+
with press(unit_test_model):
38+
input_ids = unit_test_model.dummy_inputs["input_ids"]
39+
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
40+
# Check that the press has a compression_ratio attribute
41+
assert hasattr(press, "compression_ratio")
6342

6443

6544
def test_presses_run_observed_attention(unit_test_model_output_attention): # noqa: F811
6645
for cls in [ObservedAttentionPress]:
67-
for compresion_ratio in [0.2, 0.4, 0.6, 0.8]:
46+
for compresion_ratio in [0.2, 0.8]:
6847
press = cls(compression_ratio=compresion_ratio)
6948
with press(unit_test_model_output_attention):
7049
input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"]

0 commit comments

Comments
 (0)
Please sign in to comment.