Importance weight based sparse attention implementation for auto-regressive decoding. #2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Transformers are powerful sequence models but require time and memory that grows quadratically with the sequence length. To support a longer input context, many research efforts have been made to reduce the KV cache and speed up the model inference.
This PR implements a relatively simple way to limit the KV cache size inspired by the findings in https://arxiv.org/abs/2305.17118. In this PR, a weight-based cache eviction is added on top of the circular cache eviction policy. Instead of only keeping the local k keys and values, we can also make sure the highest k weighted key and values are not dropped when the cache is at the limit. The weight is calculated simply by the Q*K result in the previous step.
Empirical results measured from a few public datasets have shown that this simple sparse attention policy can greatly improve the completion speed while retaining the majority of the completion quality. Please feel free to contact me if you are interested in details.