Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 104 additions & 15 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,46 +395,135 @@ def forward(self, hidden_states, residual_input, layer_norm):
return self.bottleneck_layer_forward(hidden_states, residual_input, layer_norm)


def hook_fn(adapter_layer, ln_get_fn, module, args, output):
# Retrieve residual from previous hook, if existing
def hook_fn(i, adapter_name, ln_name, module, args, output):
"""
Hook function to process the output of a module through an adapter layer.

Args:
i (int): Layer index.
adapter_name (str): Name of the adapter.
ln_name (str): Name of the layer normalization module.
module (torch.nn.Module): The module being hooked.
args (tuple): Arguments passed to the module.
output (torch.Tensor or tuple): Output of the module.

Returns:
torch.Tensor or tuple: Processed output through the adapter layer.
"""
context = ForwardContext.get_context()
adapter_layer = getattr(context, f"{adapter_name}", None)
layer = getattr(context, "layer", None)
layer_norm = multigetattr(layer, ln_name, None)
residual_input = getattr(context, f"{adapter_layer.location_key}_residual_input", None)
# Retrieve layer norm from getter fn
if ln_get_fn is not None:
layer_norm = ln_get_fn()
else:
layer_norm = None
# Call adapter layer
if isinstance(output, torch.Tensor):
return adapter_layer(output, residual_input, layer_norm)
else:
return (adapter_layer(output[0], residual_input, layer_norm),) + output[1:]


def _attention_adapters_hook_forward_pre_hook(module, args):
"""
Pre-forward hook to set the multi-head attention adapters in the context.

Args:
module (torch.nn.Module): The module being hooked.
args (tuple): Arguments passed to the module.
"""
context = ForwardContext.get_context()
if context is not None:
setattr(context, "mh_adapter", module.attention_adapters)


def _output_adapter_hook_forward_pre_hook(module, args):
"""
Pre-forward hook to set the output adapters in the context.

Args:
module (torch.nn.Module): The module being hooked.
args (tuple): Arguments passed to the module.
"""
context = ForwardContext.get_context()
if context is not None:
setattr(context, "output_adapter", module.output_adapters)


def _cross_attention_adapters_hook_forward_pre_hook(module, args):
"""
Pre-forward hook to set the cross-attention adapters in the context.

Args:
module (torch.nn.Module): The module being hooked.
args (tuple): Arguments passed to the module.
"""
context = ForwardContext.get_context()
if context is not None:
setattr(context, "crossattn_adapter", module.cross_attention_adapters)


def _layer_hook_forward_pre_hook(module, args):
"""
Pre-forward hook to set the current layer in the context.

Args:
module (torch.nn.Module): The module being hooked.
args (tuple): Arguments passed to the module.
"""
context = ForwardContext.get_context()
if context is not None:
setattr(context, "layer", module)


def _residual_hook_fn(location_key, module, args):
"""
Hook function to set the residual input in the context.

Args:
location_key (str): Location key of the adapter.
module (torch.nn.Module): The module being hooked.
args (tuple): Arguments passed to the module.
"""
context = ForwardContext.get_context()
if context is not None:
setattr(context, f"{location_key}_residual_input", args[0])


def init_bottleneck(model):
"""
Initialize bottleneck adapters for the given model.

Args:
model (torch.nn.Module): The model to initialize bottleneck adapters for.
"""
model = model.base_model
for _, layer in model.iter_layers():

for i, layer in model.iter_layers():
if not hasattr(layer, "has_layer_hook_forward_pre_hook"):
layer.register_forward_pre_hook(_layer_hook_forward_pre_hook)
layer.has_layer_hook_forward_pre_hook = True
if self_attn := multigetattr(layer, model.adapter_interface.layer_self_attn, None):
if o_proj := multigetattr(self_attn, model.adapter_interface.attn_o_proj, None):
if not hasattr(layer, "attention_adapters"):
layer.attention_adapters = BottleneckLayer("mh_adapter", is_layer_hooked=True)
ln_1_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_1, None)
o_proj.register_forward_hook(partial(hook_fn, layer.attention_adapters, ln_1_get_fn))
if not hasattr(layer, "has_attention_adapters_hook_forward_pre_hook"):
layer.register_forward_pre_hook(_attention_adapters_hook_forward_pre_hook)
layer.has_attention_adapters_hook_forward_pre_hook = True
o_proj.register_forward_hook(partial(hook_fn, i, "mh_adapter", model.adapter_interface.layer_ln_1))
if layer_output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj, None):
if not hasattr(layer, "output_adapters"):
layer.output_adapters = BottleneckLayer("output_adapter", is_layer_hooked=True)
ln_2_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_2, None)
layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters, ln_2_get_fn))
if not hasattr(layer, "has_output_adapters_hook_forward_pre_hook"):
layer.register_forward_pre_hook(_output_adapter_hook_forward_pre_hook)
layer.has_output_adapters_hook_forward_pre_hook = True
layer_output_proj.register_forward_hook(
partial(hook_fn, i, "output_adapter", model.adapter_interface.layer_ln_2)
)
if cross_attn := multigetattr(layer, model.adapter_interface.layer_cross_attn, None):
if not hasattr(cross_attn, "cross_attention_adapters"):
layer.attention_adapters = BottleneckLayer("cross_adapter", is_layer_hooked=True)
cross_attn.register_forward_hook(partial(hook_fn, layer.attention_adapters, None))
layer.cross_attention_adapters = BottleneckLayer("cross_adapter", is_layer_hooked=True)
if not hasattr(layer, "has_cross_attention_adapters_hook_forward_pre_hook"):
layer.register_forward_pre_hook(_cross_attention_adapters_hook_forward_pre_hook)
layer.has_cross_attention_adapters_hook_forward_pre_hook = True
cross_attn.register_forward_hook(partial(hook_fn, i, "crossattn_adapter", None))

if model.adapter_interface.layer_pre_self_attn is not None:
if pre_self_attn := multigetattr(layer, model.adapter_interface.layer_pre_self_attn, None):
Expand Down