diff --git a/experimental/llama3_attention.py b/experimental/llama3_attention.py new file mode 100644 index 0000000000..daa953a1df --- /dev/null +++ b/experimental/llama3_attention.py @@ -0,0 +1,87 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils import dispatch_for_generation +from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs + +# Select model and load it. +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +recipe = QuantizationModifier( + config_groups={ + "attention": QuantizationScheme( + targets=["LlamaAttention"], + input_activations=QuantizationArgs( + num_bits=8, type="float", strategy="attn_head" + ), + ) + } +) + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to(model.device) for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-attention-fp8-head" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/quantization/__init__.py b/src/llmcompressor/modifiers/quantization/__init__.py index f6ad149fbb..1ca6912221 100644 --- a/src/llmcompressor/modifiers/quantization/__init__.py +++ b/src/llmcompressor/modifiers/quantization/__init__.py @@ -1,5 +1,4 @@ # ruff: noqa -from .cache import * from .gptq import * from .quantization import * diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py deleted file mode 100644 index 53eca8d075..0000000000 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ /dev/null @@ -1,218 +0,0 @@ -""" -Quantized key-value cache implementation for efficient inference. - -Provides quantized KV cache classes extending HuggingFace's -DynamicCache with quantization support. Enables memory-efficient attention -mechanisms by quantizing cached key and value tensors during model -inference with configurable quantization strategies. -""" - -from typing import Any, Dict, List, Optional, Tuple - -from compressed_tensors.quantization import KVCacheScaleType, QuantizationArgs -from torch import Tensor -from transformers import DynamicCache - -from llmcompressor.observers import Observer - - -class QuantizedKVParameterCache(DynamicCache): - """ - Quantized KV cache used in the forward call based on HF's dynamic cache. - Quantization strategy (tensor, group, channel) set from Quantization arg's strategy - Singleton, so that the same cache gets reused in all forward call of self_attn. - Each time forward is called, .update() is called, and ._quantize(), ._dequantize() - gets called appropriately. - The size of tensor is - `[batch_size, num_heads, seq_len - residual_length, head_dim]`. - - - Triggered by adding kv_cache_scheme in the recipe. - - Example: - - ```python3 - recipe = ''' - quant_stage: - quant_modifiers: - QuantizationModifier: - kv_cache_scheme: - num_bits: 8 - type: float - strategy: tensor - dynamic: false - symmetric: true - ''' - - """ - - _instance = None - _initialized = False - - def __new__(cls, *args, **kwargs): - """Singleton""" - if cls._instance is None: - cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) - return cls._instance - - def __init__(self, quantization_args: QuantizationArgs): - if not self._initialized: - super().__init__() - - self.quantization_args = quantization_args - - self.k_observers: List[Observer] = [] - self.v_observers: List[Observer] = [] - - # each index corresponds to layer_idx of the attention layer - self.k_scales: List[Tensor] = [] - self.v_scales: List[Tensor] = [] - - self.k_zps: List[Tensor] = [] - self.v_zps: List[Tensor] = [] - - self._initialized = True - - def update( - self, - key_states: Tensor, - value_states: Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Get the k_scale and v_scale and output the - fakequant-ed key_states and value_states - """ - - if len(self.k_observers) <= layer_idx: - k_observer = Observer.load_from_registry( - self.quantization_args.observer, - base_name="k", - args=self.quantization_args, - ) - v_observer = Observer.load_from_registry( - self.quantization_args.observer, - base_name="v", - args=self.quantization_args, - ) - - # NOTE: User may ignore some layers in configuration, - # meaning len(self.k_observers) <= layer_idx-1 - # Must account for that case by padding list so that - # index of lists corresponds to layer_idx - _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) - _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) - - q_key_states = self._quantize( - key_states.contiguous(), KVCacheScaleType.KEY, layer_idx - ) - q_value_states = self._quantize( - value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx - ) - - qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) - qdq_value_states = self._dequantize( - q_value_states, KVCacheScaleType.VALUE, layer_idx - ) - - keys_to_return, values_to_return = qdq_key_states, qdq_value_states - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """ - Returns the sequence length of the cached states. - A layer index can be optionally passed. - """ - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and - # rely on `_seen_tokens` which is updated every "layer_idx" == 0, - # this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to - # verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def reset_states(self): - """reset the kv states (used in calibration)""" - self.key_cache: List[Tensor] = [] - self.value_cache: List[Tensor] = [] - # Used in `generate` to keep tally of how many tokens the cache has seen - self._seen_tokens = 0 - self._quantized_key_cache: List[Tensor] = [] - self._quantized_value_cache: List[Tensor] = [] - - def reset(self): - """ - Reset the instantiation, create new instance on init - """ - QuantizedKVParameterCache._instance = None - QuantizedKVParameterCache._initialized = False - - def _quantize(self, tensor, kv_type, layer_idx): - """Quantizes a key/value using a defined quantization method.""" - from compressed_tensors.quantization.lifecycle.forward import quantize - - if kv_type == KVCacheScaleType.KEY: # key type - observer = self.k_observers[layer_idx] - scales = self.k_scales - zps = self.k_zps - else: - assert kv_type == KVCacheScaleType.VALUE - observer = self.v_observers[layer_idx] - scales = self.v_scales - zps = self.v_zps - - scale, zp = observer(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale) - _pad_and_append_at_idx_(zps, layer_idx, zp) - - q_tensor = quantize( - x=tensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return q_tensor - - def _dequantize(self, qtensor, kv_type, layer_idx): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - from compressed_tensors.quantization.lifecycle.forward import dequantize - - if kv_type == KVCacheScaleType.KEY: - scale = self.k_scales[layer_idx] - zp = self.k_zps[layer_idx] - else: - assert kv_type == KVCacheScaleType.VALUE - scale = self.v_scales[layer_idx] - zp = self.v_zps[layer_idx] - - qdq_tensor = dequantize( - x_q=qtensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return qdq_tensor - - -# NOTE: Using _ suffix to denote l is modified in place -def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: - """ - Append value val to list lst at index idx, right padding if necessary - Needed because user may ignore some layers in configuration, meaning - len(lst) <= idx-1 - - >>> _pad_and_append_at_idx_([0,1,2], 5, 5) - [0, 1, 2, None, None, 5] - >>> _pad_and_append_at_idx_([0,1,2], 3, 8) - [0, 1, 2, 8] - >>> _pad_and_append_at_idx_([0,1,2], 1, 5) - [0, 5, 2] - """ - num_to_pad = idx - len(lst) + 1 - if num_to_pad > 0: - lst += [None] * num_to_pad - lst[idx] = val - return lst diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 5540532c97..da974b25b8 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,22 +1,17 @@ -import inspect -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import torch from compressed_tensors.quantization import ( DynamicType, - KVCacheScaleType, QuantizationArgs, - QuantizationScheme, QuantizationStatus, QuantizationStrategy, ) from compressed_tensors.quantization.lifecycle.forward import forward_quantize -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger from torch.nn import Module -from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain @@ -25,13 +20,13 @@ "update_weight_zp_scale", "calibrate_input_hook", "calibrate_output_hook", - "calibrate_kv_cache_input_hook", - "calibrate_kv_cache_output_hook", - "initialize_quantized_kv_cache", "freeze_module_quantization", "apply_calibration_status", "reset_quantization_status", "update_weight_global_scale", + "calibrate_query_hook", + "calibrate_key_hook", + "calibrate_value_hook", ] @@ -93,7 +88,8 @@ def call_observer( if should_calculate_qparams: scale, zero_point = observer(value) update_offload_parameter(module, f"{base_name}_scale", scale) - update_offload_parameter(module, f"{base_name}_zero_point", zero_point) + if hasattr(module, f"{base_name}_zero_point"): + update_offload_parameter(module, f"{base_name}_zero_point", zero_point) def update_weight_global_scale(module: Module): @@ -151,8 +147,9 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): if value.numel() == 0: return - quantization_scheme = getattr(module, "quantization_scheme", None) - quantization_args = getattr(quantization_scheme, f"{base_name}_activations", None) + field_name = "input" if base_name != "output" else "output" # input,q,k,v,output + args_attr = f"quantization_scheme.{field_name}_activations" + quantization_args = getattr_chain(module, args_attr, None) calculate_qparams = True calculate_gparam = False @@ -202,60 +199,16 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): return output -def calibrate_kv_cache_input_hook( - module: Module, args: Any, kwargs: Dict[str, Any] -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - """ - Hook to update inputs to attention layers when running - kv_cache quantization. Will update the passed in - kv_cache to singleton QuantizedKVParameterCache. - """ - kv_cache = getattr(module, "kv_cache") - if not hasattr(module, "_past_kv_name"): - # Determine which past KV parameter name to use once and cache it - # TODO: Find a better place to cache this - module._past_kv_name = ( - "past_key_value" # transformers#39956 - if "past_key_value" in inspect.signature(module.forward).parameters - else "past_key_values" - ) - - kwargs[module._past_kv_name] = kv_cache - kwargs["use_cache"] = False - return args, kwargs +def calibrate_query_hook(module: Module, query_states: torch.Tensor): + calibrate_activations(module, query_states, base_name="q") -def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): - """ - Hook to update k_scale and v_scale parameters when running kv_cache quantization. - """ - kv_cache = getattr(module, "kv_cache") - k_scale = kv_cache.k_scales[module.layer_idx] - v_scale = kv_cache.v_scales[module.layer_idx] - update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale) - update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale) +def calibrate_key_hook(module: Module, key_states: torch.Tensor): + calibrate_activations(module, key_states, base_name="k") -def initialize_quantized_kv_cache(module: Module): - """ - Initialize a quantized kv_cache on a module (analogous to initializing an observer) - When a config specifying kv_cache quantization is applied to a model, the kv_cache - args are redefined as the output_activations targeting attention modules. - - This function should be called on attention modules with output_activations - """ - scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None) - existing_kv_cache = getattr(module, "kv_cache", None) - - if ( - scheme is None - or not is_kv_cache_quant_scheme(scheme) - or isinstance(existing_kv_cache, QuantizedKVParameterCache) - ): - return - - quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations) - setattr(module, "kv_cache", quantized_kv_cache) +def calibrate_value_hook(module: Module, value_states: torch.Tensor): + calibrate_activations(module, value_states, base_name="v") def apply_calibration_status(module: Module): @@ -284,16 +237,11 @@ def freeze_module_quantization(module: Module): return # remove observers - for name in ("input", "weight", "output"): + for name in ("input", "weight", "output", "q", "k", "v"): obs_name = f"{name}_observer" if hasattr(module, obs_name): delattr(module, obs_name) - # remove quantized kv_cache - kv_cache = getattr(module, "kv_cache", None) - if isinstance(kv_cache, QuantizedKVParameterCache): - delattr(module, "kv_cache") - module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 719431d889..265b589d77 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -1,6 +1,10 @@ from typing import Any, Dict, List, Optional, Set, Union import torch +from compressed_tensors.modeling import ( + IMPL_ATTR, + KV_CACHE_ATTR, +) from compressed_tensors.quantization import ( DynamicType, QuantizationArgs, @@ -21,12 +25,12 @@ from llmcompressor.modifiers.quantization.calibration import ( apply_calibration_status, calibrate_input_hook, - calibrate_kv_cache_input_hook, - calibrate_kv_cache_output_hook, + calibrate_key_hook, calibrate_output_hook, + calibrate_query_hook, + calibrate_value_hook, freeze_module_quantization, initialize_observer, - initialize_quantized_kv_cache, reset_quantization_status, ) from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -145,6 +149,10 @@ def resolved_targets(self) -> Set[str]: for config_group in self.resolved_config.config_groups.values(): for target in config_group.targets: targets.add(target) + + if self.resolved_config.kv_cache_scheme is not None: + targets.add("re:.*self_attn$") + return targets def initialize_quantization(self, model: torch.nn.Module): @@ -170,9 +178,9 @@ def start_calibration(self, model: torch.nn.Module): :param model: model to prepare for calibration """ - self._calibration_hooks = self._initialize_hooks(model) for _, module in match_named_modules(model, self.resolved_targets, self.ignore): self._initialize_observers(module) + self._calibration_hooks |= self._initialize_hooks(module) apply_calibration_status(module) model.apply(enable_quantization) # quantize at the same time as calibrate @@ -260,60 +268,51 @@ def _initialize_observers(self, module: torch.nn.Module): # input activations if input: - initialize_observer(module, base_name="input") + if not is_attention: + initialize_observer(module, base_name="input") + else: + if hasattr(module, IMPL_ATTR): + initialize_observer(module, base_name="q") + if hasattr(module, KV_CACHE_ATTR): + initialize_observer(module, base_name="k") + initialize_observer(module, base_name="v") # weight observers (used by `update_weight_zp_scale` or child modifier) if weight: initialize_observer(module, base_name="weight") - # kv_cache activations. Within `apply_quantization_config`, the config is - # modified to use attention output quantization if a kv_cache_scheme exists - if is_attention and output: - initialize_quantized_kv_cache(module) - # output activations - elif output: + if output: initialize_observer(module, base_name="output") - def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: + def _initialize_hooks(self, module: torch.nn.Module) -> Set[RemovableHandle]: hooks = set() - for _, module in match_named_modules(model, self.resolved_targets, self.ignore): - if not hasattr(module, "quantization_scheme"): - continue + if not hasattr(module, "quantization_scheme"): + hooks - scheme: QuantizationScheme = module.quantization_scheme - input = scheme.input_activations and scheme.input_activations.dynamic in ( - False, - DynamicType.LOCAL, - ) - output = scheme.output_activations and not scheme.output_activations.dynamic - is_attention = is_attention_module(module) + scheme: QuantizationScheme = module.quantization_scheme + input = scheme.input_activations and scheme.input_activations.dynamic in ( + False, + DynamicType.LOCAL, + ) + output = scheme.output_activations and not scheme.output_activations.dynamic + is_attention = is_attention_module(module) - # input activations - if input: + # input activations + if input: + if not is_attention: hooks.add( self.register_hook(module, calibrate_input_hook, "forward_pre") ) + else: + if hasattr(module, IMPL_ATTR): + hooks.add(self.register_hook(module, calibrate_query_hook, "query")) + if hasattr(module, KV_CACHE_ATTR): + hooks.add(self.register_hook(module, calibrate_key_hook, "key")) + hooks.add(self.register_hook(module, calibrate_value_hook, "value")) - # kv_cache activations. Within `apply_quantization_config`, the config is - # modified to use attention output quantization if a kv_cache_scheme exists - if is_attention and output: - hooks.add( - self.register_hook( - module, - calibrate_kv_cache_input_hook, - "forward_pre", - with_kwargs=True, - ) - ) - hooks.add( - self.register_hook( - module, calibrate_kv_cache_output_hook, "forward" - ) - ) - - # output activations - elif output: - hooks.add(self.register_hook(module, calibrate_output_hook, "forward")) + # output activations + if output: + hooks.add(self.register_hook(module, calibrate_output_hook, "forward")) return hooks diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 98d5240e21..09efe9e470 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,8 +1,13 @@ import contextlib -from functools import wraps +from functools import partial, wraps from typing import Any, Callable, ClassVar, Optional, Set, Union import torch +from compressed_tensors.modeling import ( + register_key_hook, + register_query_hook, + register_value_hook, +) from loguru import logger from pydantic import BaseModel from torch.utils.hooks import RemovableHandle @@ -92,7 +97,7 @@ def wrapped_hook(*args, **kwargs): return hook(*args, **kwargs) - register_function = getattr(target, f"register_{hook_type}_hook") + register_function = self._get_register_function(target, hook_type) handle = register_function(wrapped_hook, **kwargs) self._hooks.add(handle) logger.debug(f"{self} added {handle}") @@ -113,3 +118,15 @@ def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None): hook.remove() self._hooks -= handles + + def _get_register_function( + self, target: torch.nn.Module, hook_type: str + ) -> Callable: + if hook_type == "query": + return partial(register_query_hook, target) + elif hook_type == "key": + return partial(register_key_hook, target) + elif hook_type == "value": + return partial(register_value_hook, target) + else: + return getattr(target, f"register_{hook_type}_hook") diff --git a/src/llmcompressor/observers/helpers.py b/src/llmcompressor/observers/helpers.py index 4560da1b85..71fa75d89d 100644 --- a/src/llmcompressor/observers/helpers.py +++ b/src/llmcompressor/observers/helpers.py @@ -52,6 +52,8 @@ def flatten_for_calibration( def _flatten_weight( value: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None ): + # value.shape = (num_rows, num_cols) + if args.strategy == QuantizationStrategy.TENSOR: # (1, 1, num_weight_elems) return value.reshape((1, 1, -1)) @@ -83,10 +85,15 @@ def _flatten_weight( .unsqueeze(0) ) + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("Attention head quantization cannot be applied to weights") + assert False, f"Unknown strategy {args.strategy}" def _flatten_activation(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (batch_size, seq_len, hidden_dim) + if args.strategy == QuantizationStrategy.TENSOR: # (batch_size * seq_len, 1, hidden_dim) return value.reshape((-1, 1, value.size(-1))) @@ -107,14 +114,18 @@ def _flatten_activation(value: torch.Tensor, args: QuantizationArgs): if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to activations") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("Attention head quantization cannot be applied to activations") + assert False, f"Unknown strategy {args.strategy}" def _flatten_attention(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (batch_size, num_heads, seq_len, head_dim) + if args.strategy == QuantizationStrategy.TENSOR: - # (batch_size, seq_len, num_heads, head_dim) # (batch_size * seq_len, 1, num_heads * head_dim) - return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) + return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2) if args.strategy == QuantizationStrategy.TOKEN: raise ValueError("Token quantization cannot be applied to attention") @@ -128,4 +139,8 @@ def _flatten_attention(value: torch.Tensor, args: QuantizationArgs): if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to attention") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + # (batch_size * seq_len, num_heads, 1, 1, head_dim) + return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2) + assert False, f"Unknown strategy {args.strategy}" diff --git a/tests/llmcompressor/modifiers/calibration/test_cache.py b/tests/llmcompressor/modifiers/calibration/test_cache.py deleted file mode 100644 index 70f0e61259..0000000000 --- a/tests/llmcompressor/modifiers/calibration/test_cache.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs - -from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache -from llmcompressor.observers import Observer - - -def test_is_quantized_cache_singleton(): - """ - Check if quantized_cache is a singleton, used for - passing in QuantizedKVParameterCache to the forward call of - the model's self_attn - """ - - args = QuantizationArgs() - cache = QuantizedKVParameterCache(args) - observer = args.observer - observer = Observer.load_from_registry(observer, base_name="k", args=args) - - tensor = torch.tensor([1, 2, 3]) - cache.k_scales.append(tensor) - cache.k_observers.append(observer) - - same_cache = QuantizedKVParameterCache(args) - - assert len(cache.k_scales) == len(same_cache.k_scales) - assert torch.equal(cache.k_scales[0], same_cache.k_scales[0]) - - assert cache.k_observers == same_cache.k_observers - assert hex(id(cache.k_observers[0])) == hex(id(same_cache.k_observers[0])) - - cache.reset() - - -def test_update(): - num_bits = 8 - args = QuantizationArgs(num_bits=num_bits, symmetric=True) - cache = QuantizedKVParameterCache(args) - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - denom = (2 ** (num_bits) - 1) / 2 - expected_k_scale = torch.tensor([max_key_states_val / denom]) - expected_v_scale = torch.tensor([max_value_states_val / denom]) - - assert cache.k_scales[0] == expected_k_scale - assert cache.v_scales[0] == expected_v_scale - - # new attn layer - layer_idx = 1 - cache.update(key_states, value_states, layer_idx) - - assert len(cache.k_scales) == 2 - assert len(cache.v_scales) == 2 - - assert len(cache.k_observers) == 2 - assert len(cache.v_observers) == 2 - - cache.reset() - - -def test_cache_reset(): - num_bits = 8 - args = QuantizationArgs(num_bits=num_bits, symmetric=True) - cache = QuantizedKVParameterCache(args) - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - assert len(cache.k_scales) == 1 - assert len(cache.v_scales) == 1 - - assert len(cache.k_observers) == 1 - assert len(cache.v_observers) == 1 - - cache.reset() - - # new instance, different memory addr - different_cache = QuantizedKVParameterCache(args) - - assert len(different_cache.k_scales) == 0 - assert len(different_cache.v_scales) == 0 - - assert len(different_cache.k_observers) == 0 - assert len(different_cache.v_observers) == 0 - - assert hex(id(cache)) != hex(id(different_cache)) diff --git a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py deleted file mode 100644 index b22e7ec401..0000000000 --- a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch -from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationStatus, - apply_quantization_config, - is_attention_module, -) -from transformers import AutoModelForCausalLM - -from llmcompressor.modifiers.quantization.calibration import ( - calibrate_kv_cache_input_hook, - calibrate_kv_cache_output_hook, - freeze_module_quantization, - initialize_quantized_kv_cache, -) - -config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "kv_cache_scheme": { - "num_bits": 8, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "config_groups": { - "group_1": { - "weights": { - "num_bits": 4, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, -} - - -def _prep_for_calibration(module: torch.nn.Module): - if is_attention_module(module): - module.register_forward_pre_hook( - calibrate_kv_cache_input_hook, with_kwargs=True - ) - module.register_forward_hook(calibrate_kv_cache_output_hook) - module.quantization_status = QuantizationStatus.CALIBRATION - - -@pytest.mark.parametrize("config", [config]) -def test_kv_cache_quantization(config): - sample = { - name: torch.ones((1, 32)).long() - for name in ["input_ids", "attention_mask", "labels"] - } - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", - torch_dtype="auto", - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - model.apply(initialize_quantized_kv_cache) - model.apply(_prep_for_calibration) - - with torch.no_grad(): - _ = model(**sample) - - model.apply(freeze_module_quantization) - - reloaded_config = QuantizationConfig.from_pretrained(model) - - assert ( - config.kv_cache_scheme.model_dump().keys() - == reloaded_config.kv_cache_scheme.model_dump().keys() - ) - assert list(config.kv_cache_scheme.model_dump().values()) == list( - reloaded_config.kv_cache_scheme.model_dump().values() - ) diff --git a/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py b/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py index 80886e1dd8..58f139f342 100644 --- a/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py +++ b/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py @@ -3,7 +3,7 @@ import pytest from accelerate import init_empty_weights -from compressed_tensors.quantization import KVCacheScaleType, is_attention_module +from compressed_tensors.quantization import is_attention_module from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.utils.quantization_config import CompressedTensorsConfig @@ -14,7 +14,7 @@ NUM_CALIBRATION_SAMPLES = 16 MAX_SEQUENCE_LENGTH = 512 DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" +DATASET_SPLIT = f"train_sft[:{NUM_CALIBRATION_SAMPLES}]" MODEL_IDS = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", @@ -49,9 +49,11 @@ def _oneshot_fixture(tmp_path: Path): symmetric=symmetric, ) oneshot_args = dict( - dataset="open_platypus", recipe=recipe, - num_calibration_samples=16, + dataset="open_platypus", + splits={"calibration": f"train[:{NUM_CALIBRATION_SAMPLES}]"}, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + max_seq_length=MAX_SEQUENCE_LENGTH, ) for model_id in MODEL_IDS: oneshot_args["output_dir"] = os.path.join(tmp_path, model_id) @@ -161,8 +163,8 @@ def test_kv_cache_model_state_dict_attr(oneshot_fixture, tmp_path): for name, submodule in model.named_modules(): if is_attention_module(submodule): counts += 1 - assert hasattr(submodule, KVCacheScaleType.VALUE.value) - assert hasattr(submodule, KVCacheScaleType.KEY.value) + assert hasattr(submodule, "v_scale") + assert hasattr(submodule, "k_scale") assert counts > 0 @@ -200,8 +202,8 @@ def test_kv_cache_gptq_config_format(kv_cache_fixture, tmp_path): for name, submodule in model.named_modules(): if is_attention_module(submodule): counts += 1 - assert hasattr(submodule, KVCacheScaleType.VALUE.value) - assert hasattr(submodule, KVCacheScaleType.KEY.value) + assert hasattr(submodule, "v_scale") + assert hasattr(submodule, "k_scale") assert counts > 0 @@ -240,7 +242,7 @@ def test_kv_cache_gptq_model_state_dict_attr(kv_cache_fixture, tmp_path): for name, submodule in model.named_modules(): if is_attention_module(submodule): counts += 1 - assert hasattr(submodule, KVCacheScaleType.VALUE.value) - assert hasattr(submodule, KVCacheScaleType.KEY.value) + assert hasattr(submodule, "v_scale") + assert hasattr(submodule, "k_scale") assert counts > 0