@@ -152,6 +152,10 @@ def resolved_targets(self) -> Set[str]:
152152 for config_group in self .resolved_config .config_groups .values ():
153153 for target in config_group .targets :
154154 targets .add (target )
155+
156+ if self .resolved_config .kv_cache_scheme is not None :
157+ targets .add ("re:.*self_attn$" )
158+
155159 return targets
156160
157161 def initialize_quantization (self , model : torch .nn .Module ):
@@ -177,9 +181,9 @@ def start_calibration(self, model: torch.nn.Module):
177181
178182 :param model: model to prepare for calibration
179183 """
180- self ._calibration_hooks = self ._initialize_hooks (model )
181184 for _ , module in match_named_modules (model , self .resolved_targets , self .ignore ):
182185 self ._initialize_observers (module )
186+ self ._calibration_hooks |= self ._initialize_hooks (module )
183187 apply_calibration_status (module )
184188
185189 model .apply (enable_quantization ) # quantize at the same time as calibrate
@@ -284,35 +288,34 @@ def _initialize_observers(self, module: torch.nn.Module):
284288 if output :
285289 initialize_observer (module , base_name = "output" )
286290
287- def _initialize_hooks (self , model : torch .nn .Module ) -> Set [RemovableHandle ]:
291+ def _initialize_hooks (self , module : torch .nn .Module ) -> Set [RemovableHandle ]:
288292 hooks = set ()
289- for _ , module in match_named_modules (model , self .resolved_targets , self .ignore ):
290- if not hasattr (module , "quantization_scheme" ):
291- continue
293+ if not hasattr (module , "quantization_scheme" ):
294+ hooks
292295
293- scheme : QuantizationScheme = module .quantization_scheme
294- input = scheme .input_activations and scheme .input_activations .dynamic in (
295- False ,
296- DynamicType .LOCAL ,
297- )
298- output = scheme .output_activations and not scheme .output_activations .dynamic
299- is_attention = is_attention_module (module )
300-
301- # input activations
302- if input :
303- if not is_attention :
304- hooks .add (
305- self .register_hook (module , calibrate_input_hook , "forward_pre" )
306- )
307- else :
308- if hasattr (module , IMPL_ATTR ):
309- hooks .add (register_query_hook (module , calibrate_query_hook ))
310- if hasattr (module , KV_CACHE_ATTR ):
311- hooks .add (register_key_hook (module , calibrate_key_hook ))
312- hooks .add (register_value_hook (module , calibrate_value_hook ))
313-
314- # output activations
315- if output :
316- hooks .add (self .register_hook (module , calibrate_output_hook , "forward" ))
296+ scheme : QuantizationScheme = module .quantization_scheme
297+ input = scheme .input_activations and scheme .input_activations .dynamic in (
298+ False ,
299+ DynamicType .LOCAL ,
300+ )
301+ output = scheme .output_activations and not scheme .output_activations .dynamic
302+ is_attention = is_attention_module (module )
303+
304+ # input activations
305+ if input :
306+ if not is_attention :
307+ hooks .add (
308+ self .register_hook (module , calibrate_input_hook , "forward_pre" )
309+ )
310+ else :
311+ if hasattr (module , IMPL_ATTR ):
312+ hooks .add (register_query_hook (module , calibrate_query_hook ))
313+ if hasattr (module , KV_CACHE_ATTR ):
314+ hooks .add (register_key_hook (module , calibrate_key_hook ))
315+ hooks .add (register_value_hook (module , calibrate_value_hook ))
316+
317+ # output activations
318+ if output :
319+ hooks .add (self .register_hook (module , calibrate_output_hook , "forward" ))
317320
318321 return hooks
0 commit comments