Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ThinK_Press #18

Closed
wants to merge 3 commits into from
Closed

Add ThinK_Press #18

wants to merge 3 commits into from

Conversation

yuhuixu1993
Copy link

@yuhuixu1993 yuhuixu1993 commented Nov 28, 2024

Hi KVPress authors,

Great work on this project! I’m currently working on integrating our new key cache channel pruning method, ThinK, into your framework. To minimize changes to the existing repository structure, I’ve implemented our method in a synthetic manner by zeroing out the least important channels. Alternatively, we could adjust the implementation by modifying either the cache_utils or the forward pass of the LLMs to accommodate different LLM architectures. Let me know if you have any suggestions or preferences regarding this integration approach.

The code is here for your review.

BTW, I evaluate our results on Ruler with llama3-8B-Instruct with 50% key cache channels pruned: The results are follows:
{'cwe': {'string_match': 98.96}, 'fwe': {'string_match': 94.67}, 'niah_multikey_1': {'string_match': 99.4}, 'niah_multikey_2': {'string_match': 97.6}, 'niah_multikey_3': {'string_match': 99.6}, 'niah_multiquery': {'string_match': 99.55}, 'niah_multivalue': {'string_match': 97.9}, 'niah_single_1': {'string_match': 96.8}, 'niah_single_2': {'string_match': 98.2}, 'niah_single_3': {'string_match': 99.8}, 'qa_1': {'string_match': 75.8}, 'qa_2': {'string_match': 59.4}, 'vt': {'string_match': 99.6}}

@SimJeg
Copy link
Collaborator

SimJeg commented Nov 28, 2024

Hi @yuhuixu1993 !

Thanks for proposing your work and congrats for the promising results on RULER ! Before commenting on its implementation, let's make sure I understand it. ThinK is a method compressing the channel dimension of the KV cache, as opposed to the sequence dimension in many other works.

Similarly to other presses, you define a score, this time with a shape (bsz, num_key_value_heads, head_dim) as opposed to the usual (bsz, num_key_value_heads, seq_len). You the prune the channels based on this score.

The score is defined as

channel_attn = torch.matmul(query_states.permute(0, 1, 3, 2).unsqueeze(-1), key_states.transpose(2, 3).unsqueeze(-2)) 
# shape: (bsz, num_heads, head_dim, window_size, seq_len)
channel_score = channel_attn.pow_(2).sum(dim=(-1, -2))
# shape: (bsz, num_heads, head_dim)

where the query_states are associated with the last 64 tokens as in SnapKV. Note that permute(0, 1, 3, 2) can be replaced by transpose(2, 3) for the sake of code uniformity.

Interpretation: for a given head h and dimension d, channel_score is the sum of all $q_i*k_j$ across all positions $i$ and $j$ and hence measure the importance of this dimension d.

A few questions:

  • doesn't channel_attn becomes prohibitively large for very long sequences ?
  • would you also like to implement value cache pruning as mentioned in appendix D of your paper ?

I think you might be interested by our ExpectedAttentionPress and the associated notebook. It might avoid the need to materialize channel_attn.

It seems ThinK can be combined with other presses that focus on the sequence dimension. If so what is the best way ? Apply ThinK before, simultaneously or after the press ?

@SimJeg
Copy link
Collaborator

SimJeg commented Nov 28, 2024

Another question:

  • If you zero the pruned dimension in the key you reproduce the ThinK outputs but do not reduce the memory
  • If you don't, you indeed need to change the cache implementation to support both compressed keys and uncompressed keys

In the latter case, is the method compatible with flash attention ? One solution would be to re-add 0 each time you need to use the key cache but I might be inefficient (although this could be done of parallel of the previous forward pass). If you don't do that you need to create a new kernel for efficient attention.

@yuhuixu1993
Copy link
Author

yuhuixu1993 commented Nov 28, 2024

@SimJeg

  1. Hi, thanks for your suggestions, I did not notice channel_attn. I will considering the ExpectedAttentionPress.
    I am considering use the following implementation:
query_norm = query_states.pow(2).sum(dim=-1)  # Shape: (bsz, num_heads, head_dim)
key_norm = key_states.pow(2).sum(dim=-1)      # Shape: (bsz, num_heads, head_dim)

