|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | + |
| 5 | +from dataclasses import dataclass |
| 6 | + |
| 7 | +import pytest |
| 8 | +import torch |
| 9 | +from torch import nn |
| 10 | +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half |
| 11 | + |
| 12 | +from kvpress import KeyRerotationPress, ScorerPress |
| 13 | +from kvpress.presses.key_rerotation_press import get_rope_embeddings |
| 14 | +from tests.fixtures import unit_test_model # noqa: F401 |
| 15 | + |
| 16 | + |
| 17 | +@pytest.mark.parametrize("precision", ["full", "half"]) |
| 18 | +def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM, precision): # noqa: F811 |
| 19 | + """ |
| 20 | + Compare KeyRerotationPress' rerotation of keys with the reference implementation. |
| 21 | + In the reference implementation, we are computing |
| 22 | + 1. keys = W_k * hidden_states |
| 23 | + 2. keys_pruned = prune(keys) |
| 24 | + 3. keys = RoPE(keys_pruned) |
| 25 | + """ |
| 26 | + if precision == "half" and torch.cuda.is_available(): |
| 27 | + unit_test_model = unit_test_model.cuda().half() |
| 28 | + elif precision == "half" and not torch.cuda.is_available(): |
| 29 | + pytest.skip("Half precision test is skipped because CUDA is not available.") |
| 30 | + |
| 31 | + original_press = RandomPressStoreIndices(compression_ratio=0.5) |
| 32 | + key_rerotation_press = KeyRerotationPress(press=original_press) |
| 33 | + |
| 34 | + module = unit_test_model.model.layers[0].self_attn |
| 35 | + hidden_states = torch.randn( |
| 36 | + 8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype |
| 37 | + ) |
| 38 | + |
| 39 | + keys = get_keys_with_rope(module, hidden_states) |
| 40 | + |
| 41 | + values = torch.randn_like(keys) |
| 42 | + # Press result |
| 43 | + keys_compressed, _ = key_rerotation_press.compress( |
| 44 | + module, hidden_states, keys, values, attentions=None, kwargs=dict() |
| 45 | + ) |
| 46 | + |
| 47 | + indices = original_press.indices |
| 48 | + keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices) |
| 49 | + |
| 50 | + assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6 if precision == "full" else 1e-3) |
| 51 | + |
| 52 | + |
| 53 | +def get_keys_with_rope(module, hidden_states): |
| 54 | + # Compute keys with RoPE |
| 55 | + keys = get_keys_without_pos_embedding(module, hidden_states) |
| 56 | + cos, sin = get_rope_embeddings(module, keys) |
| 57 | + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) |
| 58 | + return keys |
| 59 | + |
| 60 | + |
| 61 | +@dataclass |
| 62 | +class RandomPressStoreIndices(ScorerPress): |
| 63 | + compression_ratio: float = 0.0 |
| 64 | + seed: int = 0 |
| 65 | + |
| 66 | + def __post_init__(self): |
| 67 | + self.indices = None |
| 68 | + super().__post_init__() |
| 69 | + |
| 70 | + def score( |
| 71 | + self, |
| 72 | + module: nn.Module, |
| 73 | + hidden_states: torch.Tensor, |
| 74 | + keys: torch.Tensor, |
| 75 | + values: torch.Tensor, |
| 76 | + attentions: torch.Tensor, |
| 77 | + kwargs, |
| 78 | + ) -> torch.Tensor: |
| 79 | + torch.manual_seed(self.seed) |
| 80 | + scores = torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype) |
| 81 | + # Get indices of KV pairs with the lowest scores |
| 82 | + q_len = hidden_states.shape[1] |
| 83 | + n_kept = int(q_len * (1 - self.compression_ratio)) |
| 84 | + indices = scores.topk(n_kept, dim=-1).indices |
| 85 | + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) |
| 86 | + self.indices = indices |
| 87 | + |
| 88 | + return scores |
| 89 | + |
| 90 | + |
| 91 | +def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hidden_states, indices): |
| 92 | + """ |
| 93 | + Computes the rerotated keys for the given indices. |
| 94 | + 1. keys = W_k * hidden_states |
| 95 | + 2. keys_pruned = prune(keys) |
| 96 | + 3. keys = RoPE(keys_pruned) |
| 97 | + """ |
| 98 | + # 1. |
| 99 | + keys = get_keys_without_pos_embedding(module, hidden_states) |
| 100 | + # 2. |
| 101 | + keys = keys.gather(2, indices).contiguous() |
| 102 | + # 3. |
| 103 | + cos, sin = get_rope_embeddings(module, keys) |
| 104 | + keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) |
| 105 | + return keys |
| 106 | + |
| 107 | + |
| 108 | +def get_keys_without_pos_embedding(module, hidden_states): |
| 109 | + key_states = module.k_proj(hidden_states) |
| 110 | + key_states = key_states.view( |
| 111 | + key_states.shape[0], key_states.shape[1], module.num_key_value_heads, module.head_dim |
| 112 | + ).transpose(1, 2) |
| 113 | + return key_states |
0 commit comments