-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ThinKPress #20
Add ThinKPress #20
Conversation
@yuhuixu1993 here is a proposal for the implementation of
Please confirm this press is what you want. I explicitely mentioned you in the docstring as a reviewer.
|
@SimJeg, It looks awesome!!! Thanks for your hard work!! Sorry for the late response as we have time difference(I live in Singapore). While I notice that in this implementation channel pruning is applied in all tokens in key cache, we prefer to keep the most recent tokens unchanged e.g. 32. I did not test the performance of current evaluations if we prune all the tokens. keys = torch.cat([keys[:, :, :q_len - self.window_size, :].masked_fill(mask_k, 0), keys[:, :, q_len - self.window_size:, :]], dim=-2) |
I will update. Are you using 32 for the window size too ? (I've been using 64 as in the default of SnapKV). |
Thanks. Yes, I use 32 in the paper, but 64 is definitely OK. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this PR.
From the technical side, @yuhuixu1993 already gave some great feedback.
I left a comment regarding RoPE computation where I'm not sure about the dimensions.
Apart from that, it would be great to add an additional test that tests the inner_press
functionality, as well.
@yuhuixu1993 I tried with and without 0-ing the channels for the last 32 tokens and did not see any difference in the prompt I tried. May I keep the current version ? I'm asking because if other similar presses come, would be nice to have a uniform API and not very custom changes. Also to implement what you ask we need to slightly update the compression ratio to take into account the 32*n_pruned_channels elements that are not removed. |
@SimJeg I think it is OK with current version, as the performance is OK, many thanks for the experiments!! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot LGTM!
I inspected if module.rotary_emb(query_states, q_len)
is needed, and it seems that
- mixtral
- gpt neox
- open_llama
- idefics (cross attention)
are using it. As transformers converges to using module.rotary_emb(query_states, position_ids)
, I don't think we need to support these models.
Add ThinKPress (NVIDIA#20)
Implementation of ThinKV following #18