From f587a1b2397824dc68c69c16ac9c7e49cbafe2d3 Mon Sep 17 00:00:00 2001 From: julian fong Date: Mon, 3 Feb 2025 18:59:02 -0500 Subject: [PATCH] removed rebase --- src/adapters/model_mixin.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 5e142ad4c..7483330f7 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1351,18 +1351,27 @@ def merge_adapter(self, name: str): Args: name (str): LoRA module to merge. """ - for module in self.modules(): - if isinstance(module, LoRALayer): - if name in module.loras: - module.merge_adapter(name) + with ForwardContext(self, torch.empty(0, 1)): + #check if there are shared parameters between adapter weights + if self.base_model.shared_parameters: + ForwardContext.get_context().shared_parameters = self.base_model.shared_parameters + + for module in self.modules(): + if isinstance(module, LoRALayer): + if name in module.loras: + module.merge_adapter(name) def reset_adapter(self): """ Resets weights of a LoRA module merged using `model.merge_adapter(name)`. """ - for module in self.modules(): - if isinstance(module, LoRALayer): - module.reset_adapter() + with ForwardContext(self, torch.empty(0, 1)): + if self.base_model.shared_parameters: + ForwardContext.get_context().shared_parameters = self.base_model.shared_parameters + + for module in self.modules(): + if isinstance(module, LoRALayer): + module.reset_adapter() # HACK Copied from transformers/generation/utils.py def _prepare_encoder_decoder_kwargs_for_generation(