Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ThinKPress #20

Merged
merged 4 commits into from
Dec 3, 2024
Merged

Add ThinKPress #20

merged 4 commits into from
Dec 3, 2024

Conversation

SimJeg
Copy link
Collaborator

@SimJeg SimJeg commented Nov 28, 2024

Implementation of ThinKV following #18

@SimJeg SimJeg mentioned this pull request Nov 28, 2024
@SimJeg
Copy link
Collaborator Author

SimJeg commented Nov 28, 2024

@yuhuixu1993 here is a proposal for the implementation of ThinKPress based on our previous discussions in #18:

  • As in your proposal, I'm zeroing the pruned dimensions
  • I added support to optionally combine it with any press, and in my first experiments it worked great ! As you required, the inner press is applied before ThinK
  • I also ensured support for quantization

Please confirm this press is what you want. I explicitely mentioned you in the docstring as a reviewer.

ThinKPress is the first press in this repo that compress the channel dimension, hence the code is bit more complex. If other similar presses are proposed, we will refactor the code to make the implementation easier (e.g. we can imagine a SequenceBasePress and a DimensionBasePress with options to compose them in one way or the other).

@SimJeg SimJeg requested a review from maxjeblick November 28, 2024 17:42
@yuhuixu1993
Copy link

yuhuixu1993 commented Nov 29, 2024

@SimJeg, It looks awesome!!! Thanks for your hard work!! Sorry for the late response as we have time difference(I live in Singapore). While I notice that in this implementation channel pruning is applied in all tokens in key cache, we prefer to keep the most recent tokens unchanged e.g. 32. I did not test the performance of current evaluations if we prune all the tokens.
In my previous PR:

keys = torch.cat([keys[:, :, :q_len - self.window_size, :].masked_fill(mask_k, 0), keys[:, :, q_len - self.window_size:, :]], dim=-2)

@SimJeg
Copy link
Collaborator Author

SimJeg commented Dec 2, 2024

I will update. Are you using 32 for the window size too ? (I've been using 64 as in the default of SnapKV).

@yuhuixu1993
Copy link

Thanks. Yes, I use 32 in the paper, but 64 is definitely OK.

Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this PR.
From the technical side, @yuhuixu1993 already gave some great feedback.

I left a comment regarding RoPE computation where I'm not sure about the dimensions.

Apart from that, it would be great to add an additional test that tests the inner_press functionality, as well.

@SimJeg
Copy link
Collaborator Author

SimJeg commented Dec 3, 2024

@yuhuixu1993 I tried with and without 0-ing the channels for the last 32 tokens and did not see any difference in the prompt I tried. May I keep the current version ? I'm asking because if other similar presses come, would be nice to have a uniform API and not very custom changes. Also to implement what you ask we need to slightly update the compression ratio to take into account the 32*n_pruned_channels elements that are not removed.

@yuhuixu1993
Copy link

@SimJeg I think it is OK with current version, as the performance is OK, many thanks for the experiments!!

Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot LGTM!

I inspected if module.rotary_emb(query_states, q_len) is needed, and it seems that

  • mixtral
  • gpt neox
  • open_llama
  • idefics (cross attention)

are using it. As transformers converges to using module.rotary_emb(query_states, position_ids) , I don't think we need to support these models.

@SimJeg SimJeg merged commit ac2445e into main Dec 3, 2024
2 checks passed
@SimJeg SimJeg deleted the simon/think-press branch December 3, 2024 15:29
FFY0 added a commit to FFY0/AdaKV-in-NVIDIA-kvpress that referenced this pull request Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants