Skip to content

Commit 5eb5865

Browse files
committed
-Added abba tests to test_custom_models.py
1 parent e02cc57 commit 5eb5865

File tree

1 file changed

+62
-3
lines changed

1 file changed

+62
-3
lines changed

tests/test_custom_models.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,41 @@
284284
{"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": False},
285285
),
286286
("Conv2d 1x1 LOHA", "Conv2d1x1", LoHaConfig, {"target_modules": ["conv2d"]}),
287+
########
288+
# ABBA #
289+
########
290+
# ABBA tests are included in main TEST_CASES for basic functionality
291+
# Note: ABBA uses SVD-based initialization, so parameters are non-zero from start
292+
("Vanilla MLP 1 ABBA", "MLP", LoHaConfig, {"target_modules": "lin0", "init_weights": "abba"}),
293+
("Vanilla MLP 2 ABBA", "MLP", LoHaConfig, {"target_modules": ["lin0"], "init_weights": "abba"}),
294+
(
295+
"Vanilla MLP 3 ABBA",
296+
"MLP",
297+
LoHaConfig,
298+
{
299+
"target_modules": ["lin0"],
300+
"alpha": 4,
301+
"module_dropout": 0.1,
302+
"init_weights": "abba",
303+
},
304+
),
305+
("Vanilla MLP 4 ABBA", "MLP", LoHaConfig, {"target_modules": "lin0", "rank_dropout": 0.5, "init_weights": "abba"}),
306+
(
307+
"Vanilla MLP 5 ABBA with Khatri-Rao",
308+
"MLP",
309+
LoHaConfig,
310+
{"target_modules": ["lin0"], "init_weights": "abba", "use_khatri_rao": True},
311+
),
312+
("Conv2d 1 ABBA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d"], "init_weights": "abba"}),
313+
("Conv1d ABBA 1", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"], "init_weights": "abba"}),
314+
("Conv1d ABBA 2", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"], "r": 2, "init_weights": "abba"}),
315+
(
316+
"Conv1d ABBA 3",
317+
"Conv1dBigger",
318+
LoHaConfig,
319+
{"target_modules": ["conv1d"], "r": 2, "init_weights": "abba"},
320+
),
321+
("Conv2d 1x1 ABBA", "Conv2d1x1", LoHaConfig, {"target_modules": ["conv2d"], "init_weights": "abba"}),
287322
# LoKr
288323
("Vanilla MLP 1 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0"}),
289324
("Vanilla MLP 2 LOKR", "MLP", LoKrConfig, {"target_modules": ["lin0"]}),
@@ -2044,6 +2079,13 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
20442079
lr = 1e-3 # we get exploding gradients with MHA when learning rate is too high
20452080
elif issubclass(config_cls, VBLoRAConfig) or issubclass(config_cls, RandLoraConfig):
20462081
lr = 0.01 # otherwise we get nan
2082+
elif config_kwargs.get("init_weights") == "abba":
2083+
# ABBA starts closer to pretrained, use gentler updates than standard (0.5)
2084+
# Conv layers with ABBA need much lower LR due to Hadamard product amplification
2085+
if model_id in ["Conv1d", "Conv1dBigger", "Conv2d", "Conv2d1x1"]:
2086+
lr = 0.01 # Very low LR to prevent exploding gradients with Hadamard products
2087+
else:
2088+
lr = 0.1
20472089
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
20482090

20492091
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
@@ -2093,7 +2135,9 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
20932135
torch.nn.init.zeros_(model.vblora_vector_bank["default"])
20942136
model.eval()
20952137
outputs_before = model(**X)
2096-
assert torch.allclose(outputs_base, outputs_before)
2138+
# ABBA uses SVD initialization, so outputs won't match base model initially - skip that assertion
2139+
if config_kwargs.get("init_weights") != "abba":
2140+
assert torch.allclose(outputs_base, outputs_before)
20972141

20982142
if issubclass(config_cls, VBLoRAConfig):
20992143
# initialize `vblora_vector_bank` so it can be trained
@@ -2131,7 +2175,12 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
21312175
else:
21322176
rtol, atol = 1e-5, 1e-8
21332177
assert not torch.allclose(outputs_before, outputs_after, rtol=rtol, atol=atol)
2134-
assert torch.allclose(outputs_before, outputs_disabled)
2178+
# For ABBA: outputs_before != outputs_disabled because ABBA uses non-zero init
2179+
# But outputs_disabled should equal base model for both ABBA and others
2180+
if config_kwargs.get("init_weights") == "abba":
2181+
assert torch.allclose(outputs_base, outputs_disabled)
2182+
else:
2183+
assert torch.allclose(outputs_before, outputs_disabled)
21352184
assert torch.allclose(outputs_after, outputs_enabled_after_disable)
21362185

21372186
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@@ -2147,6 +2196,8 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
21472196
# same as test_disable_adapters, but with merging
21482197
X = self.prepare_inputs_for_testing()
21492198
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
2199+
model.eval()
2200+
outputs_base = model(**X) # Save base model outputs for ABBA comparison
21502201
config = config_cls(
21512202
base_model_name_or_path=model_id,
21522203
**config_kwargs,
@@ -2169,6 +2220,9 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
21692220
else:
21702221
# Adam optimizer since SGD isn't great for small models with IA3 + Conv1D
21712222
lr = 0.01
2223+
# ABBA Conv layers need lower learning rate to prevent gradient explosion
2224+
if config_kwargs.get("init_weights") == "abba" and model_id in ["Conv1d", "Conv1dBigger", "Conv2d", "Conv2d1x1"]:
2225+
lr = 0.001 # Very low LR for ABBA Conv with Adam
21722226
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
21732227

21742228
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
@@ -2214,7 +2268,12 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
22142268
assert torch.allclose(outputs_after, outputs_unmerged, atol=atol, rtol=rtol)
22152269

22162270
# check that disabling adapters gives the same results as before training
2217-
assert torch.allclose(outputs_before, outputs_disabled, atol=atol, rtol=rtol)
2271+
# For ABBA: outputs_before != outputs_disabled because ABBA uses non-zero init
2272+
# But outputs_disabled should equal base model for both ABBA and others
2273+
if config_kwargs.get("init_weights") == "abba":
2274+
assert torch.allclose(outputs_base, outputs_disabled, atol=atol, rtol=rtol)
2275+
else:
2276+
assert torch.allclose(outputs_before, outputs_disabled, atol=atol, rtol=rtol)
22182277

22192278
# check that enabling + disabling adapters does not change the results
22202279
assert torch.allclose(outputs_after, outputs_enabled_after_disable, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)