Skip to content

Commit ac2445e

Browse files
authoredDec 3, 2024··
Add ThinKPress (#20)
1 parent 51f3877 commit ac2445e

File tree

8 files changed

+127
-14
lines changed

8 files changed

+127
-14
lines changed
 

‎README.md

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ All current presses are training free. We provide the following presses associat
6060
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
6161
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
6262
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
63+
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above.
6364

6465
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
6566

‎kvpress/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from kvpress.presses.snapkv_press import SnapKVPress
1313
from kvpress.presses.streaming_llm_press import StreamingLLMPress
1414
from kvpress.presses.tova_press import TOVAPress
15+
from kvpress.presses.think_press import ThinKPress
1516

1617
__all__ = [
1718
"BasePress",
@@ -21,6 +22,7 @@
2122
"RandomPress",
2223
"SnapKVPress",
2324
"StreamingLLMPress",
25+
"ThinKPress",
2426
"TOVAPress",
2527
"KVPressTextGenerationPipeline",
2628
"apply_per_layer_compression",

‎kvpress/presses/snapkv_press.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44

5-
import inspect
65
import math
76
from dataclasses import dataclass
87

@@ -45,13 +44,9 @@ def compute_window_attention(self, module, hidden_states, keys):
4544
query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2)
4645

4746
# Apply RoPE
48-
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
49-
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
50-
cos, sin = module.rotary_emb(query_states, position_ids)
51-
else:
52-
cos, sin = module.rotary_emb(query_states, q_len)
53-
cos, sin = cos[-self.window_size :].unsqueeze(0), sin[-self.window_size :].unsqueeze(0)
54-
query_states = (query_states * cos) + (rotate_half(query_states) * sin)
47+
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
48+
cos, sin = module.rotary_emb(query_states, position_ids)
49+
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))
5550

5651
# Compute attention for first q_len - window_size tokens
5752
key_states = repeat_kv(keys, module.num_key_value_groups)

‎kvpress/presses/think_press.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from dataclasses import dataclass
6+
from typing import Optional
7+
8+
import torch
9+
from torch import nn
10+
from transformers.cache_utils import QuantizedCache
11+
from transformers.models.llama.modeling_llama import rotate_half
12+
13+
from kvpress.presses.base_press import BasePress
14+
15+
16+
@dataclass
17+
class ThinKPress(BasePress):
18+
"""
19+
ThinK (https://arxiv.org/pdf/2407.21018) compresses the dimensions of the keys, and not the sequence length.
20+
Hence it can be combined with any other press that compresses the sequence length, e.g.
21+
press = ThinKPress(compression_ratio=0.5, inner_press=SnapKVPress(compression_ratio=0.5))
22+
23+
Here, we zero out the pruned dimensions resulting in no memory gain (the shape of the keys remains the same).
24+
To achieve memory savings, several options can be considered (see https://github.com/NVIDIA/kvpress/pull/18/),
25+
we might implement them in the future, especially if other similar presses are requested.
26+
27+
This press has been reviewed by Yuhui Xu, first author of the ThinK paper.
28+
"""
29+
30+
compression_ratio: float = 0.0
31+
inner_press: Optional[BasePress] = None
32+
window_size: int = 32
33+
34+
def compute_window_queries(self, module, hidden_states):
35+
"""
36+
Re-compute the last window_size query states
37+
"""
38+
39+
bsz, q_len, _ = hidden_states.shape
40+
41+
# Get last window_size queries
42+
if hasattr(module, "q_proj"):
43+
query_states = module.q_proj(hidden_states[:, -self.window_size :])
44+
elif hasattr(module, "qkv_proj"):
45+
qkv = module.qkv_proj(hidden_states[:, -self.window_size :])
46+
query_states = qkv[..., : module.num_heads * module.head_dim]
47+
else:
48+
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")
49+
50+
query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2)
51+
52+
# Apply RoPE
53+
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
54+
cos, sin = module.rotary_emb(query_states, position_ids)
55+
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))
56+
57+
return query_states
58+
59+
def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
60+
"""
61+
We first apply the inner press, then we prune the key dimensions. If other similar presses are requested,
62+
we will create a dedicated DimensionBasePress class to avoid code duplication.
63+
"""
64+
65+
# Apply the forward hook of the inner press
66+
if self.inner_press is not None:
67+
output = self.inner_press.forward_hook(module, input, kwargs, output)
68+
69+
# Don't compress if the compression ratio is 0 or this is not pre-filling
70+
cache = output[-1]
71+
hidden_states = kwargs["hidden_states"]
72+
q_len = hidden_states.shape[1]
73+
assert q_len > self.window_size, "Query length should be greater than the window size"
74+
75+
if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
76+
return output
77+
78+
# Get keys
79+
if isinstance(cache, QuantizedCache):
80+
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
81+
else:
82+
keys = cache.key_cache[module.layer_idx]
83+
bsz, num_key_value_heads, q_len, head_dim = keys.shape
84+
85+
# ThinK specific code
86+
queries = self.compute_window_queries(module, kwargs["hidden_states"])
87+
88+
# Compute scores per dimension
89+
queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim)
90+
queries_norm = queries_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, module.head_dim).mean(2)
91+
keys_norm = torch.pow(keys, 2).mean(dim=2)
92+
key_scores = queries_norm * keys_norm # (bsz, num_key_value_heads, head_dim)
93+
94+
# Prune dimensions with the lowest scores by setting them to 0
95+
n_pruned = int(head_dim * self.compression_ratio)
96+
indices = key_scores.topk(n_pruned, dim=-1, largest=False).indices
97+
indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1)
98+
keys = keys.scatter_(-1, indices, 0)
99+
100+
# Update cache
101+
if isinstance(cache, QuantizedCache):
102+
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
103+
else:
104+
cache.key_cache[module.layer_idx] = keys
105+
106+
return output

