1717import torch
1818from compressed_tensors .transform import TransformArgs
1919from compressed_tensors .utils import TorchDtype
20- from pydantic import BaseModel , ConfigDict , Field
20+ from pydantic import BaseModel , ConfigDict , Field , model_validator
2121
2222
2323__all__ = ["TransformScheme" ]
@@ -36,6 +36,8 @@ class TransformScheme(BaseModel):
3636 :param randomize: True if uniquely randomized transform weights should be used,
3737 otherwise use identical transform weights where applicable
3838 :param requires_grad: True if weights include gradients for training
39+ :param block_size: If set, the transform matrix will be block diagonal, with each
40+ block being a square matrix of this size.
3941 :param precision: Precision at which this transform should be applied during online
4042 rotations. Fused (offline) rotations are always performed in float64
4143 """
@@ -44,7 +46,21 @@ class TransformScheme(BaseModel):
4446 apply : List [TransformArgs ] = Field (default_factory = list )
4547 randomize : bool = Field (default = False )
4648 requires_grad : bool = Field (default = False )
47- head_dim : Optional [int ] = Field (default = None )
49+ block_size : Optional [int ] = Field (default = None )
50+ head_dim : Optional [int ] = Field (
51+ default = None , deprecated = "head_dim is deprecated, use block_size instead"
52+ )
4853 precision : TorchDtype = Field (default = torch .float32 )
4954
55+ @model_validator (mode = "after" )
56+ def validate_model_after (model : "TransformScheme" ) -> "TransformScheme" :
57+ """
58+ If head_dim is used instead of block_size, set block_size to head_dim
59+ and remove head_dim
60+ """
61+ if model .block_size is None and model .head_dim is not None :
62+ model .block_size = model .head_dim
63+ model .head_dim = None
64+ return model
65+
5066 model_config = ConfigDict (extra = "forbid" )
0 commit comments