Skip to content

Commit

Permalink
removed rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-fong committed Feb 3, 2025
1 parent bb019b0 commit f587a1b
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,18 +1351,27 @@ def merge_adapter(self, name: str):
Args:
name (str): LoRA module to merge.
"""
for module in self.modules():
if isinstance(module, LoRALayer):
if name in module.loras:
module.merge_adapter(name)
with ForwardContext(self, torch.empty(0, 1)):
#check if there are shared parameters between adapter weights
if self.base_model.shared_parameters:
ForwardContext.get_context().shared_parameters = self.base_model.shared_parameters

for module in self.modules():
if isinstance(module, LoRALayer):
if name in module.loras:
module.merge_adapter(name)

def reset_adapter(self):
"""
Resets weights of a LoRA module merged using `model.merge_adapter(name)`.
"""
for module in self.modules():
if isinstance(module, LoRALayer):
module.reset_adapter()
with ForwardContext(self, torch.empty(0, 1)):
if self.base_model.shared_parameters:
ForwardContext.get_context().shared_parameters = self.base_model.shared_parameters

for module in self.modules():
if isinstance(module, LoRALayer):
module.reset_adapter()

# HACK Copied from transformers/generation/utils.py
def _prepare_encoder_decoder_kwargs_for_generation(
Expand Down

0 comments on commit f587a1b

Please sign in to comment.