‎notebooks/per_layer_compression_demo.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@
216216
],
217217
"metadata": {
218218
"kernelspec": {
219-
"display_name": "kvpress_2",
219+
"display_name": ".venv",
220220
"language": "python",
221-
"name": "kvpress_2"
221+
"name": "python3"
222222
},
223223
"language_info": {
224224
"codemirror_mode": {
@@ -230,7 +230,7 @@
230230
"name": "python",
231231
"nbconvert_exporter": "python",
232232
"pygments_lexer": "ipython3",
233-
"version": "3.11.9"
233+
"version": "3.10.12"
234234
}
235235
},
236236
"nbformat": 4,

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "kvpress"
33
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
44
description = "Efficiently compress the KV cache of any pretrained transformer"
5-
version = "0.0.3"
5+
version = "0.0.4"
66
readme = "README.md"
77

88
[tool.poetry.dependencies]

‎tests/__init__.py

Whitespace-only changes.

‎tests/presses/test_presses.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,24 @@
1515
SnapKVPress,
1616
StreamingLLMPress,
1717
TOVAPress,
18+
ThinKPress,
1819
)
20+
1921
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401
2022

2123

24+
def test_think_inner_press(unit_test_model): # noqa: F811
25+
press = ThinKPress(compression_ratio=0.5, window_size=2, inner_press=KnormPress(0.5))
26+
with press(unit_test_model):
27+
input_ids = unit_test_model.dummy_inputs["input_ids"]
28+
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
29+
30+
2231
def test_presses_run(unit_test_model): # noqa: F811
23-
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress]:
32+
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]:
2433
for compression_ratio in [0.2, 0.4, 0.6, 0.8]:
2534
press = cls(compression_ratio=compression_ratio)
26-
if cls == SnapKVPress:
35+
if cls in [SnapKVPress, ThinKPress]:
2736
press.window_size = 2
2837
with press(unit_test_model):
2938
input_ids = unit_test_model.dummy_inputs["input_ids"]

0 commit comments

Comments
 (0)
Please sign in to comment.