@@ -32,18 +32,18 @@ class QuIPModifier(Modifier):
32
32
33
33
Lifecycle:
34
34
- on_initialize
35
- - infer SpinQuantMappings & NormMappings
36
- - as needed, create transform schemes for R1, R2, R3, & R4
35
+ - as needed, create transform schemes for V (input) and U (output)
37
36
- on_start
38
- - normalize embeddings
39
- - fuse norm layers into subsequent Linear layers
40
37
- apply TransformConfig
41
38
- fuse transforms into weights for mergeable transforms
42
39
- add hooks for online transforms
43
40
- on sequential epoch end
44
41
- on_end
45
42
- on_finalize
46
43
44
+ :param rotations: which rotation schemes to apply to the model. Including `"v"` will
45
+ rotate the input side of weights, and including `"u"` will rotate the output
46
+ side of weights (note that v does not require u and vice-versa)
47
47
:param transform_type: The type of transform to apply to the model.
48
48
`"hadamard"` has the least performance cost but only supports sizes which are
49
49
powers of power of two.
@@ -58,6 +58,7 @@ class QuIPModifier(Modifier):
58
58
:param transform_config: Optional transform config for overriding provided arguments
59
59
""" # noqa: E501
60
60
61
+ rotations : List [Literal ["v" , "u" ]] = Field (default_factory = lambda : ["v" , "u" ])
61
62
transform_type : Literal ["hadamard" , "random-hadamard" , "random-matrix" ] = Field (
62
63
default = "random-hadamard"
63
64
)
@@ -77,6 +78,12 @@ def validate_not_implemented(cls, value, info: ValidationInfo):
77
78
raise NotImplementedError (f"{ info .field_name } is not supported right now" )
78
79
return value
79
80
81
+ @field_validator ("rotations" , mode = "before" )
82
+ def validate_lowercase_list (cls , value ):
83
+ if isinstance (value , list ):
84
+ value = [v .lower () if isinstance (v , str ) else v for v in value ]
85
+ return value
86
+
80
87
def on_initialize (self , state : State , ** kwargs ) -> bool :
81
88
if self .transform_config is not None :
82
89
return True
@@ -111,45 +118,52 @@ def on_finalize(self, state: State, **kwargs) -> bool:
111
118
return True
112
119
113
120
def _create_config (self ) -> TransformConfig :
114
- return TransformConfig (
115
- config_groups = {
116
- "v" : TransformScheme (
117
- type = self .transform_type ,
118
- apply = [
119
- TransformArgs (
120
- targets = self .targets ,
121
- location = "input" , # non-mergable
122
- ignore = self .ignore ,
123
- ),
124
- TransformArgs (
125
- targets = self .targets ,
126
- location = "weight_input" ,
127
- inverse = True ,
128
- ignore = self .ignore ,
129
- ),
130
- ],
131
- randomize = self .randomize ,
132
- requires_grad = self .learnable ,
133
- precision = self .precision ,
121
+ config_groups = dict ()
122
+ if "v" in self .rotations :
123
+ config_groups ["v" ] = self ._create_v_scheme ()
124
+ if "u" in self .rotations :
125
+ config_groups ["u" ] = self ._create_u_scheme ()
126
+
127
+ return TransformConfig (config_groups = config_groups )
128
+
129
+ def _create_v_scheme (self ) -> TransformScheme :
130
+ return TransformScheme (
131
+ type = self .transform_type ,
132
+ apply = [
133
+ TransformArgs (
134
+ targets = self .targets ,
135
+ location = "input" , # non-mergable
136
+ ignore = self .ignore ,
137
+ ),
138
+ TransformArgs (
139
+ targets = self .targets ,
140
+ location = "weight_input" ,
141
+ inverse = True ,
142
+ ignore = self .ignore ,
143
+ ),
144
+ ],
145
+ randomize = self .randomize ,
146
+ requires_grad = self .learnable ,
147
+ precision = self .precision ,
148
+ )
149
+
150
+ def _create_u_scheme (self ) -> TransformScheme :
151
+ return TransformScheme (
152
+ type = self .transform_type ,
153
+ apply = [
154
+ TransformArgs (
155
+ targets = self .targets ,
156
+ location = "weight_output" ,
157
+ ignore = self .ignore ,
134
158
),
135
- "u" : TransformScheme (
136
- type = self .transform_type ,
137
- apply = [
138
- TransformArgs (
139
- targets = self .targets ,
140
- location = "weight_output" ,
141
- ignore = self .ignore ,
142
- ),
143
- TransformArgs (
144
- targets = self .targets ,
145
- location = "output" , # non-mergable
146
- inverse = True ,
147
- ignore = self .ignore ,
148
- ),
149
- ],
150
- randomize = self .randomize ,
151
- requires_grad = self .learnable ,
152
- precision = self .precision ,
159
+ TransformArgs (
160
+ targets = self .targets ,
161
+ location = "output" , # non-mergable
162
+ inverse = True ,
163
+ ignore = self .ignore ,
153
164
),
154
- }
165
+ ],
166
+ randomize = self .randomize ,
167
+ requires_grad = self .learnable ,
168
+ precision = self .precision ,
155
169
)
0 commit comments