-
Notifications
You must be signed in to change notification settings - Fork 33
[Transform] Attention/Cache transforms #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, though i have a number of questions and minor suggestions
# assumes only one model at a time | ||
global _original_impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😬 i don't want to delay things, but we should briefly consider if there are alternative solutions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I spent 20 minutes exploring this, it requires creating specialized _ct_hooked_attention
functions and specialized QuantizedAttentionImpl
, which is more complexity than value added imho
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can _original_impl
be registered on the module level (i.e. each self_attn block) instead of setting a global var?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but in order to register the _original_impl
, it needs to be gotten from somewhere.
The first time, you "get" it from model.config
. However on subsequent calls, model.config
is overridden. This means that in order to "get" the original implementation, you'd have to go find the last Attention module you registered it to, or else store it in some global store.
You could register it to the model module itself or something like that, but I think that that's less reliable than just a a global store. If it's functionality you're after, we can turn it into a hash table or something, keyed by model hash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just trying to figure out what the lifetime of _original_impl
is, once everything is set it can basically be treated as no longer necessary? Or is it something that is important the entire duration, during model loading as well as during any forward passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is all net-new, so it's not like we will breaking anything pre-existing. global vars make me nervous, but this seems like a legitimate enough use case to use them and accept the risk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it makes you feel any better, this variable is still only scoped to this file. This is the same as any module-scoped read, only this time we're writing to it, and therefore need the global
keyword
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the goal is to use this generally for kv_cache and attn quantize, can we move the initialize_hooked_attention
and initialize_hooked_kv_cache
to initialize.py
?
I understand we haven't hooked them in yet for those workflows but I think these belong there.
7bf4b57
to
75056bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do a pass through on any missing docstring, otherwise lgtm.
nice work
The base branch was changed.
e224a5d
to
05ec17e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following for the most part. A few clarifications, but this makes sense to me
# assumes only one model at a time | ||
global _original_impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can _original_impl
be registered on the module level (i.e. each self_attn block) instead of setting a global var?
d084c5e
to
e3f24d4
Compare
The base branch was changed.
145c9aa
to
2efe3db
Compare
7c19358
to
04f716a
Compare
Last nightly worked, but e2e failed due to model storage issues |
931df60
to
4cc5ace
Compare
Signed-off-by: Kyle Sayers <[email protected]>
4cc5ace
to
9ead292
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can resolve the global var thread, I have another new comment we might want to consider in a follow-up but marking this as approved. Cool stuff! Excited to see it in action
# assumes only one model at a time | ||
global _original_impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is all net-new, so it's not like we will breaking anything pre-existing. global vars make me nervous, but this seems like a legitimate enough use case to use them and accept the risk
# use any status from modules (in practice, use the last module) | ||
model_status = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like something that might've been missed in the scoped quant work. If multiple statuses are found, rather than just using the last one found don't we want to set format="mixed-precision"
in the returned QuantizationConfig?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Essentially, you're right. However, this value is essentially meaningless, as it is later overridden by the model compressor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried to leave the functionality of this function as unchanged as possible for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, just wanted to see if we should create a ticket to fix in follow-up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd eventually like to remove this automatic inference all-together and simply use the config specified by apply
. That sort of refactor would allow the config to retain the same meaningful scheme names that the user provided, be much simpler to read, and avoid all this repackaging logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #494
Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some questions. Otherwise, LGTM
if scheme.weights is not None: | ||
raise ValueError( | ||
"Cannot apply weight quantization to attention. " | ||
"Instead, target (q|k|v)_proj" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error doesnt make a lot of sense / took me a while to realize you're saying that if you want to do weight quantization, you should target the linear layers in the attn block, not attention itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this clearer?
raise ValueError(
"Cannot apply weight quantization to attention. "
"Instead, target the (q|k|v)_proj submodule layers of attention"
""" | ||
if not hasattr(module, KV_CACHE_ATTR): | ||
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module)) | ||
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm reading this correctly, _kv_cache_attention_hook is called before every forward pass? So we're replacing the kv_cache before every forward pass with the new quantized cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's exactly correct. I've buffed up the docstrings to make this clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QuantizedKVCache injects itself into the model by overriding the past_key_values input kwarg to attention, and wrapping the functionality of the original cache
# ----- hooks ----- # | ||
|
||
|
||
def register_key_hook( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't seem to find where the key / value hooks get registered
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These hooks are used to attach observer hooks (and any other hooks we might want to add in the future), see here
# infer format | ||
if format is None: | ||
if quantization_status == QuantizationStatus.COMPRESSED: | ||
if model_status == QuantizationStatus.COMPRESSED: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is unrelated but defaulting to int doesnt make a lot of sense either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. This was the original behavior of this logic.
quantization_status = None | ||
ignore = {} | ||
quantization_type_names = set() | ||
from compressed_tensors.quantization.lifecycle.initialize import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning this up. It doesn't seem like we're adding anything here, apart from how we're fetching the kv_cache scheme?
I still find our ignore logic very confusing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I entirely agree, I've created an issue to track potential removal #494.
This PR does not change behavior, only makes the existing logic easier to read and adds this line to infer kv cache scheme
# attention quantization implies kv cache quantization
if is_attention_module(submodule):
kv_cache_scheme = submodule.quantization_scheme.input_activations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the sake of completeness, do you mind adding your kv_cache and attn quantized sample models to this PR description?
Signed-off-by: Kyle Sayers <[email protected]>
) | ||
else: | ||
ret = (key_states, value_states) | ||
self.past_key_values = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we set this to None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensures that the cache is only used once. This should theoretically never be a problem, since the self.past_key_values
attribute is always written to by the _kv_cache_attention_hook
, but this is done just for peace of mind and to avoid dangling references, even if they are weak.
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Purpose
Prerequisites
Changes
New Classes
QuantizedAttentionImpl
injects itself into the model by registering a new attention implementation calledct_hooked_attention
overridingmodel.config._attn_implementation
to be the new implementation nameQuantizedKVCache
injects itself into the model by overriding thepast_key_values
input kwarg to attention, and wrapping the functionality of the original cacheregister_query_hook
,register_key_hook
register_value_hook
Quantization Lifecycle Changes
initialize_hooked_kv_cache
initialize_hooked_attention
if attention modules are explicitly targeted (seeis_narrow_match
)initialize_module_for_quantization
QuantizationConfig.from_pretrained
was cleaned up with additional commentskv_cache_scheme
field is added if there are any attention modules with aquantization_scheme
attachedHelpers
is_narrow_match
is used to check that attention modules are being specifically targeted (rather than targeting all modules in a layer)get_num_attn_heads
,get_num_kv_heads
,get_head_dim
get attention config values from configTesting
is_narrow_match