Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/compressed_tensors/transform/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

import torch
from accelerate.utils import has_offloaded_params
from compressed_tensors import TRANSFORM_CONFIG_NAME
from compressed_tensors.transform import TransformConfig, TransformFactory

Expand All @@ -34,3 +37,35 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):

# attach config to model for compression/serialization
setattr(model, TRANSFORM_CONFIG_NAME, config)

# ensure that tied weight transforms can be serialized without aliases
# In the future, this could be done by transformers or model compressor
# which would make this more robust to changing dispatches after transforms
_tie_offloaded_tensors(model)


def _tie_offloaded_tensors(model: torch.nn.Module):
"""
When accelerate replaces tensors with meta tensors during offloading, the meta
tensors may not be identical, even if the offloaded values are identical.

However, transformers can only serialize correctly if meta tensors are identical
(see transformers#39263).

This function collects all meta tensors which have shared offloaded values and sets
those tensors to be identical so that they can be removed during serialization

:param model: model potentially containing offloaded meta tensors to fix
"""

# ensure that if a location shares an offloaded tensor pointers, that the
# meta tensor is also identical (assigned to the first instance of parameter)
ptr_to_meta: Dict[int, torch.nn.Parameter] = dict()
for module in model.modules():
if has_offloaded_params(module):
for key, _ in module.named_parameters(recurse=False):
offloaded_ptr = module._hf_hook.weights_map[key].data_ptr()

if offloaded_ptr not in ptr_to_meta:
ptr_to_meta[offloaded_ptr] = getattr(module, key)
setattr(module, key, ptr_to_meta[offloaded_ptr])
38 changes: 2 additions & 36 deletions src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from collections import defaultdict
from typing import List, Optional, Set, Tuple
from typing import List, Optional

import torch
import torch.nn.utils.parametrize as P
Expand Down Expand Up @@ -57,7 +56,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
self.name = name
self.scheme = scheme
self.generator = torch.Generator()
self.transforms = list()
if seed is not None:
self.generator.manual_seed(seed)

Expand Down Expand Up @@ -101,8 +99,6 @@ def apply_to_model(self, model: Module, use_tqdm=True):
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
self._apply_to_module(module, arg)

self._update_tied_weights()

def _apply_to_module(self, module: Module, args: TransformArgs):
"""
Create transforms and apply them to the module
Expand All @@ -120,7 +116,6 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
# create transform as submodule
transform_name = f"{self.name}_{args.location}"
transform = self.create_transform(module, args)
self.transforms.append(transform)
register_offload_module(module, transform_name, transform)

# register input transformation hook
Expand Down Expand Up @@ -165,31 +160,6 @@ def output_hook(_, _input, output):
else:
raise NotImplementedError()

def _update_tied_weights(self):
"""
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
which is used by transformers to detect and remove shared pointers
during saving
"""
# map from data_ptrs to keys
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
for transform in self.transforms:
for name, param in transform.named_parameters(recurse=False):
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
if has_offloaded_params(transform):
param = transform._hf_hook.weights_map[name]
ptr_to_keys[param.data_ptr()].append((transform, name))

# populate `_dynamic_tied_weights_keys` if there is more than one key
# and ensure that they share tensors
for shared_keys in ptr_to_keys.values():
if len(shared_keys) > 1:
tensor = getattr(shared_keys[0][0], shared_keys[0][1])

for transform, name in shared_keys:
transform._dynamic_tied_weights_keys.add(name)
setattr(transform, name, tensor)


class TransformBase(InternalModule, ABC):
"""
Expand All @@ -198,11 +168,7 @@ class TransformBase(InternalModule, ABC):

args: TransformArgs
weight: Parameter
_dynamic_tied_weights_keys: Set[str]

def __init__(self):
super().__init__()
self._dynamic_tied_weights_keys = set()
_dynamic_tied_weights_keys: List[str] = ["weight"]

@abstractmethod
def forward(self, value: Tensor) -> Tensor:
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import List, Optional

import torch
from compressed_tensors.transform import TransformArgs, TransformScheme
Expand Down Expand Up @@ -84,6 +84,8 @@ def _create_permutation(self, weight: Parameter) -> Parameter:


class HadamardTransform(TransformBase):
_dynamic_tied_weights_keys: List[str] = ["weight", "perm"]

def __init__(
self,
weight: Parameter,
Expand Down
7 changes: 5 additions & 2 deletions src/compressed_tensors/transform/utils/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,16 @@ def _fetch_hadamard_divisor(
than forcing callers to manage the file open context

:param n: size of known hadamard matrix
:param dtype: data type to move fetched hadamard to
:param device: device to move fetched hadamard to
:return: a known hadamard matrix of size `n` if one exists, else None
"""
with safe_open(file_path, framework="pt", device=str(device)) as file:
open_device = torch.device("cpu") if device.type == "meta" else device
with safe_open(file_path, framework="pt", device=str(open_device)) as file:
divisors = sorted((int(key) for key in file.keys()), reverse=True)
for divisor in divisors:
if n % divisor == 0 and is_pow2(n // divisor):
return file.get_tensor(str(divisor)).to(dtype=dtype)
return file.get_tensor(str(divisor)).to(dtype=dtype, device=device)

return None

Expand Down
2 changes: 2 additions & 0 deletions tests/test_transform/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@


class TransformableModel(PreTrainedModel):
config_class = PretrainedConfig

def __init__(self, *sizes):
super().__init__(config=PretrainedConfig())
self.fcs = torch.nn.ModuleList(
Expand Down
54 changes: 49 additions & 5 deletions tests/test_transform/factory/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
import torch
from compressed_tensors.transform import (
Expand All @@ -20,7 +22,9 @@
apply_transform_config,
)
from compressed_tensors.utils import offloaded_dispatch
from safetensors import safe_open
from tests.testing_utils import requires_accelerate, requires_gpu
from transformers import AutoModelForCausalLM, AutoTokenizer


@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
Expand All @@ -38,17 +42,57 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False):
apply_transform_config(model, config)

# save model
model.save_pretrained(tmp_path)
model_path = os.path.join(tmp_path, "test_model_path")
model.save_pretrained(model_path)

# check that saved values match model values
# note that shared weights are only serialized once
safetensors_path = os.path.join(model_path, "model.safetensors")
with safe_open(safetensors_path, framework="pt", device="cpu") as file:
saved_keys = set(file.keys())
assert {
"fcs.0.weight",
"fcs.1.weight",
"fcs.2.weight",
"fcs.3.weight",
"fcs.4.weight",
} <= saved_keys
for key in saved_keys:
param = model.get_parameter(key)
saved_param = file.get_tensor(key)

# TODO: reload model
if param.device.type != "meta": # skip testing values in offload case
assert torch.equal(param, saved_param)


@pytest.mark.skip(reason="Requires changes in upstream transformers")
# https://github.com/huggingface/transformers/pull/39280
# https://github.com/huggingface/transformers/pull/39263
@requires_gpu
@requires_accelerate()
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
@pytest.mark.parametrize("randomize", (True, False))
def test_serialization_offload(type, randomize, model_apply, tmp_path):
test_serialization(type, randomize, model_apply, tmp_path, offload=True)


@pytest.mark.skip("Requires transformers#40673")
@requires_gpu
@pytest.mark.parametrize(
"model_stub,exp_perplexity",
[
("nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", 10.0),
("nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", 10.0),
],
)
def test_load_perplexity(model_stub, exp_perplexity):
model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_stub)

prompt = "The capital of France is Paris, the capital of Germany is Berlin"
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {key: value.to(model.device) for key, value in inputs.items()}
labels = inputs["input_ids"]

with torch.no_grad():
outputs = model(**inputs, labels=labels)

perplexity = torch.exp(outputs.loss)
assert perplexity <= exp_perplexity