Skip to content

[Transform] Construct on GPU, cache on CPU #352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a27db62
use hadamards database file
kylesayrs Jun 11, 2025
ce63955
try manifest
kylesayrs Jun 11, 2025
7ae5863
try setup, update hadamards list
kylesayrs Jun 11, 2025
67675c3
fix setup
kylesayrs Jun 11, 2025
f061db9
add docstrings, cleanup
kylesayrs Jun 11, 2025
4a84ce1
fix setup, thank you @dbarbuzzi
kylesayrs Jun 11, 2025
cde1066
remove numpy, add tests
kylesayrs Jun 11, 2025
1ba6195
solidify dtype, add gpu tests
kylesayrs Jun 11, 2025
c373345
fix docstring
kylesayrs Jun 11, 2025
fbaf47a
add device option
kylesayrs Jun 11, 2025
5a887f4
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
310fe6d
save construction device changes for later
kylesayrs Jun 11, 2025
b715329
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
249323c
cite nja sloane
kylesayrs Jun 11, 2025
1823af4
Merge branch 'kylesayrs/extend-hadamard', remote-tracking branch 'ori…
kylesayrs Jun 11, 2025
94a0bf5
Merge remote-tracking branch 'origin' into kylesayrs/extend-hadamard
kylesayrs Jun 11, 2025
cf066e0
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 11, 2025
c1a4a34
remove dreg
kylesayrs Jun 11, 2025
5807ee1
put on device via safe_open
kylesayrs Jun 11, 2025
ccb88ed
nits and docstrings
kylesayrs Jun 12, 2025
feba695
update docstring
kylesayrs Jun 12, 2025
c8f6b53
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 12, 2025
b6a0dd4
Merge remote-tracking branch 'origin' into kylesayrs/transform_constr…
kylesayrs Jun 13, 2025
fd3390a
construct with same dtype, constructing on fp32 found no difference
kylesayrs Jun 23, 2025
ad29c15
remove unnecessary imports
kylesayrs Jun 23, 2025
500af9b
use factory_kwargs
kylesayrs Jul 7, 2025
8e36540
add frozen dict to deps
kylesayrs Jul 7, 2025
7dc182b
Merge remote-tracking branch 'origin' into kylesayrs/transform_constr…
kylesayrs Jul 7, 2025
2fda022
correct typo
kylesayrs Jul 7, 2025
99a1159
Merge remote-tracking branch 'origin' into kylesayrs/transform_constr…
kylesayrs Jul 7, 2025
282ba31
fix missing import
kylesayrs Jul 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _setup_packages() -> List:
)

def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict"]

def _setup_extras() -> Dict:
return {
Expand Down
19 changes: 14 additions & 5 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
apply_transform_weight,
get_matrix_size,
)
from compressed_tensors.utils import get_offloaded_device
from compressed_tensors.utils import get_execution_device, get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
from torch import Tensor, device, dtype
from torch.nn import Linear, Module, Parameter
Expand Down Expand Up @@ -55,14 +55,23 @@ def create_transform(self, module: Module, args: TransformArgs):
size = get_matrix_size(module, args.location)
dtype = module.weight.dtype
device = get_offloaded_device(module)
exec_device = get_execution_device(module)

weight = self.weights[size, dtype, device]
factory_kwargs = {"construct_device": exec_device}
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
perm = self.perms[weight] if self.scheme.randomize else None
return HadamardTransform(weight, perm, args)

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = deterministic_hadamard_matrix(size, dtype, device)
data = data.to(dtype=dtype, device=device)
def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
) -> Parameter:
# construct on execution device, cache on offload device
data = deterministic_hadamard_matrix(size, dtype, construct_device)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)

def _create_permutation(self, weight: Parameter) -> Parameter:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def create_transform(self, module: Module, args: TransformArgs):
return RandomMatrixTransform(weight, args)

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
# TODO: verify that weight is invertible (has non-zero determinant)
data = torch.rand(
(size, size), generator=self.generator, dtype=dtype, device=device
)
Expand Down
13 changes: 10 additions & 3 deletions src/compressed_tensors/transform/factory/random_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ class RandomHadamardFactory(HadamardFactory):
:param seed: random seed used to transform weight randomization
"""

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = random_hadamard_matrix(size, dtype, device, self.generator)
data = data.to(dtype=dtype, device=device)
def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
) -> Parameter:
# construct on execution device, cache on offload device
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)
21 changes: 17 additions & 4 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import contextlib
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional

import numpy
import torch
from frozendict import frozendict
from transformers import AutoConfig


Expand Down Expand Up @@ -373,11 +374,23 @@ class ParameterizedDefaultDict(dict):

def __init__(self, default_factory: Callable[[Any], Any]):
self.default_factory = default_factory
self._factory_kwargs = frozendict()

def __missing__(self, key):
def __missing__(self, key: Any) -> Any:
if isinstance(key, tuple):
value = self.default_factory(*key)
value = self.default_factory(*key, **self._factory_kwargs)
else:
value = self.default_factory(key)
value = self.default_factory(key, **self._factory_kwargs)
self[key] = value
return value

def get(self, *args, factory_kwargs: Mapping = frozendict()) -> Any:
"""
Similar to `__getitem__`, but allows passing kwargs to factory function

:param \\*args: args whose tuple will value will be treated as key
:param factory_kwargs: keyword arguments to pass to `default_factory`
:return: dictionary entry for given key
"""
with patch_attr(self, "_factory_kwargs", factory_kwargs):
return self[args]