diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index e247e7029..28d5e94fc 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict + import torch +from accelerate.utils import has_offloaded_params from compressed_tensors import TRANSFORM_CONFIG_NAME from compressed_tensors.transform import TransformConfig, TransformFactory @@ -34,3 +37,35 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): # attach config to model for compression/serialization setattr(model, TRANSFORM_CONFIG_NAME, config) + + # ensure that tied weight transforms can be serialized without aliases + # In the future, this could be done by transformers or model compressor + # which would make this more robust to changing dispatches after transforms + _tie_offloaded_tensors(model) + + +def _tie_offloaded_tensors(model: torch.nn.Module): + """ + When accelerate replaces tensors with meta tensors during offloading, the meta + tensors may not be identical, even if the offloaded values are identical. + + However, transformers can only serialize correctly if meta tensors are identical + (see transformers#39263). + + This function collects all meta tensors which have shared offloaded values and sets + those tensors to be identical so that they can be removed during serialization + + :param model: model potentially containing offloaded meta tensors to fix + """ + + # ensure that if a location shares an offloaded tensor pointers, that the + # meta tensor is also identical (assigned to the first instance of parameter) + ptr_to_meta: Dict[int, torch.nn.Parameter] = dict() + for module in model.modules(): + if has_offloaded_params(module): + for key, _ in module.named_parameters(recurse=False): + offloaded_ptr = module._hf_hook.weights_map[key].data_ptr() + + if offloaded_ptr not in ptr_to_meta: + ptr_to_meta[offloaded_ptr] = getattr(module, key) + setattr(module, key, ptr_to_meta[offloaded_ptr]) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 94e6b4a42..34d609e74 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -13,8 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from collections import defaultdict -from typing import List, Optional, Set, Tuple +from typing import List, Optional import torch import torch.nn.utils.parametrize as P @@ -57,7 +56,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non self.name = name self.scheme = scheme self.generator = torch.Generator() - self.transforms = list() if seed is not None: self.generator.manual_seed(seed) @@ -101,8 +99,6 @@ def apply_to_model(self, model: Module, use_tqdm=True): for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): self._apply_to_module(module, arg) - self._update_tied_weights() - def _apply_to_module(self, module: Module, args: TransformArgs): """ Create transforms and apply them to the module @@ -120,7 +116,6 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location}" transform = self.create_transform(module, args) - self.transforms.append(transform) register_offload_module(module, transform_name, transform) # register input transformation hook @@ -165,31 +160,6 @@ def output_hook(_, _input, output): else: raise NotImplementedError() - def _update_tied_weights(self): - """ - Populate the `_dynamic_tied_weights_keys` attribute of transforms, - which is used by transformers to detect and remove shared pointers - during saving - """ - # map from data_ptrs to keys - ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list) - for transform in self.transforms: - for name, param in transform.named_parameters(recurse=False): - # NOTE: previously asserted that parent._hf_hook.place_submodules=False - if has_offloaded_params(transform): - param = transform._hf_hook.weights_map[name] - ptr_to_keys[param.data_ptr()].append((transform, name)) - - # populate `_dynamic_tied_weights_keys` if there is more than one key - # and ensure that they share tensors - for shared_keys in ptr_to_keys.values(): - if len(shared_keys) > 1: - tensor = getattr(shared_keys[0][0], shared_keys[0][1]) - - for transform, name in shared_keys: - transform._dynamic_tied_weights_keys.add(name) - setattr(transform, name, tensor) - class TransformBase(InternalModule, ABC): """ @@ -198,11 +168,7 @@ class TransformBase(InternalModule, ABC): args: TransformArgs weight: Parameter - _dynamic_tied_weights_keys: Set[str] - - def __init__(self): - super().__init__() - self._dynamic_tied_weights_keys = set() + _dynamic_tied_weights_keys: List[str] = ["weight"] @abstractmethod def forward(self, value: Tensor) -> Tensor: diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index de6e284bb..a843e2728 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import List, Optional import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -84,6 +84,8 @@ def _create_permutation(self, weight: Parameter) -> Parameter: class HadamardTransform(TransformBase): + _dynamic_tied_weights_keys: List[str] = ["weight", "perm"] + def __init__( self, weight: Parameter, diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index 7d361e59d..c8144ae26 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -115,13 +115,16 @@ def _fetch_hadamard_divisor( than forcing callers to manage the file open context :param n: size of known hadamard matrix + :param dtype: data type to move fetched hadamard to + :param device: device to move fetched hadamard to :return: a known hadamard matrix of size `n` if one exists, else None """ - with safe_open(file_path, framework="pt", device=str(device)) as file: + open_device = torch.device("cpu") if device.type == "meta" else device + with safe_open(file_path, framework="pt", device=str(open_device)) as file: divisors = sorted((int(key) for key in file.keys()), reverse=True) for divisor in divisors: if n % divisor == 0 and is_pow2(n // divisor): - return file.get_tensor(str(divisor)).to(dtype=dtype) + return file.get_tensor(str(divisor)).to(dtype=dtype, device=device) return None diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index a0188c429..824c06bd3 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -19,6 +19,8 @@ class TransformableModel(PreTrainedModel): + config_class = PretrainedConfig + def __init__(self, *sizes): super().__init__(config=PretrainedConfig()) self.fcs = torch.nn.ModuleList( diff --git a/tests/test_transform/factory/test_serialization.py b/tests/test_transform/factory/test_serialization.py index a688c2cf1..15fa240ba 100644 --- a/tests/test_transform/factory/test_serialization.py +++ b/tests/test_transform/factory/test_serialization.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch from compressed_tensors.transform import ( @@ -20,7 +22,9 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch +from safetensors import safe_open from tests.testing_utils import requires_accelerate, requires_gpu +from transformers import AutoModelForCausalLM, AutoTokenizer @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @@ -38,17 +42,57 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False): apply_transform_config(model, config) # save model - model.save_pretrained(tmp_path) + model_path = os.path.join(tmp_path, "test_model_path") + model.save_pretrained(model_path) + + # check that saved values match model values + # note that shared weights are only serialized once + safetensors_path = os.path.join(model_path, "model.safetensors") + with safe_open(safetensors_path, framework="pt", device="cpu") as file: + saved_keys = set(file.keys()) + assert { + "fcs.0.weight", + "fcs.1.weight", + "fcs.2.weight", + "fcs.3.weight", + "fcs.4.weight", + } <= saved_keys + for key in saved_keys: + param = model.get_parameter(key) + saved_param = file.get_tensor(key) - # TODO: reload model + if param.device.type != "meta": # skip testing values in offload case + assert torch.equal(param, saved_param) -@pytest.mark.skip(reason="Requires changes in upstream transformers") -# https://github.com/huggingface/transformers/pull/39280 -# https://github.com/huggingface/transformers/pull/39263 @requires_gpu @requires_accelerate() @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @pytest.mark.parametrize("randomize", (True, False)) def test_serialization_offload(type, randomize, model_apply, tmp_path): test_serialization(type, randomize, model_apply, tmp_path, offload=True) + + +@pytest.mark.skip("Requires transformers#40673") +@requires_gpu +@pytest.mark.parametrize( + "model_stub,exp_perplexity", + [ + ("nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", 10.0), + ("nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", 10.0), + ], +) +def test_load_perplexity(model_stub, exp_perplexity): + model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="cuda") + tokenizer = AutoTokenizer.from_pretrained(model_stub) + + prompt = "The capital of France is Paris, the capital of Germany is Berlin" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {key: value.to(model.device) for key, value in inputs.items()} + labels = inputs["input_ids"] + + with torch.no_grad(): + outputs = model(**inputs, labels=labels) + + perplexity = torch.exp(outputs.loss) + assert perplexity <= exp_perplexity