|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | + |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import Optional |
| 7 | + |
| 8 | +import torch |
| 9 | +from torch import nn |
| 10 | +from transformers.cache_utils import QuantizedCache |
| 11 | +from transformers.models.llama.modeling_llama import rotate_half |
| 12 | + |
| 13 | +from kvpress.presses.base_press import BasePress |
| 14 | + |
| 15 | + |
| 16 | +@dataclass |
| 17 | +class ThinKPress(BasePress): |
| 18 | + """ |
| 19 | + ThinK (https://arxiv.org/pdf/2407.21018) compresses the dimensions of the keys, and not the sequence length. |
| 20 | + Hence it can be combined with any other press that compresses the sequence length, e.g. |
| 21 | + press = ThinKPress(compression_ratio=0.5, inner_press=SnapKVPress(compression_ratio=0.5)) |
| 22 | +
|
| 23 | + Here, we zero out the pruned dimensions resulting in no memory gain (the shape of the keys remains the same). |
| 24 | + To achieve memory savings, several options can be considered (see https://github.com/NVIDIA/kvpress/pull/18/), |
| 25 | + we might implement them in the future, especially if other similar presses are requested. |
| 26 | +
|
| 27 | + This press has been reviewed by Yuhui Xu, first author of the ThinK paper. |
| 28 | + """ |
| 29 | + |
| 30 | + compression_ratio: float = 0.0 |
| 31 | + inner_press: Optional[BasePress] = None |
| 32 | + window_size: int = 32 |
| 33 | + |
| 34 | + def compute_window_queries(self, module, hidden_states): |
| 35 | + """ |
| 36 | + Re-compute the last window_size query states |
| 37 | + """ |
| 38 | + |
| 39 | + bsz, q_len, _ = hidden_states.shape |
| 40 | + |
| 41 | + # Get last window_size queries |
| 42 | + if hasattr(module, "q_proj"): |
| 43 | + query_states = module.q_proj(hidden_states[:, -self.window_size :]) |
| 44 | + elif hasattr(module, "qkv_proj"): |
| 45 | + qkv = module.qkv_proj(hidden_states[:, -self.window_size :]) |
| 46 | + query_states = qkv[..., : module.num_heads * module.head_dim] |
| 47 | + else: |
| 48 | + raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") |
| 49 | + |
| 50 | + query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2) |
| 51 | + |
| 52 | + # Apply RoPE |
| 53 | + position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) |
| 54 | + cos, sin = module.rotary_emb(query_states, position_ids) |
| 55 | + query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) |
| 56 | + |
| 57 | + return query_states |
| 58 | + |
| 59 | + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): |
| 60 | + """ |
| 61 | + We first apply the inner press, then we prune the key dimensions. If other similar presses are requested, |
| 62 | + we will create a dedicated DimensionBasePress class to avoid code duplication. |
| 63 | + """ |
| 64 | + |
| 65 | + # Apply the forward hook of the inner press |
| 66 | + if self.inner_press is not None: |
| 67 | + output = self.inner_press.forward_hook(module, input, kwargs, output) |
| 68 | + |
| 69 | + # Don't compress if the compression ratio is 0 or this is not pre-filling |
| 70 | + cache = output[-1] |
| 71 | + hidden_states = kwargs["hidden_states"] |
| 72 | + q_len = hidden_states.shape[1] |
| 73 | + assert q_len > self.window_size, "Query length should be greater than the window size" |
| 74 | + |
| 75 | + if (self.compression_ratio == 0) or (cache.seen_tokens > q_len): |
| 76 | + return output |
| 77 | + |
| 78 | + # Get keys |
| 79 | + if isinstance(cache, QuantizedCache): |
| 80 | + keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) |
| 81 | + else: |
| 82 | + keys = cache.key_cache[module.layer_idx] |
| 83 | + bsz, num_key_value_heads, q_len, head_dim = keys.shape |
| 84 | + |
| 85 | + # ThinK specific code |
| 86 | + queries = self.compute_window_queries(module, kwargs["hidden_states"]) |
| 87 | + |
| 88 | + # Compute scores per dimension |
| 89 | + queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim) |
| 90 | + queries_norm = queries_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, module.head_dim).mean(2) |
| 91 | + keys_norm = torch.pow(keys, 2).mean(dim=2) |
| 92 | + key_scores = queries_norm * keys_norm # (bsz, num_key_value_heads, head_dim) |
| 93 | + |
| 94 | + # Prune dimensions with the lowest scores by setting them to 0 |
| 95 | + n_pruned = int(head_dim * self.compression_ratio) |
| 96 | + indices = key_scores.topk(n_pruned, dim=-1, largest=False).indices |
| 97 | + indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1) |
| 98 | + keys = keys.scatter_(-1, indices, 0) |
| 99 | + |
| 100 | + # Update cache |
| 101 | + if isinstance(cache, QuantizedCache): |
| 102 | + cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) |
| 103 | + else: |
| 104 | + cache.key_cache[module.layer_idx] = keys |
| 105 | + |
| 106 | + return output |
0 commit comments