@@ -119,16 +119,17 @@ def on_initialize(self, state: State, **kwargs) -> bool:
119119
120120 self .mappings = infer_mapping_from_model (state .model )
121121 self .norm_mappings = infer_norm_mapping_from_model (state .model )
122+ head_dim = self ._infer_head_dim (state .model )
122123
123124 config_groups = {}
124125 if SpinquantRotation .R1 in self .rotations :
125126 config_groups ["R1" ] = self ._create_r1_scheme ()
126127
127128 if SpinquantRotation .R2 in self .rotations :
128- config_groups ["R2" ] = self ._create_r2_scheme (state . model )
129+ config_groups ["R2" ] = self ._create_r2_scheme (head_dim )
129130
130131 if SpinquantRotation .R3 in self .rotations :
131- config_groups ["R3" ] = self ._create_r3_scheme ()
132+ config_groups ["R3" ] = self ._create_r3_scheme (head_dim )
132133
133134 if SpinquantRotation .R4 in self .rotations :
134135 config_groups ["R4" ] = self ._create_r4_scheme ()
@@ -209,16 +210,7 @@ def _create_r1_scheme(self) -> TransformScheme:
209210 ],
210211 )
211212
212- def _create_r2_scheme (self , model : PreTrainedModel ) -> TransformScheme :
213- config = model .config
214-
215- if hasattr (config , "head_dim" ):
216- head_dim = config .head_dim
217- elif hasattr (config , "hidden_size" ) and hasattr (config , "num_attention_heads" ):
218- head_dim = config .hidden_size // config .num_attention_heads
219- else :
220- raise NotImplementedError ()
221-
213+ def _create_r2_scheme (self , head_dim : int ) -> TransformScheme :
222214 return TransformScheme (
223215 type = self .transform_type ,
224216 randomize = self .randomize ,
@@ -235,9 +227,23 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
235227 ],
236228 )
237229
238- def _create_r3_scheme (self ) -> TransformScheme :
239- raise NotImplementedError (
240- "SpinQuant R3 rotations will be added in a future release"
230+ def _create_r3_scheme (self , head_dim : int ) -> TransformScheme :
231+ return TransformScheme (
232+ type = self .transform_type ,
233+ randomize = self .randomize ,
234+ requires_grad = self .learnable ,
235+ precision = self .precision ,
236+ head_dim = head_dim ,
237+ apply = [
238+ TransformArgs (
239+ targets = [self .mappings .attn ],
240+ location = "q_attn" ,
241+ ),
242+ TransformArgs (
243+ targets = [self .mappings .attn ],
244+ location = "k_cache" ,
245+ ),
246+ ],
241247 )
242248
243249 def _create_r4_scheme (self ) -> TransformScheme :
@@ -258,3 +264,13 @@ def _create_r4_scheme(self) -> TransformScheme:
258264 ),
259265 ],
260266 )
267+
268+ def _infer_head_dim (self , model : PreTrainedModel ) -> int :
269+ config = model .config
270+
271+ if hasattr (config , "head_dim" ):
272+ return config .head_dim
273+ elif hasattr (config , "hidden_size" ) and hasattr (config , "num_attention_heads" ):
274+ return config .hidden_size // config .num_attention_heads
275+ else :
276+ raise NotImplementedError ()
0 commit comments