-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ThinK_Press #18
Add ThinK_Press #18
Conversation
Hi @yuhuixu1993 ! Thanks for proposing your work and congrats for the promising results on RULER ! Before commenting on its implementation, let's make sure I understand it. ThinK is a method compressing the channel dimension of the KV cache, as opposed to the sequence dimension in many other works. Similarly to other presses, you define a score, this time with a shape The score is defined as channel_attn = torch.matmul(query_states.permute(0, 1, 3, 2).unsqueeze(-1), key_states.transpose(2, 3).unsqueeze(-2))
# shape: (bsz, num_heads, head_dim, window_size, seq_len)
channel_score = channel_attn.pow_(2).sum(dim=(-1, -2))
# shape: (bsz, num_heads, head_dim) where the Interpretation: for a given head h and dimension d, channel_score is the sum of all A few questions:
I think you might be interested by our It seems ThinK can be combined with other presses that focus on the sequence dimension. If so what is the best way ? Apply ThinK before, simultaneously or after the press ? |
Another question:
In the latter case, is the method compatible with flash attention ? One solution would be to re-add 0 each time you need to use the key cache but I might be inefficient (although this could be done of parallel of the previous forward pass). If you don't do that you need to create a new kernel for efficient attention. |
1bis. About
Am I correct on 3 ? |
query_norm = query_states.pow(2).sum(dim=-1) # Shape: (bsz, num_heads, head_dim)
key_norm = key_states.pow(2).sum(dim=-1) # Shape: (bsz, num_heads, head_dim)
# Compute the channel score as the element-wise product
channel_score = query_norm * key_norm # Shape: (bsz, num_heads, head_dim)
|
I will have a look and propose you some code based on the your current PR |
Sorry, this is the link!https://github.com/SalesforceAIResearch/ThinK/blob/main/ThinK_kivi/models/llama_kivi_think.py. Thanks for your kindness help! |
please refer to #20 moving forward |
Hi KVPress authors,
Great work on this project! I’m currently working on integrating our new key cache channel pruning method, ThinK, into your framework. To minimize changes to the existing repository structure, I’ve implemented our method in a synthetic manner by zeroing out the least important channels. Alternatively, we could adjust the implementation by modifying either the cache_utils or the forward pass of the LLMs to accommodate different LLM architectures. Let me know if you have any suggestions or preferences regarding this integration approach.
The code is here for your review.
BTW, I evaluate our results on Ruler with llama3-8B-Instruct with 50% key cache channels pruned: The results are follows:
{'cwe': {'string_match': 98.96}, 'fwe': {'string_match': 94.67}, 'niah_multikey_1': {'string_match': 99.4}, 'niah_multikey_2': {'string_match': 97.6}, 'niah_multikey_3': {'string_match': 99.6}, 'niah_multiquery': {'string_match': 99.55}, 'niah_multivalue': {'string_match': 97.9}, 'niah_single_1': {'string_match': 96.8}, 'niah_single_2': {'string_match': 98.2}, 'niah_single_3': {'string_match': 99.8}, 'qa_1': {'string_match': 75.8}, 'qa_2': {'string_match': 59.4}, 'vt': {'string_match': 99.6}}