Skip to content

Commit 87c9b36

Browse files
committed
meta hadamards
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 7734cce commit 87c9b36

File tree

2 files changed

+7
-36
lines changed

2 files changed

+7
-36
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from collections import defaultdict
17-
from typing import List, Optional, Set, Tuple
16+
from typing import List, Optional, Set
1817

1918
import torch
2019
import torch.nn.utils.parametrize as P
@@ -101,8 +100,6 @@ def apply_to_model(self, model: Module, use_tqdm=True):
101100
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
102101
self._apply_to_module(module, arg)
103102

104-
self._update_tied_weights()
105-
106103
def _apply_to_module(self, module: Module, args: TransformArgs):
107104
"""
108105
Create transforms and apply them to the module
@@ -165,31 +162,6 @@ def output_hook(_, _input, output):
165162
else:
166163
raise NotImplementedError()
167164

168-
def _update_tied_weights(self):
169-
"""
170-
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
171-
which is used by transformers to detect and remove shared pointers
172-
during saving
173-
"""
174-
# map from data_ptrs to keys
175-
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
176-
for transform in self.transforms:
177-
for name, param in transform.named_parameters(recurse=False):
178-
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
179-
if has_offloaded_params(transform):
180-
param = transform._hf_hook.weights_map[name]
181-
ptr_to_keys[param.data_ptr()].append((transform, name))
182-
183-
# populate `_dynamic_tied_weights_keys` if there is more than one key
184-
# and ensure that they share tensors
185-
for shared_keys in ptr_to_keys.values():
186-
if len(shared_keys) > 1:
187-
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
188-
189-
for transform, name in shared_keys:
190-
transform._dynamic_tied_weights_keys.add(name)
191-
setattr(transform, name, tensor)
192-
193165

194166
class TransformBase(InternalModule, ABC):
195167
"""
@@ -198,11 +170,7 @@ class TransformBase(InternalModule, ABC):
198170

199171
args: TransformArgs
200172
weight: Parameter
201-
_dynamic_tied_weights_keys: Set[str]
202-
203-
def __init__(self):
204-
super().__init__()
205-
self._dynamic_tied_weights_keys = set()
173+
_dynamic_tied_weights_keys: List[str] = ["weight"]
206174

207175
@abstractmethod
208176
def forward(self, value: Tensor) -> Tensor:

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,16 @@ def _fetch_hadamard_divisor(
115115
than forcing callers to manage the file open context
116116
117117
:param n: size of known hadamard matrix
118+
:param dtype: data type to move fetched hadamard to
119+
:param device: device to move fetched hadamard to
118120
:return: a known hadamard matrix of size `n` if one exists, else None
119121
"""
120-
with safe_open(file_path, framework="pt", device=str(device)) as file:
122+
open_device = torch.device("cpu") if device.type == "meta" else device
123+
with safe_open(file_path, framework="pt", device=str(open_device)) as file:
121124
divisors = sorted((int(key) for key in file.keys()), reverse=True)
122125
for divisor in divisors:
123126
if n % divisor == 0 and is_pow2(n // divisor):
124-
return file.get_tensor(str(divisor)).to(dtype=dtype)
127+
return file.get_tensor(str(divisor)).to(dtype=dtype, device=device)
125128

126129
return None
127130

0 commit comments

Comments
 (0)