@@ -32,18 +32,18 @@ class QuIPModifier(Modifier):
3232
3333 Lifecycle:
3434 - on_initialize
35- - infer SpinQuantMappings & NormMappings
36- - as needed, create transform schemes for R1, R2, R3, & R4
35+ - as needed, create transform schemes for V (input) and U (output)
3736 - on_start
38- - normalize embeddings
39- - fuse norm layers into subsequent Linear layers
4037 - apply TransformConfig
4138 - fuse transforms into weights for mergeable transforms
4239 - add hooks for online transforms
4340 - on sequential epoch end
4441 - on_end
4542 - on_finalize
4643
44+ :param rotations: which rotation schemes to apply to the model. Including `"v"` will
45+ rotate the input side of weights, and including `"u"` will rotate the output
46+ side of weights (note that v does not require u and vice-versa)
4747 :param transform_type: The type of transform to apply to the model.
4848 `"hadamard"` has the least performance cost but only supports sizes which are
4949 powers of power of two.
@@ -58,6 +58,7 @@ class QuIPModifier(Modifier):
5858 :param transform_config: Optional transform config for overriding provided arguments
5959 """ # noqa: E501
6060
61+ rotations : List [Literal ["v" , "u" ]] = Field (default_factory = lambda : ["v" , "u" ])
6162 transform_type : Literal ["hadamard" , "random-hadamard" , "random-matrix" ] = Field (
6263 default = "random-hadamard"
6364 )
@@ -77,6 +78,12 @@ def validate_not_implemented(cls, value, info: ValidationInfo):
7778 raise NotImplementedError (f"{ info .field_name } is not supported right now" )
7879 return value
7980
81+ @field_validator ("rotations" , mode = "before" )
82+ def validate_lowercase_list (cls , value ):
83+ if isinstance (value , list ):
84+ value = [v .lower () if isinstance (v , str ) else v for v in value ]
85+ return value
86+
8087 def on_initialize (self , state : State , ** kwargs ) -> bool :
8188 if self .transform_config is not None :
8289 return True
@@ -111,45 +118,52 @@ def on_finalize(self, state: State, **kwargs) -> bool:
111118 return True
112119
113120 def _create_config (self ) -> TransformConfig :
114- return TransformConfig (
115- config_groups = {
116- "v" : TransformScheme (
117- type = self .transform_type ,
118- apply = [
119- TransformArgs (
120- targets = self .targets ,
121- location = "input" , # non-mergable
122- ignore = self .ignore ,
123- ),
124- TransformArgs (
125- targets = self .targets ,
126- location = "weight_input" ,
127- inverse = True ,
128- ignore = self .ignore ,
129- ),
130- ],
131- randomize = self .randomize ,
132- requires_grad = self .learnable ,
133- precision = self .precision ,
121+ config_groups = dict ()
122+ if "v" in self .rotations :
123+ config_groups ["v" ] = self ._create_v_scheme ()
124+ if "u" in self .rotations :
125+ config_groups ["u" ] = self ._create_u_scheme ()
126+
127+ return TransformConfig (config_groups = config_groups )
128+
129+ def _create_v_scheme (self ) -> TransformScheme :
130+ return TransformScheme (
131+ type = self .transform_type ,
132+ apply = [
133+ TransformArgs (
134+ targets = self .targets ,
135+ location = "input" , # non-mergable
136+ ignore = self .ignore ,
137+ ),
138+ TransformArgs (
139+ targets = self .targets ,
140+ location = "weight_input" ,
141+ inverse = True ,
142+ ignore = self .ignore ,
143+ ),
144+ ],
145+ randomize = self .randomize ,
146+ requires_grad = self .learnable ,
147+ precision = self .precision ,
148+ )
149+
150+ def _create_u_scheme (self ) -> TransformScheme :
151+ return TransformScheme (
152+ type = self .transform_type ,
153+ apply = [
154+ TransformArgs (
155+ targets = self .targets ,
156+ location = "weight_output" ,
157+ ignore = self .ignore ,
134158 ),
135- "u" : TransformScheme (
136- type = self .transform_type ,
137- apply = [
138- TransformArgs (
139- targets = self .targets ,
140- location = "weight_output" ,
141- ignore = self .ignore ,
142- ),
143- TransformArgs (
144- targets = self .targets ,
145- location = "output" , # non-mergable
146- inverse = True ,
147- ignore = self .ignore ,
148- ),
149- ],
150- randomize = self .randomize ,
151- requires_grad = self .learnable ,
152- precision = self .precision ,
159+ TransformArgs (
160+ targets = self .targets ,
161+ location = "output" , # non-mergable
162+ inverse = True ,
163+ ignore = self .ignore ,
153164 ),
154- }
165+ ],
166+ randomize = self .randomize ,
167+ requires_grad = self .learnable ,
168+ precision = self .precision ,
155169 )
0 commit comments