Skip to content

Commit b98e61a

Browse files
committed
target attention for hooks, allow non-forced zps
Signed-off-by: Kyle Sayers <[email protected]>
1 parent e014b64 commit b98e61a

File tree

2 files changed

+34
-30
lines changed

2 files changed

+34
-30
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def call_observer(
8888
if should_calculate_qparams:
8989
scale, zero_point = observer(value)
9090
update_offload_parameter(module, f"{base_name}_scale", scale)
91-
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)
91+
if hasattr(module, f"{base_name}_zero_point"):
92+
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)
9293

9394

9495
def update_weight_global_scale(module: Module):

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def resolved_targets(self) -> Set[str]:
152152
for config_group in self.resolved_config.config_groups.values():
153153
for target in config_group.targets:
154154
targets.add(target)
155+
156+
if self.resolved_config.kv_cache_scheme is not None:
157+
targets.add("re:.*self_attn$")
158+
155159
return targets
156160

157161
def initialize_quantization(self, model: torch.nn.Module):
@@ -177,9 +181,9 @@ def start_calibration(self, model: torch.nn.Module):
177181
178182
:param model: model to prepare for calibration
179183
"""
180-
self._calibration_hooks = self._initialize_hooks(model)
181184
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
182185
self._initialize_observers(module)
186+
self._calibration_hooks |= self._initialize_hooks(module)
183187
apply_calibration_status(module)
184188

185189
model.apply(enable_quantization) # quantize at the same time as calibrate
@@ -284,35 +288,34 @@ def _initialize_observers(self, module: torch.nn.Module):
284288
if output:
285289
initialize_observer(module, base_name="output")
286290

287-
def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
291+
def _initialize_hooks(self, module: torch.nn.Module) -> Set[RemovableHandle]:
288292
hooks = set()
289-
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
290-
if not hasattr(module, "quantization_scheme"):
291-
continue
293+
if not hasattr(module, "quantization_scheme"):
294+
hooks
292295

293-
scheme: QuantizationScheme = module.quantization_scheme
294-
input = scheme.input_activations and scheme.input_activations.dynamic in (
295-
False,
296-
DynamicType.LOCAL,
297-
)
298-
output = scheme.output_activations and not scheme.output_activations.dynamic
299-
is_attention = is_attention_module(module)
300-
301-
# input activations
302-
if input:
303-
if not is_attention:
304-
hooks.add(
305-
self.register_hook(module, calibrate_input_hook, "forward_pre")
306-
)
307-
else:
308-
if hasattr(module, IMPL_ATTR):
309-
hooks.add(register_query_hook(module, calibrate_query_hook))
310-
if hasattr(module, KV_CACHE_ATTR):
311-
hooks.add(register_key_hook(module, calibrate_key_hook))
312-
hooks.add(register_value_hook(module, calibrate_value_hook))
313-
314-
# output activations
315-
if output:
316-
hooks.add(self.register_hook(module, calibrate_output_hook, "forward"))
296+
scheme: QuantizationScheme = module.quantization_scheme
297+
input = scheme.input_activations and scheme.input_activations.dynamic in (
298+
False,
299+
DynamicType.LOCAL,
300+
)
301+
output = scheme.output_activations and not scheme.output_activations.dynamic
302+
is_attention = is_attention_module(module)
303+
304+
# input activations
305+
if input:
306+
if not is_attention:
307+
hooks.add(
308+
self.register_hook(module, calibrate_input_hook, "forward_pre")
309+
)
310+
else:
311+
if hasattr(module, IMPL_ATTR):
312+
hooks.add(register_query_hook(module, calibrate_query_hook))
313+
if hasattr(module, KV_CACHE_ATTR):
314+
hooks.add(register_key_hook(module, calibrate_key_hook))
315+
hooks.add(register_value_hook(module, calibrate_value_hook))
316+
317+
# output activations
318+
if output:
319+
hooks.add(self.register_hook(module, calibrate_output_hook, "forward"))
317320

318321
return hooks

0 commit comments

Comments
 (0)