284284 {"target_modules" : ["conv1d" ], "r" : 2 , "use_effective_conv2d" : False },
285285 ),
286286 ("Conv2d 1x1 LOHA" , "Conv2d1x1" , LoHaConfig , {"target_modules" : ["conv2d" ]}),
287+ ########
288+ # ABBA #
289+ ########
290+ # ABBA tests are included in main TEST_CASES for basic functionality
291+ # Note: ABBA uses SVD-based initialization, so parameters are non-zero from start
292+ ("Vanilla MLP 1 ABBA" , "MLP" , LoHaConfig , {"target_modules" : "lin0" , "init_weights" : "abba" }),
293+ ("Vanilla MLP 2 ABBA" , "MLP" , LoHaConfig , {"target_modules" : ["lin0" ], "init_weights" : "abba" }),
294+ (
295+ "Vanilla MLP 3 ABBA" ,
296+ "MLP" ,
297+ LoHaConfig ,
298+ {
299+ "target_modules" : ["lin0" ],
300+ "alpha" : 4 ,
301+ "module_dropout" : 0.1 ,
302+ "init_weights" : "abba" ,
303+ },
304+ ),
305+ ("Vanilla MLP 4 ABBA" , "MLP" , LoHaConfig , {"target_modules" : "lin0" , "rank_dropout" : 0.5 , "init_weights" : "abba" }),
306+ (
307+ "Vanilla MLP 5 ABBA with Khatri-Rao" ,
308+ "MLP" ,
309+ LoHaConfig ,
310+ {"target_modules" : ["lin0" ], "init_weights" : "abba" , "use_khatri_rao" : True },
311+ ),
312+ ("Conv2d 1 ABBA" , "Conv2d" , LoHaConfig , {"target_modules" : ["conv2d" ], "init_weights" : "abba" }),
313+ ("Conv1d ABBA 1" , "Conv1d" , LoHaConfig , {"target_modules" : ["conv1d" ], "init_weights" : "abba" }),
314+ ("Conv1d ABBA 2" , "Conv1d" , LoHaConfig , {"target_modules" : ["conv1d" ], "r" : 2 , "init_weights" : "abba" }),
315+ (
316+ "Conv1d ABBA 3" ,
317+ "Conv1dBigger" ,
318+ LoHaConfig ,
319+ {"target_modules" : ["conv1d" ], "r" : 2 , "init_weights" : "abba" },
320+ ),
321+ ("Conv2d 1x1 ABBA" , "Conv2d1x1" , LoHaConfig , {"target_modules" : ["conv2d" ], "init_weights" : "abba" }),
287322 # LoKr
288323 ("Vanilla MLP 1 LOKR" , "MLP" , LoKrConfig , {"target_modules" : "lin0" }),
289324 ("Vanilla MLP 2 LOKR" , "MLP" , LoKrConfig , {"target_modules" : ["lin0" ]}),
@@ -2044,6 +2079,13 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
20442079 lr = 1e-3 # we get exploding gradients with MHA when learning rate is too high
20452080 elif issubclass (config_cls , VBLoRAConfig ) or issubclass (config_cls , RandLoraConfig ):
20462081 lr = 0.01 # otherwise we get nan
2082+ elif config_kwargs .get ("init_weights" ) == "abba" :
2083+ # ABBA starts closer to pretrained, use gentler updates than standard (0.5)
2084+ # Conv layers with ABBA need much lower LR due to Hadamard product amplification
2085+ if model_id in ["Conv1d" , "Conv1dBigger" , "Conv2d" , "Conv2d1x1" ]:
2086+ lr = 0.01 # Very low LR to prevent exploding gradients with Hadamard products
2087+ else :
2088+ lr = 0.1
20472089 optimizer = torch .optim .SGD (model .parameters (), lr = lr )
20482090
20492091 # train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
@@ -2093,7 +2135,9 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
20932135 torch .nn .init .zeros_ (model .vblora_vector_bank ["default" ])
20942136 model .eval ()
20952137 outputs_before = model (** X )
2096- assert torch .allclose (outputs_base , outputs_before )
2138+ # ABBA uses SVD initialization, so outputs won't match base model initially - skip that assertion
2139+ if config_kwargs .get ("init_weights" ) != "abba" :
2140+ assert torch .allclose (outputs_base , outputs_before )
20972141
20982142 if issubclass (config_cls , VBLoRAConfig ):
20992143 # initialize `vblora_vector_bank` so it can be trained
@@ -2131,7 +2175,12 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
21312175 else :
21322176 rtol , atol = 1e-5 , 1e-8
21332177 assert not torch .allclose (outputs_before , outputs_after , rtol = rtol , atol = atol )
2134- assert torch .allclose (outputs_before , outputs_disabled )
2178+ # For ABBA: outputs_before != outputs_disabled because ABBA uses non-zero init
2179+ # But outputs_disabled should equal base model for both ABBA and others
2180+ if config_kwargs .get ("init_weights" ) == "abba" :
2181+ assert torch .allclose (outputs_base , outputs_disabled )
2182+ else :
2183+ assert torch .allclose (outputs_before , outputs_disabled )
21352184 assert torch .allclose (outputs_after , outputs_enabled_after_disable )
21362185
21372186 @pytest .mark .parametrize ("test_name, model_id, config_cls, config_kwargs" , TEST_CASES )
@@ -2147,6 +2196,8 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
21472196 # same as test_disable_adapters, but with merging
21482197 X = self .prepare_inputs_for_testing ()
21492198 model = self .transformers_class .from_pretrained (model_id ).to (self .torch_device )
2199+ model .eval ()
2200+ outputs_base = model (** X ) # Save base model outputs for ABBA comparison
21502201 config = config_cls (
21512202 base_model_name_or_path = model_id ,
21522203 ** config_kwargs ,
@@ -2169,6 +2220,9 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
21692220 else :
21702221 # Adam optimizer since SGD isn't great for small models with IA3 + Conv1D
21712222 lr = 0.01
2223+ # ABBA Conv layers need lower learning rate to prevent gradient explosion
2224+ if config_kwargs .get ("init_weights" ) == "abba" and model_id in ["Conv1d" , "Conv1dBigger" , "Conv2d" , "Conv2d1x1" ]:
2225+ lr = 0.001 # Very low LR for ABBA Conv with Adam
21722226 optimizer = torch .optim .Adam (model .parameters (), lr = lr )
21732227
21742228 # train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
@@ -2214,7 +2268,12 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
22142268 assert torch .allclose (outputs_after , outputs_unmerged , atol = atol , rtol = rtol )
22152269
22162270 # check that disabling adapters gives the same results as before training
2217- assert torch .allclose (outputs_before , outputs_disabled , atol = atol , rtol = rtol )
2271+ # For ABBA: outputs_before != outputs_disabled because ABBA uses non-zero init
2272+ # But outputs_disabled should equal base model for both ABBA and others
2273+ if config_kwargs .get ("init_weights" ) == "abba" :
2274+ assert torch .allclose (outputs_base , outputs_disabled , atol = atol , rtol = rtol )
2275+ else :
2276+ assert torch .allclose (outputs_before , outputs_disabled , atol = atol , rtol = rtol )
22182277
22192278 # check that enabling + disabling adapters does not change the results
22202279 assert torch .allclose (outputs_after , outputs_enabled_after_disable , atol = atol , rtol = rtol )
0 commit comments