From 0fe39c7fb27a316d6a3272e00cd4c03cbed612f1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 26 Feb 2025 12:41:25 -0500 Subject: [PATCH 1/3] fix kv_initialization Signed-off-by: Kyle Sayers --- .../linear/compressed_linear.py | 3 ++- .../quantization/lifecycle/initialize.py | 19 +++++++------------ 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/linear/compressed_linear.py b/src/compressed_tensors/linear/compressed_linear.py index 3e2b2f5f..fcbe4606 100644 --- a/src/compressed_tensors/linear/compressed_linear.py +++ b/src/compressed_tensors/linear/compressed_linear.py @@ -21,6 +21,7 @@ QuantizationStatus, initialize_module_for_quantization, ) +from compressed_tensors.utils import register_offload_parameter from torch import Tensor from torch.nn import Parameter from torch.nn.functional import linear @@ -68,7 +69,7 @@ def from_linear( param = Parameter( torch.empty(shape, device=device, dtype=dtype), requires_grad=False ) - module.register_parameter(name, param) + register_offload_parameter(module, name, param) # mark module as compressed module.quantization_status = QuantizationStatus.COMPRESSED diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 8dd8fc51..9ce7a000 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -190,24 +190,19 @@ def _initialize_scale_zero_point( register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) -def _initialize_attn_scales(module: Module) -> None: - """Initlaize k_scale, v_scale for self_attn""" +def _initialize_attn_scales(module: Module): + """Initlaize k_scale, v_scale for self_attn""" expected_shape = 1 # per tensor - param = next(module.parameters()) - scale_dtype = param.dtype - device = param.device + weight_param = getattr(module, "weight", next(module.parameters())) + scale_dtype = weight_param.dtype + device = weight_param.device init_scale = Parameter( torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, ) - module.register_parameter(KVCacheScaleType.KEY.value, init_scale) - - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - module.register_parameter(KVCacheScaleType.VALUE.value, init_scale) + register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale) + register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale.clone()) From 4c2ec0b8d4e20185e69b25f8b2c9b3a9b8ac7741 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 26 Feb 2025 12:43:42 -0500 Subject: [PATCH 2/3] use register_offload_parameter Signed-off-by: Kyle Sayers --- src/compressed_tensors/linear/compressed_linear.py | 3 ++- src/compressed_tensors/quantization/lifecycle/initialize.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/linear/compressed_linear.py b/src/compressed_tensors/linear/compressed_linear.py index 3e2b2f5f..42d97d73 100644 --- a/src/compressed_tensors/linear/compressed_linear.py +++ b/src/compressed_tensors/linear/compressed_linear.py @@ -21,6 +21,7 @@ QuantizationStatus, initialize_module_for_quantization, ) +from compressed_tensors.utils import register_offload_parameter from torch import Tensor from torch.nn import Parameter from torch.nn.functional import linear @@ -68,7 +69,7 @@ def from_linear( param = Parameter( torch.empty(shape, device=device, dtype=dtype), requires_grad=False ) - module.register_parameter(name, param) + register_offload_parameter(name, param) # mark module as compressed module.quantization_status = QuantizationStatus.COMPRESSED diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 8dd8fc51..6886423a 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -203,11 +203,10 @@ def _initialize_attn_scales(module: Module) -> None: torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, ) - - module.register_parameter(KVCacheScaleType.KEY.value, init_scale) + register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale) init_scale = Parameter( torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, ) - module.register_parameter(KVCacheScaleType.VALUE.value, init_scale) + register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) From 15b00ac129b17a72629bd88e90ae852fb1796907 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 26 Feb 2025 12:49:33 -0500 Subject: [PATCH 3/3] fix typo Signed-off-by: Kyle Sayers --- src/compressed_tensors/linear/compressed_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/linear/compressed_linear.py b/src/compressed_tensors/linear/compressed_linear.py index 42d97d73..fcbe4606 100644 --- a/src/compressed_tensors/linear/compressed_linear.py +++ b/src/compressed_tensors/linear/compressed_linear.py @@ -69,7 +69,7 @@ def from_linear( param = Parameter( torch.empty(shape, device=device, dtype=dtype), requires_grad=False ) - register_offload_parameter(name, param) + register_offload_parameter(module, name, param) # mark module as compressed module.quantization_status = QuantizationStatus.COMPRESSED