# Compute the channel score as the element-wise product
channel_score = query_norm * key_norm         # Shape: (bsz, num_heads, head_dim)
  1. Yes, I can implement value cache pruning.
  2. Zero the cache do not reduce the memory, so we need to store the pruned and unpruned key cache separately. Yes, one solution is add zero back and another is we need to prune the query as follows. while this may not support flash attention (while this implementation may involve modify the forward function of attention.). We can support flash-attention for the first token generation just like the flash-implementation implementation of KIVI.
image 4. Can we just use the synthetic version (zero-out version)? It will be more complex to involve the real implementation in the current repo.

@SimJeg
Copy link
Collaborator

SimJeg commented Nov 28, 2024

  1. About the size of channel_attn, how did you implement it for the figures reported in the paper (Figure 3) ? For instance for llama-3.1 8b, channel_attn has a size (bsz, num_heads, head_dim, window_size, seq_len) so (bsz, 32, head_dim, 64, seq_len) while full key size for all layers is (bsz, num_layers, num_key_values_heads, head_dim, seq_len) is (bsz, 32, 8, head_dim, seq_len). So channel_attn is 8x bigger than the full KV cache. Maybe I miss something here

1bis. About ExpectedAttention, this is nothing mandatory ofc, just some self promotion. Goal is to reproduce ThinK, not to invent something new !

  1. Again no mandatory, we can start with key only if you prefer

  2. Got it, so there are three options:

  • 3a. add zero in the key_cache. Pros: easy to implement. Cons: does not reduce memory
  • 3b. custom Cache class with zero back function during decoding. Pros: quite easy to implement. Cons: might have some compute overhead to add back zeros.
  • 3.c custom Cache + custom Attention layer that re-implements the attention as proposed above using a modification of the eager attention. Pros: better for memory. Cons: liekly heavier implementation
  1. You did not answer on the combination with other presses. How do you envision it ? Do you want to also prune sequence length or do "pure" ThinK and only prune dimensions ?

Am I correct on 3 ?

@yuhuixu1993
Copy link
Author

yuhuixu1993 commented Nov 28, 2024

@SimJeg

  1. Yes I use this implementation. But Figure3 is tested follow the settings of KIVI, which only involve input prompt length of 160 and an output length of 338. and we tested on llama2 so there is no GQA. But many thanks for helping me fix this issue!!! This is the link of our implementation:https://github.com/yuhuixu1993/ThinK/tree/main/ThinK_kivi
    The results will be much better if we change the implementation.
query_norm = query_states.pow(2).sum(dim=-1)  # Shape: (bsz, num_heads, head_dim)
key_norm = key_states.pow(2).sum(dim=-1)      # Shape: (bsz, num_heads, head_dim)

# Compute the channel score as the element-wise product
channel_score = query_norm * key_norm         # Shape: (bsz, num_heads, head_dim)
  1. Thanks, I think it is better to use key cache pruning first.
  2. Yes, you are correct
  3. Yes, we can combine with token-eviction like SnapKV. In my experiments I first pruned the tokens then prune the channels. Both are OK with me (want to also prune sequence length or do "pure" ThinK and only prune dimensions ). But I think we can have the pure think first.

@SimJeg
Copy link
Collaborator

SimJeg commented Nov 28, 2024

  1. Your repository is private so I can't look at your code. Your edited message contains a much more efficient implementation !
  2. ok
  3. I propose to start with the 0 solution and I will submit a new PR if there are significant changes
  4. In that case ThinKPress can takes as input another press.

I will have a look and propose you some code based on the your current PR

@yuhuixu1993
Copy link
Author

Sorry, this is the link!https://github.com/SalesforceAIResearch/ThinK/blob/main/ThinK_kivi/models/llama_kivi_think.py. Thanks for your kindness help!

@SimJeg SimJeg mentioned this pull request Nov 28, 2024
@SimJeg SimJeg closed this Nov 28, 2024
@SimJeg
Copy link
Collaborator

SimJeg commented Nov 28, 2024

please refer to #20 moving forward

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants