diff --git a/setup.py b/setup.py index 901ec9dd..58a1105f 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ def _setup_packages() -> List: ) def _setup_install_requires() -> List: - return ["torch>=1.7.0", "transformers", "pydantic>=2.0"] + return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict"] def _setup_extras() -> Dict: return { diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 03a0b730..dd515976 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -22,7 +22,7 @@ apply_transform_weight, get_matrix_size, ) -from compressed_tensors.utils import get_offloaded_device +from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype from torch.nn import Linear, Module, Parameter @@ -55,14 +55,23 @@ def create_transform(self, module: Module, args: TransformArgs): size = get_matrix_size(module, args.location) dtype = module.weight.dtype device = get_offloaded_device(module) + exec_device = get_execution_device(module) - weight = self.weights[size, dtype, device] + factory_kwargs = {"construct_device": exec_device} + weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs) perm = self.perms[weight] if self.scheme.randomize else None return HadamardTransform(weight, perm, args) - def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = deterministic_hadamard_matrix(size, dtype, device) - data = data.to(dtype=dtype, device=device) + def _create_weight( + self, + size: int, + dtype: dtype, + device: device, + construct_device: device, + ) -> Parameter: + # construct on execution device, cache on offload device + data = deterministic_hadamard_matrix(size, dtype, construct_device) + data = data.to(device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) def _create_permutation(self, weight: Parameter) -> Parameter: diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index e551fc5f..47e9bcbb 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -62,6 +62,7 @@ def create_transform(self, module: Module, args: TransformArgs): return RandomMatrixTransform(weight, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: + # TODO: verify that weight is invertible (has non-zero determinant) data = torch.rand( (size, size), generator=self.generator, dtype=dtype, device=device ) diff --git a/src/compressed_tensors/transform/factory/random_hadamard.py b/src/compressed_tensors/transform/factory/random_hadamard.py index 78fb6975..dd32d95e 100644 --- a/src/compressed_tensors/transform/factory/random_hadamard.py +++ b/src/compressed_tensors/transform/factory/random_hadamard.py @@ -28,7 +28,14 @@ class RandomHadamardFactory(HadamardFactory): :param seed: random seed used to transform weight randomization """ - def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = random_hadamard_matrix(size, dtype, device, self.generator) - data = data.to(dtype=dtype, device=device) + def _create_weight( + self, + size: int, + dtype: dtype, + device: device, + construct_device: device, + ) -> Parameter: + # construct on execution device, cache on offload device + data = random_hadamard_matrix(size, dtype, construct_device, self.generator) + data = data.to(device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index d8898ae4..712c4f83 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -15,10 +15,11 @@ import contextlib import warnings from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional import numpy import torch +from frozendict import frozendict from transformers import AutoConfig @@ -373,11 +374,23 @@ class ParameterizedDefaultDict(dict): def __init__(self, default_factory: Callable[[Any], Any]): self.default_factory = default_factory + self._factory_kwargs = frozendict() - def __missing__(self, key): + def __missing__(self, key: Any) -> Any: if isinstance(key, tuple): - value = self.default_factory(*key) + value = self.default_factory(*key, **self._factory_kwargs) else: - value = self.default_factory(key) + value = self.default_factory(key, **self._factory_kwargs) self[key] = value return value + + def get(self, *args, factory_kwargs: Mapping = frozendict()) -> Any: + """ + Similar to `__getitem__`, but allows passing kwargs to factory function + + :param \\*args: args whose tuple will value will be treated as key + :param factory_kwargs: keyword arguments to pass to `default_factory` + :return: dictionary entry for given key + """ + with patch_attr(self, "_factory_kwargs", factory_kwargs): + return self[args]