diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 96b400d63..918fb3d35 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -71,6 +71,7 @@ def initialize_observer( observer = Observer.load_from_registry( quantization_args.observer, quantization_args=quantization_args, + base_name=base_name, averaging_constant=observer_kwargs.get( "averaging_constant", DEFAULT_AVERAGING_CONSTANT ), diff --git a/src/llmcompressor/observers/mse.py b/src/llmcompressor/observers/mse.py index 419155f07..07c7be726 100644 --- a/src/llmcompressor/observers/mse.py +++ b/src/llmcompressor/observers/mse.py @@ -2,7 +2,7 @@ import torch from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.utils import calculate_qparams +from compressed_tensors.quantization.utils import calculate_qparams, is_fp4 from torch import FloatTensor, IntTensor, Tensor from llmcompressor.observers.base import Observer @@ -13,8 +13,13 @@ @Observer.register("mse") class MovingAverageMSEObserver(Observer): """ - Implements a dynamic quantization observer that sets the scale and - zero point based on a moving average of the mse-clipped min and max observed values + Implements a dynamic quantization observer that sets the scale and zero + point based on a moving average of observed values. + + Behavior: + - Weights: global and local scales use MSE-optimized min/max. + - Activations: global scale uses MSE-optimized min/max; local scales + use plain min–max. """ def __init__( @@ -25,6 +30,7 @@ def __init__( averaging_constant: float = 0.01, grid: float = 100.0, norm: float = 2.4, + base_name: str = "weight", **kwargs, ): super().__init__(quantization_args=quantization_args) @@ -36,6 +42,7 @@ def __init__( self.averaging_constant = averaging_constant self.grid = grid self.norm = norm + self.is_activation = base_name != "weight" def calculate_mse_min_max( self, @@ -44,8 +51,10 @@ def calculate_mse_min_max( global_scale: Optional[torch.Tensor] = None, ): """ - Computes the mse-clipped min and max values of the observed tensor by - optimizing for quantization error + Computes MSE-optimized min and max values for quantization. + + - Used for weights (global and local). + - Used for activations only at the global scale (local activations use min–max). :param observed: observed tensor to calculate quantization parameters for :param reduce_dims: optional tuple of dimensions to reduce along, @@ -53,6 +62,7 @@ def calculate_mse_min_max( :param global_scale: optional scale to further scale local quantization scales :return: tuple of min and max values derived from the observed tensor """ + from compressed_tensors.quantization.lifecycle import fake_quantize if not reduce_dims: @@ -75,18 +85,38 @@ def calculate_mse_min_max( shrinked_min_val = p * absolute_min_val shrinked_max_val = p * absolute_max_val + from compressed_tensors.quantization.utils import generate_gparam + + if (is_fp4(self.quantization_args)) and global_scale is None: + # If the quantization scheme is fp4 and global_scale is still None + # i.e it has not yet been optimized, then we should first get + # the global scale and then optimize the local scales. + # Local scales are set to by the absolute min and max. + iteration_global_scale = generate_gparam( + updated_min_val=shrinked_min_val, updated_max_val=shrinked_max_val + ) + iteration_min_val = absolute_min_val + iteration_max_val = absolute_max_val + else: + # Otherwise, we are optimizing local scales and use the shrinked + # min and max + iteration_min_val = shrinked_min_val + iteration_max_val = shrinked_max_val + iteration_global_scale = global_scale + candidate_scales, candidate_zero_points = calculate_qparams( - min_vals=shrinked_min_val, - max_vals=shrinked_max_val, + min_vals=iteration_min_val, + max_vals=iteration_max_val, quantization_args=self.quantization_args, - global_scale=global_scale, + global_scale=iteration_global_scale, ) + q = fake_quantize( observed, candidate_scales, candidate_zero_points, self.quantization_args, - global_scale=global_scale, + global_scale=iteration_global_scale, ) q -= observed @@ -116,10 +146,15 @@ def calculate_updated_min_max( reduce_dims: Optional[Tuple[int]] = None, tensor_id: Optional[Any] = None, global_scale: Optional[torch.Tensor] = None, + is_local: Optional[bool] = False, ) -> Tuple[FloatTensor, IntTensor]: """ Updates the mse-clipped min and max values of the observed tensor using - a moving average smoothed by the averaging_constant + a moving average smoothed by the averaging_constant. + + - Weights: global and local scales use MSE-optimized values. + - Activations: global scale uses MSE-optimized values, local scales use + min–max. :param observed: observed tensor to calculate quantization parameters for :param reduce_dims: optional tuple of dimensions to reduce along, @@ -130,11 +165,20 @@ def calculate_updated_min_max( :param global_scale: optional scale to further scale local quantization scales :return: updated min and max values derived from the observed value """ - # TODO: will need to be expanded to support fp4 activations; - # currently not supported - min_val, max_val = self.calculate_mse_min_max( - observed, reduce_dims, global_scale=global_scale - ) + + # Skip local scales updates for dynamic activations (this will happen at + # runtime) + if self.is_activation and is_local: + # Activations local scales: min–max + min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) + max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) + else: + # Weights, or activations global: MSE loop + min_val, max_val = self.calculate_mse_min_max( + observed, reduce_dims, global_scale=global_scale + ) + + tensor_id = tensor_id or "default" running_min_val = self.min_val.get(tensor_id, None) running_max_val = self.max_val.get(tensor_id, None) @@ -150,7 +194,6 @@ def calculate_updated_min_max( max_val - running_max_val ) - tensor_id = tensor_id or "default" self.min_val[tensor_id] = updated_min_val self.max_val[tensor_id] = updated_max_val return updated_min_val, updated_max_val @@ -180,6 +223,7 @@ def calculate_qparams( tensor_id=tensor_id, reduce_dims=reduce_dims, global_scale=global_scale, + is_local=True, ) scale, zero_point = calculate_qparams( min_vals=updated_min_val, @@ -187,6 +231,7 @@ def calculate_qparams( quantization_args=self.quantization_args, global_scale=global_scale, ) + return scale, zero_point def get_qparams_along_dim( @@ -211,3 +256,23 @@ def reset(self): super().reset() self.min_val = {} self.max_val = {} + + def calculate_gparam(self, observed: Tensor) -> torch.Tensor: + """ + Generate a global scale using the observed min and max from MSE optimization. + + - Weights: global scale is computed with standard MSE optimization. + - Activations: global scale is computed with dynamic MSE-based scaling. + + :param observed: observed tensor to calculate quantization parameters for + :return: updated global scale derived from the observed tensor + """ + from compressed_tensors.quantization.utils import generate_gparam + + updated_min_val, updated_max_val = self.calculate_updated_min_max( + observed=observed + ) + + return generate_gparam( + updated_min_val=updated_min_val, updated_max_val=updated_max_val + )