diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index c14da51f..03a0b730 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory): def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): super().__init__(name, scheme, seed) self.weights = ParameterizedDefaultDict(self._create_weight) + self.perms = ParameterizedDefaultDict(self._create_permutation) def create_transform(self, module: Module, args: TransformArgs): """ @@ -56,24 +57,35 @@ def create_transform(self, module: Module, args: TransformArgs): device = get_offloaded_device(module) weight = self.weights[size, dtype, device] - return HadamardTransform(weight, args) + perm = self.perms[weight] if self.scheme.randomize else None + return HadamardTransform(weight, perm, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: data = deterministic_hadamard_matrix(size, dtype, device) data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) + def _create_permutation(self, weight: Parameter) -> Parameter: + data = torch.randperm(weight.size(0), generator=self.generator) + return Parameter(data, requires_grad=False) + class HadamardTransform(TransformBase): - def __init__(self, weight: Parameter, args: TransformArgs): + def __init__( + self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs + ): super().__init__() self.weight = weight + self.perm = perm self.args = args def forward(self, value: Tensor) -> Tensor: - if not self.args.inverse: - weight = self.weight - else: - weight = self.weight.T + weight = self.weight + + if self.perm is not None: + weight = weight[self.perm][:, self.perm] + + if self.args.inverse: + weight = weight.T return apply_transform_weight(weight, value, self.args.location) diff --git a/src/compressed_tensors/transform/transform_config.py b/src/compressed_tensors/transform/transform_config.py index 414c21e0..df178c42 100644 --- a/src/compressed_tensors/transform/transform_config.py +++ b/src/compressed_tensors/transform/transform_config.py @@ -49,7 +49,7 @@ class TransformConfig(BaseModel): inverse=True, ), ], - randomize_modules=True, + randomize=True, ), "u": TransformScheme( type="hadamard", @@ -62,7 +62,7 @@ class TransformConfig(BaseModel): targets=["Linear"], location="output", inverse=True # non-mergable ), ], - randomize_modules=True, + randomize=True, ), } ) diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 1335063c..64d646e0 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -31,13 +31,12 @@ class TransformScheme(BaseModel): (see `Transforms.registered_names()`) :param apply: list of TransformationArgs containing the information about the modules that should be targeted by the specified transform - :param randomize_modules: True if unique transforms should be applied to each - unique module targeted by `apply`, otherwise reuse transform weights where - applicable + :param randomize: True if uniquely randomized transform weights should be used, + otherwise use identical transform weights where applicable :param requires_grad: True if weights include gradients for training """ type: str apply: List[TransformArgs] = Field(default_factory=list) - randomize_modules: bool = Field(default=False) + randomize: bool = Field(default=False) requires_grad: bool = Field(default=False) diff --git a/tests/test_transform/test_transform_scheme.py b/tests/test_transform/test_transform_scheme.py index ad851762..839ab46a 100644 --- a/tests/test_transform/test_transform_scheme.py +++ b/tests/test_transform/test_transform_scheme.py @@ -24,7 +24,7 @@ def test_basic_scheme(): type="hadamard", apply=[basic_args], ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 1 assert isinstance(scheme.apply[0], TransformArgs) @@ -43,10 +43,10 @@ def test_multiple_groups_global(): scheme = TransformScheme( type="hadamard", apply=[embedding_args, linear_args], - randomize_modules=True, + randomize=True, ) - assert scheme.randomize_modules + assert scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 2 assert isinstance(scheme.apply[0], TransformArgs) @@ -69,6 +69,6 @@ def test_multiple_groups(): apply=apply, ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 20