Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ThinKPress #20

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ All current presses are training free. We provide the following presses associat
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above.

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.think_press import ThinKPress

__all__ = [
"BasePress",
Expand All @@ -21,6 +22,7 @@
"RandomPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"apply_per_layer_compression",
Expand Down
11 changes: 3 additions & 8 deletions kvpress/presses/snapkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0


import inspect
import math
from dataclasses import dataclass

Expand Down Expand Up @@ -45,13 +44,9 @@ def compute_window_attention(self, module, hidden_states, keys):
query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2)

# Apply RoPE
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
cos, sin = module.rotary_emb(query_states, position_ids)
else:
cos, sin = module.rotary_emb(query_states, q_len)
cos, sin = cos[-self.window_size :].unsqueeze(0), sin[-self.window_size :].unsqueeze(0)
query_states = (query_states * cos) + (rotate_half(query_states) * sin)
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
cos, sin = module.rotary_emb(query_states, position_ids)
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))

# Compute attention for first q_len - window_size tokens
key_states = repeat_kv(keys, module.num_key_value_groups)
Expand Down
106 changes: 106 additions & 0 deletions kvpress/presses/think_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from dataclasses import dataclass
from typing import Optional

import torch
from torch import nn
from transformers.cache_utils import QuantizedCache
from transformers.models.llama.modeling_llama import rotate_half

from kvpress.presses.base_press import BasePress


@dataclass
class ThinKPress(BasePress):
"""
ThinK (https://arxiv.org/pdf/2407.21018) compresses the dimensions of the keys, and not the sequence length.
Hence it can be combined with any other press that compresses the sequence length, e.g.
press = ThinKPress(compression_ratio=0.5, inner_press=SnapKVPress(compression_ratio=0.5))

Here, we zero out the pruned dimensions resulting in no memory gain (the shape of the keys remains the same).
To achieve memory savings, several options can be considered (see https://github.com/NVIDIA/kvpress/pull/18/),
we might implement them in the future, especially if other similar presses are requested.

This press has been reviewed by Yuhui Xu, first author of the ThinK paper.
"""

compression_ratio: float = 0.0
inner_press: Optional[BasePress] = None
window_size: int = 32

def compute_window_queries(self, module, hidden_states):
"""
Re-compute the last window_size query states
"""

bsz, q_len, _ = hidden_states.shape

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -self.window_size :])
elif hasattr(module, "qkv_proj"):
qkv = module.qkv_proj(hidden_states[:, -self.window_size :])
query_states = qkv[..., : module.num_heads * module.head_dim]
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2)

# Apply RoPE
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
cos, sin = module.rotary_emb(query_states, position_ids)
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))

return query_states

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""
We first apply the inner press, then we prune the key dimensions. If other similar presses are requested,
we will create a dedicated DimensionBasePress class to avoid code duplication.
"""

# Apply the forward hook of the inner press
if self.inner_press is not None:
output = self.inner_press.forward_hook(module, input, kwargs, output)

# Don't compress if the compression ratio is 0 or this is not pre-filling
cache = output[-1]
hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]
assert q_len > self.window_size, "Query length should be greater than the window size"

if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
return output

# Get keys
if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
bsz, num_key_value_heads, q_len, head_dim = keys.shape

# ThinK specific code
queries = self.compute_window_queries(module, kwargs["hidden_states"])

# Compute scores per dimension
queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim)
queries_norm = queries_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, module.head_dim).mean(2)
keys_norm = torch.pow(keys, 2).mean(dim=2)
key_scores = queries_norm * keys_norm # (bsz, num_key_value_heads, head_dim)

# Prune dimensions with the lowest scores by setting them to 0
n_pruned = int(head_dim * self.compression_ratio)
indices = key_scores.topk(n_pruned, dim=-1, largest=False).indices
indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1)
keys = keys.scatter_(-1, indices, 0)

# Update cache
if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
else:
cache.key_cache[module.layer_idx] = keys

return output
6 changes: 3 additions & 3 deletions notebooks/per_layer_compression_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "kvpress_2",
"display_name": ".venv",
"language": "python",
"name": "kvpress_2"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -230,7 +230,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
description = "Efficiently compress the KV cache of any pretrained transformer"
version = "0.0.3"
version = "0.0.4"
readme = "README.md"

[tool.poetry.dependencies]
Expand Down
Empty file added tests/__init__.py
Empty file.
13 changes: 11 additions & 2 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,24 @@
SnapKVPress,
StreamingLLMPress,
TOVAPress,
ThinKPress,
)

from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401


def test_think_inner_press(unit_test_model): # noqa: F811
press = ThinKPress(compression_ratio=0.5, window_size=2, inner_press=KnormPress(0.5))
with press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"]
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values


def test_presses_run(unit_test_model): # noqa: F811
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress]:
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]:
for compression_ratio in [0.2, 0.4, 0.6, 0.8]:
press = cls(compression_ratio=compression_ratio)
if cls == SnapKVPress:
if cls in [SnapKVPress, ThinKPress]:
press.window_size = 2
with press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"]
Expand Down
Loading